diff --git a/CMakeLists.txt b/CMakeLists.txt index 26ebd38..778d741 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,11 +43,13 @@ add_executable(slipstream src/slipstream_server.c src/slipstream_dns_request_buffer.c src/slipstream_inline_dots.c + src/slipstream_packet.c src/slipstream_resolver_addresses.c src/slipstream_utils.c include/slipstream.h include/slipstream_dns_request_buffer.h include/slipstream_inline_dots.h + include/slipstream_packet.h include/slipstream_resolver_addresses.h include/slipstream_utils.h diff --git a/extern/picoquic b/extern/picoquic index d2404eb..7ac027d 160000 --- a/extern/picoquic +++ b/extern/picoquic @@ -1 +1 @@ -Subproject commit d2404ebedc9bff468eb005601e4ec91d7db2cb93 +Subproject commit 7ac027d6e16ef344a19f8e32d49c93a6caef15da diff --git a/include/slipstream_packet.h b/include/slipstream_packet.h new file mode 100644 index 0000000..d777abf --- /dev/null +++ b/include/slipstream_packet.h @@ -0,0 +1,15 @@ +#ifndef SLIPSTREAM_PACKET +#define SLIPSTREAM_PACKET + +#include +#include "picoquic.h" + +#define PICOQUIC_SHORT_HEADER_CONNECTION_ID_SIZE 8 + +bool slipstream_packet_is_long_header(const uint8_t first_byte); + +int slipstream_packet_create_poll(uint8_t** dest_buf, size_t* dest_buf_len, picoquic_connection_id_t dst_connection_id); + +int slipstream_packet_parse(uint8_t* src_buf, size_t src_buf_len, size_t short_header_conn_id_len, picoquic_connection_id_t* src_connection_id, picoquic_connection_id_t* dst_connection_id, bool* is_poll_packet); + +#endif // SLIPSTREAM_PACKET diff --git a/src/slipstream_client.c b/src/slipstream_client.c index 1271e9e..7a14d19 100644 --- a/src/slipstream_client.c +++ b/src/slipstream_client.c @@ -19,6 +19,7 @@ #include "picoquic_config.h" #include "slipstream.h" #include "slipstream_inline_dots.h" +#include "slipstream_packet.h" #include "slipstream_resolver_addresses.h" #include "SPCDNS/src/dns.h" #include "SPCDNS/src/mappings.h" @@ -68,7 +69,7 @@ ssize_t client_encode_segment(picoquic_quic_t* quic, dns_packet_t* packet, size_ return 0; } -ssize_t client_encode(picoquic_quic_t* quic, picoquic_cnx_t* last_cnx, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, size_t* segment_len, struct sockaddr_storage* peer_addr) { +ssize_t client_encode(picoquic_quic_t* quic, picoquic_cnx_t* cnx, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, size_t* segment_len, struct sockaddr_storage* peer_addr) { // optimize path for single segment if (src_buf_len <= *segment_len) { size_t packet_len = MAX_DNS_QUERY_SIZE; @@ -119,7 +120,7 @@ ssize_t client_encode(picoquic_quic_t* quic, picoquic_cnx_t* last_cnx, unsigned return current_packet - packets; } -ssize_t client_decode(picoquic_quic_t* quic, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, struct sockaddr_storage* peer_addr) { +ssize_t client_decode(picoquic_quic_t* quic, picoquic_socket_ctx_t* s_ctx, size_t s_ctx_len, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, struct sockaddr_storage* peer_addr, struct sockaddr_storage* local_addr) { *dest_buf = NULL; size_t bufsize = DNS_DECODEBUF_4K * sizeof(dns_decoded_t); @@ -156,6 +157,66 @@ ssize_t client_decode(picoquic_quic_t* quic, unsigned char** dest_buf, const uns *dest_buf = malloc(answer_txt->len); memcpy((void*)*dest_buf, answer_txt->text, answer_txt->len); + + picoquic_connection_id_t incoming_src_connection_id = {0}; + picoquic_connection_id_t incoming_dest_connection_id; // sure to be set by parser + bool is_poll_packet = false; + int ret = slipstream_packet_parse(*dest_buf, answer_txt->len, PICOQUIC_SHORT_HEADER_CONNECTION_ID_SIZE, + &incoming_src_connection_id, &incoming_dest_connection_id, &is_poll_packet); + if (ret != 0 || is_poll_packet) { + fprintf(stderr, "error parsing slipstream packet: %d\n", ret); + return answer_txt->len; + } + + const SOCKET_TYPE send_socket = picoquic_socket_get_send_socket(s_ctx, s_ctx_len, peer_addr, local_addr); + if (send_socket == INVALID_SOCKET) { + fprintf(stderr, "no valid socket found for poll packet\n"); + return answer_txt->len; + } + + // get active destination connection id on this ctx + picoquic_cnx_t* cnx = picoquic_cnx_by_id_(quic, incoming_dest_connection_id); + picoquic_connection_id_t outgoing_dest_connection_id = cnx->path[0]->p_remote_cnxid->cnx_id; + if (outgoing_dest_connection_id.id_len == 0) { + // p_remote_cnxid is not set yet when we are receiving the first server response + outgoing_dest_connection_id = incoming_src_connection_id; + } + + const int poll_ratio = 1; + for (int j = 0; j < poll_ratio; ++j) { + uint8_t* poll_packet_buf; + size_t poll_packet_len; + ret = slipstream_packet_create_poll(&poll_packet_buf, &poll_packet_len, outgoing_dest_connection_id); + if (ret < 0) { + fprintf(stderr, "error creating poll packet\n"); + return answer_txt->len; + } + + unsigned char* encoded; + ssize_t encoded_len = client_encode(quic, cnx, &encoded, poll_packet_buf, poll_packet_len, &poll_packet_len, + peer_addr); + if (encoded_len <= 0) { + fprintf(stderr, "error encoding poll packet\n"); + free(poll_packet_buf); + return answer_txt->len; + } + + int sock_err = 0; + ret = picoquic_sendmsg(send_socket, + (struct sockaddr*)peer_addr, (struct sockaddr*)local_addr, 0, + (const char*)encoded, encoded_len, 0, &sock_err); + if (ret < 0) { + fprintf(stderr, "Error sending poll packet, ret=%d, sock_err=%d %s\n", ret, sock_err, strerror(sock_err)); + free(poll_packet_buf); + free(encoded); + return answer_txt->len; + } + + free(poll_packet_buf); + free(encoded); + } + + return answer_txt->len; } diff --git a/src/slipstream_packet.c b/src/slipstream_packet.c new file mode 100644 index 0000000..cce17f3 --- /dev/null +++ b/src/slipstream_packet.c @@ -0,0 +1,95 @@ +#include +#include + +#include "slipstream_packet.h" +#include "picoquic_utils.h" + +const int num_padding_for_poll = 5; + +#define PICOQUIC_SHORT_HEADER_CONNECTION_ID_SIZE 8 + +bool slipstream_packet_is_long_header(const uint8_t first_byte) { + return first_byte & 0x80; +} + +int slipstream_packet_create_poll(uint8_t** dest_buf, size_t* dest_buf_len, picoquic_connection_id_t dst_connection_id) { + *dest_buf = NULL; + + if (num_padding_for_poll < 5) { + return -1; + } + + // Allocate a num_padding_for_poll + dst_connection_id.id_len + len marker + dst_connection_id len marker + size_t packet_len = num_padding_for_poll + dst_connection_id.id_len + 1 + 1; + uint8_t* packet = malloc(packet_len); + + // Write random padding bytes to the entire packet + for (int i = 0; i < packet_len; i++) { + packet[i] = rand() % 256; + } + + packet[0] |= 0x80; // Set bit 7 (long header format) + + // Write destination connection ID + packet[5] = dst_connection_id.id_len; + memcpy(&packet[6], dst_connection_id.id, dst_connection_id.id_len); + + // Ensure the source connection ID len marker byte is larger than PICOQUIC_CONNECTION_ID_MAX_SIZE + int randomly_written_src_connection_id = packet[5+1+dst_connection_id.id_len]; + if (randomly_written_src_connection_id <= PICOQUIC_CONNECTION_ID_MAX_SIZE) { + packet[5+1+dst_connection_id.id_len] = PICOQUIC_CONNECTION_ID_MAX_SIZE + 1; + } + + // The rest of the payload (including pretend second connection ID) is random padding + + *dest_buf = packet; + *dest_buf_len = packet_len; + + return packet_len; +} + +int slipstream_packet_parse(uint8_t* src_buf, size_t src_buf_len, size_t short_header_conn_id_len, picoquic_connection_id_t* src_connection_id, picoquic_connection_id_t* dst_connection_id, bool* is_poll_packet) { + if (src_buf_len < 1) { + return -1; + } + + // Short header packet + if (!slipstream_packet_is_long_header(src_buf[0])) { + // Short header packets can't be poll packets + if (src_buf_len < short_header_conn_id_len + 1) { + return -1; + } + + picoquic_parse_connection_id(&src_buf[1], short_header_conn_id_len, dst_connection_id); + return 0; + } + + // Read destination connection ID + if (src_buf_len < 5+1) { + return -1; + } + const size_t dst_connection_id_len = src_buf[5]; + if (dst_connection_id_len > PICOQUIC_CONNECTION_ID_MAX_SIZE) { + return -1; + } + if (src_buf_len < 5+1+dst_connection_id_len) { + return -1; + } + picoquic_parse_connection_id(&src_buf[5+1], dst_connection_id_len, dst_connection_id); + + // Read source connection ID + if (src_buf_len < 5+1+dst_connection_id_len+1) { + return -1; + } + const size_t src_connection_id_len = src_buf[5+1+dst_connection_id_len]; + if (src_connection_id_len > PICOQUIC_CONNECTION_ID_MAX_SIZE) { + *is_poll_packet = true; + return 0; + } + if (src_buf_len < 5+1+dst_connection_id_len+1+src_connection_id_len) { + return -1; + } + picoquic_parse_connection_id(&src_buf[5+1+dst_connection_id_len+1], src_connection_id_len, src_connection_id); + + return 0; +} diff --git a/src/slipstream_server.c b/src/slipstream_server.c index 5c44481..8b39953 100644 --- a/src/slipstream_server.c +++ b/src/slipstream_server.c @@ -99,7 +99,7 @@ ssize_t server_encode(picoquic_quic_t* quic, picoquic_cnx_t* cnx, unsigned char* return packet_len; } -ssize_t server_decode(picoquic_quic_t* quic, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, struct sockaddr_storage *peer_addr) { +ssize_t server_decode(picoquic_quic_t* quic, picoquic_socket_ctx_t* s_ctx, size_t s_ctx_len, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, struct sockaddr_storage *peer_addr, struct sockaddr_storage *local_addr) { *dest_buf = NULL; slot_t* slot = slipstream_dns_request_buffer_get_write_slot(&slipstream_server_dns_request_buffer);