implemented session cache of size n, need to figure out failure cases and account for them

This commit is contained in:
Noah Laptop 2019-03-07 18:15:39 -08:00
parent 257a61e0f3
commit ab0cf9d52b
4 changed files with 76 additions and 35 deletions

View file

@ -70,7 +70,7 @@ enum Error {
* from the client side, however from the developer side it can be a bit confusing.
*/
template <class C>
template <class C, size_t SessionCache = 1>
class SSLClient : public SSLClientImpl {
/** static type checks
* I'm a java developer, so I want to ensure that my inheritance is safe.
@ -79,6 +79,7 @@ class SSLClient : public SSLClientImpl {
* class inherits from Client, and then that it contains a status() function.
*/
static_assert(std::is_base_of<Client, C>::value, "C must be a Client Class!");
static_assert(SessionCache > 0 && SessionCache < 255, "There can be no less than one and no more than 255 sessions in the cache!");
// static_assert(std::is_function<decltype(C::status)>::value, "C must have a status() function!");
public:
@ -102,6 +103,8 @@ public:
explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const bool debug = true)
: SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, debug)
, m_client(client)
, m_sessions{}
, m_index(0)
{
// since we are copying the client in the ctor, we have to set
// the client pointer after the class is constructed
@ -124,9 +127,44 @@ public:
//! get the client object
C& getClient() { return m_client; }
virtual SSLSession& getSession(const char* host, const IPAddress& addr) {
// search for a matching session with the IP
int temp_index = -1;
for (size_t i = 0; i < SessionCache; i++) {
// if we're looking at a real session
if (m_sessions[i].is_valid_session()
&& (
// and the hostname matches, or
(host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0)
// there is no hostname and the IP address matches
|| (host == NULL && addr == m_sessions[i].get_ip())
)) {
temp_index = i;
break;
}
}
// if none are availible, use m_index
if (temp_index == -1) {
temp_index = m_index;
// reset the session so we don't try to send one sites session to another
m_sessions[temp_index] = SSLSession();
}
// increment m_index so the session cache is a circular buffer
if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0;
// return the pointed to value
m_print("Using index: ");
m_print(temp_index);
return m_sessions[temp_index];
}
private:
// create a copy of the client
C m_client;
// also store an array of SSLSessions, so we can resume communication with multiple websites
SSLSession m_sessions[SessionCache];
// store an index of where a new session can be placed if we don't have any corresponding sessions
size_t m_index;
};
#endif /** SSLClient_H_ */

View file

