1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00
This commit is contained in:
Jacob Su 2025-02-20 14:29:46 +07:00 committed by GitHub
commit c542fe76e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 449 additions and 11 deletions

31
Dockerfile.proxy Normal file
View file

@ -0,0 +1,31 @@
ARG ARCH
FROM ${ARCH}ossrs/srs:ubuntu20 AS build
COPY ./proxy /proxy
WORKDIR /proxy
RUN make clean && make
############################################################
# dist
############################################################
FROM ${ARCH}ubuntu:focal AS dist
WORKDIR /proxy
COPY --from=build /proxy/srs-proxy /proxy/
COPY ./trunk/research /proxy/static
ENV PROXY_STATIC_FILES="/proxy/static"
ENV PROXY_LOAD_BALANCER_TYPE="memory"
ENV PROXY_RTMP_SERVER=1935
ENV PROXY_HTTP_SERVER=8080
ENV PROXY_HTTP_API=1985
ENV PROXY_WEBRTC_SERVER=8000
ENV PROXY_SRT_SERVER=10080
ENV PROXY_SYSTEM_API=12025
EXPOSE 1935 8080 1985 12025 8000/udp 10080/udp
CMD ["./srs-proxy"]

View file

@ -82,7 +82,7 @@ func (v *srsHTTPAPIServer) Run(ctx context.Context) error {
logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr)
mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) {
if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
apiError(ctx, w, r, err, http.StatusInternalServerError)
}
})
@ -90,10 +90,15 @@ func (v *srsHTTPAPIServer) Run(ctx context.Context) error {
logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr)
mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) {
if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
apiError(ctx, w, r, err, http.StatusInternalServerError)
}
})
logger.Df(ctx, "Proxy /api/ to srs")
mux.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) {
srsLoadBalancer.ProxyHTTPAPI(ctx, w, r)
})
// Run HTTP API server.
v.wg.Add(1)
go func() {
@ -239,7 +244,7 @@ func (v *systemAPI) Run(ctx context.Context) error {
logger.Df(ctx, "Register SRS media server, %+v", server)
return nil
}(); err != nil {
apiError(ctx, w, r, err)
apiError(ctx, w, r, err, http.StatusInternalServerError)
}
type Response struct {

View file

@ -198,7 +198,7 @@ func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request)
ctx := logger.WithContext(v.ctx)
if err := v.serve(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
apiError(ctx, w, r, err, http.StatusInternalServerError)
} else {
logger.Df(ctx, "HTTP client done")
}
@ -318,7 +318,7 @@ func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if err := v.serve(v.ctx, w, r); err != nil {
apiError(v.ctx, w, r, err)
apiError(v.ctx, w, r, err, http.StatusInternalServerError)
} else {
logger.Df(v.ctx, "HLS client %v for %v with %v done",
v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path)

308
proxy/srs-api-proxy.go Normal file
View file

@ -0,0 +1,308 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/json"
"io"
"net/http"
"srs-proxy/errors"
"srs-proxy/logger"
"strings"
)
type SrsClient struct {
Id string `json:"id"`
Vhost string `json:"vhost"`
Stream string `json:"stream"`
Ip string `json:"ip"`
PageUrl string `json:"pageUrl"`
SwfUrl string `json:"swfUrl"`
TcUrl string `json:"tcUrl"`
Url string `json:"url"`
Name string `json:"name"`
Type string `json:"type"`
Publish bool `json:"publish"`
Alive float32 `json:"alive"`
SendBytes int `json:"send_bytes"`
RecvBytes int `json:"recv_bytes"`
}
type SrsApiCodeResponse struct {
Code int `json:"code"`
}
type SrsAPICommonResponse struct {
SrsApiCodeResponse
Server string `json:"server"`
Service string `json:"service"`
Pid string `json:"pid"`
}
type SrsClientResponse struct {
SrsAPICommonResponse
Client SrsClient `json:"client"`
}
type SrsClientsResponse struct {
SrsAPICommonResponse
Clients []SrsClient `json:"clients"`
}
type SrsKbps struct {
Recv_30s uint32 `json:"recv_30s"`
Send_30s uint32 `json:"send_30s"`
}
type SrsPublish struct {
Active bool `json:"active"`
Cid string `json:"cid"`
}
type SrsVideo struct {
Codec string `json:"codec"`
Profile string `json:"profile"`
Level string `json:"level"`
Width uint32 `json:"width"`
Height uint32 `json:"height"`
}
type SrsAudio struct {
Codec string `json:"codec"`
Sample_rate uint32 `json:"sample_rate"`
Channel uint8 `json:"channel"`
Profile string `json:"profile"`
}
type SrsStream struct {
Id string `json:"id"`
Name string `json:"name"`
Vhost string `json:"vhost"`
App string `json:"app"`
TcUrl string `json:"tcUrl"`
Url string `json:"url"`
Live_ms uint64 `json:"live_ms"`
Clients uint32 `json:"clients"`
Frames uint32 `json:"frames"`
Send_bytes uint32 `json:"send_bytes"`
Recv_bytes uint32 `json:"recv_bytes"`
Kbps SrsKbps `json:"kbps"`
Publish SrsPublish `json:"publish"`
Video SrsVideo `json:"video"`
Audio SrsAudio `json:"audio"`
}
type SrsStreamResponse struct {
SrsAPICommonResponse
Stream SrsStream `json:"stream"`
}
type SrsStreamsResponse struct {
SrsAPICommonResponse
Streams []SrsStream `json:"streams"`
}
type SrsHTTPApi struct {
Enabled bool `json:"enabled"`
Listen string `json:"listen"`
Crossdomain bool `json:"crossdomain"`
Raw_api SrsRawApi `json:"raw_api"`
}
type SrsRawApi struct {
Enabled bool `json:"enabled"`
Allow_reload bool `json:"allow_reload"`
Allow_query bool `json:"allow_query"`
Allow_update bool `json:"allow_update"`
}
type SrsRawResponse struct {
SrsApiCodeResponse
Http_api SrsHTTPApi `json:"http_api"`
}
type SrsRawReloadResponse struct {
SrsApiCodeResponse
}
type SrsRawReloadFetchData struct {
Err int `json:"err"`
Msg string `json:"msg"`
State int `json:"state"`
Rid string `json:"rid"`
}
type SrsRawReloadFetchResponse struct {
SrsApiCodeResponse
Data SrsRawReloadFetchData `json:"data"`
}
type SrsApiProxy struct {
}
func (v *SrsApiProxy) proxySrsAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error {
if strings.HasPrefix(r.URL.Path, "/api/v1/clients") {
return proxySrsClientsAPI(ctx, servers, w, r)
} else if strings.HasPrefix(r.URL.Path, "/api/v1/streams") {
return proxySrsStreamsAPI(ctx, servers, w, r)
} else if strings.HasPrefix(r.URL.Path, "/api/v1/raw") {
return proxySrsRawAPI(ctx, servers, w, r)
}
return nil
}
// handle srs clients api /api/v1/clients
func proxySrsClientsAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
clientId := ""
if strings.HasPrefix(r.URL.Path, "/api/v1/clients/") {
clientId = r.URL.Path[len("/api/v1/clients/"):]
}
logger.Df(ctx, "%v %v clientId=%v", r.Method, r.URL.Path, clientId)
body, err := io.ReadAll(r.Body)
if err != nil {
apiError(ctx, w, r, err, http.StatusInternalServerError)
return errors.Wrapf(err, "read request body err")
}
switch r.Method {
case http.MethodDelete:
for _, server := range servers {
if ret, err := server.ApiRequest(ctx, r, body); err == nil {
logger.Df(ctx, "response %v", string(ret))
var res SrsApiCodeResponse
if err := json.Unmarshal(ret, &res); err == nil && res.Code == 0 {
apiResponse(ctx, w, r, res)
return nil
}
}
}
err := errors.Errorf("clientId %v not found in server", clientId)
apiError(ctx, w, r, err, http.StatusNotFound)
return err
case http.MethodGet:
if len(clientId) > 0 {
for _, server := range servers {
var client SrsClientResponse
if ret, err := server.ApiRequest(ctx, r, body); err == nil {
if err := json.Unmarshal(ret, &client); err == nil && client.Code == 0 {
apiResponse(ctx, w, r, client)
return nil
}
}
}
} else { // get all clients
var clients SrsClientsResponse
for _, server := range servers {
var res SrsClientsResponse
if ret, err := server.ApiRequest(ctx, r, body); err == nil {
if err := json.Unmarshal(ret, &res); err == nil && res.Code == 0 {
clients.Clients = append(clients.Clients, res.Clients...)
}
}
}
apiResponse(ctx, w, r, clients)
return nil
}
default:
logger.Df(ctx, "/api/v1/clients %v", r.Method)
}
return nil
}
func proxySrsStreamsAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
streamId := ""
if strings.HasPrefix(r.URL.Path, "/api/v1/streams/") {
streamId = r.URL.Path[len("/api/v1/streams/"):]
}
logger.Df(ctx, "%v %v streamId=%v", r.Method, r.URL.Path, streamId)
body, err := io.ReadAll(r.Body)
if err != nil {
apiError(ctx, w, r, err, http.StatusInternalServerError)
return errors.Wrapf(err, "read request body err")
}
if r.Method != http.MethodGet {
err := errors.Errorf("Unsupported http method type %v", r.Method)
apiError(ctx, w, r, err, http.StatusBadRequest)
return err
}
if len(streamId) > 0 {
var stream SrsStreamResponse
for _, server := range servers {
if ret, err := server.ApiRequest(ctx, r, body); err == nil {
if err := json.Unmarshal(ret, &stream); err == nil && stream.Code == 0 {
apiResponse(ctx, w, r, stream)
return nil
}
}
}
ret := SrsApiCodeResponse{
Code: 2048,
}
apiResponse(ctx, w, r, ret)
return nil
} else {
var streams SrsStreamsResponse
for _, server := range servers {
var res SrsStreamsResponse
if ret, err := server.ApiRequest(ctx, r, body); err == nil {
if err := json.Unmarshal(ret, &res); err == nil && res.Code == 0 {
streams.Streams = append(streams.Streams, res.Streams...)
}
}
}
apiResponse(ctx, w, r, streams)
return nil
}
}
func proxySrsRawAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
rpc := r.URL.Query().Get("rpc")
logger.Df(ctx, "%v, rpc=%v", r.URL.Path, rpc)
body, err := io.ReadAll(r.Body)
if err != nil {
apiError(ctx, w, r, err, http.StatusInternalServerError)
return errors.Wrapf(err, "read request body err")
}
for _, server := range servers {
if ret, err := server.ApiRequest(ctx, r, body); err == nil {
if rpc == "raw" {
// return the first success response
var raw SrsRawResponse
if err := json.Unmarshal(ret, &raw); err == nil && raw.Code == 0 {
raw.Http_api.Listen = envHttpAPI()
apiResponse(ctx, w, r, raw)
return nil
}
} else if rpc == "reload" {
var res SrsRawReloadResponse
err := json.Unmarshal(ret, &res)
logger.Df(ctx, "%v %v %v %v", server.IP, r.URL.Path, res, err)
} else if rpc == "reload-fetch" {
var res SrsRawReloadFetchResponse
err := json.Unmarshal(ret, &res)
logger.Df(ctx, "%v %v %v %v", server.IP, r.URL.Path, res, err)
} else {
var code SrsApiCodeResponse
if err := json.Unmarshal(ret, &code); err == nil {
logger.Df(ctx, "%v %v", r.URL.Path, code)
}
}
}
}
return nil
}

