1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

RTC: Add utest for DTLS

This commit is contained in:
winlin 2020-08-18 19:20:07 +08:00
parent 9ca6b2e50f
commit c33dfd2313
7 changed files with 424 additions and 102 deletions

View file

@ -47,15 +47,83 @@ using namespace std;
// can however retrieve the error code of the last verification error using SSL_get_verify_result(3) or by maintaining
// its own error storage managed by verify_callback.
// @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html
static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
int srs_verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
{
// Always OK, we don't check the certificate of client,
// because we allow client self-sign certificate.
return 1;
}
SSL_CTX* srs_build_dtls_ctx(SrsDtlsVersion version)
{
SSL_CTX* dtls_ctx;
#if OPENSSL_VERSION_NUMBER < 0x10002000L // v1.0.2
dtls_ctx = SSL_CTX_new(DTLSv1_method());
#else
if (version == SrsDtlsVersion1_0) {
dtls_ctx = SSL_CTX_new(DTLSv1_method());
} else if (version == SrsDtlsVersion1_2) {
dtls_ctx = SSL_CTX_new(DTLSv1_2_method());
} else {
// SrsDtlsVersionAuto, use version-flexible DTLS methods
dtls_ctx = SSL_CTX_new(DTLS_method());
}
#endif
if (_srs_rtc_dtls_certificate->is_ecdsa()) { // By ECDSA, https://stackoverflow.com/a/6006898
#if OPENSSL_VERSION_NUMBER >= 0x10002000L // v1.0.2
// For ECDSA, we could set the curves list.
// @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set1_curves_list.html
SSL_CTX_set1_curves_list(dtls_ctx, "P-521:P-384:P-256");
#endif
// For openssl <1.1, we must set the ECDH manually.
// @see https://stackoverrun.com/cn/q/10791887
#if OPENSSL_VERSION_NUMBER < 0x10100000L // v1.1.x
#if OPENSSL_VERSION_NUMBER < 0x10002000L // v1.0.2
SSL_CTX_set_tmp_ecdh(dtls_ctx, _srs_rtc_dtls_certificate->get_ecdsa_key());
#else
SSL_CTX_set_ecdh_auto(dtls_ctx, 1);
#endif
#endif
}
// Setup DTLS context.
if (true) {
// We use "ALL", while you can use "DEFAULT" means "ALL:!EXPORT:!LOW:!aNULL:!eNULL:!SSLv2"
// @see https://www.openssl.org/docs/man1.0.2/man1/ciphers.html
srs_assert(SSL_CTX_set_cipher_list(dtls_ctx, "ALL") == 1);
// Setup the certificate.
srs_assert(SSL_CTX_use_certificate(dtls_ctx, _srs_rtc_dtls_certificate->get_cert()) == 1);
srs_assert(SSL_CTX_use_PrivateKey(dtls_ctx, _srs_rtc_dtls_certificate->get_public_key()) == 1);
// Server will send Certificate Request.
// @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html
// TODO: FIXME: Config it, default to off to make the packet smaller.
SSL_CTX_set_verify(dtls_ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, srs_verify_callback);
// The depth count is "level 0:peer certificate", "level 1: CA certificate",
// "level 2: higher level CA certificate", and so on.
// @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html
SSL_CTX_set_verify_depth(dtls_ctx, 4);
// Whether we should read as many input bytes as possible (for non-blocking reads) or not.
// @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_read_ahead.html
SSL_CTX_set_read_ahead(dtls_ctx, 1);
// TODO: Maybe we can use SRTP-GCM in future.
// @see https://bugs.chromium.org/p/chromium/issues/detail?id=713701
// @see https://groups.google.com/forum/#!topic/discuss-webrtc/PvCbWSetVAQ
// @remark Only support SRTP_AES128_CM_SHA1_80, please read ssl/d1_srtp.c
srs_assert(SSL_CTX_set_tlsext_use_srtp(dtls_ctx, "SRTP_AES128_CM_SHA1_80") == 0);
}
return dtls_ctx;
}
SrsDtlsCertificate::SrsDtlsCertificate()
{
ecdsa_mode = true;
dtls_cert = NULL;
dtls_pkey = NULL;
eckey = NULL;
@ -244,12 +312,12 @@ ISrsDtlsCallback::~ISrsDtlsCallback()
{
}
SrsDtls::SrsDtls(ISrsDtlsCallback* cb)
SrsDtls::SrsDtls(ISrsDtlsCallback* callback)
{
dtls_ctx = NULL;
dtls = NULL;
callback = cb;
callback_ = callback;
handshake_done_for_us = false;
last_outgoing_packet_cache = new uint8_t[kRtpPacketSize];
@ -297,13 +365,13 @@ srs_error_t SrsDtls::initialize(std::string role, std::string version)
version_ = SrsDtlsVersionAuto;
}
dtls_ctx = build_dtls_ctx();
dtls_ctx = srs_build_dtls_ctx(version_);
if ((dtls = SSL_new(dtls_ctx)) == NULL) {
return srs_error_new(ERROR_OpenSslCreateSSL, "SSL_new dtls");
}
if (role == "active") {
if (role_ == SrsDtlsRoleClient) {
// Dtls setup active, as client role.
SSL_set_connect_state(dtls);
SSL_set_max_send_fragment(dtls, kRtpPacketSize);
@ -326,73 +394,6 @@ srs_error_t SrsDtls::initialize(std::string role, std::string version)
return err;
}
SSL_CTX* SrsDtls::build_dtls_ctx()
{
SSL_CTX* dtls_ctx;
#if OPENSSL_VERSION_NUMBER < 0x10002000L // v1.0.2
dtls_ctx = SSL_CTX_new(DTLSv1_method());
#else
if (version_ == SrsDtlsVersion1_0) {
dtls_ctx = SSL_CTX_new(DTLSv1_method());
} else if (version_ == SrsDtlsVersion1_2) {
dtls_ctx = SSL_CTX_new(DTLSv1_2_method());
} else {
// SrsDtlsVersionAuto, use version-flexible DTLS methods
dtls_ctx = SSL_CTX_new(DTLS_method());
}
#endif
if (_srs_rtc_dtls_certificate->is_ecdsa()) { // By ECDSA, https://stackoverflow.com/a/6006898
#if OPENSSL_VERSION_NUMBER >= 0x10002000L // v1.0.2
// For ECDSA, we could set the curves list.
// @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set1_curves_list.html
SSL_CTX_set1_curves_list(dtls_ctx, "P-521:P-384:P-256");
#endif
// For openssl <1.1, we must set the ECDH manually.
// @see https://stackoverrun.com/cn/q/10791887
#if OPENSSL_VERSION_NUMBER < 0x10100000L // v1.1.x
#if OPENSSL_VERSION_NUMBER < 0x10002000L // v1.0.2
SSL_CTX_set_tmp_ecdh(dtls_ctx, _srs_rtc_dtls_certificate->get_ecdsa_key());
#else
SSL_CTX_set_ecdh_auto(dtls_ctx, 1);
#endif
#endif
}
// Setup DTLS context.
if (true) {
// We use "ALL", while you can use "DEFAULT" means "ALL:!EXPORT:!LOW:!aNULL:!eNULL:!SSLv2"
// @see https://www.openssl.org/docs/man1.0.2/man1/ciphers.html
srs_assert(SSL_CTX_set_cipher_list(dtls_ctx, "ALL") == 1);
// Setup the certificate.
srs_assert(SSL_CTX_use_certificate(dtls_ctx, _srs_rtc_dtls_certificate->get_cert()) == 1);
srs_assert(SSL_CTX_use_PrivateKey(dtls_ctx, _srs_rtc_dtls_certificate->get_public_key()) == 1);
// Server will send Certificate Request.
// @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html
// TODO: FIXME: Config it, default to off to make the packet smaller.
SSL_CTX_set_verify(dtls_ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, verify_callback);
// The depth count is "level 0:peer certificate", "level 1: CA certificate",
// "level 2: higher level CA certificate", and so on.
// @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html
SSL_CTX_set_verify_depth(dtls_ctx, 4);
// Whether we should read as many input bytes as possible (for non-blocking reads) or not.
// @see https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_read_ahead.html
SSL_CTX_set_read_ahead(dtls_ctx, 1);
// TODO: Maybe we can use SRTP-GCM in future.
// @see https://bugs.chromium.org/p/chromium/issues/detail?id=713701
// @see https://groups.google.com/forum/#!topic/discuss-webrtc/PvCbWSetVAQ
// @remark Only support SRTP_AES128_CM_SHA1_80, please read ssl/d1_srtp.c
srs_assert(SSL_CTX_set_tlsext_use_srtp(dtls_ctx, "SRTP_AES128_CM_SHA1_80") == 0);
}
return dtls_ctx;
}
srs_error_t SrsDtls::start_active_handshake()
{
srs_error_t err = srs_success;
@ -437,7 +438,7 @@ srs_error_t SrsDtls::do_on_dtls(char* data, int nb_data)
}
// Trace the detail of DTLS packet.
state_trace((uint8_t*)data, nb_data, true, SSL_ERROR_NONE, false, false);
state_trace((uint8_t*)data, nb_data, true, r0, SSL_ERROR_NONE, false, false);
if ((r0 = BIO_write(bio_in, data, nb_data)) <= 0) {
// TODO: 0 or -1 maybe block, use BIO_should_retry to check.
@ -453,13 +454,12 @@ srs_error_t SrsDtls::do_on_dtls(char* data, int nb_data)
while (BIO_ctrl_pending(bio_in) > 0) {
char buf[8092];
int nb = SSL_read(dtls, buf, sizeof(buf));
if (!callback || nb <= 0) {
if (nb <= 0) {
continue;
}
srs_trace("DTLS: read nb=%d, data=[%s]", nb, srs_string_dumps_hex(buf, nb, 32).c_str());
if ((err = callback->on_dtls_application_data(buf, nb)) != srs_success) {
if ((err = callback_->on_dtls_application_data(buf, nb)) != srs_success) {
return srs_error_wrap(err, "on DTLS data, size=%u, data=[%s]", nb,
srs_string_dumps_hex(buf, nb, 32).c_str());
}
@ -476,7 +476,12 @@ srs_error_t SrsDtls::do_handshake()
int r0 = SSL_do_handshake(dtls);
int r1 = SSL_get_error(dtls, r0);
// TODO: What about SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE?
// Fatal SSL error, for example, no available suite when peer is DTLS 1.0 while we are DTLS 1.2.
if (r0 < 0 && (r1 != SSL_ERROR_NONE && r1 != SSL_ERROR_WANT_READ && r1 != SSL_ERROR_WANT_WRITE)) {
return srs_error_new(ERROR_OpenSslBIOWrite, "handshake r0=%d, r1=%d", r0, r1);
}
// OK, Handshake is done, note that it maybe done many times.
if (r1 == SSL_ERROR_NONE) {
handshake_done_for_us = true;
}
@ -495,7 +500,7 @@ srs_error_t SrsDtls::do_handshake()
}
// Trace the detail of DTLS packet.
state_trace((uint8_t*)data, size, false, r1, cache, false);
state_trace((uint8_t*)data, size, false, r0, r1, cache, false);
// Update the packet cache.
if (size > 0 && data != last_outgoing_packet_cache && size < kRtpPacketSize) {
@ -528,7 +533,7 @@ srs_error_t SrsDtls::do_handshake()
}
}
if (size > 0 && (err = callback->write_dtls_data(data, size)) != srs_success) {
if (size > 0 && (err = callback_->write_dtls_data(data, size)) != srs_success) {
return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size,
srs_string_dumps_hex((char*)data, size, 32).c_str());
}
@ -541,7 +546,7 @@ srs_error_t SrsDtls::do_handshake()
}
// Notify connection the DTLS is done.
if (((err = callback->on_dtls_handshake_done()) != srs_success)) {
if (((err = callback_->on_dtls_handshake_done()) != srs_success)) {
return srs_error_wrap(err, "dtls done");
}
}
@ -581,9 +586,9 @@ srs_error_t SrsDtls::cycle()
if (size) {
// Trace the detail of DTLS packet.
state_trace((uint8_t*)data, size, false, SSL_ERROR_NONE, true, true);
state_trace((uint8_t*)data, size, false, 1, SSL_ERROR_NONE, true, true);
if ((err = callback->write_dtls_data(data, size)) != srs_success) {
if ((err = callback_->write_dtls_data(data, size)) != srs_success) {
return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size,
srs_string_dumps_hex((char*)data, size, 32).c_str());
}
@ -596,7 +601,7 @@ srs_error_t SrsDtls::cycle()
return err;
}
void SrsDtls::state_trace(uint8_t* data, int length, bool incoming, int ssl_err, bool cache, bool arq)
void SrsDtls::state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq)
{
uint8_t content_type = 0;
if (length >= 1) {
@ -613,9 +618,9 @@ void SrsDtls::state_trace(uint8_t* data, int length, bool incoming, int ssl_err,
handshake_type = (uint8_t)data[13];
}
srs_trace("DTLS: %s %s, done=%u, cache=%u, arq=%u, state=%u, r0=%d, len=%u, cnt=%u, size=%u, hs=%u",
srs_trace("DTLS: %s %s, done=%u, cache=%u, arq=%u, state=%u, r0=%d, r1=%d, len=%u, cnt=%u, size=%u, hs=%u",
(role_ == SrsDtlsRoleClient? "Active":"Passive"), (incoming? "RECV":"SEND"), handshake_done_for_us, cache, arq,
state_, ssl_err, length, content_type, size, handshake_type);
state_, r0, r1, length, content_type, size, handshake_type);
}
srs_error_t SrsDtls::start_arq()

View file

@ -112,23 +112,20 @@ private:
SSL* dtls;
BIO* bio_in;
BIO* bio_out;
ISrsDtlsCallback* callback;
ISrsDtlsCallback* callback_;
private:
// Whether the handhshake is done, for us only.
// @remark For us only, means peer maybe not done, we also need to handle the DTLS packet.
bool handshake_done_for_us;
// DTLS packet cache, only last out-going packet.
uint8_t* last_outgoing_packet_cache;
int nn_last_outgoing_packet;
// ARQ thread, for role active(DTLS client).
// @note If passive(DTLS server), the ARQ is driven by DTLS client.
SrsCoroutine* trd;
// The DTLS-client state to drive the ARQ thread.
SrsDtlsState state_;
private:
// @remark: dtls_role_ default value is DTLS_SERVER.
SrsDtlsRole role_;
// @remark: dtls_version_ default value is SrsDtlsVersionAuto.
@ -138,8 +135,6 @@ public:
virtual ~SrsDtls();
public:
srs_error_t initialize(std::string role, std::string version);
private:
SSL_CTX* build_dtls_ctx();
public:
// As DTLS client, start handshake actively, send the ClientHello packet.
srs_error_t start_active_handshake();
@ -153,7 +148,7 @@ private:
public:
virtual srs_error_t cycle();
private:
void state_trace(uint8_t* data, int length, bool incoming, int ssl_err, bool cache, bool arq);
void state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq);
private:
srs_error_t start_arq();
void stop_arq();

View file

@ -368,6 +368,13 @@ SrsSample::SrsSample()
bframe = false;
}
SrsSample::SrsSample(char* b, int s)
{
size = s;
bytes = b;
bframe = false;
}
SrsSample::~SrsSample()
{
}

View file

@ -538,6 +538,7 @@ public:
bool bframe;
public:
SrsSample();
SrsSample(char* b, int s);
~SrsSample();
public:
// If we need to know whether sample is bframe, we have to parse the NALU payload.

View file

@ -28,6 +28,7 @@ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#include <srs_app_server.hpp>
#include <srs_app_config.hpp>
#include <srs_app_log.hpp>
#include <srs_app_rtc_dtls.hpp>
#include <string>
using namespace std;
@ -57,6 +58,10 @@ srs_error_t prepare_main() {
return srs_error_wrap(err, "init st");
}
if ((err = _srs_rtc_dtls_certificate->initialize()) != srs_success) {
return srs_error_wrap(err, "rtc dtls certificate initialize");
}
srs_freep(_srs_context);
_srs_context = new SrsThreadContext();

View file

@ -56,19 +56,24 @@ extern int _srs_tmp_port;
extern srs_utime_t _srs_tmp_timeout;
// For errors.
// @remark we directly delete the err, because we allow user to append message if fail.
#define HELPER_EXPECT_SUCCESS(x) \
if ((err = x) != srs_success) fprintf(stderr, "err %s", srs_error_desc(err).c_str()); \
EXPECT_TRUE(srs_success == err); \
srs_freep(err)
#define HELPER_EXPECT_FAILED(x) EXPECT_TRUE(srs_success != (err = x)); srs_freep(err)
if (err != srs_success) delete err; \
EXPECT_TRUE(srs_success == err)
#define HELPER_EXPECT_FAILED(x) \
if ((err = x) != srs_success) delete err; \
EXPECT_TRUE(srs_success != err)
// For errors, assert.
// @remark The err is leak when error, but it's ok in utest.
#define HELPER_ASSERT_SUCCESS(x) \
// @remark we directly delete the err, because we allow user to append message if fail.
#define HELPER_EXPECT_SUCCESS(x) \
if ((err = x) != srs_success) fprintf(stderr, "err %s", srs_error_desc(err).c_str()); \
ASSERT_TRUE(srs_success == err); \
srs_freep(err)
#define HELPER_ASSERT_FAILED(x) ASSERT_TRUE(srs_success != (err = x)); srs_freep(err)
if (err != srs_success) delete err; \
ASSERT_TRUE(srs_success == err)
#define HELPER_ASSERT_FAILED(x) \
if ((err = x) != srs_success) delete err; \
ASSERT_TRUE(srs_success != err)
// For init array data.
#define HELPER_ARRAY_INIT(buf, sz, val) \

View file

@ -29,10 +29,314 @@ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#include <srs_kernel_rtc_rtp.hpp>
#include <srs_app_rtc_source.hpp>
#include <srs_app_rtc_conn.hpp>
#include <srs_kernel_codec.hpp>
#include <vector>
using namespace std;
extern SSL_CTX* srs_build_dtls_ctx(SrsDtlsVersion version);
class MockDtls
{
public:
SSL_CTX* dtls_ctx;
SSL* dtls;
BIO* bio_in;
BIO* bio_out;
ISrsDtlsCallback* callback_;
bool handshake_done_for_us;
SrsDtlsRole role_;
SrsDtlsVersion version_;
public:
MockDtls(ISrsDtlsCallback* callback);
virtual ~MockDtls();
srs_error_t initialize(std::string role, std::string version);
srs_error_t start_active_handshake();
srs_error_t on_dtls(char* data, int nb_data);
srs_error_t do_handshake();
};
MockDtls::MockDtls(ISrsDtlsCallback* callback)
{
dtls_ctx = NULL;
dtls = NULL;
callback_ = callback;
handshake_done_for_us = false;
role_ = SrsDtlsRoleServer;
version_ = SrsDtlsVersionAuto;
}
MockDtls::~MockDtls()
{
if (dtls_ctx) {
SSL_CTX_free(dtls_ctx);
dtls_ctx = NULL;
}
if (dtls) {
SSL_free(dtls);
dtls = NULL;
}
}
srs_error_t MockDtls::initialize(std::string role, std::string version)
{
role_ = SrsDtlsRoleServer;
if (role == "active") {
role_ = SrsDtlsRoleClient;
}
if (version == "dtls1.0") {
version_ = SrsDtlsVersion1_0;
} else if (version == "dtls1.2") {
version_ = SrsDtlsVersion1_2;
} else {
version_ = SrsDtlsVersionAuto;
}
dtls_ctx = srs_build_dtls_ctx(version_);
dtls = SSL_new(dtls_ctx);
srs_assert(dtls);
if (role_ == SrsDtlsRoleClient) {
SSL_set_connect_state(dtls);
SSL_set_max_send_fragment(dtls, kRtpPacketSize);
} else {
SSL_set_accept_state(dtls);
}
bio_in = BIO_new(BIO_s_mem());
srs_assert(bio_in);
bio_out = BIO_new(BIO_s_mem());
srs_assert(bio_out);
SSL_set_bio(dtls, bio_in, bio_out);
return srs_success;
}
srs_error_t MockDtls::start_active_handshake()
{
if (role_ == SrsDtlsRoleClient) {
return do_handshake();
}
return srs_success;
}
srs_error_t MockDtls::on_dtls(char* data, int nb_data)
{
srs_error_t err = srs_success;
srs_assert(BIO_reset(bio_in) == 1);
srs_assert(BIO_reset(bio_out) == 1);
srs_assert(BIO_write(bio_in, data, nb_data) > 0);
if ((err = do_handshake()) != srs_success) {
return srs_error_wrap(err, "do handshake");
}
while (BIO_ctrl_pending(bio_in) > 0) {
char buf[8092];
int nb = SSL_read(dtls, buf, sizeof(buf));
if (nb <= 0) {
continue;
}
if ((err = callback_->on_dtls_application_data(buf, nb)) != srs_success) {
return srs_error_wrap(err, "on DTLS data, size=%u", nb);
}
}
return err;
}
srs_error_t MockDtls::do_handshake()
{
srs_error_t err = srs_success;
int r0 = SSL_do_handshake(dtls);
int r1 = SSL_get_error(dtls, r0);
if (r0 < 0 && (r1 != SSL_ERROR_NONE && r1 != SSL_ERROR_WANT_READ && r1 != SSL_ERROR_WANT_WRITE)) {
return srs_error_new(ERROR_OpenSslBIOWrite, "handshake r0=%d, r1=%d", r0, r1);
}
if (r1 == SSL_ERROR_NONE) {
handshake_done_for_us = true;
}
uint8_t* data = NULL;
int size = BIO_get_mem_data(bio_out, &data);
if (size > 0 && (err = callback_->write_dtls_data(data, size)) != srs_success) {
return srs_error_wrap(err, "dtls send size=%u", size);
}
if (handshake_done_for_us) {
return callback_->on_dtls_handshake_done();
}
return err;
}
class MockDtlsCallback : virtual public ISrsDtlsCallback, virtual public ISrsCoroutineHandler
{
public:
SrsDtls* peer;
MockDtls* peer2;
SrsCoroutine* trd;
srs_error_t r0;
bool done;
std::vector<SrsSample> samples;
MockDtlsCallback();
virtual ~MockDtlsCallback();
virtual srs_error_t on_dtls_handshake_done();
virtual srs_error_t on_dtls_application_data(const char* data, const int len);
virtual srs_error_t write_dtls_data(void* data, int size);
virtual srs_error_t cycle();
};
MockDtlsCallback::MockDtlsCallback()
{
peer = NULL;
peer2 = NULL;
r0 = srs_success;
done = false;
trd = new SrsSTCoroutine("mock", this);
srs_assert(trd->start() == srs_success);
}
MockDtlsCallback::~MockDtlsCallback()
{
srs_freep(trd);
srs_freep(r0);
for (vector<SrsSample>::iterator it = samples.begin(); it != samples.end(); ++it) {
delete[] it->bytes;
}
}
srs_error_t MockDtlsCallback::on_dtls_handshake_done()
{
done = true;
return srs_success;
}
srs_error_t MockDtlsCallback::on_dtls_application_data(const char* data, const int len)
{
return srs_success;
}
srs_error_t MockDtlsCallback::write_dtls_data(void* data, int size)
{
char* cp = new char[size];
memcpy(cp, data, size);
samples.push_back(SrsSample((char*)cp, size));
return srs_success;
}
srs_error_t MockDtlsCallback::cycle()
{
srs_error_t err = srs_success;
while (err == srs_success) {
if ((err = trd->pull()) != srs_success) {
break;
}
if (samples.empty()) {
srs_usleep(0);
continue;
}
SrsSample p = *samples.erase(samples.begin());
if (peer) {
err = peer->on_dtls((char*)p.bytes, p.size);
} else {
err = peer2->on_dtls((char*)p.bytes, p.size);
}
srs_freepa(p.bytes);
}
// Copy it for utest to check it.
r0 = srs_error_copy(err);
return err;
}
// Wait for mock io to done, try to switch to coroutine many times.
#define mock_wait_dtls_io_done(cio, sio) \
for (int i = 0; i < 100 && (!cio.samples.empty() || !sio.samples.empty()); i++) { \
srs_usleep(0 * SRS_UTIME_MILLISECONDS); \
}
struct DTLSServerFlowCase
{
int id;
string ClientVersion;
string ServerVersion;
bool ClientDone;
bool ServerDone;
bool ClientError;
bool ServerError;
};
std::ostream& operator<< (std::ostream& stream, const DTLSServerFlowCase& c)
{
stream << "Case #" << c.id
<< ", client(" << c.ClientVersion << ",done=" << c.ClientDone << ",err=" << c.ClientError << ")"
<< ", server(" << c.ServerVersion << ",done=" << c.ServerDone << ",err=" << c.ServerError << ")";
return stream;
}
VOID TEST(KernelRTCTest, DTLSServerFlowTest)
{
srs_error_t err = srs_success;
DTLSServerFlowCase cases[] = {
// OK, Client, Server: DTLS v1.0
{0, "dtls1.0", "dtls1.0", true, true, false, false},
// OK, Client, Server: DTLS v1.2
{1, "dtls1.2", "dtls1.2", true, true, false, false},
// OK, Client: DTLS v1.0, Server: DTLS auto(v1.0 or v1.2).
{2, "dtls1.0", "auto", true, true, false, false},
// OK, Client: DTLS v1.2, Server: DTLS auto(v1.0 or v1.2).
{3, "dtls1.2", "auto", true, true, false, false},
// OK, Client: DTLS auto(v1.0 or v1.2), Server: DTLS v1.0
{4, "auto", "dtls1.0", true, true, false, false},
// OK, Client: DTLS auto(v1.0 or v1.2), Server: DTLS v1.0
{5, "auto", "dtls1.2", true, true, false, false},
// Fail, Client: DTLS v1.0, Server: DTLS v1.2
{6, "dtls1.0", "dtls1.2", false, false, false, true},
// Fail, Client: DTLS v1.2, Server: DTLS v1.0
{7, "dtls1.2", "dtls1.0", false, false, true, false},
};
for (int i = 0; i < (int)(sizeof(cases) / sizeof(DTLSServerFlowCase)); i++) {
DTLSServerFlowCase c = cases[i];
MockDtlsCallback cio; MockDtls client(&cio);
MockDtlsCallback sio; SrsDtls server(&sio);
cio.peer = &server; sio.peer2 = &client;
HELPER_EXPECT_SUCCESS(client.initialize("active", c.ClientVersion)) << c;
HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c;
HELPER_EXPECT_SUCCESS(client.start_active_handshake()) << c;
mock_wait_dtls_io_done(cio, sio);
// Note that the cio error is generated from server, vice versa.
EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c;
EXPECT_EQ(c.ServerError, cio.r0 != srs_success) << c;
EXPECT_EQ(c.ClientDone, cio.done) << c;
EXPECT_EQ(c.ServerDone, sio.done) << c;
}
}
VOID TEST(KernelRTCTest, SequenceCompare)
{
if (true) {