diff --git a/src/SSLClient.h b/src/SSLClient.h index 945c59d..0155339 100644 --- a/src/SSLClient.h +++ b/src/SSLClient.h @@ -99,13 +99,10 @@ public: * @param debug whether to enable or disable debug logging, must be constexpr */ explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug = SSL_WARN) - : SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, NULL, debug) + : SSLClientImpl(trust_anchors, trust_anchors_num, analog_pin, debug) , m_client(client) , m_sessions{SSLSession()} { - // since we are copying the client in the ctor, we have to set - // the client pointer after the class is constructed - set_client(&m_client, m_sessions); // set the timeout to a reasonable number (it can always be changes later) // SSL Connections take a really long time so we don't want to time out a legitimate thing setTimeout(10 * 1000); @@ -444,8 +441,12 @@ public: C& getClient() { return m_client; } protected: - //virtual Client& get_arduino_client() { return m_client; } - //virtual SSLSession* get_session_array() { return m_sessions; } + /** @brief return an instance of m_client that is polymorphic and can be used by SSLClientImpl */ + virtual Client& get_arduino_client() { return m_client; } + virtual const Client& get_arduino_client() const { return m_client; } + /** @brief return an instance of the session array that is on the stack */ + virtual SSLSession* get_session_array() { return m_sessions; } + virtual const SSLSession* get_session_array() const { return m_sessions; } private: // create a copy of the client diff --git a/src/SSLClientImpl.cpp b/src/SSLClientImpl.cpp index 6453eba..1c1fcce 100644 --- a/src/SSLClientImpl.cpp +++ b/src/SSLClientImpl.cpp @@ -50,14 +50,11 @@ static int freeMemory() { } /** see SSLClientImpl.h */ -SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_anchors, - const size_t trust_anchors_num, const int analog_pin, SSLSession* session_ray, - const DebugLevel debug) - : m_client(client) - , m_trust_anchors(trust_anchors) +SSLClientImpl::SSLClientImpl(const br_x509_trust_anchor *trust_anchors, + const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug) + : m_trust_anchors(trust_anchors) , m_trust_anchors_num(trust_anchors_num) , m_analog_pin(analog_pin) - , m_session_ptr(session_ray) , m_session_index(0) , m_debug(debug) , m_write_idx(0) { @@ -77,7 +74,7 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a int SSLClientImpl::connect_impl(IPAddress ip, uint16_t port) { const char* func_name = __func__; // connection check - if (m_client->connected()) { + if (get_arduino_client().connected()) { m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name); return -1; } @@ -87,7 +84,7 @@ int SSLClientImpl::connect_impl(IPAddress ip, uint16_t port) { m_warn("Using a raw IP Address for an SSL connection bypasses some important verification steps. You should use a domain name (www.google.com) whenever possible.", func_name); // first we need our hidden client member to negotiate the socket for us, // since most times socket functionality is implemented in hardeware. - if (!m_client->connect(ip, port)) { + if (!get_arduino_client().connect(ip, port)) { m_error("Failed to connect using m_client. Are you connected to the internet?", func_name); setWriteError(SSL_CLIENT_CONNECT_FAIL); return 0; @@ -100,7 +97,7 @@ int SSLClientImpl::connect_impl(IPAddress ip, uint16_t port) { int SSLClientImpl::connect_impl(const char *host, uint16_t port) { const char* func_name = __func__; // connection check - if (m_client->connected()) { + if (get_arduino_client().connected()) { m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name); return -1; } @@ -113,10 +110,10 @@ int SSLClientImpl::connect_impl(const char *host, uint16_t port) { if (ses.is_valid_session()) { // if so, then connect using the stored session m_info("Connecting using a cached IP", func_name); - connect_ok = m_client->connect(ses.get_ip(), port); + connect_ok = get_arduino_client().connect(ses.get_ip(), port); } // else connect with the provided hostname - else connect_ok = m_client->connect(host, port); + else connect_ok = get_arduino_client().connect(host, port); // first we need our hidden client member to negotiate the socket for us, // since most times socket functionality is implemented in hardeware. if (!connect_ok) { @@ -240,22 +237,22 @@ void SSLClientImpl::stop_impl() { } } // close the ethernet socket - m_client->stop(); + get_arduino_client().stop(); } /** see SSLClientImpl.h */ uint8_t SSLClientImpl::connected_impl() { const char* func_name = __func__; // check all of the error cases - const auto c_con = m_client->connected(); + const auto c_con = get_arduino_client().connected(); const auto br_con = br_ssl_engine_current_state(&m_sslctx.eng) != BR_SSL_CLOSED; const auto wr_ok = getWriteError() == 0; // if we're in an error state, close the connection and set a write error if (br_con && !c_con) { // If we've got a write error, the client probably failed for some reason - if (m_client->getWriteError()) { + if (get_arduino_client().getWriteError()) { m_error("Socket was unexpectedly interrupted. m_client error: ", func_name); - m_error(m_client->getWriteError(), func_name); + m_error(get_arduino_client().getWriteError(), func_name); } // Else tell the user the endpoint closed the socket on us (ouch) else m_warn("Socket was dropped unexpectedly (this can be an alternative to closing the connection)", func_name); @@ -278,14 +275,14 @@ SSLSession& SSLClientImpl::get_session_impl(const char* host, const IPAddress& a if (temp_index == -1) { temp_index = m_session_index; // reset the session so we don't try to send one sites session to another - m_session_ptr[temp_index].clear_parameters(); + get_session_array()[temp_index].clear_parameters(); } // increment m_session_index so the session cache is a circular buffer if (temp_index == m_session_index && ++m_session_index >= getSessionCount()) m_session_index = 0; // return the pointed to value m_info("Using session index: ", func_name); m_info(temp_index, func_name); - return m_session_ptr[temp_index]; + return get_session_array()[temp_index]; } /** see SSLClientImpl.h */ @@ -295,7 +292,7 @@ void SSLClientImpl::remove_session_impl(const char* host, const IPAddress& addr) if (temp_index != -1) { m_info(" Deleted session ", func_name); m_info(temp_index, func_name); - m_session_ptr[temp_index].clear_parameters(); + get_session_array()[temp_index].clear_parameters(); } } @@ -466,11 +463,11 @@ unsigned SSLClientImpl::m_update_engine() { int wlen; buf = br_ssl_engine_sendrec_buf(&m_sslctx.eng, &len); - wlen = m_client->write(buf, len); + wlen = get_arduino_client().write(buf, len); // let the chip recover if (wlen < 0) { m_error("Error writing to m_client", func_name); - m_error(m_client->getWriteError(), func_name); + m_error(get_arduino_client().getWriteError(), func_name); setWriteError(SSL_CLIENT_WRTIE_ERROR); /* * If we received a close_notify and we @@ -543,7 +540,7 @@ unsigned SSLClientImpl::m_update_engine() { size_t len; unsigned char * buf = br_ssl_engine_recvrec_buf(&m_sslctx.eng, &len); // do we have the record you're looking for? - const auto avail = m_client->available(); + const auto avail = get_arduino_client().available(); if (avail > 0 && avail >= len) { int mem = freeMemory(); // check for a stack overflow @@ -572,10 +569,10 @@ unsigned SSLClientImpl::m_update_engine() { m_info(avail, func_name); m_info(len, func_name); // I suppose so! - int rlen = m_client->read(buf, len); + int rlen = get_arduino_client().read(buf, len); if (rlen <= 0) { m_error("Error reading bytes from m_client. Write Error: ", func_name); - m_error(m_client->getWriteError(), func_name); + m_error(get_arduino_client().getWriteError(), func_name); setWriteError(SSL_CLIENT_WRTIE_ERROR); stop_impl(); return 0; @@ -591,7 +588,7 @@ unsigned SSLClientImpl::m_update_engine() { // m_print(avail); // m_print("Bytes needed: "); // m_print(len); - // add a delay since spamming m_client->availible breaks the poor wiz chip + // add a delay since spamming get_arduino_client().availible breaks the poor wiz chip delay(10); return state; } @@ -608,15 +605,15 @@ int SSLClientImpl::m_get_session_index(const char* host, const IPAddress& addr) // search for a matching session with the IP for (uint8_t i = 0; i < getSessionCount(); i++) { // if we're looking at a real session - if (m_session_ptr[i].is_valid_session() + if (get_session_array()[i].is_valid_session() && ( // and the hostname matches, or - (host != NULL && m_session_ptr[i].get_hostname().equals(host)) + (host != NULL && get_session_array()[i].get_hostname().equals(host)) // there is no hostname and the IP address matches - || (host == NULL && addr == m_session_ptr[i].get_ip()) + || (host == NULL && addr == get_session_array()[i].get_ip()) )) { m_info("Found session match: ", func_name); - m_info(m_session_ptr[i].get_hostname(), func_name); + m_info(get_session_array()[i].get_hostname(), func_name); return i; } } diff --git a/src/SSLClientImpl.h b/src/SSLClientImpl.h index e5a9069..e4ffc5d 100644 --- a/src/SSLClientImpl.h +++ b/src/SSLClientImpl.h @@ -94,9 +94,8 @@ public: * @param session_ray A pointer to the array of SSLSessions created by SSLClient * @param debug whether to enable or disable debug logging, must be constexpr */ - explicit SSLClientImpl(Client* client, const br_x509_trust_anchor *trust_anchors, - const size_t trust_anchors_num, const int analog_pin, SSLSession* session_ray, - const DebugLevel debug); + explicit SSLClientImpl(const br_x509_trust_anchor *trust_anchors, + const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug); //============================================ //= Functions implemented in SSLClientImpl.cpp @@ -138,19 +137,17 @@ public: virtual size_t getSessionCount() const = 0; protected: + /** See SSLClient::get_arduino_client */ + virtual Client& get_arduino_client() = 0; + virtual const Client& get_arduino_client() const = 0; + /** See SSLClient::get_session_array */ + virtual SSLSession* get_session_array() = 0; + virtual const SSLSession* get_session_array() const = 0; + //============================================ //= Functions implemented in SSLClientImpl.cpp //============================================ - /** - * @brief set the pointer to the Client class that we wil use - * - * Call this function immediatly after the ctor. This functionality - * is placed in it's own function for flexibility reasons, but it - * is critical that this function is called before anything else - */ - void set_client(Client* c, SSLSession* sessions) { m_client = c; m_session_ptr = sessions; } - /** @brief Prints a debugging prefix to all logs, so we can attatch them to useful information */ void m_print_prefix(const char* func_name, const DebugLevel level) const; @@ -197,17 +194,12 @@ private: //= Data Members //============================================ - // hold a reference to the client - Client* m_client; // store pointers to the trust anchors // should not be computed at runtime const br_x509_trust_anchor *m_trust_anchors; const size_t m_trust_anchors_num; // store the pin to fetch an RNG see from const int m_analog_pin; - // store a pointer to the SSL Session array, since it's size - // is deduced at compile time - SSLSession* m_session_ptr; // store an index of where a new session can be placed if we don't have any corresponding sessions size_t m_session_index; // store whether to enable debug logging