View file

@ -4,10 +4,13 @@
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"os"
"strconv"
"strings"
@ -97,6 +100,35 @@ func (v *SRSServer) Format(f fmt.State, c rune) {
}
}
func (v *SRSServer) ApiRequest(ctx context.Context, r *http.Request, body []byte) ([]byte, error) {
var url string
// if the v.API[0] contains ip address, e.g. 127.0.0.1:1985, then use it as the ip address
if strings.Contains(v.API[0], ":") && strings.Index(v.API[0], ":") > 0 {
url = "http://" + v.API[0] + r.URL.Path
} else {
url = "http://" + v.IP + ":" + v.API[0] + r.URL.Path
}
if r.URL.RawQuery != "" {
url += "?" + r.URL.RawQuery
}
if req, err := http.NewRequestWithContext(ctx, r.Method, url, bytes.NewReader(body)); err != nil {
return nil, errors.Wrapf(err, "create request to %v", url)
} else if res, err := http.DefaultClient.Do(req); err != nil {
return nil, errors.Wrapf(err, "send request to %v", url)
} else {
defer res.Body.Close()
if ret, err := io.ReadAll(res.Body); err != nil {
return nil, errors.Wrapf(err, "read http respose error")
} else if !isHttpStatusOK(res.StatusCode) {
return ret, errors.Errorf("http response status code %v", res.StatusCode)
} else {
return ret, nil
}
}
}
func NewSRSServer(opts ...func(*SRSServer)) *SRSServer {
v := &SRSServer{}
for _, opt := range opts {
@ -158,6 +190,8 @@ type SRSLoadBalancer interface {
StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error
// Load the WebRTC streaming by ufrag, the ICE username.
LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error)
// proxy http api to srs
ProxyHTTPAPI(ctx context.Context, w http.ResponseWriter, r *http.Request) error
}
// srsLoadBalancer is the global SRS load balancer.
@ -165,6 +199,7 @@ var srsLoadBalancer SRSLoadBalancer
// srsMemoryLoadBalancer stores state in memory.
type srsMemoryLoadBalancer struct {
*SrsApiProxy
// All available SRS servers, key is server ID.
servers sync.Map[string, *SRSServer]
// The picked server to servce client by specified stream URL, key is stream url.
@ -287,7 +322,17 @@ func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag str
}
}
func (v *srsMemoryLoadBalancer) ProxyHTTPAPI(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
services := make([]*SRSServer, v.servers.Size())
v.servers.Range(func(_ string, value *SRSServer) bool {
services = append(services, value)
return true
})
return v.proxySrsAPI(ctx, services, w, r)
}
type srsRedisLoadBalancer struct {
*SrsApiProxy
// The redis client sdk.
rdb *redis.Client
}
@ -528,6 +573,40 @@ func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag stri
return &actual, nil
}
func (v *srsRedisLoadBalancer) ProxyHTTPAPI(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
// Query all servers from redis, in json string.
var serverKeys []string
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
if err := json.Unmarshal(b, &serverKeys); err != nil {
return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
}
}
// No server found, failed.
if len(serverKeys) == 0 {
err := errors.New("servers empty")
apiError(ctx, w, r, err, http.StatusInternalServerError)
return err
}
// TODO get all SRSServer
var srsServers []*SRSServer
for _, key := range serverKeys {
var server SRSServer
if b, err := v.rdb.Get(ctx, key).Bytes(); err == nil {
if err := json.Unmarshal(b, &server); err != nil {
return errors.Wrapf(err, "unmarshal servers %v, %v", key, string(b))
}
srsServers = append(srsServers, &server)
logger.Df(ctx, "srsServer: %v", server)
}
}
return v.proxySrsAPI(ctx, srsServers, w, r)
}
func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string {
return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag)
}
@ -549,5 +628,5 @@ func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string {
}
func (v *srsRedisLoadBalancer) redisKeyServers() string {
return fmt.Sprintf("srs-proxy-all-servers")
return "srs-proxy-all-servers"
}

View file

@ -43,3 +43,13 @@ func (m *Map[K, V]) Range(f func(key K, value V) bool) {
func (m *Map[K, V]) Store(key K, value V) {
m.m.Store(key, value)
}
func (m *Map[K, V]) Size() uint32 {
size := uint32(0)
m.m.Range(func(_, _ any) bool {
size++
return true
})
return size
}

View file

@ -32,7 +32,7 @@ func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, da
b, err := json.Marshal(data)
if err != nil {
apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data))
apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data), http.StatusInternalServerError)
return
}
@ -41,10 +41,10 @@ func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, da
w.Write(b)
}
func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) {
func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error, code int) {
logger.Wf(ctx, "HTTP API error %+v", err)
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusInternalServerError)
w.WriteHeader(code)
fmt.Fprintln(w, fmt.Sprintf("%v", err))
}
@ -69,6 +69,10 @@ func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool {
return false
}
func isHttpStatusOK(v int) bool {
return v >= 200 && v < 300
}
func parseGracefullyQuitTimeout() (time.Duration, error) {
if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil {
return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout())
@ -250,8 +254,9 @@ func parseSRTStreamID(sid string) (host, resource string, err error) {
}
// parseListenEndpoint parse the listen endpoint as:
// port The tcp listen port, like 1935.
// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935
//
// port The tcp listen port, like 1935.
// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935
func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) {
// If no colon in ep, it's port in string.
if !strings.Contains(ep, ":") {