From 99c8b24f64c0900ed07273eba5dd480b7a92c7d8 Mon Sep 17 00:00:00 2001 From: Jop Zitman Date: Fri, 21 Mar 2025 21:15:04 +0800 Subject: [PATCH] Add LLM-based encoding (cherry picked from commit 907ffa35af90bd7fa7c7ccf49dcc99146dc1b42d) --- CMakeLists.txt | 2 + include/slipstream_llm.h | 17 ++++++++ src/slipstream_client.c | 46 +++++++++++++++++++--- src/slipstream_llm.c | 65 +++++++++++++++++++++++++++++++ src/slipstream_server.c | 84 +++++++++++++++++++++++++--------------- 5 files changed, 177 insertions(+), 37 deletions(-) create mode 100644 include/slipstream_llm.h create mode 100644 src/slipstream_llm.c diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e6a751..e1ad713 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ add_executable(slipstream src/slipstream.c src/slipstream_client.c src/slipstream_inline_dots.c + src/slipstream_llm.c src/slipstream_resolver_addresses.c src/slipstream_server.c src/slipstream_server_cc.c @@ -46,6 +47,7 @@ add_executable(slipstream src/slipstream_utils.c include/slipstream.h include/slipstream_inline_dots.h + include/slipstream_llm.h include/slipstream_resolver_addresses.h include/slipstream_server_cc.h include/slipstream_slot.h diff --git a/include/slipstream_llm.h b/include/slipstream_llm.h new file mode 100644 index 0000000..32c5e6a --- /dev/null +++ b/include/slipstream_llm.h @@ -0,0 +1,17 @@ +#ifndef SLIPSTREAM_LLM_H +#define SLIPSTREAM_LLM_H + +#include +#include + + +typedef struct st_llm_connection_t llm_connection_t; + +typedef struct st_llm_request_t llm_request_t; + +ssize_t llm_create_connection(llm_connection_t** conn_p, int port); + +ssize_t llm_encode(llm_connection_t* conn, char* dest, size_t dest_len, const char* src, size_t src_len); +ssize_t llm_decode(llm_connection_t* conn, char* dest, size_t dest_len, const char* src, size_t src_len); + +#endif //SLIPSTREAM_LLM_H diff --git a/src/slipstream_client.c b/src/slipstream_client.c index 99a99a0..e9887cb 100644 --- a/src/slipstream_client.c +++ b/src/slipstream_client.c @@ -8,6 +8,7 @@ #include #endif #include +#include #include #include #include @@ -21,6 +22,7 @@ #include "picoquic_config.h" #include "slipstream.h" #include "slipstream_inline_dots.h" +#include "slipstream_llm.h" #include "slipstream_resolver_addresses.h" #include "slipstream_utils.h" #include "SPCDNS/src/dns.h" @@ -46,15 +48,40 @@ typedef struct st_slipstream_client_ctx_t { bool closed; int listen_sock; size_t ready_mtu; + llm_connection_t* llm_conn; } slipstream_client_ctx_t; char* client_domain_name = NULL; size_t client_domain_name_len = 0; -ssize_t client_encode_segment(dns_packet_t* packet, size_t* packet_len, const unsigned char* src_buf, size_t src_buf_len) { +ssize_t client_encode_segment(slipstream_client_ctx_t* client_ctx, dns_packet_t* packet, size_t* packet_len, const unsigned char* src_buf, size_t src_buf_len) { char name[255]; - const size_t len = b32_encode(&name[0], (const char*) src_buf, src_buf_len, true, false); - const size_t encoded_len = slipstream_inline_dotify(name, 255, len); + size_t space = sizeof(name) - 1 - client_domain_name_len - 2; + ssize_t encoded_len; + if (client_ctx->llm_conn != NULL) { + encoded_len = llm_encode(client_ctx->llm_conn, &name[0], space, (const char*) src_buf, src_buf_len); + if (encoded_len < 0) { + DBG_PRINTF("error encoding with LLM: %lu (%s)", encoded_len, strerror(errno)); + return -1; + } + } else { + // regular base32 encoding + const size_t expected_len = ceil(src_buf_len * 1.6); + if (expected_len > space) { + DBG_PRINTF("encoded length %lu > %lu", expected_len, space); + return -1; + } + const size_t len = b32_encode(&name[0], (const char*) src_buf, src_buf_len, true, false); + if (len != expected_len) { + DBG_PRINTF("unexpected base32 encoded length %lu != %lu", len, expected_len); + } + encoded_len = (ssize_t) slipstream_inline_dotify(name, 255, len); + if (encoded_len < 0) { + DBG_PRINTF("error encoding base32", NULL); + return -1; + } + } + name[encoded_len] = '.'; memcpy(&name[encoded_len + 1], client_domain_name, client_domain_name_len); @@ -94,6 +121,8 @@ ssize_t client_encode_segment(dns_packet_t* packet, size_t* packet_len, const un } ssize_t client_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, size_t* segment_len, struct sockaddr_storage* peer_addr, struct sockaddr_storage* local_addr) { + slipstream_client_ctx_t* client_ctx = callback_ctx; + // optimize path for single segment if (src_buf_len <= *segment_len) { #ifdef NOENCODE @@ -104,7 +133,7 @@ ssize_t client_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf #endif size_t packet_len = MAX_DNS_QUERY_SIZE; unsigned char* packet = malloc(packet_len); - const ssize_t ret = client_encode_segment((dns_packet_t*) packet, &packet_len, src_buf, src_buf_len); + const ssize_t ret = client_encode_segment(client_ctx, (dns_packet_t*) packet, &packet_len, src_buf, src_buf_len); if (ret < 0) { free(packet); return -1; @@ -128,7 +157,7 @@ ssize_t client_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf size_t first_packet_len = 0; for (size_t i = 0; i < num_segments; i++) { size_t packet_len = MAX_DNS_QUERY_SIZE; - const ssize_t ret = client_encode_segment((dns_packet_t*) current_packet, &packet_len, segment, *segment_len); + const ssize_t ret = client_encode_segment(client_ctx, (dns_packet_t*) current_packet, &packet_len, segment, *segment_len); if (ret < 0) { free(packets); return -1; @@ -823,6 +852,12 @@ int picoquic_slipstream_client(int listen_port, char const* resolver_addresses_f param.decode = client_decode; param.encode = client_encode; + if (llm_create_connection(&client_ctx.llm_conn, 8000) < 0) { + perror("unable to create LLM connection"); + close(client_ctx.listen_sock); + exit(EXIT_FAILURE); + } + picoquic_network_thread_ctx_t thread_ctx = {0}; thread_ctx.quic = quic; thread_ctx.param = ¶m; @@ -854,6 +889,7 @@ int picoquic_slipstream_client(int listen_port, char const* resolver_addresses_f /* And finish. */ printf("Client exit, ret = %d\n", ret); + free(client_ctx.llm_conn); picoquic_free(quic); return ret; diff --git a/src/slipstream_llm.c b/src/slipstream_llm.c new file mode 100644 index 0000000..3f8d6bf --- /dev/null +++ b/src/slipstream_llm.c @@ -0,0 +1,65 @@ +#include "slipstream_llm.h" + +#include +#include +#include +#include +#include +#include + + +typedef struct st_llm_connection_t { + int sockfd; + struct sockaddr_in addr; +} llm_connection_t; + +typedef struct st_llm_request_t { + uint8_t query_data_length; + uint8_t should_decode; + uint8_t query_data[255]; +} llm_request_t; + +ssize_t llm_create_connection(llm_connection_t** conn_p, int port) { + llm_connection_t* conn = malloc(sizeof(llm_connection_t)); + if (conn == NULL) { + return -1; + } + memset(conn, 0, sizeof(llm_connection_t)); + + conn->sockfd = socket(AF_INET, SOCK_DGRAM, 0); + if (conn->sockfd < 0) { + free(conn); + return -1; + } + + conn->addr.sin_family = AF_INET; // IPv4 + conn->addr.sin_port = htons(port); // port number + conn->addr.sin_addr.s_addr = inet_addr("127.0.0.1"); // localhost + + *conn_p = conn; + return 0; +} + +ssize_t llm_send_recv(llm_connection_t* conn, char* dest, const size_t dest_len, const char* src, const size_t src_len, int should_decode) { + llm_request_t req; + req.query_data_length = src_len; + req.should_decode = should_decode; + memcpy(req.query_data, src, src_len); + + const size_t request_len = sizeof(req.query_data_length) + sizeof(req.should_decode) + req.query_data_length; + ssize_t n = sendto(conn->sockfd, &req, request_len, 0, (const struct sockaddr*)&conn->addr, sizeof(conn->addr)); + if (n < 0) { + return n; + } + + socklen_t addr_len = sizeof(conn->addr); + return recvfrom(conn->sockfd, dest, dest_len, 0, (struct sockaddr*)&conn->addr, &addr_len); +} + +ssize_t llm_encode(llm_connection_t* conn, char* dest, const size_t dest_len, const char* src, const size_t src_len) { + return llm_send_recv(conn, dest, dest_len, src, src_len, 0); +} + +ssize_t llm_decode(llm_connection_t* conn, char* dest, const size_t dest_len, const char* src, const size_t src_len) { + return llm_send_recv(conn, dest, dest_len, src, src_len, 1); +} diff --git a/src/slipstream_server.c b/src/slipstream_server.c index c9aa9c0..6cd454e 100644 --- a/src/slipstream_server.c +++ b/src/slipstream_server.c @@ -21,7 +21,8 @@ #include "picoquic_logger.h" #include "slipstream.h" #include "slipstream_inline_dots.h" -#include "../include/slipstream_server_cc.h" +#include "slipstream_llm.h" +#include "slipstream_server_cc.h" #include "slipstream_slot.h" #include "slipstream_utils.h" #include "SPCDNS/src/dns.h" @@ -30,6 +31,26 @@ char* server_domain_name = NULL; size_t server_domain_name_len = 0; +typedef struct st_slipstream_server_stream_ctx_t { + struct st_slipstream_server_stream_ctx_t* next_stream; + struct st_slipstream_server_stream_ctx_t* previous_stream; + int fd; + uint64_t stream_id; + volatile sig_atomic_t set_active; + int syn_received; + int syn_sent; +} slipstream_server_stream_ctx_t; + +typedef struct st_slipstream_server_ctx_t { + picoquic_cnx_t* cnx; + slipstream_server_stream_ctx_t* first_stream; + picoquic_network_thread_ctx_t* thread_ctx; + struct sockaddr_storage upstream_addr; + struct st_slipstream_server_ctx_t* prev_ctx; + struct st_slipstream_server_ctx_t* next_ctx; + llm_connection_t* llm_conn; +} slipstream_server_ctx_t; + ssize_t server_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, size_t* segment_len, struct sockaddr_storage* peer_addr, struct sockaddr_storage* local_addr) { // we don't support segmentation in the server assert(segment_len == NULL || *segment_len == 0 || *segment_len == src_buf_len); @@ -106,6 +127,7 @@ ssize_t server_decode(void* slot_p, void* callback_ctx, unsigned char** dest_buf *dest_buf = NULL; slot_t* slot = slot_p; + slipstream_server_ctx_t* default_ctx = callback_ctx; // DNS packets arrive from random source ports, so: // * save the original address in the dns query slot @@ -159,45 +181,37 @@ ssize_t server_decode(void* slot_p, void* callback_ctx, unsigned char** dest_buf return 0; } - // copy the subdomain from name to a new buffer - char data_buf[data_len]; - memcpy(data_buf, question->name, data_len); - data_buf[data_len] = '\0'; - const size_t encoded_len = slipstream_inline_undotify(data_buf, data_len); + char* decoded_buf = malloc(data_len); + ssize_t decoded_len; + if (default_ctx->llm_conn) { + decoded_len = llm_decode(default_ctx->llm_conn, decoded_buf, data_len, (const char*) question->name, data_len); + if (decoded_len < 0) { + // DBG_PRINTF socket errono + DBG_PRINTF("error decoding with LLM: %lu (%s)", decoded_len, strerror(errno)); + return -1; + } + } else { + // copy the subdomain from name to a new buffer + char data_buf[data_len]; + memcpy(data_buf, question->name, data_len); + data_buf[data_len] = '\0'; - char* decoded_buf = malloc(encoded_len); - const size_t decoded_len = b32_decode(decoded_buf, data_buf, encoded_len, false); - if (decoded_len == (size_t) -1) { - free(decoded_buf); - DBG_PRINTF("error decoding base32: %lu", decoded_len); - slot->error = RCODE_SERVER_FAILURE; - return 0; + const size_t encoded_len = slipstream_inline_undotify(data_buf, data_len); + decoded_len = b32_decode(decoded_buf, data_buf, encoded_len, false); + if (decoded_len == (size_t) -1) { + free(decoded_buf); + DBG_PRINTF("error decoding base32: %lu", decoded_len); + slot->error = RCODE_SERVER_FAILURE; + return 0; + } } + *dest_buf = decoded_buf; return decoded_len; } -typedef struct st_slipstream_server_stream_ctx_t { - struct st_slipstream_server_stream_ctx_t* next_stream; - struct st_slipstream_server_stream_ctx_t* previous_stream; - int fd; - uint64_t stream_id; - volatile sig_atomic_t set_active; - int syn_received; - int syn_sent; -} slipstream_server_stream_ctx_t; - -typedef struct st_slipstream_server_ctx_t { - picoquic_cnx_t* cnx; - slipstream_server_stream_ctx_t* first_stream; - picoquic_network_thread_ctx_t* thread_ctx; - struct sockaddr_storage upstream_addr; - struct st_slipstream_server_ctx_t* prev_ctx; - struct st_slipstream_server_ctx_t* next_ctx; -} slipstream_server_ctx_t; - slipstream_server_stream_ctx_t* slipstream_server_create_stream_ctx(slipstream_server_ctx_t* server_ctx, uint64_t stream_id) { slipstream_server_stream_ctx_t* stream_ctx = malloc(sizeof(slipstream_server_stream_ctx_t)); @@ -652,6 +666,11 @@ int picoquic_slipstream_server(int server_port, const char* server_cert, const c param.encode = server_encode; // param.delay_max = 5000; + if (llm_create_connection(&default_context.llm_conn, 8001) < 0) { + perror("unable to create LLM connection"); + exit(EXIT_FAILURE); + } + picoquic_network_thread_ctx_t thread_ctx = {0}; thread_ctx.quic = quic; thread_ctx.param = ¶m; @@ -671,6 +690,7 @@ int picoquic_slipstream_server(int server_port, const char* server_cert, const c /* And finish. */ printf("Server exit, ret = %d\n", ret); + free(default_context.llm_conn); picoquic_free(quic); return ret;