diff --git a/src/SSLClient.h b/src/SSLClient.h index 1ce3647..f25d4d8 100644 --- a/src/SSLClient.h +++ b/src/SSLClient.h @@ -108,20 +108,18 @@ public: set_client(&m_client); } - /** + /* * The special functions most clients have are below * Most of them smply pass through */ - virtual int availableForWrite(void) { return m_client.availableForWrite(); } virtual operator bool() { return connected() > 0; } virtual bool operator==(const bool value) { return bool() == value; } virtual bool operator!=(const bool value) { return bool() != value; } virtual bool operator==(const C& rhs) { return m_client == rhs; } virtual bool operator!=(const C& rhs) { return m_client != rhs; } - virtual uint16_t localPort() { return m_client.localPort(); } - virtual IPAddress remoteIP() { return m_client.remoteIP(); } - virtual uint16_t remotePort() { return m_client.remotePort(); } - virtual void setConnectionTimeout(uint16_t timeout) { m_client.setConnectionTimeout(timeout); } + virtual uint16_t localPort() { return std::is_member_function_pointer::value ? m_client.localPort() : 0; } + virtual IPAddress remoteIP() { return std::is_member_function_pointer::value ? m_client.remoteIP() : INADDR_NONE; } + virtual uint16_t remotePort() { return std::is_member_function_pointer::value ? m_client.remotePort() : 0; } //! get the client object C& getClient() { return m_client; } diff --git a/src/SSLClientImpl.cpp b/src/SSLClientImpl.cpp index 4778308..3217766 100644 --- a/src/SSLClientImpl.cpp +++ b/src/SSLClientImpl.cpp @@ -21,18 +21,18 @@ #include "SSLClient.h" /** see SSLClientImpl.h */ -SSLClientImpl::SSLClientImpl(Client* client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const bool debug) +SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_anchors, + const size_t trust_anchors_num, const int analog_pin, const bool debug) : m_client(client) , m_trust_anchors(trust_anchors) , m_trust_anchors_num(trust_anchors_num) , m_analog_pin(analog_pin) , m_debug(debug) - , m_write_idx(0) { + , m_write_idx(0) + , m_session() { // zero the iobuf just in case it's still garbage memset(m_iobuf, 0, sizeof m_iobuf); - // zero the session parameters for similar reason - memset(&m_ses_param, 0, sizeof m_ses_param); // initlalize the various bearssl libraries so they're ready to go when we connect br_client_init_TLS12_only(&m_sslctx, &m_x509ctx, m_trust_anchors, m_trust_anchors_num); // br_ssl_client_init_full(&m_sslctx, &m_x509ctx, m_trust_anchors, m_trust_anchors_num); @@ -62,16 +62,25 @@ int SSLClientImpl::connect(IPAddress ip, uint16_t port) { int SSLClientImpl::connect(const char *host, uint16_t port) { // reset indexs for saftey m_write_idx = 0; + // first, if we have a session, check if we're trying to resolve the same host + // as before + bool connect_ok; + if (m_session.is_valid_session() + && strcmp(m_session.get_hostname(), host) == 0) { + // if so, then connect using the stored session + m_print("Connecting using a cached IP"); + connect_ok = m_client->connect(m_session.get_ip(), port); + } + // else connect with the provided hostname + else connect_ok = m_client->connect(host, port); // first we need our hidden client member to negotiate the socket for us, // since most times socket functionality is implemented in hardeware. - if (!m_client->connect(host, port)) { + if (!connect_ok) { m_print("Error: Failed to connect using m_client"); setWriteError(SSL_CLIENT_CONNECT_FAIL); return 0; } m_print("Base ethernet client connected!"); - // clear the write error - setWriteError(SSL_OK); // start ssl! return m_start_ssl(host); } @@ -221,6 +230,8 @@ uint8_t SSLClientImpl::connected() { /** see SSLClientImpl.h */ int SSLClientImpl::m_start_ssl(const char* host) { + // clear the write error + setWriteError(SSL_OK); // get some random data by reading the analog pin we've been handed // we want 128 bits to be safe, as recommended by the bearssl docs uint8_t rng_seeds[16]; @@ -228,9 +239,9 @@ int SSLClientImpl::m_start_ssl(const char* host) { for (uint8_t i = 0; i < sizeof rng_seeds; i++) rng_seeds[i] = static_cast(analogRead(m_analog_pin)); br_ssl_engine_inject_entropy(&m_sslctx.eng, rng_seeds, sizeof rng_seeds); // inject session parameters for faster reconnection, if we have any - if(m_ses_param.session_id_len > 0) { + if(m_session.is_valid_session()) { + br_ssl_engine_set_session_parameters(&m_sslctx.eng, m_session.to_br_session()); m_print("Set session!"); - br_ssl_engine_set_session_parameters(&m_sslctx.eng, &m_ses_param); } // reset the engine, but make sure that it reset successfully int ret = br_ssl_client_reset(&m_sslctx, host, 1); @@ -249,14 +260,16 @@ int SSLClientImpl::m_start_ssl(const char* host) { } // all good to go! the SSL socket should be up and running // debug print the session parameters to see if they exist - br_ssl_engine_get_session_parameters(&m_sslctx.eng, &m_ses_param); + br_ssl_engine_get_session_parameters(&m_sslctx.eng, m_session.to_br_session()); + // set the hostname and ip in the session as well + m_session.set_parameters(remoteIP(), host); m_print("Session:"); - for (uint8_t i = 0; i < m_ses_param.session_id_len; i++) { + for (uint8_t i = 0; i < m_session.session_id_len; i++) { Serial.print(", 0x"); - Serial.print(m_ses_param.session_id[i], HEX); + Serial.print(m_session.session_id[i], HEX); } Serial.println(); - Serial.println(m_ses_param.cipher_suite, HEX); + Serial.println(m_session.cipher_suite, HEX); return 1; } diff --git a/src/SSLClientImpl.h b/src/SSLClientImpl.h index 8e323ef..ea9ce0f 100644 --- a/src/SSLClientImpl.h +++ b/src/SSLClientImpl.h @@ -19,8 +19,9 @@ */ #include "bearssl.h" -#include "Client.h" #include "Arduino.h" +#include "Client.h" +#include "SSLSession.h" #ifndef SSLClientImpl_H_ #define SSLClientImpl_H_ @@ -44,9 +45,13 @@ public: * line utility. For more information see the samples or bearssl.org * @param trust_anchors_num The number of trust anchors stored * @param analog_pin An analog pin to pull random bytes from, used in seeding the RNG + * @param get_remote_ip Function pointer to get the remote ip from the client. We + * need this value since the Client abstract class has no remoteIP() function, + * however most of the arduino internet client implementations do. * @param debug whether to enable or disable debug logging, must be constexpr */ - explicit SSLClientImpl(Client* client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const bool debug = true); + explicit SSLClientImpl(Client* client, const br_x509_trust_anchor *trust_anchors, + const size_t trust_anchors_num, const int analog_pin, const bool debug = true); /** Dtor is implicit since unique_ptr handles it fine */ /** functions specific to the EthernetClient which I'll have to override */ @@ -111,6 +116,10 @@ public: virtual void stop(); virtual uint8_t connected(); + // stub virtual functions to get things from the client + virtual uint16_t localPort() = 0; + virtual IPAddress remoteIP() = 0; + virtual uint16_t remotePort() = 0; protected: /** * @brief set the pointer to the Client class that we wil use @@ -120,6 +129,7 @@ protected: * is critical that this function is called before anything else */ void set_client(Client* c) { m_client = c; } + private: /** @brief debugging print function, only prints if m_debug is true */ @@ -174,7 +184,7 @@ private: // weird timing issues size_t m_write_idx; // store the last SSL session, so reconnection later is speedy fast - br_ssl_session_parameters m_ses_param; + SSLSession m_session; }; #endif /* SSLClientImpl_H_ */ \ No newline at end of file