1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

make code easy, wrap udp remux socket

This commit is contained in:
xiaozhihong 2020-03-08 00:30:31 +08:00
parent b730458d51
commit c62901a3ac
11 changed files with 245 additions and 294 deletions

View file

@ -279,7 +279,7 @@ srs_error_t SrsSdp::parse_attr(const string& line)
return err;
}
SrsDtlsSession::SrsDtlsSession(srs_netfd_t lfd, const sockaddr* f, int fl)
SrsDtlsSession::SrsDtlsSession()
{
dtls = NULL;
bio_in = NULL;
@ -291,10 +291,6 @@ SrsDtlsSession::SrsDtlsSession(srs_netfd_t lfd, const sockaddr* f, int fl)
srtp_send = NULL;
srtp_recv = NULL;
fd = lfd;
from = f;
fromlen = fl;
handshake_done = false;
}
@ -302,7 +298,7 @@ SrsDtlsSession::~SrsDtlsSession()
{
}
srs_error_t SrsDtlsSession::handshake()
srs_error_t SrsDtlsSession::handshake(SrsUdpRemuxSocket* udp_remux_socket)
{
srs_error_t err = srs_success;
@ -314,11 +310,7 @@ srs_error_t SrsDtlsSession::handshake()
int ssl_err = SSL_get_error(dtls, ret);
switch(ssl_err) {
case SSL_ERROR_NONE: {
srs_trace("dtls handshake done");
handshake_done = true;
srtp_init();
srtp_sender_side_init();
srtp_receiver_side_init();
err = on_dtls_handshake_done();
}
break;
@ -337,25 +329,25 @@ srs_error_t SrsDtlsSession::handshake()
if (out_bio_len) {
srs_trace("send dtls handshake data");
srs_sendto(fd, out_bio_data, out_bio_len, from, fromlen, 0);
udp_remux_socket->sendto(out_bio_data, out_bio_len, 0);
}
return err;
}
srs_error_t SrsDtlsSession::on_dtls(const char* data, const int len)
srs_error_t SrsDtlsSession::on_dtls(SrsUdpRemuxSocket* udp_remux_socket)
{
srs_error_t err = srs_success;
if (! handshake_done) {
BIO_reset(bio_in);
BIO_reset(bio_out);
BIO_write(bio_in, data, len);
BIO_write(bio_in, udp_remux_socket->data(), udp_remux_socket->size());
handshake();
handshake(udp_remux_socket);
} else {
BIO_reset(bio_in);
BIO_reset(bio_out);
BIO_write(bio_in, data, len);
BIO_write(bio_in, udp_remux_socket->data(), udp_remux_socket->size());
while (BIO_ctrl_pending(bio_in) > 0) {
char dtls_read_buf[8092];
@ -370,6 +362,14 @@ srs_error_t SrsDtlsSession::on_dtls(const char* data, const int len)
return err;
}
srs_error_t SrsDtlsSession::on_dtls_handshake_done()
{
srs_trace("dtls handshake done");
handshake_done = true;
return srtp_init();
}
srs_error_t SrsDtlsSession::on_dtls_application_data(const char* buf, const int nb_buf)
{
srs_error_t err = srs_success;
@ -377,8 +377,7 @@ srs_error_t SrsDtlsSession::on_dtls_application_data(const char* buf, const int
return err;
}
void SrsDtlsSession::send_client_hello()
void SrsDtlsSession::send_client_hello(SrsUdpRemuxSocket* udp_remux_socket)
{
if (dtls == NULL) {
srs_trace("send client hello");
@ -391,7 +390,7 @@ void SrsDtlsSession::send_client_hello()
SSL_set_bio(dtls, bio_in, bio_out);
handshake();
handshake(udp_remux_socket);
}
}
@ -400,8 +399,8 @@ srs_error_t SrsDtlsSession::srtp_init()
srs_error_t err = srs_success;
unsigned char material[SRTP_MASTER_KEY_LEN * 2] = {0}; // client(SRTP_MASTER_KEY_KEY_LEN + SRTP_MASTER_KEY_SALT_LEN) + server
char dtls_srtp_lable[] = "EXTRACTOR-dtls_srtp";
if (! SSL_export_keying_material(dtls, material, sizeof(material), dtls_srtp_lable, strlen(dtls_srtp_lable), NULL, 0, 0)) {
static string dtls_srtp_lable = "EXTRACTOR-dtls_srtp";
if (! SSL_export_keying_material(dtls, material, sizeof(material), dtls_srtp_lable.c_str(), dtls_srtp_lable.size(), NULL, 0, 0)) {
return srs_error_wrap(err, "SSL_export_keying_material failed");
}
@ -418,8 +417,15 @@ srs_error_t SrsDtlsSession::srtp_init()
client_key = sClientMasterKey + sClientMasterSalt;
server_key = sServerMasterKey + sServerMasterSalt;
srtp_sender_side_init();
srtp_receiver_side_init();
if (srtp_sender_side_init() != srs_success) {
return srs_error_wrap(err, "srtp sender size init failed");
}
if (srtp_receiver_side_init() != srs_success) {
return srs_error_wrap(err, "srtp receiver size init failed");
}
return err;
}
srs_error_t SrsDtlsSession::srtp_sender_side_init()
@ -435,7 +441,8 @@ srs_error_t SrsDtlsSession::srtp_sender_side_init()
policy.ssrc.type = ssrc_any_outbound;
policy.ssrc.value = 0;
policy.window_size = 8192; // seq 相差8192认为无效
// TODO: adjust window_size
policy.window_size = 8192;
policy.allow_repeat_tx = 1;
policy.next = NULL;
@ -444,6 +451,7 @@ srs_error_t SrsDtlsSession::srtp_sender_side_init()
policy.key = key;
if (srtp_create(&srtp_send, &policy) != 0) {
delete [] key;
return srs_error_wrap(err, "srtp_create failed");
}
@ -465,7 +473,8 @@ srs_error_t SrsDtlsSession::srtp_receiver_side_init()
policy.ssrc.type = ssrc_any_inbound;
policy.ssrc.value = 0;
policy.window_size = 8192; // seq 相差8192认为无效
// TODO: adjust window_size
policy.window_size = 8192;
policy.allow_repeat_tx = 1;
policy.next = NULL;
@ -474,6 +483,7 @@ srs_error_t SrsDtlsSession::srtp_receiver_side_init()
policy.key = key;
if (srtp_create(&srtp_recv, &policy) != 0) {
delete [] key;
return srs_error_wrap(err, "srtp_create failed");
}
@ -482,8 +492,9 @@ srs_error_t SrsDtlsSession::srtp_receiver_side_init()
return err;
}
SrsRtcSession::SrsRtcSession()
SrsRtcSession::SrsRtcSession(SrsRtcServer* svr)
{
rtc_server = svr;
session_state = INIT;
dtls_session = NULL;
}
@ -492,42 +503,73 @@ SrsRtcSession::~SrsRtcSession()
{
}
srs_error_t SrsRtcSession::on_binding_request(const SrsStunPacket& stun_packet, const string& peer_ip, const uint16_t peer_port,
SrsStunPacket& stun_binding_response)
srs_error_t SrsRtcSession::on_stun(SrsUdpRemuxSocket* udp_remux_socket, SrsStunPacket* stun_req)
{
srs_error_t err = srs_success;
stun_binding_response.set_message_type(BindingResponse);
stun_binding_response.set_local_ufrag(stun_packet.get_remote_ufrag());
stun_binding_response.set_remote_ufrag(stun_packet.get_local_ufrag());
stun_binding_response.set_transcation_id(stun_packet.get_transcation_id());
stun_binding_response.set_mapped_address(be32toh(inet_addr(peer_ip.c_str())));
stun_binding_response.set_mapped_port(peer_port);
if (stun_req->is_binding_request()) {
if (on_binding_request(udp_remux_socket, stun_req) != srs_success) {
return srs_error_wrap(err, "stun binding request failed");
}
}
return err;
}
srs_error_t SrsRtcSession::send_client_hello(srs_netfd_t fd, const sockaddr* from, int fromlen)
srs_error_t SrsRtcSession::on_binding_request(SrsUdpRemuxSocket* udp_remux_socket, SrsStunPacket* stun_req)
{
if (dtls_session == NULL) {
dtls_session = new SrsDtlsSession(fd, from, fromlen);
srs_error_t err = srs_success;
SrsStunPacket stun_binding_response;
char buf[1460];
SrsBuffer* stream = new SrsBuffer(buf, sizeof(buf));
SrsAutoFree(SrsBuffer, stream);
stun_binding_response.set_message_type(BindingResponse);
stun_binding_response.set_local_ufrag(stun_req->get_remote_ufrag());
stun_binding_response.set_remote_ufrag(stun_req->get_local_ufrag());
stun_binding_response.set_transcation_id(stun_req->get_transcation_id());
// FIXME: inet_addr is deprecated, IPV6 support
stun_binding_response.set_mapped_address(be32toh(inet_addr(udp_remux_socket->get_peer_ip().c_str())));
stun_binding_response.set_mapped_port(udp_remux_socket->get_peer_port());
if (stun_binding_response.encode(get_local_sdp()->get_ice_pwd(), stream) != srs_success) {
return srs_error_wrap(err, "stun binding response encode failed");
}
dtls_session->send_client_hello();
if (udp_remux_socket->sendto(stream->data(), stream->pos(), 0) <= 0) {
return srs_error_wrap(err, "stun binding response send failed");
}
if (get_session_state() == WAITING_STUN) {
set_session_state(DOING_DTLS_HANDSHAKE);
send_client_hello(udp_remux_socket);
string peer_id = udp_remux_socket->get_peer_id();
rtc_server->insert_into_id_sessions(peer_id, this);
}
// TODO: dtls send client retry
return err;
}
srs_error_t SrsRtcSession::on_dtls(const char* buf, const int nb_buf)
srs_error_t SrsRtcSession::send_client_hello(SrsUdpRemuxSocket* udp_remux_socket)
{
dtls_session->on_dtls(buf, nb_buf);
if (dtls_session == NULL) {
dtls_session = new SrsDtlsSession();
}
dtls_session->send_client_hello(udp_remux_socket);
}
srs_error_t SrsRtcSession::send_packet()
srs_error_t SrsRtcSession::on_dtls(SrsUdpRemuxSocket* udp_remux_socket)
{
return dtls_session->on_dtls(udp_remux_socket);
}
SrsRtcServer::SrsRtcServer(SrsServer* svr)
SrsRtcServer::SrsRtcServer()
{
server = svr;
}
SrsRtcServer::~SrsRtcServer()
@ -541,32 +583,32 @@ srs_error_t SrsRtcServer::initialize()
return err;
}
srs_error_t SrsRtcServer::on_udp_packet(srs_netfd_t fd, const string& peer_ip, const int peer_port,
const sockaddr* from, const int fromlen, const char* data, const int size)
srs_error_t SrsRtcServer::on_udp_packet(SrsUdpRemuxSocket* udp_remux_socket)
{
srs_error_t err = srs_success;
if (is_stun(data, size)) {
return on_stun(fd, peer_ip, peer_port, from, fromlen, data, size);
} else if (is_dtls(data, size)) {
srs_trace("dtls");
return on_dtls(fd, peer_ip, peer_port, from, fromlen, data, size);
} else if (is_rtp_or_rtcp(data, size)) {
return on_rtp_or_rtcp(fd, peer_ip, peer_port, from, fromlen, data, size);
if (is_stun(udp_remux_socket->data(), udp_remux_socket->size())) {
return on_stun(udp_remux_socket);
} else if (is_dtls(udp_remux_socket->data(), udp_remux_socket->size())) {
return on_dtls(udp_remux_socket);
} else if (is_rtp_or_rtcp(udp_remux_socket->data(), udp_remux_socket->size())) {
return on_rtp_or_rtcp(udp_remux_socket);
}
return srs_error_wrap(err, "unknown packet type");
return srs_error_wrap(err, "unknown udp packet type");
}
SrsRtcSession* SrsRtcServer::create_rtc_session(const SrsSdp& remote_sdp, SrsSdp& local_sdp)
{
SrsRtcSession* session = new SrsRtcSession();
SrsRtcSession* session = new SrsRtcSession(this);
std::string local_ufrag = gen_random_str(8);
std::string local_pwd = gen_random_str(32);
std::string local_ufrag = "";
while (true) {
bool ret = map_ufrag_sessions.insert(make_pair(remote_sdp.get_ice_ufrag(), session)).second;
local_ufrag = gen_random_str(8);
std::string username = local_ufrag + ":" + remote_sdp.get_ice_ufrag();
bool ret = map_username_session.insert(make_pair(username, session)).second;
if (ret) {
break;
}
@ -583,105 +625,71 @@ SrsRtcSession* SrsRtcServer::create_rtc_session(const SrsSdp& remote_sdp, SrsSdp
return session;
}
SrsRtcSession* SrsRtcServer::find_rtc_session_by_ip_port(const string& peer_ip, const uint16_t peer_port)
SrsRtcSession* SrsRtcServer::find_rtc_session_by_peer_id(const string& peer_id)
{
ostringstream os;
os << peer_ip << ":" << peer_port;
string key = os.str();
map<string, SrsRtcSession*>::iterator iter = map_ip_port_sessions.find(key);
if (iter == map_ip_port_sessions.end()) {
map<string, SrsRtcSession*>::iterator iter = map_id_session.find(peer_id);
if (iter == map_id_session.end()) {
return NULL;
}
return iter->second;
}
srs_error_t SrsRtcServer::on_stun(srs_netfd_t fd, const string& peer_ip, const int peer_port,
const sockaddr* from, const int fromlen, const char* data, const int size)
srs_error_t SrsRtcServer::on_stun(SrsUdpRemuxSocket* udp_remux_socket)
{
srs_error_t err = srs_success;
srs_trace("peer %s:%d stun", peer_ip.c_str(), peer_port);
srs_trace("recv stun packet from %s", udp_remux_socket->get_peer_id().c_str());
SrsStunPacket stun_req;
if (stun_req.decode(data, size) != srs_success) {
return srs_error_wrap(err, "decode stun failed");
if (stun_req.decode(udp_remux_socket->data(), udp_remux_socket->size()) != srs_success) {
return srs_error_wrap(err, "decode stun packet failed");
}
std::string remote_ufrag = stun_req.get_remote_ufrag();
SrsRtcSession* rtc_session = find_rtc_session_by_ufrag(remote_ufrag);
std::string username = stun_req.get_username();
SrsRtcSession* rtc_session = find_rtc_session_by_username(username);
if (rtc_session == NULL) {
return srs_error_wrap(err, "can not find rtc_session, ufrag=%s", remote_ufrag.c_str());
return srs_error_wrap(err, "can not find rtc_session, stun username=%s", username.c_str());
}
SrsStunPacket stun_rsp;
char buf[1460];
SrsBuffer* stream = new SrsBuffer(buf, sizeof(buf));
SrsAutoFree(SrsBuffer, stream);
if (stun_req.is_binding_request()) {
if (rtc_session->on_binding_request(stun_req, peer_ip, peer_port, stun_rsp) != srs_success) {
return srs_error_wrap(err, "stun binding request failed");
}
}
if (stun_rsp.encode(rtc_session->get_local_sdp()->get_ice_pwd(), stream) != srs_success) {
return srs_error_wrap(err, "stun rsp encode failed");
}
srs_sendto(fd, stream->data(), stream->pos(), from, fromlen, 0);
if (rtc_session->get_session_state() == WAITING_STUN) {
rtc_session->set_session_state(DOING_DTLS_HANDSHAKE);
rtc_session->send_client_hello(fd, from, fromlen);
insert_into_ip_port_sessions(peer_ip, peer_port, rtc_session);
}
return err;
return rtc_session->on_stun(udp_remux_socket, &stun_req);
}
srs_error_t SrsRtcServer::on_dtls(srs_netfd_t fd, const string& peer_ip, const int peer_port,
const sockaddr* from, const int fromlen, const char* data, const int size)
srs_error_t SrsRtcServer::on_dtls(SrsUdpRemuxSocket* udp_remux_socket)
{
srs_error_t err = srs_success;
srs_trace("on dtls");
// FIXME
SrsRtcSession* rtc_session = find_rtc_session_by_ip_port(peer_ip, peer_port);
SrsRtcSession* rtc_session = find_rtc_session_by_peer_id(udp_remux_socket->get_peer_id());
if (rtc_session == NULL) {
return srs_error_wrap(err, "can not find rtc session by ip=%s, port=%u", peer_ip.c_str(), peer_port);
return srs_error_wrap(err, "can not find rtc session by peer_id=%s", udp_remux_socket->get_peer_id().c_str());
}
rtc_session->on_dtls(data, size);
rtc_session->on_dtls(udp_remux_socket);
return err;
}
srs_error_t SrsRtcServer::on_rtp_or_rtcp(srs_netfd_t fd, const string& peer_ip, const int peer_port,
const sockaddr* from, const int fromlen, const char* data, const int size)
srs_error_t SrsRtcServer::on_rtp_or_rtcp(SrsUdpRemuxSocket* udp_remux_socket)
{
srs_error_t err = srs_success;
srs_trace("on rtp/rtcp");
return err;
}
SrsRtcSession* SrsRtcServer::find_rtc_session_by_ufrag(const std::string& ufrag)
SrsRtcSession* SrsRtcServer::find_rtc_session_by_username(const std::string& username)
{
map<string, SrsRtcSession*>::iterator iter = map_ufrag_sessions.find(ufrag);
if (iter == map_ufrag_sessions.end()) {
map<string, SrsRtcSession*>::iterator iter = map_username_session.find(username);
if (iter == map_username_session.end()) {
return NULL;
}
return iter->second;
}
bool SrsRtcServer::insert_into_ip_port_sessions(const string& peer_ip, const uint16_t peer_port, SrsRtcSession* rtc_session)
bool SrsRtcServer::insert_into_id_sessions(const string& peer_id, SrsRtcSession* rtc_session)
{
ostringstream os;
os << peer_ip << ":" << peer_port;
string key = os.str();
return map_ip_port_sessions.insert(make_pair(key, rtc_session)).second;
return map_id_session.insert(make_pair(peer_id, rtc_session)).second;
}