1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

Support proxy to webrtc media.

This commit is contained in:
winlin 2024-09-04 16:36:41 +08:00
parent 17f836a886
commit 5b6c9df785
4 changed files with 356 additions and 21 deletions

View file

@ -258,6 +258,8 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite
type HLSStreaming struct {
// The context for HLS streaming.
ctx context.Context
// The context ID for recovering the context.
ContextID string `json:"cid"`
// The spbhid, used to identify the backend server.
SRSProxyBackendHLSID string `json:"spbhid"`
@ -265,8 +267,6 @@ type HLSStreaming struct {
StreamURL string `json:"stream_url"`
// The full request URL for HLS streaming
FullURL string `json:"full_url"`
// The context ID for recovering the context.
ContextID string `json:"cid"`
}
func NewHLSStreaming(opts ...func(streaming *HLSStreaming)) *HLSStreaming {

View file

@ -5,24 +5,34 @@ package main
import (
"context"
"encoding/binary"
"fmt"
"io/ioutil"
"net"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
stdSync "sync"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
type rtcServer struct {
// The UDP listener for WebRTC server.
listener *net.UDPConn
// Fast cache for the username to identify the connection.
// The key is username, the value is the UDP address.
usernames sync.Map[string, *RTCConnection]
// Fast cache for the udp address to identify the connection.
// The key is UDP address, the value is the username.
// TODO: Support fast earch by uint64 address.
addresses sync.Map[string, *RTCConnection]
// The wait group for server.
wg sync.WaitGroup
wg stdSync.WaitGroup
}
func newRTCServer(opts ...func(*rtcServer)) *rtcServer {
@ -173,22 +183,26 @@ func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r
}
// Fetch the ice-ufrag and ice-pwd from local SDP answer.
var iceUfrag, icePwd string
if true {
ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`)
ufragMatch := ufragRe.FindStringSubmatch(localSDPAnswer)
if len(ufragMatch) <= 1 {
return errors.Errorf("no ice-ufrag in local sdp answer %v", localSDPAnswer)
}
iceUfrag = ufragMatch[1]
remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer)
if err != nil {
return errors.Wrapf(err, "parse remote sdp offer")
}
if true {
pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`)
pwdMatch := pwdRe.FindStringSubmatch(localSDPAnswer)
if len(pwdMatch) <= 1 {
return errors.Errorf("no ice-pwd in local sdp answer %v", localSDPAnswer)
}
icePwd = pwdMatch[1]
localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer)
if err != nil {
return errors.Wrapf(err, "parse local sdp answer")
}
// Save the new WebRTC connection to LB.
icePair := &RTCICEPair{
RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd,
LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd,
}
if _, err := srsLoadBalancer.LoadOrStoreWebRTC(ctx, streamURL, icePair.Ufrag(), NewRTCStreaming(func(s *RTCConnection) {
s.StreamURL, s.listenerUDP = streamURL, v.listener
s.BuildContext(ctx)
})); err != nil {
return errors.Wrapf(err, "load or store webrtc %v", streamURL)
}
// Response client with local answer.
@ -197,7 +211,7 @@ func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r
}
logger.Df(ctx, "Response local answer %vB with ice-ufrag=%v, ice-pwd=%vB",
len(localSDPAnswer), iceUfrag, len(icePwd))
len(localSDPAnswer), localICEUfrag, len(localICEPwd))
return nil
}
@ -220,5 +234,253 @@ func (v *rtcServer) Run(ctx context.Context) error {
v.listener = listener
logger.Df(ctx, "WebRTC server listen at %v", addr)
// Consume all messages from UDP media transport.
v.wg.Add(1)
go func() {
defer v.wg.Done()
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, addr, err := listener.ReadFromUDP(buf)
if err != nil {
// TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "read from udp failed, err=%v", err)
continue
}
if err := v.handleClientUDP(ctx, addr, buf[:n]); err != nil {
logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%v", n, addr, err)
}
}
}()
return nil
}
func (v *rtcServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
var stream *RTCConnection
// If STUN binding request, parse the ufrag and identify the connection.
if err := func() error {
if rtc_is_rtp_or_rtcp(data) || !rtc_is_stun(data) {
return nil
}
var pkt RTCStunPacket
if err := pkt.UnmarshalBinary(data); err != nil {
return errors.Wrapf(err, "unmarshal stun packet")
}
// Search the stream in fast cache.
if s, ok := v.usernames.Load(pkt.Username); ok {
stream = s
return nil
}
// Load stream by username.
if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil {
return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username)
} else {
stream = s
}
// Cache stream for fast search.
if stream != nil {
v.usernames.Store(pkt.Username, stream)
}
return nil
}(); err != nil {
return err
}
// Search the stream by addr.
if s, ok := v.addresses.Load(addr.String()); ok {
stream = s
} else if stream != nil {
// Cache the address for fast search.
v.addresses.Store(addr.String(), stream)
}
// If stream is not found, ignore the packet.
if stream == nil {
// TODO: Should logging the dropped packet, only logging the first one for each address.
return nil
}
// Proxy the packet to backend.
if err := stream.Proxy(addr, data); err != nil {
return errors.Wrapf(err, "proxy %vB for %v", len(data), stream.StreamURL)
}
return nil
}
type RTCConnection struct {
// The stream context for WebRTC streaming.
ctx context.Context
// The context ID for recovering the context.
ContextID string `json:"cid"`
// The stream URL in vhost/app/stream schema.
StreamURL string `json:"stream_url"`
// The UDP connection proxy to backend.
backendUDP *net.UDPConn
// The client UDP address. Note that it may change.
clientUDP *net.UDPAddr
// The listener UDP connection, used to send messages to client.
listenerUDP *net.UDPConn
}
func NewRTCStreaming(opts ...func(*RTCConnection)) *RTCConnection {
v := &RTCConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTCConnection) Proxy(addr *net.UDPAddr, data []byte) error {
ctx := v.ctx
// Update the current UDP address.
v.clientUDP = addr
// Start the UDP proxy to backend.
if err := v.connectBackend(ctx); err != nil {
return errors.Wrapf(err, "connect backend for %v", v.StreamURL)
}
// Proxy client message to backend.
if v.backendUDP != nil {
if _, err := v.backendUDP.Write(data); err != nil {
return errors.Wrapf(err, "write to backend %v", v.StreamURL)
}
}
return nil
}
func (v *RTCConnection) connectBackend(ctx context.Context) error {
if v.backendUDP != nil {
return nil
}
// Pick a backend SRS server to proxy the RTC stream.
backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL)
if err != nil {
return errors.Wrapf(err, "pick backend")
}
// Parse UDP port from backend.
if len(backend.RTC) == 0 {
return errors.Errorf("no udp server")
}
var udpPort int
if iv, err := strconv.ParseInt(backend.RTC[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse udp port %v", backend.RTC[0])
} else {
udpPort = int(iv)
}
// Connect to backend SRS server via UDP client.
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort}
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
return errors.Wrapf(err, "dial udp to %v", backendAddr)
} else {
v.backendUDP = backendUDP
}
// Proxy all messages from backend to client.
go func() {
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, _, err := v.backendUDP.ReadFromUDP(buf)
if err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "read from backend failed, err=%v", err)
break
}
if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "write to client failed, err=%v", err)
break
}
}
}()
return nil
}
func (v *RTCConnection) BuildContext(ctx context.Context) {
if v.ContextID == "" {
v.ContextID = logger.GenerateContextID()
}
v.ctx = logger.WithContextID(ctx, v.ContextID)
}
type RTCICEPair struct {
// The remote ufrag, used for ICE username and session id.
RemoteICEUfrag string `json:"remote_ufrag"`
// The remote pwd, used for ICE password.
RemoteICEPwd string `json:"remote_pwd"`
// The local ufrag, used for ICE username and session id.
LocalICEUfrag string `json:"local_ufrag"`
// The local pwd, used for ICE password.
LocalICEPwd string `json:"local_pwd"`
}
// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag.
func (v *RTCICEPair) Ufrag() string {
return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag)
}
type RTCStunPacket struct {
// The stun message type.
MessageType uint16
// The stun username, or ufrag.
Username string
}
func (v *RTCStunPacket) UnmarshalBinary(data []byte) error {
if len(data) < 20 {
return errors.Errorf("stun packet too short %v", len(data))
}
p := data
v.MessageType = binary.BigEndian.Uint16(p)
messageLen := binary.BigEndian.Uint16(p[2:])
//magicCookie := p[:8]
//transactionID := p[:20]
p = p[20:]
if len(p) != int(messageLen) {
return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen)
}
for len(p) > 0 {
typ := binary.BigEndian.Uint16(p)
length := binary.BigEndian.Uint16(p[2:])
p = p[4:]
if len(p) < int(length) {
return errors.Errorf("stun attribute length invalid %v < %v", len(p), length)
}
value := p[:length]
p = p[length:]
if length%4 != 0 {
p = p[4-length%4:]
}
switch typ {
case 0x0006:
v.Username = string(value)
}
}
return nil
}

