diff --git a/src/network/tcpsocket.c b/src/network/tcpsocket.c index 3eb266a..b6b1899 100644 --- a/src/network/tcpsocket.c +++ b/src/network/tcpsocket.c @@ -82,30 +82,48 @@ static void client_to_server_state(struct ustream *s) { static void client_read_cb(struct ustream *s, int bytes) { char *str; int len = 0; - uint32_t final_len = sizeof(uint32_t); - str = malloc(final_len); + uint32_t final_len; + int max_retry = 3, tried = 0; - if ((len = ustream_read(s, str, final_len)) < final_len){//ensure recv sizeof(uint32_t). - fprintf(stderr,"not complete msg, len:%d, expected len:%u\n", len, final_len); - goto out; - } - - final_len = ntohl(*(uint32_t *)str) - sizeof(uint32_t);//the final_len in headder includes header itself - str = realloc(str, final_len); - if ((len = ustream_read(s, str, final_len)) < final_len) {//ensure recv final_len bytes. - fprintf(stderr,"not complete msg, len:%d, expected len:%u\n", len, final_len); - goto out; - } + do { + str = ustream_get_read_buf(s, &len); + if (!str) + break; - if (network_config.use_symm_enc) { - char *dec = gcrypt_decrypt_msg(str, final_len);//len of str is final_len - handle_network_msg(dec); - free(dec); - } else { - handle_network_msg(str);//len of str is final_len - } -out: - free(str); + if (network_config.use_symm_enc) { + final_len = ntohl(*(uint32_t *)str); + if(len < final_len) {//not complete msg, wait for next recv + fprintf(stderr,"not complete msg, len:%d, expected len:%u\n", len, final_len); + if (tried++ == max_retry) { + ustream_consume(s, len); + return;//drop package + } + continue; + } + char *dec = gcrypt_decrypt_msg(str+sizeof(final_len), final_len-sizeof(final_len)); + + handle_network_msg(dec); + free(dec); + ustream_consume(s, final_len);//one msg is processed + tried = 0; + } else { + final_len = ntohl(*(uint32_t *)str); + if(len < final_len){ + fprintf(stderr,"not complete msg, len:%d, expected len:%u\n", len, final_len); + if (tried++ == max_retry) { + ustream_consume(s, len); + return; + } + continue; + } + char* msg = malloc(final_len); + memcpy(msg, str+sizeof(final_len), final_len-sizeof(final_len)); + handle_network_msg(msg); + ustream_consume(s, final_len);//one msg is processed + free(msg); + tried = 0; + } + } while(1); } static void server_cb(struct uloop_fd *fd, unsigned int events) {