1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-02-13 11:51:57 +00:00

Dispose session when DTLS alert

This commit is contained in:
winlin 2020-09-14 10:47:06 +08:00
parent 86a80396de
commit dd7587c497
7 changed files with 92 additions and 37 deletions

View file

@ -118,6 +118,11 @@ srs_error_t SrsSecurityTransport::on_dtls(char* data, int nb_data)
return dtls_->on_dtls(data, nb_data);
}
srs_error_t SrsSecurityTransport::on_dtls_alert(std::string type, std::string desc)
{
return session_->on_dtls_alert(type, desc);
}
srs_error_t SrsSecurityTransport::on_dtls_handshake_done()
{
srs_error_t err = srs_success;
@ -237,6 +242,11 @@ srs_error_t SrsPlaintextTransport::on_dtls(char* data, int nb_data)
return srs_success;
}
srs_error_t SrsPlaintextTransport::on_dtls_alert(std::string type, std::string desc)
{
return srs_success;
}
srs_error_t SrsPlaintextTransport::on_dtls_handshake_done()
{
srs_trace("RTC: DTLS handshake done.");
@ -2112,6 +2122,26 @@ srs_error_t SrsRtcConnection::on_connection_established()
return err;
}
srs_error_t SrsRtcConnection::on_dtls_alert(std::string type, std::string desc)
{
srs_error_t err = srs_success;
SrsRtcConnection* session = this;
// CN(Close Notify) is sent when client close the PeerConnection.
if (type == "warning" && desc == "CN") {
SrsContextRestore(_srs_context->get_id());
session->switch_to_context();
string username = session->username();
srs_trace("RTC: session DTLS alert, username=%s, summary: %s", username.c_str(), session->stat_->summary().c_str());
server_->dispose(session);
}
return err;
}
srs_error_t SrsRtcConnection::start_play(string stream_uri)
{
srs_error_t err = srs_success;

View file

@ -94,6 +94,7 @@ public:
virtual srs_error_t initialize(SrsSessionConfig* cfg) = 0;
virtual srs_error_t start_active_handshake() = 0;
virtual srs_error_t on_dtls(char* data, int nb_data) = 0;
virtual srs_error_t on_dtls_alert(std::string type, std::string desc) = 0;
public:
virtual srs_error_t protect_rtp(const char* plaintext, char* cipher, int& nb_cipher) = 0;
virtual srs_error_t protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher) = 0;
@ -118,6 +119,7 @@ public:
// When play role of dtls client, it send handshake.
srs_error_t start_active_handshake();
srs_error_t on_dtls(char* data, int nb_data);
srs_error_t on_dtls_alert(std::string type, std::string desc);
public:
// Encrypt the input plaintext to output cipher with nb_cipher bytes.
// @remark Note that the nb_cipher is the size of input plaintext, and
@ -165,6 +167,7 @@ public:
virtual srs_error_t initialize(SrsSessionConfig* cfg);
virtual srs_error_t start_active_handshake();
virtual srs_error_t on_dtls(char* data, int nb_data);
virtual srs_error_t on_dtls_alert(std::string type, std::string desc);
virtual srs_error_t on_dtls_handshake_done();
virtual srs_error_t on_dtls_application_data(const char* data, const int len);
virtual srs_error_t write_dtls_data(void* data, int size);
@ -508,6 +511,7 @@ public:
void set_hijacker(ISrsRtcConnectionHijacker* h);
public:
srs_error_t on_connection_established();
srs_error_t on_dtls_alert(std::string type, std::string desc);
srs_error_t start_play(std::string stream_uri);
srs_error_t start_publish(std::string stream_uri);
bool is_stun_timeout();

View file

