From ab0cf9d52be9faa5b8270fd3f938ae9e64f30288 Mon Sep 17 00:00:00 2001 From: Noah Laptop Date: Thu, 7 Mar 2019 18:15:39 -0800 Subject: [PATCH] implemented session cache of size n, need to figure out failure cases and account for them --- src/SSLClient.h | 40 +++++++++++++++++++++++++++++++- src/SSLClientImpl.cpp | 54 ++++++++++++++++++++++--------------------- src/SSLClientImpl.h | 11 +++++---- src/SSLSession.h | 6 ++--- 4 files changed, 76 insertions(+), 35 deletions(-) diff --git a/src/SSLClient.h b/src/SSLClient.h index f25d4d8..8310bea 100644 --- a/src/SSLClient.h +++ b/src/SSLClient.h @@ -70,7 +70,7 @@ enum Error { * from the client side, however from the developer side it can be a bit confusing. */ -template +template class SSLClient : public SSLClientImpl { /** static type checks * I'm a java developer, so I want to ensure that my inheritance is safe. @@ -79,6 +79,7 @@ class SSLClient : public SSLClientImpl { * class inherits from Client, and then that it contains a status() function. */ static_assert(std::is_base_of::value, "C must be a Client Class!"); +static_assert(SessionCache > 0 && SessionCache < 255, "There can be no less than one and no more than 255 sessions in the cache!"); // static_assert(std::is_function::value, "C must have a status() function!"); public: @@ -102,6 +103,8 @@ public: explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const bool debug = true) : SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, debug) , m_client(client) + , m_sessions{} + , m_index(0) { // since we are copying the client in the ctor, we have to set // the client pointer after the class is constructed @@ -124,9 +127,44 @@ public: //! get the client object C& getClient() { return m_client; } + virtual SSLSession& getSession(const char* host, const IPAddress& addr) { + // search for a matching session with the IP + int temp_index = -1; + for (size_t i = 0; i < SessionCache; i++) { + // if we're looking at a real session + if (m_sessions[i].is_valid_session() + && ( + // and the hostname matches, or + (host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0) + // there is no hostname and the IP address matches + || (host == NULL && addr == m_sessions[i].get_ip()) + )) { + + temp_index = i; + break; + } + } + // if none are availible, use m_index + if (temp_index == -1) { + temp_index = m_index; + // reset the session so we don't try to send one sites session to another + m_sessions[temp_index] = SSLSession(); + } + // increment m_index so the session cache is a circular buffer + if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0; + // return the pointed to value + m_print("Using index: "); + m_print(temp_index); + return m_sessions[temp_index]; + } + private: // create a copy of the client C m_client; + // also store an array of SSLSessions, so we can resume communication with multiple websites + SSLSession m_sessions[SessionCache]; + // store an index of where a new session can be placed if we don't have any corresponding sessions + size_t m_index; }; #endif /** SSLClient_H_ */ \ No newline at end of file diff --git a/src/SSLClientImpl.cpp b/src/SSLClientImpl.cpp index 3217766..66f1042 100644 --- a/src/SSLClientImpl.cpp +++ b/src/SSLClientImpl.cpp @@ -28,13 +28,13 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a , m_trust_anchors_num(trust_anchors_num) , m_analog_pin(analog_pin) , m_debug(debug) - , m_write_idx(0) - , m_session() { + , m_write_idx(0) { // zero the iobuf just in case it's still garbage memset(m_iobuf, 0, sizeof m_iobuf); // initlalize the various bearssl libraries so they're ready to go when we connect br_client_init_TLS12_only(&m_sslctx, &m_x509ctx, m_trust_anchors, m_trust_anchors_num); + // comment the above line and uncomment the line below if you're having trouble connecting over SSL // br_ssl_client_init_full(&m_sslctx, &m_x509ctx, m_trust_anchors, m_trust_anchors_num); // check if the buffer size is half or full duplex constexpr auto duplex = sizeof m_iobuf <= BR_SSL_BUFSIZE_MONO ? 0 : 1; @@ -43,6 +43,11 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a /* see SSLClientImpl.h*/ int SSLClientImpl::connect(IPAddress ip, uint16_t port) { + // connection check + if (connected()) { + m_print("Error: cannot have two connections at the same time! Please create another SSLClient instance."); + return -1; + } // reset indexs for saftey m_write_idx = 0; // Warning for security @@ -55,21 +60,26 @@ int SSLClientImpl::connect(IPAddress ip, uint16_t port) { return 0; } m_print("Base ethernet client connected!"); - return m_start_ssl(); + return m_start_ssl(NULL, getSession(NULL, ip)); } /* see SSLClientImpl.h*/ int SSLClientImpl::connect(const char *host, uint16_t port) { + // connection check + if (connected()) { + m_print("Error: cannot have two connections at the same time! Please create another SSLClient instance."); + return -1; + } // reset indexs for saftey m_write_idx = 0; // first, if we have a session, check if we're trying to resolve the same host // as before bool connect_ok; - if (m_session.is_valid_session() - && strcmp(m_session.get_hostname(), host) == 0) { + SSLSession& ses = getSession(host, INADDR_NONE); + if (ses.is_valid_session()) { // if so, then connect using the stored session m_print("Connecting using a cached IP"); - connect_ok = m_client->connect(m_session.get_ip(), port); + connect_ok = m_client->connect(ses.get_ip(), port); } // else connect with the provided hostname else connect_ok = m_client->connect(host, port); @@ -82,7 +92,7 @@ int SSLClientImpl::connect(const char *host, uint16_t port) { } m_print("Base ethernet client connected!"); // start ssl! - return m_start_ssl(host); + return m_start_ssl(host, ses); } /** see SSLClientImpl.h*/ @@ -229,7 +239,7 @@ uint8_t SSLClientImpl::connected() { } /** see SSLClientImpl.h */ -int SSLClientImpl::m_start_ssl(const char* host) { +int SSLClientImpl::m_start_ssl(const char* host, SSLSession& ssl_ses) { // clear the write error setWriteError(SSL_OK); // get some random data by reading the analog pin we've been handed @@ -239,8 +249,8 @@ int SSLClientImpl::m_start_ssl(const char* host) { for (uint8_t i = 0; i < sizeof rng_seeds; i++) rng_seeds[i] = static_cast(analogRead(m_analog_pin)); br_ssl_engine_inject_entropy(&m_sslctx.eng, rng_seeds, sizeof rng_seeds); // inject session parameters for faster reconnection, if we have any - if(m_session.is_valid_session()) { - br_ssl_engine_set_session_parameters(&m_sslctx.eng, m_session.to_br_session()); + if(ssl_ses.is_valid_session()) { + br_ssl_engine_set_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session()); m_print("Set session!"); } // reset the engine, but make sure that it reset successfully @@ -259,17 +269,18 @@ int SSLClientImpl::m_start_ssl(const char* host) { return 0; } // all good to go! the SSL socket should be up and running - // debug print the session parameters to see if they exist - br_ssl_engine_get_session_parameters(&m_sslctx.eng, m_session.to_br_session()); + // overwrite the session we got with new parameters + br_ssl_engine_get_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session()); // set the hostname and ip in the session as well - m_session.set_parameters(remoteIP(), host); + ssl_ses.set_parameters(remoteIP(), host); + // print the session details m_print("Session:"); - for (uint8_t i = 0; i < m_session.session_id_len; i++) { + for (uint8_t i = 0; i < ssl_ses.session_id_len; i++) { Serial.print(", 0x"); - Serial.print(m_session.session_id[i], HEX); + Serial.print(ssl_ses.session_id[i], HEX); } Serial.println(); - Serial.println(m_session.cipher_suite, HEX); + Serial.println(ssl_ses.cipher_suite, HEX); return 1; } @@ -292,7 +303,7 @@ int SSLClientImpl::m_run_until(const unsigned target) { } if (state & BR_SSL_RECVREC) { size_t len; - unsigned char * buf = br_ssl_engine_recvrec_buf(&m_sslctx.eng, &len); + br_ssl_engine_recvrec_buf(&m_sslctx.eng, &len); if (lastLen != len) { m_print("Expected bytes count: "); m_print(lastLen = len); @@ -355,15 +366,6 @@ unsigned SSLClientImpl::m_update_engine() { int wlen; buf = br_ssl_engine_sendrec_buf(&m_sslctx.eng, &len); - Serial.print("Payload: "); - for (int i = 0; i < len; i++) { - if (buf[i] <= 0x0f) Serial.print("0x0"); - else Serial.print("0x"); - Serial.print(buf[i], HEX); - Serial.print(", "); - } - Serial.println(); - //delay(100); wlen = m_client->write(buf, len); // let the chip recover if (wlen < 0) { diff --git a/src/SSLClientImpl.h b/src/SSLClientImpl.h index ea9ce0f..a8e7345 100644 --- a/src/SSLClientImpl.h +++ b/src/SSLClientImpl.h @@ -120,6 +120,9 @@ public: virtual uint16_t localPort() = 0; virtual IPAddress remoteIP() = 0; virtual uint16_t remotePort() = 0; + + // as well as store and retrieve session data + virtual SSLSession& getSession(const char* host, const IPAddress& addr) = 0; protected: /** * @brief set the pointer to the Client class that we wil use @@ -130,8 +133,6 @@ protected: */ void set_client(Client* c) { m_client = c; } -private: - /** @brief debugging print function, only prints if m_debug is true */ template constexpr void m_print(const T str) const { @@ -141,6 +142,8 @@ private: } } +private: + void printState(unsigned state) const { if(m_debug) { m_print("State: "); @@ -155,7 +158,7 @@ private: } } /** start the ssl engine on the connected client */ - int m_start_ssl(const char* host = NULL); + int m_start_ssl(const char* host, SSLSession& ssl_ses); /** run the bearssl engine until a certain state */ int m_run_until(const unsigned target); /** proxy for availble that returns the state */ @@ -183,8 +186,6 @@ private: // so we can send our records all at once to prevent // weird timing issues size_t m_write_idx; - // store the last SSL session, so reconnection later is speedy fast - SSLSession m_session; }; #endif /* SSLClientImpl_H_ */ \ No newline at end of file diff --git a/src/SSLSession.h b/src/SSLSession.h index a4e1f22..042091f 100644 --- a/src/SSLSession.h +++ b/src/SSLSession.h @@ -52,7 +52,7 @@ class SSLSession : public br_ssl_session_parameters { public: explicit SSLSession() : m_valid_session(false) - , m_hostname({}) + , m_hostname{} , m_ip(INADDR_NONE) {} /** @@ -75,14 +75,14 @@ public: /** * \pre must check isValidSession */ - const char* const get_hostname() const { return m_hostname; } + const char* get_hostname() const { return m_hostname; } /** * \pre must check isValidSession */ const IPAddress& get_ip() const { return m_ip; } - const bool is_valid_session() const { return m_valid_session; } + bool is_valid_session() const { return m_valid_session; } private: bool m_valid_session; // aparently a hostname has a max length of 256 chars. Go figure.