diff --git a/src/network/tcpsocket.c b/src/network/tcpsocket.c index b6b1899..e21aaac 100644 --- a/src/network/tcpsocket.c +++ b/src/network/tcpsocket.c @@ -82,48 +82,37 @@ static void client_to_server_state(struct ustream *s) { static void client_read_cb(struct ustream *s, int bytes) { char *str; int len = 0; - uint32_t final_len; - int max_retry = 3, tried = 0; + uint32_t final_len = sizeof(uint32_t); + str = malloc(final_len); - do { - str = ustream_get_read_buf(s, &len); - if (!str) - break; + if (ustream_read(s, str, final_len) < final_len) {//ensure recv sizeof(uint32_t). + free(str); + return; + } - if (network_config.use_symm_enc) { - final_len = ntohl(*(uint32_t *)str); - if(len < final_len) {//not complete msg, wait for next recv - fprintf(stderr,"not complete msg, len:%d, expected len:%u\n", len, final_len); - if (tried++ == max_retry) { - ustream_consume(s, len); - return;//drop package - } - continue; - } - char *dec = gcrypt_decrypt_msg(str+sizeof(final_len), final_len-sizeof(final_len)); - - handle_network_msg(dec); - free(dec); - ustream_consume(s, final_len);//one msg is processed - tried = 0; - } else { - final_len = ntohl(*(uint32_t *)str); - if(len < final_len){ - fprintf(stderr,"not complete msg, len:%d, expected len:%u\n", len, final_len); - if (tried++ == max_retry) { - ustream_consume(s, len); - return; - } - continue; - } - char* msg = malloc(final_len); - memcpy(msg, str+sizeof(final_len), final_len-sizeof(final_len)); - handle_network_msg(msg); - ustream_consume(s, final_len);//one msg is processed - free(msg); - tried = 0; + if (network_config.use_symm_enc) { + final_len = ntohl(*(uint32_t *)str); + str = realloc(str, final_len); + if ((len = ustream_read(s, str, final_len)) < final_len) {//ensure recv final_len bytes. + fprintf(stderr,"not complete msg, len:%d, expected len:%u\n", len, final_len); + free(str); + return; } - } while(1); + char *dec = gcrypt_decrypt_msg(str, final_len);//len of str is final_len + handle_network_msg(dec); + free(dec); + free(str); + } else { + final_len = ntohl(*(uint32_t *)str); + str = realloc(str, final_len); + if ((len = ustream_read(s, str, final_len)) < final_len) {//ensure recv final_len bytes. + fprintf(stderr,"not complete msg, len:%d, expected len:%u\n", len, final_len); + free(str); + return; + } + handle_network_msg(str);//len of str is final_len + free(str); + } } static void server_cb(struct uloop_fd *fd, unsigned int events) {