From dd7587c4975060bcf240bc517e65bfa3cdf38062 Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 14 Sep 2020 10:47:06 +0800 Subject: [PATCH] Dispose session when DTLS alert --- trunk/src/app/srs_app_rtc_conn.cpp | 30 +++++++++++++ trunk/src/app/srs_app_rtc_conn.hpp | 4 ++ trunk/src/app/srs_app_rtc_dtls.cpp | 15 +++++++ trunk/src/app/srs_app_rtc_dtls.hpp | 3 ++ trunk/src/app/srs_app_rtc_server.cpp | 65 +++++++++++++--------------- trunk/src/app/srs_app_rtc_server.hpp | 6 +-- trunk/src/utest/srs_utest_rtc.cpp | 6 +++ 7 files changed, 92 insertions(+), 37 deletions(-) diff --git a/trunk/src/app/srs_app_rtc_conn.cpp b/trunk/src/app/srs_app_rtc_conn.cpp index 2b185b390..a6646304e 100644 --- a/trunk/src/app/srs_app_rtc_conn.cpp +++ b/trunk/src/app/srs_app_rtc_conn.cpp @@ -118,6 +118,11 @@ srs_error_t SrsSecurityTransport::on_dtls(char* data, int nb_data) return dtls_->on_dtls(data, nb_data); } +srs_error_t SrsSecurityTransport::on_dtls_alert(std::string type, std::string desc) +{ + return session_->on_dtls_alert(type, desc); +} + srs_error_t SrsSecurityTransport::on_dtls_handshake_done() { srs_error_t err = srs_success; @@ -237,6 +242,11 @@ srs_error_t SrsPlaintextTransport::on_dtls(char* data, int nb_data) return srs_success; } +srs_error_t SrsPlaintextTransport::on_dtls_alert(std::string type, std::string desc) +{ + return srs_success; +} + srs_error_t SrsPlaintextTransport::on_dtls_handshake_done() { srs_trace("RTC: DTLS handshake done."); @@ -2112,6 +2122,26 @@ srs_error_t SrsRtcConnection::on_connection_established() return err; } +srs_error_t SrsRtcConnection::on_dtls_alert(std::string type, std::string desc) +{ + srs_error_t err = srs_success; + + SrsRtcConnection* session = this; + + // CN(Close Notify) is sent when client close the PeerConnection. + if (type == "warning" && desc == "CN") { + SrsContextRestore(_srs_context->get_id()); + session->switch_to_context(); + + string username = session->username(); + srs_trace("RTC: session DTLS alert, username=%s, summary: %s", username.c_str(), session->stat_->summary().c_str()); + + server_->dispose(session); + } + + return err; +} + srs_error_t SrsRtcConnection::start_play(string stream_uri) { srs_error_t err = srs_success; diff --git a/trunk/src/app/srs_app_rtc_conn.hpp b/trunk/src/app/srs_app_rtc_conn.hpp index e640de19f..561ad2459 100644 --- a/trunk/src/app/srs_app_rtc_conn.hpp +++ b/trunk/src/app/srs_app_rtc_conn.hpp @@ -94,6 +94,7 @@ public: virtual srs_error_t initialize(SrsSessionConfig* cfg) = 0; virtual srs_error_t start_active_handshake() = 0; virtual srs_error_t on_dtls(char* data, int nb_data) = 0; + virtual srs_error_t on_dtls_alert(std::string type, std::string desc) = 0; public: virtual srs_error_t protect_rtp(const char* plaintext, char* cipher, int& nb_cipher) = 0; virtual srs_error_t protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher) = 0; @@ -118,6 +119,7 @@ public: // When play role of dtls client, it send handshake. srs_error_t start_active_handshake(); srs_error_t on_dtls(char* data, int nb_data); + srs_error_t on_dtls_alert(std::string type, std::string desc); public: // Encrypt the input plaintext to output cipher with nb_cipher bytes. // @remark Note that the nb_cipher is the size of input plaintext, and @@ -165,6 +167,7 @@ public: virtual srs_error_t initialize(SrsSessionConfig* cfg); virtual srs_error_t start_active_handshake(); virtual srs_error_t on_dtls(char* data, int nb_data); + virtual srs_error_t on_dtls_alert(std::string type, std::string desc); virtual srs_error_t on_dtls_handshake_done(); virtual srs_error_t on_dtls_application_data(const char* data, const int len); virtual srs_error_t write_dtls_data(void* data, int size); @@ -508,6 +511,7 @@ public: void set_hijacker(ISrsRtcConnectionHijacker* h); public: srs_error_t on_connection_established(); + srs_error_t on_dtls_alert(std::string type, std::string desc); srs_error_t start_play(std::string stream_uri); srs_error_t start_publish(std::string stream_uri); bool is_stun_timeout(); diff --git a/trunk/src/app/srs_app_rtc_dtls.cpp b/trunk/src/app/srs_app_rtc_dtls.cpp index 20ce37b2a..6c17d133b 100644 --- a/trunk/src/app/srs_app_rtc_dtls.cpp +++ b/trunk/src/app/srs_app_rtc_dtls.cpp @@ -57,6 +57,9 @@ int srs_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) // Print the information of SSL, DTLS alert as such. void ssl_on_info(const SSL* dtls, int where, int ret) { + SrsDtlsImpl* dtls_impl = (SrsDtlsImpl*)SSL_get_ex_data(dtls, 0); + srs_assert(dtls_impl); + const char* method; int w = where& ~SSL_ST_MASK; if (w & SSL_ST_CONNECT) { @@ -85,6 +88,9 @@ void ssl_on_info(const SSL* dtls, int where, int ret) srs_error("DTLS: SSL3 alert method=%s type=%s, desc=%s(%s), where=%d, ret=%d, r1=%d", method, alert_type.c_str(), alert_desc.c_str(), SSL_alert_desc_string_long(ret), where, ret, r1); } + + // Notify the DTLS to handle the ALERT message, which maybe means media connection disconnect. + dtls_impl->callback_by_ssl(alert_type, alert_desc); } else if (where & SSL_CB_EXIT) { if (ret == 0) { srs_warn("DTLS: Fail method=%s state=%s(%s), where=%d, ret=%d, r1=%d", method, SSL_state_string(dtls), @@ -588,6 +594,15 @@ srs_error_t SrsDtlsImpl::get_srtp_key(std::string& recv_key, std::string& send_k return err; } +void SrsDtlsImpl::callback_by_ssl(std::string type, std::string desc) +{ + srs_error_t err = srs_success; + if ((err = callback_->on_dtls_alert(type, desc)) != srs_success) { + srs_warn2("DTLSAlert", "DTLS: handler alert err %s", srs_error_desc(err).c_str()); + srs_freep(err); + } +} + SrsDtlsClientImpl::SrsDtlsClientImpl(ISrsDtlsCallback* callback) : SrsDtlsImpl(callback) { trd = NULL; diff --git a/trunk/src/app/srs_app_rtc_dtls.hpp b/trunk/src/app/srs_app_rtc_dtls.hpp index 019c360ec..089a42264 100644 --- a/trunk/src/app/srs_app_rtc_dtls.hpp +++ b/trunk/src/app/srs_app_rtc_dtls.hpp @@ -93,6 +93,8 @@ public: virtual srs_error_t on_dtls_application_data(const char* data, const int len) = 0; // DTLS write dtls data. virtual srs_error_t write_dtls_data(void* data, int size) = 0; + // Callback when DTLS Alert message. + virtual srs_error_t on_dtls_alert(std::string type, std::string desc) = 0; }; // The state for DTLS client. @@ -135,6 +137,7 @@ protected: void state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq); public: srs_error_t get_srtp_key(std::string& recv_key, std::string& send_key); + void callback_by_ssl(std::string type, std::string desc); protected: virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached) = 0; virtual srs_error_t on_final_out_data(uint8_t* data, int size) = 0; diff --git a/trunk/src/app/srs_app_rtc_server.cpp b/trunk/src/app/srs_app_rtc_server.cpp index 1cdc9fc16..855ad1309 100644 --- a/trunk/src/app/srs_app_rtc_server.cpp +++ b/trunk/src/app/srs_app_rtc_server.cpp @@ -575,6 +575,19 @@ srs_error_t SrsRtcServer::setup_session2(SrsRtcConnection* session, SrsRequest* return err; } +void SrsRtcServer::dispose(SrsRtcConnection* session) +{ + if (session->disposing_) { + return; + } + + destroy(session); + + if (handler) { + handler->on_timeout(session); + } +} + void SrsRtcServer::destroy(SrsRtcConnection* session) { if (session->disposing_) { @@ -582,12 +595,6 @@ void SrsRtcServer::destroy(SrsRtcConnection* session) } session->disposing_ = true; - SrsContextRestore(_srs_context->get_id()); - session->switch_to_context(); - - string username = session->username(); - srs_trace("RTC: session destroy, username=%s, summary: %s", username.c_str(), session->stat_->summary().c_str()); - manager->remove(session); } @@ -596,31 +603,6 @@ void SrsRtcServer::insert_into_id_sessions(const string& peer_id, SrsRtcConnecti manager->add_with_id(peer_id, session); } -void SrsRtcServer::check_and_clean_timeout_session() -{ - for (int i = 0; i < (int)manager->size(); i++) { - SrsRtcConnection* session = dynamic_cast(manager->at(i)); - srs_assert(session); - - if (!session->is_stun_timeout()) { - continue; - } - - // Now, we got the RTC session to cleanup, switch to its context - // to make all logs write to the "correct" pid+cid. - session->switch_to_context(); - string username = session->username(); - srs_trace("RTC: session STUN timeout, username=%s, summary: %s", username.c_str(), session->stat_->summary().c_str()); - - session->disposing_ = true; - manager->remove(session); - - if (handler) { - handler->on_timeout(session); - } - } -} - SrsRtcConnection* SrsRtcServer::find_session_by_username(const std::string& username) { ISrsConnection* conn = manager->find_by_name(username); @@ -631,9 +613,24 @@ srs_error_t SrsRtcServer::notify(int type, srs_utime_t interval, srs_utime_t tic { srs_error_t err = srs_success; - // TODO: FIXME: Merge small functions. - // Check session timeout, put to zombies queue. - check_and_clean_timeout_session(); + // Check all sessions and dispose the dead sessions. + for (int i = 0; i < (int)manager->size(); i++) { + SrsRtcConnection* session = dynamic_cast(manager->at(i)); + srs_assert(session); + + if (!session->is_stun_timeout()) { + continue; + } + + SrsContextRestore(_srs_context->get_id()); + session->switch_to_context(); + + string username = session->username(); + srs_trace("RTC: session STUN timeout, username=%s, summary: %s", username.c_str(), session->stat_->summary().c_str()); + + // Destroy session and notify the handler. + dispose(session); + } return err; } diff --git a/trunk/src/app/srs_app_rtc_server.hpp b/trunk/src/app/srs_app_rtc_server.hpp index 3110c92c3..d1bf3890e 100644 --- a/trunk/src/app/srs_app_rtc_server.hpp +++ b/trunk/src/app/srs_app_rtc_server.hpp @@ -123,12 +123,12 @@ public: // We start offering, create_session2 to generate offer, setup_session2 to handle answer. srs_error_t create_session2(SrsRequest* req, SrsSdp& local_sdp, const std::string& mock_eip, bool unified_plan, SrsRtcConnection** psession); srs_error_t setup_session2(SrsRtcConnection* session, SrsRequest* req, const SrsSdp& remote_sdp); - // Destroy the session from server. + // Destroy the session and notify the callback. + void dispose(SrsRtcConnection* session); + // Destroy the session from server, without notify callback. void destroy(SrsRtcConnection* session); public: void insert_into_id_sessions(const std::string& peer_id, SrsRtcConnection* session); -private: - void check_and_clean_timeout_session(); public: SrsRtcConnection* find_session_by_username(const std::string& ufrag); // interface ISrsHourGlass diff --git a/trunk/src/utest/srs_utest_rtc.cpp b/trunk/src/utest/srs_utest_rtc.cpp index 3d7d511b4..b1da5accc 100644 --- a/trunk/src/utest/srs_utest_rtc.cpp +++ b/trunk/src/utest/srs_utest_rtc.cpp @@ -298,6 +298,7 @@ public: virtual srs_error_t on_dtls_handshake_done(); virtual srs_error_t on_dtls_application_data(const char* data, const int len); virtual srs_error_t write_dtls_data(void* data, int size); + virtual srs_error_t on_dtls_alert(std::string type, std::string desc); virtual srs_error_t cycle(); }; @@ -391,6 +392,11 @@ srs_error_t MockDtlsCallback::write_dtls_data(void* data, int size) return srs_success; } +srs_error_t MockDtlsCallback::on_dtls_alert(std::string type, std::string desc) +{ + return srs_success; +} + srs_error_t MockDtlsCallback::cycle() { srs_error_t err = srs_success;