From ef279a8b1ea2c003dd0d5a84ca164b8c6b3bebe0 Mon Sep 17 00:00:00 2001 From: winlin Date: Sun, 7 Feb 2021 16:57:48 +0800 Subject: [PATCH] RTC: Refine the SRTP protect api --- trunk/src/app/srs_app_rtc_conn.cpp | 66 ++++++++++-------------------- trunk/src/app/srs_app_rtc_conn.hpp | 31 +++++++------- trunk/src/app/srs_app_rtc_dtls.cpp | 29 ++----------- trunk/src/app/srs_app_rtc_dtls.hpp | 10 +---- 4 files changed, 41 insertions(+), 95 deletions(-) diff --git a/trunk/src/app/srs_app_rtc_conn.cpp b/trunk/src/app/srs_app_rtc_conn.cpp index 89a691477..74a232d3a 100644 --- a/trunk/src/app/srs_app_rtc_conn.cpp +++ b/trunk/src/app/srs_app_rtc_conn.cpp @@ -170,20 +170,14 @@ srs_error_t SrsSecurityTransport::srtp_initialize() return err; } -srs_error_t SrsSecurityTransport::protect_rtp(const char* plaintext, char* cipher, int& nb_cipher) +srs_error_t SrsSecurityTransport::protect_rtp(void* packet, int* nb_cipher) { - return srtp_->protect_rtp(plaintext, cipher, nb_cipher); + return srtp_->protect_rtp(packet, nb_cipher); } -srs_error_t SrsSecurityTransport::protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher) +srs_error_t SrsSecurityTransport::protect_rtcp(void* packet, int* nb_cipher) { - return srtp_->protect_rtcp(plaintext, cipher, nb_cipher); -} - -// TODO: FIXME: Merge with protect_rtp. -srs_error_t SrsSecurityTransport::protect_rtp2(void* rtp_hdr, int* len_ptr) -{ - return srtp_->protect_rtp2(rtp_hdr, len_ptr); + return srtp_->protect_rtcp(packet, nb_cipher); } srs_error_t SrsSecurityTransport::unprotect_rtp(void* packet, int* nb_plaintext) @@ -204,17 +198,12 @@ SrsSemiSecurityTransport::~SrsSemiSecurityTransport() { } -srs_error_t SrsSemiSecurityTransport::protect_rtp(const char* plaintext, char* cipher, int& nb_cipher) +srs_error_t SrsSemiSecurityTransport::protect_rtp(void* packet, int* nb_cipher) { return srs_success; } -srs_error_t SrsSemiSecurityTransport::protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher) -{ - return srs_success; -} - -srs_error_t SrsSemiSecurityTransport::protect_rtp2(void* rtp_hdr, int* len_ptr) +srs_error_t SrsSemiSecurityTransport::protect_rtcp(void* packet, int* nb_cipher) { return srs_success; } @@ -264,19 +253,12 @@ srs_error_t SrsPlaintextTransport::write_dtls_data(void* data, int size) return srs_success; } -srs_error_t SrsPlaintextTransport::protect_rtp(const char* plaintext, char* cipher, int& nb_cipher) +srs_error_t SrsPlaintextTransport::protect_rtp(void* packet, int* nb_cipher) { - memcpy(cipher, plaintext, nb_cipher); return srs_success; } -srs_error_t SrsPlaintextTransport::protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher) -{ - memcpy(cipher, plaintext, nb_cipher); - return srs_success; -} - -srs_error_t SrsPlaintextTransport::protect_rtp2(void* rtp_hdr, int* len_ptr) +srs_error_t SrsPlaintextTransport::protect_rtcp(void* packet, int* nb_cipher) { return srs_success; } @@ -1308,12 +1290,11 @@ srs_error_t SrsRtcPublishStream::send_periodic_twcc() } int nb_protected_buf = buffer->pos(); - char protected_buf[kRtpPacketSize]; - if ((err = session_->transport_->protect_rtcp(pkt, protected_buf, nb_protected_buf)) != srs_success) { + if ((err = session_->transport_->protect_rtcp(pkt, &nb_protected_buf)) != srs_success) { return srs_error_wrap(err, "protect rtcp, size=%u", nb_protected_buf); } - return session_->sendonly_skt->sendto(protected_buf, nb_protected_buf, 0); + return session_->sendonly_skt->sendto(pkt, nb_protected_buf, 0); } srs_error_t SrsRtcPublishStream::on_rtcp(SrsRtcpCommon* rtcp) @@ -2274,12 +2255,11 @@ srs_error_t SrsRtcConnection::send_rtcp(char *data, int nb_data) srs_error_t err = srs_success; int nb_buf = nb_data; - char protected_buf[kRtpPacketSize]; - if ((err = transport_->protect_rtcp(data, protected_buf, nb_buf)) != srs_success) { + if ((err = transport_->protect_rtcp(data, &nb_buf)) != srs_success) { return srs_error_wrap(err, "protect rtcp"); } - if ((err = sendonly_skt->sendto(protected_buf, nb_buf, 0)) != srs_success) { + if ((err = sendonly_skt->sendto(data, nb_buf, 0)) != srs_success) { return srs_error_wrap(err, "send"); } @@ -2305,12 +2285,11 @@ void SrsRtcConnection::check_send_nacks(SrsRtpNackForReceiver* nack, uint32_t ss rtcpNack.encode(&stream); // TODO: FIXME: Check error. - char protected_buf[kRtpPacketSize]; int nb_protected_buf = stream.pos(); - transport_->protect_rtcp(stream.data(), protected_buf, nb_protected_buf); + transport_->protect_rtcp(stream.data(), &nb_protected_buf); // TODO: FIXME: Check error. - sendonly_skt->sendto(protected_buf, nb_protected_buf, 0); + sendonly_skt->sendto(stream.data(), nb_protected_buf, 0); } srs_error_t SrsRtcConnection::send_rtcp_rr(uint32_t ssrc, SrsRtpRingBuffer* rtp_queue, const uint64_t& last_send_systime, const SrsNtp& last_send_ntp) @@ -2350,13 +2329,12 @@ srs_error_t SrsRtcConnection::send_rtcp_rr(uint32_t ssrc, SrsRtpRingBuffer* rtp_ srs_info("RR ssrc=%u, fraction_lost=%u, cumulative_number_of_packets_lost=%u, extended_highest_sequence=%u, interarrival_jitter=%u", ssrc, fraction_lost, cumulative_number_of_packets_lost, extended_highest_sequence, interarrival_jitter); - char protected_buf[kRtpPacketSize]; int nb_protected_buf = stream.pos(); - if ((err = transport_->protect_rtcp(stream.data(), protected_buf, nb_protected_buf)) != srs_success) { + if ((err = transport_->protect_rtcp(stream.data(), &nb_protected_buf)) != srs_success) { return srs_error_wrap(err, "protect rtcp rr"); } - return sendonly_skt->sendto(protected_buf, nb_protected_buf, 0); + return sendonly_skt->sendto(stream.data(), nb_protected_buf, 0); } srs_error_t SrsRtcConnection::send_rtcp_xr_rrtr(uint32_t ssrc) @@ -2403,13 +2381,12 @@ srs_error_t SrsRtcConnection::send_rtcp_xr_rrtr(uint32_t ssrc) stream.write_4bytes(cur_ntp.ntp_second_); stream.write_4bytes(cur_ntp.ntp_fractions_); - char protected_buf[kRtpPacketSize]; int nb_protected_buf = stream.pos(); - if ((err = transport_->protect_rtcp(stream.data(), protected_buf, nb_protected_buf)) != srs_success) { + if ((err = transport_->protect_rtcp(stream.data(), &nb_protected_buf)) != srs_success) { return srs_error_wrap(err, "protect rtcp xr"); } - return sendonly_skt->sendto(protected_buf, nb_protected_buf, 0); + return sendonly_skt->sendto(stream.data(), nb_protected_buf, 0); } srs_error_t SrsRtcConnection::send_rtcp_fb_pli(uint32_t ssrc, const SrsContextId& cid_of_subscriber) @@ -2434,13 +2411,12 @@ srs_error_t SrsRtcConnection::send_rtcp_fb_pli(uint32_t ssrc, const SrsContextId _srs_blackhole->sendto(stream.data(), stream.pos()); } - char protected_buf[kRtpPacketSize]; int nb_protected_buf = stream.pos(); - if ((err = transport_->protect_rtcp(stream.data(), protected_buf, nb_protected_buf)) != srs_success) { + if ((err = transport_->protect_rtcp(stream.data(), &nb_protected_buf)) != srs_success) { return srs_error_wrap(err, "protect rtcp psfb pli"); } - return sendonly_skt->sendto(protected_buf, nb_protected_buf, 0); + return sendonly_skt->sendto(stream.data(), nb_protected_buf, 0); } void SrsRtcConnection::simulate_nack_drop(int nn) @@ -2491,7 +2467,7 @@ srs_error_t SrsRtcConnection::do_send_packets(const std::vector& // Cipher RTP to SRTP packet. if (true) { int nn_encrypt = (int)iov->iov_len; - if ((err = transport_->protect_rtp2(iov->iov_base, &nn_encrypt)) != srs_success) { + if ((err = transport_->protect_rtp(iov->iov_base, &nn_encrypt)) != srs_success) { return srs_error_wrap(err, "srtp protect"); } iov->iov_len = (size_t)nn_encrypt; diff --git a/trunk/src/app/srs_app_rtc_conn.hpp b/trunk/src/app/srs_app_rtc_conn.hpp index 29e348359..fadee0059 100644 --- a/trunk/src/app/srs_app_rtc_conn.hpp +++ b/trunk/src/app/srs_app_rtc_conn.hpp @@ -97,9 +97,12 @@ public: virtual srs_error_t on_dtls(char* data, int nb_data) = 0; virtual srs_error_t on_dtls_alert(std::string type, std::string desc) = 0; public: - virtual srs_error_t protect_rtp(const char* plaintext, char* cipher, int& nb_cipher) = 0; - virtual srs_error_t protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher) = 0; - virtual srs_error_t protect_rtp2(void* rtp_hdr, int* len_ptr) = 0; + // Encrypt the packet(paintext) to cipher, which is aso the packet ptr. + // The nb_cipher should be initialized to the size of cipher, with some paddings. + virtual srs_error_t protect_rtp(void* packet, int* nb_cipher) = 0; + virtual srs_error_t protect_rtcp(void* packet, int* nb_cipher) = 0; + // Decrypt the packet(cipher) to plaintext, which is also the packet ptr. + // The nb_plaintext should be initialized to the size of cipher. virtual srs_error_t unprotect_rtp(void* packet, int* nb_plaintext) = 0; virtual srs_error_t unprotect_rtcp(void* packet, int* nb_plaintext) = 0; }; @@ -122,14 +125,10 @@ public: srs_error_t on_dtls(char* data, int nb_data); srs_error_t on_dtls_alert(std::string type, std::string desc); public: - // Encrypt the input plaintext to output cipher with nb_cipher bytes. - // @remark Note that the nb_cipher is the size of input plaintext, and - // it also is the length of output cipher when return. - srs_error_t protect_rtp(const char* plaintext, char* cipher, int& nb_cipher); - srs_error_t protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher); - // Encrypt the input rtp_hdr with *len_ptr bytes. - // @remark the input plaintext and out cipher reuse rtp_hdr. - srs_error_t protect_rtp2(void* rtp_hdr, int* len_ptr); + // Encrypt the packet(paintext) to cipher, which is aso the packet ptr. + // The nb_cipher should be initialized to the size of cipher, with some paddings. + srs_error_t protect_rtp(void* packet, int* nb_cipher); + srs_error_t protect_rtcp(void* packet, int* nb_cipher); // Decrypt the packet(cipher) to plaintext, which is also the packet ptr. // The nb_plaintext should be initialized to the size of cipher. srs_error_t unprotect_rtp(void* packet, int* nb_plaintext); @@ -150,9 +149,8 @@ public: SrsSemiSecurityTransport(SrsRtcConnection* s); virtual ~SrsSemiSecurityTransport(); public: - virtual srs_error_t protect_rtp(const char* plaintext, char* cipher, int& nb_cipher); - virtual srs_error_t protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher); - virtual srs_error_t protect_rtp2(void* rtp_hdr, int* len_ptr); + srs_error_t protect_rtp(void* packet, int* nb_cipher); + srs_error_t protect_rtcp(void* packet, int* nb_cipher); }; // Plaintext transport, without DTLS or SRTP. @@ -172,9 +170,8 @@ public: virtual srs_error_t on_dtls_application_data(const char* data, const int len); virtual srs_error_t write_dtls_data(void* data, int size); public: - virtual srs_error_t protect_rtp(const char* plaintext, char* cipher, int& nb_cipher); - virtual srs_error_t protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher); - virtual srs_error_t protect_rtp2(void* rtp_hdr, int* len_ptr); + srs_error_t protect_rtp(void* packet, int* nb_cipher); + srs_error_t protect_rtcp(void* packet, int* nb_cipher); srs_error_t unprotect_rtp(void* packet, int* nb_plaintext); srs_error_t unprotect_rtcp(void* packet, int* nb_plaintext); }; diff --git a/trunk/src/app/srs_app_rtc_dtls.cpp b/trunk/src/app/srs_app_rtc_dtls.cpp index ba335c0c6..d2d6bc364 100644 --- a/trunk/src/app/srs_app_rtc_dtls.cpp +++ b/trunk/src/app/srs_app_rtc_dtls.cpp @@ -955,7 +955,7 @@ srs_error_t SrsSRTP::initialize(string recv_key, std::string send_key) return err; } -srs_error_t SrsSRTP::protect_rtp(const char* plaintext, char* cipher, int& nb_cipher) +srs_error_t SrsSRTP::protect_rtp(void* packet, int* nb_cipher) { srs_error_t err = srs_success; @@ -964,17 +964,15 @@ srs_error_t SrsSRTP::protect_rtp(const char* plaintext, char* cipher, int& nb_ci return srs_error_new(ERROR_RTC_SRTP_PROTECT, "not ready"); } - memcpy(cipher, plaintext, nb_cipher); - srtp_err_status_t r0 = srtp_err_status_ok; - if ((r0 = srtp_protect(send_ctx_, cipher, &nb_cipher)) != srtp_err_status_ok) { + if ((r0 = srtp_protect(send_ctx_, packet, nb_cipher)) != srtp_err_status_ok) { return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect r0=%u", r0); } return err; } -srs_error_t SrsSRTP::protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher) +srs_error_t SrsSRTP::protect_rtcp(void* packet, int* nb_cipher) { srs_error_t err = srs_success; @@ -983,33 +981,14 @@ srs_error_t SrsSRTP::protect_rtcp(const char* plaintext, char* cipher, int& nb_c return srs_error_new(ERROR_RTC_SRTP_PROTECT, "not ready"); } - memcpy(cipher, plaintext, nb_cipher); - srtp_err_status_t r0 = srtp_err_status_ok; - if ((r0 = srtp_protect_rtcp(send_ctx_, cipher, &nb_cipher)) != srtp_err_status_ok) { + if ((r0 = srtp_protect_rtcp(send_ctx_, packet, nb_cipher)) != srtp_err_status_ok) { return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtcp protect r0=%u", r0); } return err; } -srs_error_t SrsSRTP::protect_rtp2(void* rtp_hdr, int* len_ptr) -{ - srs_error_t err = srs_success; - - // If DTLS/SRTP is not ready, fail. - if (!send_ctx_) { - return srs_error_new(ERROR_RTC_SRTP_PROTECT, "not ready"); - } - - srtp_err_status_t r0 = srtp_err_status_ok; - if ((r0 = srtp_protect(send_ctx_, rtp_hdr, len_ptr)) != srtp_err_status_ok) { - return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect r0=%u", r0); - } - - return err; -} - srs_error_t SrsSRTP::unprotect_rtp(void* packet, int* nb_plaintext) { srs_error_t err = srs_success; diff --git a/trunk/src/app/srs_app_rtc_dtls.hpp b/trunk/src/app/srs_app_rtc_dtls.hpp index 47e509989..75e099f83 100644 --- a/trunk/src/app/srs_app_rtc_dtls.hpp +++ b/trunk/src/app/srs_app_rtc_dtls.hpp @@ -224,14 +224,8 @@ public: // Intialize srtp context with recv_key and send_key. srs_error_t initialize(std::string recv_key, std::string send_key); public: - // Encrypt the input plaintext to output cipher with nb_cipher bytes. - // @remark Note that the nb_cipher is the size of input plaintext, and - // it also is the length of output cipher when return. - srs_error_t protect_rtp(const char* plaintext, char* cipher, int& nb_cipher); - srs_error_t protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher); - // Encrypt the input rtp_hdr with *len_ptr bytes. - // @remark the input plaintext and out cipher reuse rtp_hdr. - srs_error_t protect_rtp2(void* rtp_hdr, int* len_ptr); + srs_error_t protect_rtp(void* packet, int* nb_cipher); + srs_error_t protect_rtcp(void* packet, int* nb_cipher); srs_error_t unprotect_rtp(void* packet, int* nb_plaintext); srs_error_t unprotect_rtcp(void* packet, int* nb_plaintext); };