mirror of
https://github.com/ossrs/srs.git
synced 2025-03-09 15:49:59 +00:00
RTC: Refine DTLS impl, extract client and server
This commit is contained in:
parent
5589120dc8
commit
9416fddd2b
3 changed files with 360 additions and 195 deletions
|
@ -312,7 +312,7 @@ ISrsDtlsCallback::~ISrsDtlsCallback()
|
|||
{
|
||||
}
|
||||
|
||||
SrsDtls::SrsDtls(ISrsDtlsCallback* callback)
|
||||
ISrsDtlsImpl::ISrsDtlsImpl(ISrsDtlsCallback* callback)
|
||||
{
|
||||
dtls_ctx = NULL;
|
||||
dtls = NULL;
|
||||
|
@ -323,19 +323,11 @@ SrsDtls::SrsDtls(ISrsDtlsCallback* callback)
|
|||
last_outgoing_packet_cache = new uint8_t[kRtpPacketSize];
|
||||
nn_last_outgoing_packet = 0;
|
||||
|
||||
role_ = SrsDtlsRoleServer;
|
||||
version_ = SrsDtlsVersionAuto;
|
||||
|
||||
trd = NULL;
|
||||
state_ = SrsDtlsStateInit;
|
||||
arq_first = 50 * SRS_UTIME_MILLISECONDS;
|
||||
arq_interval = 100 * SRS_UTIME_MILLISECONDS;
|
||||
}
|
||||
|
||||
SrsDtls::~SrsDtls()
|
||||
ISrsDtlsImpl::~ISrsDtlsImpl()
|
||||
{
|
||||
srs_freep(trd);
|
||||
|
||||
if (dtls_ctx) {
|
||||
SSL_CTX_free(dtls_ctx);
|
||||
dtls_ctx = NULL;
|
||||
|
@ -350,15 +342,10 @@ SrsDtls::~SrsDtls()
|
|||
srs_freepa(last_outgoing_packet_cache);
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::initialize(std::string role, std::string version)
|
||||
srs_error_t ISrsDtlsImpl::initialize(std::string version)
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
role_ = SrsDtlsRoleServer;
|
||||
if (role == "active") {
|
||||
role_ = SrsDtlsRoleClient;
|
||||
}
|
||||
|
||||
if (version == "dtls1.0") {
|
||||
version_ = SrsDtlsVersion1_0;
|
||||
} else if (version == "dtls1.2") {
|
||||
|
@ -373,15 +360,6 @@ srs_error_t SrsDtls::initialize(std::string role, std::string version)
|
|||
return srs_error_new(ERROR_OpenSslCreateSSL, "SSL_new dtls");
|
||||
}
|
||||
|
||||
if (role_ == SrsDtlsRoleClient) {
|
||||
// Dtls setup active, as client role.
|
||||
SSL_set_connect_state(dtls);
|
||||
SSL_set_max_send_fragment(dtls, kRtpPacketSize);
|
||||
} else {
|
||||
// Dtls setup passive, as server role.
|
||||
SSL_set_accept_state(dtls);
|
||||
}
|
||||
|
||||
if ((bio_in = BIO_new(BIO_s_mem())) == NULL) {
|
||||
return srs_error_new(ERROR_OpenSslBIONew, "BIO_new in");
|
||||
}
|
||||
|
@ -396,29 +374,10 @@ srs_error_t SrsDtls::initialize(std::string role, std::string version)
|
|||
return err;
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::start_active_handshake()
|
||||
srs_error_t ISrsDtlsImpl::on_dtls(char* data, int nb_data)
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
if (role_ == SrsDtlsRoleClient) {
|
||||
return do_handshake();
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::on_dtls(char* data, int nb_data)
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
// When got packet, stop the ARQ if server in the first ARQ state SrsDtlsStateServerHello.
|
||||
// @note But for ARQ state, we should never stop the ARQ, for example, we are in the second ARQ sate
|
||||
// SrsDtlsStateServerDone, but we got previous late wrong packet ServeHello, which is not the expect
|
||||
// packet SessionNewTicket, we should never stop the ARQ thread.
|
||||
if (role_ == SrsDtlsRoleClient && state_ == SrsDtlsStateServerHello) {
|
||||
stop_arq();
|
||||
}
|
||||
|
||||
if ((err = do_on_dtls(data, nb_data)) != srs_success) {
|
||||
return srs_error_wrap(err, "on_dtls size=%u, data=[%s]", nb_data,
|
||||
srs_string_dumps_hex(data, nb_data, 32).c_str());
|
||||
|
@ -427,7 +386,40 @@ srs_error_t SrsDtls::on_dtls(char* data, int nb_data)
|
|||
return err;
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::do_on_dtls(char* data, int nb_data)
|
||||
const int SRTP_MASTER_KEY_KEY_LEN = 16;
|
||||
const int SRTP_MASTER_KEY_SALT_LEN = 14;
|
||||
srs_error_t ISrsDtlsImpl::get_srtp_key(std::string& recv_key, std::string& send_key)
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
unsigned char material[SRTP_MASTER_KEY_LEN * 2] = {0}; // client(SRTP_MASTER_KEY_KEY_LEN + SRTP_MASTER_KEY_SALT_LEN) + server
|
||||
static const string dtls_srtp_lable = "EXTRACTOR-dtls_srtp";
|
||||
if (!SSL_export_keying_material(dtls, material, sizeof(material), dtls_srtp_lable.c_str(), dtls_srtp_lable.size(), NULL, 0, 0)) {
|
||||
return srs_error_new(ERROR_RTC_SRTP_INIT, "SSL export key r0=%u", ERR_get_error());
|
||||
}
|
||||
|
||||
size_t offset = 0;
|
||||
|
||||
std::string client_master_key(reinterpret_cast<char*>(material), SRTP_MASTER_KEY_KEY_LEN);
|
||||
offset += SRTP_MASTER_KEY_KEY_LEN;
|
||||
std::string server_master_key(reinterpret_cast<char*>(material + offset), SRTP_MASTER_KEY_KEY_LEN);
|
||||
offset += SRTP_MASTER_KEY_KEY_LEN;
|
||||
std::string client_master_salt(reinterpret_cast<char*>(material + offset), SRTP_MASTER_KEY_SALT_LEN);
|
||||
offset += SRTP_MASTER_KEY_SALT_LEN;
|
||||
std::string server_master_salt(reinterpret_cast<char*>(material + offset), SRTP_MASTER_KEY_SALT_LEN);
|
||||
|
||||
if (is_dtls_client()) {
|
||||
recv_key = server_master_key + server_master_salt;
|
||||
send_key = client_master_key + client_master_salt;
|
||||
} else {
|
||||
recv_key = client_master_key + client_master_salt;
|
||||
send_key = server_master_key + server_master_salt;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
srs_error_t ISrsDtlsImpl::do_on_dtls(char* data, int nb_data)
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
|
@ -470,7 +462,7 @@ srs_error_t SrsDtls::do_on_dtls(char* data, int nb_data)
|
|||
return err;
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::do_handshake()
|
||||
srs_error_t ISrsDtlsImpl::do_handshake()
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
|
@ -492,16 +484,9 @@ srs_error_t SrsDtls::do_handshake()
|
|||
uint8_t* data = NULL;
|
||||
int size = BIO_get_mem_data(bio_out, &data);
|
||||
|
||||
// If outgoing packet is empty, we use the last cache.
|
||||
// @remark Only for DTLS server, because DTLS client use ARQ thread to send cached packet.
|
||||
// Callback when got SSL original data.
|
||||
bool cache = false;
|
||||
if (role_ != SrsDtlsRoleClient && size <= 0 && nn_last_outgoing_packet) {
|
||||
size = nn_last_outgoing_packet;
|
||||
data = last_outgoing_packet_cache;
|
||||
cache = true;
|
||||
}
|
||||
|
||||
// Trace the detail of DTLS packet.
|
||||
on_ssl_out_data(data, size, cache);
|
||||
state_trace((uint8_t*)data, size, false, r0, r1, cache, false);
|
||||
|
||||
// Update the packet cache.
|
||||
|
@ -510,29 +495,9 @@ srs_error_t SrsDtls::do_handshake()
|
|||
nn_last_outgoing_packet = size;
|
||||
}
|
||||
|
||||
// Driven ARQ and state for DTLS client.
|
||||
if (role_ == SrsDtlsRoleClient) {
|
||||
// If we are sending client hello, change from init to new state.
|
||||
if (state_ == SrsDtlsStateInit && size > 14 && data[13] == 1) {
|
||||
state_ = SrsDtlsStateClientHello;
|
||||
}
|
||||
// If we are sending certificate, change from SrsDtlsStateServerHello to new state.
|
||||
if (state_ == SrsDtlsStateServerHello && size > 14 && data[13] == 11) {
|
||||
state_ = SrsDtlsStateClientCertificate;
|
||||
}
|
||||
|
||||
// Try to start the ARQ for client.
|
||||
if ((state_ == SrsDtlsStateClientHello || state_ == SrsDtlsStateClientCertificate)) {
|
||||
if (state_ == SrsDtlsStateClientHello) {
|
||||
state_ = SrsDtlsStateServerHello;
|
||||
} else if (state_ == SrsDtlsStateClientCertificate) {
|
||||
state_ = SrsDtlsStateServerDone;
|
||||
}
|
||||
|
||||
if ((err = start_arq()) != srs_success) {
|
||||
return srs_error_wrap(err, "start arq");
|
||||
}
|
||||
}
|
||||
// Callback for the final output data, before send-out.
|
||||
if ((err = on_final_out_data(data, size)) != srs_success) {
|
||||
return srs_error_wrap(err, "handle");
|
||||
}
|
||||
|
||||
if (size > 0 && (err = callback_->write_dtls_data(data, size)) != srs_success) {
|
||||
|
@ -541,22 +506,171 @@ srs_error_t SrsDtls::do_handshake()
|
|||
}
|
||||
|
||||
if (handshake_done_for_us) {
|
||||
// When handshake done, stop the ARQ.
|
||||
if (role_ == SrsDtlsRoleClient) {
|
||||
state_ = SrsDtlsStateClientDone;
|
||||
stop_arq();
|
||||
}
|
||||
|
||||
// Notify connection the DTLS is done.
|
||||
if (((err = callback_->on_dtls_handshake_done()) != srs_success)) {
|
||||
return srs_error_wrap(err, "dtls done");
|
||||
if (((err = on_handshake_done()) != srs_success)) {
|
||||
return srs_error_wrap(err, "done");
|
||||
}
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::cycle()
|
||||
void ISrsDtlsImpl::state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq)
|
||||
{
|
||||
uint8_t content_type = 0;
|
||||
if (length >= 1) {
|
||||
content_type = (uint8_t)data[0];
|
||||
}
|
||||
|
||||
uint16_t size = 0;
|
||||
if (length >= 13) {
|
||||
size = uint16_t(data[11])<<8 | uint16_t(data[12]);
|
||||
}
|
||||
|
||||
uint8_t handshake_type = 0;
|
||||
if (length >= 14) {
|
||||
handshake_type = (uint8_t)data[13];
|
||||
}
|
||||
|
||||
srs_trace("DTLS: %s %s, done=%u, cache=%u, arq=%u, r0=%d, r1=%d, len=%u, cnt=%u, size=%u, hs=%u",
|
||||
(is_dtls_client()? "Active":"Passive"), (incoming? "RECV":"SEND"), handshake_done_for_us, cache, arq,
|
||||
r0, r1, length, content_type, size, handshake_type);
|
||||
}
|
||||
|
||||
SrsDtlsClientImpl::SrsDtlsClientImpl(ISrsDtlsCallback* callback) : ISrsDtlsImpl(callback)
|
||||
{
|
||||
trd = NULL;
|
||||
state_ = SrsDtlsStateInit;
|
||||
arq_first = 50 * SRS_UTIME_MILLISECONDS;
|
||||
arq_interval = 100 * SRS_UTIME_MILLISECONDS;
|
||||
}
|
||||
|
||||
SrsDtlsClientImpl::~SrsDtlsClientImpl()
|
||||
{
|
||||
srs_freep(trd);
|
||||
}
|
||||
|
||||
srs_error_t SrsDtlsClientImpl::initialize(std::string version)
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
if ((err = ISrsDtlsImpl::initialize(version)) != srs_success) {
|
||||
return err;
|
||||
}
|
||||
|
||||
// Dtls setup active, as client role.
|
||||
SSL_set_connect_state(dtls);
|
||||
SSL_set_max_send_fragment(dtls, kRtpPacketSize);
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
srs_error_t SrsDtlsClientImpl::start_active_handshake()
|
||||
{
|
||||
return do_handshake();
|
||||
}
|
||||
|
||||
srs_error_t SrsDtlsClientImpl::on_dtls(char* data, int nb_data)
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
// When got packet, stop the ARQ if server in the first ARQ state SrsDtlsStateServerHello.
|
||||
// @note But for ARQ state, we should never stop the ARQ, for example, we are in the second ARQ sate
|
||||
// SrsDtlsStateServerDone, but we got previous late wrong packet ServeHello, which is not the expect
|
||||
// packet SessionNewTicket, we should never stop the ARQ thread.
|
||||
if (state_ == SrsDtlsStateServerHello) {
|
||||
stop_arq();
|
||||
}
|
||||
|
||||
if ((err = ISrsDtlsImpl::on_dtls(data, nb_data)) != srs_success) {
|
||||
return err;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
void SrsDtlsClientImpl::on_ssl_out_data(uint8_t*& data, int& size, bool& cached)
|
||||
{
|
||||
// DTLS client use ARQ thread to send cached packet.
|
||||
cached = false;
|
||||
}
|
||||
|
||||
srs_error_t SrsDtlsClientImpl::on_final_out_data(uint8_t* data, int size)
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
// Driven ARQ and state for DTLS client.
|
||||
// If we are sending client hello, change from init to new state.
|
||||
if (state_ == SrsDtlsStateInit && size > 14 && data[13] == 1) {
|
||||
state_ = SrsDtlsStateClientHello;
|
||||
}
|
||||
// If we are sending certificate, change from SrsDtlsStateServerHello to new state.
|
||||
if (state_ == SrsDtlsStateServerHello && size > 14 && data[13] == 11) {
|
||||
state_ = SrsDtlsStateClientCertificate;
|
||||
}
|
||||
|
||||
// Try to start the ARQ for client.
|
||||
if ((state_ == SrsDtlsStateClientHello || state_ == SrsDtlsStateClientCertificate)) {
|
||||
if (state_ == SrsDtlsStateClientHello) {
|
||||
state_ = SrsDtlsStateServerHello;
|
||||
} else if (state_ == SrsDtlsStateClientCertificate) {
|
||||
state_ = SrsDtlsStateServerDone;
|
||||
}
|
||||
|
||||
if ((err = start_arq()) != srs_success) {
|
||||
return srs_error_wrap(err, "start arq");
|
||||
}
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
srs_error_t SrsDtlsClientImpl::on_handshake_done()
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
// When handshake done, stop the ARQ.
|
||||
state_ = SrsDtlsStateClientDone;
|
||||
stop_arq();
|
||||
|
||||
// Notify connection the DTLS is done.
|
||||
if (((err = callback_->on_dtls_handshake_done()) != srs_success)) {
|
||||
return srs_error_wrap(err, "dtls done");
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
bool SrsDtlsClientImpl::is_dtls_client()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
srs_error_t SrsDtlsClientImpl::start_arq()
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
srs_info("start arq, state=%u", state_);
|
||||
|
||||
// Dispose the previous ARQ thread.
|
||||
srs_freep(trd);
|
||||
trd = new SrsSTCoroutine("dtls", this, _srs_context->get_id());
|
||||
|
||||
// We should start the ARQ thread for DTLS client.
|
||||
if ((err = trd->start()) != srs_success) {
|
||||
return srs_error_wrap(err, "arq start");
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
void SrsDtlsClientImpl::stop_arq()
|
||||
{
|
||||
srs_info("stop arq, state=%u", state_);
|
||||
srs_freep(trd);
|
||||
srs_info("stop arq, done");
|
||||
}
|
||||
|
||||
srs_error_t SrsDtlsClientImpl::cycle()
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
|
@ -603,90 +717,104 @@ srs_error_t SrsDtls::cycle()
|
|||
return err;
|
||||
}
|
||||
|
||||
void SrsDtls::state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq)
|
||||
SrsDtlsServerImpl::SrsDtlsServerImpl(ISrsDtlsCallback* callback) : ISrsDtlsImpl(callback)
|
||||
{
|
||||
uint8_t content_type = 0;
|
||||
if (length >= 1) {
|
||||
content_type = (uint8_t)data[0];
|
||||
}
|
||||
|
||||
uint16_t size = 0;
|
||||
if (length >= 13) {
|
||||
size = uint16_t(data[11])<<8 | uint16_t(data[12]);
|
||||
}
|
||||
|
||||
uint8_t handshake_type = 0;
|
||||
if (length >= 14) {
|
||||
handshake_type = (uint8_t)data[13];
|
||||
}
|
||||
|
||||
srs_trace("DTLS: %s %s, done=%u, cache=%u, arq=%u, state=%u, r0=%d, r1=%d, len=%u, cnt=%u, size=%u, hs=%u",
|
||||
(role_ == SrsDtlsRoleClient? "Active":"Passive"), (incoming? "RECV":"SEND"), handshake_done_for_us, cache, arq,
|
||||
state_, r0, r1, length, content_type, size, handshake_type);
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::start_arq()
|
||||
SrsDtlsServerImpl::~SrsDtlsServerImpl()
|
||||
{
|
||||
}
|
||||
|
||||
srs_error_t SrsDtlsServerImpl::initialize(std::string version)
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
if (role_ != SrsDtlsRoleClient) {
|
||||
if ((err = ISrsDtlsImpl::initialize(version)) != srs_success) {
|
||||
return err;
|
||||
}
|
||||
|
||||
srs_info("start arq, state=%u", state_);
|
||||
|
||||
// Dispose the previous ARQ thread.
|
||||
srs_freep(trd);
|
||||
trd = new SrsSTCoroutine("dtls", this, _srs_context->get_id());
|
||||
|
||||
// We should start the ARQ thread for DTLS client.
|
||||
if ((err = trd->start()) != srs_success) {
|
||||
return srs_error_wrap(err, "arq start");
|
||||
}
|
||||
// Dtls setup passive, as server role.
|
||||
SSL_set_accept_state(dtls);
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
void SrsDtls::stop_arq()
|
||||
srs_error_t SrsDtlsServerImpl::start_active_handshake()
|
||||
{
|
||||
srs_info("stop arq, state=%u", state_);
|
||||
srs_freep(trd);
|
||||
srs_info("stop arq, done");
|
||||
return srs_success;
|
||||
}
|
||||
|
||||
const int SRTP_MASTER_KEY_KEY_LEN = 16;
|
||||
const int SRTP_MASTER_KEY_SALT_LEN = 14;
|
||||
srs_error_t SrsDtls::get_srtp_key(std::string& recv_key, std::string& send_key)
|
||||
void SrsDtlsServerImpl::on_ssl_out_data(uint8_t*& data, int& size, bool& cached)
|
||||
{
|
||||
// If outgoing packet is empty, we use the last cache.
|
||||
// @remark Only for DTLS server, because DTLS client use ARQ thread to send cached packet.
|
||||
if (size <= 0 && nn_last_outgoing_packet) {
|
||||
size = nn_last_outgoing_packet;
|
||||
data = last_outgoing_packet_cache;
|
||||
cached = true;
|
||||
}
|
||||
}
|
||||
|
||||
srs_error_t SrsDtlsServerImpl::on_final_out_data(uint8_t* data, int size)
|
||||
{
|
||||
return srs_success;
|
||||
}
|
||||
|
||||
srs_error_t SrsDtlsServerImpl::on_handshake_done()
|
||||
{
|
||||
srs_error_t err = srs_success;
|
||||
|
||||
unsigned char material[SRTP_MASTER_KEY_LEN * 2] = {0}; // client(SRTP_MASTER_KEY_KEY_LEN + SRTP_MASTER_KEY_SALT_LEN) + server
|
||||
static const string dtls_srtp_lable = "EXTRACTOR-dtls_srtp";
|
||||
if (!SSL_export_keying_material(dtls, material, sizeof(material), dtls_srtp_lable.c_str(), dtls_srtp_lable.size(), NULL, 0, 0)) {
|
||||
return srs_error_new(ERROR_RTC_SRTP_INIT, "SSL export key r0=%u", ERR_get_error());
|
||||
}
|
||||
|
||||
size_t offset = 0;
|
||||
|
||||
std::string client_master_key(reinterpret_cast<char*>(material), SRTP_MASTER_KEY_KEY_LEN);
|
||||
offset += SRTP_MASTER_KEY_KEY_LEN;
|
||||
std::string server_master_key(reinterpret_cast<char*>(material + offset), SRTP_MASTER_KEY_KEY_LEN);
|
||||
offset += SRTP_MASTER_KEY_KEY_LEN;
|
||||
std::string client_master_salt(reinterpret_cast<char*>(material + offset), SRTP_MASTER_KEY_SALT_LEN);
|
||||
offset += SRTP_MASTER_KEY_SALT_LEN;
|
||||
std::string server_master_salt(reinterpret_cast<char*>(material + offset), SRTP_MASTER_KEY_SALT_LEN);
|
||||
|
||||
if (role_ == SrsDtlsRoleClient) {
|
||||
recv_key = server_master_key + server_master_salt;
|
||||
send_key = client_master_key + client_master_salt;
|
||||
} else {
|
||||
recv_key = client_master_key + client_master_salt;
|
||||
send_key = server_master_key + server_master_salt;
|
||||
// Notify connection the DTLS is done.
|
||||
if (((err = callback_->on_dtls_handshake_done()) != srs_success)) {
|
||||
return srs_error_wrap(err, "dtls done");
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
bool SrsDtlsServerImpl::is_dtls_client()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
SrsDtls::SrsDtls(ISrsDtlsCallback* callback)
|
||||
{
|
||||
callback_ = callback;
|
||||
impl = new SrsDtlsServerImpl(callback);
|
||||
}
|
||||
|
||||
SrsDtls::~SrsDtls()
|
||||
{
|
||||
srs_freep(impl);
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::initialize(std::string role, std::string version)
|
||||
{
|
||||
srs_freep(impl);
|
||||
if (role == "active") {
|
||||
impl = new SrsDtlsClientImpl(callback_);
|
||||
} else {
|
||||
impl = new SrsDtlsServerImpl(callback_);
|
||||
}
|
||||
|
||||
return impl->initialize(version);
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::start_active_handshake()
|
||||
{
|
||||
return impl->start_active_handshake();
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::on_dtls(char* data, int nb_data)
|
||||
{
|
||||
return impl->on_dtls(data, nb_data);
|
||||
}
|
||||
|
||||
srs_error_t SrsDtls::get_srtp_key(std::string& recv_key, std::string& send_key)
|
||||
{
|
||||
return impl->get_srtp_key(recv_key, send_key);
|
||||
}
|
||||
|
||||
SrsSRTP::SrsSRTP()
|
||||
{
|
||||
recv_ctx_ = NULL;
|
||||
|
|
|
@ -105,21 +105,45 @@ enum SrsDtlsState {
|
|||
SrsDtlsStateClientDone, // Done.
|
||||
};
|
||||
|
||||
class SrsDtls : public ISrsCoroutineHandler
|
||||
class ISrsDtlsImpl
|
||||
{
|
||||
private:
|
||||
protected:
|
||||
SSL_CTX* dtls_ctx;
|
||||
SSL* dtls;
|
||||
BIO* bio_in;
|
||||
BIO* bio_out;
|
||||
ISrsDtlsCallback* callback_;
|
||||
private:
|
||||
// @remark: dtls_version_ default value is SrsDtlsVersionAuto.
|
||||
SrsDtlsVersion version_;
|
||||
protected:
|
||||
// Whether the handhshake is done, for us only.
|
||||
// @remark For us only, means peer maybe not done, we also need to handle the DTLS packet.
|
||||
bool handshake_done_for_us;
|
||||
// DTLS packet cache, only last out-going packet.
|
||||
uint8_t* last_outgoing_packet_cache;
|
||||
int nn_last_outgoing_packet;
|
||||
public:
|
||||
ISrsDtlsImpl(ISrsDtlsCallback* callback);
|
||||
virtual ~ISrsDtlsImpl();
|
||||
public:
|
||||
virtual srs_error_t initialize(std::string version);
|
||||
virtual srs_error_t start_active_handshake() = 0;
|
||||
virtual srs_error_t on_dtls(char* data, int nb_data);
|
||||
srs_error_t get_srtp_key(std::string& recv_key, std::string& send_key);
|
||||
protected:
|
||||
srs_error_t do_on_dtls(char* data, int nb_data);
|
||||
srs_error_t do_handshake();
|
||||
void state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq);
|
||||
protected:
|
||||
virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached) = 0;
|
||||
virtual srs_error_t on_final_out_data(uint8_t* data, int size) = 0;
|
||||
virtual srs_error_t on_handshake_done() = 0;
|
||||
virtual bool is_dtls_client() = 0;
|
||||
};
|
||||
|
||||
class SrsDtlsClientImpl : virtual public ISrsDtlsImpl, virtual public ISrsCoroutineHandler
|
||||
{
|
||||
private:
|
||||
// ARQ thread, for role active(DTLS client).
|
||||
// @note If passive(DTLS server), the ARQ is driven by DTLS client.
|
||||
SrsCoroutine* trd;
|
||||
|
@ -128,11 +152,45 @@ private:
|
|||
// The timeout for ARQ.
|
||||
srs_utime_t arq_first;
|
||||
srs_utime_t arq_interval;
|
||||
public:
|
||||
SrsDtlsClientImpl(ISrsDtlsCallback* callback);
|
||||
virtual ~SrsDtlsClientImpl();
|
||||
public:
|
||||
virtual srs_error_t initialize(std::string version);
|
||||
virtual srs_error_t start_active_handshake();
|
||||
virtual srs_error_t on_dtls(char* data, int nb_data);
|
||||
protected:
|
||||
virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached);
|
||||
virtual srs_error_t on_final_out_data(uint8_t* data, int size);
|
||||
virtual srs_error_t on_handshake_done();
|
||||
virtual bool is_dtls_client();
|
||||
private:
|
||||
// @remark: dtls_role_ default value is DTLS_SERVER.
|
||||
SrsDtlsRole role_;
|
||||
// @remark: dtls_version_ default value is SrsDtlsVersionAuto.
|
||||
SrsDtlsVersion version_;
|
||||
srs_error_t start_arq();
|
||||
void stop_arq();
|
||||
public:
|
||||
virtual srs_error_t cycle();
|
||||
};
|
||||
|
||||
class SrsDtlsServerImpl : public ISrsDtlsImpl
|
||||
{
|
||||
public:
|
||||
SrsDtlsServerImpl(ISrsDtlsCallback* callback);
|
||||
virtual ~SrsDtlsServerImpl();
|
||||
public:
|
||||
virtual srs_error_t initialize(std::string version);
|
||||
virtual srs_error_t start_active_handshake();
|
||||
protected:
|
||||
virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached);
|
||||
virtual srs_error_t on_final_out_data(uint8_t* data, int size);
|
||||
virtual srs_error_t on_handshake_done();
|
||||
virtual bool is_dtls_client();
|
||||
};
|
||||
|
||||
class SrsDtls
|
||||
{
|
||||
private:
|
||||
ISrsDtlsImpl* impl;
|
||||
ISrsDtlsCallback* callback_;
|
||||
public:
|
||||
SrsDtls(ISrsDtlsCallback* callback);
|
||||
virtual ~SrsDtls();
|
||||
|
@ -144,17 +202,6 @@ public:
|
|||
// When got DTLS packet, may handshake packets or application data.
|
||||
// @remark When we are passive(DTLS server), we start handshake when got DTLS packet.
|
||||
srs_error_t on_dtls(char* data, int nb_data);
|
||||
private:
|
||||
srs_error_t do_on_dtls(char* data, int nb_data);
|
||||
srs_error_t do_handshake();
|
||||
// interface ISrsCoroutineHandler
|
||||
public:
|
||||
virtual srs_error_t cycle();
|
||||
private:
|
||||
void state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq);
|
||||
private:
|
||||
srs_error_t start_arq();
|
||||
void stop_arq();
|
||||
public:
|
||||
srs_error_t get_srtp_key(std::string& recv_key, std::string& send_key);
|
||||
};
|
||||
|
|
|
@ -36,7 +36,7 @@ using namespace std;
|
|||
|
||||
extern SSL_CTX* srs_build_dtls_ctx(SrsDtlsVersion version);
|
||||
|
||||
class MockDtls : public ISrsCoroutineHandler
|
||||
class MockDtls
|
||||
{
|
||||
public:
|
||||
SSL_CTX* dtls_ctx;
|
||||
|
@ -54,7 +54,6 @@ public:
|
|||
srs_error_t start_active_handshake();
|
||||
srs_error_t on_dtls(char* data, int nb_data);
|
||||
srs_error_t do_handshake();
|
||||
virtual srs_error_t cycle();
|
||||
};
|
||||
|
||||
MockDtls::MockDtls(ISrsDtlsCallback* callback)
|
||||
|
@ -180,11 +179,6 @@ srs_error_t MockDtls::do_handshake()
|
|||
return err;
|
||||
}
|
||||
|
||||
srs_error_t MockDtls::cycle()
|
||||
{
|
||||
return srs_success;
|
||||
}
|
||||
|
||||
class MockDtlsCallback : virtual public ISrsDtlsCallback, virtual public ISrsCoroutineHandler
|
||||
{
|
||||
public:
|
||||
|
@ -355,14 +349,10 @@ class MockBridgeDtlsIO
|
|||
private:
|
||||
MockDtlsCallback* io_;
|
||||
public:
|
||||
MockBridgeDtlsIO(MockDtlsCallback* io, ISrsCoroutineHandler* dtls) {
|
||||
MockBridgeDtlsIO(MockDtlsCallback* io, SrsDtls* peer, MockDtls* peer2) {
|
||||
io_ = io;
|
||||
if (dynamic_cast<SrsDtls*>(dtls)) {
|
||||
io->peer = dynamic_cast<SrsDtls*>(dtls);
|
||||
}
|
||||
if (dynamic_cast<MockDtls*>(dtls)) {
|
||||
io->peer2 = dynamic_cast<MockDtls*>(dtls);
|
||||
}
|
||||
io->peer = peer;
|
||||
io->peer2 = peer2;
|
||||
}
|
||||
virtual ~MockBridgeDtlsIO() {
|
||||
io_->peer = NULL;
|
||||
|
@ -424,13 +414,13 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest)
|
|||
if (true) {
|
||||
MockDtlsCallback cio; SrsDtls client(&cio);
|
||||
MockDtlsCallback sio; SrsDtls server(&sio);
|
||||
MockBridgeDtlsIO b0(&cio, &server); MockBridgeDtlsIO b1(&sio, &client);
|
||||
MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL);
|
||||
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
|
||||
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
|
||||
|
||||
// Use very short interval for utest.
|
||||
client.arq_first = 1 * SRS_UTIME_MILLISECONDS;
|
||||
client.arq_interval = 1 * SRS_UTIME_MILLISECONDS;
|
||||
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_first = 1 * SRS_UTIME_MILLISECONDS;
|
||||
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
|
||||
|
||||
// Lost 2 packets, total packets should be 3.
|
||||
// Note that only one server hello.
|
||||
|
@ -456,13 +446,13 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest)
|
|||
if (true) {
|
||||
MockDtlsCallback cio; SrsDtls client(&cio);
|
||||
MockDtlsCallback sio; SrsDtls server(&sio);
|
||||
MockBridgeDtlsIO b0(&cio, &server); MockBridgeDtlsIO b1(&sio, &client);
|
||||
MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL);
|
||||
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
|
||||
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
|
||||
|
||||
// Use very short interval for utest.
|
||||
client.arq_first = 1 * SRS_UTIME_MILLISECONDS;
|
||||
client.arq_interval = 1 * SRS_UTIME_MILLISECONDS;
|
||||
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_first = 1 * SRS_UTIME_MILLISECONDS;
|
||||
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
|
||||
|
||||
// Lost 2 packets, total packets should be 3.
|
||||
// Note that only one server NewSessionTicket.
|
||||
|
@ -517,13 +507,13 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest)
|
|||
if (true) {
|
||||
MockDtlsCallback cio; SrsDtls client(&cio);
|
||||
MockDtlsCallback sio; SrsDtls server(&sio);
|
||||
MockBridgeDtlsIO b0(&cio, &server); MockBridgeDtlsIO b1(&sio, &client);
|
||||
MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL);
|
||||
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
|
||||
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
|
||||
|
||||
// Use very short interval for utest.
|
||||
client.arq_first = 1 * SRS_UTIME_MILLISECONDS;
|
||||
client.arq_interval = 1 * SRS_UTIME_MILLISECONDS;
|
||||
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_first = 1 * SRS_UTIME_MILLISECONDS;
|
||||
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
|
||||
|
||||
// Lost 2 packets, total packets should be 3.
|
||||
sio.nn_server_hello_lost = 2;
|
||||
|
@ -548,13 +538,13 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest)
|
|||
if (true) {
|
||||
MockDtlsCallback cio; SrsDtls client(&cio);
|
||||
MockDtlsCallback sio; SrsDtls server(&sio);
|
||||
MockBridgeDtlsIO b0(&cio, &server); MockBridgeDtlsIO b1(&sio, &client);
|
||||
MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL);
|
||||
HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0"));
|
||||
HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0"));
|
||||
|
||||
// Use very short interval for utest.
|
||||
client.arq_first = 1 * SRS_UTIME_MILLISECONDS;
|
||||
client.arq_interval = 1 * SRS_UTIME_MILLISECONDS;
|
||||
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_first = 1 * SRS_UTIME_MILLISECONDS;
|
||||
dynamic_cast<SrsDtlsClientImpl*>(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS;
|
||||
|
||||
// Lost 2 packets, total packets should be 3.
|
||||
sio.nn_new_session_lost = 2;
|
||||
|
@ -604,7 +594,7 @@ VOID TEST(KernelRTCTest, DTLSClientFlowTest)
|
|||
|
||||
MockDtlsCallback cio; SrsDtls client(&cio);
|
||||
MockDtlsCallback sio; MockDtls server(&sio);
|
||||
cio.peer2 = &server; sio.peer = &client;
|
||||
MockBridgeDtlsIO b0(&cio, NULL, &server); MockBridgeDtlsIO b1(&sio, &client, NULL);
|
||||
HELPER_EXPECT_SUCCESS(client.initialize("active", c.ClientVersion)) << c;
|
||||
HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c;
|
||||
|
||||
|
@ -648,7 +638,7 @@ VOID TEST(KernelRTCTest, DTLSServerFlowTest)
|
|||
|
||||
MockDtlsCallback cio; MockDtls client(&cio);
|
||||
MockDtlsCallback sio; SrsDtls server(&sio);
|
||||
cio.peer = &server; sio.peer2 = &client;
|
||||
MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, NULL, &client);
|
||||
HELPER_EXPECT_SUCCESS(client.initialize("active", c.ClientVersion)) << c;
|
||||
HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue