Add LLM-based encoding

(cherry picked from commit 907ffa35af90bd7fa7c7ccf49dcc99146dc1b42d)
This commit is contained in:
Jop Zitman 2025-03-21 21:15:04 +08:00
parent bee0de0c26
commit 99c8b24f64
5 changed files with 177 additions and 37 deletions

View file

@ -39,6 +39,7 @@ add_executable(slipstream
src/slipstream.c
src/slipstream_client.c
src/slipstream_inline_dots.c
src/slipstream_llm.c
src/slipstream_resolver_addresses.c
src/slipstream_server.c
src/slipstream_server_cc.c
@ -46,6 +47,7 @@ add_executable(slipstream
src/slipstream_utils.c
include/slipstream.h
include/slipstream_inline_dots.h
include/slipstream_llm.h
include/slipstream_resolver_addresses.h
include/slipstream_server_cc.h
include/slipstream_slot.h

17
include/slipstream_llm.h Normal file
View file

@ -0,0 +1,17 @@
#ifndef SLIPSTREAM_LLM_H
#define SLIPSTREAM_LLM_H
#include <stdio.h>
#include <netinet/in.h>
typedef struct st_llm_connection_t llm_connection_t;
typedef struct st_llm_request_t llm_request_t;
ssize_t llm_create_connection(llm_connection_t** conn_p, int port);
ssize_t llm_encode(llm_connection_t* conn, char* dest, size_t dest_len, const char* src, size_t src_len);
ssize_t llm_decode(llm_connection_t* conn, char* dest, size_t dest_len, const char* src, size_t src_len);
#endif //SLIPSTREAM_LLM_H

View file

@ -8,6 +8,7 @@
#include <autoqlog.h>
#endif
#include <assert.h>
#include <math.h>
#include <picoquic_internal.h>
#include <pthread.h>
#include <slipstream_sockloop.h>
@ -21,6 +22,7 @@
#include "picoquic_config.h"
#include "slipstream.h"
#include "slipstream_inline_dots.h"
#include "slipstream_llm.h"
#include "slipstream_resolver_addresses.h"
#include "slipstream_utils.h"
#include "SPCDNS/src/dns.h"
@ -46,15 +48,40 @@ typedef struct st_slipstream_client_ctx_t {
bool closed;
int listen_sock;
size_t ready_mtu;
llm_connection_t* llm_conn;
} slipstream_client_ctx_t;
char* client_domain_name = NULL;
size_t client_domain_name_len = 0;
ssize_t client_encode_segment(dns_packet_t* packet, size_t* packet_len, const unsigned char* src_buf, size_t src_buf_len) {
ssize_t client_encode_segment(slipstream_client_ctx_t* client_ctx, dns_packet_t* packet, size_t* packet_len, const unsigned char* src_buf, size_t src_buf_len) {
char name[255];
const size_t len = b32_encode(&name[0], (const char*) src_buf, src_buf_len, true, false);
const size_t encoded_len = slipstream_inline_dotify(name, 255, len);
size_t space = sizeof(name) - 1 - client_domain_name_len - 2;
ssize_t encoded_len;
if (client_ctx->llm_conn != NULL) {
encoded_len = llm_encode(client_ctx->llm_conn, &name[0], space, (const char*) src_buf, src_buf_len);
if (encoded_len < 0) {
DBG_PRINTF("error encoding with LLM: %lu (%s)", encoded_len, strerror(errno));
return -1;
}
} else {
// regular base32 encoding
const size_t expected_len = ceil(src_buf_len * 1.6);
if (expected_len > space) {
DBG_PRINTF("encoded length %lu > %lu", expected_len, space);
return -1;
}
const size_t len = b32_encode(&name[0], (const char*) src_buf, src_buf_len, true, false);
if (len != expected_len) {
DBG_PRINTF("unexpected base32 encoded length %lu != %lu", len, expected_len);
}
encoded_len = (ssize_t) slipstream_inline_dotify(name, 255, len);
if (encoded_len < 0) {
DBG_PRINTF("error encoding base32", NULL);
return -1;
}
}
name[encoded_len] = '.';
memcpy(&name[encoded_len + 1], client_domain_name, client_domain_name_len);
@ -94,6 +121,8 @@ ssize_t client_encode_segment(dns_packet_t* packet, size_t* packet_len, const un
}
ssize_t client_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, size_t* segment_len, struct sockaddr_storage* peer_addr, struct sockaddr_storage* local_addr) {
slipstream_client_ctx_t* client_ctx = callback_ctx;
// optimize path for single segment
if (src_buf_len <= *segment_len) {
#ifdef NOENCODE
@ -104,7 +133,7 @@ ssize_t client_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf
#endif
size_t packet_len = MAX_DNS_QUERY_SIZE;
unsigned char* packet = malloc(packet_len);
const ssize_t ret = client_encode_segment((dns_packet_t*) packet, &packet_len, src_buf, src_buf_len);
const ssize_t ret = client_encode_segment(client_ctx, (dns_packet_t*) packet, &packet_len, src_buf, src_buf_len);
if (ret < 0) {
free(packet);
return -1;
@ -128,7 +157,7 @@ ssize_t client_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf
size_t first_packet_len = 0;
for (size_t i = 0; i < num_segments; i++) {
size_t packet_len = MAX_DNS_QUERY_SIZE;
const ssize_t ret = client_encode_segment((dns_packet_t*) current_packet, &packet_len, segment, *segment_len);
const ssize_t ret = client_encode_segment(client_ctx, (dns_packet_t*) current_packet, &packet_len, segment, *segment_len);
if (ret < 0) {
free(packets);
return -1;
@ -823,6 +852,12 @@ int picoquic_slipstream_client(int listen_port, char const* resolver_addresses_f
param.decode = client_decode;
param.encode = client_encode;
if (llm_create_connection(&client_ctx.llm_conn, 8000) < 0) {
perror("unable to create LLM connection");
close(client_ctx.listen_sock);
exit(EXIT_FAILURE);
}
picoquic_network_thread_ctx_t thread_ctx = {0};
thread_ctx.quic = quic;
thread_ctx.param = &param;
@ -854,6 +889,7 @@ int picoquic_slipstream_client(int listen_port, char const* resolver_addresses_f
/* And finish. */
printf("Client exit, ret = %d\n", ret);
free(client_ctx.llm_conn);
picoquic_free(quic);
return ret;

65
src/slipstream_llm.c Normal file
View file

@ -0,0 +1,65 @@
#include "slipstream_llm.h"
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
typedef struct st_llm_connection_t {
int sockfd;
struct sockaddr_in addr;
} llm_connection_t;
typedef struct st_llm_request_t {
uint8_t query_data_length;
uint8_t should_decode;
uint8_t query_data[255];
} llm_request_t;
ssize_t llm_create_connection(llm_connection_t** conn_p, int port) {
llm_connection_t* conn = malloc(sizeof(llm_connection_t));
if (conn == NULL) {
return -1;
}
memset(conn, 0, sizeof(llm_connection_t));
conn->sockfd = socket(AF_INET, SOCK_DGRAM, 0);
if (conn->sockfd < 0) {
free(conn);
return -1;
}
conn->addr.sin_family = AF_INET; // IPv4
conn->addr.sin_port = htons(port); // port number
conn->addr.sin_addr.s_addr = inet_addr("127.0.0.1"); // localhost
*conn_p = conn;
return 0;
}
ssize_t llm_send_recv(llm_connection_t* conn, char* dest, const size_t dest_len, const char* src, const size_t src_len, int should_decode) {
llm_request_t req;
req.query_data_length = src_len;
req.should_decode = should_decode;
memcpy(req.query_data, src, src_len);
const size_t request_len = sizeof(req.query_data_length) + sizeof(req.should_decode) + req.query_data_length;
ssize_t n = sendto(conn->sockfd, &req, request_len, 0, (const struct sockaddr*)&conn->addr, sizeof(conn->addr));
if (n < 0) {
return n;
}
socklen_t addr_len = sizeof(conn->addr);
return recvfrom(conn->sockfd, dest, dest_len, 0, (struct sockaddr*)&conn->addr, &addr_len);
}
ssize_t llm_encode(llm_connection_t* conn, char* dest, const size_t dest_len, const char* src, const size_t src_len) {
return llm_send_recv(conn, dest, dest_len, src, src_len, 0);
}
ssize_t llm_decode(llm_connection_t* conn, char* dest, const size_t dest_len, const char* src, const size_t src_len) {
return llm_send_recv(conn, dest, dest_len, src, src_len, 1);
}

View file

@ -21,7 +21,8 @@
#include "picoquic_logger.h"
#include "slipstream.h"
#include "slipstream_inline_dots.h"
#include "../include/slipstream_server_cc.h"
#include "slipstream_llm.h"
#include "slipstream_server_cc.h"
#include "slipstream_slot.h"
#include "slipstream_utils.h"
#include "SPCDNS/src/dns.h"
@ -30,6 +31,26 @@
char* server_domain_name = NULL;
size_t server_domain_name_len = 0;
typedef struct st_slipstream_server_stream_ctx_t {
struct st_slipstream_server_stream_ctx_t* next_stream;
struct st_slipstream_server_stream_ctx_t* previous_stream;
int fd;
uint64_t stream_id;
volatile sig_atomic_t set_active;
int syn_received;
int syn_sent;
} slipstream_server_stream_ctx_t;
typedef struct st_slipstream_server_ctx_t {
picoquic_cnx_t* cnx;
slipstream_server_stream_ctx_t* first_stream;
picoquic_network_thread_ctx_t* thread_ctx;
struct sockaddr_storage upstream_addr;
struct st_slipstream_server_ctx_t* prev_ctx;
struct st_slipstream_server_ctx_t* next_ctx;
llm_connection_t* llm_conn;
} slipstream_server_ctx_t;
ssize_t server_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, size_t* segment_len, struct sockaddr_storage* peer_addr, struct sockaddr_storage* local_addr) {
// we don't support segmentation in the server
assert(segment_len == NULL || *segment_len == 0 || *segment_len == src_buf_len);
@ -106,6 +127,7 @@ ssize_t server_decode(void* slot_p, void* callback_ctx, unsigned char** dest_buf
*dest_buf = NULL;
slot_t* slot = slot_p;
slipstream_server_ctx_t* default_ctx = callback_ctx;
// DNS packets arrive from random source ports, so:
// * save the original address in the dns query slot
@ -159,45 +181,37 @@ ssize_t server_decode(void* slot_p, void* callback_ctx, unsigned char** dest_buf
return 0;
}
// copy the subdomain from name to a new buffer
char data_buf[data_len];
memcpy(data_buf, question->name, data_len);
data_buf[data_len] = '\0';
const size_t encoded_len = slipstream_inline_undotify(data_buf, data_len);
char* decoded_buf = malloc(data_len);
ssize_t decoded_len;
if (default_ctx->llm_conn) {
decoded_len = llm_decode(default_ctx->llm_conn, decoded_buf, data_len, (const char*) question->name, data_len);
if (decoded_len < 0) {
// DBG_PRINTF socket errono
DBG_PRINTF("error decoding with LLM: %lu (%s)", decoded_len, strerror(errno));
return -1;
}
} else {
// copy the subdomain from name to a new buffer
char data_buf[data_len];
memcpy(data_buf, question->name, data_len);
data_buf[data_len] = '\0';
char* decoded_buf = malloc(encoded_len);
const size_t decoded_len = b32_decode(decoded_buf, data_buf, encoded_len, false);
if (decoded_len == (size_t) -1) {
free(decoded_buf);
DBG_PRINTF("error decoding base32: %lu", decoded_len);
slot->error = RCODE_SERVER_FAILURE;
return 0;
const size_t encoded_len = slipstream_inline_undotify(data_buf, data_len);
decoded_len = b32_decode(decoded_buf, data_buf, encoded_len, false);
if (decoded_len == (size_t) -1) {
free(decoded_buf);
DBG_PRINTF("error decoding base32: %lu", decoded_len);
slot->error = RCODE_SERVER_FAILURE;
return 0;
}
}
*dest_buf = decoded_buf;
return decoded_len;
}
typedef struct st_slipstream_server_stream_ctx_t {
struct st_slipstream_server_stream_ctx_t* next_stream;
struct st_slipstream_server_stream_ctx_t* previous_stream;
int fd;
uint64_t stream_id;
volatile sig_atomic_t set_active;
int syn_received;
int syn_sent;
} slipstream_server_stream_ctx_t;
typedef struct st_slipstream_server_ctx_t {
picoquic_cnx_t* cnx;
slipstream_server_stream_ctx_t* first_stream;
picoquic_network_thread_ctx_t* thread_ctx;
struct sockaddr_storage upstream_addr;
struct st_slipstream_server_ctx_t* prev_ctx;
struct st_slipstream_server_ctx_t* next_ctx;
} slipstream_server_ctx_t;
slipstream_server_stream_ctx_t* slipstream_server_create_stream_ctx(slipstream_server_ctx_t* server_ctx,
uint64_t stream_id) {
slipstream_server_stream_ctx_t* stream_ctx = malloc(sizeof(slipstream_server_stream_ctx_t));
@ -652,6 +666,11 @@ int picoquic_slipstream_server(int server_port, const char* server_cert, const c
param.encode = server_encode;
// param.delay_max = 5000;
if (llm_create_connection(&default_context.llm_conn, 8001) < 0) {
perror("unable to create LLM connection");
exit(EXIT_FAILURE);
}
picoquic_network_thread_ctx_t thread_ctx = {0};
thread_ctx.quic = quic;
thread_ctx.param = &param;
@ -671,6 +690,7 @@ int picoquic_slipstream_server(int server_port, const char* server_cert, const c
/* And finish. */
printf("Server exit, ret = %d\n", ret);
free(default_context.llm_conn);
picoquic_free(quic);
return ret;