diff --git a/src/SSLClient.h b/src/SSLClient.h index b47758a..83d8446 100644 --- a/src/SSLClient.h +++ b/src/SSLClient.h @@ -67,6 +67,7 @@ class SSLClient : public SSLClientImpl { */ static_assert(std::is_base_of::value, "C must be a Client Class!"); static_assert(SessionCache > 0 && SessionCache < 255, "There can be no less than one and no more than 255 sessions in the cache!"); +static_assert(SessionCache <= 3, "You need to decrease the size of m_iobuf in order to have more than 3 sessions at once, otherwise memory issues will occur."); // static_assert(std::is_function::value, "C must have a status() function!"); public: @@ -90,12 +91,16 @@ public: explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug = SSL_ERROR) : SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, debug) , m_client(client) - , m_sessions{} + , m_sessions{SSLSession()} , m_index(0) { + // for (uint8_t i = 0; i < SessionCache; i++) m_sessions[i] = SSLSession(); // since we are copying the client in the ctor, we have to set // the client pointer after the class is constructed set_client(&m_client); + // set the timeout to a reasonable number (it can always be changes later) + // SSL Connections take a really long time so we don't want to time out a legitimate thing + setTimeout(10 * 1000); } /* @@ -114,38 +119,13 @@ public: //! get the client object C& getClient() { return m_client; } - virtual SSLSession& getSession(const char* host, const IPAddress& addr) { - // search for a matching session with the IP - int temp_index = -1; - for (size_t i = 0; i < SessionCache; i++) { - // if we're looking at a real session - if (m_sessions[i].is_valid_session() - && ( - // and the hostname matches, or - (host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0) - // there is no hostname and the IP address matches - || (host == NULL && addr == m_sessions[i].get_ip()) - )) { + virtual SSLSession& getSession(const char* host, const IPAddress& addr); - temp_index = i; - break; - } - } - // if none are availible, use m_index - if (temp_index == -1) { - temp_index = m_index; - // reset the session so we don't try to send one sites session to another - m_sessions[temp_index] = SSLSession(); - } - // increment m_index so the session cache is a circular buffer - if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0; - // return the pointed to value - m_info("Using session index: ", __func__); - Serial.println(temp_index); - return m_sessions[temp_index]; - } + virtual void removeSession(const char* host, const IPAddress& addr); private: + // utility function to find a session index based off of a host and IP + int m_getSessionIndex(const char* host, const IPAddress& addr) const; // create a copy of the client C m_client; // also store an array of SSLSessions, so we can resume communication with multiple websites @@ -154,4 +134,56 @@ private: size_t m_index; }; +template +SSLSession& SSLClient::getSession(const char* host, const IPAddress& addr) { + const char* func_name = __func__; + // search for a matching session with the IP + int temp_index = m_getSessionIndex(host, addr); + // if none are availible, use m_index + if (temp_index == -1) { + temp_index = m_index; + // reset the session so we don't try to send one sites session to another + m_sessions[temp_index] = SSLSession(); + } + // increment m_index so the session cache is a circular buffer + if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0; + // return the pointed to value + m_info("Using session index: ", func_name); + Serial.println(temp_index); + return m_sessions[temp_index]; +} + +template +void SSLClient::removeSession(const char* host, const IPAddress& addr) { + const char* func_name = __func__; + int temp_index = m_getSessionIndex(host, addr); + if (temp_index != -1) { + m_info(" Deleted session ", func_name); + m_info(temp_index, func_name); + m_sessions[temp_index] = SSLSession(); + } +} + +template +int SSLClient::m_getSessionIndex(const char* host, const IPAddress& addr) const { + const char* func_name = __func__; + // search for a matching session with the IP + for (uint8_t i = 0; i < SessionCache; i++) { + // if we're looking at a real session + if (m_sessions[i].is_valid_session() + && ( + // and the hostname matches, or + (host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0) + // there is no hostname and the IP address matches + || (host == NULL && addr == m_sessions[i].get_ip()) + )) { + m_info("Found session match: ", func_name); + m_info(m_sessions[i].get_hostname(), func_name); + return i; + } + } + // none found + return -1; +} + #endif /** SSLClient_H_ */ \ No newline at end of file diff --git a/src/SSLClientImpl.cpp b/src/SSLClientImpl.cpp index d4e85fd..997fcb9 100644 --- a/src/SSLClientImpl.cpp +++ b/src/SSLClientImpl.cpp @@ -45,7 +45,7 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a int SSLClientImpl::connect(IPAddress ip, uint16_t port) { const char* func_name = __func__; // connection check - if (connected()) { + if (m_client->connected()) { m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name); return -1; } @@ -68,7 +68,7 @@ int SSLClientImpl::connect(IPAddress ip, uint16_t port) { int SSLClientImpl::connect(const char *host, uint16_t port) { const char* func_name = __func__; // connection check - if (connected()) { + if (m_client->connected()) { m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name); return -1; } @@ -148,7 +148,7 @@ int SSLClientImpl::available() { br_ssl_engine_recvapp_buf(&m_sslctx.eng, &alen); return (int)(alen); } - else if (state == BR_SSL_CLOSED) m_warn("Engine closed after update", func_name); + else if (state == BR_SSL_CLOSED) m_info("Engine closed after update", func_name); // flush the buffer if it's stuck in the SENDAPP state else if (state & BR_SSL_SENDAPP) br_ssl_engine_flush(&m_sslctx.eng, 0); // other state, or client is closed @@ -194,8 +194,6 @@ void SSLClientImpl::flush() { void SSLClientImpl::stop() { // tell the SSL connection to gracefully close br_ssl_engine_close(&m_sslctx.eng); - // info about the socket connection - if (br_ssl_engine_current_state(&m_sslctx.eng) == BR_SSL_CLOSED) m_info("Socket was terminated before graceful closure (probably fine)", __func__); // if the engine isn't closed, and the socket is still open while (br_ssl_engine_current_state(&m_sslctx.eng) != BR_SSL_CLOSED && m_run_until(BR_SSL_RECVAPP) == 0) { @@ -220,8 +218,14 @@ uint8_t SSLClientImpl::connected() { const auto wr_ok = getWriteError() == 0; // if we're in an error state, close the connection and set a write error if (br_con && !c_con) { - m_error("Socket was unexpectedly interrupted. m_client error: ", func_name); - m_error(m_client->getWriteError(), func_name); + // If we've got a write error, the client probably failed for some reason + if (m_client->getWriteError()) { + m_error("Socket was unexpectedly interrupted. m_client error: ", func_name); + m_error(m_client->getWriteError(), func_name); + } + // Else tell the user the endpoint closed the socket on us (ouch) + else m_warn("Socket was dropped unexpectedly (this can be an alternative to closing the connection)", func_name); + // set the write error so the engine doesn't try to close the connection setWriteError(SSL_CLIENT_WRTIE_ERROR); stop(); } @@ -280,11 +284,18 @@ int SSLClientImpl::m_start_ssl(const char* host, SSLSession& ssl_ses) { m_print_br_error(br_ssl_engine_last_error(&m_sslctx.eng), SSL_ERROR); return 0; } + m_info("Connection successful!", func_name); // all good to go! the SSL socket should be up and running // overwrite the session we got with new parameters br_ssl_engine_get_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session()); // set the hostname and ip in the session as well ssl_ses.set_parameters(remoteIP(), host); + // print the handshake cipher chioce + m_info("Cipher suite: ", func_name); + if (m_debug >= SSL_INFO) { + m_print_prefix(func_name, SSL_INFO); + Serial.println(ssl_ses.cipher_suite, HEX); + } return 1; } @@ -293,6 +304,7 @@ int SSLClientImpl::m_run_until(const unsigned target) { const char* func_name = __func__; unsigned lastState = 0; size_t lastLen = 0; + const unsigned long start = millis(); for (;;) { unsigned state = m_update_engine(); // error check @@ -300,11 +312,20 @@ int SSLClientImpl::m_run_until(const unsigned target) { m_warn("Tried to run_until when the engine is closed", func_name); return -1; } + // timeout check + if (millis() - start > getTimeout()) { + m_error("SSL internals timed out! This could be an internal error or bad data sent from the server", func_name); + setWriteError(SSL_BR_WRITE_ERROR); + stop(); + return -1; + } // debug if (state != lastState) { lastState = state; - m_info("m_run waiting:", func_name); + m_info("m_run changed state:", func_name); printState(state); + m_info("Memory: ", func_name); + m_info(freeMemory(), func_name); } if (state & BR_SSL_RECVREC) { size_t len; @@ -455,7 +476,8 @@ unsigned SSLClientImpl::m_update_engine() { m_info("Read bytes from client: ", func_name); m_info(avail, func_name); m_info(len, func_name); - + m_info("Memory: ", func_name); + m_info(freeMemory(), func_name); // I suppose so! int rlen = m_client->read(buf, len); if (rlen <= 0) { @@ -495,20 +517,20 @@ void SSLClientImpl::m_print_prefix(const char* func_name, const DebugLevel level Serial.print("(SSLClient)"); // print the debug level switch (level) { - case SSL_INFO: Serial.print("SSL_INFO"); break; - case SSL_WARN: Serial.print("SSL_WARN"); break; - case SSL_ERROR: Serial.print("SSL_ERROR"); break; - default: Serial.print("Unknown level"); + case SSL_INFO: Serial.print("(SSL_INFO)"); break; + case SSL_WARN: Serial.print("(SSL_WARN)"); break; + case SSL_ERROR: Serial.print("(SSL_ERROR)"); break; + default: Serial.print("(Unknown level)"); } // print the function name + Serial.print("("); Serial.print(func_name); - // get ready - Serial.print(": "); + Serial.print("): "); } /** See SSLClientImpl.h */ void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel level) const { - if (level < m_debug) return; + if (level > m_debug) return; m_print_prefix(__func__, level); switch(ssl_error) { case SSL_OK: Serial.println("SSL_OK"); break; @@ -522,7 +544,7 @@ void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel leve /* See SSLClientImpl.h */ void SSLClientImpl::m_print_br_error(const unsigned br_error_code, const DebugLevel level) const { - if (level < m_debug) return; + if (level > m_debug) return; m_print_prefix(__func__, level); switch (br_error_code) { case BR_ERR_BAD_PARAM: Serial.println("Caller-provided parameter is incorrect."); break; diff --git a/src/SSLClientImpl.h b/src/SSLClientImpl.h index e91bc68..7ea6774 100644 --- a/src/SSLClientImpl.h +++ b/src/SSLClientImpl.h @@ -45,11 +45,31 @@ enum Error { */ enum DebugLevel { SSL_NONE = 0, - SSL_INFO = 1, + SSL_ERROR = 1, SSL_WARN = 2, - SSL_ERROR = 3 + SSL_INFO = 3, }; + +#ifdef __arm__ +// should use uinstd.h to define sbrk but Due causes a conflict +extern "C" char* sbrk(int incr); +#else // __ARM__ +extern char *__brkval; +#endif // __arm__ + +static int freeMemory() { + char top; +#ifdef __arm__ + return &top - reinterpret_cast(sbrk(0)); +#elif defined(CORE_TEENSY) || (ARDUINO > 103 && ARDUINO != 151) + return &top - __brkval; +#else // __arm__ + return __brkval ? &top - __brkval : &top - __malloc_heap_start; +#endif // __arm__ +} + + /** TODO: Write what this is */ class SSLClientImpl : public Client { @@ -170,7 +190,7 @@ protected: template void m_print(const T str, const char* func_name, const DebugLevel level) const { // check the current debug level - if (level < m_debug) return; + if (level > m_debug) return; // print prefix m_print_prefix(func_name, level); // print the message @@ -226,7 +246,9 @@ private: // can expand to a bi-directional buffer with maximum of BR_SSL_BUFSIZE_BIDI // or shrink to below BR_SSL_BUFSIZE_MONO, and bearSSL will adapt automatically // simply edit this value to change the buffer size to the desired value - unsigned char m_iobuf[BR_SSL_BUFSIZE_MONO]; + // additionally, we need to correct buffer size based off of how many sessions we decide to cache + // since SSL takes so much memory if we don't it will cause the stack and heap to collide + unsigned char m_iobuf[BR_SSL_BUFSIZE_MONO / 4]; static_assert(sizeof m_iobuf <= BR_SSL_BUFSIZE_BIDI, "m_iobuf must be below maximum buffer size"); // store the index of where we are writing in the buffer // so we can send our records all at once to prevent diff --git a/src/TLS12_only_profile.c b/src/TLS12_only_profile.c index e0b8295..bfe242a 100644 --- a/src/TLS12_only_profile.c +++ b/src/TLS12_only_profile.c @@ -418,6 +418,7 @@ br_client_init_TLS12_only(br_ssl_client_context *cc, * supported hash function is appropriate; here we use SHA-256. * The trust an */ + memset(xc, 0, sizeof *xc); br_x509_minimal_init(xc, &br_sha256_vtable, trust_anchors, trust_anchors_num); diff --git a/src/bearssl/src/ssl/ssl_client_full.c b/src/bearssl/src/ssl/ssl_client_full.c index fd35b3c..bc34e92 100644 --- a/src/bearssl/src/ssl/ssl_client_full.c +++ b/src/bearssl/src/ssl/ssl_client_full.c @@ -119,6 +119,7 @@ br_ssl_client_init_full(br_ssl_client_context *cc, * to TLS-1.2 (inclusive). */ br_ssl_client_zero(cc); + memset(xc, 0, sizeof *xc); br_ssl_engine_set_versions(&cc->eng, BR_TLS10, BR_TLS12); /*