1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

Refactor RTP encrypt

This commit is contained in:
winlin 2020-04-13 15:24:41 +08:00
parent fa21df7bb8
commit 756826756a
2 changed files with 43 additions and 33 deletions

View file

@ -386,7 +386,7 @@ srs_error_t SrsDtlsSession::protect_rtp(char* out_buf, const char* in_buf, int&
return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect failed"); return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect failed");
} }
srs_error_t SrsDtlsSession::protect_rtp2(char* buf, int* pnn_buf, SrsRtpPacket2* pkt) srs_error_t SrsDtlsSession::protect_rtp2(void* rtp_hdr, int* len_ptr)
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
@ -394,14 +394,7 @@ srs_error_t SrsDtlsSession::protect_rtp2(char* buf, int* pnn_buf, SrsRtpPacket2*
return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect"); return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect");
} }
SrsBuffer stream(buf, *pnn_buf); if (srtp_protect(srtp_send, rtp_hdr, len_ptr) != 0) {
if ((err = pkt->encode(&stream)) != srs_success) {
return srs_error_wrap(err, "encode packet");
}
*pnn_buf = stream.pos();
if (srtp_protect(srtp_send, buf, pnn_buf) != 0) {
return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect"); return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect");
} }
@ -587,14 +580,15 @@ srs_error_t SrsRtcSenderThread::cycle()
continue; continue;
} }
int nn = 0;
int nn_rtp_pkts = 0; int nn_rtp_pkts = 0;
if ((err = send_messages(source, msgs.msgs, msg_count, sendonly_ukt, &nn, &nn_rtp_pkts)) != srs_success) { if ((err = send_messages(source, msgs.msgs, msg_count, sendonly_ukt, &nn_rtp_pkts)) != srs_success) {
srs_warn("send err %s", srs_error_summary(err).c_str()); srs_error_reset(err); srs_warn("send err %s", srs_error_summary(err).c_str()); srs_error_reset(err);
} }
int nn = 0;
for (int i = 0; i < msg_count; i++) { for (int i = 0; i < msg_count; i++) {
SrsSharedPtrMessage* msg = msgs.msgs[i]; SrsSharedPtrMessage* msg = msgs.msgs[i];
nn += msg->size;
srs_freep(msg); srs_freep(msg);
} }
@ -608,7 +602,7 @@ srs_error_t SrsRtcSenderThread::cycle()
srs_error_t SrsRtcSenderThread::send_messages( srs_error_t SrsRtcSenderThread::send_messages(
SrsSource* source, SrsSharedPtrMessage** msgs, int nb_msgs, SrsSource* source, SrsSharedPtrMessage** msgs, int nb_msgs,
SrsUdpMuxSocket* skt, int* pnn, int* pnn_rtp_pkts SrsUdpMuxSocket* skt, int* pnn_rtp_pkts
) { ) {
srs_error_t err = srs_success; srs_error_t err = srs_success;
@ -618,10 +612,34 @@ srs_error_t SrsRtcSenderThread::send_messages(
// Covert kernel messages to RTP packets. // Covert kernel messages to RTP packets.
vector<SrsRtpPacket2*> packets; vector<SrsRtpPacket2*> packets;
if ((err = messages_to_packets(source, msgs, nb_msgs, packets)) != srs_success) {
for (int j = 0; j < (int)packets.size(); j++) {
SrsRtpPacket2* packet = packets[j];
srs_freep(packet);
}
return err;
}
*pnn_rtp_pkts += (int)packets.size();
for (int j = 0; j < (int)packets.size(); j++) {
SrsRtpPacket2* packet = packets[j];
if ((err = send_packet(packet, skt)) != srs_success) {
srs_warn("send err %s", srs_error_summary(err).c_str()); srs_error_reset(err);
}
srs_freep(packet);
}
return err;
}
srs_error_t SrsRtcSenderThread::messages_to_packets(
SrsSource* source, SrsSharedPtrMessage** msgs, int nb_msgs, vector<SrsRtpPacket2*>& packets
) {
srs_error_t err = srs_success;
for (int i = 0; i < nb_msgs; i++) { for (int i = 0; i < nb_msgs; i++) {
SrsSharedPtrMessage* msg = msgs[i]; SrsSharedPtrMessage* msg = msgs[i];
*pnn += msg->size;
SrsRtpPacket2* packet = NULL; SrsRtpPacket2* packet = NULL;
if (msg->is_audio()) { if (msg->is_audio()) {
@ -632,7 +650,6 @@ srs_error_t SrsRtcSenderThread::send_messages(
} }
packets.push_back(packet); packets.push_back(packet);
} }
continue; continue;
} }
@ -675,16 +692,6 @@ srs_error_t SrsRtcSenderThread::send_messages(
} }
} }
*pnn_rtp_pkts += (int)packets.size();
for (int j = 0; j < (int)packets.size(); j++) {
SrsRtpPacket2* packet = packets[j];
if ((err = send_packet(packet, skt)) != srs_success) {
srs_warn("send err %s", srs_error_summary(err).c_str()); srs_error_reset(err);
}
srs_freep(packet);
}
return err; return err;
} }
@ -701,15 +708,10 @@ srs_error_t SrsRtcSenderThread::send_packet(SrsRtpPacket2* pkt, SrsUdpMuxSocket*
return srs_error_wrap(err, "fetch msghdr"); return srs_error_wrap(err, "fetch msghdr");
} }
char* buf = (char*)mhdr->msg_hdr.msg_iov->iov_base; char* buf = (char*)mhdr->msg_hdr.msg_iov->iov_base;
// Length of iov, default size.
int length = kRtpPacketSize; int length = kRtpPacketSize;
if (rtc_session->encrypt) { // Marshal packet to bytes.
if ((err = rtc_session->dtls_session->protect_rtp2(buf, &length, pkt)) != srs_success) { if (true) {
return srs_error_wrap(err, "srtp protect");
}
} else {
SrsBuffer stream(buf, length); SrsBuffer stream(buf, length);
if ((err = pkt->encode(&stream)) != srs_success) { if ((err = pkt->encode(&stream)) != srs_success) {
return srs_error_wrap(err, "encode packet"); return srs_error_wrap(err, "encode packet");
@ -717,6 +719,13 @@ srs_error_t SrsRtcSenderThread::send_packet(SrsRtpPacket2* pkt, SrsUdpMuxSocket*
length = stream.pos(); length = stream.pos();
} }
// Whether encrypt the RTP bytes.
if (rtc_session->encrypt) {
if ((err = rtc_session->dtls_session->protect_rtp2(buf, &length)) != srs_success) {
return srs_error_wrap(err, "srtp protect");
}
}
sockaddr_in* addr = (sockaddr_in*)skt->peer_addr(); sockaddr_in* addr = (sockaddr_in*)skt->peer_addr();
socklen_t addrlen = (socklen_t)skt->peer_addrlen(); socklen_t addrlen = (socklen_t)skt->peer_addrlen();

View file

@ -105,7 +105,7 @@ public:
srs_error_t on_dtls_application_data(const char* data, const int len); srs_error_t on_dtls_application_data(const char* data, const int len);
public: public:
srs_error_t protect_rtp(char* protected_buf, const char* ori_buf, int& nb_protected_buf); srs_error_t protect_rtp(char* protected_buf, const char* ori_buf, int& nb_protected_buf);
srs_error_t protect_rtp2(char* buf, int* pnn_buf, SrsRtpPacket2* pkt); srs_error_t protect_rtp2(void* rtp_hdr, int* len_ptr);
srs_error_t unprotect_rtp(char* unprotected_buf, const char* ori_buf, int& nb_unprotected_buf); srs_error_t unprotect_rtp(char* unprotected_buf, const char* ori_buf, int& nb_unprotected_buf);
srs_error_t protect_rtcp(char* protected_buf, const char* ori_buf, int& nb_protected_buf); srs_error_t protect_rtcp(char* protected_buf, const char* ori_buf, int& nb_protected_buf);
srs_error_t unprotect_rtcp(char* unprotected_buf, const char* ori_buf, int& nb_unprotected_buf); srs_error_t unprotect_rtcp(char* unprotected_buf, const char* ori_buf, int& nb_unprotected_buf);
@ -152,7 +152,8 @@ public:
public: public:
virtual srs_error_t cycle(); virtual srs_error_t cycle();
private: private:
srs_error_t send_messages(SrsSource* source, SrsSharedPtrMessage** msgs, int nb_msgs, SrsUdpMuxSocket* skt, int* pnn, int* pnn_rtp_pkts); srs_error_t send_messages(SrsSource* source, SrsSharedPtrMessage** msgs, int nb_msgs, SrsUdpMuxSocket* skt, int* pnn_rtp_pkts);
srs_error_t messages_to_packets(SrsSource* source, SrsSharedPtrMessage** msgs, int nb_msgs, std::vector<SrsRtpPacket2*>& packets);
srs_error_t send_packet(SrsRtpPacket2* pkt, SrsUdpMuxSocket* skt); srs_error_t send_packet(SrsRtpPacket2* pkt, SrsUdpMuxSocket* skt);
private: private:
srs_error_t packet_opus(SrsSample* sample, SrsRtpPacket2** ppacket); srs_error_t packet_opus(SrsSample* sample, SrsRtpPacket2** ppacket);