mirror of
https://github.com/EndPositive/slipstream.git
synced 2025-10-08 12:25:04 +00:00
Add LLM-based encoding
(cherry picked from commit 907ffa35af90bd7fa7c7ccf49dcc99146dc1b42d)
This commit is contained in:
parent
bee0de0c26
commit
99c8b24f64
5 changed files with 177 additions and 37 deletions
|
|
@ -39,6 +39,7 @@ add_executable(slipstream
|
|||
src/slipstream.c
|
||||
src/slipstream_client.c
|
||||
src/slipstream_inline_dots.c
|
||||
src/slipstream_llm.c
|
||||
src/slipstream_resolver_addresses.c
|
||||
src/slipstream_server.c
|
||||
src/slipstream_server_cc.c
|
||||
|
|
@ -46,6 +47,7 @@ add_executable(slipstream
|
|||
src/slipstream_utils.c
|
||||
include/slipstream.h
|
||||
include/slipstream_inline_dots.h
|
||||
include/slipstream_llm.h
|
||||
include/slipstream_resolver_addresses.h
|
||||
include/slipstream_server_cc.h
|
||||
include/slipstream_slot.h
|
||||
|
|
|
|||
17
include/slipstream_llm.h
Normal file
17
include/slipstream_llm.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#ifndef SLIPSTREAM_LLM_H
|
||||
#define SLIPSTREAM_LLM_H
|
||||
|
||||
#include <stdio.h>
|
||||
#include <netinet/in.h>
|
||||
|
||||
|
||||
typedef struct st_llm_connection_t llm_connection_t;
|
||||
|
||||
typedef struct st_llm_request_t llm_request_t;
|
||||
|
||||
ssize_t llm_create_connection(llm_connection_t** conn_p, int port);
|
||||
|
||||
ssize_t llm_encode(llm_connection_t* conn, char* dest, size_t dest_len, const char* src, size_t src_len);
|
||||
ssize_t llm_decode(llm_connection_t* conn, char* dest, size_t dest_len, const char* src, size_t src_len);
|
||||
|
||||
#endif //SLIPSTREAM_LLM_H
|
||||
|
|
@ -8,6 +8,7 @@
|
|||
#include <autoqlog.h>
|
||||
#endif
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <picoquic_internal.h>
|
||||
#include <pthread.h>
|
||||
#include <slipstream_sockloop.h>
|
||||
|
|
@ -21,6 +22,7 @@
|
|||
#include "picoquic_config.h"
|
||||
#include "slipstream.h"
|
||||
#include "slipstream_inline_dots.h"
|
||||
#include "slipstream_llm.h"
|
||||
#include "slipstream_resolver_addresses.h"
|
||||
#include "slipstream_utils.h"
|
||||
#include "SPCDNS/src/dns.h"
|
||||
|
|
@ -46,15 +48,40 @@ typedef struct st_slipstream_client_ctx_t {
|
|||
bool closed;
|
||||
int listen_sock;
|
||||
size_t ready_mtu;
|
||||
llm_connection_t* llm_conn;
|
||||
} slipstream_client_ctx_t;
|
||||
|
||||
char* client_domain_name = NULL;
|
||||
size_t client_domain_name_len = 0;
|
||||
|
||||
ssize_t client_encode_segment(dns_packet_t* packet, size_t* packet_len, const unsigned char* src_buf, size_t src_buf_len) {
|
||||
ssize_t client_encode_segment(slipstream_client_ctx_t* client_ctx, dns_packet_t* packet, size_t* packet_len, const unsigned char* src_buf, size_t src_buf_len) {
|
||||
char name[255];
|
||||
const size_t len = b32_encode(&name[0], (const char*) src_buf, src_buf_len, true, false);
|
||||
const size_t encoded_len = slipstream_inline_dotify(name, 255, len);
|
||||
size_t space = sizeof(name) - 1 - client_domain_name_len - 2;
|
||||
ssize_t encoded_len;
|
||||
if (client_ctx->llm_conn != NULL) {
|
||||
encoded_len = llm_encode(client_ctx->llm_conn, &name[0], space, (const char*) src_buf, src_buf_len);
|
||||
if (encoded_len < 0) {
|
||||
DBG_PRINTF("error encoding with LLM: %lu (%s)", encoded_len, strerror(errno));
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
// regular base32 encoding
|
||||
const size_t expected_len = ceil(src_buf_len * 1.6);
|
||||
if (expected_len > space) {
|
||||
DBG_PRINTF("encoded length %lu > %lu", expected_len, space);
|
||||
return -1;
|
||||
}
|
||||
const size_t len = b32_encode(&name[0], (const char*) src_buf, src_buf_len, true, false);
|
||||
if (len != expected_len) {
|
||||
DBG_PRINTF("unexpected base32 encoded length %lu != %lu", len, expected_len);
|
||||
}
|
||||
encoded_len = (ssize_t) slipstream_inline_dotify(name, 255, len);
|
||||
if (encoded_len < 0) {
|
||||
DBG_PRINTF("error encoding base32", NULL);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
name[encoded_len] = '.';
|
||||
|
||||
memcpy(&name[encoded_len + 1], client_domain_name, client_domain_name_len);
|
||||
|
|
@ -94,6 +121,8 @@ ssize_t client_encode_segment(dns_packet_t* packet, size_t* packet_len, const un
|
|||
}
|
||||
|
||||
ssize_t client_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, size_t* segment_len, struct sockaddr_storage* peer_addr, struct sockaddr_storage* local_addr) {
|
||||
slipstream_client_ctx_t* client_ctx = callback_ctx;
|
||||
|
||||
// optimize path for single segment
|
||||
if (src_buf_len <= *segment_len) {
|
||||
#ifdef NOENCODE
|
||||
|
|
@ -104,7 +133,7 @@ ssize_t client_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf
|
|||
#endif
|
||||
size_t packet_len = MAX_DNS_QUERY_SIZE;
|
||||
unsigned char* packet = malloc(packet_len);
|
||||
const ssize_t ret = client_encode_segment((dns_packet_t*) packet, &packet_len, src_buf, src_buf_len);
|
||||
const ssize_t ret = client_encode_segment(client_ctx, (dns_packet_t*) packet, &packet_len, src_buf, src_buf_len);
|
||||
if (ret < 0) {
|
||||
free(packet);
|
||||
return -1;
|
||||
|
|
@ -128,7 +157,7 @@ ssize_t client_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf
|
|||
size_t first_packet_len = 0;
|
||||
for (size_t i = 0; i < num_segments; i++) {
|
||||
size_t packet_len = MAX_DNS_QUERY_SIZE;
|
||||
const ssize_t ret = client_encode_segment((dns_packet_t*) current_packet, &packet_len, segment, *segment_len);
|
||||
const ssize_t ret = client_encode_segment(client_ctx, (dns_packet_t*) current_packet, &packet_len, segment, *segment_len);
|
||||
if (ret < 0) {
|
||||
free(packets);
|
||||
return -1;
|
||||
|
|
@ -823,6 +852,12 @@ int picoquic_slipstream_client(int listen_port, char const* resolver_addresses_f
|
|||
param.decode = client_decode;
|
||||
param.encode = client_encode;
|
||||
|
||||
if (llm_create_connection(&client_ctx.llm_conn, 8000) < 0) {
|
||||
perror("unable to create LLM connection");
|
||||
close(client_ctx.listen_sock);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
picoquic_network_thread_ctx_t thread_ctx = {0};
|
||||
thread_ctx.quic = quic;
|
||||
thread_ctx.param = ¶m;
|
||||
|
|
@ -854,6 +889,7 @@ int picoquic_slipstream_client(int listen_port, char const* resolver_addresses_f
|
|||
/* And finish. */
|
||||
printf("Client exit, ret = %d\n", ret);
|
||||
|
||||
free(client_ctx.llm_conn);
|
||||
picoquic_free(quic);
|
||||
|
||||
return ret;
|
||||
|
|
|
|||
65
src/slipstream_llm.c
Normal file
65
src/slipstream_llm.c
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
#include "slipstream_llm.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
|
||||
|
||||
typedef struct st_llm_connection_t {
|
||||
int sockfd;
|
||||
struct sockaddr_in addr;
|
||||
} llm_connection_t;
|
||||
|
||||
typedef struct st_llm_request_t {
|
||||
uint8_t query_data_length;
|
||||
uint8_t should_decode;
|
||||
uint8_t query_data[255];
|
||||
} llm_request_t;
|
||||
|
||||
ssize_t llm_create_connection(llm_connection_t** conn_p, int port) {
|
||||
llm_connection_t* conn = malloc(sizeof(llm_connection_t));
|
||||
if (conn == NULL) {
|
||||
return -1;
|
||||
}
|
||||
memset(conn, 0, sizeof(llm_connection_t));
|
||||
|
||||
conn->sockfd = socket(AF_INET, SOCK_DGRAM, 0);
|
||||
if (conn->sockfd < 0) {
|
||||
free(conn);
|
||||
return -1;
|
||||
}
|
||||
|
||||
conn->addr.sin_family = AF_INET; // IPv4
|
||||
conn->addr.sin_port = htons(port); // port number
|
||||
conn->addr.sin_addr.s_addr = inet_addr("127.0.0.1"); // localhost
|
||||
|
||||
*conn_p = conn;
|
||||
return 0;
|
||||
}
|
||||
|
||||
ssize_t llm_send_recv(llm_connection_t* conn, char* dest, const size_t dest_len, const char* src, const size_t src_len, int should_decode) {
|
||||
llm_request_t req;
|
||||
req.query_data_length = src_len;
|
||||
req.should_decode = should_decode;
|
||||
memcpy(req.query_data, src, src_len);
|
||||
|
||||
const size_t request_len = sizeof(req.query_data_length) + sizeof(req.should_decode) + req.query_data_length;
|
||||
ssize_t n = sendto(conn->sockfd, &req, request_len, 0, (const struct sockaddr*)&conn->addr, sizeof(conn->addr));
|
||||
if (n < 0) {
|
||||
return n;
|
||||
}
|
||||
|
||||
socklen_t addr_len = sizeof(conn->addr);
|
||||
return recvfrom(conn->sockfd, dest, dest_len, 0, (struct sockaddr*)&conn->addr, &addr_len);
|
||||
}
|
||||
|
||||
ssize_t llm_encode(llm_connection_t* conn, char* dest, const size_t dest_len, const char* src, const size_t src_len) {
|
||||
return llm_send_recv(conn, dest, dest_len, src, src_len, 0);
|
||||
}
|
||||
|
||||
ssize_t llm_decode(llm_connection_t* conn, char* dest, const size_t dest_len, const char* src, const size_t src_len) {
|
||||
return llm_send_recv(conn, dest, dest_len, src, src_len, 1);
|
||||
}
|
||||
|
|
@ -21,7 +21,8 @@
|
|||
#include "picoquic_logger.h"
|
||||
#include "slipstream.h"
|
||||
#include "slipstream_inline_dots.h"
|
||||
#include "../include/slipstream_server_cc.h"
|
||||
#include "slipstream_llm.h"
|
||||
#include "slipstream_server_cc.h"
|
||||
#include "slipstream_slot.h"
|
||||
#include "slipstream_utils.h"
|
||||
#include "SPCDNS/src/dns.h"
|
||||
|
|
@ -30,6 +31,26 @@
|
|||
char* server_domain_name = NULL;
|
||||
size_t server_domain_name_len = 0;
|
||||
|
||||
typedef struct st_slipstream_server_stream_ctx_t {
|
||||
struct st_slipstream_server_stream_ctx_t* next_stream;
|
||||
struct st_slipstream_server_stream_ctx_t* previous_stream;
|
||||
int fd;
|
||||
uint64_t stream_id;
|
||||
volatile sig_atomic_t set_active;
|
||||
int syn_received;
|
||||
int syn_sent;
|
||||
} slipstream_server_stream_ctx_t;
|
||||
|
||||
typedef struct st_slipstream_server_ctx_t {
|
||||
picoquic_cnx_t* cnx;
|
||||
slipstream_server_stream_ctx_t* first_stream;
|
||||
picoquic_network_thread_ctx_t* thread_ctx;
|
||||
struct sockaddr_storage upstream_addr;
|
||||
struct st_slipstream_server_ctx_t* prev_ctx;
|
||||
struct st_slipstream_server_ctx_t* next_ctx;
|
||||
llm_connection_t* llm_conn;
|
||||
} slipstream_server_ctx_t;
|
||||
|
||||
ssize_t server_encode(void* slot_p, void* callback_ctx, unsigned char** dest_buf, const unsigned char* src_buf, size_t src_buf_len, size_t* segment_len, struct sockaddr_storage* peer_addr, struct sockaddr_storage* local_addr) {
|
||||
// we don't support segmentation in the server
|
||||
assert(segment_len == NULL || *segment_len == 0 || *segment_len == src_buf_len);
|
||||
|
|
@ -106,6 +127,7 @@ ssize_t server_decode(void* slot_p, void* callback_ctx, unsigned char** dest_buf
|
|||
*dest_buf = NULL;
|
||||
|
||||
slot_t* slot = slot_p;
|
||||
slipstream_server_ctx_t* default_ctx = callback_ctx;
|
||||
|
||||
// DNS packets arrive from random source ports, so:
|
||||
// * save the original address in the dns query slot
|
||||
|
|
@ -159,45 +181,37 @@ ssize_t server_decode(void* slot_p, void* callback_ctx, unsigned char** dest_buf
|
|||
return 0;
|
||||
}
|
||||
|
||||
// copy the subdomain from name to a new buffer
|
||||
char data_buf[data_len];
|
||||
memcpy(data_buf, question->name, data_len);
|
||||
data_buf[data_len] = '\0';
|
||||
const size_t encoded_len = slipstream_inline_undotify(data_buf, data_len);
|
||||
char* decoded_buf = malloc(data_len);
|
||||
ssize_t decoded_len;
|
||||
if (default_ctx->llm_conn) {
|
||||
decoded_len = llm_decode(default_ctx->llm_conn, decoded_buf, data_len, (const char*) question->name, data_len);
|
||||
if (decoded_len < 0) {
|
||||
// DBG_PRINTF socket errono
|
||||
DBG_PRINTF("error decoding with LLM: %lu (%s)", decoded_len, strerror(errno));
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
// copy the subdomain from name to a new buffer
|
||||
char data_buf[data_len];
|
||||
memcpy(data_buf, question->name, data_len);
|
||||
data_buf[data_len] = '\0';
|
||||
|
||||
char* decoded_buf = malloc(encoded_len);
|
||||
const size_t decoded_len = b32_decode(decoded_buf, data_buf, encoded_len, false);
|
||||
if (decoded_len == (size_t) -1) {
|
||||
free(decoded_buf);
|
||||
DBG_PRINTF("error decoding base32: %lu", decoded_len);
|
||||
slot->error = RCODE_SERVER_FAILURE;
|
||||
return 0;
|
||||
const size_t encoded_len = slipstream_inline_undotify(data_buf, data_len);
|
||||
decoded_len = b32_decode(decoded_buf, data_buf, encoded_len, false);
|
||||
if (decoded_len == (size_t) -1) {
|
||||
free(decoded_buf);
|
||||
DBG_PRINTF("error decoding base32: %lu", decoded_len);
|
||||
slot->error = RCODE_SERVER_FAILURE;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
*dest_buf = decoded_buf;
|
||||
|
||||
return decoded_len;
|
||||
}
|
||||
|
||||
typedef struct st_slipstream_server_stream_ctx_t {
|
||||
struct st_slipstream_server_stream_ctx_t* next_stream;
|
||||
struct st_slipstream_server_stream_ctx_t* previous_stream;
|
||||
int fd;
|
||||
uint64_t stream_id;
|
||||
volatile sig_atomic_t set_active;
|
||||
int syn_received;
|
||||
int syn_sent;
|
||||
} slipstream_server_stream_ctx_t;
|
||||
|
||||
typedef struct st_slipstream_server_ctx_t {
|
||||
picoquic_cnx_t* cnx;
|
||||
slipstream_server_stream_ctx_t* first_stream;
|
||||
picoquic_network_thread_ctx_t* thread_ctx;
|
||||
struct sockaddr_storage upstream_addr;
|
||||
struct st_slipstream_server_ctx_t* prev_ctx;
|
||||
struct st_slipstream_server_ctx_t* next_ctx;
|
||||
} slipstream_server_ctx_t;
|
||||
|
||||
slipstream_server_stream_ctx_t* slipstream_server_create_stream_ctx(slipstream_server_ctx_t* server_ctx,
|
||||
uint64_t stream_id) {
|
||||
slipstream_server_stream_ctx_t* stream_ctx = malloc(sizeof(slipstream_server_stream_ctx_t));
|
||||
|
|
@ -652,6 +666,11 @@ int picoquic_slipstream_server(int server_port, const char* server_cert, const c
|
|||
param.encode = server_encode;
|
||||
// param.delay_max = 5000;
|
||||
|
||||
if (llm_create_connection(&default_context.llm_conn, 8001) < 0) {
|
||||
perror("unable to create LLM connection");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
picoquic_network_thread_ctx_t thread_ctx = {0};
|
||||
thread_ctx.quic = quic;
|
||||
thread_ctx.param = ¶m;
|
||||
|
|
@ -671,6 +690,7 @@ int picoquic_slipstream_server(int server_port, const char* server_cert, const c
|
|||
/* And finish. */
|
||||
printf("Server exit, ret = %d\n", ret);
|
||||
|
||||
free(default_context.llm_conn);
|
||||
picoquic_free(quic);
|
||||
|
||||
return ret;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue