From e4b0dd56f0b4c7c80fafc72b3d54c41fc337d48d Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 19 Aug 2020 17:22:34 +0800 Subject: [PATCH] RTC: Covert server ARQ for DTLS --- trunk/src/app/srs_app_rtc_dtls.cpp | 6 +- trunk/src/app/srs_app_rtc_dtls.hpp | 3 + trunk/src/utest/srs_utest_rtc.cpp | 206 +++++++++++++++++++++++++++-- 3 files changed, 205 insertions(+), 10 deletions(-) diff --git a/trunk/src/app/srs_app_rtc_dtls.cpp b/trunk/src/app/srs_app_rtc_dtls.cpp index a54d8ddc6..284f3b715 100644 --- a/trunk/src/app/srs_app_rtc_dtls.cpp +++ b/trunk/src/app/srs_app_rtc_dtls.cpp @@ -328,6 +328,8 @@ SrsDtls::SrsDtls(ISrsDtlsCallback* callback) trd = NULL; state_ = SrsDtlsStateInit; + arq_first = 50 * SRS_UTIME_MILLISECONDS; + arq_interval = 100 * SRS_UTIME_MILLISECONDS; } SrsDtls::~SrsDtls() @@ -559,7 +561,7 @@ srs_error_t SrsDtls::cycle() srs_error_t err = srs_success; // The first ARQ delay. - srs_usleep(50 * SRS_UTIME_MILLISECONDS); + srs_usleep(arq_first); while (true) { srs_info("arq cycle, state=%u", state_); @@ -595,7 +597,7 @@ srs_error_t SrsDtls::cycle() } // TODO: Use ARQ step timeouts. - srs_usleep(100 * SRS_UTIME_MILLISECONDS); + srs_usleep(arq_interval); } return err; diff --git a/trunk/src/app/srs_app_rtc_dtls.hpp b/trunk/src/app/srs_app_rtc_dtls.hpp index b15da73c1..3832475e8 100644 --- a/trunk/src/app/srs_app_rtc_dtls.hpp +++ b/trunk/src/app/srs_app_rtc_dtls.hpp @@ -125,6 +125,9 @@ private: SrsCoroutine* trd; // The DTLS-client state to drive the ARQ thread. SrsDtlsState state_; + // The timeout for ARQ. + srs_utime_t arq_first; + srs_utime_t arq_interval; private: // @remark: dtls_role_ default value is DTLS_SERVER. SrsDtlsRole role_; diff --git a/trunk/src/utest/srs_utest_rtc.cpp b/trunk/src/utest/srs_utest_rtc.cpp index 8ed9c9e00..03fbd8079 100644 --- a/trunk/src/utest/srs_utest_rtc.cpp +++ b/trunk/src/utest/srs_utest_rtc.cpp @@ -36,7 +36,7 @@ using namespace std; extern SSL_CTX* srs_build_dtls_ctx(SrsDtlsVersion version); -class MockDtls +class MockDtls : public ISrsCoroutineHandler { public: SSL_CTX* dtls_ctx; @@ -54,6 +54,7 @@ public: srs_error_t start_active_handshake(); srs_error_t on_dtls(char* data, int nb_data); srs_error_t do_handshake(); + virtual srs_error_t cycle(); }; MockDtls::MockDtls(ISrsDtlsCallback* callback) @@ -179,6 +180,11 @@ srs_error_t MockDtls::do_handshake() return err; } +srs_error_t MockDtls::cycle() +{ + return srs_success; +} + class MockDtlsCallback : virtual public ISrsDtlsCallback, virtual public ISrsCoroutineHandler { public: @@ -188,6 +194,23 @@ public: srs_error_t r0; bool done; std::vector samples; +public: + int nn_client_hello_lost; + int nn_server_hello_lost; + int nn_certificate_lost; + int nn_new_session_lost; + int nn_change_cipher_lost; +public: + // client -> server + int nn_client_hello; + // server -> client + int nn_server_hello; + // client -> server + int nn_certificate; + // server -> client + int nn_new_session; + int nn_change_cipher; +public: MockDtlsCallback(); virtual ~MockDtlsCallback(); virtual srs_error_t on_dtls_handshake_done(); @@ -204,6 +227,18 @@ MockDtlsCallback::MockDtlsCallback() done = false; trd = new SrsSTCoroutine("mock", this); srs_assert(trd->start() == srs_success); + + nn_client_hello_lost = 0; + nn_server_hello_lost = 0; + nn_certificate_lost = 0; + nn_new_session_lost = 0; + nn_change_cipher_lost = 0; + + nn_client_hello = 0; + nn_server_hello = 0; + nn_certificate = 0; + nn_new_session = 0; + nn_change_cipher = 0; } MockDtlsCallback::~MockDtlsCallback() @@ -228,10 +263,49 @@ srs_error_t MockDtlsCallback::on_dtls_application_data(const char* data, const i srs_error_t MockDtlsCallback::write_dtls_data(void* data, int size) { + int nn_lost = 0; + if (true) { + uint8_t content_type = 0; + if (size >= 1) { + content_type = (uint8_t)((uint8_t*)data)[0]; + } + + uint8_t handshake_type = 0; + if (size >= 14) { + handshake_type = (uint8_t)((uint8_t*)data)[13]; + } + + if (content_type == 22) { + if (handshake_type == 1) { + nn_lost = nn_client_hello_lost--; + nn_client_hello++; + } else if (handshake_type == 2) { + nn_lost = nn_server_hello_lost--; + nn_server_hello++; + } else if (handshake_type == 11) { + nn_lost = nn_certificate_lost--; + nn_certificate++; + } else if (handshake_type == 4) { + nn_lost = nn_new_session_lost--; + nn_new_session++; + } + } else if (content_type == 20) { + nn_lost = nn_change_cipher_lost--; + nn_change_cipher++; + } + } + + // Simulate to drop packet. + if (nn_lost > 0) { + return srs_success; + } + + // Send out it. char* cp = new char[size]; memcpy(cp, data, size); samples.push_back(SrsSample((char*)cp, size)); + return srs_success; } @@ -249,10 +323,12 @@ srs_error_t MockDtlsCallback::cycle() continue; } - SrsSample p = *samples.erase(samples.begin()); + SrsSample p = samples.at(0); + samples.erase(samples.begin()); + if (peer) { err = peer->on_dtls((char*)p.bytes, p.size); - } else { + } else if (peer2) { err = peer2->on_dtls((char*)p.bytes, p.size); } @@ -266,10 +342,33 @@ srs_error_t MockDtlsCallback::cycle() } // Wait for mock io to done, try to switch to coroutine many times. -#define mock_wait_dtls_io_done(cio, sio) \ - for (int i = 0; i < 100 && (!cio.samples.empty() || !sio.samples.empty()); i++) { \ - srs_usleep(0 * SRS_UTIME_MILLISECONDS); \ +void mock_wait_dtls_io_done(int count = 100, int interval = 0) +{ + for (int i = 0; i < count; i++) { + srs_usleep(interval * SRS_UTIME_MILLISECONDS); } +} + +// To avoid the crash when peer or peer2 is freed before io. +class MockBridgeDtlsIO +{ +private: + MockDtlsCallback* io_; +public: + MockBridgeDtlsIO(MockDtlsCallback* io, ISrsCoroutineHandler* dtls) { + io_ = io; + if (dynamic_cast(dtls)) { + io->peer = dynamic_cast(dtls); + } + if (dynamic_cast(dtls)) { + io->peer2 = dynamic_cast(dtls); + } + } + virtual ~MockBridgeDtlsIO() { + io_->peer = NULL; + io_->peer2 = NULL; + } +}; struct DTLSServerFlowCase { @@ -293,6 +392,97 @@ std::ostream& operator<< (std::ostream& stream, const DTLSServerFlowCase& c) return stream; } +VOID TEST(KernelRTCTest, DTLSServerARQTest) +{ + srs_error_t err = srs_success; + + // No ARQ, check the number of packets. + if (true) { + MockDtlsCallback cio; SrsDtls client(&cio); + MockDtlsCallback sio; SrsDtls server(&sio); + cio.peer = &server; sio.peer = &client; + HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); + HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); + + HELPER_EXPECT_SUCCESS(client.start_active_handshake()); + mock_wait_dtls_io_done(30, 1); + + EXPECT_TRUE(sio.r0 == srs_success); + EXPECT_TRUE(cio.r0 == srs_success); + + EXPECT_TRUE(cio.done); + EXPECT_TRUE(sio.done); + + EXPECT_EQ(1, cio.nn_client_hello); + EXPECT_EQ(1, sio.nn_server_hello); + EXPECT_EQ(1, cio.nn_certificate); + EXPECT_EQ(1, sio.nn_new_session); + EXPECT_EQ(0, sio.nn_change_cipher); + } + + // ServerHello lost, client retransmit the ClientHello. + if (true) { + MockDtlsCallback cio; SrsDtls client(&cio); + MockDtlsCallback sio; SrsDtls server(&sio); + MockBridgeDtlsIO b0(&cio, &server); MockBridgeDtlsIO b1(&sio, &client); + HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); + HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); + + // Use very short interval for utest. + client.arq_first = 1 * SRS_UTIME_MILLISECONDS; + client.arq_interval = 1 * SRS_UTIME_MILLISECONDS; + + // Lost 2 packets, total packets should be 3. + sio.nn_server_hello_lost = 2; + + HELPER_EXPECT_SUCCESS(client.start_active_handshake()); + mock_wait_dtls_io_done(10, 3); + + EXPECT_TRUE(sio.r0 == srs_success); + EXPECT_TRUE(cio.r0 == srs_success); + + EXPECT_TRUE(cio.done); + EXPECT_TRUE(sio.done); + + EXPECT_EQ(3, cio.nn_client_hello); + EXPECT_EQ(3, sio.nn_server_hello); + EXPECT_EQ(1, cio.nn_certificate); + EXPECT_EQ(1, sio.nn_new_session); + EXPECT_EQ(0, sio.nn_change_cipher); + } + + // NewSessionTicket lost, client retransmit the Certificate. + if (true) { + MockDtlsCallback cio; SrsDtls client(&cio); + MockDtlsCallback sio; SrsDtls server(&sio); + MockBridgeDtlsIO b0(&cio, &server); MockBridgeDtlsIO b1(&sio, &client); + HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); + HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); + + // Use very short interval for utest. + client.arq_first = 1 * SRS_UTIME_MILLISECONDS; + client.arq_interval = 1 * SRS_UTIME_MILLISECONDS; + + // Lost 2 packets, total packets should be 3. + sio.nn_new_session_lost = 2; + + HELPER_EXPECT_SUCCESS(client.start_active_handshake()); + mock_wait_dtls_io_done(10, 3); + + EXPECT_TRUE(sio.r0 == srs_success); + EXPECT_TRUE(cio.r0 == srs_success); + + EXPECT_TRUE(cio.done); + EXPECT_TRUE(sio.done); + + EXPECT_EQ(1, cio.nn_client_hello); + EXPECT_EQ(1, sio.nn_server_hello); + EXPECT_EQ(3, cio.nn_certificate); + EXPECT_EQ(3, sio.nn_new_session); + EXPECT_EQ(0, sio.nn_change_cipher); + } +} + VOID TEST(KernelRTCTest, DTLSClientFlowTest) { srs_error_t err = srs_success; @@ -326,7 +516,7 @@ VOID TEST(KernelRTCTest, DTLSClientFlowTest) HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c; HELPER_EXPECT_SUCCESS(client.start_active_handshake()) << c; - mock_wait_dtls_io_done(cio, sio); + mock_wait_dtls_io_done(); // Note that the cio error is generated from server, vice versa. EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c; @@ -370,7 +560,7 @@ VOID TEST(KernelRTCTest, DTLSServerFlowTest) HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c; HELPER_EXPECT_SUCCESS(client.start_active_handshake()) << c; - mock_wait_dtls_io_done(cio, sio); + mock_wait_dtls_io_done(); // Note that the cio error is generated from server, vice versa. EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c;