debugged stack overflow error, fixed connection timeout and some error flow problems

This commit is contained in:
Noah Laptop 2019-03-12 16:59:45 -07:00
parent c212e355a4
commit 7f72073fa6
5 changed files with 129 additions and 51 deletions

View file

@ -67,6 +67,7 @@ class SSLClient : public SSLClientImpl {
*/ */
static_assert(std::is_base_of<Client, C>::value, "C must be a Client Class!"); static_assert(std::is_base_of<Client, C>::value, "C must be a Client Class!");
static_assert(SessionCache > 0 && SessionCache < 255, "There can be no less than one and no more than 255 sessions in the cache!"); static_assert(SessionCache > 0 && SessionCache < 255, "There can be no less than one and no more than 255 sessions in the cache!");
static_assert(SessionCache <= 3, "You need to decrease the size of m_iobuf in order to have more than 3 sessions at once, otherwise memory issues will occur.");
// static_assert(std::is_function<decltype(C::status)>::value, "C must have a status() function!"); // static_assert(std::is_function<decltype(C::status)>::value, "C must have a status() function!");
public: public:
@ -90,12 +91,16 @@ public:
explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug = SSL_ERROR) explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug = SSL_ERROR)
: SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, debug) : SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, debug)
, m_client(client) , m_client(client)
, m_sessions{} , m_sessions{SSLSession()}
, m_index(0) , m_index(0)
{ {
// for (uint8_t i = 0; i < SessionCache; i++) m_sessions[i] = SSLSession();
// since we are copying the client in the ctor, we have to set // since we are copying the client in the ctor, we have to set
// the client pointer after the class is constructed // the client pointer after the class is constructed
set_client(&m_client); set_client(&m_client);
// set the timeout to a reasonable number (it can always be changes later)
// SSL Connections take a really long time so we don't want to time out a legitimate thing
setTimeout(10 * 1000);
} }
/* /*
@ -114,38 +119,13 @@ public:
//! get the client object //! get the client object
C& getClient() { return m_client; } C& getClient() { return m_client; }
virtual SSLSession& getSession(const char* host, const IPAddress& addr) { virtual SSLSession& getSession(const char* host, const IPAddress& addr);
// search for a matching session with the IP
int temp_index = -1;
for (size_t i = 0; i < SessionCache; i++) {
// if we're looking at a real session
if (m_sessions[i].is_valid_session()
&& (
// and the hostname matches, or
(host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0)
// there is no hostname and the IP address matches
|| (host == NULL && addr == m_sessions[i].get_ip())
)) {
temp_index = i; virtual void removeSession(const char* host, const IPAddress& addr);
break;
}
}
// if none are availible, use m_index
if (temp_index == -1) {
temp_index = m_index;
// reset the session so we don't try to send one sites session to another
m_sessions[temp_index] = SSLSession();
}
// increment m_index so the session cache is a circular buffer
if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0;
// return the pointed to value
m_info("Using session index: ", __func__);
Serial.println(temp_index);
return m_sessions[temp_index];
}
private: private:
// utility function to find a session index based off of a host and IP
int m_getSessionIndex(const char* host, const IPAddress& addr) const;
// create a copy of the client // create a copy of the client
C m_client; C m_client;
// also store an array of SSLSessions, so we can resume communication with multiple websites // also store an array of SSLSessions, so we can resume communication with multiple websites
@ -154,4 +134,56 @@ private:
size_t m_index; size_t m_index;
}; };
template <class C, size_t SessionCache>
SSLSession& SSLClient<C, SessionCache>::getSession(const char* host, const IPAddress& addr) {
const char* func_name = __func__;
// search for a matching session with the IP
int temp_index = m_getSessionIndex(host, addr);
// if none are availible, use m_index
if (temp_index == -1) {
temp_index = m_index;
// reset the session so we don't try to send one sites session to another
m_sessions[temp_index] = SSLSession();
}
// increment m_index so the session cache is a circular buffer
if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0;
// return the pointed to value
m_info("Using session index: ", func_name);
Serial.println(temp_index);
return m_sessions[temp_index];
}
template <class C, size_t SessionCache>
void SSLClient<C, SessionCache>::removeSession(const char* host, const IPAddress& addr) {
const char* func_name = __func__;
int temp_index = m_getSessionIndex(host, addr);
if (temp_index != -1) {
m_info(" Deleted session ", func_name);
m_info(temp_index, func_name);
m_sessions[temp_index] = SSLSession();
}
}
template <class C, size_t SessionCache>
int SSLClient<C, SessionCache>::m_getSessionIndex(const char* host, const IPAddress& addr) const {
const char* func_name = __func__;
// search for a matching session with the IP
for (uint8_t i = 0; i < SessionCache; i++) {
// if we're looking at a real session
if (m_sessions[i].is_valid_session()
&& (
// and the hostname matches, or
(host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0)
// there is no hostname and the IP address matches
|| (host == NULL && addr == m_sessions[i].get_ip())
)) {
m_info("Found session match: ", func_name);
m_info(m_sessions[i].get_hostname(), func_name);
return i;
}
}
// none found
return -1;
}
#endif /** SSLClient_H_ */ #endif /** SSLClient_H_ */

View file

@ -45,7 +45,7 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a
int SSLClientImpl::connect(IPAddress ip, uint16_t port) { int SSLClientImpl::connect(IPAddress ip, uint16_t port) {
const char* func_name = __func__; const char* func_name = __func__;
// connection check // connection check
if (connected()) { if (m_client->connected()) {
m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name); m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name);
return -1; return -1;
} }
@ -68,7 +68,7 @@ int SSLClientImpl::connect(IPAddress ip, uint16_t port) {
int SSLClientImpl::connect(const char *host, uint16_t port) { int SSLClientImpl::connect(const char *host, uint16_t port) {
const char* func_name = __func__; const char* func_name = __func__;
// connection check // connection check
if (connected()) { if (m_client->connected()) {
m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name); m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name);
return -1; return -1;
} }
@ -148,7 +148,7 @@ int SSLClientImpl::available() {
br_ssl_engine_recvapp_buf(&m_sslctx.eng, &alen); br_ssl_engine_recvapp_buf(&m_sslctx.eng, &alen);
return (int)(alen); return (int)(alen);
} }
else if (state == BR_SSL_CLOSED) m_warn("Engine closed after update", func_name); else if (state == BR_SSL_CLOSED) m_info("Engine closed after update", func_name);
// flush the buffer if it's stuck in the SENDAPP state // flush the buffer if it's stuck in the SENDAPP state
else if (state & BR_SSL_SENDAPP) br_ssl_engine_flush(&m_sslctx.eng, 0); else if (state & BR_SSL_SENDAPP) br_ssl_engine_flush(&m_sslctx.eng, 0);
// other state, or client is closed // other state, or client is closed
@ -194,8 +194,6 @@ void SSLClientImpl::flush() {
void SSLClientImpl::stop() { void SSLClientImpl::stop() {
// tell the SSL connection to gracefully close // tell the SSL connection to gracefully close
br_ssl_engine_close(&m_sslctx.eng); br_ssl_engine_close(&m_sslctx.eng);
// info about the socket connection
if (br_ssl_engine_current_state(&m_sslctx.eng) == BR_SSL_CLOSED) m_info("Socket was terminated before graceful closure (probably fine)", __func__);
// if the engine isn't closed, and the socket is still open // if the engine isn't closed, and the socket is still open
while (br_ssl_engine_current_state(&m_sslctx.eng) != BR_SSL_CLOSED while (br_ssl_engine_current_state(&m_sslctx.eng) != BR_SSL_CLOSED
&& m_run_until(BR_SSL_RECVAPP) == 0) { && m_run_until(BR_SSL_RECVAPP) == 0) {
@ -220,8 +218,14 @@ uint8_t SSLClientImpl::connected() {
const auto wr_ok = getWriteError() == 0; const auto wr_ok = getWriteError() == 0;
// if we're in an error state, close the connection and set a write error // if we're in an error state, close the connection and set a write error
if (br_con && !c_con) { if (br_con && !c_con) {
m_error("Socket was unexpectedly interrupted. m_client error: ", func_name); // If we've got a write error, the client probably failed for some reason
m_error(m_client->getWriteError(), func_name); if (m_client->getWriteError()) {
m_error("Socket was unexpectedly interrupted. m_client error: ", func_name);
m_error(m_client->getWriteError(), func_name);
}
// Else tell the user the endpoint closed the socket on us (ouch)
else m_warn("Socket was dropped unexpectedly (this can be an alternative to closing the connection)", func_name);
// set the write error so the engine doesn't try to close the connection
setWriteError(SSL_CLIENT_WRTIE_ERROR); setWriteError(SSL_CLIENT_WRTIE_ERROR);
stop(); stop();
} }
@ -280,11 +284,18 @@ int SSLClientImpl::m_start_ssl(const char* host, SSLSession& ssl_ses) {
m_print_br_error(br_ssl_engine_last_error(&m_sslctx.eng), SSL_ERROR); m_print_br_error(br_ssl_engine_last_error(&m_sslctx.eng), SSL_ERROR);
return 0; return 0;
} }
m_info("Connection successful!", func_name);
// all good to go! the SSL socket should be up and running // all good to go! the SSL socket should be up and running
// overwrite the session we got with new parameters // overwrite the session we got with new parameters
br_ssl_engine_get_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session()); br_ssl_engine_get_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session());
// set the hostname and ip in the session as well // set the hostname and ip in the session as well
ssl_ses.set_parameters(remoteIP(), host); ssl_ses.set_parameters(remoteIP(), host);
// print the handshake cipher chioce
m_info("Cipher suite: ", func_name);
if (m_debug >= SSL_INFO) {
m_print_prefix(func_name, SSL_INFO);
Serial.println(ssl_ses.cipher_suite, HEX);
}
return 1; return 1;
} }
@ -293,6 +304,7 @@ int SSLClientImpl::m_run_until(const unsigned target) {
const char* func_name = __func__; const char* func_name = __func__;
unsigned lastState = 0; unsigned lastState = 0;
size_t lastLen = 0; size_t lastLen = 0;
const unsigned long start = millis();
for (;;) { for (;;) {
unsigned state = m_update_engine(); unsigned state = m_update_engine();
// error check // error check
@ -300,11 +312,20 @@ int SSLClientImpl::m_run_until(const unsigned target) {
m_warn("Tried to run_until when the engine is closed", func_name); m_warn("Tried to run_until when the engine is closed", func_name);
return -1; return -1;
} }
// timeout check
if (millis() - start > getTimeout()) {
m_error("SSL internals timed out! This could be an internal error or bad data sent from the server", func_name);
setWriteError(SSL_BR_WRITE_ERROR);
stop();
return -1;
}
// debug // debug
if (state != lastState) { if (state != lastState) {
lastState = state; lastState = state;
m_info("m_run waiting:", func_name); m_info("m_run changed state:", func_name);
printState(state); printState(state);
m_info("Memory: ", func_name);
m_info(freeMemory(), func_name);
} }
if (state & BR_SSL_RECVREC) { if (state & BR_SSL_RECVREC) {
size_t len; size_t len;
@ -455,7 +476,8 @@ unsigned SSLClientImpl::m_update_engine() {
m_info("Read bytes from client: ", func_name); m_info("Read bytes from client: ", func_name);
m_info(avail, func_name); m_info(avail, func_name);
m_info(len, func_name); m_info(len, func_name);
m_info("Memory: ", func_name);
m_info(freeMemory(), func_name);
// I suppose so! // I suppose so!
int rlen = m_client->read(buf, len); int rlen = m_client->read(buf, len);
if (rlen <= 0) { if (rlen <= 0) {
@ -495,20 +517,20 @@ void SSLClientImpl::m_print_prefix(const char* func_name, const DebugLevel level
Serial.print("(SSLClient)"); Serial.print("(SSLClient)");
// print the debug level // print the debug level
switch (level) { switch (level) {
case SSL_INFO: Serial.print("SSL_INFO"); break; case SSL_INFO: Serial.print("(SSL_INFO)"); break;
case SSL_WARN: Serial.print("SSL_WARN"); break; case SSL_WARN: Serial.print("(SSL_WARN)"); break;
case SSL_ERROR: Serial.print("SSL_ERROR"); break; case SSL_ERROR: Serial.print("(SSL_ERROR)"); break;
default: Serial.print("Unknown level"); default: Serial.print("(Unknown level)");
} }
// print the function name // print the function name
Serial.print("(");
Serial.print(func_name); Serial.print(func_name);
// get ready Serial.print("): ");
Serial.print(": ");
} }
/** See SSLClientImpl.h */ /** See SSLClientImpl.h */
void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel level) const { void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel level) const {
if (level < m_debug) return; if (level > m_debug) return;
m_print_prefix(__func__, level); m_print_prefix(__func__, level);
switch(ssl_error) { switch(ssl_error) {
case SSL_OK: Serial.println("SSL_OK"); break; case SSL_OK: Serial.println("SSL_OK"); break;
@ -522,7 +544,7 @@ void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel leve
/* See SSLClientImpl.h */ /* See SSLClientImpl.h */
void SSLClientImpl::m_print_br_error(const unsigned br_error_code, const DebugLevel level) const { void SSLClientImpl::m_print_br_error(const unsigned br_error_code, const DebugLevel level) const {
if (level < m_debug) return; if (level > m_debug) return;
m_print_prefix(__func__, level); m_print_prefix(__func__, level);
switch (br_error_code) { switch (br_error_code) {
case BR_ERR_BAD_PARAM: Serial.println("Caller-provided parameter is incorrect."); break; case BR_ERR_BAD_PARAM: Serial.println("Caller-provided parameter is incorrect."); break;

View file

@ -45,11 +45,31 @@ enum Error {
*/ */
enum DebugLevel { enum DebugLevel {
SSL_NONE = 0, SSL_NONE = 0,
SSL_INFO = 1, SSL_ERROR = 1,
SSL_WARN = 2, SSL_WARN = 2,
SSL_ERROR = 3 SSL_INFO = 3,
}; };
#ifdef __arm__
// should use uinstd.h to define sbrk but Due causes a conflict
extern "C" char* sbrk(int incr);
#else // __ARM__
extern char *__brkval;
#endif // __arm__
static int freeMemory() {
char top;
#ifdef __arm__
return &top - reinterpret_cast<char*>(sbrk(0));
#elif defined(CORE_TEENSY) || (ARDUINO > 103 && ARDUINO != 151)
return &top - __brkval;
#else // __arm__
return __brkval ? &top - __brkval : &top - __malloc_heap_start;
#endif // __arm__
}
/** TODO: Write what this is */ /** TODO: Write what this is */
class SSLClientImpl : public Client { class SSLClientImpl : public Client {
@ -170,7 +190,7 @@ protected:
template<typename T> template<typename T>
void m_print(const T str, const char* func_name, const DebugLevel level) const { void m_print(const T str, const char* func_name, const DebugLevel level) const {
// check the current debug level // check the current debug level
if (level < m_debug) return; if (level > m_debug) return;
// print prefix // print prefix
m_print_prefix(func_name, level); m_print_prefix(func_name, level);
// print the message // print the message
@ -226,7 +246,9 @@ private:
// can expand to a bi-directional buffer with maximum of BR_SSL_BUFSIZE_BIDI // can expand to a bi-directional buffer with maximum of BR_SSL_BUFSIZE_BIDI
// or shrink to below BR_SSL_BUFSIZE_MONO, and bearSSL will adapt automatically // or shrink to below BR_SSL_BUFSIZE_MONO, and bearSSL will adapt automatically
// simply edit this value to change the buffer size to the desired value // simply edit this value to change the buffer size to the desired value
unsigned char m_iobuf[BR_SSL_BUFSIZE_MONO]; // additionally, we need to correct buffer size based off of how many sessions we decide to cache
// since SSL takes so much memory if we don't it will cause the stack and heap to collide
unsigned char m_iobuf[BR_SSL_BUFSIZE_MONO / 4];
static_assert(sizeof m_iobuf <= BR_SSL_BUFSIZE_BIDI, "m_iobuf must be below maximum buffer size"); static_assert(sizeof m_iobuf <= BR_SSL_BUFSIZE_BIDI, "m_iobuf must be below maximum buffer size");
// store the index of where we are writing in the buffer // store the index of where we are writing in the buffer
// so we can send our records all at once to prevent // so we can send our records all at once to prevent

View file

@ -418,6 +418,7 @@ br_client_init_TLS12_only(br_ssl_client_context *cc,
* supported hash function is appropriate; here we use SHA-256. * supported hash function is appropriate; here we use SHA-256.
* The trust an * The trust an
*/ */
memset(xc, 0, sizeof *xc);
br_x509_minimal_init(xc, &br_sha256_vtable, br_x509_minimal_init(xc, &br_sha256_vtable,
trust_anchors, trust_anchors_num); trust_anchors, trust_anchors_num);

View file

@ -119,6 +119,7 @@ br_ssl_client_init_full(br_ssl_client_context *cc,
* to TLS-1.2 (inclusive). * to TLS-1.2 (inclusive).
*/ */
br_ssl_client_zero(cc); br_ssl_client_zero(cc);
memset(xc, 0, sizeof *xc);
br_ssl_engine_set_versions(&cc->eng, BR_TLS10, BR_TLS12); br_ssl_engine_set_versions(&cc->eng, BR_TLS10, BR_TLS12);
/* /*