debugged stack overflow error, fixed connection timeout and some error flow problems
This commit is contained in:
		
							parent
							
								
									c212e355a4
								
							
						
					
					
						commit
						7f72073fa6
					
				
					 5 changed files with 129 additions and 51 deletions
				
			
		|  | @ -67,6 +67,7 @@ class SSLClient : public SSLClientImpl { | |||
|  */ | ||||
| static_assert(std::is_base_of<Client, C>::value, "C must be a Client Class!"); | ||||
| static_assert(SessionCache > 0 && SessionCache < 255, "There can be no less than one and no more than 255 sessions in the cache!"); | ||||
| static_assert(SessionCache <= 3, "You need to decrease the size of m_iobuf in order to have more than 3 sessions at once, otherwise memory issues will occur."); | ||||
| // static_assert(std::is_function<decltype(C::status)>::value, "C must have a status() function!");
 | ||||
| 
 | ||||
| public: | ||||
|  | @ -90,12 +91,16 @@ public: | |||
|     explicit SSLClient(const C& client, const br_x509_trust_anchor *trust_anchors, const size_t trust_anchors_num, const int analog_pin, const DebugLevel debug = SSL_ERROR) | ||||
|     : SSLClientImpl(NULL, trust_anchors, trust_anchors_num, analog_pin, debug)  | ||||
|     , m_client(client) | ||||
|     , m_sessions{} | ||||
|     , m_sessions{SSLSession()} | ||||
|     , m_index(0) | ||||
|     { | ||||
|         // for (uint8_t i = 0; i < SessionCache; i++) m_sessions[i] = SSLSession();
 | ||||
|         // since we are copying the client in the ctor, we have to set
 | ||||
|         // the client pointer after the class is constructed
 | ||||
|         set_client(&m_client); | ||||
|         // set the timeout to a reasonable number (it can always be changes later)
 | ||||
|         // SSL Connections take a really long time so we don't want to time out a legitimate thing
 | ||||
|         setTimeout(10 * 1000); | ||||
|     } | ||||
|      | ||||
|     /*
 | ||||
|  | @ -114,38 +119,13 @@ public: | |||
|     //! get the client object
 | ||||
|     C& getClient() { return m_client; } | ||||
| 
 | ||||
|     virtual SSLSession& getSession(const char* host, const IPAddress& addr) { | ||||
|         // search for a matching session with the IP
 | ||||
|         int temp_index = -1; | ||||
|         for (size_t i = 0; i < SessionCache; i++) { | ||||
|             // if we're looking at a real session
 | ||||
|             if (m_sessions[i].is_valid_session()  | ||||
|                 && ( | ||||
|                     // and the hostname matches, or
 | ||||
|                     (host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0) | ||||
|                     // there is no hostname and the IP address matches    
 | ||||
|                     || (host == NULL && addr == m_sessions[i].get_ip()) | ||||
|                 )) { | ||||
|     virtual SSLSession& getSession(const char* host, const IPAddress& addr); | ||||
| 
 | ||||
|                 temp_index = i; | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|         // if none are availible, use m_index
 | ||||
|         if (temp_index == -1) { | ||||
|             temp_index = m_index; | ||||
|             // reset the session so we don't try to send one sites session to another
 | ||||
|             m_sessions[temp_index] = SSLSession(); | ||||
|         } | ||||
|         // increment m_index so the session cache is a circular buffer
 | ||||
|         if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0; | ||||
|         // return the pointed to value
 | ||||
|         m_info("Using session index: ", __func__); | ||||
|         Serial.println(temp_index); | ||||
|         return m_sessions[temp_index]; | ||||
|     } | ||||
|     virtual void removeSession(const char* host, const IPAddress& addr); | ||||
| 
 | ||||
| private: | ||||
|     // utility function to find a session index based off of a host and IP
 | ||||
|     int m_getSessionIndex(const char* host, const IPAddress& addr) const; | ||||
|     // create a copy of the client
 | ||||
|     C m_client; | ||||
|     // also store an array of SSLSessions, so we can resume communication with multiple websites
 | ||||
|  | @ -154,4 +134,56 @@ private: | |||
|     size_t m_index; | ||||
| }; | ||||
| 
 | ||||
| template <class C, size_t SessionCache> | ||||
| SSLSession& SSLClient<C, SessionCache>::getSession(const char* host, const IPAddress& addr) { | ||||
|     const char* func_name = __func__; | ||||
|     // search for a matching session with the IP
 | ||||
|     int temp_index = m_getSessionIndex(host, addr); | ||||
|     // if none are availible, use m_index
 | ||||
|     if (temp_index == -1) { | ||||
|         temp_index = m_index; | ||||
|         // reset the session so we don't try to send one sites session to another
 | ||||
|         m_sessions[temp_index] = SSLSession(); | ||||
|     } | ||||
|     // increment m_index so the session cache is a circular buffer
 | ||||
|     if (temp_index == m_index && ++m_index >= SessionCache) m_index = 0; | ||||
|     // return the pointed to value
 | ||||
|     m_info("Using session index: ", func_name); | ||||
|     Serial.println(temp_index); | ||||
|     return m_sessions[temp_index]; | ||||
| } | ||||
| 
 | ||||
| template <class C, size_t SessionCache> | ||||
| void SSLClient<C, SessionCache>::removeSession(const char* host, const IPAddress& addr) { | ||||
|     const char* func_name = __func__; | ||||
|     int temp_index = m_getSessionIndex(host, addr); | ||||
|     if (temp_index != -1) { | ||||
|         m_info(" Deleted session ", func_name); | ||||
|         m_info(temp_index, func_name); | ||||
|         m_sessions[temp_index] = SSLSession(); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| template <class C, size_t SessionCache> | ||||
| int SSLClient<C, SessionCache>::m_getSessionIndex(const char* host, const IPAddress& addr) const { | ||||
|     const char* func_name = __func__; | ||||
|     // search for a matching session with the IP
 | ||||
|     for (uint8_t i = 0; i < SessionCache; i++) { | ||||
|         // if we're looking at a real session
 | ||||
|         if (m_sessions[i].is_valid_session()  | ||||
|             && ( | ||||
|                 // and the hostname matches, or
 | ||||
|                 (host != NULL && strcmp(host, m_sessions[i].get_hostname()) == 0) | ||||
|                 // there is no hostname and the IP address matches    
 | ||||
|                 || (host == NULL && addr == m_sessions[i].get_ip()) | ||||
|             )) { | ||||
|             m_info("Found session match: ", func_name); | ||||
|             m_info(m_sessions[i].get_hostname(), func_name); | ||||
|             return i; | ||||
|         } | ||||
|     } | ||||
|     // none found
 | ||||
|     return -1; | ||||
| } | ||||
| 
 | ||||
| #endif /** SSLClient_H_ */ | ||||
|  | @ -45,7 +45,7 @@ SSLClientImpl::SSLClientImpl(Client *client, const br_x509_trust_anchor *trust_a | |||
| int SSLClientImpl::connect(IPAddress ip, uint16_t port) { | ||||
|     const char* func_name = __func__; | ||||
|     // connection check
 | ||||
|     if (connected()) { | ||||
|     if (m_client->connected()) { | ||||
|         m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name); | ||||
|         return -1; | ||||
|     } | ||||
|  | @ -68,7 +68,7 @@ int SSLClientImpl::connect(IPAddress ip, uint16_t port) { | |||
| int SSLClientImpl::connect(const char *host, uint16_t port) { | ||||
|     const char* func_name = __func__; | ||||
|     // connection check
 | ||||
|     if (connected()) { | ||||
|     if (m_client->connected()) { | ||||
|         m_error("Cannot have two connections at the same time! Please create another SSLClient instance.", func_name); | ||||
|         return -1; | ||||
|     } | ||||
|  | @ -148,7 +148,7 @@ int SSLClientImpl::available() { | |||
|         br_ssl_engine_recvapp_buf(&m_sslctx.eng, &alen); | ||||
|         return (int)(alen); | ||||
|     } | ||||
|     else if (state == BR_SSL_CLOSED) m_warn("Engine closed after update", func_name); | ||||
|     else if (state == BR_SSL_CLOSED) m_info("Engine closed after update", func_name); | ||||
|     // flush the buffer if it's stuck in the SENDAPP state
 | ||||
|     else if (state & BR_SSL_SENDAPP) br_ssl_engine_flush(&m_sslctx.eng, 0); | ||||
|     // other state, or client is closed
 | ||||
|  | @ -194,8 +194,6 @@ void SSLClientImpl::flush() { | |||
| void SSLClientImpl::stop() { | ||||
|     // tell the SSL connection to gracefully close
 | ||||
|     br_ssl_engine_close(&m_sslctx.eng); | ||||
|     // info about the socket connection
 | ||||
|     if (br_ssl_engine_current_state(&m_sslctx.eng) == BR_SSL_CLOSED) m_info("Socket was terminated before graceful closure (probably fine)", __func__); | ||||
|     // if the engine isn't closed, and the socket is still open
 | ||||
|     while (br_ssl_engine_current_state(&m_sslctx.eng) != BR_SSL_CLOSED | ||||
|         && m_run_until(BR_SSL_RECVAPP) == 0) { | ||||
|  | @ -220,8 +218,14 @@ uint8_t SSLClientImpl::connected() { | |||
|     const auto wr_ok = getWriteError() == 0; | ||||
|     // if we're in an error state, close the connection and set a write error
 | ||||
|     if (br_con && !c_con) { | ||||
|         m_error("Socket was unexpectedly interrupted. m_client error: ", func_name); | ||||
|         m_error(m_client->getWriteError(), func_name); | ||||
|         // If we've got a write error, the client probably failed for some reason
 | ||||
|         if (m_client->getWriteError()) { | ||||
|             m_error("Socket was unexpectedly interrupted. m_client error: ", func_name); | ||||
|             m_error(m_client->getWriteError(), func_name); | ||||
|         } | ||||
|         // Else tell the user the endpoint closed the socket on us (ouch)
 | ||||
|         else m_warn("Socket was dropped unexpectedly (this can be an alternative to closing the connection)", func_name); | ||||
|         // set the write error so the engine doesn't try to close the connection
 | ||||
|         setWriteError(SSL_CLIENT_WRTIE_ERROR); | ||||
|         stop(); | ||||
|     } | ||||
|  | @ -280,11 +284,18 @@ int SSLClientImpl::m_start_ssl(const char* host, SSLSession& ssl_ses) { | |||
|         m_print_br_error(br_ssl_engine_last_error(&m_sslctx.eng), SSL_ERROR); | ||||
|         return 0; | ||||
| 	} | ||||
|     m_info("Connection successful!", func_name); | ||||
|     // all good to go! the SSL socket should be up and running
 | ||||
|     // overwrite the session we got with new parameters
 | ||||
|     br_ssl_engine_get_session_parameters(&m_sslctx.eng, ssl_ses.to_br_session()); | ||||
|     // set the hostname and ip in the session as well
 | ||||
|     ssl_ses.set_parameters(remoteIP(), host); | ||||
|     // print the handshake cipher chioce
 | ||||
|     m_info("Cipher suite: ", func_name); | ||||
|     if (m_debug >= SSL_INFO) { | ||||
|         m_print_prefix(func_name, SSL_INFO); | ||||
|         Serial.println(ssl_ses.cipher_suite, HEX); | ||||
|     } | ||||
|     return 1; | ||||
| } | ||||
| 
 | ||||
|  | @ -293,6 +304,7 @@ int SSLClientImpl::m_run_until(const unsigned target) { | |||
|     const char* func_name = __func__; | ||||
|     unsigned lastState = 0; | ||||
|     size_t lastLen = 0; | ||||
|     const unsigned long start = millis(); | ||||
|     for (;;) { | ||||
|         unsigned state = m_update_engine(); | ||||
| 		// error check
 | ||||
|  | @ -300,11 +312,20 @@ int SSLClientImpl::m_run_until(const unsigned target) { | |||
|             m_warn("Tried to run_until when the engine is closed", func_name); | ||||
|             return -1; | ||||
|         } | ||||
|         // timeout check
 | ||||
|         if (millis() - start > getTimeout()) { | ||||
|             m_error("SSL internals timed out! This could be an internal error or bad data sent from the server", func_name); | ||||
|             setWriteError(SSL_BR_WRITE_ERROR); | ||||
|             stop(); | ||||
|             return -1; | ||||
|         } | ||||
|         // debug
 | ||||
|         if (state != lastState) { | ||||
|             lastState = state; | ||||
|             m_info("m_run waiting:", func_name); | ||||
|             m_info("m_run changed state:", func_name); | ||||
|             printState(state); | ||||
|             m_info("Memory: ", func_name); | ||||
|             m_info(freeMemory(), func_name); | ||||
|         } | ||||
|         if (state & BR_SSL_RECVREC) { | ||||
|             size_t len; | ||||
|  | @ -455,7 +476,8 @@ unsigned SSLClientImpl::m_update_engine() { | |||
|                 m_info("Read bytes from client: ", func_name); | ||||
|                 m_info(avail, func_name); | ||||
|                 m_info(len, func_name); | ||||
|                  | ||||
|                 m_info("Memory: ", func_name); | ||||
|                 m_info(freeMemory(), func_name); | ||||
|                 // I suppose so!
 | ||||
|                 int rlen = m_client->read(buf, len); | ||||
|                 if (rlen <= 0) { | ||||
|  | @ -495,20 +517,20 @@ void SSLClientImpl::m_print_prefix(const char* func_name, const DebugLevel level | |||
|     Serial.print("(SSLClient)"); | ||||
|     // print the debug level
 | ||||
|     switch (level) { | ||||
|         case SSL_INFO: Serial.print("SSL_INFO"); break; | ||||
|         case SSL_WARN: Serial.print("SSL_WARN"); break; | ||||
|         case SSL_ERROR: Serial.print("SSL_ERROR"); break; | ||||
|         default: Serial.print("Unknown level"); | ||||
|         case SSL_INFO: Serial.print("(SSL_INFO)"); break; | ||||
|         case SSL_WARN: Serial.print("(SSL_WARN)"); break; | ||||
|         case SSL_ERROR: Serial.print("(SSL_ERROR)"); break; | ||||
|         default: Serial.print("(Unknown level)"); | ||||
|     } | ||||
|     // print the function name
 | ||||
|     Serial.print("("); | ||||
|     Serial.print(func_name); | ||||
|     // get ready
 | ||||
|     Serial.print(": "); | ||||
|     Serial.print("): "); | ||||
| } | ||||
| 
 | ||||
| /** See SSLClientImpl.h */ | ||||
| void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel level) const { | ||||
|     if (level < m_debug) return; | ||||
|     if (level > m_debug) return; | ||||
|     m_print_prefix(__func__, level); | ||||
|     switch(ssl_error) { | ||||
|         case SSL_OK: Serial.println("SSL_OK"); break; | ||||
|  | @ -522,7 +544,7 @@ void SSLClientImpl::m_print_ssl_error(const int ssl_error, const DebugLevel leve | |||
| 
 | ||||
| /* See SSLClientImpl.h */ | ||||
| void SSLClientImpl::m_print_br_error(const unsigned br_error_code, const DebugLevel level) const { | ||||
|   if (level < m_debug) return; | ||||
|   if (level > m_debug) return; | ||||
|   m_print_prefix(__func__, level); | ||||
|   switch (br_error_code) { | ||||
|     case BR_ERR_BAD_PARAM: Serial.println("Caller-provided parameter is incorrect."); break; | ||||
|  |  | |||
|  | @ -45,11 +45,31 @@ enum Error { | |||
|  */ | ||||
| enum DebugLevel { | ||||
|     SSL_NONE = 0, | ||||
|     SSL_INFO = 1, | ||||
|     SSL_ERROR = 1, | ||||
|     SSL_WARN = 2, | ||||
|     SSL_ERROR = 3 | ||||
|     SSL_INFO = 3, | ||||
| }; | ||||
| 
 | ||||
| 
 | ||||
| #ifdef __arm__ | ||||
| // should use uinstd.h to define sbrk but Due causes a conflict
 | ||||
| extern "C" char* sbrk(int incr); | ||||
| #else  // __ARM__
 | ||||
| extern char *__brkval; | ||||
| #endif  // __arm__
 | ||||
|   | ||||
| static int freeMemory() { | ||||
|   char top; | ||||
| #ifdef __arm__ | ||||
|   return &top - reinterpret_cast<char*>(sbrk(0)); | ||||
| #elif defined(CORE_TEENSY) || (ARDUINO > 103 && ARDUINO != 151) | ||||
|   return &top - __brkval; | ||||
| #else  // __arm__
 | ||||
|   return __brkval ? &top - __brkval : &top - __malloc_heap_start; | ||||
| #endif  // __arm__
 | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| /** TODO: Write what this is */ | ||||
| 
 | ||||
| class SSLClientImpl : public Client { | ||||
|  | @ -170,7 +190,7 @@ protected: | |||
|     template<typename T> | ||||
|     void m_print(const T str, const char* func_name, const DebugLevel level) const {  | ||||
|         // check the current debug level
 | ||||
|         if (level < m_debug) return; | ||||
|         if (level > m_debug) return; | ||||
|         // print prefix
 | ||||
|         m_print_prefix(func_name, level); | ||||
|         // print the message
 | ||||
|  | @ -226,7 +246,9 @@ private: | |||
|     // can expand to a bi-directional buffer with maximum of BR_SSL_BUFSIZE_BIDI
 | ||||
|     // or shrink to below BR_SSL_BUFSIZE_MONO, and bearSSL will adapt automatically
 | ||||
|     // simply edit this value to change the buffer size to the desired value
 | ||||
|     unsigned char m_iobuf[BR_SSL_BUFSIZE_MONO]; | ||||
|     // additionally, we need to correct buffer size based off of how many sessions we decide to cache
 | ||||
|     // since SSL takes so much memory if we don't it will cause the stack and heap to collide 
 | ||||
|     unsigned char m_iobuf[BR_SSL_BUFSIZE_MONO / 4]; | ||||
|     static_assert(sizeof m_iobuf <= BR_SSL_BUFSIZE_BIDI, "m_iobuf must be below maximum buffer size"); | ||||
|     // store the index of where we are writing in the buffer
 | ||||
|     // so we can send our records all at once to prevent
 | ||||
|  |  | |||
|  | @ -418,6 +418,7 @@ br_client_init_TLS12_only(br_ssl_client_context *cc, | |||
| 	 * supported hash function is appropriate; here we use SHA-256. | ||||
| 	 * The trust an | ||||
| 	 */ | ||||
| 	memset(xc, 0, sizeof *xc); | ||||
| 	br_x509_minimal_init(xc, &br_sha256_vtable, | ||||
| 		trust_anchors, trust_anchors_num); | ||||
| 
 | ||||
|  |  | |||
|  | @ -119,6 +119,7 @@ br_ssl_client_init_full(br_ssl_client_context *cc, | |||
| 	 * to TLS-1.2 (inclusive). | ||||
| 	 */ | ||||
| 	br_ssl_client_zero(cc); | ||||
| 	memset(xc, 0, sizeof *xc); | ||||
| 	br_ssl_engine_set_versions(&cc->eng, BR_TLS10, BR_TLS12); | ||||
| 
 | ||||
| 	/*
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Noah Laptop
						Noah Laptop