diff --git a/trunk/src/app/srs_app_rtc_dtls.cpp b/trunk/src/app/srs_app_rtc_dtls.cpp index f44422109..d5657e923 100644 --- a/trunk/src/app/srs_app_rtc_dtls.cpp +++ b/trunk/src/app/srs_app_rtc_dtls.cpp @@ -257,10 +257,15 @@ SrsDtls::SrsDtls(ISrsDtlsCallback* cb) role_ = SrsDtlsRoleServer; version_ = SrsDtlsVersionAuto; + + trd = NULL; + client_state_ = SrsDtlsStateInit; } SrsDtls::~SrsDtls() { + srs_freep(trd); + if (dtls_ctx) { SSL_CTX_free(dtls_ctx); dtls_ctx = NULL; @@ -391,17 +396,24 @@ SSL_CTX* SrsDtls::build_dtls_ctx() srs_error_t SrsDtls::start_active_handshake() { + srs_error_t err = srs_success; + if (role_ == SrsDtlsRoleClient) { return do_handshake(); } - return srs_success; + return err; } srs_error_t SrsDtls::on_dtls(char* data, int nb_data) { srs_error_t err = srs_success; + // When got packet, always stop the ARQ. + if (role_ == SrsDtlsRoleClient && client_state_ == SrsDtlsStateServerHello) { + stop_arq(); + } + if ((err = do_on_dtls(data, nb_data)) != srs_success) { return srs_error_wrap(err, "on_dtls size=%u, data=[%s]", nb_data, srs_string_dumps_hex(data, nb_data, 32).c_str()); @@ -423,7 +435,7 @@ srs_error_t SrsDtls::do_on_dtls(char* data, int nb_data) } // Trace the detail of DTLS packet. - trace((uint8_t*)data, nb_data, true, SSL_ERROR_NONE, false); + state_trace((uint8_t*)data, nb_data, true, SSL_ERROR_NONE, false, false); if ((r0 = BIO_write(bio_in, data, nb_data)) <= 0) { // TODO: 0 or -1 maybe block, use BIO_should_retry to check. @@ -484,15 +496,16 @@ srs_error_t SrsDtls::do_handshake() } // If outgoing packet is empty, we use the last cache. + // @remark Only for DTLS server, because DTLS client use ARQ thread to send cached packet. bool cache = false; - if (size <= 0 && nn_last_outgoing_packet) { + if (role_ != SrsDtlsRoleClient && size <= 0 && nn_last_outgoing_packet) { size = nn_last_outgoing_packet; data = last_outgoing_packet_cache; cache = true; } // Trace the detail of DTLS packet. - trace((uint8_t*)data, size, false, ssl_err, cache); + state_trace((uint8_t*)data, size, false, ssl_err, cache, false); // Update the packet cache. if (size > 0 && data != last_outgoing_packet_cache && size < kRtpPacketSize) { @@ -500,12 +513,44 @@ srs_error_t SrsDtls::do_handshake() nn_last_outgoing_packet = size; } + // Driven ARQ and state for DTLS client. + if (role_ == SrsDtlsRoleClient) { + // If we are sending client hello, change from init to new state. + if (client_state_ == SrsDtlsStateInit && size > 14 && data[13] == 1) { + client_state_ = SrsDtlsStateClientHello; + } + // If we are sending certificate, change from SrsDtlsStateServerHello to new state. + if (client_state_ == SrsDtlsStateServerHello && size > 14 && data[13] == 11) { + client_state_ = SrsDtlsStateClientCertificate; + } + + // Try to start the ARQ for client. + if ((client_state_ == SrsDtlsStateClientHello || client_state_ == SrsDtlsStateClientCertificate)) { + if (client_state_ == SrsDtlsStateClientHello) { + client_state_ = SrsDtlsStateServerHello; + } else if (client_state_ == SrsDtlsStateClientCertificate) { + client_state_ = SrsDtlsStateServerDone; + } + + if ((err = start_arq()) != srs_success) { + return srs_error_wrap(err, "start arq"); + } + } + } + if (size > 0 && (err = callback->write_dtls_data(data, size)) != srs_success) { return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size, srs_string_dumps_hex((char*)data, size, 32).c_str()); } if (handshake_done_for_us) { + // When handshake done, stop the ARQ. + if (role_ == SrsDtlsRoleClient) { + client_state_ = SrsDtlsStateClientDone; + stop_arq(); + } + + // Notify connection the DTLS is done. if (((err = callback->on_dtls_handshake_done()) != srs_success)) { return srs_error_wrap(err, "dtls done"); } @@ -514,7 +559,54 @@ srs_error_t SrsDtls::do_handshake() return err; } -void SrsDtls::trace(uint8_t* data, int length, bool incoming, int ssl_err, bool cache) +srs_error_t SrsDtls::cycle() +{ + srs_error_t err = srs_success; + + // The first ARQ delay. + srs_usleep(50 * SRS_UTIME_MILLISECONDS); + + while (true) { + srs_info("arq cycle, state=%u", client_state_); + + // We ignore any error for ARQ thread. + if ((err = trd->pull()) != srs_success) { + srs_freep(err); + return err; + } + + // If done, should stop ARQ. + if (handshake_done_for_us) { + return err; + } + + // For DTLS client ARQ, the state should be specified. + if (client_state_ != SrsDtlsStateServerHello && client_state_ != SrsDtlsStateServerDone) { + return err; + } + + // Try to retransmit the packet. + uint8_t* data = last_outgoing_packet_cache; + int size = nn_last_outgoing_packet; + + if (size) { + // Trace the detail of DTLS packet. + state_trace((uint8_t*)data, size, false, SSL_ERROR_NONE, true, true); + + if ((err = callback->write_dtls_data(data, size)) != srs_success) { + return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size, + srs_string_dumps_hex((char*)data, size, 32).c_str()); + } + } + + // TODO: Use ARQ step timeouts. + srs_usleep(100 * SRS_UTIME_MILLISECONDS); + } + + return err; +} + +void SrsDtls::state_trace(uint8_t* data, int length, bool incoming, int ssl_err, bool cache, bool arq) { uint8_t content_type = 0; if (length >= 1) { @@ -531,8 +623,38 @@ void SrsDtls::trace(uint8_t* data, int length, bool incoming, int ssl_err, bool handshake_type = (uint8_t)data[13]; } - srs_trace("DTLS: %s done=%u, cache=%u, ssl-err=%d, length=%u, content-type=%u, size=%u, handshake-type=%u", (incoming? "RECV":"SEND"), - handshake_done_for_us, cache, ssl_err, length, content_type, size, handshake_type); + srs_trace("DTLS: %s %s, done=%u, cache=%u, arq=%u, state=%u, ssl-err=%d, length=%u, content=%u, size=%u, handshake=%u", + (role_ == SrsDtlsRoleClient? "Client":"Server"), (incoming? "RECV":"SEND"), handshake_done_for_us, cache, arq, + client_state_, ssl_err, length, content_type, size, handshake_type); +} + +srs_error_t SrsDtls::start_arq() +{ + srs_error_t err = srs_success; + + if (role_ != SrsDtlsRoleClient) { + return err; + } + + srs_info("start arq, state=%u", client_state_); + + // Dispose the previous ARQ thread. + srs_freep(trd); + trd = new SrsSTCoroutine("dtls", this, _srs_context->get_id()); + + // We should start the ARQ thread for DTLS client. + if ((err = trd->start()) != srs_success) { + return srs_error_wrap(err, "arq start"); + } + + return err; +} + +void SrsDtls::stop_arq() +{ + srs_info("stop arq, state=%u", client_state_); + srs_freep(trd); + srs_info("stop arq, done"); } const int SRTP_MASTER_KEY_KEY_LEN = 16; diff --git a/trunk/src/app/srs_app_rtc_dtls.hpp b/trunk/src/app/srs_app_rtc_dtls.hpp index 2106e1464..03342c16e 100644 --- a/trunk/src/app/srs_app_rtc_dtls.hpp +++ b/trunk/src/app/srs_app_rtc_dtls.hpp @@ -27,12 +27,15 @@ #include #include - -class SrsRequest; +#include #include #include +#include + +class SrsRequest; + class SrsDtlsCertificate { private: @@ -92,7 +95,17 @@ public: virtual srs_error_t write_dtls_data(void* data, int size) = 0; }; -class SrsDtls +// The state for DTLS client. +enum SrsDtlsState { + SrsDtlsStateInit, + SrsDtlsStateClientHello, + SrsDtlsStateServerHello, + SrsDtlsStateClientCertificate, + SrsDtlsStateServerDone, + SrsDtlsStateClientDone, +}; + +class SrsDtls : public ISrsCoroutineHandler { private: SSL_CTX* dtls_ctx; @@ -110,6 +123,12 @@ private: uint8_t* last_outgoing_packet_cache; int nn_last_outgoing_packet; + // ARQ thread, for role active(DTLS client). + // @note If passive(DTLS server), the ARQ is driven by DTLS client. + SrsCoroutine* trd; + // The DTLS-client state to drive the ARQ thread. + SrsDtlsState client_state_; + // @remark: dtls_role_ default value is DTLS_SERVER. SrsDtlsRole role_; // @remark: dtls_version_ default value is SrsDtlsVersionAuto. @@ -130,7 +149,14 @@ public: private: srs_error_t do_on_dtls(char* data, int nb_data); srs_error_t do_handshake(); - void trace(uint8_t* data, int length, bool incoming, int ssl_err, bool cache); +// interface ISrsCoroutineHandler +public: + virtual srs_error_t cycle(); +private: + void state_trace(uint8_t* data, int length, bool incoming, int ssl_err, bool cache, bool arq); +private: + srs_error_t start_arq(); + void stop_arq(); public: srs_error_t get_srtp_key(std::string& recv_key, std::string& send_key); };