refactored SSLClient to use a reference to a client as opposed to an instance.

This commit is contained in:
Noah Laptop 2019-11-07 12:08:39 -08:00
parent 0ca631c627
commit 00f78f18e8
10 changed files with 281 additions and 505 deletions

View file

@ -20,6 +20,7 @@
#include "SSLClient.h"
#if defined(ARDUINO_ARCH_SAMD)
// system reset definitions
static constexpr auto SYSRESETREQ = (1<<2);
static constexpr auto VECTKEY = (0x05fa0000UL);
@ -29,6 +30,7 @@ static constexpr auto VECTKEY_MASK = (0x0000ffffUL);
(*(uint32_t*)0xe000ed0cUL)=((*(uint32_t*)0xe000ed0cUL)&VECTKEY_MASK)|VECTKEY|SYSRESETREQ;
while(1) { }
}
#endif
#ifdef __arm__
// should use uinstd.h to define sbrk but Due causes a conflict
@ -49,15 +51,22 @@ static int freeMemory() {
#endif // __arm__
}
/* see SSLClientImpl.h */
SSLClientImpl::SSLClientImpl(const br_x509_trust_anchor *trust_anchors,
const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug)
: m_analog_pin(analog_pin)
, m_session_index(0)
/* see SSLClient.h */
SSLClient::SSLClient( Client& client,
const br_x509_trust_anchor *trust_anchors,
const size_t trust_anchors_num,
const int analog_pin,
const size_t max_sessions,
const DebugLevel debug)
: m_client(client)
, m_sessions()
, m_max_sessions(max_sessions)
, m_analog_pin(analog_pin)
, m_debug(debug)
, m_is_connected(false)
, m_write_idx(0) {
setTimeout(30*1000);
// zero the iobuf just in case it's still garbage
memset(m_iobuf, 0, sizeof m_iobuf);
// initlalize the various bearssl libraries so they're ready to go when we connect
@ -69,8 +78,8 @@ SSLClientImpl::SSLClientImpl(const br_x509_trust_anchor *trust_anchors,
br_ssl_engine_set_buffer(&m_sslctx.eng, m_iobuf, sizeof m_iobuf, duplex);
}
/* see SSLClientImpl.h*/
int SSLClientImpl::connect_impl(IPAddress ip, uint16_t port) {
/* see SSLClient.h*/
int SSLClient::connect(IPAddress ip, uint16_t port) {
const char* func_name = __func__;
// connection check
if (get_arduino_client().connected()) {
@ -89,11 +98,11 @@ int SSLClientImpl::connect_impl(IPAddress ip, uint16_t port) {
return 0;
}
m_info("Base client connected!", func_name);
return m_start_ssl(NULL, get_session_impl(NULL, ip));
return m_start_ssl(nullptr);
}
/* see SSLClientImpl.h*/
int SSLClientImpl::connect_impl(const char *host, uint16_t port) {
/* see SSLClient.h*/
int SSLClient::connect(const char *host, uint16_t port) {
const char* func_name = __func__;
// connection check
if (get_arduino_client().connected()) {
@ -105,15 +114,7 @@ int SSLClientImpl::connect_impl(const char *host, uint16_t port) {
m_write_idx = 0;
// first, if we have a session, check if we're trying to resolve the same host
// as before
bool connect_ok;
SSLSession& ses = get_session_impl(host, INADDR_NONE);
if (ses.is_valid_session()) {
// if so, then connect using the stored session
m_info("Connecting using a cached IP", func_name);
connect_ok = get_arduino_client().connect(ses.get_ip(), port);
}
// else connect with the provided hostname
else connect_ok = get_arduino_client().connect(host, port);
const bool connect_ok = get_arduino_client().connect(host, port);
// first we need our hidden client member to negotiate the socket for us,
// since most times socket functionality is implemented in hardeware.
if (!connect_ok) {
@ -123,11 +124,11 @@ int SSLClientImpl::connect_impl(const char *host, uint16_t port) {
}
m_info("Base client connected!", func_name);
// start ssl!
return m_start_ssl(host, ses);
return m_start_ssl(host, getSession(host));
}
/* see SSLClientImpl.h*/
size_t SSLClientImpl::write_impl(const uint8_t *buf, size_t size) {
/* see SSLClient.h*/
size_t SSLClient::write(const uint8_t *buf, size_t size) {
const char* func_name = __func__;
// check if the socket is still open and such
if (!m_soft_connected(func_name) || !buf || !size) return 0;
@ -169,8 +170,8 @@ size_t SSLClientImpl::write_impl(const uint8_t *buf, size_t size) {
return size;
}
/* see SSLClientImpl.h*/
int SSLClientImpl::available_impl() {
/* see SSLClient.h*/
int SSLClient::available() {
const char* func_name = __func__;
// connection check
if (!m_soft_connected(func_name)) return 0;
@ -190,10 +191,10 @@ int SSLClientImpl::available_impl() {
return 0;
}
/* see SSLClientImpl.h */
int SSLClientImpl::read_impl(uint8_t *buf, size_t size) {
/* see SSLClient.h */
int SSLClient::read(uint8_t *buf, size_t size) {
// check that the engine is ready to read
if (available_impl() <= 0 || !size) return -1;
if (available() <= 0 || !size) return -1;
// read the buffer, send the ack, and return the bytes read
size_t alen;
unsigned char* br_buf = br_ssl_engine_recvapp_buf(&m_sslctx.eng, &alen);
@ -205,10 +206,10 @@ int SSLClientImpl::read_impl(uint8_t *buf, size_t size) {
return read_amount;
}
/* see SSLClientImpl.h */
int SSLClientImpl::peek_impl() {
/* see SSLClient.h */
int SSLClient::peek() {
// check that the engine is ready to read
if (available_impl() <= 0) return -1;
if (available() <= 0) return -1;
// read the buffer, send the ack, and return the bytes read
size_t alen;
uint8_t read_num;
@ -217,15 +218,14 @@ int SSLClientImpl::peek_impl() {
return (int)read_num;
}
/* see SSLClientImpl.h */
void SSLClientImpl::flush_impl() {
/* see SSLClient.h */
void SSLClient::flush() {
if (m_write_idx > 0)
if(m_run_until(BR_SSL_RECVAPP) < 0) m_error("Could not flush write buffer!", __func__);
}
/* see SSLClientImpl.h */
void SSLClientImpl::stop_impl() {
const char* func_name = __func__;
/* see SSLClient.h */
void SSLClient::stop() {
// tell the SSL connection to gracefully close
br_ssl_engine_close(&m_sslctx.eng);
// if the engine isn't closed, and the socket is still open
@ -240,7 +240,7 @@ void SSLClientImpl::stop_impl() {
*/
size_t len;
if (br_ssl_engine_recvapp_buf(&m_sslctx.eng, &len) != NULL) {
if (br_ssl_engine_recvapp_buf(&m_sslctx.eng, &len) != nullptr) {
br_ssl_engine_recvapp_ack(&m_sslctx.eng, len);
}
}
@ -251,8 +251,8 @@ void SSLClientImpl::stop_impl() {
m_is_connected = false;
}
/* see SSLClientImpl.h */
uint8_t SSLClientImpl::connected_impl() {
/* see SSLClient.h */
uint8_t SSLClient::connected() {
const char* func_name = __func__;
// check all of the error cases
const auto c_con = get_arduino_client().connected();
@ -273,7 +273,7 @@ uint8_t SSLClientImpl::connected_impl() {
// we are not connected
m_is_connected = false;
// set the write error so the engine doesn't try to close the connection
stop_impl();
stop();
}
else if (!wr_ok) {
m_error("Not connected because write error is set", func_name);
@ -282,38 +282,32 @@ uint8_t SSLClientImpl::connected_impl() {
return c_con && br_con;
}
/* see SSLClientImpl.h */
SSLSession& SSLClientImpl::get_session_impl(const char* host, const IPAddress& addr) {
/* see SSLClient.h */
SSLSession* SSLClient::getSession(const char* host) {
const char* func_name = __func__;
// search for a matching session with the IP
int temp_index = m_get_session_index(host, addr);
int temp_index = m_get_session_index(host);
// if none are availible, use m_session_index
if (temp_index == -1) {
temp_index = m_session_index;
// reset the session so we don't try to send one sites session to another
get_session_array()[temp_index].clear_parameters();
}
// increment m_session_index so the session cache is a circular buffer
if (temp_index == m_session_index && ++m_session_index >= getSessionCount()) m_session_index = 0;
if (temp_index < 0) return nullptr;
// return the pointed to value
m_info("Using session index: ", func_name);
m_info(temp_index, func_name);
return get_session_array()[temp_index];
return &(m_sessions[temp_index]);
}
/* see SSLClientImpl.h */
void SSLClientImpl::remove_session_impl(const char* host, const IPAddress& addr) {
/* see SSLClient.h */
void SSLClient::removeSession(const char* host) {
const char* func_name = __func__;
int temp_index = m_get_session_index(host, addr);
if (temp_index != -1) {
int temp_index = m_get_session_index(host);
if (temp_index >= 0) {
m_info(" Deleted session ", func_name);
m_info(temp_index, func_name);
get_session_array()[temp_index].clear_parameters();
m_sessions.erase(m_sessions.begin() + static_cast<size_t>(temp_index));
}
}
/* see SSLClientImpl.h */
void SSLClientImpl::set_mutual_impl(const SSLClientParameters* params) {
/* see SSLClient.h */
void SSLClient::setMutualAuthParams(const SSLClientParameters* params) {
// if mutual authentication if needed, configure bearssl to support it.
if (params != nullptr)
br_ssl_client_set_single_ec( &m_sslctx,
@ -326,7 +320,7 @@ void SSLClientImpl::set_mutual_impl(const SSLClientParameters* params) {
&br_ecdsa_i15_sign_asn1);
}
bool SSLClientImpl::m_soft_connected(const char* func_name) {
bool SSLClient::m_soft_connected(const char* func_name) {
// check if the socket is still open and such
if (getWriteError()) {
m_error("Cannot operate if the write error is not reset: ", func_name);
@ -343,8 +337,8 @@ bool SSLClientImpl::m_soft_connected(const char* func_name) {
return true;
}
/* see SSLClientImpl.h */
int SSLClientImpl::m_start_ssl(const char* host, SSLSession& ssl_ses) {
/* see SSLClient.h */
int SSLClient::m_start_ssl(const char* host, SSLSession* ssl_ses) {
const char* func_name = __func__;
// clear the write error
setWriteError(SSL_OK);
@ -352,11 +346,12 @@ int SSLClientImpl::m_start_ssl(const char* host, SSLSession& ssl_ses) {
// we want 128 bits to be safe, as recommended by the bearssl docs
uint8_t rng_seeds[16];
// take the bottom 8 bits of the analog read
for (uint8_t i = 0; i < sizeof rng_seeds; i++) rng_seeds[i] = static_cast<uint8_t>(analogRead(m_analog_pin));
for (uint8_t i = 0; i < sizeof rng_seeds; i++)
rng_seeds[i] = static_cast<uint8_t>(analogRead(m_analog_pin));
br_ssl_engine_inject_entropy(&m_sslctx.eng, rng_seeds, sizeof rng_seeds);
// inject session parameters for faster reconnection, if we have any
if(ssl_ses.is_valid_session()) {
br_ssl_engine_set_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session());
if(ssl_ses != nullptr) {
br_ssl_engine_set_session_parameters(&m_sslctx.eng, ssl_ses->to_br_session());
m_info("Set SSL session!", func_name);
}
// reset the engine, but make sure that it reset successfully
@ -379,24 +374,27 @@ int SSLClientImpl::m_start_ssl(const char* host, SSLSession& ssl_ses) {
m_is_connected = true;
// all good to go! the SSL socket should be up and running
// overwrite the session we got with new parameters
br_ssl_engine_get_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session());
// print the cipher suite
m_info("Used cipher suite: ", func_name);
m_info(ssl_ses.cipher_suite, func_name);
// set the hostname and ip in the session as well
ssl_ses.set_parameters(remoteIP(), host);
if (ssl_ses != nullptr)
br_ssl_engine_get_session_parameters(&m_sslctx.eng, ssl_ses->to_br_session());
else if (host != nullptr) {
if (m_sessions.size() >= m_max_sessions)
m_sessions.erase(m_sessions.begin());
SSLSession session(host);
br_ssl_engine_get_session_parameters(&m_sslctx.eng, session.to_br_session());
m_sessions.push_back(session);
}
return 1;
}
/* see SSLClientImpl.h*/
int SSLClientImpl::m_run_until(const unsigned target) {
/* see SSLClient.h */
int SSLClient::m_run_until(const unsigned target) {
const char* func_name = __func__;
unsigned lastState = 0;
size_t lastLen = 0;
const unsigned long start = millis();
for (;;) {
unsigned state = m_update_engine();
// error check
// error check
if (state == BR_SSL_CLOSED || getWriteError() != SSL_OK) {
return -1;
}
@ -404,7 +402,7 @@ int SSLClientImpl::m_run_until(const unsigned target) {
if (millis() - start > getTimeout()) {
m_error("SSL internals timed out! This could be an internal error, bad data sent from the server, or data being discarded due to a buffer overflow. If you are using Ethernet, did you modify the library properly (see README)?", func_name);
setWriteError(SSL_BR_WRITE_ERROR);
stop_impl();
stop();
return -1;
}
// debug
@ -448,7 +446,7 @@ int SSLClientImpl::m_run_until(const unsigned target) {
*/
if (state & BR_SSL_RECVAPP && target & BR_SSL_SENDAPP) {
size_t len;
if (br_ssl_engine_recvapp_buf(&m_sslctx.eng, &len) != NULL) {
if (br_ssl_engine_recvapp_buf(&m_sslctx.eng, &len) != nullptr) {
m_write_idx = 0;
m_warn("Discarded unread data to favor a write operation", func_name);
br_ssl_engine_recvapp_ack(&m_sslctx.eng, len);
@ -457,7 +455,7 @@ int SSLClientImpl::m_run_until(const unsigned target) {
else {
m_error("SSL engine state is RECVAPP, however the buffer was null! (This is a problem with BearSSL internals)", func_name);
setWriteError(SSL_BR_WRITE_ERROR);
stop_impl();
stop();
return -1;
}
}
@ -473,8 +471,8 @@ int SSLClientImpl::m_run_until(const unsigned target) {
}
}
/* see SSLClientImpl.h*/
unsigned SSLClientImpl::m_update_engine() {
/* see SSLClient.h*/
unsigned SSLClient::m_update_engine() {
const char* func_name = __func__;
for(;;) {
// get the state
@ -491,26 +489,21 @@ unsigned SSLClientImpl::m_update_engine() {
buf = br_ssl_engine_sendrec_buf(&m_sslctx.eng, &len);
wlen = get_arduino_client().write(buf, len);
// let the chip recover
if (wlen < 0) {
m_error("Error writing to m_client", func_name);
m_error(get_arduino_client().getWriteError(), func_name);
setWriteError(SSL_CLIENT_WRTIE_ERROR);
/*
* If we received a close_notify and we
* still send something, then we have our
* own response close_notify to send, and
* the peer is allowed by RFC 5246 not to
* wait for it.
*/
if (!&m_sslctx.eng.shutdown_recv) return 0;
stop_impl();
if (wlen <= 0) {
// if the arduino client encountered an error
if (get_arduino_client().getWriteError() || !get_arduino_client().connected()) {
m_error("Error writing to m_client", func_name);
m_error(get_arduino_client().getWriteError(), func_name);
setWriteError(SSL_CLIENT_WRTIE_ERROR);
}
// else presumably the socket just closed itself, so just stop the engine
stop();
return 0;
}
if (wlen > 0) {
br_ssl_engine_sendrec_ack(&m_sslctx.eng, wlen);
}
continue;
continue;
}
/*
@ -525,7 +518,7 @@ unsigned SSLClientImpl::m_update_engine() {
m_error(br_ssl_engine_current_state(&m_sslctx.eng), func_name);
m_error(br_ssl_engine_last_error(&m_sslctx.eng), func_name);
setWriteError(SSL_BR_WRITE_ERROR);
stop_impl();
stop();
return 0;
}
// else time to send the application data
@ -533,17 +526,17 @@ unsigned SSLClientImpl::m_update_engine() {
size_t alen;
unsigned char *buf = br_ssl_engine_sendapp_buf(&m_sslctx.eng, &alen);
// engine check
if (alen == 0 || buf == NULL) {
if (alen == 0 || buf == nullptr) {
m_error("Engine set write flag but returned null buffer", func_name);
setWriteError(SSL_BR_WRITE_ERROR);
stop_impl();
stop();
return 0;
}
// sanity check
if (alen < m_write_idx) {
m_error("Alen is less than m_write_idx", func_name);
setWriteError(SSL_INTERNAL_ERROR);
stop_impl();
stop();
return 0;
}
// all good? lets send the data
@ -570,8 +563,9 @@ unsigned SSLClientImpl::m_update_engine() {
unsigned char * buf = br_ssl_engine_recvrec_buf(&m_sslctx.eng, &len);
// do we have the record you're looking for?
const auto avail = get_arduino_client().available();
if (avail > 0 && avail >= len) {
if (avail > 0 && static_cast<size_t>(avail) >= len) {
int mem = freeMemory();
#if defined(ARDUINO_ARCH_SAMD)
// check for a stack overflow
// if the stack overflows we basically have to crash, and
// hope the user is ok with that
@ -581,6 +575,7 @@ unsigned SSLClientImpl::m_update_engine() {
// software reset
RESET();
}
#endif
// debug info
m_info("Memory: ", func_name);
m_info(mem, func_name);
@ -591,7 +586,7 @@ unsigned SSLClientImpl::m_update_engine() {
if(mem < 7000) {
m_error("Out of memory! Decrease the number of sessions or the size of m_iobuf", func_name);
setWriteError(SSL_OUT_OF_MEMORY);
stop_impl();
stop();
return 0;
}
// I suppose so!
@ -600,7 +595,7 @@ unsigned SSLClientImpl::m_update_engine() {
m_error("Error reading bytes from m_client. Write Error: ", func_name);
m_error(get_arduino_client().getWriteError(), func_name);
setWriteError(SSL_CLIENT_WRTIE_ERROR);
stop_impl();
stop();
return 0;
}
if (rlen > 0) {
@ -626,19 +621,14 @@ unsigned SSLClientImpl::m_update_engine() {
}
/* see SSLClientImpl.h */
int SSLClientImpl::m_get_session_index(const char* host, const IPAddress& addr) const {
int SSLClient::m_get_session_index(const char* host) const {
const char* func_name = __func__;
if(host == nullptr) return -1;
// search for a matching session with the IP
for (uint8_t i = 0; i < getSessionCount(); i++) {
// if we're looking at a real session
if (get_session_array()[i].is_valid_session()
&& (
// and the hostname matches, or
(host != NULL && get_session_array()[i].get_hostname().equals(host))
// there is no hostname and the IP address matches
|| (host == NULL && addr == get_session_array()[i].get_ip())
)) {
m_info(get_session_array()[i].get_hostname(), func_name);
if (m_sessions[i].get_hostname().equals(host)) {
m_info(m_sessions[i].get_hostname(), func_name);
return i;
}
}
@ -646,8 +636,8 @@ int SSLClientImpl::m_get_session_index(const char* host, const IPAddress& addr)
return -1;
}
/* See SSLClientImpl.h */
void SSLClientImpl::m_print_prefix(const char* func_name, const DebugLevel level) const
/* See SSLClient.h */
void SSLClient::m_print_prefix(const char* func_name, const DebugLevel level) const
{
// print the sslclient prefix
Serial.print("(SSLClient)");
@ -664,8 +654,8 @@ void SSLClientImpl::m_print_prefix(const char* func_name, const DebugLevel level
Serial.print("): ");
}
/* See SSLClientImpl.h */
void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel level) const {
/* See SSLClient.h */
void SSLClient::m_print_ssl_error(const int ssl_error, const DebugLevel level) const {
if (level > m_debug) return;
m_print_prefix(__func__, level);
switch(ssl_error) {
@ -679,8 +669,8 @@ void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel leve
}
}
/* See SSLClientImpl.h */
void SSLClientImpl::m_print_br_error(const unsigned br_error_code, const DebugLevel level) const {
/* See SSLClient.h */
void SSLClient::m_print_br_error(const unsigned br_error_code, const DebugLevel level) const {
if (level > m_debug) return;
m_print_prefix(__func__, level);
switch (br_error_code) {
@ -744,4 +734,4 @@ void SSLClientImpl::m_print_br_error(const unsigned br_error_code, const DebugLe
case BR_ERR_X509_NOT_TRUSTED: Serial.println("Chain could not be linked to a trust anchor."); break;
default: Serial.println("Unknown error code."); break;
}
}
}

View file

@ -19,10 +19,10 @@
*/
#include "Client.h"
#include "SSLClientImpl.h"
#include "SSLSession.h"
#include "SSLClientParameters.h"
#include "SSLObj.h"
#include <vector>
#ifndef SSLClient_H_
#define SSLClient_H_
@ -32,26 +32,49 @@
* Check out README.md for more info.
*/
template <class C, size_t SessionCache = 1>
class SSLClient : public SSLClientImpl {
/*
* static checks
* I'm a java developer, so I want to ensure that my inheritance is safe.
* These checks ensure that all the functions we use on class C are
* actually present on class C. It does this by checking that the
* class inherits from Client.
*
* Additionally, I ran into a lot of memory issues with large sessions caches.
* Since each session contains at max 352 bytes of memory, they eat of the
* stack quite quickly and can cause overflows. As a result, I have added a
* warning here to discourage the use of more than 3 sessions at a time. Any
* amount past that will require special modification of this library, and
* assumes you know what you are doing.
*/
static_assert(SessionCache > 0 && SessionCache < 255, "There can be no less than one and no more than 255 sessions in the cache!");
static_assert(SessionCache <= 3, "You need to decrease the size of m_iobuf in order to have more than 3 sessions at once, otherwise memory issues will occur.");
class SSLClient : public Client {
public:
/**
* @brief Static constants defining the possible errors encountered.
*
* If SSLClient encounters an error, it will generally output
* logs into the serial monitor. If you need a way of programmatically
* checking the errors, you can do so with SSLClient::getWriteError(),
* which will return one of these values.
*/
enum Error {
SSL_OK = 0,
/** The underlying client failed to connect, probably not an issue with SSL */
SSL_CLIENT_CONNECT_FAIL,
/** BearSSL failed to complete the SSL handshake, check logs for bear ssl error output */
SSL_BR_CONNECT_FAIL,
/** The underlying client failed to write a payload, probably not an issue with SSL */
SSL_CLIENT_WRTIE_ERROR,
/** An internal error occurred with BearSSL, check logs for diagnosis. */
SSL_BR_WRITE_ERROR,
/** An internal error occurred with SSLClient, and you probably need to submit an issue on Github. */
SSL_INTERNAL_ERROR,
/** SSLClient detected that there was not enough memory (>8000 bytes) to continue. */
SSL_OUT_OF_MEMORY
};
/**
* @brief Level of verbosity used in logging for SSLClient.
*
* Use these values when initializing SSLClient to set how many logs you
* would like to see in the Serial monitor.
*/
enum DebugLevel {
/** No logging output */
SSL_NONE = 0,
/** Only output errors that result in connection failure */
SSL_ERROR = 1,
/** Output errors and warnings (useful when just starting to develop) */
SSL_WARN = 2,
/** Output errors, warnings, and internal information (very verbose) */
SSL_INFO = 3,
};
/**
* @brief Initialize SSLClient with all of the prerequisites needed.
*
@ -66,25 +89,18 @@ public:
* of the SSL server certificate. Check out TrustAnchors.md for more info.
* @param trust_anchors_num The number of objects in the trust_anchors array.
* @param analog_pin An analog pin to pull random bytes from, used in seeding the RNG.
* @param max_sessions The maximum number of SSL sessions to store connection information from.
* @param debug The level of debug logging (use the ::DebugLevel enum).
* @param mutual_auth_params Configuration to use for mutual authentication, nullptr to disable mutual auth. (see ::SSLClientParameters).
*/
explicit SSLClient( const C& client,
explicit SSLClient( Client& client,
const br_x509_trust_anchor *trust_anchors,
const size_t trust_anchors_num,
const int analog_pin,
const DebugLevel debug = SSL_WARN)
: SSLClientImpl(trust_anchors, trust_anchors_num, analog_pin, debug)
, m_client(client)
, m_sessions{}
{
// set the timeout to a reasonable number (it can always be changes later)
// SSL Connections take a really long time so we don't want to time out a legitimate thing
setTimeout(30 * 1000);
}
const size_t max_sessions = 1,
const DebugLevel debug = SSL_WARN);
//========================================
//= Functions implemented in SSLClientImpl
//= Functions implemented in SSLClient.cpp
//========================================
/**
@ -126,7 +142,7 @@ public:
* @param port the port to connect to
* @returns 1 if success, 0 if failure
*/
int connect(IPAddress ip, uint16_t port) override { return connect_impl(ip, port); }
int connect(IPAddress ip, uint16_t port) override;
/**
* @brief Connect over SSL to a host specified by a hostname.
@ -164,10 +180,8 @@ public:
* @param port The port to connect to on the host (443 for HTTPS)
* @returns 1 of success, 0 if failure
*/
int connect(const char *host, uint16_t port) override { return connect_impl(host, port); }
int connect(const char *host, uint16_t port) override;
/** @see SSLClient::write(uint8_t*, size_t) */
size_t write(uint8_t b) override { return write_impl(&b, 1); }
/**
* @brief Write some bytes to the SSL connection
*
@ -191,7 +205,9 @@ public:
* @returns The number of bytes copied to the buffer (size), or zero if the BearSSL engine
* fails to become ready for writing data.
*/
size_t write(const uint8_t *buf, size_t size) override { return write_impl(buf, size); }
size_t write(const uint8_t *buf, size_t size) override;
/** @see SSLClient::write(uint8_t*, size_t) */
size_t write(uint8_t b) override { return write(&b, 1); }
/**
* @brief Returns the number of bytes available to read from the data that has been received and decrypted.
@ -211,13 +227,8 @@ public:
* @returns The number of bytes available (can be zero), or zero if any of the pre
* conditions aren't satisfied.
*/
int available() override { return available_impl(); }
int available() override;
/**
* @brief Read a single byte, or -1 if none is available.
* @see SSLClient::read(uint8_t*, size_t)
*/
int read() override { uint8_t read_val; return read(&read_val, 1) > 0 ? read_val : -1; };
/**
* @brief Read size bytes from the SSL client buffer, copying them into *buf, and return the number of bytes read.
*
@ -239,7 +250,12 @@ public:
*
* @returns The number of bytes copied (<= size), or -1 if the preconditions are not satisfied.
*/
int read(uint8_t *buf, size_t size) override { return read_impl(buf, size); }
int read(uint8_t *buf, size_t size) override;
/**
* @brief Read a single byte, or -1 if none is available.
* @see SSLClient::read(uint8_t*, size_t)
*/
int read() override { uint8_t read_val; return read(&read_val, 1) > 0 ? read_val : -1; };
/**
* @brief View the first byte of the buffer, without removing it from the SSLClient Buffer
@ -249,7 +265,7 @@ public:
* @returns The first byte received, or -1 if the preconditions are not satisfied (warning:
* do not use if your data may be -1, as the return value is ambiguous)
*/
int peek() override { return peek_impl(); }
int peek() override;
/**
* @brief Force writing the buffered bytes from SSLClient::write to the network.
@ -258,7 +274,7 @@ public:
* an explanation of how writing with SSLClient works, please see SSLClient::write.
* The implementation for this function can be found in SSLClientImpl::flush.
*/
void flush() override { return flush_impl(); }
void flush() override;
/**
* @brief Close the connection
@ -268,7 +284,7 @@ public:
* error was encountered previously, this function will simply call m_client::stop.
* The implementation for this function can be found in SSLClientImpl::peek.
*/
void stop() override { return stop_impl(); }
void stop() override;
/**
* @brief Check if the device is connected.
@ -283,7 +299,7 @@ public:
*
* @returns 1 if connected, 0 if not
*/
uint8_t connected() override { return connected_impl(); }
uint8_t connected() override;
//========================================
//= Functions Not in the Client Interface
@ -297,7 +313,7 @@ public:
*
* @pre SSLClient has not already started an SSL connection.
*/
void setMutualAuthParams(const SSLClientParameters* params) { return set_mutual_impl(params); }
void setMutualAuthParams(const SSLClientParameters* params);
/**
* @brief Gets a session reference corresponding to a host and IP, or a reference to a empty session if none exist
@ -311,26 +327,26 @@ public:
*
* @param host A hostname c string, or NULL if one is not available
* @param addr An IP address
* @returns A reference to an SSLSession object
* @returns A pointer to the SSLSession, or NULL of none matched the criteria available
*/
SSLSession& getSession(const char* host, const IPAddress& addr) { return get_session_impl(host, addr); }
SSLSession* getSession(const char* host);
/**
* @brief Clear the session corresponding to a host and IP
*
* The implementation for this function can be found at SSLClientImpl::remove_session_impl.
*
* @param host A hostname c string, or NULL if one is not available
* @param host A hostname c string, or nullptr if one is not available
* @param addr An IP address
*/
void removeSession(const char* host, const IPAddress& addr) { return remove_session_impl(host, addr); }
void removeSession(const char* host);
/**
* @brief Get the maximum number of SSL sessions that can be stored at once
*
* @returns The SessionCache template parameter.
*/
size_t getSessionCount() const override { return SessionCache; }
size_t getSessionCount() const { return m_sessions.size(); }
/**
* @brief Equivalent to SSLClient::connected() > 0
@ -338,37 +354,92 @@ public:
* @returns true if connected, false if not
*/
operator bool() { return connected() > 0; }
/** @see SSLClient::operator bool */
bool operator==(const bool value) { return bool() == value; }
/** @see SSLClient::operator bool */
bool operator!=(const bool value) { return bool() != value; }
/** @brief Returns whether or not two SSLClient objects have the same underlying client object */
bool operator==(const C& rhs) { return m_client == rhs; }
/** @brief Returns whether or not two SSLClient objects do not have the same underlying client object */
bool operator!=(const C& rhs) { return m_client != rhs; }
/** @brief Returns the local port, if C::localPort exists */
uint16_t localPort() override { return m_client.localPort(); }
/** @brief Returns the remote IP, if C::remoteIP exists. */
IPAddress remoteIP() override { return m_client.remoteIP(); }
/** @brief Returns the remote port, if C::remotePort exists. Else return 0. */
uint16_t remotePort() override { return m_client.remotePort(); }
/** @brief Returns a reference to the client object stored in this class. Take care not to break it. */
C& getClient() { return m_client; }
protected:
/** @brief Returns an instance of m_client that is polymorphic and can be used by SSLClientImpl */
Client& get_arduino_client() override { return m_client; }
const Client& get_arduino_client() const override { return m_client; }
/** @brief Returns an instance of the session array that is on the stack */
SSLSession* get_session_array() override { return m_sessions; }
const SSLSession* get_session_array() const override { return m_sessions; }
Client& getClient() { return m_client; }
private:
/** @brief Returns an instance of m_client that is polymorphic and can be used by SSLClientImpl */
Client& get_arduino_client() { return m_client; }
const Client& get_arduino_client() const { return m_client; }
/** Returns whether or not the engine is connected, without polling the client over SPI or other (as opposed to connected()) */
bool m_soft_connected(const char* func_name);
/** start the ssl engine on the connected client */
int m_start_ssl(const char* host = nullptr, SSLSession* ssl_ses = nullptr);
/** run the bearssl engine until a certain state */
int m_run_until(const unsigned target);
/** proxy for available that returns the state */
unsigned m_update_engine();
/** utility function to find a session index based off of a host and IP */
int m_get_session_index(const char* host) const;
/** @brief Prints a debugging prefix to all logs, so we can attatch them to useful information */
void m_print_prefix(const char* func_name, const DebugLevel level) const;
/** @brief Prints the string associated with a write error */
void m_print_ssl_error(const int ssl_error, const DebugLevel level) const;
/** @brief Print the text string associated with a BearSSL error code */
void m_print_br_error(const unsigned br_error_code, const DebugLevel level) const;
/** @brief debugging print function, only prints if m_debug is true */
template<typename T>
void m_print(const T str, const char* func_name, const DebugLevel level) const {
// check the current debug level and serial status
if (level > m_debug || !Serial) return;
// print prefix
m_print_prefix(func_name, level);
// print the message
Serial.println(str);
}
/** @brief Prints a info message to serial, if info messages are enabled */
template<typename T>
void m_info(const T str, const char* func_name) const { m_print(str, func_name, SSL_INFO); }
template<typename T>
void m_warn(const T str, const char* func_name) const { m_print(str, func_name, SSL_WARN); }
template<typename T>
void m_error(const T str, const char* func_name) const { m_print(str, func_name, SSL_ERROR); }
//============================================
//= Data Members
//============================================
// create a copy of the client
C m_client;
Client& m_client;
// also store an array of SSLSessions, so we can resume communication with multiple websites
SSLSession m_sessions[SessionCache];
std::vector<SSLSession> m_sessions;
// as well as the maximmum number of sessions we can store
const size_t m_max_sessions;
// store the pin to fetch an RNG see from
const int m_analog_pin;
// store whether to enable debug logging
const DebugLevel m_debug;
// store if we are connected in bearssl or not
bool m_is_connected;
// store the context values required for SSL
br_ssl_client_context m_sslctx;
br_x509_minimal_context m_x509ctx;
// use a mono-directional buffer by default to cut memory in half
// can expand to a bi-directional buffer with maximum of BR_SSL_BUFSIZE_BIDI
// or shrink to below BR_SSL_BUFSIZE_MONO, and bearSSL will adapt automatically
// simply edit this value to change the buffer size to the desired value
// additionally, we need to correct buffer size based off of how many sessions we decide to cache
// since SSL takes so much memory if we don't it will cause the stack and heap to collide
/**
* @brief The internal buffer to use with BearSSL.
* This buffer controls how much data BearSSL can encrypt/decrypt at a given time. It can be expanded
* or shrunk to [255, BR_SSL_BUFSIZE_BIDI], depending on the memory and speed needs of your application.
* As a rule of thumb SSLClient will fail if it does not have at least 8000 bytes when starting a
* connection.
*/
unsigned char m_iobuf[2048];
// store the index of where we are writing in the buffer
// so we can send our records all at once to prevent
// weird timing issues
size_t m_write_idx;
};
#endif /** SSLClient_H_ */

View file

@ -1,213 +0,0 @@
/* Copyright 2019 OSU OPEnS Lab
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this
* software and associated documentation files (the "Software"), to deal in the Software
* without restriction, including without limitation the rights to use, copy, modify,
* merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
* OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "bearssl.h"
#include "Arduino.h"
#include "Client.h"
#include "SSLSession.h"
#include "SSLClientParameters.h"
#ifndef SSLClientImpl_H_
#define SSLClientImpl_H_
/**
* @brief Static constants defining the possible errors encountered.
*
* If SSLClient encounters an error, it will generally output
* logs into the serial monitor. If you need a way of programmatically
* checking the errors, you can do so with SSLClient::getWriteError(),
* which will return one of these values.
*/
enum Error {
SSL_OK = 0,
/** The underlying client failed to connect, probably not an issue with SSL */
SSL_CLIENT_CONNECT_FAIL,
/** BearSSL failed to complete the SSL handshake, check logs for bear ssl error output */
SSL_BR_CONNECT_FAIL,
/** The underlying client failed to write a payload, probably not an issue with SSL */
SSL_CLIENT_WRTIE_ERROR,
/** An internal error occurred with BearSSL, check logs for diagnosis. */
SSL_BR_WRITE_ERROR,
/** An internal error occurred with SSLClient, and you probably need to submit an issue on Github. */
SSL_INTERNAL_ERROR,
/** SSLClient detected that there was not enough memory (>8000 bytes) to continue. */
SSL_OUT_OF_MEMORY
};
/**
* @brief Level of verbosity used in logging for SSLClient.
*
* Use these values when initializing SSLClient to set how many logs you
* would like to see in the Serial monitor.
*/
enum DebugLevel {
/** No logging output */
SSL_NONE = 0,
/** Only output errors that result in connection failure */
SSL_ERROR = 1,
/** Output errors and warnings (useful when just starting to develop) */
SSL_WARN = 2,
/** Output errors, warnings, and internal information (very verbose) */
SSL_INFO = 3,
};
/** @brief Implementation code to be inherited by SSLClient */
class SSLClientImpl : public Client {
public:
/** @see SSLClient::SSLClient */
explicit SSLClientImpl(const br_x509_trust_anchor *trust_anchors,
const size_t trust_anchors_num, const int analog_pin,
const DebugLevel debug);
/** @see SSLClient::SSLClient */
explicit SSLClientImpl(const br_x509_trust_anchor *trust_anchors,
const size_t trust_anchors_num, const int analog_pin,
const DebugLevel debug, const SSLClientParameters* mutual_auth_params);
//============================================
//= Functions implemented in SSLClientImpl.cpp
//============================================
/** @see SSLClient::connect(IPAddress, uint16_t) */
int connect_impl(IPAddress ip, uint16_t port);
/** @see SSLClient::connect(const char*, uint16_t) */
int connect_impl(const char *host, uint16_t port);
/** @see SSLClient::write(const uint8_t*, size_t) */
size_t write_impl(const uint8_t *buf, size_t size);
/** @see SSLClient::available */
int available_impl();
/** @see SSLClient::read(uint8_t*, size_t) */
int read_impl(uint8_t *buf, size_t size);
/** @see SSLClient::peek */
int peek_impl();
/** @see SSLClient::flush */
void flush_impl();
/** @see SSLClient::stop */
void stop_impl();
/** @see SSLClient::connected */
uint8_t connected_impl();
/** @see SSLClient::getSession */
SSLSession& get_session_impl(const char* host, const IPAddress& addr);
/** @see SSLClient::removeSession */
void remove_session_impl(const char* host, const IPAddress& addr);
/** @see SSLClient::setMutualAuthParams */
void set_mutual_impl(const SSLClientParameters* params);
//============================================
//= Functions implemented in SSLClient.h
//============================================
/** @see SSLClient::localPort */
virtual uint16_t localPort() = 0;
/** @see SSLClient::remoteIP */
virtual IPAddress remoteIP() = 0;
/** @see SSLClient::localPort */
virtual uint16_t remotePort() = 0;
/** @see SSLClient::getSessionCount */
virtual size_t getSessionCount() const = 0;
protected:
/** @see SSLClient::get_arduino_client */
virtual Client& get_arduino_client() = 0;
virtual const Client& get_arduino_client() const = 0;
/** @see SSLClient::get_session_array */
virtual SSLSession* get_session_array() = 0;
virtual const SSLSession* get_session_array() const = 0;
//============================================
//= Functions implemented in SSLClientImpl.cpp
//============================================
/** @brief Prints a debugging prefix to all logs, so we can attatch them to useful information */
void m_print_prefix(const char* func_name, const DebugLevel level) const;
/** @brief Prints the string associated with a write error */
void m_print_ssl_error(const int ssl_error, const DebugLevel level) const;
/** @brief Print the text string associated with a BearSSL error code */
void m_print_br_error(const unsigned br_error_code, const DebugLevel level) const;
/** @brief debugging print function, only prints if m_debug is true */
template<typename T>
void m_print(const T str, const char* func_name, const DebugLevel level) const {
// check the current debug level and serial status
if (level > m_debug || !Serial) return;
// print prefix
m_print_prefix(func_name, level);
// print the message
Serial.println(str);
}
/** @brief Prints a info message to serial, if info messages are enabled */
template<typename T>
void m_info(const T str, const char* func_name) const { m_print(str, func_name, SSL_INFO); }
template<typename T>
void m_warn(const T str, const char* func_name) const { m_print(str, func_name, SSL_WARN); }
template<typename T>
void m_error(const T str, const char* func_name) const { m_print(str, func_name, SSL_ERROR); }
private:
/** Returns whether or not the engine is connected, without polling the client over SPI or other (as opposed to connected()) */
bool m_soft_connected(const char* func_name);
/** start the ssl engine on the connected client */
int m_start_ssl(const char* host, SSLSession& ssl_ses);
/** run the bearssl engine until a certain state */
int m_run_until(const unsigned target);
/** proxy for available that returns the state */
unsigned m_update_engine();
/** utility function to find a session index based off of a host and IP */
int m_get_session_index(const char* host, const IPAddress& addr) const;
//============================================
//= Data Members
//============================================
// store the pin to fetch an RNG see from
const int m_analog_pin;
// store an index of where a new session can be placed if we don't have any corresponding sessions
size_t m_session_index;
// store whether to enable debug logging
const DebugLevel m_debug;
// store if we are connected in bearssl or not
bool m_is_connected;
// store the context values required for SSL
br_ssl_client_context m_sslctx;
br_x509_minimal_context m_x509ctx;
// use a mono-directional buffer by default to cut memory in half
// can expand to a bi-directional buffer with maximum of BR_SSL_BUFSIZE_BIDI
// or shrink to below BR_SSL_BUFSIZE_MONO, and bearSSL will adapt automatically
// simply edit this value to change the buffer size to the desired value
// additionally, we need to correct buffer size based off of how many sessions we decide to cache
// since SSL takes so much memory if we don't it will cause the stack and heap to collide
/**
* @brief The internal buffer to use with BearSSL.
* This buffer controls how much data BearSSL can encrypt/decrypt at a given time. It can be expanded
* or shrunk to [255, BR_SSL_BUFSIZE_BIDI], depending on the memory and speed needs of your application.
* As a rule of thumb SSLClient will fail if it does not have at least 8000 bytes when starting a
* connection.
*/
unsigned char m_iobuf[2048];
// store the index of where we are writing in the buffer
// so we can send our records all at once to prevent
// weird timing issues
size_t m_write_idx;
};
#endif /* SSLClientImpl_H_ */

View file

@ -1,24 +0,0 @@
#include "SSLSession.h"
/* See SSLSession.h */
void SSLSession::set_parameters(const IPAddress& ip, const char* hostname) {
// copy the hostname
if (hostname != NULL) m_hostname = hostname;
// or if there's no hostname, clear the string
else m_hostname = "";
// and the IP address
m_ip = ip;
// check if both values are valid, and if so set valid to true
if (m_ip != INADDR_NONE && session_id_len > 0
&& (hostname == NULL || m_hostname)) m_valid_session = true;
// else clear
else clear_parameters();
}
/* see SSLSession.h */
void SSLSession::clear_parameters() {
// clear the hostname , ip, and valid session flags
m_hostname = "";
m_ip = INADDR_NONE;
m_valid_session = false;
}

View file

@ -27,7 +27,6 @@
#include "bearssl.h"
#include "Arduino.h"
#include "IPAddress.h"
#ifndef SSLSession_H_
#define SSLSession_H_
@ -57,13 +56,8 @@ public:
*
* Sets all parameters to zero, and invalidates the session
*/
SSLSession()
: m_valid_session(false)
, m_hostname()
, m_ip(INADDR_NONE) {}
/** @brief use clear_parameters or set_parameters instead */
SSLSession& operator=(const SSLSession&) = delete;
SSLSession(const char* hostname)
: m_hostname(hostname) {}
/**
* @brief Get the hostname string associated with this session
@ -75,57 +69,12 @@ public:
*/
const String& get_hostname() const { return m_hostname; }
/**
* @brief Get ::IPAddress associated with this session
*
* @returns A ::IPAddress object, #INADDR_NONE if there is no IP
* @pre must check isValidSession before getting this value,
* as if this session in invalid this value is not guarenteed
* to be reset to #INADDR_NONE.
*/
const IPAddress& get_ip() const { return m_ip; }
bool is_valid_session() const { return m_valid_session; }
/**
* @brief Set the ip address and hostname of the session.
*
* This function stores the ip Address object and hostname object into
* the session object. If hostname is not null or ip address is
* not blank, and the ::br_ssl_session_parameters values are non-zero
* it then validates the session.
*
* @pre You must call
* ::br_ssl_engine_get_session_parameters
* with this session before calling this function. This is because
* there is no way to completely validate the ::br_ssl_session_parameters
* and the session may end up in a corrupted state if this is not observed.
*
* @param ip The IP address of the host associated with the session
* @param hostname The string hostname ("www.google.com") associated with the session.
* Take care that this value is corrent, SSLSession performs no validation
* of the hostname.
*/
void set_parameters(const IPAddress& ip, const char* hostname = NULL);
/**
* @brief Delete the parameters and invalidate the session.
*
* Roughly equivalent to this_session = SSLSession(), however
* this function preserves the String object, allowing it
* to better handle the dynamic memory needed.
*/
void clear_parameters();
/** @brief Returns a pointer to the ::br_ssl_session_parameters component of this class. */
br_ssl_session_parameters* to_br_session() { return (br_ssl_session_parameters *)this; }
private:
bool m_valid_session;
// aparently a hostname has a max length of 256 chars. Go figure.
String m_hostname;
// store the IP Address we connected to
IPAddress m_ip;
};

View file

@ -133,8 +133,8 @@
* returned value (a 'time_t') is an integer that counts time in seconds
* since the Unix Epoch (Jan 1st, 1970, 00:00 UTC).
*
#define BR_USE_UNIX_TIME 1
*/
#define BR_USE_UNIX_TIME 0
/*
* When BR_USE_WIN32_TIME is enabled, the X.509 validation engine obtains
@ -143,9 +143,8 @@
*
* Note: if both BR_USE_UNIX_TIME and BR_USE_WIN32_TIME are defined, the
* former takes precedence.
*
#define BR_USE_WIN32_TIME 1
*/
#define BR_USE_WIN32_TIME 0
/*
* When BR_ARMEL_CORTEXM_GCC is enabled, some operations are replaced with