diff --git a/src/SSLClient.h b/src/SSLClient.h index 19b5d52..945c59d 100644 --- a/src/SSLClient.h +++ b/src/SSLClient.h @@ -99,14 +99,13 @@ public: * @param debug whether to enable or disable debug logging, must be constexpr */ explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug = SSL_WARN) - : SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, debug) + : SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, NULL, debug) , m_client(client) , m_sessions{SSLSession()} - , m_index(0) { // since we are copying the client in the ctor, we have to set // the client pointer after the class is constructed - set_client(&m_client); + set_client(&m_client, m_sessions); // set the timeout to a reasonable number (it can always be changes later) // SSL Connections take a really long time so we don't want to time out a legitimate thing setTimeout(10 * 1000); @@ -371,6 +370,38 @@ public: //= Functions Not in the Client Interface //======================================== + /** + * @brief Get a sesssion reference corressponding to a host and IP, or a reference to a emptey session if none exist + * + * If no session corresponding to the host and ip exist, then this function will cycle through + * sessions in a rotating order. This allows the ssession cache to continuially store sessions, + * however it will also result in old sessions being cleared and returned. In general, it is a + * good idea to use a SessionCache size equal to the number of domains you plan on connecting to. + * + * The implementation for this function can be found at SSLClientImpl::get_session_impl. + * + * @param host A hostname c string, or NULL if one is not availible + * @param ip An IP address + * @returns A reference to an SSLSession object + */ + virtual SSLSession& getSession(const char* host, const IPAddress& addr) { return get_session_impl(host, addr); } + + /** + * @brief Clear the session corresponding to a host and IP + * + * The implementation for this function can be found at SSLClientImpl::remove_session_impl. + * + * @param host A hostname c string, or NULL if one is not availible + * @param ip An IP address + */ + virtual void removeSession(const char* host, const IPAddress& addr) { return remove_session_impl(host, addr); } + + /** + * @brief Get the meximum number of SSL sessions that can be stored at once + * @returns The SessionCache template parameter. + */ + virtual size_t getSessionCount() const { return SessionCache; } + /** * @brief Equivalent to SSLClient::connected() > 0 * @returns true if connected, false if not @@ -412,88 +443,15 @@ public: /** @brief returns a refernence to the client object stored in this class. Take care not to break it. */ C& getClient() { return m_client; } - /** - * @brief Get a sesssion reference corressponding to a host and IP, or a reference to a emptey session if none exist - * - * If no session corresponding to the host and ip exist, then this function will cycle through - * sessions in a rotating order. This allows the ssession cache to continuially store sessions, - * however it will also result in old sessions being cleared and returned. In general, it is a - * good idea to use a SessionCache size equal to the number of domains you plan on connecting to. - * - * @param host A hostname c string, or NULL if one is not availible - * @param ip An IP address - * @returns A reference to an SSLSession object - */ - virtual SSLSession& getSession(const char* host, const IPAddress& addr); +protected: + //virtual Client& get_arduino_client() { return m_client; } + //virtual SSLSession* get_session_array() { return m_sessions; } - /** - * @brief Clear the session corresponding to a host and IP - * - * @param host A hostname c string, or NULL if one is not availible - * @param ip An IP address - */ - virtual void removeSession(const char* host, const IPAddress& addr); private: - // utility function to find a session index based off of a host and IP - int m_getSessionIndex(const char* host, const IPAddress& addr) const; // create a copy of the client C m_client; // also store an array of SSLSessions, so we can resume communication with multiple websites SSLSession m_sessions[SessionCache]; - // store an index of where a new session can be placed if we don't have any corresponding sessions - size_t m_index; }; -template -SSLSession& SSLClient::getSession(const char* host, const IPAddress& addr) { - const char* func_name = __func__; - // search for a matching session with the IP - int temp_index = m_getSessionIndex(host, addr); - // if none are availible, use m_index - if (temp_index == -1) { - temp_index = m_index; - // reset the session so we don't try to send one sites session to another - m_sessions[temp_index].clear_parameters(); - } - // increment m_index so the session cache is a circular buffer - if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0; - // return the pointed to value - m_info("Using session index: ", func_name); - m_info(temp_index, func_name); - return m_sessions[temp_index]; -} - -template -void SSLClient::removeSession(const char* host, const IPAddress& addr) { - const char* func_name = __func__; - int temp_index = m_getSessionIndex(host, addr); - if (temp_index != -1) { - m_info(" Deleted session ", func_name); - m_info(temp_index, func_name); - m_sessions[temp_index].clear_parameters(); - } -} - -template -int SSLClient::m_getSessionIndex(const char* host, const IPAddress& addr) const { - const char* func_name = __func__; - // search for a matching session with the IP - for (uint8_t i = 0; i < SessionCache; i++) { - // if we're looking at a real session - if (m_sessions[i].is_valid_session() - && ( - // and the hostname matches, or - (host != NULL && m_sessions[i].get_hostname().equals(host)) - // there is no hostname and the IP address matches - || (host == NULL && addr == m_sessions[i].get_ip()) - )) { - m_info("Found session match: ", func_name); - m_info(m_sessions[i].get_hostname(), func_name); - return i; - } - } - // none found - return -1; -} - #endif /** SSLClient_H_ */ \ No newline at end of file diff --git a/src/SSLClientImpl.cpp b/src/SSLClientImpl.cpp index 3e653d8..6453eba 100644 --- a/src/SSLClientImpl.cpp +++ b/src/SSLClientImpl.cpp @@ -51,11 +51,14 @@ static int freeMemory() { /** see SSLClientImpl.h */ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_anchors, - const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug) + const size_t trust_anchors_num, const int analog_pin, SSLSession* session_ray, + const DebugLevel debug) : m_client(client) , m_trust_anchors(trust_anchors) , m_trust_anchors_num(trust_anchors_num) , m_analog_pin(analog_pin) + , m_session_ptr(session_ray) + , m_session_index(0) , m_debug(debug) , m_write_idx(0) { @@ -90,7 +93,7 @@ int SSLClientImpl::connect_impl(IPAddress ip, uint16_t port) { return 0; } m_info("Base client connected!", func_name); - return m_start_ssl(NULL, getSession(NULL, ip)); + return m_start_ssl(NULL, get_session_impl(NULL, ip)); } /* see SSLClientImpl.h*/ @@ -106,7 +109,7 @@ int SSLClientImpl::connect_impl(const char *host, uint16_t port) { // first, if we have a session, check if we're trying to resolve the same host // as before bool connect_ok; - SSLSession& ses = getSession(host, INADDR_NONE); + SSLSession& ses = get_session_impl(host, INADDR_NONE); if (ses.is_valid_session()) { // if so, then connect using the stored session m_info("Connecting using a cached IP", func_name); @@ -187,7 +190,7 @@ int SSLClientImpl::available_impl() { /** see SSLClientImpl.h */ int SSLClientImpl::read_impl(uint8_t *buf, size_t size) { // check that the engine is ready to read - if (available() <= 0) return -1; + if (available_impl() <= 0) return -1; // read the buffer, send the ack, and return the bytes read size_t alen; unsigned char* br_buf = br_ssl_engine_recvapp_buf(&m_sslctx.eng, &alen); @@ -202,7 +205,7 @@ int SSLClientImpl::read_impl(uint8_t *buf, size_t size) { /** see SSLClientImpl.h */ int SSLClientImpl::peek_impl() { // check that the engine is ready to read - if (available() <= 0) return -1; + if (available_impl() <= 0) return -1; // read the buffer, send the ack, and return the bytes read size_t alen; uint8_t read_num; @@ -211,7 +214,7 @@ int SSLClientImpl::peek_impl() { return (int)read_num; } -/** see SSLClientImpl.h*/ +/** see SSLClientImpl.h */ void SSLClientImpl::flush_impl() { // trigger a flush, incase there's any leftover data br_ssl_engine_flush(&m_sslctx.eng, 0); @@ -219,7 +222,7 @@ void SSLClientImpl::flush_impl() { if(m_run_until(BR_SSL_RECVAPP) < 0) m_error("Could not flush write buffer!", __func__); } -/** see SSLClientImpl.h*/ +/** see SSLClientImpl.h */ void SSLClientImpl::stop_impl() { // tell the SSL connection to gracefully close br_ssl_engine_close(&m_sslctx.eng); @@ -240,6 +243,7 @@ void SSLClientImpl::stop_impl() { m_client->stop(); } +/** see SSLClientImpl.h */ uint8_t SSLClientImpl::connected_impl() { const char* func_name = __func__; // check all of the error cases @@ -257,7 +261,7 @@ uint8_t SSLClientImpl::connected_impl() { else m_warn("Socket was dropped unexpectedly (this can be an alternative to closing the connection)", func_name); // set the write error so the engine doesn't try to close the connection setWriteError(SSL_CLIENT_WRTIE_ERROR); - stop(); + stop_impl(); } else if (!wr_ok) { m_error("Not connected because write error is set", func_name); @@ -265,6 +269,36 @@ uint8_t SSLClientImpl::connected_impl() { return c_con && br_con && wr_ok; } +/** see SSLClientImpl.h */ +SSLSession& SSLClientImpl::get_session_impl(const char* host, const IPAddress& addr) { + const char* func_name = __func__; + // search for a matching session with the IP + int temp_index = m_get_session_index(host, addr); + // if none are availible, use m_session_index + if (temp_index == -1) { + temp_index = m_session_index; + // reset the session so we don't try to send one sites session to another + m_session_ptr[temp_index].clear_parameters(); + } + // increment m_session_index so the session cache is a circular buffer + if (temp_index == m_session_index && ++m_session_index >= getSessionCount()) m_session_index = 0; + // return the pointed to value + m_info("Using session index: ", func_name); + m_info(temp_index, func_name); + return m_session_ptr[temp_index]; +} + +/** see SSLClientImpl.h */ +void SSLClientImpl::remove_session_impl(const char* host, const IPAddress& addr) { + const char* func_name = __func__; + int temp_index = m_get_session_index(host, addr); + if (temp_index != -1) { + m_info(" Deleted session ", func_name); + m_info(temp_index, func_name); + m_session_ptr[temp_index].clear_parameters(); + } +} + bool SSLClientImpl::m_soft_connected(const char* func_name) { // check if the socket is still open and such if (getWriteError()) { @@ -346,7 +380,7 @@ int SSLClientImpl::m_run_until(const unsigned target) { if (millis() - start > getTimeout()) { m_error("SSL internals timed out! This could be an internal error or bad data sent from the server", func_name); setWriteError(SSL_BR_WRITE_ERROR); - stop(); + stop_impl(); return -1; } // debug @@ -399,7 +433,7 @@ int SSLClientImpl::m_run_until(const unsigned target) { else { m_error("SSL engine state is RECVAPP, however the buffer was null! (This is a problem with BearSSL internals)", func_name); setWriteError(SSL_BR_WRITE_ERROR); - stop(); + stop_impl(); return -1; } } @@ -446,7 +480,7 @@ unsigned SSLClientImpl::m_update_engine() { * wait for it. */ if (!&m_sslctx.eng.shutdown_recv) return 0; - stop(); + stop_impl(); return 0; } if (wlen > 0) { @@ -465,7 +499,7 @@ unsigned SSLClientImpl::m_update_engine() { if (!(state & BR_SSL_SENDAPP)) { m_error("Error m_write_idx > 0 but the ssl engine is not ready for data", func_name); setWriteError(SSL_BR_WRITE_ERROR); - stop(); + stop_impl(); return 0; } // else time to send the application data @@ -476,14 +510,14 @@ unsigned SSLClientImpl::m_update_engine() { if (alen == 0 || buf == NULL) { m_error("Engine set write flag but returned null buffer", func_name); setWriteError(SSL_BR_WRITE_ERROR); - stop(); + stop_impl(); return 0; } // sanity check if (alen < m_write_idx) { m_error("Alen is less than m_write_idx", func_name); setWriteError(SSL_INTERNAL_ERROR); - stop(); + stop_impl(); return 0; } // all good? lets send the data @@ -510,7 +544,7 @@ unsigned SSLClientImpl::m_update_engine() { unsigned char * buf = br_ssl_engine_recvrec_buf(&m_sslctx.eng, &len); // do we have the record you're looking for? const auto avail = m_client->available(); - if (avail >= len) { + if (avail > 0 && avail >= len) { int mem = freeMemory(); // check for a stack overflow // if the stack overflows we basically have to crash, and @@ -531,7 +565,7 @@ unsigned SSLClientImpl::m_update_engine() { if(mem < 8000) { m_error("Out of memory! Decrease the number of sessions or the size of m_iobuf", func_name); setWriteError(SSL_OUT_OF_MEMORY); - stop(); + stop_impl(); return 0; } m_info("Read bytes from client: ", func_name); @@ -543,7 +577,7 @@ unsigned SSLClientImpl::m_update_engine() { m_error("Error reading bytes from m_client. Write Error: ", func_name); m_error(m_client->getWriteError(), func_name); setWriteError(SSL_CLIENT_WRTIE_ERROR); - stop(); + stop_impl(); return 0; } if (rlen > 0) { @@ -568,6 +602,27 @@ unsigned SSLClientImpl::m_update_engine() { } } +/** see SSLClientImpl.h */ +int SSLClientImpl::m_get_session_index(const char* host, const IPAddress& addr) const { + const char* func_name = __func__; + // search for a matching session with the IP + for (uint8_t i = 0; i < getSessionCount(); i++) { + // if we're looking at a real session + if (m_session_ptr[i].is_valid_session() + && ( + // and the hostname matches, or + (host != NULL && m_session_ptr[i].get_hostname().equals(host)) + // there is no hostname and the IP address matches + || (host == NULL && addr == m_session_ptr[i].get_ip()) + )) { + m_info("Found session match: ", func_name); + m_info(m_session_ptr[i].get_hostname(), func_name); + return i; + } + } + // none found + return -1; +} /** See SSLClientImpl.h */ void SSLClientImpl::m_print_prefix(const char* func_name, const DebugLevel level) const diff --git a/src/SSLClientImpl.h b/src/SSLClientImpl.h index 0c29e74..e5a9069 100644 --- a/src/SSLClientImpl.h +++ b/src/SSLClientImpl.h @@ -71,7 +71,7 @@ enum DebugLevel { * On error, any function in this class will terminate the socket. * TODO: Write what this is */ -class SSLClientImpl: public Client { +class SSLClientImpl : public Client { public: /** * @brief initializes SSL contexts for bearSSL @@ -80,22 +80,23 @@ public: * based off of the domains you want to make SSL connections to. Check out the * Wiki on the pycert-bearssl tool for a simple way to do this. * @pre The analog_pin should be set to input. + * @pre The session_ray must be an array of the size returned by SSLClient::getSessionCount() + * filled with SSLSession objects. * * @post set_client must be called immediatly after to set the client class - * pointer. + * pointer and Session pointer. * * @param trust_anchors Trust anchors used in the verification * of the SSL server certificate, generated using the `brssl` command * line utility. For more information see the samples or bearssl.org * @param trust_anchors_num The number of trust anchors stored * @param analog_pin An analog pin to pull random bytes from, used in seeding the RNG - * @param get_remote_ip Function pointer to get the remote ip from the client. We - * need this value since the Client abstract class has no remoteIP() function, - * however most of the arduino internet client implementations do. + * @param session_ray A pointer to the array of SSLSessions created by SSLClient * @param debug whether to enable or disable debug logging, must be constexpr */ explicit SSLClientImpl(Client* client, const br_x509_trust_anchor *trust_anchors, - const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug); + const size_t trust_anchors_num, const int analog_pin, SSLSession* session_ray, + const DebugLevel debug); //============================================ //= Functions implemented in SSLClientImpl.cpp @@ -119,6 +120,10 @@ public: void stop_impl(); /** @see SSLClient::connected */ uint8_t connected_impl(); + /** See SSLClient::getSession */ + SSLSession& get_session_impl(const char* host, const IPAddress& addr); + /** See SSLClient::removeSession */ + void remove_session_impl(const char* host, const IPAddress& addr); //============================================ //= Functions implemented in SSLClient.h @@ -129,11 +134,10 @@ public: virtual IPAddress remoteIP() = 0; /** See SSLClient::localPort */ virtual uint16_t remotePort() = 0; - /** See SSLClient::getSession */ - virtual SSLSession& getSession(const char* host, const IPAddress& addr) = 0; + /** See SSLClient::getSessionCount */ + virtual size_t getSessionCount() const = 0; protected: - //============================================ //= Functions implemented in SSLClientImpl.cpp //============================================ @@ -145,7 +149,7 @@ protected: * is placed in it's own function for flexibility reasons, but it * is critical that this function is called before anything else */ - void set_client(Client* c) { m_client = c; } + void set_client(Client* c, SSLSession* sessions) { m_client = c; m_session_ptr = sessions; } /** @brief Prints a debugging prefix to all logs, so we can attatch them to useful information */ void m_print_prefix(const char* func_name, const DebugLevel level) const; @@ -185,7 +189,9 @@ private: /** run the bearssl engine until a certain state */ int m_run_until(const unsigned target); /** proxy for availble that returns the state */ - unsigned m_update_engine(); + unsigned m_update_engine(); + /** utility function to find a session index based off of a host and IP */ + int m_get_session_index(const char* host, const IPAddress& addr) const; //============================================ //= Data Members @@ -199,6 +205,11 @@ private: const size_t m_trust_anchors_num; // store the pin to fetch an RNG see from const int m_analog_pin; + // store a pointer to the SSL Session array, since it's size + // is deduced at compile time + SSLSession* m_session_ptr; + // store an index of where a new session can be placed if we don't have any corresponding sessions + size_t m_session_index; // store whether to enable debug logging const DebugLevel m_debug; // store the context values required for SSL