@ -28,13 +28,13 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a
, m_trust_anchors_num(trust_anchors_num)
, m_analog_pin(analog_pin)
, m_debug(debug)
, m_write_idx(0)
, m_session() {
, m_write_idx(0) {
// zero the iobuf just in case it's still garbage
memset(m_iobuf, 0, sizeof m_iobuf);
// initlalize the various bearssl libraries so they're ready to go when we connect
br_client_init_TLS12_only(&m_sslctx, &m_x509ctx, m_trust_anchors, m_trust_anchors_num);
// comment the above line and uncomment the line below if you're having trouble connecting over SSL
// br_ssl_client_init_full(&m_sslctx, &m_x509ctx, m_trust_anchors, m_trust_anchors_num);
// check if the buffer size is half or full duplex
constexpr auto duplex = sizeof m_iobuf <= BR_SSL_BUFSIZE_MONO ? 0 : 1;
@ -43,6 +43,11 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a
/* see SSLClientImpl.h*/
int SSLClientImpl::connect(IPAddress ip, uint16_t port) {
// connection check
if (connected()) {
m_print("Error: cannot have two connections at the same time! Please create another SSLClient instance.");
return -1;
}
// reset indexs for saftey
m_write_idx = 0;
// Warning for security
@ -55,21 +60,26 @@ int SSLClientImpl::connect(IPAddress ip, uint16_t port) {
return 0;
}
m_print("Base ethernet client connected!");
return m_start_ssl();
return m_start_ssl(NULL, getSession(NULL, ip));
}
/* see SSLClientImpl.h*/
int SSLClientImpl::connect(const char *host, uint16_t port) {
// connection check
if (connected()) {
m_print("Error: cannot have two connections at the same time! Please create another SSLClient instance.");
return -1;
}
// reset indexs for saftey
m_write_idx = 0;
// first, if we have a session, check if we're trying to resolve the same host
// as before
bool connect_ok;
if (m_session.is_valid_session()
&& strcmp(m_session.get_hostname(), host) == 0) {
SSLSession& ses = getSession(host, INADDR_NONE);
if (ses.is_valid_session()) {
// if so, then connect using the stored session
m_print("Connecting using a cached IP");
connect_ok = m_client->connect(m_session.get_ip(), port);
connect_ok = m_client->connect(ses.get_ip(), port);
}
// else connect with the provided hostname
else connect_ok = m_client->connect(host, port);
@ -82,7 +92,7 @@ int SSLClientImpl::connect(const char *host, uint16_t port) {
}
m_print("Base ethernet client connected!");
// start ssl!
return m_start_ssl(host);
return m_start_ssl(host, ses);
}
/** see SSLClientImpl.h*/
@ -229,7 +239,7 @@ uint8_t SSLClientImpl::connected() {
}
/** see SSLClientImpl.h */
int SSLClientImpl::m_start_ssl(const char* host) {
int SSLClientImpl::m_start_ssl(const char* host, SSLSession& ssl_ses) {
// clear the write error
setWriteError(SSL_OK);
// get some random data by reading the analog pin we've been handed
@ -239,8 +249,8 @@ int SSLClientImpl::m_start_ssl(const char* host) {
for (uint8_t i = 0; i < sizeof rng_seeds; i++) rng_seeds[i] = static_cast<uint8_t>(analogRead(m_analog_pin));
br_ssl_engine_inject_entropy(&m_sslctx.eng, rng_seeds, sizeof rng_seeds);
// inject session parameters for faster reconnection, if we have any
if(m_session.is_valid_session()) {
br_ssl_engine_set_session_parameters(&m_sslctx.eng, m_session.to_br_session());
if(ssl_ses.is_valid_session()) {
br_ssl_engine_set_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session());
m_print("Set session!");
}
// reset the engine, but make sure that it reset successfully
@ -259,17 +269,18 @@ int SSLClientImpl::m_start_ssl(const char* host) {
return 0;
}
// all good to go! the SSL socket should be up and running
// debug print the session parameters to see if they exist
br_ssl_engine_get_session_parameters(&m_sslctx.eng, m_session.to_br_session());
// overwrite the session we got with new parameters
br_ssl_engine_get_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session());
// set the hostname and ip in the session as well
m_session.set_parameters(remoteIP(), host);
ssl_ses.set_parameters(remoteIP(), host);
// print the session details
m_print("Session:");
for (uint8_t i = 0; i < m_session.session_id_len; i++) {
for (uint8_t i = 0; i < ssl_ses.session_id_len; i++) {
Serial.print(", 0x");
Serial.print(m_session.session_id[i], HEX);
Serial.print(ssl_ses.session_id[i], HEX);
}
Serial.println();
Serial.println(m_session.cipher_suite, HEX);
Serial.println(ssl_ses.cipher_suite, HEX);
return 1;
}
@ -292,7 +303,7 @@ int SSLClientImpl::m_run_until(const unsigned target) {
}
if (state & BR_SSL_RECVREC) {
size_t len;
unsigned char * buf = br_ssl_engine_recvrec_buf(&m_sslctx.eng, &len);
br_ssl_engine_recvrec_buf(&m_sslctx.eng, &len);
if (lastLen != len) {
m_print("Expected bytes count: ");
m_print(lastLen = len);
@ -355,15 +366,6 @@ unsigned SSLClientImpl::m_update_engine() {
int wlen;
buf = br_ssl_engine_sendrec_buf(&m_sslctx.eng, &len);
Serial.print("Payload: ");
for (int i = 0; i < len; i++) {
if (buf[i] <= 0x0f) Serial.print("0x0");
else Serial.print("0x");
Serial.print(buf[i], HEX);
Serial.print(", ");
}
Serial.println();
//delay(100);
wlen = m_client->write(buf, len);
// let the chip recover
if (wlen < 0) {

View file

@ -120,6 +120,9 @@ public:
virtual uint16_t localPort() = 0;
virtual IPAddress remoteIP() = 0;
virtual uint16_t remotePort() = 0;
// as well as store and retrieve session data
virtual SSLSession& getSession(const char* host, const IPAddress& addr) = 0;
protected:
/**
* @brief set the pointer to the Client class that we wil use
@ -130,8 +133,6 @@ protected:
*/
void set_client(Client* c) { m_client = c; }
private:
/** @brief debugging print function, only prints if m_debug is true */
template<typename T>
constexpr void m_print(const T str) const {
@ -141,6 +142,8 @@ private:
}
}
private:
void printState(unsigned state) const {
if(m_debug) {
m_print("State: ");
@ -155,7 +158,7 @@ private:
}
}
/** start the ssl engine on the connected client */
int m_start_ssl(const char* host = NULL);
int m_start_ssl(const char* host, SSLSession& ssl_ses);
/** run the bearssl engine until a certain state */
int m_run_until(const unsigned target);
/** proxy for availble that returns the state */
@ -183,8 +186,6 @@ private:
// so we can send our records all at once to prevent
// weird timing issues
size_t m_write_idx;
// store the last SSL session, so reconnection later is speedy fast
SSLSession m_session;
};
#endif /* SSLClientImpl_H_ */

View file

@ -52,7 +52,7 @@ class SSLSession : public br_ssl_session_parameters {
public:
explicit SSLSession()
: m_valid_session(false)
, m_hostname({})
, m_hostname{}
, m_ip(INADDR_NONE) {}
/**
@ -75,14 +75,14 @@ public:
/**
* \pre must check isValidSession
*/
const char* const get_hostname() const { return m_hostname; }
const char* get_hostname() const { return m_hostname; }
/**
* \pre must check isValidSession
*/
const IPAddress& get_ip() const { return m_ip; }
const bool is_valid_session() const { return m_valid_session; }
bool is_valid_session() const { return m_valid_session; }
private:
bool m_valid_session;
// aparently a hostname has a max length of 256 chars. Go figure.