network/tcpsocket: make sure every msg is complete before handle

This commit is contained in:
twy_2000 2020-05-28 14:15:19 +08:00 committed by Polynomialdivision
parent 50d347c233
commit 653ce9fa56

View file

@ -81,7 +81,9 @@ static void client_to_server_state(struct ustream *s) {
static void client_read_cb(struct ustream *s, int bytes) {
char *str;
int len;
int len = 0;
size_t final_len;
int max_retry = 3, tried = 0;
do {
str = ustream_get_read_buf(s, &len);
@ -89,25 +91,39 @@ static void client_read_cb(struct ustream *s, int bytes) {
break;
if (network_config.use_symm_enc) {
char *base64_dec_str = malloc(B64_DECODE_LEN(strlen(str)));
int base64_dec_length = b64_decode(str, base64_dec_str, B64_DECODE_LEN(strlen(str)));
char *dec = gcrypt_decrypt_msg(base64_dec_str, base64_dec_length);
final_len = *(size_t*)str;
if(len < final_len) {//not complete msg, wait for next recv
fprintf(stderr,"not complete msg, len:%d, expected len:%d\n", len, final_len);
if (tried++ == max_retry) {
ustream_consume(s, len);
return;//drop package
}
continue;
}
char *dec = gcrypt_decrypt_msg(str+sizeof(size_t), final_len-sizeof(size_t));
free(base64_dec_str);
handle_network_msg(dec);
free(dec);
ustream_consume(s, final_len);//one msg is processed
tried = 0;
} else {
handle_network_msg(str);
final_len = *(size_t*)str;
if(len < final_len){
fprintf(stderr,"not complete msg, len:%d, expected len:%d\n", len, final_len);
if (tried++ == max_retry) {
ustream_consume(s, len);
return;
}
continue;
}
char* msg = malloc(final_len);
memcpy(msg, str+sizeof(size_t), final_len-sizeof(size_t));
handle_network_msg(msg);
ustream_consume(s, final_len);//one msg is processed
free(msg);
tried = 0;
}
ustream_consume(s, len);
} while (1);
if (s->w.data_bytes > 256 && !ustream_read_blocked(s)) {
fprintf(stderr, "Block read, bytes: %d\n", s->w.data_bytes);
ustream_set_read_blocked(s, true);
}
} while(1);
}
static void server_cb(struct uloop_fd *fd, unsigned int events) {
@ -227,16 +243,19 @@ void send_tcp(char *msg) {
print_tcp_array();
if (network_config.use_symm_enc) {
int length_enc;
size_t msglen = strlen(msg);
char *enc = gcrypt_encrypt_msg(msg, msglen + 1, &length_enc);
size_t msglen = strlen(msg)+1;
char *enc = gcrypt_encrypt_msg(msg, msglen, &length_enc);
char *base64_enc_str = malloc(B64_ENCODE_LEN(length_enc));
size_t base64_enc_length = b64_encode(enc, length_enc, base64_enc_str, B64_ENCODE_LEN(length_enc));
struct network_con_s *con;
size_t final_len = length_enc + sizeof(size_t);
char *final_str = malloc(final_len);
size_t *msg_header = (size_t*)final_str;
*msg_header = final_len;
memcpy(final_str+sizeof(size_t), enc, length_enc);
list_for_each_entry(con, &tcp_sock_list, list)
{
if (con->connected) {
int len_ustream = ustream_write(&con->stream.stream, base64_enc_str, base64_enc_length, 0);
int len_ustream = ustream_write(&con->stream.stream, final_str, final_len, 0);
printf("Ustream send: %d\n", len_ustream);
if (len_ustream <= 0) {
fprintf(stderr,"Ustream error!\n");
@ -246,20 +265,29 @@ void send_tcp(char *msg) {
}
free(base64_enc_str);
free(final_str);
free(enc);
} else {
size_t msglen = strlen(msg) + 1;
size_t final_len = msglen + sizeof(size_t);
char *final_str = malloc(final_len);
size_t *msg_header = (size_t*)final_str;
*msg_header = final_len;
memcpy(final_str+sizeof(size_t), msg, msglen);
struct network_con_s *con;
list_for_each_entry(con, &tcp_sock_list, list)
{
if (con->connected) {
if (ustream_printf(&con->stream.stream, "%s", msg) == 0) {
int len_ustream = ustream_write(&con->stream.stream, final_str, final_len, 0);
printf("Ustream send: %d\n", len_ustream);
if (len_ustream <= 0) {
//TODO: ERROR HANDLING!
fprintf(stderr,"Ustream error!\n");
}
}
}
free(final_str);
}
}