@ -57,6 +57,9 @@ int srs_verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
// Print the information of SSL, DTLS alert as such.
void ssl_on_info(const SSL* dtls, int where, int ret)
{
SrsDtlsImpl* dtls_impl = (SrsDtlsImpl*)SSL_get_ex_data(dtls, 0);
srs_assert(dtls_impl);
const char* method;
int w = where& ~SSL_ST_MASK;
if (w & SSL_ST_CONNECT) {
@ -85,6 +88,9 @@ void ssl_on_info(const SSL* dtls, int where, int ret)
srs_error("DTLS: SSL3 alert method=%s type=%s, desc=%s(%s), where=%d, ret=%d, r1=%d", method, alert_type.c_str(),
alert_desc.c_str(), SSL_alert_desc_string_long(ret), where, ret, r1);
}
// Notify the DTLS to handle the ALERT message, which maybe means media connection disconnect.
dtls_impl->callback_by_ssl(alert_type, alert_desc);
} else if (where & SSL_CB_EXIT) {
if (ret == 0) {
srs_warn("DTLS: Fail method=%s state=%s(%s), where=%d, ret=%d, r1=%d", method, SSL_state_string(dtls),
@ -588,6 +594,15 @@ srs_error_t SrsDtlsImpl::get_srtp_key(std::string& recv_key, std::string& send_k
return err;
}
void SrsDtlsImpl::callback_by_ssl(std::string type, std::string desc)
{
srs_error_t err = srs_success;
if ((err = callback_->on_dtls_alert(type, desc)) != srs_success) {
srs_warn2("DTLSAlert", "DTLS: handler alert err %s", srs_error_desc(err).c_str());
srs_freep(err);
}
}
SrsDtlsClientImpl::SrsDtlsClientImpl(ISrsDtlsCallback* callback) : SrsDtlsImpl(callback)
{
trd = NULL;

View file

@ -93,6 +93,8 @@ public:
virtual srs_error_t on_dtls_application_data(const char* data, const int len) = 0;
// DTLS write dtls data.
virtual srs_error_t write_dtls_data(void* data, int size) = 0;
// Callback when DTLS Alert message.
virtual srs_error_t on_dtls_alert(std::string type, std::string desc) = 0;
};
// The state for DTLS client.
@ -135,6 +137,7 @@ protected:
void state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq);
public:
srs_error_t get_srtp_key(std::string& recv_key, std::string& send_key);
void callback_by_ssl(std::string type, std::string desc);
protected:
virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached) = 0;
virtual srs_error_t on_final_out_data(uint8_t* data, int size) = 0;

View file

@ -575,6 +575,19 @@ srs_error_t SrsRtcServer::setup_session2(SrsRtcConnection* session, SrsRequest*
return err;
}
void SrsRtcServer::dispose(SrsRtcConnection* session)
{
if (session->disposing_) {
return;
}
destroy(session);
if (handler) {
handler->on_timeout(session);
}
}
void SrsRtcServer::destroy(SrsRtcConnection* session)
{
if (session->disposing_) {
@ -582,12 +595,6 @@ void SrsRtcServer::destroy(SrsRtcConnection* session)
}
session->disposing_ = true;
SrsContextRestore(_srs_context->get_id());
session->switch_to_context();
string username = session->username();
srs_trace("RTC: session destroy, username=%s, summary: %s", username.c_str(), session->stat_->summary().c_str());
manager->remove(session);
}
@ -596,31 +603,6 @@ void SrsRtcServer::insert_into_id_sessions(const string& peer_id, SrsRtcConnecti
manager->add_with_id(peer_id, session);
}
void SrsRtcServer::check_and_clean_timeout_session()
{
for (int i = 0; i < (int)manager->size(); i++) {
SrsRtcConnection* session = dynamic_cast<SrsRtcConnection*>(manager->at(i));
srs_assert(session);
if (!session->is_stun_timeout()) {
continue;
}
// Now, we got the RTC session to cleanup, switch to its context
// to make all logs write to the "correct" pid+cid.
session->switch_to_context();
string username = session->username();
srs_trace("RTC: session STUN timeout, username=%s, summary: %s", username.c_str(), session->stat_->summary().c_str());
session->disposing_ = true;
manager->remove(session);
if (handler) {
handler->on_timeout(session);
}
}
}
SrsRtcConnection* SrsRtcServer::find_session_by_username(const std::string& username)
{
ISrsConnection* conn = manager->find_by_name(username);
@ -631,9 +613,24 @@ srs_error_t SrsRtcServer::notify(int type, srs_utime_t interval, srs_utime_t tic
{
srs_error_t err = srs_success;
// TODO: FIXME: Merge small functions.
// Check session timeout, put to zombies queue.
check_and_clean_timeout_session();
// Check all sessions and dispose the dead sessions.
for (int i = 0; i < (int)manager->size(); i++) {
SrsRtcConnection* session = dynamic_cast<SrsRtcConnection*>(manager->at(i));
srs_assert(session);
if (!session->is_stun_timeout()) {
continue;
}
SrsContextRestore(_srs_context->get_id());
session->switch_to_context();
string username = session->username();
srs_trace("RTC: session STUN timeout, username=%s, summary: %s", username.c_str(), session->stat_->summary().c_str());
// Destroy session and notify the handler.
dispose(session);
}
return err;
}

View file

@ -123,12 +123,12 @@ public:
// We start offering, create_session2 to generate offer, setup_session2 to handle answer.
srs_error_t create_session2(SrsRequest* req, SrsSdp& local_sdp, const std::string& mock_eip, bool unified_plan, SrsRtcConnection** psession);
srs_error_t setup_session2(SrsRtcConnection* session, SrsRequest* req, const SrsSdp& remote_sdp);
// Destroy the session from server.
// Destroy the session and notify the callback.
void dispose(SrsRtcConnection* session);
// Destroy the session from server, without notify callback.
void destroy(SrsRtcConnection* session);
public:
void insert_into_id_sessions(const std::string& peer_id, SrsRtcConnection* session);
private:
void check_and_clean_timeout_session();
public:
SrsRtcConnection* find_session_by_username(const std::string& ufrag);
// interface ISrsHourGlass

View file

@ -298,6 +298,7 @@ public:
virtual srs_error_t on_dtls_handshake_done();
virtual srs_error_t on_dtls_application_data(const char* data, const int len);
virtual srs_error_t write_dtls_data(void* data, int size);
virtual srs_error_t on_dtls_alert(std::string type, std::string desc);
virtual srs_error_t cycle();
};
@ -391,6 +392,11 @@ srs_error_t MockDtlsCallback::write_dtls_data(void* data, int size)
return srs_success;
}
srs_error_t MockDtlsCallback::on_dtls_alert(std::string type, std::string desc)
{
return srs_success;
}
srs_error_t MockDtlsCallback::cycle()
{
srs_error_t err = srs_success;