View file

@ -27,6 +27,9 @@ const srsServerAliveDuration = 300 * time.Second
// If HLS streaming update in this duration, it's alive.
const srsHLSAliveDuration = 120 * time.Second
// If WebRTC streaming update in this duration, it's alive.
const srsRTCAliveDuration = 120 * time.Second
type SRSServer struct {
// The server IP.
IP string `json:"ip,omitempty"`
@ -148,6 +151,10 @@ type SRSLoadBalancer interface {
LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error)
// Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID.
LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSStreaming, error)
// Load or store the WebRTC streaming for the specified stream URL.
LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error)
// Load the WebRTC streaming by ufrag, the ICE username.
LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error)
}
// srsLoadBalancer is the global SRS load balancer.
@ -163,6 +170,10 @@ type srsMemoryLoadBalancer struct {
hlsStreamURL sync.Map[string, *HLSStreaming]
// The HLS streaming, key is SPBHID.
hlsSPBHID sync.Map[string, *HLSStreaming]
// The WebRTC streaming, key is stream URL.
rtcStreamURL sync.Map[string, *RTCConnection]
// The WebRTC streaming, key is ufrag.
rtcUfrag sync.Map[string, *RTCConnection]
}
func NewMemoryLoadBalancer() SRSLoadBalancer {
@ -255,6 +266,26 @@ func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL st
return actual, nil
}
func (v *srsMemoryLoadBalancer) LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) {
// Update the WebRTC streaming for the stream URL.
actual, _ := v.rtcStreamURL.LoadOrStore(streamURL, value)
if actual == nil {
return nil, errors.Errorf("load or store WebRTC streaming for %v failed", streamURL)
}
// Update the WebRTC streaming for the ufrag.
v.rtcUfrag.Store(ufrag, value)
return nil, nil
}
func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
if actual, ok := v.rtcUfrag.Load(ufrag); !ok {
return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag)
} else {
return actual, nil
}
}
type srsRedisLoadBalancer struct {
// The redis client sdk.
rdb *redis.Client
@ -462,6 +493,14 @@ func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL str
return &actualHLS, nil
}
func (v *srsRedisLoadBalancer) LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) {
return nil, nil
}
func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
return nil, nil
}
func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string {
return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid)
}

View file

@ -16,6 +16,7 @@ import (
"os"
"path"
"reflect"
"regexp"
"strings"
"syscall"
"time"
@ -176,3 +177,36 @@ func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) {
fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt)
return
}
// rtc_is_stun returns true if data of UDP payload is a STUN packet.
func rtc_is_stun(data []byte) bool {
return len(data) > 0 && (data[0] == 0 || data[0] == 1)
}
// rtc_is_rtp_or_rtcp returns true if data of UDP payload is a RTP or RTCP packet.
func rtc_is_rtp_or_rtcp(data []byte) bool {
return len(data) >= 12 && (data[0]&0xC0) == 0x80
}
// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP.
func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) {
var iceUfrag, icePwd string
if true {
ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`)
ufragMatch := ufragRe.FindStringSubmatch(sdp)
if len(ufragMatch) <= 1 {
return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp)
}
iceUfrag = ufragMatch[1]
}
if true {
pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`)
pwdMatch := pwdRe.FindStringSubmatch(sdp)
if len(pwdMatch) <= 1 {
return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp)
}
icePwd = pwdMatch[1]
}
return iceUfrag, icePwd, nil
}