From 5b6c9df785a1157c7cf9ac345b7ca5c558720d83 Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 4 Sep 2024 16:36:41 +0800 Subject: [PATCH] Support proxy to webrtc media. --- proxy/http.go | 4 +- proxy/rtc.go | 300 +++++++++++++++++++++++++++++++++++++++++++++---- proxy/srs.go | 39 +++++++ proxy/utils.go | 34 ++++++ 4 files changed, 356 insertions(+), 21 deletions(-) diff --git a/proxy/http.go b/proxy/http.go index d38ffb789..ed89c4acb 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -258,6 +258,8 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite type HLSStreaming struct { // The context for HLS streaming. ctx context.Context + // The context ID for recovering the context. + ContextID string `json:"cid"` // The spbhid, used to identify the backend server. SRSProxyBackendHLSID string `json:"spbhid"` @@ -265,8 +267,6 @@ type HLSStreaming struct { StreamURL string `json:"stream_url"` // The full request URL for HLS streaming FullURL string `json:"full_url"` - // The context ID for recovering the context. - ContextID string `json:"cid"` } func NewHLSStreaming(opts ...func(streaming *HLSStreaming)) *HLSStreaming { diff --git a/proxy/rtc.go b/proxy/rtc.go index 13865212c..3799b7dbf 100644 --- a/proxy/rtc.go +++ b/proxy/rtc.go @@ -5,24 +5,34 @@ package main import ( "context" + "encoding/binary" "fmt" "io/ioutil" "net" "net/http" - "regexp" "strconv" "strings" - "sync" + stdSync "sync" "srs-proxy/errors" "srs-proxy/logger" + "srs-proxy/sync" ) type rtcServer struct { // The UDP listener for WebRTC server. listener *net.UDPConn + + // Fast cache for the username to identify the connection. + // The key is username, the value is the UDP address. + usernames sync.Map[string, *RTCConnection] + // Fast cache for the udp address to identify the connection. + // The key is UDP address, the value is the username. + // TODO: Support fast earch by uint64 address. + addresses sync.Map[string, *RTCConnection] + // The wait group for server. - wg sync.WaitGroup + wg stdSync.WaitGroup } func newRTCServer(opts ...func(*rtcServer)) *rtcServer { @@ -173,22 +183,26 @@ func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r } // Fetch the ice-ufrag and ice-pwd from local SDP answer. - var iceUfrag, icePwd string - if true { - ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) - ufragMatch := ufragRe.FindStringSubmatch(localSDPAnswer) - if len(ufragMatch) <= 1 { - return errors.Errorf("no ice-ufrag in local sdp answer %v", localSDPAnswer) - } - iceUfrag = ufragMatch[1] + remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer) + if err != nil { + return errors.Wrapf(err, "parse remote sdp offer") } - if true { - pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) - pwdMatch := pwdRe.FindStringSubmatch(localSDPAnswer) - if len(pwdMatch) <= 1 { - return errors.Errorf("no ice-pwd in local sdp answer %v", localSDPAnswer) - } - icePwd = pwdMatch[1] + + localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer) + if err != nil { + return errors.Wrapf(err, "parse local sdp answer") + } + + // Save the new WebRTC connection to LB. + icePair := &RTCICEPair{ + RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd, + LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd, + } + if _, err := srsLoadBalancer.LoadOrStoreWebRTC(ctx, streamURL, icePair.Ufrag(), NewRTCStreaming(func(s *RTCConnection) { + s.StreamURL, s.listenerUDP = streamURL, v.listener + s.BuildContext(ctx) + })); err != nil { + return errors.Wrapf(err, "load or store webrtc %v", streamURL) } // Response client with local answer. @@ -197,7 +211,7 @@ func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r } logger.Df(ctx, "Response local answer %vB with ice-ufrag=%v, ice-pwd=%vB", - len(localSDPAnswer), iceUfrag, len(icePwd)) + len(localSDPAnswer), localICEUfrag, len(localICEPwd)) return nil } @@ -220,5 +234,253 @@ func (v *rtcServer) Run(ctx context.Context) error { v.listener = listener logger.Df(ctx, "WebRTC server listen at %v", addr) + // Consume all messages from UDP media transport. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, addr, err := listener.ReadFromUDP(buf) + if err != nil { + // TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "read from udp failed, err=%v", err) + continue + } + + if err := v.handleClientUDP(ctx, addr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%v", n, addr, err) + } + } + }() + + return nil +} + +func (v *rtcServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + var stream *RTCConnection + + // If STUN binding request, parse the ufrag and identify the connection. + if err := func() error { + if rtc_is_rtp_or_rtcp(data) || !rtc_is_stun(data) { + return nil + } + + var pkt RTCStunPacket + if err := pkt.UnmarshalBinary(data); err != nil { + return errors.Wrapf(err, "unmarshal stun packet") + } + + // Search the stream in fast cache. + if s, ok := v.usernames.Load(pkt.Username); ok { + stream = s + return nil + } + + // Load stream by username. + if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { + return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username) + } else { + stream = s + } + + // Cache stream for fast search. + if stream != nil { + v.usernames.Store(pkt.Username, stream) + } + return nil + }(); err != nil { + return err + } + + // Search the stream by addr. + if s, ok := v.addresses.Load(addr.String()); ok { + stream = s + } else if stream != nil { + // Cache the address for fast search. + v.addresses.Store(addr.String(), stream) + } + + // If stream is not found, ignore the packet. + if stream == nil { + // TODO: Should logging the dropped packet, only logging the first one for each address. + return nil + } + + // Proxy the packet to backend. + if err := stream.Proxy(addr, data); err != nil { + return errors.Wrapf(err, "proxy %vB for %v", len(data), stream.StreamURL) + } + + return nil +} + +type RTCConnection struct { + // The stream context for WebRTC streaming. + ctx context.Context + // The context ID for recovering the context. + ContextID string `json:"cid"` + + // The stream URL in vhost/app/stream schema. + StreamURL string `json:"stream_url"` + + // The UDP connection proxy to backend. + backendUDP *net.UDPConn + // The client UDP address. Note that it may change. + clientUDP *net.UDPAddr + // The listener UDP connection, used to send messages to client. + listenerUDP *net.UDPConn +} + +func NewRTCStreaming(opts ...func(*RTCConnection)) *RTCConnection { + v := &RTCConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTCConnection) Proxy(addr *net.UDPAddr, data []byte) error { + ctx := v.ctx + + // Update the current UDP address. + v.clientUDP = addr + + // Start the UDP proxy to backend. + if err := v.connectBackend(ctx); err != nil { + return errors.Wrapf(err, "connect backend for %v", v.StreamURL) + } + + // Proxy client message to backend. + if v.backendUDP != nil { + if _, err := v.backendUDP.Write(data); err != nil { + return errors.Wrapf(err, "write to backend %v", v.StreamURL) + } + } + + return nil +} + +func (v *RTCConnection) connectBackend(ctx context.Context) error { + if v.backendUDP != nil { + return nil + } + + // Pick a backend SRS server to proxy the RTC stream. + backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL) + if err != nil { + return errors.Wrapf(err, "pick backend") + } + + // Parse UDP port from backend. + if len(backend.RTC) == 0 { + return errors.Errorf("no udp server") + } + + var udpPort int + if iv, err := strconv.ParseInt(backend.RTC[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse udp port %v", backend.RTC[0]) + } else { + udpPort = int(iv) + } + + // Connect to backend SRS server via UDP client. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort} + if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { + return errors.Wrapf(err, "dial udp to %v", backendAddr) + } else { + v.backendUDP = backendUDP + } + + // Proxy all messages from backend to client. + go func() { + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, _, err := v.backendUDP.ReadFromUDP(buf) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + break + } + + if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + break + } + } + }() + + return nil +} + +func (v *RTCConnection) BuildContext(ctx context.Context) { + if v.ContextID == "" { + v.ContextID = logger.GenerateContextID() + } + v.ctx = logger.WithContextID(ctx, v.ContextID) +} + +type RTCICEPair struct { + // The remote ufrag, used for ICE username and session id. + RemoteICEUfrag string `json:"remote_ufrag"` + // The remote pwd, used for ICE password. + RemoteICEPwd string `json:"remote_pwd"` + // The local ufrag, used for ICE username and session id. + LocalICEUfrag string `json:"local_ufrag"` + // The local pwd, used for ICE password. + LocalICEPwd string `json:"local_pwd"` +} + +// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag. +func (v *RTCICEPair) Ufrag() string { + return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag) +} + +type RTCStunPacket struct { + // The stun message type. + MessageType uint16 + // The stun username, or ufrag. + Username string +} + +func (v *RTCStunPacket) UnmarshalBinary(data []byte) error { + if len(data) < 20 { + return errors.Errorf("stun packet too short %v", len(data)) + } + + p := data + v.MessageType = binary.BigEndian.Uint16(p) + messageLen := binary.BigEndian.Uint16(p[2:]) + //magicCookie := p[:8] + //transactionID := p[:20] + p = p[20:] + + if len(p) != int(messageLen) { + return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen) + } + + for len(p) > 0 { + typ := binary.BigEndian.Uint16(p) + length := binary.BigEndian.Uint16(p[2:]) + p = p[4:] + + if len(p) < int(length) { + return errors.Errorf("stun attribute length invalid %v < %v", len(p), length) + } + + value := p[:length] + p = p[length:] + + if length%4 != 0 { + p = p[4-length%4:] + } + + switch typ { + case 0x0006: + v.Username = string(value) + } + } + return nil } diff --git a/proxy/srs.go b/proxy/srs.go index b19234178..15e9418de 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -27,6 +27,9 @@ const srsServerAliveDuration = 300 * time.Second // If HLS streaming update in this duration, it's alive. const srsHLSAliveDuration = 120 * time.Second +// If WebRTC streaming update in this duration, it's alive. +const srsRTCAliveDuration = 120 * time.Second + type SRSServer struct { // The server IP. IP string `json:"ip,omitempty"` @@ -148,6 +151,10 @@ type SRSLoadBalancer interface { LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID. LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSStreaming, error) + // Load or store the WebRTC streaming for the specified stream URL. + LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) + // Load the WebRTC streaming by ufrag, the ICE username. + LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) } // srsLoadBalancer is the global SRS load balancer. @@ -163,6 +170,10 @@ type srsMemoryLoadBalancer struct { hlsStreamURL sync.Map[string, *HLSStreaming] // The HLS streaming, key is SPBHID. hlsSPBHID sync.Map[string, *HLSStreaming] + // The WebRTC streaming, key is stream URL. + rtcStreamURL sync.Map[string, *RTCConnection] + // The WebRTC streaming, key is ufrag. + rtcUfrag sync.Map[string, *RTCConnection] } func NewMemoryLoadBalancer() SRSLoadBalancer { @@ -255,6 +266,26 @@ func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL st return actual, nil } +func (v *srsMemoryLoadBalancer) LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) { + // Update the WebRTC streaming for the stream URL. + actual, _ := v.rtcStreamURL.LoadOrStore(streamURL, value) + if actual == nil { + return nil, errors.Errorf("load or store WebRTC streaming for %v failed", streamURL) + } + + // Update the WebRTC streaming for the ufrag. + v.rtcUfrag.Store(ufrag, value) + return nil, nil +} + +func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { + if actual, ok := v.rtcUfrag.Load(ufrag); !ok { + return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag) + } else { + return actual, nil + } +} + type srsRedisLoadBalancer struct { // The redis client sdk. rdb *redis.Client @@ -462,6 +493,14 @@ func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL str return &actualHLS, nil } +func (v *srsRedisLoadBalancer) LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) { + return nil, nil +} + +func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { + return nil, nil +} + func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string { return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid) } diff --git a/proxy/utils.go b/proxy/utils.go index 42bf2eb04..c2f41ed1f 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -16,6 +16,7 @@ import ( "os" "path" "reflect" + "regexp" "strings" "syscall" "time" @@ -176,3 +177,36 @@ func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt) return } + +// rtc_is_stun returns true if data of UDP payload is a STUN packet. +func rtc_is_stun(data []byte) bool { + return len(data) > 0 && (data[0] == 0 || data[0] == 1) +} + +// rtc_is_rtp_or_rtcp returns true if data of UDP payload is a RTP or RTCP packet. +func rtc_is_rtp_or_rtcp(data []byte) bool { + return len(data) >= 12 && (data[0]&0xC0) == 0x80 +} + +// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP. +func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) { + var iceUfrag, icePwd string + if true { + ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) + ufragMatch := ufragRe.FindStringSubmatch(sdp) + if len(ufragMatch) <= 1 { + return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp) + } + iceUfrag = ufragMatch[1] + } + if true { + pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) + pwdMatch := pwdRe.FindStringSubmatch(sdp) + if len(pwdMatch) <= 1 { + return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp) + } + icePwd = pwdMatch[1] + } + + return iceUfrag, icePwd, nil +}