debugged stack overflow error, fixed connection timeout and some error flow problems

This commit is contained in:
Noah Laptop 2019-03-12 16:59:45 -07:00
parent c212e355a4
commit 7f72073fa6
5 changed files with 129 additions and 51 deletions

View file

@ -67,6 +67,7 @@ class SSLClient : public SSLClientImpl {
*/
static_assert(std::is_base_of<Client, C>::value, "C must be a Client Class!");
static_assert(SessionCache > 0 && SessionCache < 255, "There can be no less than one and no more than 255 sessions in the cache!");
static_assert(SessionCache <= 3, "You need to decrease the size of m_iobuf in order to have more than 3 sessions at once, otherwise memory issues will occur.");
// static_assert(std::is_function<decltype(C::status)>::value, "C must have a status() function!");
public:
@ -90,12 +91,16 @@ public:
explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug = SSL_ERROR)
: SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, debug)
, m_client(client)
, m_sessions{}
, m_sessions{SSLSession()}
, m_index(0)
{
// for (uint8_t i = 0; i < SessionCache; i++) m_sessions[i] = SSLSession();
// since we are copying the client in the ctor, we have to set
// the client pointer after the class is constructed
set_client(&m_client);
// set the timeout to a reasonable number (it can always be changes later)
// SSL Connections take a really long time so we don't want to time out a legitimate thing
setTimeout(10 * 1000);
}
/*
@ -114,38 +119,13 @@ public:
//! get the client object
C& getClient() { return m_client; }
virtual SSLSession& getSession(const char* host, const IPAddress& addr) {
// search for a matching session with the IP
int temp_index = -1;
for (size_t i = 0; i < SessionCache; i++) {
// if we're looking at a real session
if (m_sessions[i].is_valid_session()
&& (
// and the hostname matches, or
(host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0)
// there is no hostname and the IP address matches
|| (host == NULL && addr == m_sessions[i].get_ip())
)) {
virtual SSLSession& getSession(const char* host, const IPAddress& addr);
temp_index = i;
break;
}
}
// if none are availible, use m_index
if (temp_index == -1) {
temp_index = m_index;
// reset the session so we don't try to send one sites session to another
m_sessions[temp_index] = SSLSession();
}
// increment m_index so the session cache is a circular buffer
if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0;
// return the pointed to value
m_info("Using session index: ", __func__);
Serial.println(temp_index);
return m_sessions[temp_index];
}
virtual void removeSession(const char* host, const IPAddress& addr);
private:
// utility function to find a session index based off of a host and IP
int m_getSessionIndex(const char* host, const IPAddress& addr) const;
// create a copy of the client
C m_client;
// also store an array of SSLSessions, so we can resume communication with multiple websites
@ -154,4 +134,56 @@ private:
size_t m_index;
};
template <class C, size_t SessionCache>
SSLSession& SSLClient<C, SessionCache>::getSession(const char* host, const IPAddress& addr) {
const char* func_name = __func__;
// search for a matching session with the IP
int temp_index = m_getSessionIndex(host, addr);
// if none are availible, use m_index
if (temp_index == -1) {
temp_index = m_index;
// reset the session so we don't try to send one sites session to another
m_sessions[temp_index] = SSLSession();
}
// increment m_index so the session cache is a circular buffer
if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0;
// return the pointed to value
m_info("Using session index: ", func_name);
Serial.println(temp_index);
return m_sessions[temp_index];
}
template <class C, size_t SessionCache>
void SSLClient<C, SessionCache>::removeSession(const char* host, const IPAddress& addr) {
const char* func_name = __func__;
int temp_index = m_getSessionIndex(host, addr);
if (temp_index != -1) {
m_info(" Deleted session ", func_name);
m_info(temp_index, func_name);
m_sessions[temp_index] = SSLSession();
}
}
template <class C, size_t SessionCache>
int SSLClient<C, SessionCache>::m_getSessionIndex(const char* host, const IPAddress& addr) const {
const char* func_name = __func__;
// search for a matching session with the IP
for (uint8_t i = 0; i < SessionCache; i++) {
// if we're looking at a real session
if (m_sessions[i].is_valid_session()
&& (
// and the hostname matches, or
(host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0)
// there is no hostname and the IP address matches
|| (host == NULL && addr == m_sessions[i].get_ip())
)) {
m_info("Found session match: ", func_name);
m_info(m_sessions[i].get_hostname(), func_name);
return i;
}
}
// none found
return -1;
}
#endif /** SSLClient_H_ */

