diff --git a/src/SSLClient.h b/src/SSLClient.h index 158bf19..176e51d 100644 --- a/src/SSLClient.h +++ b/src/SSLClient.h @@ -36,29 +36,16 @@ */ #include -#include "bearssl.h" -#include "Arduino.h" #include "Client.h" +#include "SSLClientImpl.h" #ifndef SSLClient_H_ #define SSLClient_H_ -template -class SSLClient : public Client { -/** static type checks - * I'm a java developer, so I want to ensure that my inheritance is safe. - * These checks ensure that all the functions we use on class C are - * actually present on class C. It does this by first checking that the - * class inherits from Client, and then that it contains a status() function. - */ -static_assert(std::is_base_of::value, "C must be a Client Class!"); -// static_assert(std::is_function::value, "C must have a status() function!"); - /** error enums * Static constants defining the possible errors encountered * Read from getWriteError(); */ - enum Error { SSL_OK = 0, SSL_CLIENT_CONNECT_FAIL, @@ -68,9 +55,35 @@ enum Error { SSL_INTERNAL_ERROR }; +/** + * \brief This class serves as a templating proxy class for the SSLClientImpl to do the real work. + * + * A problem arose when writing this class: I wanted the user to be able to construct + * this class in a single line of code (e.g. SSLClient(EthernetClient())), but I also + * wanted to avoid the use of dynamic memory if possible. In an attempt to solve this + * problem I used a templated classes. However, becuase of the Arduino build process + * this meant that the implementations for all the functions had to be in a header + * file (a weird effect of using templated classes and linking) which would slow down + * the build quite a bit. As a comprimise, I instead decided to build the main class (SSLCLient) + * as a templated class, and have use a not templated implementation class (SSLClientImpl) + * that would be able to reside in a seperate file. This gets the best of both worlds + * from the client side, however from the developer side it can be a bit confusing. + */ + +template +class SSLClient : public SSLClientImpl { +/** static type checks + * I'm a java developer, so I want to ensure that my inheritance is safe. + * These checks ensure that all the functions we use on class C are + * actually present on class C. It does this by first checking that the + * class inherits from Client, and then that it contains a status() function. + */ +static_assert(std::is_base_of::value, "C must be a Client Class!"); +// static_assert(std::is_function::value, "C must have a status() function!"); + public: /** - * @brief copies the client object and initializes SSL contexts for bearSSL + * @brief copies the client object, and passes the various parameters to the SSLCLientImpl functions. * * We copy the client because we aren't sure the Client object * is going to exists past the inital creation of the SSLClient. @@ -84,123 +97,36 @@ public: * @param trust_anchors_num The number of trust anchors stored * @param debug whether to enable or disable debug logging, must be constexpr */ - explicit SSLClient(const C &client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const bool debug = true); - /** Dtor is implicit since unique_ptr handles it fine */ - + SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const bool debug = true) + : SSLClientImpl(NULL, trust_anchors, trust_anchors_num, debug) + , m_client(client) + { + // since we are copying the client in the ctor, we have to set + // the client pointer after the class is constructed + set_client(&m_client); + } + /** - * The virtual functions defining a Client are below + * The special functions most clients have are below * Most of them smply pass through */ - virtual int availableForWrite(void) { return m_client.availableForWrite(); }; - virtual operator bool() { return static_cast(m_client); } - // virtual bool operator==(const bool value) { return bool() == value; } - // virtual bool operator!=(const bool value) { return bool() != value; } - // virtual bool operator==(const C& rhs) const { return m_client.operator==(rhs); } - // virtual bool operator!=(const C& rhs) const { return !this->operator==(rhs); } + virtual int availableForWrite(void) { return m_client.availableForWrite(); } + virtual operator bool() { return connected() > 0; } + virtual bool operator==(const bool value) { return bool() == value; } + virtual bool operator!=(const bool value) { return bool() != value; } + virtual bool operator==(const C& rhs) { return m_client == rhs; } + virtual bool operator!=(const C& rhs) { return m_client != rhs; } virtual uint16_t localPort() { return m_client.localPort(); } virtual IPAddress remoteIP() { return m_client.remoteIP(); } virtual uint16_t remotePort() { return m_client.remotePort(); } virtual void setConnectionTimeout(uint16_t timeout) { m_client.setConnectionTimeout(timeout); } - /** functions specific to the EthernetClient which I'll have to override */ - // uint8_t status(); - // uint8_t getSocketNumber() const; - - /** functions dealing with read/write that BearSSL will be injected into */ - /** - * @brief Connect over SSL to a host specified by an ip address - * - * SSLClient::connect(host, port) should be preffered over this function, - * as verifying the domain name is a step in ensuring the certificate is - * legitimate, which is important to the security of the device. Additionally, - * SSL sessions cannot be resumed, which can drastically increase initial - * connect time. - * - * This function initializes EthernetClient by calling EthernetClient::connect - * with the parameters supplied, then once the socket is open initializes - * the appropriete bearssl contexts using the TLS_only_profile. Due to the - * design of the SSL standard, this function will probably take an extended - * period (1-2sec) to negotiate the handshake and finish the connection. - * - * @param ip The ip address to connect to - * @param port the port to connect to - * @returns 1 if success, 0 if failure (as found in EthernetClient) - * - * @error SSL_CLIENT_CONNECT_FAIL The client object could not connect to the host or port - * @error SSL_BR_CONNECT_FAIL BearSSL could not initialize the SSL connection. - */ - virtual int connect(IPAddress ip, uint16_t port); - /** - * @brief Connect over SSL using connect(ip, port), but use a DNS lookup to - * get the IP Address first. - * - * This function initializes EthernetClient by calling EthernetClient::connect - * with the parameters supplied, then once the socket is open initializes - * the appropriete bearssl contexts using the TLS_only_profile. - * - * Due to the design of the SSL standard, this function will probably take an - * extended period (1-2sec) to negotiate the handshake and finish the - * connection. Since the hostname is provided, however, BearSSL is able to keep - * a session cache of the clients we have connected to. This should reduce - * connection time greatly. In order to use this feature, you must reuse the - * same SSLClient object to connect to the reused host. Doing this will allow - * BearSSL to automatically match the hostname to a cached session. - * - * @param host The cstring host ("www.google.com") - * @param port the port to connect to - * @returns 1 of success, 0 if failure (as found in EthernetClient) - * - * @error SSL_CLIENT_CONNECT_FAIL The client object could not connect to the host or port - * @error SSL_BR_CONNECT_FAIL BearSSL could not initialize the SSL connection. - */ - virtual int connect(const char *host, uint16_t port); - virtual size_t write(uint8_t b) { return write(&b, 1); } - virtual size_t write(const uint8_t *buf, size_t size); - virtual int available(); - virtual int read() { int peeked = peek(); if(peeked != -1) br_ssl_engine_recvapp_ack(&m_sslctx.eng, 1); return peeked; } - virtual int read(uint8_t *buf, size_t size); - virtual int peek(); - virtual void flush(); - virtual void stop(); - virtual uint8_t connected(); - //! get the client object C& getClient() { return m_client; } private: - /** @brief debugging print function, only prints if m_debug is true */ - template - constexpr void m_print(const T str) const { - if (m_debug) { - Serial.print("SSLClient: "); - Serial.println(str); - } - } - /** run the bearssl engine until a certain state */ - int m_run_until(const unsigned target); - /** proxy for availble that returns the state */ - unsigned m_update_engine(); // create a copy of the client C m_client; - // store pointers to the trust anchors - // should not be computed at runtime - const br_x509_trust_anchor *m_trust_anchors; - const size_t m_trust_anchors_num; - // store whether to enable debug logging - const bool m_debug; - // store the context values required for SSL - br_ssl_client_context m_sslctx; - br_x509_minimal_context m_x509ctx; - // use a mono-directional buffer by default to cut memory in half - // can expand to a bi-directional buffer with maximum of BR_SSL_BUFSIZE_BIDI - // or shrink to below BR_SSL_BUFSIZE_MONO, and bearSSL will adapt automatically - // simply edit this value to change the buffer size to the desired value - unsigned char m_iobuf[BR_SSL_BUFSIZE_MONO]; - static_assert(sizeof m_iobuf <= BR_SSL_BUFSIZE_BIDI, "m_iobuf must be below maximum buffer size"); - // store the index of where we are writing in the buffer - // so we can send our records all at once to prevent - // weird timing issues - size_t m_write_idx; }; #endif /** SSLClient_H_ */ \ No newline at end of file diff --git a/src/SSLClientImpl.cpp b/src/SSLClientImpl.cpp index 0a0e6ce..f28ed75 100644 --- a/src/SSLClientImpl.cpp +++ b/src/SSLClientImpl.cpp @@ -20,9 +20,8 @@ #include "SSLClient.h" -/** see SSLClient.h */ -template -SSLClient::SSLClient(const C &client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const bool debug) +/** see SSLClientImpl.h */ +SSLClientImpl::SSLClientImpl(Client* client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const bool debug) : m_client(client) , m_trust_anchors(trust_anchors) , m_trust_anchors_num(trust_anchors_num) @@ -38,16 +37,15 @@ SSLClient::SSLClient(const C &client, const br_x509_trust_anchor *trust_ancho br_ssl_engine_set_buffer(&m_sslctx.eng, m_iobuf, sizeof m_iobuf, duplex); } -/* see SSLClient.h */ -template -int SSLClient::connect(IPAddress ip, uint16_t port) { +/* see SSLClientImpl.h*/ +int SSLClientImpl::connect(IPAddress ip, uint16_t port) { // reset indexs for saftey m_write_idx = 0; // Warning for security m_print("Warning! Using a raw IP Address for an SSL connection bypasses some important verification steps\nYou should use a domain name (www.google.com) whenever possible."); // first we need our hidden client member to negotiate the socket for us, // since most times socket functionality is implemented in hardeware. - if (!this->m_client.connect(ip, port)) { + if (!m_client->connect(ip, port)) { m_print("Failed to connect using m_client"); setWriteError(SSL_CLIENT_CONNECT_FAIL); return 0; @@ -68,15 +66,13 @@ int SSLClient::connect(IPAddress ip, uint16_t port) { return 1; } -/* see SSLClient.h */ -template -int SSLClient::connect(const char *host, uint16_t port) { +/* see SSLClientImpl.h*/ +int SSLClientImpl::connect(const char *host, uint16_t port) { // reset indexs for saftey - m_write_idx = 0; // first we need our hidden client member to negotiate the socket for us, // since most times socket functionality is implemented in hardeware. - if (!this->m_client.connect(host, port)) { + if (!m_client->connect(host, port)) { m_print("Failed to connect using m_client"); setWriteError(SSL_CLIENT_CONNECT_FAIL); return 0; @@ -97,9 +93,8 @@ int SSLClient::connect(const char *host, uint16_t port) { return 1; } -/** see SSLClient.h */ -template -size_t SSLClient::write(const uint8_t *buf, size_t size) { +/** see SSLClientImpl.h*/ +size_t SSLClientImpl::write(const uint8_t *buf, size_t size) { // check if the socket is still open and such if(!connected()) { m_print("Client is not connected! Perhaps something has happened?"); @@ -143,9 +138,8 @@ size_t SSLClient::write(const uint8_t *buf, size_t size) { return size; } -/** see SSLClient.h */ -template -int SSLClient::available() { +/** see SSLClientImpl.h*/ +int SSLClientImpl::available() { // connection check if (!connected()) { m_print("Warn: Cannot check available of disconnected client"); @@ -171,9 +165,8 @@ int SSLClient::available() { return 0; } -/** see SSLClient.h */ -template -int SSLClient::read(uint8_t *buf, size_t size) { +/** see SSLClientImpl.h */ +int SSLClientImpl::read(uint8_t *buf, size_t size) { // check that the engine is ready to read if (available() <= 0) return -1; // read the buffer, send the ack, and return the bytes read @@ -187,9 +180,8 @@ int SSLClient::read(uint8_t *buf, size_t size) { return read_amount; } -/** see SSLClient.h */ -template -int SSLClient::peek() { +/** see SSLClientImpl.h */ +int SSLClientImpl::peek() { // check that the engine is ready to read if (available() <= 0) return -1; // read the buffer, send the ack, and return the bytes read @@ -200,18 +192,16 @@ int SSLClient::peek() { return (int)read_num; } -/** see SSLClient.h */ -template -void SSLClient::flush() { +/** see SSLClientImpl.h*/ +void SSLClientImpl::flush() { // trigger a flush, incase there's any leftover data br_ssl_engine_flush(&m_sslctx.eng, 0); // run until application data is ready for pickup if(m_run_until(BR_SSL_RECVAPP) < 0) m_print("Error: could not flush write buffer!"); } -/** see SSLClient.h */ -template -void SSLClient::stop() { +/** see SSLClientImpl.h*/ +void SSLClientImpl::stop() { // tell the SSL connection to gracefully close br_ssl_engine_close(&m_sslctx.eng); while (br_ssl_engine_current_state(&m_sslctx.eng) != BR_SSL_CLOSED) { @@ -226,29 +216,27 @@ void SSLClient::stop() { } } // close the ethernet socket - m_client.stop(); + m_client->stop(); } -template -uint8_t SSLClient::connected() { +uint8_t SSLClientImpl::connected() { // check all of the error cases - const auto c_con = m_client.connected(); + const auto c_con = m_client->connected(); const auto br_con = br_ssl_engine_current_state(&m_sslctx.eng) != BR_SSL_CLOSED; const auto wr_ok = getWriteError() == 0; // if we're in an error state, close the connection and set a write error if ((br_con && !c_con) || !wr_ok) { m_print("Error: Socket was unexpectedly interrupted"); m_print("Terminated with: "); - m_print(m_client.getWriteError()); + m_print(m_client->getWriteError()); setWriteError(SSL_CLIENT_WRTIE_ERROR); stop(); } return c_con && br_con && wr_ok; } -/** see SSLClient.h */ -template -int SSLClient::m_run_until(const unsigned target) { +/** see SSLClientImpl.h*/ +int SSLClientImpl::m_run_until(const unsigned target) { for (;;) { unsigned state = m_update_engine(); /* @@ -291,9 +279,8 @@ int SSLClient::m_run_until(const unsigned target) { } } -/** see SSLClient.h */ -template -unsigned SSLClient::m_update_engine() { +/** see SSLClientImpl.h*/ +unsigned SSLClientImpl::m_update_engine() { for(;;) { // get the state unsigned state = br_ssl_engine_current_state(&m_sslctx.eng); @@ -308,7 +295,7 @@ unsigned SSLClient::m_update_engine() { int wlen; buf = br_ssl_engine_sendrec_buf(&m_sslctx.eng, &len); - wlen = m_client.write(buf, len); + wlen = m_client->write(buf, len); if (wlen < 0) { m_print("Error writing to m_client"); /* @@ -381,9 +368,9 @@ unsigned SSLClient::m_update_engine() { size_t len; unsigned char * buf = br_ssl_engine_recvrec_buf(&m_sslctx.eng, &len); // do we have the record you're looking for? - if (m_client.available() >= len) { + if (m_client->available() >= len) { // I suppose so! - int rlen = m_client.readBytes((char *)buf, len); + int rlen = m_client->readBytes((char *)buf, len); if (rlen < 0) { m_print("Error reading bytes from m_client"); setWriteError(SSL_BR_WRITE_ERROR); diff --git a/src/SSLClientImpl.h b/src/SSLClientImpl.h index a88cf64..0bde340 100644 --- a/src/SSLClientImpl.h +++ b/src/SSLClientImpl.h @@ -18,8 +18,143 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +#include "bearssl.h" +#include "Client.h" +#include "Arduino.h" + #ifndef SSLClientImpl_H_ #define SSLClientImpl_H_ +/** TODO: Write what this is */ + +class SSLClientImpl : public Client { +public: + /** + * @brief initializes SSL contexts for bearSSL + * + * @pre The client class must be able to access the internet, as SSLClient + * cannot manage this for you. + * + * @post set_client must be called immediatly after to set the client class + * pointer. + * + * @param trust_anchors Trust anchors used in the verification + * of the SSL server certificate, generated using the `brssl` command + * line utility. For more information see the samples or bearssl.org + * @param trust_anchors_num The number of trust anchors stored + * @param debug whether to enable or disable debug logging, must be constexpr + */ + explicit SSLClientImpl(Client* client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const bool debug = true); + /** Dtor is implicit since unique_ptr handles it fine */ + + /** functions specific to the EthernetClient which I'll have to override */ + // uint8_t status(); + // uint8_t getSocketNumber() const; + + /** functions dealing with read/write that BearSSL will be injected into */ + /** + * @brief Connect over SSL to a host specified by an ip address + * + * SSLClient::connect(host, port) should be preffered over this function, + * as verifying the domain name is a step in ensuring the certificate is + * legitimate, which is important to the security of the device. Additionally, + * SSL sessions cannot be resumed, which can drastically increase initial + * connect time. + * + * This function initializes EthernetClient by calling EthernetClient::connect + * with the parameters supplied, then once the socket is open initializes + * the appropriete bearssl contexts using the TLS_only_profile. Due to the + * design of the SSL standard, this function will probably take an extended + * period (1-2sec) to negotiate the handshake and finish the connection. + * + * @param ip The ip address to connect to + * @param port the port to connect to + * @returns 1 if success, 0 if failure (as found in EthernetClient) + * + * @error SSL_CLIENT_CONNECT_FAIL The client object could not connect to the host or port + * @error SSL_BR_CONNECT_FAIL BearSSL could not initialize the SSL connection. + */ + virtual int connect(IPAddress ip, uint16_t port); + /** + * @brief Connect over SSL using connect(ip, port), but use a DNS lookup to + * get the IP Address first. + * + * This function initializes EthernetClient by calling EthernetClient::connect + * with the parameters supplied, then once the socket is open initializes + * the appropriete bearssl contexts using the TLS_only_profile. + * + * Due to the design of the SSL standard, this function will probably take an + * extended period (1-2sec) to negotiate the handshake and finish the + * connection. Since the hostname is provided, however, BearSSL is able to keep + * a session cache of the clients we have connected to. This should reduce + * connection time greatly. In order to use this feature, you must reuse the + * same SSLClient object to connect to the reused host. Doing this will allow + * BearSSL to automatically match the hostname to a cached session. + * + * @param host The cstring host ("www.google.com") + * @param port the port to connect to + * @returns 1 of success, 0 if failure (as found in EthernetClient) + * + * @error SSL_CLIENT_CONNECT_FAIL The client object could not connect to the host or port + * @error SSL_BR_CONNECT_FAIL BearSSL could not initialize the SSL connection. + */ + virtual int connect(const char *host, uint16_t port); + virtual size_t write(uint8_t b) { return write(&b, 1); } + virtual size_t write(const uint8_t *buf, size_t size); + virtual int available(); + virtual int read() { uint8_t read_val; return read(&read_val, 1) > 0 ? read_val : -1; } + virtual int read(uint8_t *buf, size_t size); + virtual int peek(); + virtual void flush(); + virtual void stop(); + virtual uint8_t connected(); + +protected: + /** + * @brief set the pointer to the Client class that we wil use + * + * Call this function immediatly after the ctor. This functionality + * is placed in it's own function for flexibility reasons, but it + * is critical that this function is called before anything else + */ + void set_client(Client* c) { m_client = c; } +private: + + /** @brief debugging print function, only prints if m_debug is true */ + template + constexpr void m_print(const T str) const { + if (m_debug) { + Serial.print("SSLClientImpl: "); + Serial.println(str); + } + } + /** run the bearssl engine until a certain state */ + int m_run_until(const unsigned target); + /** proxy for availble that returns the state */ + unsigned m_update_engine(); + // hold a reference to the client + Client* m_client; + // store pointers to the trust anchors + // should not be computed at runtime + const br_x509_trust_anchor *m_trust_anchors; + const size_t m_trust_anchors_num; + // store whether to enable debug logging + const bool m_debug; + // store the context values required for SSL + br_ssl_client_context m_sslctx; + br_x509_minimal_context m_x509ctx; + // use a mono-directional buffer by default to cut memory in half + // can expand to a bi-directional buffer with maximum of BR_SSL_BUFSIZE_BIDI + // or shrink to below BR_SSL_BUFSIZE_MONO, and bearSSL will adapt automatically + // simply edit this value to change the buffer size to the desired value + unsigned char m_iobuf[BR_SSL_BUFSIZE_MONO]; + static_assert(sizeof m_iobuf <= BR_SSL_BUFSIZE_BIDI, "m_iobuf must be below maximum buffer size"); + // store the index of where we are writing in the buffer + // so we can send our records all at once to prevent + // weird timing issues + size_t m_write_idx; + // store the last error code + +}; #endif /* SSLClientImpl_H_ */ \ No newline at end of file diff --git a/src/bearssl/TLS12_only_profile.c b/src/bearssl/TLS12_only_profile.c index 4400b5a..f253e2b 100644 --- a/src/bearssl/TLS12_only_profile.c +++ b/src/bearssl/TLS12_only_profile.c @@ -23,6 +23,7 @@ */ #include "bearssl.h" +#include "bearssl_ssl.h" /* * A "profile" is an initialisation function for a SSL context, that @@ -234,7 +235,7 @@ br_client_init_TLS12_only(br_ssl_client_context *cc, // br_ssl_engine_set_default_ecdsa(&cc->eng); //* Alternate: set implementations explicitly. // br_ssl_client_set_rsapub(cc, &br_rsa_i31_public); - br_ssl_client_set_rsavrfy(cc, &br_rsa_i15_pkcs1_vrfy); + br_ssl_engine_set_rsavrfy(&cc->eng, &br_rsa_i15_pkcs1_vrfy); br_ssl_engine_set_ec(&cc->eng, &br_ec_all_m15); br_ssl_engine_set_ecdsa(&cc->eng, &br_ecdsa_i15_vrfy_asn1); //*/