diff --git a/trunk/src/app/srs_app_rtc_dtls.cpp b/trunk/src/app/srs_app_rtc_dtls.cpp index 8a70b772d..c3f32ef2c 100644 --- a/trunk/src/app/srs_app_rtc_dtls.cpp +++ b/trunk/src/app/srs_app_rtc_dtls.cpp @@ -47,15 +47,83 @@ using namespace std; // can however retrieve the error code of the last verification error using SSL_get_verify_result(3) or by maintaining // its own error storage managed by verify_callback. // @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html -static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) +int srs_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { // Always OK, we don't check the certificate of client, // because we allow client self-sign certificate. return 1; } +SSL_CTX* srs_build_dtls_ctx(SrsDtlsVersion version) +{ + SSL_CTX* dtls_ctx; +#if OPENSSL_VERSION_NUMBER < 0x10002000L // v1.0.2 + dtls_ctx = SSL_CTX_new(DTLSv1_method()); +#else + if (version == SrsDtlsVersion1_0) { + dtls_ctx = SSL_CTX_new(DTLSv1_method()); + } else if (version == SrsDtlsVersion1_2) { + dtls_ctx = SSL_CTX_new(DTLSv1_2_method()); + } else { + // SrsDtlsVersionAuto, use version-flexible DTLS methods + dtls_ctx = SSL_CTX_new(DTLS_method()); + } +#endif + + if (_srs_rtc_dtls_certificate->is_ecdsa()) { // By ECDSA, https://stackoverflow.com/a/6006898 +#if OPENSSL_VERSION_NUMBER >= 0x10002000L // v1.0.2 + // For ECDSA, we could set the curves list. + // @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set1_curves_list.html + SSL_CTX_set1_curves_list(dtls_ctx, "P-521:P-384:P-256"); +#endif + + // For openssl <1.1, we must set the ECDH manually. + // @see https://stackoverrun.com/cn/q/10791887 +#if OPENSSL_VERSION_NUMBER < 0x10100000L // v1.1.x + #if OPENSSL_VERSION_NUMBER < 0x10002000L // v1.0.2 + SSL_CTX_set_tmp_ecdh(dtls_ctx, _srs_rtc_dtls_certificate->get_ecdsa_key()); + #else + SSL_CTX_set_ecdh_auto(dtls_ctx, 1); + #endif +#endif + } + + // Setup DTLS context. + if (true) { + // We use "ALL", while you can use "DEFAULT" means "ALL:!EXPORT:!LOW:!aNULL:!eNULL:!SSLv2" + // @see https://www.openssl.org/docs/man1.0.2/man1/ciphers.html + srs_assert(SSL_CTX_set_cipher_list(dtls_ctx, "ALL") == 1); + + // Setup the certificate. + srs_assert(SSL_CTX_use_certificate(dtls_ctx, _srs_rtc_dtls_certificate->get_cert()) == 1); + srs_assert(SSL_CTX_use_PrivateKey(dtls_ctx, _srs_rtc_dtls_certificate->get_public_key()) == 1); + + // Server will send Certificate Request. + // @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html + // TODO: FIXME: Config it, default to off to make the packet smaller. + SSL_CTX_set_verify(dtls_ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, srs_verify_callback); + // The depth count is "level 0:peer certificate", "level 1: CA certificate", + // "level 2: higher level CA certificate", and so on. + // @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html + SSL_CTX_set_verify_depth(dtls_ctx, 4); + + // Whether we should read as many input bytes as possible (for non-blocking reads) or not. + // @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_read_ahead.html + SSL_CTX_set_read_ahead(dtls_ctx, 1); + + // TODO: Maybe we can use SRTP-GCM in future. + // @see https://bugs.chromium.org/p/chromium/issues/detail?id=713701 + // @see https://groups.google.com/forum/#!topic/discuss-webrtc/PvCbWSetVAQ + // @remark Only support SRTP_AES128_CM_SHA1_80, please read ssl/d1_srtp.c + srs_assert(SSL_CTX_set_tlsext_use_srtp(dtls_ctx, "SRTP_AES128_CM_SHA1_80") == 0); + } + + return dtls_ctx; +} + SrsDtlsCertificate::SrsDtlsCertificate() { + ecdsa_mode = true; dtls_cert = NULL; dtls_pkey = NULL; eckey = NULL; @@ -244,12 +312,12 @@ ISrsDtlsCallback::~ISrsDtlsCallback() { } -SrsDtls::SrsDtls(ISrsDtlsCallback* cb) +SrsDtls::SrsDtls(ISrsDtlsCallback* callback) { dtls_ctx = NULL; dtls = NULL; - callback = cb; + callback_ = callback; handshake_done_for_us = false; last_outgoing_packet_cache = new uint8_t[kRtpPacketSize]; @@ -297,13 +365,13 @@ srs_error_t SrsDtls::initialize(std::string role, std::string version) version_ = SrsDtlsVersionAuto; } - dtls_ctx = build_dtls_ctx(); + dtls_ctx = srs_build_dtls_ctx(version_); if ((dtls = SSL_new(dtls_ctx)) == NULL) { return srs_error_new(ERROR_OpenSslCreateSSL, "SSL_new dtls"); } - if (role == "active") { + if (role_ == SrsDtlsRoleClient) { // Dtls setup active, as client role. SSL_set_connect_state(dtls); SSL_set_max_send_fragment(dtls, kRtpPacketSize); @@ -326,73 +394,6 @@ srs_error_t SrsDtls::initialize(std::string role, std::string version) return err; } -SSL_CTX* SrsDtls::build_dtls_ctx() -{ - SSL_CTX* dtls_ctx; -#if OPENSSL_VERSION_NUMBER < 0x10002000L // v1.0.2 - dtls_ctx = SSL_CTX_new(DTLSv1_method()); -#else - if (version_ == SrsDtlsVersion1_0) { - dtls_ctx = SSL_CTX_new(DTLSv1_method()); - } else if (version_ == SrsDtlsVersion1_2) { - dtls_ctx = SSL_CTX_new(DTLSv1_2_method()); - } else { - // SrsDtlsVersionAuto, use version-flexible DTLS methods - dtls_ctx = SSL_CTX_new(DTLS_method()); - } -#endif - - if (_srs_rtc_dtls_certificate->is_ecdsa()) { // By ECDSA, https://stackoverflow.com/a/6006898 -#if OPENSSL_VERSION_NUMBER >= 0x10002000L // v1.0.2 - // For ECDSA, we could set the curves list. - // @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set1_curves_list.html - SSL_CTX_set1_curves_list(dtls_ctx, "P-521:P-384:P-256"); -#endif - - // For openssl <1.1, we must set the ECDH manually. - // @see https://stackoverrun.com/cn/q/10791887 -#if OPENSSL_VERSION_NUMBER < 0x10100000L // v1.1.x - #if OPENSSL_VERSION_NUMBER < 0x10002000L // v1.0.2 - SSL_CTX_set_tmp_ecdh(dtls_ctx, _srs_rtc_dtls_certificate->get_ecdsa_key()); - #else - SSL_CTX_set_ecdh_auto(dtls_ctx, 1); - #endif -#endif - } - - // Setup DTLS context. - if (true) { - // We use "ALL", while you can use "DEFAULT" means "ALL:!EXPORT:!LOW:!aNULL:!eNULL:!SSLv2" - // @see https://www.openssl.org/docs/man1.0.2/man1/ciphers.html - srs_assert(SSL_CTX_set_cipher_list(dtls_ctx, "ALL") == 1); - - // Setup the certificate. - srs_assert(SSL_CTX_use_certificate(dtls_ctx, _srs_rtc_dtls_certificate->get_cert()) == 1); - srs_assert(SSL_CTX_use_PrivateKey(dtls_ctx, _srs_rtc_dtls_certificate->get_public_key()) == 1); - - // Server will send Certificate Request. - // @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html - // TODO: FIXME: Config it, default to off to make the packet smaller. - SSL_CTX_set_verify(dtls_ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, verify_callback); - // The depth count is "level 0:peer certificate", "level 1: CA certificate", - // "level 2: higher level CA certificate", and so on. - // @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html - SSL_CTX_set_verify_depth(dtls_ctx, 4); - - // Whether we should read as many input bytes as possible (for non-blocking reads) or not. - // @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_read_ahead.html - SSL_CTX_set_read_ahead(dtls_ctx, 1); - - // TODO: Maybe we can use SRTP-GCM in future. - // @see https://bugs.chromium.org/p/chromium/issues/detail?id=713701 - // @see https://groups.google.com/forum/#!topic/discuss-webrtc/PvCbWSetVAQ - // @remark Only support SRTP_AES128_CM_SHA1_80, please read ssl/d1_srtp.c - srs_assert(SSL_CTX_set_tlsext_use_srtp(dtls_ctx, "SRTP_AES128_CM_SHA1_80") == 0); - } - - return dtls_ctx; -} - srs_error_t SrsDtls::start_active_handshake() { srs_error_t err = srs_success; @@ -437,7 +438,7 @@ srs_error_t SrsDtls::do_on_dtls(char* data, int nb_data) } // Trace the detail of DTLS packet. - state_trace((uint8_t*)data, nb_data, true, SSL_ERROR_NONE, false, false); + state_trace((uint8_t*)data, nb_data, true, r0, SSL_ERROR_NONE, false, false); if ((r0 = BIO_write(bio_in, data, nb_data)) <= 0) { // TODO: 0 or -1 maybe block, use BIO_should_retry to check. @@ -453,13 +454,12 @@ srs_error_t SrsDtls::do_on_dtls(char* data, int nb_data) while (BIO_ctrl_pending(bio_in) > 0) { char buf[8092]; int nb = SSL_read(dtls, buf, sizeof(buf)); - - if (!callback || nb <= 0) { + if (nb <= 0) { continue; } srs_trace("DTLS: read nb=%d, data=[%s]", nb, srs_string_dumps_hex(buf, nb, 32).c_str()); - if ((err = callback->on_dtls_application_data(buf, nb)) != srs_success) { + if ((err = callback_->on_dtls_application_data(buf, nb)) != srs_success) { return srs_error_wrap(err, "on DTLS data, size=%u, data=[%s]", nb, srs_string_dumps_hex(buf, nb, 32).c_str()); } @@ -476,7 +476,12 @@ srs_error_t SrsDtls::do_handshake() int r0 = SSL_do_handshake(dtls); int r1 = SSL_get_error(dtls, r0); - // TODO: What about SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE? + // Fatal SSL error, for example, no available suite when peer is DTLS 1.0 while we are DTLS 1.2. + if (r0 < 0 && (r1 != SSL_ERROR_NONE && r1 != SSL_ERROR_WANT_READ && r1 != SSL_ERROR_WANT_WRITE)) { + return srs_error_new(ERROR_OpenSslBIOWrite, "handshake r0=%d, r1=%d", r0, r1); + } + + // OK, Handshake is done, note that it maybe done many times. if (r1 == SSL_ERROR_NONE) { handshake_done_for_us = true; } @@ -495,7 +500,7 @@ srs_error_t SrsDtls::do_handshake() } // Trace the detail of DTLS packet. - state_trace((uint8_t*)data, size, false, r1, cache, false); + state_trace((uint8_t*)data, size, false, r0, r1, cache, false); // Update the packet cache. if (size > 0 && data != last_outgoing_packet_cache && size < kRtpPacketSize) { @@ -528,7 +533,7 @@ srs_error_t SrsDtls::do_handshake() } } - if (size > 0 && (err = callback->write_dtls_data(data, size)) != srs_success) { + if (size > 0 && (err = callback_->write_dtls_data(data, size)) != srs_success) { return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size, srs_string_dumps_hex((char*)data, size, 32).c_str()); } @@ -541,7 +546,7 @@ srs_error_t SrsDtls::do_handshake() } // Notify connection the DTLS is done. - if (((err = callback->on_dtls_handshake_done()) != srs_success)) { + if (((err = callback_->on_dtls_handshake_done()) != srs_success)) { return srs_error_wrap(err, "dtls done"); } } @@ -581,9 +586,9 @@ srs_error_t SrsDtls::cycle() if (size) { // Trace the detail of DTLS packet. - state_trace((uint8_t*)data, size, false, SSL_ERROR_NONE, true, true); + state_trace((uint8_t*)data, size, false, 1, SSL_ERROR_NONE, true, true); - if ((err = callback->write_dtls_data(data, size)) != srs_success) { + if ((err = callback_->write_dtls_data(data, size)) != srs_success) { return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size, srs_string_dumps_hex((char*)data, size, 32).c_str()); } @@ -596,7 +601,7 @@ srs_error_t SrsDtls::cycle() return err; } -void SrsDtls::state_trace(uint8_t* data, int length, bool incoming, int ssl_err, bool cache, bool arq) +void SrsDtls::state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq) { uint8_t content_type = 0; if (length >= 1) { @@ -613,9 +618,9 @@ void SrsDtls::state_trace(uint8_t* data, int length, bool incoming, int ssl_err, handshake_type = (uint8_t)data[13]; } - srs_trace("DTLS: %s %s, done=%u, cache=%u, arq=%u, state=%u, r0=%d, len=%u, cnt=%u, size=%u, hs=%u", + srs_trace("DTLS: %s %s, done=%u, cache=%u, arq=%u, state=%u, r0=%d, r1=%d, len=%u, cnt=%u, size=%u, hs=%u", (role_ == SrsDtlsRoleClient? "Active":"Passive"), (incoming? "RECV":"SEND"), handshake_done_for_us, cache, arq, - state_, ssl_err, length, content_type, size, handshake_type); + state_, r0, r1, length, content_type, size, handshake_type); } srs_error_t SrsDtls::start_arq() diff --git a/trunk/src/app/srs_app_rtc_dtls.hpp b/trunk/src/app/srs_app_rtc_dtls.hpp index 32cb76025..b15da73c1 100644 --- a/trunk/src/app/srs_app_rtc_dtls.hpp +++ b/trunk/src/app/srs_app_rtc_dtls.hpp @@ -112,23 +112,20 @@ private: SSL* dtls; BIO* bio_in; BIO* bio_out; - - ISrsDtlsCallback* callback; - + ISrsDtlsCallback* callback_; +private: // Whether the handhshake is done, for us only. // @remark For us only, means peer maybe not done, we also need to handle the DTLS packet. bool handshake_done_for_us; - // DTLS packet cache, only last out-going packet. uint8_t* last_outgoing_packet_cache; int nn_last_outgoing_packet; - // ARQ thread, for role active(DTLS client). // @note If passive(DTLS server), the ARQ is driven by DTLS client. SrsCoroutine* trd; // The DTLS-client state to drive the ARQ thread. SrsDtlsState state_; - +private: // @remark: dtls_role_ default value is DTLS_SERVER. SrsDtlsRole role_; // @remark: dtls_version_ default value is SrsDtlsVersionAuto. @@ -138,8 +135,6 @@ public: virtual ~SrsDtls(); public: srs_error_t initialize(std::string role, std::string version); -private: - SSL_CTX* build_dtls_ctx(); public: // As DTLS client, start handshake actively, send the ClientHello packet. srs_error_t start_active_handshake(); @@ -153,7 +148,7 @@ private: public: virtual srs_error_t cycle(); private: - void state_trace(uint8_t* data, int length, bool incoming, int ssl_err, bool cache, bool arq); + void state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq); private: srs_error_t start_arq(); void stop_arq(); diff --git a/trunk/src/kernel/srs_kernel_codec.cpp b/trunk/src/kernel/srs_kernel_codec.cpp index e3d3ca4a3..a937662f5 100644 --- a/trunk/src/kernel/srs_kernel_codec.cpp +++ b/trunk/src/kernel/srs_kernel_codec.cpp @@ -368,6 +368,13 @@ SrsSample::SrsSample() bframe = false; } +SrsSample::SrsSample(char* b, int s) +{ + size = s; + bytes = b; + bframe = false; +} + SrsSample::~SrsSample() { } diff --git a/trunk/src/kernel/srs_kernel_codec.hpp b/trunk/src/kernel/srs_kernel_codec.hpp index a5bb960cb..cc2bcb711 100644 --- a/trunk/src/kernel/srs_kernel_codec.hpp +++ b/trunk/src/kernel/srs_kernel_codec.hpp @@ -538,6 +538,7 @@ public: bool bframe; public: SrsSample(); + SrsSample(char* b, int s); ~SrsSample(); public: // If we need to know whether sample is bframe, we have to parse the NALU payload. diff --git a/trunk/src/utest/srs_utest.cpp b/trunk/src/utest/srs_utest.cpp index 286b925e6..8e400f1f0 100644 --- a/trunk/src/utest/srs_utest.cpp +++ b/trunk/src/utest/srs_utest.cpp @@ -28,6 +28,7 @@ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. #include #include #include +#include #include using namespace std; @@ -57,6 +58,10 @@ srs_error_t prepare_main() { return srs_error_wrap(err, "init st"); } + if ((err = _srs_rtc_dtls_certificate->initialize()) != srs_success) { + return srs_error_wrap(err, "rtc dtls certificate initialize"); + } + srs_freep(_srs_context); _srs_context = new SrsThreadContext(); diff --git a/trunk/src/utest/srs_utest.hpp b/trunk/src/utest/srs_utest.hpp index 252c9928f..3d63388fb 100644 --- a/trunk/src/utest/srs_utest.hpp +++ b/trunk/src/utest/srs_utest.hpp @@ -56,19 +56,24 @@ extern int _srs_tmp_port; extern srs_utime_t _srs_tmp_timeout; // For errors. +// @remark we directly delete the err, because we allow user to append message if fail. #define HELPER_EXPECT_SUCCESS(x) \ if ((err = x) != srs_success) fprintf(stderr, "err %s", srs_error_desc(err).c_str()); \ - EXPECT_TRUE(srs_success == err); \ - srs_freep(err) -#define HELPER_EXPECT_FAILED(x) EXPECT_TRUE(srs_success != (err = x)); srs_freep(err) + if (err != srs_success) delete err; \ + EXPECT_TRUE(srs_success == err) +#define HELPER_EXPECT_FAILED(x) \ + if ((err = x) != srs_success) delete err; \ + EXPECT_TRUE(srs_success != err) // For errors, assert. -// @remark The err is leak when error, but it's ok in utest. -#define HELPER_ASSERT_SUCCESS(x) \ +// @remark we directly delete the err, because we allow user to append message if fail. +#define HELPER_EXPECT_SUCCESS(x) \ if ((err = x) != srs_success) fprintf(stderr, "err %s", srs_error_desc(err).c_str()); \ - ASSERT_TRUE(srs_success == err); \ - srs_freep(err) -#define HELPER_ASSERT_FAILED(x) ASSERT_TRUE(srs_success != (err = x)); srs_freep(err) + if (err != srs_success) delete err; \ + ASSERT_TRUE(srs_success == err) +#define HELPER_ASSERT_FAILED(x) \ + if ((err = x) != srs_success) delete err; \ + ASSERT_TRUE(srs_success != err) // For init array data. #define HELPER_ARRAY_INIT(buf, sz, val) \ diff --git a/trunk/src/utest/srs_utest_rtc.cpp b/trunk/src/utest/srs_utest_rtc.cpp index 31308693f..e1cbf7bc8 100644 --- a/trunk/src/utest/srs_utest_rtc.cpp +++ b/trunk/src/utest/srs_utest_rtc.cpp @@ -29,10 +29,314 @@ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. #include #include #include +#include #include using namespace std; +extern SSL_CTX* srs_build_dtls_ctx(SrsDtlsVersion version); + +class MockDtls +{ +public: + SSL_CTX* dtls_ctx; + SSL* dtls; + BIO* bio_in; + BIO* bio_out; + ISrsDtlsCallback* callback_; + bool handshake_done_for_us; + SrsDtlsRole role_; + SrsDtlsVersion version_; +public: + MockDtls(ISrsDtlsCallback* callback); + virtual ~MockDtls(); + srs_error_t initialize(std::string role, std::string version); + srs_error_t start_active_handshake(); + srs_error_t on_dtls(char* data, int nb_data); + srs_error_t do_handshake(); +}; + +MockDtls::MockDtls(ISrsDtlsCallback* callback) +{ + dtls_ctx = NULL; + dtls = NULL; + + callback_ = callback; + handshake_done_for_us = false; + + role_ = SrsDtlsRoleServer; + version_ = SrsDtlsVersionAuto; +} + +MockDtls::~MockDtls() +{ + if (dtls_ctx) { + SSL_CTX_free(dtls_ctx); + dtls_ctx = NULL; + } + + if (dtls) { + SSL_free(dtls); + dtls = NULL; + } +} + +srs_error_t MockDtls::initialize(std::string role, std::string version) +{ + role_ = SrsDtlsRoleServer; + if (role == "active") { + role_ = SrsDtlsRoleClient; + } + + if (version == "dtls1.0") { + version_ = SrsDtlsVersion1_0; + } else if (version == "dtls1.2") { + version_ = SrsDtlsVersion1_2; + } else { + version_ = SrsDtlsVersionAuto; + } + + dtls_ctx = srs_build_dtls_ctx(version_); + dtls = SSL_new(dtls_ctx); + srs_assert(dtls); + + if (role_ == SrsDtlsRoleClient) { + SSL_set_connect_state(dtls); + SSL_set_max_send_fragment(dtls, kRtpPacketSize); + } else { + SSL_set_accept_state(dtls); + } + + bio_in = BIO_new(BIO_s_mem()); + srs_assert(bio_in); + + bio_out = BIO_new(BIO_s_mem()); + srs_assert(bio_out); + + SSL_set_bio(dtls, bio_in, bio_out); + return srs_success; +} + +srs_error_t MockDtls::start_active_handshake() +{ + if (role_ == SrsDtlsRoleClient) { + return do_handshake(); + } + return srs_success; +} + +srs_error_t MockDtls::on_dtls(char* data, int nb_data) +{ + srs_error_t err = srs_success; + + srs_assert(BIO_reset(bio_in) == 1); + srs_assert(BIO_reset(bio_out) == 1); + srs_assert(BIO_write(bio_in, data, nb_data) > 0); + + if ((err = do_handshake()) != srs_success) { + return srs_error_wrap(err, "do handshake"); + } + + while (BIO_ctrl_pending(bio_in) > 0) { + char buf[8092]; + int nb = SSL_read(dtls, buf, sizeof(buf)); + if (nb <= 0) { + continue; + } + + if ((err = callback_->on_dtls_application_data(buf, nb)) != srs_success) { + return srs_error_wrap(err, "on DTLS data, size=%u", nb); + } + } + + return err; +} + +srs_error_t MockDtls::do_handshake() +{ + srs_error_t err = srs_success; + + int r0 = SSL_do_handshake(dtls); + int r1 = SSL_get_error(dtls, r0); + if (r0 < 0 && (r1 != SSL_ERROR_NONE && r1 != SSL_ERROR_WANT_READ && r1 != SSL_ERROR_WANT_WRITE)) { + return srs_error_new(ERROR_OpenSslBIOWrite, "handshake r0=%d, r1=%d", r0, r1); + } + if (r1 == SSL_ERROR_NONE) { + handshake_done_for_us = true; + } + + uint8_t* data = NULL; + int size = BIO_get_mem_data(bio_out, &data); + + if (size > 0 && (err = callback_->write_dtls_data(data, size)) != srs_success) { + return srs_error_wrap(err, "dtls send size=%u", size); + } + + if (handshake_done_for_us) { + return callback_->on_dtls_handshake_done(); + } + + return err; +} + +class MockDtlsCallback : virtual public ISrsDtlsCallback, virtual public ISrsCoroutineHandler +{ +public: + SrsDtls* peer; + MockDtls* peer2; + SrsCoroutine* trd; + srs_error_t r0; + bool done; + std::vector samples; + MockDtlsCallback(); + virtual ~MockDtlsCallback(); + virtual srs_error_t on_dtls_handshake_done(); + virtual srs_error_t on_dtls_application_data(const char* data, const int len); + virtual srs_error_t write_dtls_data(void* data, int size); + virtual srs_error_t cycle(); +}; + +MockDtlsCallback::MockDtlsCallback() +{ + peer = NULL; + peer2 = NULL; + r0 = srs_success; + done = false; + trd = new SrsSTCoroutine("mock", this); + srs_assert(trd->start() == srs_success); +} + +MockDtlsCallback::~MockDtlsCallback() +{ + srs_freep(trd); + srs_freep(r0); + for (vector::iterator it = samples.begin(); it != samples.end(); ++it) { + delete[] it->bytes; + } +} + +srs_error_t MockDtlsCallback::on_dtls_handshake_done() +{ + done = true; + return srs_success; +} + +srs_error_t MockDtlsCallback::on_dtls_application_data(const char* data, const int len) +{ + return srs_success; +} + +srs_error_t MockDtlsCallback::write_dtls_data(void* data, int size) +{ + char* cp = new char[size]; + memcpy(cp, data, size); + + samples.push_back(SrsSample((char*)cp, size)); + return srs_success; +} + +srs_error_t MockDtlsCallback::cycle() +{ + srs_error_t err = srs_success; + + while (err == srs_success) { + if ((err = trd->pull()) != srs_success) { + break; + } + + if (samples.empty()) { + srs_usleep(0); + continue; + } + + SrsSample p = *samples.erase(samples.begin()); + if (peer) { + err = peer->on_dtls((char*)p.bytes, p.size); + } else { + err = peer2->on_dtls((char*)p.bytes, p.size); + } + + srs_freepa(p.bytes); + } + + // Copy it for utest to check it. + r0 = srs_error_copy(err); + + return err; +} + +// Wait for mock io to done, try to switch to coroutine many times. +#define mock_wait_dtls_io_done(cio, sio) \ + for (int i = 0; i < 100 && (!cio.samples.empty() || !sio.samples.empty()); i++) { \ + srs_usleep(0 * SRS_UTIME_MILLISECONDS); \ + } + +struct DTLSServerFlowCase +{ + int id; + + string ClientVersion; + string ServerVersion; + + bool ClientDone; + bool ServerDone; + + bool ClientError; + bool ServerError; +}; + +std::ostream& operator<< (std::ostream& stream, const DTLSServerFlowCase& c) +{ + stream << "Case #" << c.id + << ", client(" << c.ClientVersion << ",done=" << c.ClientDone << ",err=" << c.ClientError << ")" + << ", server(" << c.ServerVersion << ",done=" << c.ServerDone << ",err=" << c.ServerError << ")"; + return stream; +} + +VOID TEST(KernelRTCTest, DTLSServerFlowTest) +{ + srs_error_t err = srs_success; + + DTLSServerFlowCase cases[] = { + // OK, Client, Server: DTLS v1.0 + {0, "dtls1.0", "dtls1.0", true, true, false, false}, + // OK, Client, Server: DTLS v1.2 + {1, "dtls1.2", "dtls1.2", true, true, false, false}, + // OK, Client: DTLS v1.0, Server: DTLS auto(v1.0 or v1.2). + {2, "dtls1.0", "auto", true, true, false, false}, + // OK, Client: DTLS v1.2, Server: DTLS auto(v1.0 or v1.2). + {3, "dtls1.2", "auto", true, true, false, false}, + // OK, Client: DTLS auto(v1.0 or v1.2), Server: DTLS v1.0 + {4, "auto", "dtls1.0", true, true, false, false}, + // OK, Client: DTLS auto(v1.0 or v1.2), Server: DTLS v1.0 + {5, "auto", "dtls1.2", true, true, false, false}, + // Fail, Client: DTLS v1.0, Server: DTLS v1.2 + {6, "dtls1.0", "dtls1.2", false, false, false, true}, + // Fail, Client: DTLS v1.2, Server: DTLS v1.0 + {7, "dtls1.2", "dtls1.0", false, false, true, false}, + }; + + for (int i = 0; i < (int)(sizeof(cases) / sizeof(DTLSServerFlowCase)); i++) { + DTLSServerFlowCase c = cases[i]; + + MockDtlsCallback cio; MockDtls client(&cio); + MockDtlsCallback sio; SrsDtls server(&sio); + cio.peer = &server; sio.peer2 = &client; + HELPER_EXPECT_SUCCESS(client.initialize("active", c.ClientVersion)) << c; + HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c; + + HELPER_EXPECT_SUCCESS(client.start_active_handshake()) << c; + mock_wait_dtls_io_done(cio, sio); + + // Note that the cio error is generated from server, vice versa. + EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c; + EXPECT_EQ(c.ServerError, cio.r0 != srs_success) << c; + + EXPECT_EQ(c.ClientDone, cio.done) << c; + EXPECT_EQ(c.ServerDone, sio.done) << c; + } +} + VOID TEST(KernelRTCTest, SequenceCompare) { if (true) {