View file

@ -45,7 +45,7 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a
int SSLClientImpl::connect(IPAddress ip, uint16_t port) {
const char* func_name = __func__;
// connection check
if (connected()) {
if (m_client->connected()) {
m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name);
return -1;
}
@ -68,7 +68,7 @@ int SSLClientImpl::connect(IPAddress ip, uint16_t port) {
int SSLClientImpl::connect(const char *host, uint16_t port) {
const char* func_name = __func__;
// connection check
if (connected()) {
if (m_client->connected()) {
m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name);
return -1;
}
@ -148,7 +148,7 @@ int SSLClientImpl::available() {
br_ssl_engine_recvapp_buf(&m_sslctx.eng, &alen);
return (int)(alen);
}
else if (state == BR_SSL_CLOSED) m_warn("Engine closed after update", func_name);
else if (state == BR_SSL_CLOSED) m_info("Engine closed after update", func_name);
// flush the buffer if it's stuck in the SENDAPP state
else if (state & BR_SSL_SENDAPP) br_ssl_engine_flush(&m_sslctx.eng, 0);
// other state, or client is closed
@ -194,8 +194,6 @@ void SSLClientImpl::flush() {
void SSLClientImpl::stop() {
// tell the SSL connection to gracefully close
br_ssl_engine_close(&m_sslctx.eng);
// info about the socket connection
if (br_ssl_engine_current_state(&m_sslctx.eng) == BR_SSL_CLOSED) m_info("Socket was terminated before graceful closure (probably fine)", __func__);
// if the engine isn't closed, and the socket is still open
while (br_ssl_engine_current_state(&m_sslctx.eng) != BR_SSL_CLOSED
&& m_run_until(BR_SSL_RECVAPP) == 0) {
@ -220,8 +218,14 @@ uint8_t SSLClientImpl::connected() {
const auto wr_ok = getWriteError() == 0;
// if we're in an error state, close the connection and set a write error
if (br_con && !c_con) {
m_error("Socket was unexpectedly interrupted. m_client error: ", func_name);
m_error(m_client->getWriteError(), func_name);
// If we've got a write error, the client probably failed for some reason
if (m_client->getWriteError()) {
m_error("Socket was unexpectedly interrupted. m_client error: ", func_name);
m_error(m_client->getWriteError(), func_name);
}
// Else tell the user the endpoint closed the socket on us (ouch)
else m_warn("Socket was dropped unexpectedly (this can be an alternative to closing the connection)", func_name);
// set the write error so the engine doesn't try to close the connection
setWriteError(SSL_CLIENT_WRTIE_ERROR);
stop();
}
@ -280,11 +284,18 @@ int SSLClientImpl::m_start_ssl(const char* host, SSLSession& ssl_ses) {
m_print_br_error(br_ssl_engine_last_error(&m_sslctx.eng), SSL_ERROR);
return 0;
}
m_info("Connection successful!", func_name);
// all good to go! the SSL socket should be up and running
// overwrite the session we got with new parameters
br_ssl_engine_get_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session());
// set the hostname and ip in the session as well
ssl_ses.set_parameters(remoteIP(), host);
// print the handshake cipher chioce
m_info("Cipher suite: ", func_name);
if (m_debug >= SSL_INFO) {
m_print_prefix(func_name, SSL_INFO);
Serial.println(ssl_ses.cipher_suite, HEX);
}
return 1;
}
@ -293,6 +304,7 @@ int SSLClientImpl::m_run_until(const unsigned target) {
const char* func_name = __func__;
unsigned lastState = 0;
size_t lastLen = 0;
const unsigned long start = millis();
for (;;) {
unsigned state = m_update_engine();
// error check
@ -300,11 +312,20 @@ int SSLClientImpl::m_run_until(const unsigned target) {
m_warn("Tried to run_until when the engine is closed", func_name);
return -1;
}
// timeout check
if (millis() - start > getTimeout()) {
m_error("SSL internals timed out! This could be an internal error or bad data sent from the server", func_name);
setWriteError(SSL_BR_WRITE_ERROR);
stop();
return -1;
}
// debug
if (state != lastState) {
lastState = state;
m_info("m_run waiting:", func_name);
m_info("m_run changed state:", func_name);
printState(state);
m_info("Memory: ", func_name);
m_info(freeMemory(), func_name);
}
if (state & BR_SSL_RECVREC) {
size_t len;
@ -455,7 +476,8 @@ unsigned SSLClientImpl::m_update_engine() {
m_info("Read bytes from client: ", func_name);
m_info(avail, func_name);
m_info(len, func_name);
m_info("Memory: ", func_name);
m_info(freeMemory(), func_name);
// I suppose so!
int rlen = m_client->read(buf, len);
if (rlen <= 0) {
@ -495,20 +517,20 @@ void SSLClientImpl::m_print_prefix(const char* func_name, const DebugLevel level
Serial.print("(SSLClient)");
// print the debug level
switch (level) {
case SSL_INFO: Serial.print("SSL_INFO"); break;
case SSL_WARN: Serial.print("SSL_WARN"); break;
case SSL_ERROR: Serial.print("SSL_ERROR"); break;
default: Serial.print("Unknown level");
case SSL_INFO: Serial.print("(SSL_INFO)"); break;
case SSL_WARN: Serial.print("(SSL_WARN)"); break;
case SSL_ERROR: Serial.print("(SSL_ERROR)"); break;
default: Serial.print("(Unknown level)");
}
// print the function name
Serial.print("(");
Serial.print(func_name);
// get ready
Serial.print(": ");
Serial.print("): ");
}
/** See SSLClientImpl.h */
void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel level) const {
if (level < m_debug) return;
if (level > m_debug) return;
m_print_prefix(__func__, level);
switch(ssl_error) {
case SSL_OK: Serial.println("SSL_OK"); break;
@ -522,7 +544,7 @@ void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel leve
/* See SSLClientImpl.h */
void SSLClientImpl::m_print_br_error(const unsigned br_error_code, const DebugLevel level) const {
if (level < m_debug) return;
if (level > m_debug) return;
m_print_prefix(__func__, level);
switch (br_error_code) {
case BR_ERR_BAD_PARAM: Serial.println("Caller-provided parameter is incorrect."); break;

View file

@ -45,11 +45,31 @@ enum Error {
*/
enum DebugLevel {
SSL_NONE = 0,
SSL_INFO = 1,
SSL_ERROR = 1,
SSL_WARN = 2,
SSL_ERROR = 3
SSL_INFO = 3,
};
#ifdef __arm__
// should use uinstd.h to define sbrk but Due causes a conflict
extern "C" char* sbrk(int incr);
#else // __ARM__
extern char *__brkval;
#endif // __arm__
static int freeMemory() {
char top;
#ifdef __arm__
return &top - reinterpret_cast<char*>(sbrk(0));
#elif defined(CORE_TEENSY) || (ARDUINO > 103 && ARDUINO != 151)
return &top - __brkval;
#else // __arm__
return __brkval ? &top - __brkval : &top - __malloc_heap_start;
#endif // __arm__
}
/** TODO: Write what this is */
class SSLClientImpl : public Client {
@ -170,7 +190,7 @@ protected:
template<typename T>
void m_print(const T str, const char* func_name, const DebugLevel level) const {
// check the current debug level
if (level < m_debug) return;
if (level > m_debug) return;
// print prefix
m_print_prefix(func_name, level);
// print the message
@ -226,7 +246,9 @@ private:
// can expand to a bi-directional buffer with maximum of BR_SSL_BUFSIZE_BIDI
// or shrink to below BR_SSL_BUFSIZE_MONO, and bearSSL will adapt automatically
// simply edit this value to change the buffer size to the desired value
unsigned char m_iobuf[BR_SSL_BUFSIZE_MONO];
// additionally, we need to correct buffer size based off of how many sessions we decide to cache
// since SSL takes so much memory if we don't it will cause the stack and heap to collide
unsigned char m_iobuf[BR_SSL_BUFSIZE_MONO / 4];
static_assert(sizeof m_iobuf <= BR_SSL_BUFSIZE_BIDI, "m_iobuf must be below maximum buffer size");
// store the index of where we are writing in the buffer
// so we can send our records all at once to prevent

View file

@ -418,6 +418,7 @@ br_client_init_TLS12_only(br_ssl_client_context *cc,
* supported hash function is appropriate; here we use SHA-256.
* The trust an
*/
memset(xc, 0, sizeof *xc);
br_x509_minimal_init(xc, &br_sha256_vtable,
trust_anchors, trust_anchors_num);

View file

@ -119,6 +119,7 @@ br_ssl_client_init_full(br_ssl_client_context *cc,
* to TLS-1.2 (inclusive).
*/
br_ssl_client_zero(cc);
memset(xc, 0, sizeof *xc);
br_ssl_engine_set_versions(&cc->eng, BR_TLS10, BR_TLS12);
/*