implemented session cache of size n, need to figure out failure cases and account for them
This commit is contained in:
parent
257a61e0f3
commit
ab0cf9d52b
4 changed files with 76 additions and 35 deletions
|
@ -70,7 +70,7 @@ enum Error {
|
||||||
* from the client side, however from the developer side it can be a bit confusing.
|
* from the client side, however from the developer side it can be a bit confusing.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
template <class C>
|
template <class C, size_t SessionCache = 1>
|
||||||
class SSLClient : public SSLClientImpl {
|
class SSLClient : public SSLClientImpl {
|
||||||
/** static type checks
|
/** static type checks
|
||||||
* I'm a java developer, so I want to ensure that my inheritance is safe.
|
* I'm a java developer, so I want to ensure that my inheritance is safe.
|
||||||
|
@ -79,6 +79,7 @@ class SSLClient : public SSLClientImpl {
|
||||||
* class inherits from Client, and then that it contains a status() function.
|
* class inherits from Client, and then that it contains a status() function.
|
||||||
*/
|
*/
|
||||||
static_assert(std::is_base_of<Client, C>::value, "C must be a Client Class!");
|
static_assert(std::is_base_of<Client, C>::value, "C must be a Client Class!");
|
||||||
|
static_assert(SessionCache > 0 && SessionCache < 255, "There can be no less than one and no more than 255 sessions in the cache!");
|
||||||
// static_assert(std::is_function<decltype(C::status)>::value, "C must have a status() function!");
|
// static_assert(std::is_function<decltype(C::status)>::value, "C must have a status() function!");
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -102,6 +103,8 @@ public:
|
||||||
explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const bool debug = true)
|
explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const bool debug = true)
|
||||||
: SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, debug)
|
: SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, debug)
|
||||||
, m_client(client)
|
, m_client(client)
|
||||||
|
, m_sessions{}
|
||||||
|
, m_index(0)
|
||||||
{
|
{
|
||||||
// since we are copying the client in the ctor, we have to set
|
// since we are copying the client in the ctor, we have to set
|
||||||
// the client pointer after the class is constructed
|
// the client pointer after the class is constructed
|
||||||
|
@ -124,9 +127,44 @@ public:
|
||||||
//! get the client object
|
//! get the client object
|
||||||
C& getClient() { return m_client; }
|
C& getClient() { return m_client; }
|
||||||
|
|
||||||
|
virtual SSLSession& getSession(const char* host, const IPAddress& addr) {
|
||||||
|
// search for a matching session with the IP
|
||||||
|
int temp_index = -1;
|
||||||
|
for (size_t i = 0; i < SessionCache; i++) {
|
||||||
|
// if we're looking at a real session
|
||||||
|
if (m_sessions[i].is_valid_session()
|
||||||
|
&& (
|
||||||
|
// and the hostname matches, or
|
||||||
|
(host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0)
|
||||||
|
// there is no hostname and the IP address matches
|
||||||
|
|| (host == NULL && addr == m_sessions[i].get_ip())
|
||||||
|
)) {
|
||||||
|
|
||||||
|
temp_index = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if none are availible, use m_index
|
||||||
|
if (temp_index == -1) {
|
||||||
|
temp_index = m_index;
|
||||||
|
// reset the session so we don't try to send one sites session to another
|
||||||
|
m_sessions[temp_index] = SSLSession();
|
||||||
|
}
|
||||||
|
// increment m_index so the session cache is a circular buffer
|
||||||
|
if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0;
|
||||||
|
// return the pointed to value
|
||||||
|
m_print("Using index: ");
|
||||||
|
m_print(temp_index);
|
||||||
|
return m_sessions[temp_index];
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// create a copy of the client
|
// create a copy of the client
|
||||||
C m_client;
|
C m_client;
|
||||||
|
// also store an array of SSLSessions, so we can resume communication with multiple websites
|
||||||
|
SSLSession m_sessions[SessionCache];
|
||||||
|
// store an index of where a new session can be placed if we don't have any corresponding sessions
|
||||||
|
size_t m_index;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif /** SSLClient_H_ */
|
#endif /** SSLClient_H_ */
|
|
@ -28,13 +28,13 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a
|
||||||
, m_trust_anchors_num(trust_anchors_num)
|
, m_trust_anchors_num(trust_anchors_num)
|
||||||
, m_analog_pin(analog_pin)
|
, m_analog_pin(analog_pin)
|
||||||
, m_debug(debug)
|
, m_debug(debug)
|
||||||
, m_write_idx(0)
|
, m_write_idx(0) {
|
||||||
, m_session() {
|
|
||||||
|
|
||||||
// zero the iobuf just in case it's still garbage
|
// zero the iobuf just in case it's still garbage
|
||||||
memset(m_iobuf, 0, sizeof m_iobuf);
|
memset(m_iobuf, 0, sizeof m_iobuf);
|
||||||
// initlalize the various bearssl libraries so they're ready to go when we connect
|
// initlalize the various bearssl libraries so they're ready to go when we connect
|
||||||
br_client_init_TLS12_only(&m_sslctx, &m_x509ctx, m_trust_anchors, m_trust_anchors_num);
|
br_client_init_TLS12_only(&m_sslctx, &m_x509ctx, m_trust_anchors, m_trust_anchors_num);
|
||||||
|
// comment the above line and uncomment the line below if you're having trouble connecting over SSL
|
||||||
// br_ssl_client_init_full(&m_sslctx, &m_x509ctx, m_trust_anchors, m_trust_anchors_num);
|
// br_ssl_client_init_full(&m_sslctx, &m_x509ctx, m_trust_anchors, m_trust_anchors_num);
|
||||||
// check if the buffer size is half or full duplex
|
// check if the buffer size is half or full duplex
|
||||||
constexpr auto duplex = sizeof m_iobuf <= BR_SSL_BUFSIZE_MONO ? 0 : 1;
|
constexpr auto duplex = sizeof m_iobuf <= BR_SSL_BUFSIZE_MONO ? 0 : 1;
|
||||||
|
@ -43,6 +43,11 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a
|
||||||
|
|
||||||
/* see SSLClientImpl.h*/
|
/* see SSLClientImpl.h*/
|
||||||
int SSLClientImpl::connect(IPAddress ip, uint16_t port) {
|
int SSLClientImpl::connect(IPAddress ip, uint16_t port) {
|
||||||
|
// connection check
|
||||||
|
if (connected()) {
|
||||||
|
m_print("Error: cannot have two connections at the same time! Please create another SSLClient instance.");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
// reset indexs for saftey
|
// reset indexs for saftey
|
||||||
m_write_idx = 0;
|
m_write_idx = 0;
|
||||||
// Warning for security
|
// Warning for security
|
||||||
|
@ -55,21 +60,26 @@ int SSLClientImpl::connect(IPAddress ip, uint16_t port) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
m_print("Base ethernet client connected!");
|
m_print("Base ethernet client connected!");
|
||||||
return m_start_ssl();
|
return m_start_ssl(NULL, getSession(NULL, ip));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* see SSLClientImpl.h*/
|
/* see SSLClientImpl.h*/
|
||||||
int SSLClientImpl::connect(const char *host, uint16_t port) {
|
int SSLClientImpl::connect(const char *host, uint16_t port) {
|
||||||
|
// connection check
|
||||||
|
if (connected()) {
|
||||||
|
m_print("Error: cannot have two connections at the same time! Please create another SSLClient instance.");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
// reset indexs for saftey
|
// reset indexs for saftey
|
||||||
m_write_idx = 0;
|
m_write_idx = 0;
|
||||||
// first, if we have a session, check if we're trying to resolve the same host
|
// first, if we have a session, check if we're trying to resolve the same host
|
||||||
// as before
|
// as before
|
||||||
bool connect_ok;
|
bool connect_ok;
|
||||||
if (m_session.is_valid_session()
|
SSLSession& ses = getSession(host, INADDR_NONE);
|
||||||
&& strcmp(m_session.get_hostname(), host) == 0) {
|
if (ses.is_valid_session()) {
|
||||||
// if so, then connect using the stored session
|
// if so, then connect using the stored session
|
||||||
m_print("Connecting using a cached IP");
|
m_print("Connecting using a cached IP");
|
||||||
connect_ok = m_client->connect(m_session.get_ip(), port);
|
connect_ok = m_client->connect(ses.get_ip(), port);
|
||||||
}
|
}
|
||||||
// else connect with the provided hostname
|
// else connect with the provided hostname
|
||||||
else connect_ok = m_client->connect(host, port);
|
else connect_ok = m_client->connect(host, port);
|
||||||
|
@ -82,7 +92,7 @@ int SSLClientImpl::connect(const char *host, uint16_t port) {
|
||||||
}
|
}
|
||||||
m_print("Base ethernet client connected!");
|
m_print("Base ethernet client connected!");
|
||||||
// start ssl!
|
// start ssl!
|
||||||
return m_start_ssl(host);
|
return m_start_ssl(host, ses);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** see SSLClientImpl.h*/
|
/** see SSLClientImpl.h*/
|
||||||
|
@ -229,7 +239,7 @@ uint8_t SSLClientImpl::connected() {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** see SSLClientImpl.h */
|
/** see SSLClientImpl.h */
|
||||||
int SSLClientImpl::m_start_ssl(const char* host) {
|
int SSLClientImpl::m_start_ssl(const char* host, SSLSession& ssl_ses) {
|
||||||
// clear the write error
|
// clear the write error
|
||||||
setWriteError(SSL_OK);
|
setWriteError(SSL_OK);
|
||||||
// get some random data by reading the analog pin we've been handed
|
// get some random data by reading the analog pin we've been handed
|
||||||
|
@ -239,8 +249,8 @@ int SSLClientImpl::m_start_ssl(const char* host) {
|
||||||
for (uint8_t i = 0; i < sizeof rng_seeds; i++) rng_seeds[i] = static_cast<uint8_t>(analogRead(m_analog_pin));
|
for (uint8_t i = 0; i < sizeof rng_seeds; i++) rng_seeds[i] = static_cast<uint8_t>(analogRead(m_analog_pin));
|
||||||
br_ssl_engine_inject_entropy(&m_sslctx.eng, rng_seeds, sizeof rng_seeds);
|
br_ssl_engine_inject_entropy(&m_sslctx.eng, rng_seeds, sizeof rng_seeds);
|
||||||
// inject session parameters for faster reconnection, if we have any
|
// inject session parameters for faster reconnection, if we have any
|
||||||
if(m_session.is_valid_session()) {
|
if(ssl_ses.is_valid_session()) {
|
||||||
br_ssl_engine_set_session_parameters(&m_sslctx.eng, m_session.to_br_session());
|
br_ssl_engine_set_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session());
|
||||||
m_print("Set session!");
|
m_print("Set session!");
|
||||||
}
|
}
|
||||||
// reset the engine, but make sure that it reset successfully
|
// reset the engine, but make sure that it reset successfully
|
||||||
|
@ -259,17 +269,18 @@ int SSLClientImpl::m_start_ssl(const char* host) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
// all good to go! the SSL socket should be up and running
|
// all good to go! the SSL socket should be up and running
|
||||||
// debug print the session parameters to see if they exist
|
// overwrite the session we got with new parameters
|
||||||
br_ssl_engine_get_session_parameters(&m_sslctx.eng, m_session.to_br_session());
|
br_ssl_engine_get_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session());
|
||||||
// set the hostname and ip in the session as well
|
// set the hostname and ip in the session as well
|
||||||
m_session.set_parameters(remoteIP(), host);
|
ssl_ses.set_parameters(remoteIP(), host);
|
||||||
|
// print the session details
|
||||||
m_print("Session:");
|
m_print("Session:");
|
||||||
for (uint8_t i = 0; i < m_session.session_id_len; i++) {
|
for (uint8_t i = 0; i < ssl_ses.session_id_len; i++) {
|
||||||
Serial.print(", 0x");
|
Serial.print(", 0x");
|
||||||
Serial.print(m_session.session_id[i], HEX);
|
Serial.print(ssl_ses.session_id[i], HEX);
|
||||||
}
|
}
|
||||||
Serial.println();
|
Serial.println();
|
||||||
Serial.println(m_session.cipher_suite, HEX);
|
Serial.println(ssl_ses.cipher_suite, HEX);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,7 +303,7 @@ int SSLClientImpl::m_run_until(const unsigned target) {
|
||||||
}
|
}
|
||||||
if (state & BR_SSL_RECVREC) {
|
if (state & BR_SSL_RECVREC) {
|
||||||
size_t len;
|
size_t len;
|
||||||
unsigned char * buf = br_ssl_engine_recvrec_buf(&m_sslctx.eng, &len);
|
br_ssl_engine_recvrec_buf(&m_sslctx.eng, &len);
|
||||||
if (lastLen != len) {
|
if (lastLen != len) {
|
||||||
m_print("Expected bytes count: ");
|
m_print("Expected bytes count: ");
|
||||||
m_print(lastLen = len);
|
m_print(lastLen = len);
|
||||||
|
@ -355,15 +366,6 @@ unsigned SSLClientImpl::m_update_engine() {
|
||||||
int wlen;
|
int wlen;
|
||||||
|
|
||||||
buf = br_ssl_engine_sendrec_buf(&m_sslctx.eng, &len);
|
buf = br_ssl_engine_sendrec_buf(&m_sslctx.eng, &len);
|
||||||
Serial.print("Payload: ");
|
|
||||||
for (int i = 0; i < len; i++) {
|
|
||||||
if (buf[i] <= 0x0f) Serial.print("0x0");
|
|
||||||
else Serial.print("0x");
|
|
||||||
Serial.print(buf[i], HEX);
|
|
||||||
Serial.print(", ");
|
|
||||||
}
|
|
||||||
Serial.println();
|
|
||||||
//delay(100);
|
|
||||||
wlen = m_client->write(buf, len);
|
wlen = m_client->write(buf, len);
|
||||||
// let the chip recover
|
// let the chip recover
|
||||||
if (wlen < 0) {
|
if (wlen < 0) {
|
||||||
|
|
|
@ -120,6 +120,9 @@ public:
|
||||||
virtual uint16_t localPort() = 0;
|
virtual uint16_t localPort() = 0;
|
||||||
virtual IPAddress remoteIP() = 0;
|
virtual IPAddress remoteIP() = 0;
|
||||||
virtual uint16_t remotePort() = 0;
|
virtual uint16_t remotePort() = 0;
|
||||||
|
|
||||||
|
// as well as store and retrieve session data
|
||||||
|
virtual SSLSession& getSession(const char* host, const IPAddress& addr) = 0;
|
||||||
protected:
|
protected:
|
||||||
/**
|
/**
|
||||||
* @brief set the pointer to the Client class that we wil use
|
* @brief set the pointer to the Client class that we wil use
|
||||||
|
@ -130,8 +133,6 @@ protected:
|
||||||
*/
|
*/
|
||||||
void set_client(Client* c) { m_client = c; }
|
void set_client(Client* c) { m_client = c; }
|
||||||
|
|
||||||
private:
|
|
||||||
|
|
||||||
/** @brief debugging print function, only prints if m_debug is true */
|
/** @brief debugging print function, only prints if m_debug is true */
|
||||||
template<typename T>
|
template<typename T>
|
||||||
constexpr void m_print(const T str) const {
|
constexpr void m_print(const T str) const {
|
||||||
|
@ -141,6 +142,8 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
void printState(unsigned state) const {
|
void printState(unsigned state) const {
|
||||||
if(m_debug) {
|
if(m_debug) {
|
||||||
m_print("State: ");
|
m_print("State: ");
|
||||||
|
@ -155,7 +158,7 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/** start the ssl engine on the connected client */
|
/** start the ssl engine on the connected client */
|
||||||
int m_start_ssl(const char* host = NULL);
|
int m_start_ssl(const char* host, SSLSession& ssl_ses);
|
||||||
/** run the bearssl engine until a certain state */
|
/** run the bearssl engine until a certain state */
|
||||||
int m_run_until(const unsigned target);
|
int m_run_until(const unsigned target);
|
||||||
/** proxy for availble that returns the state */
|
/** proxy for availble that returns the state */
|
||||||
|
@ -183,8 +186,6 @@ private:
|
||||||
// so we can send our records all at once to prevent
|
// so we can send our records all at once to prevent
|
||||||
// weird timing issues
|
// weird timing issues
|
||||||
size_t m_write_idx;
|
size_t m_write_idx;
|
||||||
// store the last SSL session, so reconnection later is speedy fast
|
|
||||||
SSLSession m_session;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif /* SSLClientImpl_H_ */
|
#endif /* SSLClientImpl_H_ */
|
|
@ -52,7 +52,7 @@ class SSLSession : public br_ssl_session_parameters {
|
||||||
public:
|
public:
|
||||||
explicit SSLSession()
|
explicit SSLSession()
|
||||||
: m_valid_session(false)
|
: m_valid_session(false)
|
||||||
, m_hostname({})
|
, m_hostname{}
|
||||||
, m_ip(INADDR_NONE) {}
|
, m_ip(INADDR_NONE) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -75,14 +75,14 @@ public:
|
||||||
/**
|
/**
|
||||||
* \pre must check isValidSession
|
* \pre must check isValidSession
|
||||||
*/
|
*/
|
||||||
const char* const get_hostname() const { return m_hostname; }
|
const char* get_hostname() const { return m_hostname; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \pre must check isValidSession
|
* \pre must check isValidSession
|
||||||
*/
|
*/
|
||||||
const IPAddress& get_ip() const { return m_ip; }
|
const IPAddress& get_ip() const { return m_ip; }
|
||||||
|
|
||||||
const bool is_valid_session() const { return m_valid_session; }
|
bool is_valid_session() const { return m_valid_session; }
|
||||||
private:
|
private:
|
||||||
bool m_valid_session;
|
bool m_valid_session;
|
||||||
// aparently a hostname has a max length of 256 chars. Go figure.
|
// aparently a hostname has a max length of 256 chars. Go figure.
|
||||||
|
|
Loading…
Reference in a new issue