1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

Merge branch 'ossrs:develop' into develop

This commit is contained in:
Laurentiu 2024-10-24 16:11:51 +03:00 committed by GitHub
commit 79773325ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
871 changed files with 21347 additions and 7352 deletions

4
.github/FUNDING.yml vendored
View file

@ -1,7 +1,7 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
github: [winlinvip] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with patreon id.
open_collective: srs-server
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel

View file

@ -7,6 +7,10 @@ assignees: ''
---
!!! Before submitting a new bug report, please ensure you have searched for any existing bugs. Duplicate issues or
questions that are overly simple or already addressed in the documentation will be removed without any
response.
**Describe the bug**
A clear and concise description of what the bug is.

View file

@ -7,6 +7,11 @@ assignees: ''
---
!!! Before submitting a new feature request, please ensure you have searched for any existing features and utilized
the `Ask AI` feature at https://ossrs.io or https://ossrs.net (for users in China). Duplicate issues or
questions that are overly simple or already addressed in the documentation will be removed without any
response.
**What is the business background? Please provide a description.**
Who are the users? How do they utilize this feature? What problem does this feature address?

View file

@ -4,7 +4,7 @@ name: "Release"
on:
push:
tags:
- v6*
- v7*
# For draft, need write permission.
permissions:
@ -75,7 +75,7 @@ jobs:
- name: Run SRS regression-test
run: |
docker run --rm srs:test bash -c 'make && \
./objs/srs -c conf/regression-test.conf && \
./objs/srs -c conf/regression-test.conf && sleep 10 && \
cd 3rdparty/srs-bench && make && ./objs/srs_test -test.v'
runs-on: ubuntu-20.04
@ -108,7 +108,7 @@ jobs:
# See https://github.com/cygwin/cygwin-install-action#parameters
# Note that https://github.com/egor-tensin/setup-cygwin fails to install packages.
- name: Setup Cygwin
uses: cygwin/cygwin-install-action@db475590d56881c6cef7b3f96f6f3dd9532ea1f4 # master
uses: cygwin/cygwin-install-action@006ad0b0946ca6d0a3ea2d4437677fa767392401 # master
with:
platform: x86_64
packages: bash make gcc-g++ cmake automake patch pkg-config tcl unzip
@ -250,7 +250,7 @@ jobs:
echo "SRS_TAG=${{ needs.envs.outputs.SRS_TAG }}" >> $GITHUB_ENV
echo "SRS_VERSION=${{ needs.envs.outputs.SRS_VERSION }}" >> $GITHUB_ENV
echo "SRS_MAJOR=${{ needs.envs.outputs.SRS_MAJOR }}" >> $GITHUB_ENV
echo "SRS_MAJOR=${{ needs.envs.outputs.SRS_XYZ }}" >> $GITHUB_ENV
echo "SRS_XYZ=${{ needs.envs.outputs.SRS_XYZ }}" >> $GITHUB_ENV
##################################################################################################################
# Git checkout
- name: Checkout repository
@ -276,7 +276,9 @@ jobs:
echo "Release ossrs/srs:$SRS_TAG"
docker buildx build --platform linux/arm/v7,linux/arm64/v8,linux/amd64 \
--output "type=image,push=true" \
-t ossrs/srs:$SRS_TAG --build-arg SRS_AUTO_PACKAGER=$PACKAGER -f Dockerfile .
-t ossrs/srs:$SRS_TAG --build-arg SRS_AUTO_PACKAGER=$PACKAGER \
--build-arg CONFARGS='--sanitizer=off --gb28181=on' \
-f Dockerfile .
# Docker alias images
# TODO: FIXME: If stable, please set the latest from 5.0 to 6.0
- name: Docker alias images for ossrs/srs
@ -304,7 +306,7 @@ jobs:
echo "SRS_TAG=${{ needs.envs.outputs.SRS_TAG }}" >> $GITHUB_ENV
echo "SRS_VERSION=${{ needs.envs.outputs.SRS_VERSION }}" >> $GITHUB_ENV
echo "SRS_MAJOR=${{ needs.envs.outputs.SRS_MAJOR }}" >> $GITHUB_ENV
echo "SRS_MAJOR=${{ needs.envs.outputs.SRS_XYZ }}" >> $GITHUB_ENV
echo "SRS_XYZ=${{ needs.envs.outputs.SRS_XYZ }}" >> $GITHUB_ENV
# Aliyun ACR
# TODO: FIXME: If stable, please set the latest from 5.0 to 6.0
- name: Login aliyun hub
@ -406,6 +408,7 @@ jobs:
echo "SRS_TAG=${{ needs.envs.outputs.SRS_TAG }}" >> $GITHUB_ENV
echo "SRS_VERSION=${{ needs.envs.outputs.SRS_VERSION }}" >> $GITHUB_ENV
echo "SRS_MAJOR=${{ needs.envs.outputs.SRS_MAJOR }}" >> $GITHUB_ENV
echo "SRS_XYZ=${{ needs.envs.outputs.SRS_XYZ }}" >> $GITHUB_ENV
echo "SRS_RELEASE_ID=${{ needs.draft.outputs.SRS_RELEASE_ID }}" >> $GITHUB_ENV
echo "SRS_PACKAGE_ZIP=${{ needs.linux.outputs.SRS_PACKAGE_ZIP }}" >> $GITHUB_ENV
echo "SRS_PACKAGE_MD5=${{ needs.linux.outputs.SRS_PACKAGE_MD5 }}" >> $GITHUB_ENV
@ -446,23 +449,23 @@ jobs:
* Binary: ${{ env.SRS_CYGWIN_MD5 }} [${{ env.SRS_CYGWIN_TAR }}](https://gitee.com/ossrs/srs/releases/download/${{ env.SRS_TAG }}/${{ env.SRS_CYGWIN_TAR }})
## Docker
* [docker pull ossrs/srs:${{ env.SRS_MAJOR }}](https://ossrs.io/lts/en-us/docs/v5/doc/getting-started)
* [docker pull ossrs/srs:${{ env.SRS_TAG }}](https://ossrs.io/lts/en-us/docs/v5/doc/getting-started)
* [docker pull ossrs/srs:${{ env.SRS_XYZ }}](https://ossrs.io/lts/en-us/docs/v5/doc/getting-started)
* [docker pull ossrs/srs:${{ env.SRS_MAJOR }}](https://ossrs.io/lts/en-us/docs/v7/doc/getting-started)
* [docker pull ossrs/srs:${{ env.SRS_TAG }}](https://ossrs.io/lts/en-us/docs/v7/doc/getting-started)
* [docker pull ossrs/srs:${{ env.SRS_XYZ }}](https://ossrs.io/lts/en-us/docs/v7/doc/getting-started)
## Docker Mirror: aliyun.com
* [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_MAJOR }}](https://ossrs.net/lts/zh-cn/docs/v5/doc/getting-started)
* [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_TAG }}](https://ossrs.net/lts/zh-cn/docs/v5/doc/getting-started)
* [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_XYZ }}](https://ossrs.net/lts/zh-cn/docs/v5/doc/getting-started)
* [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_MAJOR }}](https://ossrs.net/lts/zh-cn/docs/v7/doc/getting-started)
* [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_TAG }}](https://ossrs.net/lts/zh-cn/docs/v7/doc/getting-started)
* [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_XYZ }}](https://ossrs.net/lts/zh-cn/docs/v7/doc/getting-started)
## Doc: ossrs.io
* [Getting Started](https://ossrs.io/lts/en-us/docs/v5/doc/getting-started)
* [Wiki home](https://ossrs.io/lts/en-us/docs/v5/doc/introduction)
* [Getting Started](https://ossrs.io/lts/en-us/docs/v7/doc/getting-started)
* [Wiki home](https://ossrs.io/lts/en-us/docs/v7/doc/introduction)
* [FAQ](https://ossrs.io/lts/en-us/faq), [Features](https://github.com/ossrs/srs/blob/${{ github.sha }}/trunk/doc/Features.md#features) or [ChangeLogs](https://github.com/ossrs/srs/blob/${{ github.sha }}/trunk/doc/CHANGELOG.md#changelog)
## Doc: ossrs.net
* [快速入门](https://ossrs.net/lts/zh-cn/docs/v5/doc/getting-started)
* [中文Wiki首页](https://ossrs.net/lts/zh-cn/docs/v5/doc/introduction)
* [快速入门](https://ossrs.net/lts/zh-cn/docs/v7/doc/getting-started)
* [中文Wiki首页](https://ossrs.net/lts/zh-cn/docs/v7/doc/introduction)
* [中文FAQ](https://ossrs.net/lts/zh-cn/faq), [功能列表](https://github.com/ossrs/srs/blob/${{ github.sha }}/trunk/doc/Features.md#features) 或 [修订历史](https://github.com/ossrs/srs/blob/${{ github.sha }}/trunk/doc/CHANGELOG.md#changelog)
draft: false
prerelease: true

View file

@ -1,73 +0,0 @@
# This workflow uses actions that are not certified by GitHub. They are provided
# by a third-party and are governed by separate terms of service, privacy
# policy, and support documentation.
name: Scorecard
on:
# For Branch-Protection check. Only the default branch is supported. See
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection
#branch_protection_rule:
# To guarantee Maintained check is occasionally updated. See
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained
#schedule:
# - cron: '00 00 * * 6' # At 00:00 on Saturday, see https://crontab.guru/#00_00_*_*_6
push:
branches: [ "develop" ]
# Declare default permissions as read only.
permissions: read-all
jobs:
analysis:
name: Scorecard analysis
runs-on: ubuntu-latest
permissions:
# Needed to upload the results to code-scanning dashboard.
security-events: write
# Needed to publish results and get a badge (see publish_results below).
id-token: write
# Uncomment the permissions below if installing in a private repository.
# contents: read
# actions: read
steps:
- name: "Checkout code"
uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # v3.1.0
with:
persist-credentials: false
- name: "Run analysis"
uses: ossf/scorecard-action@99c53751e09b9529366343771cc321ec74e9bd3d # v2.0.6
with:
results_file: results.sarif
results_format: sarif
# (Optional) "write" PAT token. Uncomment the `repo_token` line below if:
# - you want to enable the Branch-Protection check on a *public* repository, or
# - you are installing Scorecard on a *private* repository
# To create the PAT, follow the steps in https://github.com/ossf/scorecard-action#authentication-with-pat.
repo_token: ${{ secrets.SCORECARD_TOKEN }}
# Public repositories:
# - Publish results to OpenSSF REST API for easy access by consumers
# - Allows the repository to include the Scorecard badge.
# - See https://github.com/ossf/scorecard-action#publishing-results.
# For private repositories:
# - `publish_results` will always be set to `false`, regardless
# of the value entered here.
publish_results: true
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
# format to the repository Actions tab.
#- name: "Upload artifact"
# uses: actions/upload-artifact@3cea5372237819ed00197afe530f5a7ea3e805c8 # v3.1.0
# with:
# name: SARIF file
# path: results.sarif
# retention-days: 5
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@807578363a7869ca324a79039e6db9c843e0e100 # v2.1.27
with:
sarif_file: results.sarif

View file

@ -55,7 +55,7 @@ jobs:
# See https://github.com/cygwin/cygwin-install-action#parameters
# Note that https://github.com/egor-tensin/setup-cygwin fails to install packages.
- name: Setup Cygwin
uses: cygwin/cygwin-install-action@db475590d56881c6cef7b3f96f6f3dd9532ea1f4 # master
uses: cygwin/cygwin-install-action@006ad0b0946ca6d0a3ea2d4437677fa767392401 # master
with:
platform: x86_64
packages: bash make gcc-g++ cmake automake patch pkg-config tcl unzip
@ -190,15 +190,16 @@ jobs:
docker run --rm -w /srs/trunk/3rdparty/srs-bench srs:test \
./objs/srs_blackbox_test -test.v -test.run '^TestFast' -test.parallel 64
docker run --rm -w /srs/trunk/3rdparty/srs-bench srs:test \
./objs/srs_blackbox_test -test.v -test.run '^TestSlow' -test.parallel 4
./objs/srs_blackbox_test -test.v -test.run '^TestSlow' -test.parallel 1
# For utest
- name: Run SRS utest
run: docker run --rm srs:test ./objs/srs_utest
# For regression-test
- name: Run SRS regression-test
run: |
docker run --rm srs:test bash -c './objs/srs -c conf/regression-test.conf && \
cd 3rdparty/srs-bench && ./objs/srs_test -test.v && ./objs/srs_gb28181_test -test.v'
docker run --rm srs:test bash -c './objs/srs -c conf/regression-test.conf && sleep 10 && \
cd 3rdparty/srs-bench && (./objs/srs_test -test.v || (cat ../../objs/srs.log && exit 1)) && \
./objs/srs_gb28181_test -test.v'
runs-on: ubuntu-20.04
coverage:

4
.gitignore vendored
View file

@ -16,8 +16,6 @@
*.pyc
*.swp
.DS_Store
.vscode
.vscode/*
/trunk/Makefile
/trunk/objs
/trunk/src/build-qt-Desktop-Debug
@ -43,3 +41,5 @@ cmake-build-debug
/trunk/ide/srs_clion/cmake_install.cmake
/trunk/ide/srs_clion/srs
/trunk/ide/srs_clion/Testing/
/trunk/ide/vscode-build

View file

@ -1,7 +1,6 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="private" type="CMakeRunConfiguration" factoryName="Application" PROGRAM_PARAMS="-c console.conf" REDIRECT_INPUT="false" ELEVATE="false" USE_EXTERNAL_CONSOLE="false" WORKING_DIR="file://$CMakeCurrentBuildDir$/../../../" PASS_PARENT_ENVS_2="true" PROJECT_NAME="srs" TARGET_NAME="srs" CONFIG_NAME="Debug" RUN_TARGET_PROJECT_NAME="srs" RUN_TARGET_NAME="srs">
<envs>
<env name="SRS_RTC_SERVER_ENABLED" value="on" />
<env name="MallocNanoZone" value="0" />
</envs>
<method v="2">

38
.vscode/README.md vendored Normal file
View file

@ -0,0 +1,38 @@
# Debug with VSCode
Support run and debug with VSCode.
## SRS
Install the following extensions:
- CMake Tools
- CodeLLDB
- C/C++ Extension Pack
Open the folder like `~/git/srs` in VSCode.
Run commmand `> CMake: Configure` to configure the project.
> Note: You can press `Ctrl+R`, then type `CMake: Configure` then select `Clang` as the toolchain.
> Note: The `settings.json` is used to configure the cmake. It will use `${workspaceFolder}/trunk/ide/srs_clion/CMakeLists.txt`
> and `${workspaceFolder}/trunk/ide/vscode-build` as the source file and build directory.
Click the `Run > Run Without Debugging` button to start the server.
> Note: The `launch.json` is used for running and debugging. The build will output the binary to
> `${workspaceFolder}/trunk/ide/vscode-build/srs`.
## Proxy
Install the following extensions:
- Go
Open the folder like `~/git/srs` in VSCode.
Select the `View > Run` and select `Launch srs-proxy` to start the proxy server.
Click the `Run > Run Without Debugging` button to start the server.
> Note: The `launch.json` is used for running and debugging.

36
.vscode/launch.json vendored Normal file
View file

@ -0,0 +1,36 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Launch SRS",
"type": "cppdbg",
"request": "launch",
"program": "${workspaceFolder}/trunk/ide/vscode-build/srs",
"args": ["-c", "conf/console.conf"],
"stopAtEntry": false,
"cwd": "${workspaceFolder}/trunk",
"environment": [],
"externalConsole": false,
"MIMode": "lldb",
"setupCommands": [
{
"description": "Enable pretty-printing for gdb",
"text": "-enable-pretty-printing",
"ignoreFailures": true
}
],
"preLaunchTask": "build",
"logging": {
"engineLogging": true
}
},
{
"name": "Launch srs-proxy",
"type": "go",
"request": "launch",
"mode": "auto",
"cwd": "${workspaceFolder}/proxy",
"program": "${workspaceFolder}/proxy"
}
]
}

5
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,5 @@
{
"cmake.sourceDirectory": "${workspaceFolder}/trunk/ide/srs_clion",
"cmake.buildDirectory": "${workspaceFolder}/trunk/ide/vscode-build",
"cmake.configureOnOpen": false
}

17
.vscode/tasks.json vendored Normal file
View file

@ -0,0 +1,17 @@
{
"version": "2.0.0",
"tasks": [
{
"label": "build",
"type": "shell",
"command": "cmake --build ${workspaceFolder}/trunk/ide/vscode-build",
"group": {
"kind": "build",
"isDefault": true
},
"problemMatcher": ["$gcc"],
"detail": "Build SRS by cmake."
}
]
}

View file

@ -11,7 +11,7 @@ ARG SRS_AUTO_PACKAGER
RUN echo "BUILDPLATFORM: $BUILDPLATFORM, TARGETPLATFORM: $TARGETPLATFORM, PACKAGER: ${#SRS_AUTO_PACKAGER}, CONFARGS: ${CONFARGS}, MAKEARGS: ${MAKEARGS}, INSTALLDEPENDS: ${INSTALLDEPENDS}"
# https://serverfault.com/questions/949991/how-to-install-tzdata-on-a-ubuntu-docker-image
ENV DEBIAN_FRONTEND noninteractive
ENV DEBIAN_FRONTEND=noninteractive
# To use if in RUN, see https://github.com/moby/moby/issues/7281#issuecomment-389440503
# Note that only exists issue like "/bin/sh: 1: [[: not found" for Ubuntu20, no such problem in CentOS7.
@ -29,7 +29,7 @@ WORKDIR /srs/trunk
# Build and install SRS.
# Note that SRT is enabled by default, so we configure without --srt=on.
# Note that we have copied all files by make install.
RUN ./configure --sanitizer=off --gb28181=on --h265=on ${CONFARGS} && make ${MAKEARGS} && make install
RUN ./configure ${CONFARGS} && make ${MAKEARGS} && make install
############################################################
# dist

View file

@ -1,6 +1,6 @@
The MIT License (MIT)
Copyright (c) 2013-2023 The SRS Authors
Copyright (c) 2013-2024 The SRS Authors
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in

View file

@ -10,13 +10,9 @@
[![](https://img.shields.io/badge/SRS-YouTube-red)](https://www.youtube.com/@srs_server)
[![](https://badgen.net/discord/members/yZ4BnPmHAd)](https://discord.gg/yZ4BnPmHAd)
[![](https://app.fossa.com/api/projects/git%2Bgithub.com%2Fossrs%2Fsrs.svg?type=small)](https://app.fossa.com/projects/git%2Bgithub.com%2Fossrs%2Fsrs?ref=badge_small)
[![](https://ossrs.net/wiki/images/srs-faq.svg)](https://ossrs.net/lts/zh-cn/faq)
[![](https://badgen.net/badge/srs/stackoverflow/orange?icon=terminal)](https://stackoverflow.com/questions/tagged/simple-realtime-server)
[![](https://opencollective.com/srs-server/tiers/badge.svg)](https://opencollective.com/srs-server)
[![](https://img.shields.io/docker/pulls/ossrs/srs)](https://hub.docker.com/r/ossrs/srs/tags)
[![](https://ossrs.net/wiki/images/do-btn-srs-125x20.svg)](https://cloud.digitalocean.com/droplets/new?appId=133468816&size=s-2vcpu-2gb&region=sgp1&image=ossrs-srs&type=applications)
[![](https://api.securityscorecards.dev/projects/github.com/ossrs/srs/badge)](https://api.securityscorecards.dev/projects/github.com/ossrs/srs)
[![](https://bestpractices.coreinfrastructure.org/projects/5619/badge)](https://bestpractices.coreinfrastructure.org/projects/5619)
SRS/6.0 ([Hang](https://ossrs.io/lts/en-us/product#release-60)) is a simple, high-efficiency, and real-time video server,
supporting RTMP/WebRTC/HLS/HTTP-FLV/SRT/MPEG-DASH/GB28181, Linux/Windows/macOS, X86_64/ARMv7/AARCH64/M1/RISCV/LOONGARCH/MIPS,
@ -126,9 +122,20 @@ distributed under their [licenses](https://ossrs.io/lts/en-us/license).
## Releases
* 2024-09-01, [Release v6.0-a1](https://github.com/ossrs/srs/releases/tag/v6.0-a1), v6.0-a1, 6.0 alpha1, v6.0.155, 169636 lines.
* 2024-07-27, [Release v6.0-a0](https://github.com/ossrs/srs/releases/tag/v6.0-a0), v6.0-a0, 6.0 alpha0, v6.0.145, 169259 lines.
* 2024-07-04, [Release v6.0-d6](https://github.com/ossrs/srs/releases/tag/v6.0-d6), v6.0-d6, 6.0 dev6, v6.0.134, 168904 lines.
* 2024-06-15, [Release v6.0-d5](https://github.com/ossrs/srs/releases/tag/v6.0-d5), v6.0-d5, 6.0 dev5, v6.0.129, 168454 lines.
* 2024-02-15, [Release v6.0-d4](https://github.com/ossrs/srs/releases/tag/v6.0-d4), v6.0-d4, 6.0 dev4, v6.0.113, 167695 lines.
* 2023-11-19, [Release v6.0-d3](https://github.com/ossrs/srs/releases/tag/v6.0-d3), v6.0-d3, 6.0 dev3, v6.0.101, 167560 lines.
* 2023-09-28, [Release v6.0-d2](https://github.com/ossrs/srs/releases/tag/v6.0-d2), v6.0-d2, 6.0 dev2, v6.0.85, 167509 lines.
* 2023-08-31, [Release v6.0-d1](https://github.com/ossrs/srs/releases/tag/v6.0-d1), v6.0-d1, 6.0 dev1, v6.0.72, 167135 lines.
* 2023-07-09, [Release v6.0-d0](https://github.com/ossrs/srs/releases/tag/v6.0-d0), v6.0-d0, 6.0 dev0, v6.0.59, 166739 lines.
* 2024-06-15, [Release v5.0-r3](https://github.com/ossrs/srs/releases/tag/v5.0-r3), v5.0-r3, 5.0 release3, v5.0.213, 163585 lines.
* 2024-04-03, [Release v5.0-r2](https://github.com/ossrs/srs/releases/tag/v5.0-r2), v5.0-r2, 5.0 release2, v5.0.210, 163515 lines.
* 2024-02-15, [Release v5.0-r1](https://github.com/ossrs/srs/releases/tag/v5.0-r1), v5.0-r1, 5.0 release1, v5.0.208, 163441 lines.
* 2023-12-30, [Release v5.0-r0](https://github.com/ossrs/srs/releases/tag/v5.0-r0), v5.0-r0, 5.0 release0, v5.0.205, 163363 lines.
* 2023-11-19, [Release v5.0-b7](https://github.com/ossrs/srs/releases/tag/v5.0-b7), v5.0-b7, 5.0 beta7, v5.0.200, 163305 lines.
* 2023-10-25, [Release v5.0-b6](https://github.com/ossrs/srs/releases/tag/v5.0-b6), v5.0-b6, 5.0 beta6, v5.0.195, 163303 lines.
* 2023-09-28, [Release v5.0-b5](https://github.com/ossrs/srs/releases/tag/v5.0-b5), v5.0-b5, 5.0 beta5, v5.0.185, 163254 lines.
* 2023-08-31, [Release v5.0-b4](https://github.com/ossrs/srs/releases/tag/v5.0-b4), v5.0-b4, 5.0 beta4, v5.0.176, 162919 lines.

4
proxy/.gitignore vendored Normal file
View file

@ -0,0 +1,4 @@
.idea
srs-proxy
.env
.go-formarted

23
proxy/Makefile Normal file
View file

@ -0,0 +1,23 @@
.PHONY: all build test fmt clean run
all: build
build: fmt ./srs-proxy
./srs-proxy: *.go
go build -o srs-proxy .
test:
go test ./...
fmt: ./.go-formarted
./.go-formarted: *.go
touch .go-formarted
go fmt ./...
clean:
rm -f srs-proxy .go-formarted
run: fmt
go run .

272
proxy/api.go Normal file
View file

@ -0,0 +1,272 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP,
// to proxy other HTTP API of SRS like the streams and clients, etc.
type srsHTTPAPIServer struct {
// The underlayer HTTP server.
server *http.Server
// The WebRTC server.
rtc *srsWebRTCServer
// The gracefully quit timeout, wait server to quit.
gracefulQuitTimeout time.Duration
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewSRSHTTPAPIServer(opts ...func(*srsHTTPAPIServer)) *srsHTTPAPIServer {
v := &srsHTTPAPIServer{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsHTTPAPIServer) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
v.wg.Wait()
return nil
}
func (v *srsHTTPAPIServer) Run(ctx context.Context) error {
// Parse address to listen.
addr := envHttpAPI()
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
// Create server and handler.
mux := http.NewServeMux()
v.server = &http.Server{Addr: addr, Handler: mux}
logger.Df(ctx, "HTTP API server listen at %v", addr)
// Shutdown the server gracefully when quiting.
go func() {
ctxParent := ctx
<-ctxParent.Done()
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
}()
// The basic version handler, also can be used as health check API.
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
apiResponse(ctx, w, r, map[string]string{
"signature": Signature(),
"version": Version(),
})
})
// The WebRTC WHIP API handler.
logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr)
mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) {
if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
}
})
// The WebRTC WHEP API handler.
logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr)
mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) {
if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
}
})
// Run HTTP API server.
v.wg.Add(1)
go func() {
defer v.wg.Done()
err := v.server.ListenAndServe()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If HTTP API server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "HTTP API accept err %+v", err)
} else {
logger.Df(ctx, "HTTP API server done")
}
}
}()
return nil
}
// systemAPI is the system HTTP API of the proxy server, for SRS media server to register the service
// to proxy server. It also provides some other system APIs like the status of proxy server, like exporter
// for Prometheus metrics.
type systemAPI struct {
// The underlayer HTTP server.
server *http.Server
// The gracefully quit timeout, wait server to quit.
gracefulQuitTimeout time.Duration
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewSystemAPI(opts ...func(*systemAPI)) *systemAPI {
v := &systemAPI{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *systemAPI) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
v.wg.Wait()
return nil
}
func (v *systemAPI) Run(ctx context.Context) error {
// Parse address to listen.
addr := envSystemAPI()
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
// Create server and handler.
mux := http.NewServeMux()
v.server = &http.Server{Addr: addr, Handler: mux}
logger.Df(ctx, "System API server listen at %v", addr)
// Shutdown the server gracefully when quiting.
go func() {
ctxParent := ctx
<-ctxParent.Done()
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
}()
// The basic version handler, also can be used as health check API.
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
apiResponse(ctx, w, r, map[string]string{
"signature": Signature(),
"version": Version(),
})
})
// The register service for SRS media servers.
logger.Df(ctx, "Handle /api/v1/srs/register by %v", addr)
mux.HandleFunc("/api/v1/srs/register", func(w http.ResponseWriter, r *http.Request) {
if err := func() error {
var deviceID, ip, serverID, serviceID, pid string
var rtmp, stream, api, srt, rtc []string
if err := ParseBody(r.Body, &struct {
// The IP of SRS, mandatory.
IP *string `json:"ip"`
// The server id of SRS, store in file, may not change, mandatory.
ServerID *string `json:"server"`
// The service id of SRS, always change when restarted, mandatory.
ServiceID *string `json:"service"`
// The process id of SRS, always change when restarted, mandatory.
PID *string `json:"pid"`
// The RTMP listen endpoints, mandatory.
RTMP *[]string `json:"rtmp"`
// The HTTP Stream listen endpoints, optional.
HTTP *[]string `json:"http"`
// The API listen endpoints, optional.
API *[]string `json:"api"`
// The SRT listen endpoints, optional.
SRT *[]string `json:"srt"`
// The RTC listen endpoints, optional.
RTC *[]string `json:"rtc"`
// The device id of SRS, optional.
DeviceID *string `json:"device_id"`
}{
IP: &ip, DeviceID: &deviceID,
ServerID: &serverID, ServiceID: &serviceID, PID: &pid,
RTMP: &rtmp, HTTP: &stream, API: &api, SRT: &srt, RTC: &rtc,
}); err != nil {
return errors.Wrapf(err, "parse body")
}
if ip == "" {
return errors.Errorf("empty ip")
}
if serverID == "" {
return errors.Errorf("empty server")
}
if serviceID == "" {
return errors.Errorf("empty service")
}
if pid == "" {
return errors.Errorf("empty pid")
}
if len(rtmp) == 0 {
return errors.Errorf("empty rtmp")
}
server := NewSRSServer(func(srs *SRSServer) {
srs.IP, srs.DeviceID = ip, deviceID
srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid
srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api
srs.SRT, srs.RTC = srt, rtc
srs.UpdatedAt = time.Now()
})
if err := srsLoadBalancer.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update SRS server %+v", server)
}
logger.Df(ctx, "Register SRS media server, %+v", server)
return nil
}(); err != nil {
apiError(ctx, w, r, err)
}
type Response struct {
Code int `json:"code"`
PID string `json:"pid"`
}
apiResponse(ctx, w, r, &Response{
Code: 0, PID: fmt.Sprintf("%v", os.Getpid()),
})
})
// Run System API server.
v.wg.Add(1)
go func() {
defer v.wg.Done()
err := v.server.ListenAndServe()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If System API server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "System API accept err %+v", err)
} else {
logger.Df(ctx, "System API server done")
}
}
}()
return nil
}

20
proxy/debug.go Normal file
View file

@ -0,0 +1,20 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"net/http"
"srs-proxy/logger"
)
func handleGoPprof(ctx context.Context) {
if addr := envGoPprof(); addr != "" {
go func() {
logger.Df(ctx, "Start Go pprof at %v", addr)
http.ListenAndServe(addr, nil)
}()
}
}

226
proxy/env.go Normal file
View file

@ -0,0 +1,226 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"io"
"os"
"path"
"strings"
"srs-proxy/errors"
"srs-proxy/logger"
)
// loadEnvFile loads the environment variables from file. Note that we only use .env file.
func loadEnvFile(ctx context.Context) error {
workDir, err := os.Getwd()
if err != nil {
return errors.Wrapf(err, "getpwd")
}
envFile := path.Join(workDir, ".env")
if _, err := os.Stat(envFile); err != nil {
return nil
}
file, err := os.Open(envFile)
if err != nil {
return errors.Wrapf(err, "open %v", envFile)
}
defer file.Close()
b, err := io.ReadAll(file)
if err != nil {
return errors.Wrapf(err, "read %v", envFile)
}
lines := strings.Split(strings.Replace(string(b), "\r\n", "\n", -1), "\n")
logger.Df(ctx, "load env file %v, lines=%v", envFile, len(lines))
for _, line := range lines {
if strings.HasPrefix(strings.TrimSpace(line), "#") {
continue
}
if pos := strings.IndexByte(line, '='); pos > 0 {
key := strings.TrimSpace(line[:pos])
value := strings.TrimSpace(line[pos+1:])
if v := os.Getenv(key); v != "" {
continue
}
os.Setenv(key, value)
}
}
return nil
}
// buildDefaultEnvironmentVariables setups the default environment variables.
func buildDefaultEnvironmentVariables(ctx context.Context) {
// Whether enable the Go pprof.
setEnvDefault("GO_PPROF", "")
// Force shutdown timeout.
setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s")
// Graceful quit timeout.
setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s")
// The HTTP API server.
setEnvDefault("PROXY_HTTP_API", "11985")
// The HTTP web server.
setEnvDefault("PROXY_HTTP_SERVER", "18080")
// The RTMP media server.
setEnvDefault("PROXY_RTMP_SERVER", "11935")
// The WebRTC media server, via UDP protocol.
setEnvDefault("PROXY_WEBRTC_SERVER", "18000")
// The SRT media server, via UDP protocol.
setEnvDefault("PROXY_SRT_SERVER", "20080")
// The API server of proxy itself.
setEnvDefault("PROXY_SYSTEM_API", "12025")
// The static directory for web server.
setEnvDefault("PROXY_STATIC_FILES", "../trunk/research")
// The load balancer, use redis or memory.
setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "memory")
// The redis server host.
setEnvDefault("PROXY_REDIS_HOST", "127.0.0.1")
// The redis server port.
setEnvDefault("PROXY_REDIS_PORT", "6379")
// The redis server password.
setEnvDefault("PROXY_REDIS_PASSWORD", "")
// The redis server db.
setEnvDefault("PROXY_REDIS_DB", "0")
// Whether enable the default backend server, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_ENABLED", "off")
// Default backend server IP, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1")
// Default backend server port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_RTMP", "1935")
// Default backend api port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985")
// Default backend udp rtc port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000")
// Default backend udp srt port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_SRT", "10080")
logger.Df(ctx, "load .env as GO_PPROF=%v, "+
"PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+
"PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+
"PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+
"PROXY_SYSTEM_API=%v, PROXY_STATIC_FILES=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+
"PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+
"PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+
"PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+
"PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+
"PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v",
envGoPprof(),
envForceQuitTimeout(), envGraceQuitTimeout(),
envHttpAPI(), envHttpServer(), envRtmpServer(),
envWebRTCServer(), envSRTServer(),
envSystemAPI(), envStaticFiles(), envDefaultBackendEnabled(),
envDefaultBackendIP(), envDefaultBackendRTMP(),
envDefaultBackendHttp(), envDefaultBackendAPI(),
envDefaultBackendRTC(), envDefaultBackendSRT(),
envLoadBalancerType(), envRedisHost(), envRedisPort(),
envRedisPassword(), envRedisDB(),
)
}
func envStaticFiles() string {
return os.Getenv("PROXY_STATIC_FILES")
}
func envDefaultBackendSRT() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_SRT")
}
func envDefaultBackendRTC() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_RTC")
}
func envDefaultBackendAPI() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_API")
}
func envSRTServer() string {
return os.Getenv("PROXY_SRT_SERVER")
}
func envWebRTCServer() string {
return os.Getenv("PROXY_WEBRTC_SERVER")
}
func envDefaultBackendHttp() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_HTTP")
}
func envRedisDB() string {
return os.Getenv("PROXY_REDIS_DB")
}
func envRedisPassword() string {
return os.Getenv("PROXY_REDIS_PASSWORD")
}
func envRedisPort() string {
return os.Getenv("PROXY_REDIS_PORT")
}
func envRedisHost() string {
return os.Getenv("PROXY_REDIS_HOST")
}
func envLoadBalancerType() string {
return os.Getenv("PROXY_LOAD_BALANCER_TYPE")
}
func envDefaultBackendRTMP() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP")
}
func envDefaultBackendIP() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_IP")
}
func envDefaultBackendEnabled() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED")
}
func envGraceQuitTimeout() string {
return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT")
}
func envForceQuitTimeout() string {
return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT")
}
func envGoPprof() string {
return os.Getenv("GO_PPROF")
}
func envSystemAPI() string {
return os.Getenv("PROXY_SYSTEM_API")
}
func envRtmpServer() string {
return os.Getenv("PROXY_RTMP_SERVER")
}
func envHttpServer() string {
return os.Getenv("PROXY_HTTP_SERVER")
}
func envHttpAPI() string {
return os.Getenv("PROXY_HTTP_API")
}
// setEnvDefault set env key=value if not set.
func setEnvDefault(key, value string) {
if os.Getenv(key) == "" {
os.Setenv(key, value)
}
}

270
proxy/errors/errors.go Normal file
View file

@ -0,0 +1,270 @@
// Package errors provides simple error handling primitives.
//
// The traditional error handling idiom in Go is roughly akin to
//
// if err != nil {
// return err
// }
//
// which applied recursively up the call stack results in error reports
// without context or debugging information. The errors package allows
// programmers to add context to the failure path in their code in a way
// that does not destroy the original value of the error.
//
// Adding context to an error
//
// The errors.Wrap function returns a new error that adds context to the
// original error by recording a stack trace at the point Wrap is called,
// and the supplied message. For example
//
// _, err := ioutil.ReadAll(r)
// if err != nil {
// return errors.Wrap(err, "read failed")
// }
//
// If additional control is required the errors.WithStack and errors.WithMessage
// functions destructure errors.Wrap into its component operations of annotating
// an error with a stack trace and an a message, respectively.
//
// Retrieving the cause of an error
//
// Using errors.Wrap constructs a stack of errors, adding context to the
// preceding error. Depending on the nature of the error it may be necessary
// to reverse the operation of errors.Wrap to retrieve the original error
// for inspection. Any error value which implements this interface
//
// type causer interface {
// Cause() error
// }
//
// can be inspected by errors.Cause. errors.Cause will recursively retrieve
// the topmost error which does not implement causer, which is assumed to be
// the original cause. For example:
//
// switch err := errors.Cause(err).(type) {
// case *MyError:
// // handle specifically
// default:
// // unknown error
// }
//
// causer interface is not exported by this package, but is considered a part
// of stable public API.
//
// Formatted printing of errors
//
// All error values returned from this package implement fmt.Formatter and can
// be formatted by the fmt package. The following verbs are supported
//
// %s print the error. If the error has a Cause it will be
// printed recursively
// %v see %s
// %+v extended format. Each Frame of the error's StackTrace will
// be printed in detail.
//
// Retrieving the stack trace of an error or wrapper
//
// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are
// invoked. This information can be retrieved with the following interface.
//
// type stackTracer interface {
// StackTrace() errors.StackTrace
// }
//
// Where errors.StackTrace is defined as
//
// type StackTrace []Frame
//
// The Frame type represents a call site in the stack trace. Frame supports
// the fmt.Formatter interface that can be used for printing information about
// the stack trace of this error. For example:
//
// if err, ok := err.(stackTracer); ok {
// for _, f := range err.StackTrace() {
// fmt.Printf("%+s:%d", f)
// }
// }
//
// stackTracer interface is not exported by this package, but is considered a part
// of stable public API.
//
// See the documentation for Frame.Format for more details.
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
)
// New returns an error with the supplied message.
// New also records the stack trace at the point it was called.
func New(message string) error {
return &fundamental{
msg: message,
stack: callers(),
}
}
// Errorf formats according to a format specifier and returns the string
// as a value that satisfies error.
// Errorf also records the stack trace at the point it was called.
func Errorf(format string, args ...interface{}) error {
return &fundamental{
msg: fmt.Sprintf(format, args...),
stack: callers(),
}
}
// fundamental is an error that has a message and a stack, but no caller.
type fundamental struct {
msg string
*stack
}
func (f *fundamental) Error() string { return f.msg }
func (f *fundamental) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
io.WriteString(s, f.msg)
f.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, f.msg)
case 'q':
fmt.Fprintf(s, "%q", f.msg)
}
}
// WithStack annotates err with a stack trace at the point WithStack was called.
// If err is nil, WithStack returns nil.
func WithStack(err error) error {
if err == nil {
return nil
}
return &withStack{
err,
callers(),
}
}
type withStack struct {
error
*stack
}
func (w *withStack) Cause() error { return w.error }
func (w *withStack) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v", w.Cause())
w.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, w.Error())
case 'q':
fmt.Fprintf(s, "%q", w.Error())
}
}
// Wrap returns an error annotating err with a stack trace
// at the point Wrap is called, and the supplied message.
// If err is nil, Wrap returns nil.
func Wrap(err error, message string) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: message,
}
return &withStack{
err,
callers(),
}
}
// Wrapf returns an error annotating err with a stack trace
// at the point Wrapf is call, and the format specifier.
// If err is nil, Wrapf returns nil.
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: fmt.Sprintf(format, args...),
}
return &withStack{
err,
callers(),
}
}
// WithMessage annotates err with a new message.
// If err is nil, WithMessage returns nil.
func WithMessage(err error, message string) error {
if err == nil {
return nil
}
return &withMessage{
cause: err,
msg: message,
}
}
type withMessage struct {
cause error
msg string
}
func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() }
func (w *withMessage) Cause() error { return w.cause }
func (w *withMessage) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v\n", w.Cause())
io.WriteString(s, w.msg)
return
}
fallthrough
case 's', 'q':
io.WriteString(s, w.Error())
}
}
// Cause returns the underlying cause of the error, if possible.
// An error value has a cause if it implements the following
// interface:
//
// type causer interface {
// Cause() error
// }
//
// If the error does not implement Cause, the original error will
// be returned. If the error is nil, nil will be returned without further
// investigation.
func Cause(err error) error {
type causer interface {
Cause() error
}
for err != nil {
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return err
}

187
proxy/errors/stack.go Normal file
View file

@ -0,0 +1,187 @@
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
"path"
"runtime"
"strings"
)
// Frame represents a program counter inside a stack frame.
type Frame uintptr
// pc returns the program counter for this frame;
// multiple frames may have the same PC value.
func (f Frame) pc() uintptr { return uintptr(f) - 1 }
// file returns the full path to the file that contains the
// function for this Frame's pc.
func (f Frame) file() string {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return "unknown"
}
file, _ := fn.FileLine(f.pc())
return file
}
// line returns the line number of source code of the
// function for this Frame's pc.
func (f Frame) line() int {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return 0
}
_, line := fn.FileLine(f.pc())
return line
}
// Format formats the frame according to the fmt.Formatter interface.
//
// %s source file
// %d source line
// %n function name
// %v equivalent to %s:%d
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+s path of source file relative to the compile time GOPATH
// %+v equivalent to %+s:%d
func (f Frame) Format(s fmt.State, verb rune) {
switch verb {
case 's':
switch {
case s.Flag('+'):
pc := f.pc()
fn := runtime.FuncForPC(pc)
if fn == nil {
io.WriteString(s, "unknown")
} else {
file, _ := fn.FileLine(pc)
fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file)
}
default:
io.WriteString(s, path.Base(f.file()))
}
case 'd':
fmt.Fprintf(s, "%d", f.line())
case 'n':
name := runtime.FuncForPC(f.pc()).Name()
io.WriteString(s, funcname(name))
case 'v':
f.Format(s, 's')
io.WriteString(s, ":")
f.Format(s, 'd')
}
}
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
type StackTrace []Frame
// Format formats the stack of Frames according to the fmt.Formatter interface.
//
// %s lists source files for each Frame in the stack
// %v lists the source file and line number for each Frame in the stack
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+v Prints filename, function, and line number for each Frame in the stack.
func (st StackTrace) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case s.Flag('+'):
for _, f := range st {
fmt.Fprintf(s, "\n%+v", f)
}
case s.Flag('#'):
fmt.Fprintf(s, "%#v", []Frame(st))
default:
fmt.Fprintf(s, "%v", []Frame(st))
}
case 's':
fmt.Fprintf(s, "%s", []Frame(st))
}
}
// stack represents a stack of program counters.
type stack []uintptr
func (s *stack) Format(st fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case st.Flag('+'):
for _, pc := range *s {
f := Frame(pc)
fmt.Fprintf(st, "\n%+v", f)
}
}
}
}
func (s *stack) StackTrace() StackTrace {
f := make([]Frame, len(*s))
for i := 0; i < len(f); i++ {
f[i] = Frame((*s)[i])
}
return f
}
func callers() *stack {
const depth = 32
var pcs [depth]uintptr
n := runtime.Callers(3, pcs[:])
var st stack = pcs[0:n]
return &st
}
// funcname removes the path prefix component of a function's name reported by func.Name().
func funcname(name string) string {
i := strings.LastIndex(name, "/")
name = name[i+1:]
i = strings.Index(name, ".")
return name[i+1:]
}
func trimGOPATH(name, file string) string {
// Here we want to get the source file path relative to the compile time
// GOPATH. As of Go 1.6.x there is no direct way to know the compiled
// GOPATH at runtime, but we can infer the number of path segments in the
// GOPATH. We note that fn.Name() returns the function name qualified by
// the import path, which does not include the GOPATH. Thus we can trim
// segments from the beginning of the file path until the number of path
// separators remaining is one more than the number of path separators in
// the function name. For example, given:
//
// GOPATH /home/user
// file /home/user/src/pkg/sub/file.go
// fn.Name() pkg/sub.Type.Method
//
// We want to produce:
//
// pkg/sub/file.go
//
// From this we can easily see that fn.Name() has one less path separator
// than our desired output. We count separators from the end of the file
// path until it finds two more than in the function name and then move
// one character forward to preserve the initial path segment without a
// leading separator.
const sep = "/"
goal := strings.Count(name, sep) + 2
i := len(file)
for n := 0; n < goal; n++ {
i = strings.LastIndex(file[:i], sep)
if i == -1 {
// not enough separators found, set i so that the slice expression
// below leaves file unmodified
i = -len(sep)
break
}
}
// get back to 0 or trim the leading separator
file = file[i+len(sep):]
return file
}

10
proxy/go.mod Normal file
View file

@ -0,0 +1,10 @@
module srs-proxy
go 1.18
require github.com/go-redis/redis/v8 v8.11.5
require (
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)

15
proxy/go.sum Normal file
View file

@ -0,0 +1,15 @@
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=

419
proxy/http.go Normal file
View file

@ -0,0 +1,419 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strconv"
"strings"
stdSync "sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS,
// HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy
// the request to the origin server.
type srsHTTPStreamServer struct {
// The underlayer HTTP server.
server *http.Server
// The gracefully quit timeout, wait server to quit.
gracefulQuitTimeout time.Duration
// The wait group for all goroutines.
wg stdSync.WaitGroup
}
func NewSRSHTTPStreamServer(opts ...func(*srsHTTPStreamServer)) *srsHTTPStreamServer {
v := &srsHTTPStreamServer{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsHTTPStreamServer) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
v.wg.Wait()
return nil
}
func (v *srsHTTPStreamServer) Run(ctx context.Context) error {
// Parse address to listen.
addr := envHttpServer()
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
// Create server and handler.
mux := http.NewServeMux()
v.server = &http.Server{Addr: addr, Handler: mux}
logger.Df(ctx, "HTTP Stream server listen at %v", addr)
// Shutdown the server gracefully when quiting.
go func() {
ctxParent := ctx
<-ctxParent.Done()
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
}()
// The basic version handler, also can be used as health check API.
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
type Response struct {
Code int `json:"code"`
PID string `json:"pid"`
Data struct {
Major int `json:"major"`
Minor int `json:"minor"`
Revision int `json:"revision"`
Version string `json:"version"`
} `json:"data"`
}
res := Response{Code: 0, PID: fmt.Sprintf("%v", os.Getpid())}
res.Data.Major = VersionMajor()
res.Data.Minor = VersionMinor()
res.Data.Revision = VersionRevision()
res.Data.Version = Version()
apiResponse(ctx, w, r, &res)
})
// The static web server, for the web pages.
var staticServer http.Handler
if staticFiles := envStaticFiles(); staticFiles != "" {
if _, err := os.Stat(staticFiles); err != nil {
return errors.Wrapf(err, "invalid static files %v", staticFiles)
}
staticServer = http.FileServer(http.Dir(staticFiles))
logger.Df(ctx, "Handle static files at %v", staticFiles)
}
// The default handler, for both static web server and streaming server.
logger.Df(ctx, "Handle / by %v", addr)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// For HLS streaming, we will proxy the request to the streaming server.
if strings.HasSuffix(r.URL.Path, ".m3u8") {
unifiedURL, fullURL := convertURLToStreamURL(r)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
http.Error(w, fmt.Sprintf("build stream url by %v from %v", unifiedURL, fullURL), http.StatusBadRequest)
return
}
stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) {
s.SRSProxyBackendHLSID = logger.GenerateContextID()
s.StreamURL, s.FullURL = streamURL, fullURL
}))
stream.Initialize(ctx).ServeHTTP(w, r)
return
}
// For HTTP streaming, we will proxy the request to the streaming server.
if strings.HasSuffix(r.URL.Path, ".flv") ||
strings.HasSuffix(r.URL.Path, ".ts") {
// If SPBHID is specified, it must be a HLS stream client.
if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" {
if stream, err := srsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil {
http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest)
} else {
stream.Initialize(ctx).ServeHTTP(w, r)
}
return
}
// Use HTTP pseudo streaming to proxy the request.
NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) {
c.ctx = ctx
}).ServeHTTP(w, r)
return
}
// Serve by static server.
if staticServer != nil {
staticServer.ServeHTTP(w, r)
return
}
http.NotFound(w, r)
})
// Run HTTP server.
v.wg.Add(1)
go func() {
defer v.wg.Done()
err := v.server.ListenAndServe()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If HTTP Stream server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "HTTP Stream accept err %+v", err)
} else {
logger.Df(ctx, "HTTP Stream server done")
}
}
}()
return nil
}
// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS
// connection. There is no state need to be sync between proxy servers.
//
// When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request,
// then proxy to the corresponding backend server. All state is in the HTTP request, so this
// connection is stateless.
type HTTPFlvTsConnection struct {
// The context for HTTP streaming.
ctx context.Context
}
func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection {
v := &HTTPFlvTsConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
ctx := logger.WithContext(v.ctx)
if err := v.serve(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
} else {
logger.Df(ctx, "HTTP client done")
}
}
func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Build the stream URL in vhost/app/stream schema.
unifiedURL, fullURL := convertURLToStreamURL(r)
logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, fullURL)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
return errors.Wrapf(err, "build stream url %v", unifiedURL)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.serveByBackend(ctx, w, r, backend); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error {
// Parse HTTP port from backend.
if len(backend.HTTP) == 0 {
return errors.Errorf("no http stream server")
}
var httpPort int
if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse http port %v", backend.HTTP[0])
} else {
httpPort = int(iv)
}
// Connect to backend SRS server via HTTP client.
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path)
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil)
if err != nil {
return errors.Wrapf(err, "create request to %v", backendURL)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return errors.Wrapf(err, "do request to %v", backendURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
logger.Df(ctx, "HTTP start streaming")
// Proxy the stream from backend to client.
if _, err := io.Copy(w, resp.Body); err != nil {
return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL)
}
return nil
}
// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS
// clients will share this object, and they do not use the same ctx among proxy servers.
//
// Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections.
// Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create
// the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert
// to the stream URL and then query the backend server to serve it.
type HLSPlayStream struct {
// The context for HLS streaming.
ctx context.Context
// The spbhid, used to identify the backend server.
SRSProxyBackendHLSID string `json:"spbhid"`
// The stream URL in vhost/app/stream schema.
StreamURL string `json:"stream_url"`
// The full request URL for HLS streaming
FullURL string `json:"full_url"`
}
func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream {
v := &HLSPlayStream{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream {
if v.ctx == nil {
v.ctx = logger.WithContext(ctx)
}
return v
}
func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if err := v.serve(v.ctx, w, r); err != nil {
apiError(v.ctx, w, r, err)
} else {
logger.Df(v.ctx, "HLS client %v for %v with %v done",
v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path)
}
}
func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.serveByBackend(ctx, w, r, backend); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error {
// Parse HTTP port from backend.
if len(backend.HTTP) == 0 {
return errors.Errorf("no rtmp server")
}
var httpPort int
if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse http port %v", backend.HTTP[0])
} else {
httpPort = int(iv)
}
// Connect to backend SRS server via HTTP client.
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path)
if r.URL.RawQuery != "" {
backendURL += "?" + r.URL.RawQuery
}
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil)
if err != nil {
return errors.Wrapf(err, "create request to %v", backendURL)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return errors.Errorf("do request to %v EOF", backendURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
// For TS file, directly copy it.
if !strings.HasSuffix(r.URL.Path, ".m3u8") {
if _, err := io.Copy(w, resp.Body); err != nil {
return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL)
}
return nil
}
// Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts
// URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID.
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errors.Wrapf(err, "read stream from %v", backendURL)
}
m3u8 := string(b)
if strings.Contains(m3u8, ".ts?") {
m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID))
} else {
m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID))
}
if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil {
return errors.Wrapf(err, "proxy m3u8 client to %v", backendURL)
}
return nil
}

43
proxy/logger/context.go Normal file
View file

@ -0,0 +1,43 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package logger
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
)
type key string
var cidKey key = "cid.proxy.ossrs.org"
// generateContextID generates a random context id in string.
func GenerateContextID() string {
randomBytes := make([]byte, 32)
_, _ = rand.Read(randomBytes)
hash := sha256.Sum256(randomBytes)
hashString := hex.EncodeToString(hash[:])
cid := hashString[:7]
return cid
}
// WithContext creates a new context with cid, which will be used for log.
func WithContext(ctx context.Context) context.Context {
return WithContextID(ctx, GenerateContextID())
}
// WithContextID creates a new context with cid, which will be used for log.
func WithContextID(ctx context.Context, cid string) context.Context {
return context.WithValue(ctx, cidKey, cid)
}
// ContextID returns the cid in context, or empty string if not set.
func ContextID(ctx context.Context) string {
if cid, ok := ctx.Value(cidKey).(string); ok {
return cid
}
return ""
}

87
proxy/logger/log.go Normal file
View file

@ -0,0 +1,87 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package logger
import (
"context"
"io/ioutil"
stdLog "log"
"os"
)
type logger interface {
Printf(ctx context.Context, format string, v ...any)
}
type loggerPlus struct {
logger *stdLog.Logger
level string
}
func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus {
v := &loggerPlus{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) {
format, args := f, a
if cid := ContextID(ctx); cid != "" {
format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...)
}
v.logger.Printf(format, args...)
}
var verboseLogger logger
func Vf(ctx context.Context, format string, a ...interface{}) {
verboseLogger.Printf(ctx, format, a...)
}
var debugLogger logger
func Df(ctx context.Context, format string, a ...interface{}) {
debugLogger.Printf(ctx, format, a...)
}
var warnLogger logger
func Wf(ctx context.Context, format string, a ...interface{}) {
warnLogger.Printf(ctx, format, a...)
}
var errorLogger logger
func Ef(ctx context.Context, format string, a ...interface{}) {
errorLogger.Printf(ctx, format, a...)
}
const (
logVerboseLabel = "verb"
logDebugLabel = "debug"
logWarnLabel = "warn"
logErrorLabel = "error"
)
func init() {
verboseLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logVerboseLabel
})
debugLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logDebugLabel
})
warnLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logWarnLabel
})
errorLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logErrorLabel
})
}

121
proxy/main.go Normal file
View file

@ -0,0 +1,121 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"os"
"srs-proxy/errors"
"srs-proxy/logger"
)
func main() {
ctx := logger.WithContext(context.Background())
logger.Df(ctx, "%v/%v started", Signature(), Version())
// Install signals.
ctx, cancel := context.WithCancel(ctx)
installSignals(ctx, cancel)
// Start the main loop, ignore the user cancel error.
err := doMain(ctx)
if err != nil && ctx.Err() != context.Canceled {
logger.Ef(ctx, "main: %+v", err)
os.Exit(-1)
}
logger.Df(ctx, "%v done", Signature())
}
func doMain(ctx context.Context) error {
// Setup the environment variables.
if err := loadEnvFile(ctx); err != nil {
return errors.Wrapf(err, "load env")
}
buildDefaultEnvironmentVariables(ctx)
// When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur
// because the main thread exits after the context is cancelled. However, sometimes the main thread
// may be blocked for some reason, so a forced exit is necessary to ensure the program terminates.
if err := installForceQuit(ctx); err != nil {
return errors.Wrapf(err, "install force quit")
}
// Start the Go pprof if enabled.
handleGoPprof(ctx)
// Initialize SRS load balancers.
switch lbType := envLoadBalancerType(); lbType {
case "memory":
srsLoadBalancer = NewMemoryLoadBalancer()
case "redis":
srsLoadBalancer = NewRedisLoadBalancer()
default:
return errors.Errorf("invalid load balancer %v", lbType)
}
if err := srsLoadBalancer.Initialize(ctx); err != nil {
return errors.Wrapf(err, "initialize srs load balancer")
}
// Parse the gracefully quit timeout.
gracefulQuitTimeout, err := parseGracefullyQuitTimeout()
if err != nil {
return errors.Wrapf(err, "parse gracefully quit timeout")
}
// Start the RTMP server.
srsRTMPServer := NewSRSRTMPServer()
defer srsRTMPServer.Close()
if err := srsRTMPServer.Run(ctx); err != nil {
return errors.Wrapf(err, "rtmp server")
}
// Start the WebRTC server.
srsWebRTCServer := NewSRSWebRTCServer()
defer srsWebRTCServer.Close()
if err := srsWebRTCServer.Run(ctx); err != nil {
return errors.Wrapf(err, "rtc server")
}
// Start the HTTP API server.
srsHTTPAPIServer := NewSRSHTTPAPIServer(func(server *srsHTTPAPIServer) {
server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, srsWebRTCServer
})
defer srsHTTPAPIServer.Close()
if err := srsHTTPAPIServer.Run(ctx); err != nil {
return errors.Wrapf(err, "http api server")
}
// Start the SRT server.
srsSRTServer := NewSRSSRTServer()
defer srsSRTServer.Close()
if err := srsSRTServer.Run(ctx); err != nil {
return errors.Wrapf(err, "srt server")
}
// Start the System API server.
systemAPI := NewSystemAPI(func(server *systemAPI) {
server.gracefulQuitTimeout = gracefulQuitTimeout
})
defer systemAPI.Close()
if err := systemAPI.Run(ctx); err != nil {
return errors.Wrapf(err, "system api server")
}
// Start the HTTP web server.
srsHTTPStreamServer := NewSRSHTTPStreamServer(func(server *srsHTTPStreamServer) {
server.gracefulQuitTimeout = gracefulQuitTimeout
})
defer srsHTTPStreamServer.Close()
if err := srsHTTPStreamServer.Run(ctx); err != nil {
return errors.Wrapf(err, "http server")
}
// Wait for the main loop to quit.
<-ctx.Done()
return nil
}

515
proxy/rtc.go Normal file
View file

@ -0,0 +1,515 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/binary"
"fmt"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
stdSync "sync"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out
// which backend server to proxy to. It will also replace the UDP port to the proxy server's in the
// SDP answer.
type srsWebRTCServer struct {
// The UDP listener for WebRTC server.
listener *net.UDPConn
// Fast cache for the username to identify the connection.
// The key is username, the value is the UDP address.
usernames sync.Map[string, *RTCConnection]
// Fast cache for the udp address to identify the connection.
// The key is UDP address, the value is the username.
// TODO: Support fast earch by uint64 address.
addresses sync.Map[string, *RTCConnection]
// The wait group for server.
wg stdSync.WaitGroup
}
func NewSRSWebRTCServer(opts ...func(*srsWebRTCServer)) *srsWebRTCServer {
v := &srsWebRTCServer{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsWebRTCServer) Close() error {
if v.listener != nil {
_ = v.listener.Close()
}
v.wg.Wait()
return nil
}
func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
ctx = logger.WithContext(ctx)
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Read remote SDP offer from body.
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
if err != nil {
return errors.Wrapf(err, "read remote sdp offer")
}
// Build the stream URL in vhost/app/stream schema.
unifiedURL, fullURL := convertURLToStreamURL(r)
logger.Df(ctx, "Got WebRTC WHIP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
return errors.Wrapf(err, "build stream url %v", unifiedURL)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
ctx = logger.WithContext(ctx)
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Read remote SDP offer from body.
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
if err != nil {
return errors.Wrapf(err, "read remote sdp offer")
}
// Build the stream URL in vhost/app/stream schema.
unifiedURL, fullURL := convertURLToStreamURL(r)
logger.Df(ctx, "Got WebRTC WHEP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
return errors.Wrapf(err, "build stream url %v", unifiedURL)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *srsWebRTCServer) proxyApiToBackend(
ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer,
remoteSDPOffer string, streamURL string,
) error {
// Parse HTTP port from backend.
if len(backend.API) == 0 {
return errors.Errorf("no http api server")
}
var apiPort int
if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse http port %v", backend.API[0])
} else {
apiPort = int(iv)
}
// Connect to backend SRS server via HTTP client.
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path)
if r.URL.RawQuery != "" {
backendURL += "?" + r.URL.RawQuery
}
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer))
if err != nil {
return errors.Wrapf(err, "create request to %v", backendURL)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return errors.Errorf("do request to %v EOF", backendURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return errors.Errorf("proxy api to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
// Parse the local SDP answer from backend.
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errors.Wrapf(err, "read stream from %v", backendURL)
}
// Replace the WebRTC UDP port in answer.
localSDPAnswer := string(b)
for _, endpoint := range backend.RTC {
_, _, port, err := parseListenEndpoint(endpoint)
if err != nil {
return errors.Wrapf(err, "parse endpoint %v", endpoint)
}
from := fmt.Sprintf(" %v typ host", port)
to := fmt.Sprintf(" %v typ host", envWebRTCServer())
localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1)
}
// Fetch the ice-ufrag and ice-pwd from local SDP answer.
remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer)
if err != nil {
return errors.Wrapf(err, "parse remote sdp offer")
}
localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer)
if err != nil {
return errors.Wrapf(err, "parse local sdp answer")
}
// Save the new WebRTC connection to LB.
icePair := &RTCICEPair{
RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd,
LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd,
}
if err := srsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) {
c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag()
c.Initialize(ctx, v.listener)
// Cache the connection for fast search by username.
v.usernames.Store(c.Ufrag, c)
})); err != nil {
return errors.Wrapf(err, "load or store webrtc %v", streamURL)
}
// Response client with local answer.
if _, err = w.Write([]byte(localSDPAnswer)); err != nil {
return errors.Wrapf(err, "write local sdp answer %v", localSDPAnswer)
}
logger.Df(ctx, "Create WebRTC connection with local answer %vB with ice-ufrag=%v, ice-pwd=%vB",
len(localSDPAnswer), localICEUfrag, len(localICEPwd))
return nil
}
func (v *srsWebRTCServer) Run(ctx context.Context) error {
// Parse address to listen.
endpoint := envWebRTCServer()
if !strings.Contains(endpoint, ":") {
endpoint = fmt.Sprintf(":%v", endpoint)
}
saddr, err := net.ResolveUDPAddr("udp", endpoint)
if err != nil {
return errors.Wrapf(err, "resolve udp addr %v", endpoint)
}
listener, err := net.ListenUDP("udp", saddr)
if err != nil {
return errors.Wrapf(err, "listen udp %v", saddr)
}
v.listener = listener
logger.Df(ctx, "WebRTC server listen at %v", saddr)
// Consume all messages from UDP media transport.
v.wg.Add(1)
go func() {
defer v.wg.Done()
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, caddr, err := listener.ReadFromUDP(buf)
if err != nil {
// TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "read from udp failed, err=%+v", err)
continue
}
if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil {
logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err)
}
}
}()
return nil
}
func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
var connection *RTCConnection
// If STUN binding request, parse the ufrag and identify the connection.
if err := func() error {
if rtcIsRTPOrRTCP(data) || !rtcIsSTUN(data) {
return nil
}
var pkt RTCStunPacket
if err := pkt.UnmarshalBinary(data); err != nil {
return errors.Wrapf(err, "unmarshal stun packet")
}
// Search the connection in fast cache.
if s, ok := v.usernames.Load(pkt.Username); ok {
connection = s
return nil
}
// Load connection by username.
if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil {
return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username)
} else {
connection = s.Initialize(ctx, v.listener)
logger.Df(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL)
}
// Cache connection for fast search.
if connection != nil {
v.usernames.Store(pkt.Username, connection)
}
return nil
}(); err != nil {
return err
}
// Search the connection by addr.
if s, ok := v.addresses.Load(addr.String()); ok {
connection = s
} else if connection != nil {
// Cache the address for fast search.
v.addresses.Store(addr.String(), connection)
}
// If connection is not found, ignore the packet.
if connection == nil {
// TODO: Should logging the dropped packet, only logging the first one for each address.
return nil
}
// Proxy the packet to backend.
if err := connection.HandlePacket(addr, data); err != nil {
return errors.Wrapf(err, "proxy %vB for %v", len(data), connection.StreamURL)
}
return nil
}
// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC
// connection, identify by the ufrag in sdp offer/answer and ICE binding request.
//
// It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is
// in the client request. The RTCConnection is stateful, and need to sync the ufrag between
// proxy servers.
//
// The media transport is UDP, which is also a special thing for WebRTC. So if the client switch
// to another UDP address, it may connect to another WebRTC proxy, then we should discover the
// RTCConnection by the ufrag from the ICE binding request.
type RTCConnection struct {
// The stream context for WebRTC streaming.
ctx context.Context
// The stream URL in vhost/app/stream schema.
StreamURL string `json:"stream_url"`
// The ufrag for this WebRTC connection.
Ufrag string `json:"ufrag"`
// The UDP connection proxy to backend.
backendUDP *net.UDPConn
// The client UDP address. Note that it may change.
clientUDP *net.UDPAddr
// The listener UDP connection, used to send messages to client.
listenerUDP *net.UDPConn
}
func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection {
v := &RTCConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection {
if v.ctx == nil {
v.ctx = logger.WithContext(ctx)
}
if listener != nil {
v.listenerUDP = listener
}
return v
}
func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
ctx := v.ctx
// Update the current UDP address.
v.clientUDP = addr
// Start the UDP proxy to backend.
if err := v.connectBackend(ctx); err != nil {
return errors.Wrapf(err, "connect backend for %v", v.StreamURL)
}
// Proxy client message to backend.
if v.backendUDP == nil {
return nil
}
// Proxy all messages from backend to client.
go func() {
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, _, err := v.backendUDP.ReadFromUDP(buf)
if err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "read from backend failed, err=%v", err)
break
}
if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "write to client failed, err=%v", err)
break
}
}
}()
if _, err := v.backendUDP.Write(data); err != nil {
return errors.Wrapf(err, "write to backend %v", v.StreamURL)
}
return nil
}
func (v *RTCConnection) connectBackend(ctx context.Context) error {
if v.backendUDP != nil {
return nil
}
// Pick a backend SRS server to proxy the RTC stream.
backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL)
if err != nil {
return errors.Wrapf(err, "pick backend")
}
// Parse UDP port from backend.
if len(backend.RTC) == 0 {
return errors.Errorf("no udp server")
}
_, _, udpPort, err := parseListenEndpoint(backend.RTC[0])
if err != nil {
return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.RTC[0], backend, v.StreamURL)
}
// Connect to backend SRS server via UDP client.
// TODO: FIXME: Support close the connection when timeout or DTLS alert.
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)}
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
return errors.Wrapf(err, "dial udp to %v", backendAddr)
} else {
v.backendUDP = backendUDP
}
return nil
}
type RTCICEPair struct {
// The remote ufrag, used for ICE username and session id.
RemoteICEUfrag string `json:"remote_ufrag"`
// The remote pwd, used for ICE password.
RemoteICEPwd string `json:"remote_pwd"`
// The local ufrag, used for ICE username and session id.
LocalICEUfrag string `json:"local_ufrag"`
// The local pwd, used for ICE password.
LocalICEPwd string `json:"local_pwd"`
}
// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag.
func (v *RTCICEPair) Ufrag() string {
return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag)
}
type RTCStunPacket struct {
// The stun message type.
MessageType uint16
// The stun username, or ufrag.
Username string
}
func (v *RTCStunPacket) UnmarshalBinary(data []byte) error {
if len(data) < 20 {
return errors.Errorf("stun packet too short %v", len(data))
}
p := data
v.MessageType = binary.BigEndian.Uint16(p)
messageLen := binary.BigEndian.Uint16(p[2:])
//magicCookie := p[:8]
//transactionID := p[:20]
p = p[20:]
if len(p) != int(messageLen) {
return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen)
}
for len(p) > 0 {
typ := binary.BigEndian.Uint16(p)
length := binary.BigEndian.Uint16(p[2:])
p = p[4:]
if len(p) < int(length) {
return errors.Errorf("stun attribute length invalid %v < %v", len(p), length)
}
value := p[:length]
p = p[length:]
if length%4 != 0 {
p = p[4-length%4:]
}
switch typ {
case 0x0006:
v.Username = string(value)
}
}
return nil
}

655
proxy/rtmp.go Normal file
View file

@ -0,0 +1,655 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"fmt"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/rtmp"
)
// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS
// server. It will figure out the backend server to proxy to. Unlike the edge server, it will
// not cache the stream, but just proxy the stream to backend.
type srsRTMPServer struct {
// The TCP listener for RTMP server.
listener *net.TCPListener
// The random number generator.
rd *rand.Rand
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewSRSRTMPServer(opts ...func(*srsRTMPServer)) *srsRTMPServer {
v := &srsRTMPServer{
rd: rand.New(rand.NewSource(time.Now().UnixNano())),
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsRTMPServer) Close() error {
if v.listener != nil {
v.listener.Close()
}
v.wg.Wait()
return nil
}
func (v *srsRTMPServer) Run(ctx context.Context) error {
endpoint := envRtmpServer()
if !strings.Contains(endpoint, ":") {
endpoint = ":" + endpoint
}
addr, err := net.ResolveTCPAddr("tcp", endpoint)
if err != nil {
return errors.Wrapf(err, "resolve rtmp addr %v", endpoint)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return errors.Wrapf(err, "listen rtmp addr %v", addr)
}
v.listener = listener
logger.Df(ctx, "RTMP server listen at %v", addr)
v.wg.Add(1)
go func() {
defer v.wg.Done()
for {
conn, err := v.listener.AcceptTCP()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "RTMP server accept err %+v", err)
} else {
logger.Df(ctx, "RTMP server done")
}
return
}
v.wg.Add(1)
go func(ctx context.Context, conn *net.TCPConn) {
defer v.wg.Done()
defer conn.Close()
handleErr := func(err error) {
if isPeerClosedError(err) {
logger.Df(ctx, "RTMP peer is closed")
} else {
logger.Wf(ctx, "RTMP serve err %+v", err)
}
}
rc := NewRTMPConnection(func(client *RTMPConnection) {
client.rd = v.rd
})
if err := rc.serve(ctx, conn); err != nil {
handleErr(err)
} else {
logger.Df(ctx, "RTMP client done")
}
}(logger.WithContext(ctx), conn)
}
}()
return nil
}
// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between
// proxy servers.
//
// When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request,
// then proxy to the corresponding backend server. All state is in the RTMP request, so this
// connection is stateless.
type RTMPConnection struct {
// The random number generator.
rd *rand.Rand
}
func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection {
v := &RTMPConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr())
// If any goroutine quit, cancel another one.
parentCtx := ctx
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var backend *RTMPClientToBackend
if true {
go func() {
<-ctx.Done()
conn.Close()
if backend != nil {
backend.Close()
}
}()
}
// Simple handshake with client.
hs := rtmp.NewHandshake(v.rd)
if _, err := hs.ReadC0S0(conn); err != nil {
return errors.Wrapf(err, "read c0")
}
if _, err := hs.ReadC1S1(conn); err != nil {
return errors.Wrapf(err, "read c1")
}
if err := hs.WriteC0S0(conn); err != nil {
return errors.Wrapf(err, "write s1")
}
if err := hs.WriteC1S1(conn); err != nil {
return errors.Wrapf(err, "write s1")
}
if err := hs.WriteC2S2(conn, hs.C1S1()); err != nil {
return errors.Wrapf(err, "write s2")
}
if _, err := hs.ReadC2S2(conn); err != nil {
return errors.Wrapf(err, "read c2")
}
client := rtmp.NewProtocol(conn)
logger.Df(ctx, "RTMP simple handshake done")
// Expect RTMP connect command with tcUrl.
var connectReq *rtmp.ConnectAppPacket
if _, err := rtmp.ExpectPacket(ctx, client, &connectReq); err != nil {
return errors.Wrapf(err, "expect connect req")
}
if true {
ack := rtmp.NewWindowAcknowledgementSize()
ack.AckSize = 2500000
if err := client.WritePacket(ctx, ack, 0); err != nil {
return errors.Wrapf(err, "write set ack size")
}
}
if true {
chunk := rtmp.NewSetChunkSize()
chunk.ChunkSize = 128
if err := client.WritePacket(ctx, chunk, 0); err != nil {
return errors.Wrapf(err, "write set chunk size")
}
}
connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID)
connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888"))
connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127))
connectRes.CommandObject.Set("mode", rtmp.NewAmf0Number(1))
connectRes.Args.Set("level", rtmp.NewAmf0String("status"))
connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success"))
connectRes.Args.Set("description", rtmp.NewAmf0String("Connection succeeded"))
connectRes.Args.Set("objectEncoding", rtmp.NewAmf0Number(0))
connectResData := rtmp.NewAmf0EcmaArray()
connectResData.Set("version", rtmp.NewAmf0String("3,5,3,888"))
connectResData.Set("srs_version", rtmp.NewAmf0String(Version()))
connectResData.Set("srs_id", rtmp.NewAmf0String(logger.ContextID(ctx)))
connectRes.Args.Set("data", connectResData)
if err := client.WritePacket(ctx, connectRes, 0); err != nil {
return errors.Wrapf(err, "write connect res")
}
tcUrl := connectReq.TcUrl()
logger.Df(ctx, "RTMP connect app %v", tcUrl)
// Expect RTMP command to identify the client, a publisher or viewer.
var currentStreamID, nextStreamID int
var streamName string
var clientType RTMPClientType
for clientType == "" {
var identifyReq rtmp.Packet
if _, err := rtmp.ExpectPacket(ctx, client, &identifyReq); err != nil {
return errors.Wrapf(err, "expect identify req")
}
var response rtmp.Packet
switch pkt := identifyReq.(type) {
case *rtmp.CallPacket:
if pkt.CommandName == "createStream" {
identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID)
response = identifyRes
nextStreamID = 1
identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID))
} else if pkt.CommandName == "getStreamLength" {
// Ignore and do not reply these packets.
} else {
// For releaseStream, FCPublish, etc.
identifyRes := rtmp.NewCallPacket()
response = identifyRes
identifyRes.TransactionID = pkt.TransactionID
identifyRes.CommandName = "_result"
identifyRes.CommandObject = rtmp.NewAmf0Null()
identifyRes.Args = rtmp.NewAmf0Undefined()
}
case *rtmp.PublishPacket:
streamName = string(pkt.StreamName)
clientType = RTMPClientTypePublisher
identifyRes := rtmp.NewCallPacket()
response = identifyRes
identifyRes.CommandName = "onFCPublish"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start"))
data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
identifyRes.Args = data
case *rtmp.PlayPacket:
streamName = string(pkt.StreamName)
clientType = RTMPClientTypeViewer
identifyRes := rtmp.NewCallPacket()
response = identifyRes
identifyRes.CommandName = "onStatus"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("level", rtmp.NewAmf0String("status"))
data.Set("code", rtmp.NewAmf0String("NetStream.Play.Reset"))
data.Set("description", rtmp.NewAmf0String("Playing and resetting stream."))
data.Set("details", rtmp.NewAmf0String("stream"))
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
identifyRes.Args = data
}
if response != nil {
if err := client.WritePacket(ctx, response, currentStreamID); err != nil {
return errors.Wrapf(err, "write identify res for req=%v, stream=%v",
identifyReq, currentStreamID)
}
}
// Update the stream ID for next request.
currentStreamID = nextStreamID
}
logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v",
tcUrl, streamName, currentStreamID, clientType)
// Find a backend SRS server to proxy the RTMP stream.
backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) {
client.rd, client.typ = v.rd, clientType
})
defer backend.Close()
if err := backend.Connect(ctx, tcUrl, streamName); err != nil {
return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName)
}
// Start the streaming.
if clientType == RTMPClientTypePublisher {
identifyRes := rtmp.NewCallPacket()
identifyRes.CommandName = "onStatus"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("level", rtmp.NewAmf0String("status"))
data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start"))
data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
identifyRes.Args = data
if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
return errors.Wrapf(err, "start publish")
}
} else if clientType == RTMPClientTypeViewer {
identifyRes := rtmp.NewCallPacket()
identifyRes.CommandName = "onStatus"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("level", rtmp.NewAmf0String("status"))
data.Set("code", rtmp.NewAmf0String("NetStream.Play.Start"))
data.Set("description", rtmp.NewAmf0String("Started playing stream."))
data.Set("details", rtmp.NewAmf0String("stream"))
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
identifyRes.Args = data
if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
return errors.Wrapf(err, "start play")
}
}
logger.Df(ctx, "RTMP start streaming")
// For all proxy goroutines.
var wg sync.WaitGroup
defer wg.Wait()
// Proxy all message from backend to client.
wg.Add(1)
var r0 error
go func() {
defer wg.Done()
defer cancel()
r0 = func() error {
for {
m, err := backend.client.ReadMessage(ctx)
if err != nil {
return errors.Wrapf(err, "read message")
}
//logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
// TODO: Update the stream ID if not the same.
if err := client.WriteMessage(ctx, m); err != nil {
return errors.Wrapf(err, "write message")
}
}
}()
}()
// Proxy all messages from client to backend.
wg.Add(1)
var r1 error
go func() {
defer wg.Done()
defer cancel()
r1 = func() error {
for {
m, err := client.ReadMessage(ctx)
if err != nil {
return errors.Wrapf(err, "read message")
}
//logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
// TODO: Update the stream ID if not the same.
if err := backend.client.WriteMessage(ctx, m); err != nil {
return errors.Wrapf(err, "write message")
}
}
}()
}()
// Wait until all goroutine quit.
wg.Wait()
// Reset the error if caused by another goroutine.
if r0 != nil {
return errors.Wrapf(r0, "proxy backend->client")
}
if r1 != nil {
return errors.Wrapf(r1, "proxy client->backend")
}
return parentCtx.Err()
}
type RTMPClientType string
const (
RTMPClientTypePublisher RTMPClientType = "publisher"
RTMPClientTypeViewer RTMPClientType = "viewer"
)
// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend.
type RTMPClientToBackend struct {
// The random number generator.
rd *rand.Rand
// The underlayer tcp client.
tcpConn *net.TCPConn
// The RTMP protocol client.
client *rtmp.Protocol
// The stream type.
typ RTMPClientType
}
func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend {
v := &RTMPClientToBackend{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTMPClientToBackend) Close() error {
if v.tcpConn != nil {
v.tcpConn.Close()
}
return nil
}
func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error {
// Build the stream URL in vhost/app/stream schema.
streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName))
if err != nil {
return errors.Wrapf(err, "build stream url %v/%v", tcUrl, streamName)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
// Parse RTMP port from backend.
if len(backend.RTMP) == 0 {
return errors.Errorf("no rtmp server %+v for %v", backend, streamURL)
}
var rtmpPort int
if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.RTMP[0])
} else {
rtmpPort = int(iv)
}
// Connect to backend SRS server via TCP client.
addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort}
c, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend)
}
v.tcpConn = c
hs := rtmp.NewHandshake(v.rd)
client := rtmp.NewProtocol(c)
v.client = client
// Simple RTMP handshake with server.
if err := hs.WriteC0S0(c); err != nil {
return errors.Wrapf(err, "write c0")
}
if err := hs.WriteC1S1(c); err != nil {
return errors.Wrapf(err, "write c1")
}
if _, err = hs.ReadC0S0(c); err != nil {
return errors.Wrapf(err, "read s0")
}
if _, err := hs.ReadC1S1(c); err != nil {
return errors.Wrapf(err, "read s1")
}
if _, err = hs.ReadC2S2(c); err != nil {
return errors.Wrapf(err, "read c2")
}
logger.Df(ctx, "backend simple handshake done, server=%v", addr)
if err := hs.WriteC2S2(c, hs.C1S1()); err != nil {
return errors.Wrapf(err, "write c2")
}
// Connect RTMP app on tcUrl with server.
if true {
connectApp := rtmp.NewConnectAppPacket()
connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl))
if err := client.WritePacket(ctx, connectApp, 1); err != nil {
return errors.Wrapf(err, "write connect app")
}
}
if true {
var connectAppRes *rtmp.ConnectAppResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil {
return errors.Wrapf(err, "expect connect app res")
}
logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID())
}
// Play or view RTMP stream with server.
if v.typ == RTMPClientTypeViewer {
return v.play(ctx, client, streamName)
}
// Publish RTMP stream with server.
return v.publish(ctx, client, streamName)
}
func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error {
if true {
identifyReq := rtmp.NewCallPacket()
identifyReq.CommandName = "releaseStream"
identifyReq.TransactionID = 2
identifyReq.CommandObject = rtmp.NewAmf0Null()
identifyReq.Args = rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
return errors.Wrapf(err, "releaseStream")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect releaseStream res")
}
if identifyRes.CommandName == "_result" {
break
}
}
if true {
identifyReq := rtmp.NewCallPacket()
identifyReq.CommandName = "FCPublish"
identifyReq.TransactionID = 3
identifyReq.CommandObject = rtmp.NewAmf0Null()
identifyReq.Args = rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
return errors.Wrapf(err, "FCPublish")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect FCPublish res")
}
if identifyRes.CommandName == "_result" {
break
}
}
var currentStreamID int
if true {
createStream := rtmp.NewCreateStreamPacket()
createStream.TransactionID = 4
createStream.CommandObject = rtmp.NewAmf0Null()
if err := client.WritePacket(ctx, createStream, 0); err != nil {
return errors.Wrapf(err, "createStream")
}
}
for {
var identifyRes *rtmp.CreateStreamResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect createStream res")
}
if sid := identifyRes.StreamID; sid != 0 {
currentStreamID = int(sid)
break
}
}
if true {
publishStream := rtmp.NewPublishPacket()
publishStream.TransactionID = 5
publishStream.CommandObject = rtmp.NewAmf0Null()
publishStream.StreamName = *rtmp.NewAmf0String(streamName)
publishStream.StreamType = *rtmp.NewAmf0String("live")
if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil {
return errors.Wrapf(err, "publish")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect publish res")
}
// Ignore onFCPublish, expect onStatus(NetStream.Publish.Start).
if identifyRes.CommandName == "onStatus" {
if data := rtmp.NewAmf0Converter(identifyRes.Args).ToObject(); data == nil {
return errors.Errorf("onStatus args not object")
} else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil {
return errors.Errorf("onStatus code not string")
} else if *code != "NetStream.Publish.Start" {
return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code)
}
break
}
}
logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID)
return nil
}
func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error {
var currentStreamID int
if true {
createStream := rtmp.NewCreateStreamPacket()
createStream.TransactionID = 4
createStream.CommandObject = rtmp.NewAmf0Null()
if err := client.WritePacket(ctx, createStream, 0); err != nil {
return errors.Wrapf(err, "createStream")
}
}
for {
var identifyRes *rtmp.CreateStreamResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect createStream res")
}
if sid := identifyRes.StreamID; sid != 0 {
currentStreamID = int(sid)
break
}
}
playStream := rtmp.NewPlayPacket()
playStream.StreamName = *rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil {
return errors.Wrapf(err, "play")
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect releaseStream res")
}
if identifyRes.CommandName == "onStatus" && identifyRes.ArgsCode() == "NetStream.Play.Start" {
break
}
}
return nil
}

771
proxy/rtmp/amf0.go Normal file
View file

@ -0,0 +1,771 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package rtmp
import (
"bytes"
"encoding"
"encoding/binary"
"fmt"
"math"
"sync"
"srs-proxy/errors"
)
// Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview
type amf0Marker uint8
const (
amf0MarkerNumber amf0Marker = iota // 0
amf0MarkerBoolean // 1
amf0MarkerString // 2
amf0MarkerObject // 3
amf0MarkerMovieClip // 4
amf0MarkerNull // 5
amf0MarkerUndefined // 6
amf0MarkerReference // 7
amf0MarkerEcmaArray // 8
amf0MarkerObjectEnd // 9
amf0MarkerStrictArray // 10
amf0MarkerDate // 11
amf0MarkerLongString // 12
amf0MarkerUnsupported // 13
amf0MarkerRecordSet // 14
amf0MarkerXmlDocument // 15
amf0MarkerTypedObject // 16
amf0MarkerAvmPlusObject // 17
amf0MarkerForbidden amf0Marker = 0xff
)
func (v amf0Marker) String() string {
switch v {
case amf0MarkerNumber:
return "Amf0Number"
case amf0MarkerBoolean:
return "amf0Boolean"
case amf0MarkerString:
return "Amf0String"
case amf0MarkerObject:
return "Amf0Object"
case amf0MarkerNull:
return "Null"
case amf0MarkerUndefined:
return "Undefined"
case amf0MarkerReference:
return "Reference"
case amf0MarkerEcmaArray:
return "EcmaArray"
case amf0MarkerObjectEnd:
return "ObjectEnd"
case amf0MarkerStrictArray:
return "StrictArray"
case amf0MarkerDate:
return "Date"
case amf0MarkerLongString:
return "LongString"
case amf0MarkerUnsupported:
return "Unsupported"
case amf0MarkerXmlDocument:
return "XmlDocument"
case amf0MarkerTypedObject:
return "TypedObject"
case amf0MarkerAvmPlusObject:
return "AvmPlusObject"
case amf0MarkerMovieClip:
return "MovieClip"
case amf0MarkerRecordSet:
return "RecordSet"
default:
return "Forbidden"
}
}
// For utest to mock it.
type amf0Buffer interface {
Bytes() []byte
WriteByte(c byte) error
Write(p []byte) (n int, err error)
}
var createBuffer = func() amf0Buffer {
return &bytes.Buffer{}
}
// All AMF0 things.
type amf0Any interface {
// Binary marshaler and unmarshaler.
encoding.BinaryUnmarshaler
encoding.BinaryMarshaler
// Get the size of bytes to marshal this object.
Size() int
// Get the Marker of any AMF0 stuff.
amf0Marker() amf0Marker
}
type amf0Converter struct {
from amf0Any
}
func NewAmf0Converter(from amf0Any) *amf0Converter {
return &amf0Converter{from: from}
}
func (v *amf0Converter) ToNumber() *amf0Number {
return amf0AnyTo[*amf0Number](v.from)
}
func (v *amf0Converter) ToBoolean() *amf0Boolean {
return amf0AnyTo[*amf0Boolean](v.from)
}
func (v *amf0Converter) ToString() *amf0String {
return amf0AnyTo[*amf0String](v.from)
}
func (v *amf0Converter) ToObject() *amf0Object {
return amf0AnyTo[*amf0Object](v.from)
}
func (v *amf0Converter) ToNull() *amf0Null {
return amf0AnyTo[*amf0Null](v.from)
}
func (v *amf0Converter) ToUndefined() *amf0Undefined {
return amf0AnyTo[*amf0Undefined](v.from)
}
func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray {
return amf0AnyTo[*amf0EcmaArray](v.from)
}
func (v *amf0Converter) ToStrictArray() *amf0StrictArray {
return amf0AnyTo[*amf0StrictArray](v.from)
}
// Convert any to specified object.
func amf0AnyTo[T amf0Any](a amf0Any) T {
var to T
if a != nil {
if v, ok := a.(T); ok {
return v
}
}
return to
}
// Discovery the amf0 object from the bytes b.
func Amf0Discovery(p []byte) (a amf0Any, err error) {
if len(p) < 1 {
return nil, errors.Errorf("require 1 bytes only %v", len(p))
}
m := amf0Marker(p[0])
switch m {
case amf0MarkerNumber:
return NewAmf0Number(0), nil
case amf0MarkerBoolean:
return NewAmf0Boolean(false), nil
case amf0MarkerString:
return NewAmf0String(""), nil
case amf0MarkerObject:
return NewAmf0Object(), nil
case amf0MarkerNull:
return NewAmf0Null(), nil
case amf0MarkerUndefined:
return NewAmf0Undefined(), nil
case amf0MarkerReference:
case amf0MarkerEcmaArray:
return NewAmf0EcmaArray(), nil
case amf0MarkerObjectEnd:
return &amf0ObjectEOF{}, nil
case amf0MarkerStrictArray:
return NewAmf0StrictArray(), nil
case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument,
amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip,
amf0MarkerRecordSet:
return nil, errors.Errorf("Marker %v is not supported", m)
}
return nil, errors.Errorf("Marker %v is invalid", m)
}
// The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8
type amf0UTF8 string
func (v *amf0UTF8) Size() int {
return 2 + len(string(*v))
}
func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 2 {
return errors.Errorf("require 2 bytes only %v", len(p))
}
size := uint16(p[0])<<8 | uint16(p[1])
if p = data[2:]; len(p) < int(size) {
return errors.Errorf("require %v bytes only %v", int(size), len(p))
}
*v = amf0UTF8(string(p[:size]))
return
}
func (v *amf0UTF8) MarshalBinary() (data []byte, err error) {
data = make([]byte, v.Size())
size := uint16(len(string(*v)))
data[0] = byte(size >> 8)
data[1] = byte(size)
if size > 0 {
copy(data[2:], []byte(*v))
}
return
}
// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type
type amf0Number float64
func NewAmf0Number(f float64) *amf0Number {
v := amf0Number(f)
return &v
}
func (v *amf0Number) amf0Marker() amf0Marker {
return amf0MarkerNumber
}
func (v *amf0Number) Size() int {
return 1 + 8
}
func (v *amf0Number) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 9 {
return errors.Errorf("require 9 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerNumber {
return errors.Errorf("Amf0Number amf0Marker %v is illegal", m)
}
f := binary.BigEndian.Uint64(p[1:])
*v = amf0Number(math.Float64frombits(f))
return
}
func (v *amf0Number) MarshalBinary() (data []byte, err error) {
data = make([]byte, 9)
data[0] = byte(amf0MarkerNumber)
f := math.Float64bits(float64(*v))
binary.BigEndian.PutUint64(data[1:], f)
return
}
// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type
type amf0String string
func NewAmf0String(s string) *amf0String {
v := amf0String(s)
return &v
}
func (v *amf0String) amf0Marker() amf0Marker {
return amf0MarkerString
}
func (v *amf0String) Size() int {
u := amf0UTF8(*v)
return 1 + u.Size()
}
func (v *amf0String) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 1 {
return errors.Errorf("require 1 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerString {
return errors.Errorf("Amf0String amf0Marker %v is illegal", m)
}
var sv amf0UTF8
if err = sv.UnmarshalBinary(p[1:]); err != nil {
return errors.WithMessage(err, "utf8")
}
*v = amf0String(string(sv))
return
}
func (v *amf0String) MarshalBinary() (data []byte, err error) {
u := amf0UTF8(*v)
var pb []byte
if pb, err = u.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "utf8")
}
data = append([]byte{byte(amf0MarkerString)}, pb...)
return
}
// The AMF0 object end type, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.11 Object End Type
type amf0ObjectEOF struct {
}
func (v *amf0ObjectEOF) amf0Marker() amf0Marker {
return amf0MarkerObjectEnd
}
func (v *amf0ObjectEOF) Size() int {
return 3
}
func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) {
p := data
if len(p) < 3 {
return errors.Errorf("require 3 bytes only %v", len(p))
}
if p[0] != 0 || p[1] != 0 || p[2] != 9 {
return errors.Errorf("EOF amf0Marker %v is illegal", p[0:3])
}
return
}
func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) {
return []byte{0, 0, 9}, nil
}
// Use array for object and ecma array, to keep the original order.
type amf0Property struct {
key amf0UTF8
value amf0Any
}
// The object-like AMF0 structure, like object and ecma array and strict array.
type amf0ObjectBase struct {
properties []*amf0Property
lock sync.Mutex
}
func (v *amf0ObjectBase) Size() int {
v.lock.Lock()
defer v.lock.Unlock()
var size int
for _, p := range v.properties {
key, value := p.key, p.value
size += key.Size() + value.Size()
}
return size
}
func (v *amf0ObjectBase) Get(key string) amf0Any {
v.lock.Lock()
defer v.lock.Unlock()
for _, p := range v.properties {
if string(p.key) == key {
return p.value
}
}
return nil
}
func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase {
v.lock.Lock()
defer v.lock.Unlock()
prop := &amf0Property{key: amf0UTF8(key), value: value}
var ok bool
for i, p := range v.properties {
if string(p.key) == key {
v.properties[i] = prop
ok = true
}
}
if !ok {
v.properties = append(v.properties, prop)
}
return v
}
func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) {
// if no eof, elems specified by maxElems.
if !eof && maxElems < 0 {
return errors.Errorf("maxElems=%v without eof", maxElems)
}
// if eof, maxElems must be -1.
if eof && maxElems != -1 {
return errors.Errorf("maxElems=%v with eof", maxElems)
}
readOne := func() (amf0UTF8, amf0Any, error) {
var u amf0UTF8
if err = u.UnmarshalBinary(p); err != nil {
return "", nil, errors.WithMessage(err, "prop name")
}
p = p[u.Size():]
var a amf0Any
if a, err = Amf0Discovery(p); err != nil {
return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u)))
}
return u, a, nil
}
pushOne := func(u amf0UTF8, a amf0Any) error {
// For object property, consume the whole bytes.
if err = a.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u)))
}
v.Set(string(u), a)
p = p[a.Size():]
return nil
}
for eof {
u, a, err := readOne()
if err != nil {
return errors.WithMessage(err, "read")
}
// For object EOF, we should only consume total 3bytes.
if u.Size() == 2 && a.amf0Marker() == amf0MarkerObjectEnd {
// 2 bytes is consumed by u(name), the a(eof) should only consume 1 byte.
p = p[1:]
return nil
}
if err := pushOne(u, a); err != nil {
return errors.WithMessage(err, "push")
}
}
for len(v.properties) < maxElems {
u, a, err := readOne()
if err != nil {
return errors.WithMessage(err, "read")
}
if err := pushOne(u, a); err != nil {
return errors.WithMessage(err, "push")
}
}
return
}
func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) {
v.lock.Lock()
defer v.lock.Unlock()
var pb []byte
for _, p := range v.properties {
key, value := p.key, p.value
if pb, err = key.MarshalBinary(); err != nil {
return errors.WithMessage(err, fmt.Sprintf("marshal %v", string(key)))
}
if _, err = b.Write(pb); err != nil {
return errors.Wrapf(err, "write %v", string(key))
}
if pb, err = value.MarshalBinary(); err != nil {
return errors.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key)))
}
if _, err = b.Write(pb); err != nil {
return errors.Wrapf(err, "marshal value for %v", string(key))
}
}
return
}
// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type
type amf0Object struct {
amf0ObjectBase
eof amf0ObjectEOF
}
func NewAmf0Object() *amf0Object {
v := &amf0Object{}
v.properties = []*amf0Property{}
return v
}
func (v *amf0Object) amf0Marker() amf0Marker {
return amf0MarkerObject
}
func (v *amf0Object) Size() int {
return int(1) + v.eof.Size() + v.amf0ObjectBase.Size()
}
func (v *amf0Object) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 1 {
return errors.Errorf("require 1 byte only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerObject {
return errors.Errorf("Amf0Object amf0Marker %v is illegal", m)
}
p = p[1:]
if err = v.unmarshal(p, true, -1); err != nil {
return errors.WithMessage(err, "unmarshal")
}
return
}
func (v *amf0Object) MarshalBinary() (data []byte, err error) {
b := createBuffer()
if err = b.WriteByte(byte(amf0MarkerObject)); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = v.marshal(b); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
var pb []byte
if pb, err = v.eof.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
if _, err = b.Write(pb); err != nil {
return nil, errors.Wrap(err, "marshal")
}
return b.Bytes(), nil
}
// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type
type amf0EcmaArray struct {
amf0ObjectBase
count uint32
eof amf0ObjectEOF
}
func NewAmf0EcmaArray() *amf0EcmaArray {
v := &amf0EcmaArray{}
v.properties = []*amf0Property{}
return v
}
func (v *amf0EcmaArray) amf0Marker() amf0Marker {
return amf0MarkerEcmaArray
}
func (v *amf0EcmaArray) Size() int {
return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size()
}
func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 5 {
return errors.Errorf("require 5 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray {
return errors.Errorf("EcmaArray amf0Marker %v is illegal", m)
}
v.count = binary.BigEndian.Uint32(p[1:])
p = p[5:]
if err = v.unmarshal(p, true, -1); err != nil {
return errors.WithMessage(err, "unmarshal")
}
return
}
func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) {
b := createBuffer()
if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = binary.Write(b, binary.BigEndian, v.count); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = v.marshal(b); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
var pb []byte
if pb, err = v.eof.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
if _, err = b.Write(pb); err != nil {
return nil, errors.Wrap(err, "marshal")
}
return b.Bytes(), nil
}
// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type
type amf0StrictArray struct {
amf0ObjectBase
count uint32
}
func NewAmf0StrictArray() *amf0StrictArray {
v := &amf0StrictArray{}
v.properties = []*amf0Property{}
return v
}
func (v *amf0StrictArray) amf0Marker() amf0Marker {
return amf0MarkerStrictArray
}
func (v *amf0StrictArray) Size() int {
return int(1) + 4 + v.amf0ObjectBase.Size()
}
func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 5 {
return errors.Errorf("require 5 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerStrictArray {
return errors.Errorf("StrictArray amf0Marker %v is illegal", m)
}
v.count = binary.BigEndian.Uint32(p[1:])
p = p[5:]
if int(v.count) <= 0 {
return
}
if err = v.unmarshal(p, false, int(v.count)); err != nil {
return errors.WithMessage(err, "unmarshal")
}
return
}
func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) {
b := createBuffer()
if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = binary.Write(b, binary.BigEndian, v.count); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = v.marshal(b); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
return b.Bytes(), nil
}
// The single amf0Marker object, for all AMF0 which only has the amf0Marker, like null and undefined.
type amf0SingleMarkerObject struct {
target amf0Marker
}
func newAmf0SingleMarkerObject(m amf0Marker) amf0SingleMarkerObject {
return amf0SingleMarkerObject{target: m}
}
func (v *amf0SingleMarkerObject) amf0Marker() amf0Marker {
return v.target
}
func (v *amf0SingleMarkerObject) Size() int {
return int(1)
}
func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 1 {
return errors.Errorf("require 1 byte only %v", len(p))
}
if m := amf0Marker(p[0]); m != v.target {
return errors.Errorf("%v amf0Marker %v is illegal", v.target, m)
}
return
}
func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) {
return []byte{byte(v.target)}, nil
}
// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type
type amf0Null struct {
amf0SingleMarkerObject
}
func NewAmf0Null() *amf0Null {
v := amf0Null{}
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull)
return &v
}
// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type
type amf0Undefined struct {
amf0SingleMarkerObject
}
func NewAmf0Undefined() amf0Any {
v := amf0Undefined{}
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined)
return &v
}
// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type
type amf0Boolean bool
func NewAmf0Boolean(b bool) amf0Any {
v := amf0Boolean(b)
return &v
}
func (v *amf0Boolean) amf0Marker() amf0Marker {
return amf0MarkerBoolean
}
func (v *amf0Boolean) Size() int {
return int(2)
}
func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 2 {
return errors.Errorf("require 2 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerBoolean {
return errors.Errorf("BOOL amf0Marker %v is illegal", m)
}
if p[1] == 0 {
*v = false
} else {
*v = true
}
return
}
func (v *amf0Boolean) MarshalBinary() (data []byte, err error) {
var b byte
if *v {
b = 1
}
return []byte{byte(amf0MarkerBoolean), b}, nil
}

1792
proxy/rtmp/rtmp.go Normal file

File diff suppressed because it is too large Load diff

44
proxy/signal.go Normal file
View file

@ -0,0 +1,44 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
func installSignals(ctx context.Context, cancel context.CancelFunc) {
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
go func() {
for s := range sc {
logger.Df(ctx, "Got signal %v", s)
cancel()
}
}()
}
func installForceQuit(ctx context.Context) error {
var forceTimeout time.Duration
if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil {
return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout())
} else {
forceTimeout = t
}
go func() {
<-ctx.Done()
time.Sleep(forceTimeout)
logger.Wf(ctx, "Force to exit by timeout")
os.Exit(1)
}()
return nil
}

553
proxy/srs.go Normal file
View file

@ -0,0 +1,553 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"os"
"strconv"
"strings"
"time"
// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
"github.com/go-redis/redis/v8"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
// If server heartbeat in this duration, it's alive.
const srsServerAliveDuration = 300 * time.Second
// If HLS streaming update in this duration, it's alive.
const srsHLSAliveDuration = 120 * time.Second
// If WebRTC streaming update in this duration, it's alive.
const srsRTCAliveDuration = 120 * time.Second
type SRSServer struct {
// The server IP.
IP string `json:"ip,omitempty"`
// The server device ID, configured by user.
DeviceID string `json:"device_id,omitempty"`
// The server id of SRS, store in file, may not change, mandatory.
ServerID string `json:"server_id,omitempty"`
// The service id of SRS, always change when restarted, mandatory.
ServiceID string `json:"service_id,omitempty"`
// The process id of SRS, always change when restarted, mandatory.
PID string `json:"pid,omitempty"`
// The RTMP listen endpoints.
RTMP []string `json:"rtmp,omitempty"`
// The HTTP Stream listen endpoints.
HTTP []string `json:"http,omitempty"`
// The HTTP API listen endpoints.
API []string `json:"api,omitempty"`
// The SRT server listen endpoints.
SRT []string `json:"srt,omitempty"`
// The RTC server listen endpoints.
RTC []string `json:"rtc,omitempty"`
// Last update time.
UpdatedAt time.Time `json:"update_at,omitempty"`
}
func (v *SRSServer) ID() string {
return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID)
}
func (v *SRSServer) String() string {
return fmt.Sprintf("%v", v)
}
func (v *SRSServer) Format(f fmt.State, c rune) {
switch c {
case 'v', 's':
if f.Flag('+') {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID))
if v.DeviceID != "" {
sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID))
}
if len(v.RTMP) > 0 {
sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ",")))
}
if len(v.HTTP) > 0 {
sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ",")))
}
if len(v.API) > 0 {
sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ",")))
}
if len(v.SRT) > 0 {
sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ",")))
}
if len(v.RTC) > 0 {
sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ",")))
}
sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999")))
fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String())
} else {
fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID())
}
default:
fmt.Fprintf(f, "%v, fmt=%%%c", v, c)
}
}
func NewSRSServer(opts ...func(*SRSServer)) *SRSServer {
v := &SRSServer{}
for _, opt := range opts {
opt(v)
}
return v
}
// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only.
func NewDefaultSRSForDebugging() (*SRSServer, error) {
if envDefaultBackendEnabled() != "on" {
return nil, nil
}
if envDefaultBackendIP() == "" {
return nil, fmt.Errorf("empty default backend ip")
}
if envDefaultBackendRTMP() == "" {
return nil, fmt.Errorf("empty default backend rtmp")
}
server := NewSRSServer(func(srs *SRSServer) {
srs.IP = envDefaultBackendIP()
srs.RTMP = []string{envDefaultBackendRTMP()}
srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID())
srs.ServiceID = logger.GenerateContextID()
srs.PID = fmt.Sprintf("%v", os.Getpid())
srs.UpdatedAt = time.Now()
})
if envDefaultBackendHttp() != "" {
server.HTTP = []string{envDefaultBackendHttp()}
}
if envDefaultBackendAPI() != "" {
server.API = []string{envDefaultBackendAPI()}
}
if envDefaultBackendRTC() != "" {
server.RTC = []string{envDefaultBackendRTC()}
}
if envDefaultBackendSRT() != "" {
server.SRT = []string{envDefaultBackendSRT()}
}
return server, nil
}
// SRSLoadBalancer is the interface to load balance the SRS servers.
type SRSLoadBalancer interface {
// Initialize the load balancer.
Initialize(ctx context.Context) error
// Update the backer server.
Update(ctx context.Context, server *SRSServer) error
// Pick a backend server for the specified stream URL.
Pick(ctx context.Context, streamURL string) (*SRSServer, error)
// Load or store the HLS streaming for the specified stream URL.
LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error)
// Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID.
LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error)
// Store the WebRTC streaming for the specified stream URL.
StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error
// Load the WebRTC streaming by ufrag, the ICE username.
LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error)
}
// srsLoadBalancer is the global SRS load balancer.
var srsLoadBalancer SRSLoadBalancer
// srsMemoryLoadBalancer stores state in memory.
type srsMemoryLoadBalancer struct {
// All available SRS servers, key is server ID.
servers sync.Map[string, *SRSServer]
// The picked server to servce client by specified stream URL, key is stream url.
picked sync.Map[string, *SRSServer]
// The HLS streaming, key is stream URL.
hlsStreamURL sync.Map[string, *HLSPlayStream]
// The HLS streaming, key is SPBHID.
hlsSPBHID sync.Map[string, *HLSPlayStream]
// The WebRTC streaming, key is stream URL.
rtcStreamURL sync.Map[string, *RTCConnection]
// The WebRTC streaming, key is ufrag.
rtcUfrag sync.Map[string, *RTCConnection]
}
func NewMemoryLoadBalancer() SRSLoadBalancer {
return &srsMemoryLoadBalancer{}
}
func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error {
if server, err := NewDefaultSRSForDebugging(); err != nil {
return errors.Wrapf(err, "initialize default SRS")
} else if server != nil {
if err := v.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update default SRS %+v", server)
}
// Keep alive.
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
if err := v.Update(ctx, server); err != nil {
logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
}
}
}
}()
logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server)
}
return nil
}
func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
v.servers.Store(server.ID(), server)
return nil
}
func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
// Always proxy to the same server for the same stream URL.
if server, ok := v.picked.Load(streamURL); ok {
return server, nil
}
// Gather all servers that were alive within the last few seconds.
var servers []*SRSServer
v.servers.Range(func(key string, server *SRSServer) bool {
if time.Since(server.UpdatedAt) < srsServerAliveDuration {
servers = append(servers, server)
}
return true
})
// If no servers available, use all possible servers.
if len(servers) == 0 {
v.servers.Range(func(key string, server *SRSServer) bool {
servers = append(servers, server)
return true
})
}
// No server found, failed.
if len(servers) == 0 {
return nil, fmt.Errorf("no server available for %v", streamURL)
}
// Pick a server randomly from servers.
server := servers[rand.Intn(len(servers))]
v.picked.Store(streamURL, server)
return server, nil
}
func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
// Load the HLS streaming for the SPBHID, for TS files.
if actual, ok := v.hlsSPBHID.Load(spbhid); !ok {
return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid)
} else {
return actual, nil
}
}
func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
// Update the HLS streaming for the stream URL, for M3u8.
actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value)
if actual == nil {
return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL)
}
// Update the HLS streaming for the SPBHID, for TS files.
v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual)
return actual, nil
}
func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
// Update the WebRTC streaming for the stream URL.
v.rtcStreamURL.Store(streamURL, value)
// Update the WebRTC streaming for the ufrag.
v.rtcUfrag.Store(value.Ufrag, value)
return nil
}
func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
if actual, ok := v.rtcUfrag.Load(ufrag); !ok {
return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag)
} else {
return actual, nil
}
}
type srsRedisLoadBalancer struct {
// The redis client sdk.
rdb *redis.Client
}
func NewRedisLoadBalancer() SRSLoadBalancer {
return &srsRedisLoadBalancer{}
}
func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error {
redisDatabase, err := strconv.Atoi(envRedisDB())
if err != nil {
return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB())
}
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()),
Password: envRedisPassword(),
DB: redisDatabase,
})
v.rdb = rdb
if err := rdb.Ping(ctx).Err(); err != nil {
return errors.Wrapf(err, "unable to connect to redis %v", rdb.String())
}
logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String())
if server, err := NewDefaultSRSForDebugging(); err != nil {
return errors.Wrapf(err, "initialize default SRS")
} else if server != nil {
if err := v.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update default SRS %+v", server)
}
// Keep alive.
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
if err := v.Update(ctx, server); err != nil {
logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
}
}
}
}()
logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server)
}
return nil
}
func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
b, err := json.Marshal(server)
if err != nil {
return errors.Wrapf(err, "marshal server %+v", server)
}
key := v.redisKeyServer(server.ID())
if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v server %+v", key, server)
}
// Query all servers from redis, in json string.
var serverKeys []string
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
if err := json.Unmarshal(b, &serverKeys); err != nil {
return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
}
}
// Check each server expiration, if not exists in redis, remove from servers.
for i := len(serverKeys) - 1; i >= 0; i-- {
if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil {
serverKeys = append(serverKeys[:i], serverKeys[i+1:]...)
}
}
// Add server to servers if not exists.
var found bool
for _, serverKey := range serverKeys {
if serverKey == key {
found = true
break
}
}
if !found {
serverKeys = append(serverKeys, key)
}
// Update all servers to redis.
b, err = json.Marshal(serverKeys)
if err != nil {
return errors.Wrapf(err, "marshal servers %+v", serverKeys)
}
if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil {
return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys)
}
return nil
}
func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
key := fmt.Sprintf("srs-proxy-url:%v", streamURL)
// Always proxy to the same server for the same stream URL.
if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil {
// If server not exists, ignore and pick another server for the stream URL.
if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 {
var server SRSServer
if err := json.Unmarshal(b, &server); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b))
}
// TODO: If server fail, we should migrate the streams to another server.
return &server, nil
}
}
// Query all servers from redis, in json string.
var serverKeys []string
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
if err := json.Unmarshal(b, &serverKeys); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
}
}
// No server found, failed.
if len(serverKeys) == 0 {
return nil, fmt.Errorf("no server available for %v", streamURL)
}
// All server should be alive, if not, should have been removed by redis. So we only
// random pick one that is always available.
var serverKey string
var server SRSServer
for i := 0; i < 3; i++ {
tryServerKey := serverKeys[rand.Intn(len(serverKeys))]
b, err := v.rdb.Get(ctx, tryServerKey).Bytes()
if err == nil && len(b) > 0 {
if err := json.Unmarshal(b, &server); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b))
}
serverKey = tryServerKey
break
}
}
if serverKey == "" {
return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL)
}
// Update the picked server for the stream URL.
if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey)
}
return &server, nil
}
func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
key := v.redisKeySPBHID(spbhid)
b, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v HLS", key)
}
var actual HLSPlayStream
if err := json.Unmarshal(b, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
b, err := json.Marshal(value)
if err != nil {
return nil, errors.Wrapf(err, "marshal HLS %v", value)
}
key := v.redisKeyHLS(streamURL)
if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value)
}
key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID)
if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value)
}
// Query the HLS streaming from redis.
b2, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v HLS", key)
}
var actual HLSPlayStream
if err := json.Unmarshal(b2, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
b, err := json.Marshal(value)
if err != nil {
return errors.Wrapf(err, "marshal WebRTC %v", value)
}
key := v.redisKeyRTC(streamURL)
if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v WebRTC %v", key, value)
}
key2 := v.redisKeyUfrag(value.Ufrag)
if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value)
}
return nil
}
func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
key := v.redisKeyUfrag(ufrag)
b, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v WebRTC", key)
}
var actual RTCConnection
if err := json.Unmarshal(b, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string {
return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag)
}
func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string {
return fmt.Sprintf("srs-proxy-rtc:%v", streamURL)
}
func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string {
return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid)
}
func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string {
return fmt.Sprintf("srs-proxy-hls:%v", streamURL)
}
func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string {
return fmt.Sprintf("srs-proxy-server:%v", serverID)
}
func (v *srsRedisLoadBalancer) redisKeyServers() string {
return fmt.Sprintf("srs-proxy-all-servers")
}

574
proxy/srt.go Normal file
View file

@ -0,0 +1,574 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net"
"strings"
stdSync "sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
// srsSRTServer is the proxy for SRS server via SRT. It will figure out which backend server to
// proxy to. It only parses the SRT handshake messages, parses the stream id, and proxy to the
// backend server.
type srsSRTServer struct {
// The UDP listener for SRT server.
listener *net.UDPConn
// The SRT connections, identify by the socket ID.
sockets sync.Map[uint32, *SRTConnection]
// The system start time.
start time.Time
// The wait group for server.
wg stdSync.WaitGroup
}
func NewSRSSRTServer(opts ...func(*srsSRTServer)) *srsSRTServer {
v := &srsSRTServer{
start: time.Now(),
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsSRTServer) Close() error {
if v.listener != nil {
v.listener.Close()
}
v.wg.Wait()
return nil
}
func (v *srsSRTServer) Run(ctx context.Context) error {
// Parse address to listen.
endpoint := envSRTServer()
if !strings.Contains(endpoint, ":") {
endpoint = ":" + endpoint
}
saddr, err := net.ResolveUDPAddr("udp", endpoint)
if err != nil {
return errors.Wrapf(err, "resolve udp addr %v", endpoint)
}
listener, err := net.ListenUDP("udp", saddr)
if err != nil {
return errors.Wrapf(err, "listen udp %v", saddr)
}
v.listener = listener
logger.Df(ctx, "SRT server listen at %v", saddr)
// Consume all messages from UDP media transport.
v.wg.Add(1)
go func() {
defer v.wg.Done()
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, caddr, err := v.listener.ReadFromUDP(buf)
if err != nil {
// TODO: If SRT server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "read from udp failed, err=%+v", err)
continue
}
if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil {
logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err)
}
}
}()
return nil
}
func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
socketID := srtParseSocketID(data)
var pkt *SRTHandshakePacket
if srtIsHandshake(data) {
pkt = &SRTHandshakePacket{}
if err := pkt.UnmarshalBinary(data); err != nil {
return err
}
if socketID == 0 {
socketID = pkt.SRTSocketID
}
}
conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) {
c.ctx = logger.WithContext(ctx)
c.listenerUDP, c.socketID = v.listener, socketID
c.start = v.start
}))
ctx = conn.ctx
if !ok {
logger.Df(ctx, "Create new SRT connection skt=%v", socketID)
}
if newSocketID, err := conn.HandlePacket(pkt, addr, data); err != nil {
return errors.Wrapf(err, "handle packet")
} else if newSocketID != 0 && newSocketID != socketID {
// The connection may use a new socket ID.
// TODO: FIXME: Should cleanup the dead SRT connection.
v.sockets.Store(newSocketID, conn)
}
return nil
}
// SRTConnection is an SRT connection proxy, for both caller and listener. It represents an SRT
// connection, identify by the socket ID.
//
// It's similar to RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is in
// the client request. The SRTConnection is stateless, and no need to sync between proxy servers.
//
// Unlike the WebRTC connection, SRTConnection does not support address changes. This means the
// client should never switch to another network or port. If this occurs, the client may be served
// by a different proxy server and fail because the other proxy server cannot identify the client.
type SRTConnection struct {
// The stream context for SRT connection.
ctx context.Context
// The current socket ID.
socketID uint32
// The UDP connection proxy to backend.
backendUDP *net.UDPConn
// The listener UDP connection, used to send messages to client.
listenerUDP *net.UDPConn
// Listener start time.
start time.Time
// Handshake packets with client.
handshake0 *SRTHandshakePacket
handshake1 *SRTHandshakePacket
handshake2 *SRTHandshakePacket
handshake3 *SRTHandshakePacket
}
func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection {
v := &SRTConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) {
ctx := v.ctx
// If not handshake, try to proxy to backend directly.
if pkt == nil {
// Proxy client message to backend.
if v.backendUDP != nil {
if _, err := v.backendUDP.Write(data); err != nil {
return v.socketID, errors.Wrapf(err, "write to backend")
}
}
return v.socketID, nil
}
// Handle handshake messages.
if err := v.handleHandshake(ctx, pkt, addr, data); err != nil {
return v.socketID, errors.Wrapf(err, "handle handshake %v", pkt)
}
return v.socketID, nil
}
func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error {
// Handle handshake 0 and 1 messages.
if pkt.SynCookie == 0 {
// Save handshake 0 packet.
v.handshake0 = pkt
logger.Df(ctx, "SRT Handshake 0: %v", v.handshake0)
// Response handshake 1.
v.handshake1 = &SRTHandshakePacket{
ControlFlag: pkt.ControlFlag,
ControlType: 0,
SubType: 0,
AdditionalInfo: 0,
Timestamp: uint32(time.Since(v.start).Microseconds()),
SocketID: pkt.SRTSocketID,
Version: 5,
EncryptionField: 0,
ExtensionField: 0x4A17,
InitSequence: pkt.InitSequence,
MTU: pkt.MTU,
FlowWindow: pkt.FlowWindow,
HandshakeType: 1,
SRTSocketID: pkt.SRTSocketID,
SynCookie: 0x418d5e4e,
PeerIP: net.ParseIP("127.0.0.1"),
}
logger.Df(ctx, "SRT Handshake 1: %v", v.handshake1)
if b, err := v.handshake1.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 1")
} else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil {
return errors.Wrapf(err, "write handshake 1")
}
return nil
}
// Handle handshake 2 and 3 messages.
// Parse stream id from packet.
streamID, err := pkt.StreamID()
if err != nil {
return errors.Wrapf(err, "parse stream id")
}
// Save handshake packet.
v.handshake2 = pkt
logger.Df(ctx, "SRT Handshake 2: %v, sid=%v", v.handshake2, streamID)
// Start the UDP proxy to backend.
if err := v.connectBackend(ctx, streamID); err != nil {
return errors.Wrapf(err, "connect backend for %v", streamID)
}
// Proxy client message to backend.
if v.backendUDP == nil {
return errors.Errorf("no backend for %v", streamID)
}
// Proxy handshake 0 to backend server.
if b, err := v.handshake0.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 0")
} else if _, err = v.backendUDP.Write(b); err != nil {
return errors.Wrapf(err, "write handshake 0")
}
logger.Df(ctx, "Proxy send handshake 0: %v", v.handshake0)
// Read handshake 1 from backend server.
b := make([]byte, 4096)
handshake1p := &SRTHandshakePacket{}
if nn, err := v.backendUDP.Read(b); err != nil {
return errors.Wrapf(err, "read handshake 1")
} else if err := handshake1p.UnmarshalBinary(b[:nn]); err != nil {
return errors.Wrapf(err, "unmarshal handshake 1")
}
logger.Df(ctx, "Proxy got handshake 1: %v", handshake1p)
// Proxy handshake 2 to backend server.
handshake2p := *v.handshake2
handshake2p.SynCookie = handshake1p.SynCookie
if b, err := handshake2p.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 2")
} else if _, err = v.backendUDP.Write(b); err != nil {
return errors.Wrapf(err, "write handshake 2")
}
logger.Df(ctx, "Proxy send handshake 2: %v", handshake2p)
// Read handshake 3 from backend server.
handshake3p := &SRTHandshakePacket{}
if nn, err := v.backendUDP.Read(b); err != nil {
return errors.Wrapf(err, "read handshake 3")
} else if err := handshake3p.UnmarshalBinary(b[:nn]); err != nil {
return errors.Wrapf(err, "unmarshal handshake 3")
}
logger.Df(ctx, "Proxy got handshake 3: %v", handshake3p)
// Response handshake 3 to client.
v.handshake3 = &*handshake3p
v.handshake3.SynCookie = v.handshake1.SynCookie
v.socketID = handshake3p.SRTSocketID
logger.Df(ctx, "Handshake 3: %v", v.handshake3)
if b, err := v.handshake3.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 3")
} else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil {
return errors.Wrapf(err, "write handshake 3")
}
// Start a goroutine to proxy message from backend to client.
// TODO: FIXME: Support close the connection when timeout or client disconnected.
go func() {
for ctx.Err() == nil {
nn, err := v.backendUDP.Read(b)
if err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "read from backend failed, err=%v", err)
return
}
if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "write to client failed, err=%v", err)
return
}
}
}()
return nil
}
func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) error {
if v.backendUDP != nil {
return nil
}
// Parse stream id to host and resource.
host, resource, err := parseSRTStreamID(streamID)
if err != nil {
return errors.Wrapf(err, "parse stream id %v", streamID)
}
if host == "" {
host = "localhost"
}
streamURL, err := buildStreamURL(fmt.Sprintf("srt://%v/%v", host, resource))
if err != nil {
return errors.Wrapf(err, "build stream url %v", streamID)
}
// Pick a backend SRS server to proxy the SRT stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
// Parse UDP port from backend.
if len(backend.SRT) == 0 {
return errors.Errorf("no udp server %v for %v", backend, streamURL)
}
_, _, udpPort, err := parseListenEndpoint(backend.SRT[0])
if err != nil {
return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL)
}
// Connect to backend SRS server via UDP client.
// TODO: FIXME: Support close the connection when timeout or client disconnected.
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)}
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL)
} else {
v.backendUDP = backendUDP
}
return nil
}
// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2
// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2.1
type SRTHandshakePacket struct {
// F: 1 bit. Packet Type Flag. The control packet has this flag set to
// "1". The data packet has this flag set to "0".
ControlFlag uint8
// Control Type: 15 bits. Control Packet Type. The use of these bits
// is determined by the control packet type definition.
// Handshake control packets (Control Type = 0x0000) are used to
// exchange peer configurations, to agree on connection parameters, and
// to establish a connection.
ControlType uint16
// Subtype: 16 bits. This field specifies an additional subtype for
// specific packets.
SubType uint16
// Type-specific Information: 32 bits. The use of this field depends on
// the particular control packet type. Handshake packets do not use
// this field.
AdditionalInfo uint32
// Timestamp: 32 bits.
Timestamp uint32
// Destination Socket ID: 32 bits.
SocketID uint32
// Version: 32 bits. A base protocol version number. Currently used
// values are 4 and 5. Values greater than 5 are reserved for future
// use.
Version uint32
// Encryption Field: 16 bits. Block cipher family and key size. The
// values of this field are described in Table 2. The default value
// is AES-128.
// 0 | No Encryption Advertised
// 2 | AES-128
// 3 | AES-192
// 4 | AES-256
EncryptionField uint16
// Extension Field: 16 bits. This field is message specific extension
// related to Handshake Type field. The value MUST be set to 0
// except for the following cases. (1) If the handshake control
// packet is the INDUCTION message, this field is sent back by the
// Listener. (2) In the case of a CONCLUSION message, this field
// value should contain a combination of Extension Type values.
// 0x00000001 | HSREQ
// 0x00000002 | KMREQ
// 0x00000004 | CONFIG
// 0x4A17 if HandshakeType is INDUCTION, see https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-4.3.1.1
ExtensionField uint16
// Initial Packet Sequence Number: 32 bits. The sequence number of the
// very first data packet to be sent.
InitSequence uint32
// Maximum Transmission Unit Size: 32 bits. This value is typically set
// to 1500, which is the default Maximum Transmission Unit (MTU) size
// for Ethernet, but can be less.
MTU uint32
// Maximum Flow Window Size: 32 bits. The value of this field is the
// maximum number of data packets allowed to be "in flight" (i.e. the
// number of sent packets for which an ACK control packet has not yet
// been received).
FlowWindow uint32
// Handshake Type: 32 bits. This field indicates the handshake packet
// type.
// 0xFFFFFFFD | DONE
// 0xFFFFFFFE | AGREEMENT
// 0xFFFFFFFF | CONCLUSION
// 0x00000000 | WAVEHAND
// 0x00000001 | INDUCTION
HandshakeType uint32
// SRT Socket ID: 32 bits. This field holds the ID of the source SRT
// socket from which a handshake packet is issued.
SRTSocketID uint32
// SYN Cookie: 32 bits. Randomized value for processing a handshake.
// The value of this field is specified by the handshake message
// type.
SynCookie uint32
// Peer IP Address: 128 bits. IPv4 or IPv6 address of the packet's
// sender. The value consists of four 32-bit fields.
PeerIP net.IP
// Extensions.
// Extension Type: 16 bits. The value of this field is used to process
// an integrated handshake. Each extension can have a pair of
// request and response types.
// Extension Length: 16 bits. The length of the Extension Contents
// field in four-byte blocks.
// Extension Contents: variable length. The payload of the extension.
ExtraData []byte
}
func (v *SRTHandshakePacket) IsData() bool {
return v.ControlFlag == 0x00
}
func (v *SRTHandshakePacket) IsControl() bool {
return v.ControlFlag == 0x80
}
func (v *SRTHandshakePacket) IsHandshake() bool {
return v.IsControl() && v.ControlType == 0x00 && v.SubType == 0x00
}
func (v *SRTHandshakePacket) StreamID() (string, error) {
p := v.ExtraData
for {
if len(p) < 2 {
return "", errors.Errorf("Require 2 bytes, actual=%v, extra=%v", len(p), len(v.ExtraData))
}
extType := binary.BigEndian.Uint16(p)
extSize := binary.BigEndian.Uint16(p[2:])
p = p[4:]
if len(p) < int(extSize*4) {
return "", errors.Errorf("Require %v bytes, actual=%v, extra=%v", extSize*4, len(p), len(v.ExtraData))
}
// Ignore other packets except stream id.
if extType != 0x05 {
p = p[extSize*4:]
continue
}
// We must copy it, because we will decode the stream id.
data := append([]byte{}, p[:extSize*4]...)
// Reverse the stream id encoded in little-endian to big-endian.
for i := 0; i < len(data); i += 4 {
value := binary.LittleEndian.Uint32(data[i:])
binary.BigEndian.PutUint32(data[i:], value)
}
// Trim the trailing zero bytes.
data = bytes.TrimRight(data, "\x00")
return string(data), nil
}
}
func (v *SRTHandshakePacket) String() string {
return fmt.Sprintf("Control=%v, CType=%v, SType=%v, Timestamp=%v, SocketID=%v, Version=%v, Encrypt=%v, Extension=%v, InitSequence=%v, MTU=%v, FlowWnd=%v, HSType=%v, SRTSocketID=%v, Cookie=%v, Peer=%vB, Extra=%vB",
v.IsControl(), v.ControlType, v.SubType, v.Timestamp, v.SocketID, v.Version, v.EncryptionField, v.ExtensionField, v.InitSequence, v.MTU, v.FlowWindow, v.HandshakeType, v.SRTSocketID, v.SynCookie, len(v.PeerIP), len(v.ExtraData))
}
func (v *SRTHandshakePacket) UnmarshalBinary(b []byte) error {
if len(b) < 4 {
return errors.Errorf("Invalid packet length %v", len(b))
}
v.ControlFlag = b[0] & 0x80
v.ControlType = binary.BigEndian.Uint16(b[0:2]) & 0x7fff
v.SubType = binary.BigEndian.Uint16(b[2:4])
if len(b) < 64 {
return errors.Errorf("Invalid packet length %v", len(b))
}
v.AdditionalInfo = binary.BigEndian.Uint32(b[4:])
v.Timestamp = binary.BigEndian.Uint32(b[8:])
v.SocketID = binary.BigEndian.Uint32(b[12:])
v.Version = binary.BigEndian.Uint32(b[16:])
v.EncryptionField = binary.BigEndian.Uint16(b[20:])
v.ExtensionField = binary.BigEndian.Uint16(b[22:])
v.InitSequence = binary.BigEndian.Uint32(b[24:])
v.MTU = binary.BigEndian.Uint32(b[28:])
v.FlowWindow = binary.BigEndian.Uint32(b[32:])
v.HandshakeType = binary.BigEndian.Uint32(b[36:])
v.SRTSocketID = binary.BigEndian.Uint32(b[40:])
v.SynCookie = binary.BigEndian.Uint32(b[44:])
// Only support IPv4.
v.PeerIP = net.IPv4(b[51], b[50], b[49], b[48])
v.ExtraData = b[64:]
return nil
}
func (v *SRTHandshakePacket) MarshalBinary() ([]byte, error) {
b := make([]byte, 64+len(v.ExtraData))
binary.BigEndian.PutUint16(b, uint16(v.ControlFlag)<<8|v.ControlType)
binary.BigEndian.PutUint16(b[2:], v.SubType)
binary.BigEndian.PutUint32(b[4:], v.AdditionalInfo)
binary.BigEndian.PutUint32(b[8:], v.Timestamp)
binary.BigEndian.PutUint32(b[12:], v.SocketID)
binary.BigEndian.PutUint32(b[16:], v.Version)
binary.BigEndian.PutUint16(b[20:], v.EncryptionField)
binary.BigEndian.PutUint16(b[22:], v.ExtensionField)
binary.BigEndian.PutUint32(b[24:], v.InitSequence)
binary.BigEndian.PutUint32(b[28:], v.MTU)
binary.BigEndian.PutUint32(b[32:], v.FlowWindow)
binary.BigEndian.PutUint32(b[36:], v.HandshakeType)
binary.BigEndian.PutUint32(b[40:], v.SRTSocketID)
binary.BigEndian.PutUint32(b[44:], v.SynCookie)
// Only support IPv4.
ip := v.PeerIP.To4()
b[48] = ip[3]
b[49] = ip[2]
b[50] = ip[1]
b[51] = ip[0]
if len(v.ExtraData) > 0 {
copy(b[64:], v.ExtraData)
}
return b, nil
}

45
proxy/sync/map.go Normal file
View file

@ -0,0 +1,45 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package sync
import "sync"
type Map[K comparable, V any] struct {
m sync.Map
}
func (m *Map[K, V]) Delete(key K) {
m.m.Delete(key)
}
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
v, ok := m.m.Load(key)
if !ok {
return value, ok
}
return v.(V), ok
}
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
v, loaded := m.m.LoadAndDelete(key)
if !loaded {
return value, loaded
}
return v.(V), loaded
}
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
a, loaded := m.m.LoadOrStore(key, value)
return a.(V), loaded
}
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
m.m.Range(func(key, value any) bool {
return f(key.(K), value.(V))
})
}
func (m *Map[K, V]) Store(key K, value V) {
m.m.Store(key, value)
}

276
proxy/utils.go Normal file
View file

@ -0,0 +1,276 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/binary"
"encoding/json"
stdErr "errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"path"
"reflect"
"regexp"
"strconv"
"strings"
"syscall"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) {
w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version()))
b, err := json.Marshal(data)
if err != nil {
apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(b)
}
func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) {
logger.Wf(ctx, "HTTP API error %+v", err)
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, fmt.Sprintf("%v", err))
}
func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool {
// Always support CORS. Note that browser may send origin header for m3u8, but no origin header
// for ts. So we always response CORS header.
if true {
// SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin,
// headers, expose headers and methods.
w.Header().Set("Access-Control-Allow-Origin", "*")
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
w.Header().Set("Access-Control-Allow-Headers", "*")
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
w.Header().Set("Access-Control-Allow-Methods", "*")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return true
}
return false
}
func parseGracefullyQuitTimeout() (time.Duration, error) {
if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil {
return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout())
} else {
return t, nil
}
}
// ParseBody read the body from r, and unmarshal JSON to v.
func ParseBody(r io.ReadCloser, v interface{}) error {
b, err := ioutil.ReadAll(r)
if err != nil {
return errors.Wrapf(err, "read body")
}
defer r.Close()
if len(b) == 0 {
return nil
}
if err := json.Unmarshal(b, v); err != nil {
return errors.Wrapf(err, "json unmarshal %v", string(b))
}
return nil
}
// buildStreamURL build as vhost/app/stream for stream URL r.
func buildStreamURL(r string) (string, error) {
u, err := url.Parse(r)
if err != nil {
return "", errors.Wrapf(err, "parse url %v", r)
}
// If not domain or ip in hostname, it's __defaultVhost__.
defaultVhost := !strings.Contains(u.Hostname(), ".")
// If hostname is actually an IP address, it's __defaultVhost__.
if ip := net.ParseIP(u.Hostname()); ip.To4() != nil {
defaultVhost = true
}
if defaultVhost {
return fmt.Sprintf("__defaultVhost__%v", u.Path), nil
}
// Ignore port, only use hostname as vhost.
return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil
}
// isPeerClosedError indicates whether peer object closed the connection.
func isPeerClosedError(err error) bool {
causeErr := errors.Cause(err)
if stdErr.Is(causeErr, io.EOF) {
return true
}
if stdErr.Is(causeErr, syscall.EPIPE) {
return true
}
if netErr, ok := causeErr.(*net.OpError); ok {
if sysErr, ok := netErr.Err.(*os.SyscallError); ok {
if stdErr.Is(sysErr.Err, syscall.ECONNRESET) {
return true
}
}
}
return false
}
// convertURLToStreamURL convert the URL in HTTP request to special URLs. The unifiedURL is the URL
// in unified, foramt as scheme://vhost/app/stream without extensions. While the fullURL is the unifiedURL
// with extension.
func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
hostname := "__defaultVhost__"
if strings.Contains(r.Host, ":") {
if v, _, err := net.SplitHostPort(r.Host); err == nil {
hostname = v
}
}
var appStream, streamExt string
// Parse app/stream from query string.
q := r.URL.Query()
if app := q.Get("app"); app != "" {
appStream = "/" + app
}
if stream := q.Get("stream"); stream != "" {
appStream = fmt.Sprintf("%v/%v", appStream, stream)
}
// Parse app/stream from path.
if appStream == "" {
streamExt = path.Ext(r.URL.Path)
appStream = strings.TrimSuffix(r.URL.Path, streamExt)
}
unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, appStream)
fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt)
return
}
// rtcIsSTUN returns true if data of UDP payload is a STUN packet.
func rtcIsSTUN(data []byte) bool {
return len(data) > 0 && (data[0] == 0 || data[0] == 1)
}
// rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet.
func rtcIsRTPOrRTCP(data []byte) bool {
return len(data) >= 12 && (data[0]&0xC0) == 0x80
}
// srtIsHandshake returns true if data of UDP payload is a SRT handshake packet.
func srtIsHandshake(data []byte) bool {
return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000
}
// srtParseSocketID parse the socket id from the SRT packet.
func srtParseSocketID(data []byte) uint32 {
if len(data) >= 16 {
return binary.BigEndian.Uint32(data[12:])
}
return 0
}
// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP.
func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) {
if true {
ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`)
ufragMatch := ufragRe.FindStringSubmatch(sdp)
if len(ufragMatch) <= 1 {
return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp)
}
ufrag = ufragMatch[1]
}
if true {
pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`)
pwdMatch := pwdRe.FindStringSubmatch(sdp)
if len(pwdMatch) <= 1 {
return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp)
}
pwd = pwdMatch[1]
}
return ufrag, pwd, nil
}
// parseSRTStreamID parse the SRT stream id to host(optional) and resource(required).
// See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url
func parseSRTStreamID(sid string) (host, resource string, err error) {
if true {
hostRe := regexp.MustCompile(`h=([^,]+)`)
hostMatch := hostRe.FindStringSubmatch(sid)
if len(hostMatch) > 1 {
host = hostMatch[1]
}
}
if true {
resourceRe := regexp.MustCompile(`r=([^,]+)`)
resourceMatch := resourceRe.FindStringSubmatch(sid)
if len(resourceMatch) <= 1 {
return "", "", errors.Errorf("no resource in sid %v", sid)
}
resource = resourceMatch[1]
}
return host, resource, nil
}
// parseListenEndpoint parse the listen endpoint as:
// port The tcp listen port, like 1935.
// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935
func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) {
// If no colon in ep, it's port in string.
if !strings.Contains(ep, ":") {
if p, err := strconv.Atoi(ep); err != nil {
return "", nil, 0, errors.Wrapf(err, "parse port %v", ep)
} else {
return "tcp", nil, uint16(p), nil
}
}
// Must be protocol://ip:port schema.
parts := strings.Split(ep, ":")
if len(parts) != 3 {
return "", nil, 0, errors.Errorf("invalid endpoint %v", ep)
}
if p, err := strconv.Atoi(parts[2]); err != nil {
return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2])
} else {
return parts[0], net.ParseIP(parts[1]), uint16(p), nil
}
}

27
proxy/version.go Normal file
View file

@ -0,0 +1,27 @@
// Copyright (c) 2024 Winlin
//
// SPDX-License-Identifier: MIT
package main
import "fmt"
func VersionMajor() int {
return 1
}
// VersionMinor specifies the typical version of SRS we adapt to.
func VersionMinor() int {
return 5
}
func VersionRevision() int {
return 0
}
func Version() string {
return fmt.Sprintf("%v.%v.%v", VersionMajor(), VersionMinor(), VersionRevision())
}
func Signature() string {
return "SRSProxy"
}

View file

@ -1,20 +1,32 @@
.PHONY: help default clean bench test
.PHONY: help default clean bench pcap test all
default: bench test
clean:
rm -rf ./objs
all: bench test pcap test
#########################################################################################################
# SRS benchmark tool for SRS, janus, GB28181.
./objs/.format.bench.txt: *.go janus/*.go ./objs/.format.srs.txt ./objs/.format.gb28181.txt
gofmt -w *.go janus
mkdir -p objs && echo "done" > ./objs/.format.bench.txt
bench: ./objs/srs_bench ./objs/pcap_simulator
bench: ./objs/srs_bench
./objs/srs_bench: ./objs/.format.bench.txt *.go janus/*.go srs/*.go vnet/*.go gb28181/*.go Makefile
go build -mod=vendor -o objs/srs_bench .
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S),Darwin)
SRT_PREFIX := $(shell brew --prefix srt)
CGO_CFLAGS := -I$(SRT_PREFIX)/include
CGO_LDFLAGS := -L$(SRT_PREFIX)/lib -lsrt
else ifeq ($(UNAME_S),Linux)
CGO_CFLAGS := -I/usr/local/include
CGO_LDFLAGS := -L/usr/local/lib -lsrt -L/usr/local/ssl/lib -lcrypto -lstdc++ -lm -ldl
endif
./objs/srs_bench: ./objs/.format.bench.txt *.go janus/*.go srs/*.go vnet/*.go gb28181/*.go live/*.go Makefile
CGO_CFLAGS="$(CGO_CFLAGS)" CGO_LDFLAGS="$(CGO_LDFLAGS)" go build -mod=vendor -o objs/srs_bench .
#########################################################################################################
# For all regression tests.
@ -35,6 +47,8 @@ test: ./objs/srs_test ./objs/srs_gb28181_test ./objs/srs_blackbox_test
gofmt -w pcap
mkdir -p objs && echo "done" > ./objs/.format.pcap.txt
pcap: ./objs/pcap_simulator
./objs/pcap_simulator: ./objs/.format.pcap.txt pcap/*.go Makefile
go build -mod=vendor -o ./objs/pcap_simulator ./pcap
@ -59,9 +73,10 @@ test: ./objs/srs_test ./objs/srs_gb28181_test ./objs/srs_blackbox_test
#########################################################################################################
# Help menu.
help:
@echo "Usage: make [default|bench|test|clean]"
@echo "Usage: make [default|bench|pcap|test|clean]"
@echo " default The default entry for make is bench+test"
@echo " bench Make the bench to ./objs/srs_bench"
@echo " pcap Make the pcap simulator to ./objs/pcap_simulator"
@echo " test Make the test tool to ./objs/srs_test and ./objs/srs_gb28181_test ./objs/srs_blackbox_test"
@echo " clean Remove all tools at ./objs"

View file

@ -14,6 +14,8 @@ git clone -b feature/rtc https://github.com/ossrs/srs-bench.git &&
cd srs-bench && make
```
> Note: 依赖Go编译工具建议使用 Go 1.17 及以上的版本。
编译会生成下面的工具:
* `./objs/srs_bench` 压测模拟大量客户端的负载测试支持SRS、GB28181和Janus三种场景。
@ -31,9 +33,11 @@ cd srs/trunk && ./configure --h265=on --gb28181=on && make &&
./objs/srs -c conf/console.conf
```
> Note: Use valgrind to check memory leak, please use `valgrind --leak-check=full ./objs/srs -c conf/console.conf >/dev/null` to start SRS.
具体场景,请按下面的操作启动测试。
## Player for Live
## Player for WHEP
直播播放压测,一个流,很多个播放。
@ -49,7 +53,7 @@ ffmpeg -re -i doc/source.200kbps.768x320.flv -c copy -f flv -y rtmp://localhost/
./objs/srs_bench -sr webrtc://localhost/live/livestream -nn 100
```
## Publisher for Live or RTC
## Publisher for WHIP
直播或会议场景推流压测,一般会推多个流。
@ -63,7 +67,7 @@ ffmpeg -re -i doc/source.200kbps.768x320.flv -c copy -f flv -y rtmp://localhost/
> 注意帧率是原始视频的帧率由于264中没有这个信息所以需要传递。
## Multipel Player or Publisher for RTC
## Multiple WHIP or WHEP for RTC
会议场景的播放压测会多个客户端播放多个流比如3人会议那么就有3个推流每个流有2个播放。
@ -84,7 +88,7 @@ ffmpeg -re -i doc/source.200kbps.768x320.flv -c copy -f flv -y rtmp://localhost/
> 备注URL的变量格式参考Go的`fmt.Sprintf`,比如可以用`webrtc://localhost/live/livestream_%03d`
<a name="dvr"></a>
## DVR for Benchmark
## DVR for RTC Benchmark
录制场景,主要是把内容录制下来后,可分析,也可以用于推流。
@ -120,6 +124,37 @@ ffmpeg -re -i doc/source.200kbps.768x320.flv -c copy -f flv -y rtmp://localhost/
> Note: 可以传递更多参数详细参考SRS支持的参数。
## Reconnecting Load Test
建立连接和断开重连的压测可以测试SRS在多个Source时是否有内存泄露问题参考 [#3667](https://github.com/ossrs/srs/discussions/3667#discussioncomment-8969107)
RTMP重连测试
```bash
for ((i=0;;i++)); do
./objs/srs_bench -sfu=live -pr=rtmp://localhost/live${i}/stream -sn=1000 -cap=true;
sleep 10;
done
```
SRT重连测试
```bash
for ((i=0;;i++)); do
./objs/srs_bench -sfu=live -pr='srt://127.0.0.1:10080?streamid=#!::'m=publish,r=live${i}/stream -sn=1000 -cap=true;
sleep 10;
done
```
WebRTC重连测试
```bash
for ((i=0;;i++)); do
./objs/srs_bench -sfu=rtc -pr=webrtc://localhost/live${i}/livestream -sn=1000 -cap=true;
sleep 10;
done
```
## Regression Test
回归测试需要先启动[SRS](https://github.com/ossrs/srs/issues/307)支持WebRTC推拉流
@ -329,4 +364,50 @@ make -j10 && ./objs/srs_bench -sfu janus \
-nn 5
```
## Install LIBSRT
我们使用 [srtgo](https://github.com/Haivision/srtgo) 库测试SRT协议需要安装libsrt库
参考[macOS](https://github.com/Haivision/srt/blob/master/docs/build/build-macOS.md)
```bash
brew install srt
```
如果是Ubuntu可以参考[Ubuntu](https://github.com/Haivision/srt/blob/master/docs/build/package-managers.md):
```bash
apt-get install -y libsrt
```
安装完libsrt后直接编译srs-bench即可
```bash
make
```
## Ubuntu Docker
如果使用Ubuntu编译推荐使用 `ossrs/srs:ubuntu20` 作为镜像编译已经编译了openssl和libsrt启动容器
```bash
docker run --rm -it -v $(pwd):/g -w /g ossrs/srs:ubuntu20 make
```
## GoLand
使用GoLand编译和调试时需要设置libsrt的环境变量首先可以使用brew获取路径
```bash
brew --prefix srt
#/opt/homebrew/opt/srt
```
然后在GoLand中编辑配置 `Edit Configurations`,添加环境变量:
```bash
CGO_CFLAGS=-I/opt/homebrew/opt/srt/include;CGO_LDFLAGS=-L/opt/homebrew/opt/srt/lib -lsrt
```
> Note: 特别注意的是CGO_LDFLAGS是可以有空格的不能使用字符串否则找不到库。
2021.01, Winlin

View file

@ -900,6 +900,14 @@ func TestSlow_SrtPublish_HttpTsPlay_HEVC_Basic(t *testing.T) {
r1 = ffmpeg.Run(ctx, cancel)
}()
// Should wait for TS to generate the contents.
select {
case <-ctx.Done():
r2 = fmt.Errorf("timeout")
return
case <-time.After(5 * time.Second):
}
// Start FFprobe to detect and verify stream.
duration := time.Duration(*srsFFprobeDuration) * time.Millisecond
ffprobe := NewFFprobe(func(v *ffprobeClient) {
@ -912,9 +920,6 @@ func TestSlow_SrtPublish_HttpTsPlay_HEVC_Basic(t *testing.T) {
defer wg.Done()
<-svr.ReadyCtx().Done()
// wait for ffmpeg
time.Sleep(3 * time.Second)
r2 = ffprobe.Run(ctx, cancel)
}()
@ -930,8 +935,8 @@ func TestSlow_SrtPublish_HttpTsPlay_HEVC_Basic(t *testing.T) {
}
// Note that HLS score is low, so we only check duration.
if dv := m.Duration(); dv < duration {
r5 = errors.Errorf("short duration=%v < %v, %v, %v", dv, duration, m.String(), str)
if dv := m.Duration(); dv < duration/2 {
r5 = errors.Errorf("short duration=%v < %v, %v, %v", dv, duration/2, m.String(), str)
}
if v := m.Video(); v == nil {
@ -949,7 +954,6 @@ func TestSlow_SrtPublish_HlsPlay_HEVC_Basic(t *testing.T) {
// Setup the max timeout for this case.
ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
// Check a set of errors.
var r0, r1, r2, r3, r4 error
defer func(ctx context.Context) {
@ -995,12 +999,17 @@ func TestSlow_SrtPublish_HlsPlay_HEVC_Basic(t *testing.T) {
defer wg.Done()
<-svr.ReadyCtx().Done()
// wait for ffmpeg
time.Sleep(3 * time.Second)
r1 = ffmpeg.Run(ctx, cancel)
}()
// Should wait for HLS to generate the ts files.
select {
case <-ctx.Done():
r2 = fmt.Errorf("timeout")
return
case <-time.After(20 * time.Second):
}
// Start FFprobe to detect and verify stream.
duration := time.Duration(*srsFFprobeDuration) * time.Millisecond
ffprobe := NewFFprobe(func(v *ffprobeClient) {

View file

@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
// # Copyright (c) 2022-2024 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
@ -57,7 +57,11 @@ func Parse(ctx context.Context) interface{} {
fl.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or gb28181 or janus. Default: srs"))
fmt.Println(fmt.Sprintf(" -sfu The target server that can be rtc, live, janus, or gb28181. Default: rtc"))
fmt.Println(fmt.Sprintf(" rtc/srs: SRS WebRTC SFU server, for WebRTC/WHIP/WHEP."))
fmt.Println(fmt.Sprintf(" live: SRS live streaming server, for RTMP/HTTP-FLV/HLS."))
fmt.Println(fmt.Sprintf(" janus: Janus WebRTC SFU server, for janus private protocol."))
fmt.Println(fmt.Sprintf(" gb28181: GB media server, for GB protocol."))
fmt.Println(fmt.Sprintf("SIP:"))
fmt.Println(fmt.Sprintf(" -user The SIP username, ID of device."))
fmt.Println(fmt.Sprintf(" -random Append N number to user as random device ID, like 1320000001."))

View file

@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
// # Copyright (c) 2022-2024 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in

View file

@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
// # Copyright (c) 2022-2024 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in

View file

@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
// # Copyright (c) 2022-2024 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in

View file

@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
// # Copyright (c) 2022-2024 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in

View file

@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
// # Copyright (c) 2022-2024 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in

View file

@ -1,6 +1,6 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
// # Copyright (c) 2022-2024 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in

View file

@ -5,6 +5,7 @@ go 1.17
require (
github.com/ghettovoice/gosip v0.0.0-20220929080231-de8ba881be83
github.com/google/gopacket v1.1.19
github.com/haivision/srtgo v0.0.0-20230627061225-a70d53fcd618
github.com/ossrs/go-oryx-lib v0.0.9
github.com/pion/ice/v2 v2.3.6
github.com/pion/interceptor v0.1.17
@ -28,6 +29,7 @@ require (
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/mattn/go-colorable v0.1.4 // indirect
github.com/mattn/go-isatty v0.0.8 // indirect
github.com/mattn/go-pointer v0.0.1 // indirect
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect
github.com/pion/datachannel v1.5.5 // indirect
github.com/pion/dtls/v2 v2.2.7 // indirect
@ -40,12 +42,12 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b // indirect
github.com/sirupsen/logrus v1.4.2 // indirect
github.com/stretchr/testify v1.8.4 // indirect
github.com/stretchr/testify v1.9.0 // indirect
github.com/tevino/abool v0.0.0-20170917061928-9b9efcf221b5 // indirect
github.com/x-cray/logrus-prefixed-formatter v0.5.2 // indirect
golang.org/x/crypto v0.9.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/term v0.8.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/term v0.18.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View file

@ -31,6 +31,8 @@ github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/haivision/srtgo v0.0.0-20230627061225-a70d53fcd618 h1:oGPTZa7I5wqmQs/UhWHj3ln6/CjQX2yQt784xx6H0wI=
github.com/haivision/srtgo v0.0.0-20230627061225-a70d53fcd618/go.mod h1:aTd4vOr9wtzkCbbocUFh6atlJy7H/iV5jhqEWlTdCdA=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
@ -44,6 +46,8 @@ github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaa
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@ -115,6 +119,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@ -122,8 +127,9 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tevino/abool v0.0.0-20170917061928-9b9efcf221b5 h1:hNna6Fi0eP1f2sMBe/rJicDmaHmoXGe1Ta84FPYHLuE=
github.com/tevino/abool v0.0.0-20170917061928-9b9efcf221b5/go.mod h1:f1SCnEOt6sc3fOJfPQDRDzHOtSXuTtnz0ImG9kPRDV0=
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
@ -140,8 +146,10 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
@ -161,8 +169,9 @@ golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -178,6 +187,7 @@ golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200926100807-9d91bd62050c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -193,8 +203,10 @@ golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -202,8 +214,10 @@ golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@ -213,8 +227,9 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=

View file

@ -66,7 +66,11 @@ func Parse(ctx context.Context) {
fl.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or gb28181 or janus. Default: srs"))
fmt.Println(fmt.Sprintf(" -sfu The target server that can be rtc, live, janus, or gb28181. Default: rtc"))
fmt.Println(fmt.Sprintf(" rtc/srs: SRS WebRTC SFU server, for WebRTC/WHIP/WHEP."))
fmt.Println(fmt.Sprintf(" live: SRS live streaming server, for RTMP/HTTP-FLV/HLS."))
fmt.Println(fmt.Sprintf(" janus: Janus WebRTC SFU server, for janus private protocol."))
fmt.Println(fmt.Sprintf(" gb28181: GB media server, for GB protocol."))
fmt.Println(fmt.Sprintf(" -nn The number of clients to simulate. Default: 1"))
fmt.Println(fmt.Sprintf(" -sn The number of streams to simulate. Variable: %%d. Default: 1"))
fmt.Println(fmt.Sprintf(" -delay The start delay in ms for each client or stream to simulate. Default: 50"))

195
trunk/3rdparty/srs-bench/live/live.go vendored Normal file
View file

@ -0,0 +1,195 @@
// The MIT License (MIT)
//
// # Copyright (c) 2021 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package live
import (
"context"
"flag"
"fmt"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger"
)
var closeAfterPublished bool
var pr string
var streams, delay int
var statListen string
func Parse(ctx context.Context) {
fl := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
var sfu string
fl.StringVar(&sfu, "sfu", "srs", "The SFU server, srs or gb28181 or janus")
fl.BoolVar(&closeAfterPublished, "cap", false, "")
fl.StringVar(&pr, "pr", "", "")
fl.IntVar(&streams, "sn", 1, "")
fl.IntVar(&delay, "delay", 10, "")
fl.StringVar(&statListen, "stat", "", "")
fl.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -sfu The target server that can be rtc, live, janus, or gb28181. Default: rtc"))
fmt.Println(fmt.Sprintf(" rtc/srs: SRS WebRTC SFU server, for WebRTC/WHIP/WHEP."))
fmt.Println(fmt.Sprintf(" live: SRS live streaming server, for RTMP/HTTP-FLV/HLS."))
fmt.Println(fmt.Sprintf(" janus: Janus WebRTC SFU server, for janus private protocol."))
fmt.Println(fmt.Sprintf(" -sn The number of streams to simulate. Variable: %%d. Default: 1"))
fmt.Println(fmt.Sprintf(" -delay The start delay in ms for each client or stream to simulate. Default: 50"))
fmt.Println(fmt.Sprintf(" -stat [Optional] The stat server API listen port."))
fmt.Println(fmt.Sprintf("Publisher:"))
fmt.Println(fmt.Sprintf(" -pr The url to publish. If sn exceed 1, auto append variable %%d."))
fmt.Println(fmt.Sprintf(" -cap Whether to close connection after publish. Default: false"))
fmt.Println(fmt.Sprintf("\n例如1个推流无媒体传输:"))
fmt.Println(fmt.Sprintf(" %v -pr=rtmp://localhost/live/livestream -cap=true", os.Args[0]))
fmt.Println(fmt.Sprintf("\n例如2个推流无媒体传输"))
fmt.Println(fmt.Sprintf(" %v -pr=rtmp://localhost/live/livestream_%%d -sn=2 -cap=true", os.Args[0]))
fmt.Println()
}
_ = fl.Parse(os.Args[1:])
showHelp := streams <= 0
if pr == "" {
showHelp = true
}
if showHelp {
fl.Usage()
os.Exit(-1)
}
if statListen != "" && !strings.Contains(statListen, ":") {
statListen = ":" + statListen
}
summaryDesc := fmt.Sprintf("streams=%v", streams)
if pr != "" {
summaryDesc = fmt.Sprintf("%v, publish=(url=%v,cap=%v)",
summaryDesc, pr, closeAfterPublished)
}
logger.Tf(ctx, "Run benchmark with %v", summaryDesc)
}
func Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Run tasks.
var wg sync.WaitGroup
defer wg.Wait()
// Run STAT API server.
wg.Add(1)
go func() {
defer wg.Done()
if statListen == "" {
return
}
var lc net.ListenConfig
ln, err := lc.Listen(ctx, "tcp", statListen)
if err != nil {
logger.Ef(ctx, "stat listen err+%v", err)
cancel()
return
}
mux := http.NewServeMux()
handleStat(ctx, mux, statListen)
srv := &http.Server{
Handler: mux,
BaseContext: func(listener net.Listener) context.Context {
return ctx
},
}
go func() {
<-ctx.Done()
srv.Shutdown(ctx)
}()
logger.Tf(ctx, "Stat listen at %v", statListen)
if err := srv.Serve(ln); err != nil {
if ctx.Err() == nil {
logger.Ef(ctx, "stat serve err+%v", err)
cancel()
}
return
}
}()
// Run all publishers.
publisherStartedCtx, publisherStartedCancel := context.WithCancel(ctx)
defer publisherStartedCancel()
for i := 0; pr != "" && i < streams && ctx.Err() == nil; i++ {
r_auto := pr
if streams > 1 && !strings.Contains(r_auto, "%") {
r_auto += "%d"
}
r2 := r_auto
if strings.Contains(r2, "%") {
r2 = fmt.Sprintf(r2, i)
}
gStatLive.Publishers.Expect++
gStatLive.Publishers.Alive++
wg.Add(1)
go func(pr string) {
defer wg.Done()
defer func() {
gStatLive.Publishers.Alive--
logger.Tf(ctx, "Publisher %v done, alive=%v", pr, gStatLive.Publishers.Alive)
<- publisherStartedCtx.Done()
if gStatLive.Publishers.Alive == 0 {
cancel()
}
}()
if err := startPublish(ctx, pr, closeAfterPublished); err != nil {
if errors.Cause(err) != context.Canceled {
logger.Wf(ctx, "Run err %+v", err)
}
}
}(r2)
if delay > 0 {
time.Sleep(time.Duration(delay) * time.Millisecond)
}
}
return nil
}

View file

@ -0,0 +1,210 @@
// The MIT License (MIT)
//
// # Copyright (c) 2021 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package live
import (
"context"
"fmt"
"math/rand"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/haivision/srtgo"
"github.com/ossrs/go-oryx-lib/amf0"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/go-oryx-lib/rtmp"
)
func startPublish(ctx context.Context, r string, closeAfterPublished bool) error {
ctx = logger.WithContext(ctx)
logger.Tf(ctx, "Run publish url=%v, cap=%v", r, closeAfterPublished)
u, err := url.Parse(r)
if err != nil {
return errors.Wrapf(err, "parse %v", r)
}
if u.Scheme == "rtmp" {
return startPublishRTMP(ctx, u, closeAfterPublished)
} else if u.Scheme == "srt" {
return startPublishSRT(ctx, u, closeAfterPublished)
}
return fmt.Errorf("invalid schema %v of %v", u.Scheme, r)
}
func startPublishSRT(ctx context.Context, u *url.URL, closeAfterPublished bool) (err error) {
// Parse host and port.
port := 1935
if u.Port() != "" {
if port, err = strconv.Atoi(u.Port()); err != nil {
return errors.Wrapf(err, "parse port %v", u.Port())
}
}
ips, err := net.LookupIP(u.Hostname())
if err != nil {
return errors.Wrapf(err, "lookup %v", u.Hostname())
}
if len(ips) == 0 {
return errors.Errorf("no ips for %v", u.Hostname())
}
logger.Tf(ctx, "Parse url %v to host=%v, ip=%v, port=%v",
u.String(), u.Hostname(), ips[0], port)
// Setup libsrt.
client := srtgo.NewSrtSocket(ips[0].To4().String(), uint16(port),
map[string]string{
"transtype": "live",
"tsbpdmode": "false",
"tlpktdrop": "false",
"latency": "0",
"streamid": fmt.Sprintf("#%v", u.Fragment),
},
)
defer client.Close()
if err := client.Connect(); err != nil {
return errors.Wrapf(err, "SRT connect to %v:%v", u.Hostname(), port)
}
logger.Tf(ctx, "Connect to SRT server %v:%v success", u.Hostname(), port)
// We should wait for a while after connected to SRT server before quit. Because SRT server use timeout
// to detect UDP connection status, so we should never reconnect very fast.
select {
case <-ctx.Done():
case <-time.After(3 * time.Second):
logger.Tf(ctx, "SRT publish stream success, stream=%v", u.Fragment)
}
if closeAfterPublished {
logger.Tf(ctx, "Close connection after published")
return nil
}
return nil
}
func startPublishRTMP(ctx context.Context, u *url.URL, closeAfterPublished bool) (err error) {
parts := strings.Split(u.Path, "/")
if len(parts) == 0 {
return errors.Errorf("invalid path %v", u.Path)
}
app, stream := strings.Join(parts[:len(parts)-1], "/"), parts[len(parts)-1]
// Parse host and port.
port := 1935
if u.Port() != "" {
if port, err = strconv.Atoi(u.Port()); err != nil {
return errors.Wrapf(err, "parse port %v", u.Port())
}
}
ips, err := net.LookupIP(u.Hostname())
if err != nil {
return errors.Wrapf(err, "lookup %v", u.Hostname())
}
if len(ips) == 0 {
return errors.Errorf("no ips for %v", u.Hostname())
}
logger.Tf(ctx, "Parse url %v to host=%v, ip=%v, port=%v, app=%v, stream=%v",
u.String(), u.Hostname(), ips[0], port, app, stream)
// Connect via TCP client.
c, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ips[0], Port: port})
if err != nil {
return errors.Wrapf(err, "dial %v %v", u.Hostname(), u.Port())
}
defer c.Close()
logger.Tf(ctx, "Connect to RTMP server %v:%v success", u.Hostname(), port)
// RTMP Handshake.
rd := rand.New(rand.NewSource(time.Now().UnixNano()))
hs := rtmp.NewHandshake(rd)
if err := hs.WriteC0S0(c); err != nil {
return errors.Wrap(err, "write c0")
}
if err := hs.WriteC1S1(c); err != nil {
return errors.Wrap(err, "write c1")
}
if _, err = hs.ReadC0S0(c); err != nil {
return errors.Wrap(err, "read s1")
}
s1, err := hs.ReadC1S1(c)
if err != nil {
return errors.Wrap(err, "read s1")
}
if _, err = hs.ReadC2S2(c); err != nil {
return errors.Wrap(err, "read s2")
}
if err := hs.WriteC2S2(c, s1); err != nil {
return errors.Wrap(err, "write c2")
}
logger.Tf(ctx, "RTMP handshake with %v:%v success", ips[0], port)
// Do connect and publish.
client := rtmp.NewProtocol(c)
connectApp := rtmp.NewConnectAppPacket()
tcURL := fmt.Sprintf("rtmp://%v%v", u.Hostname(), app)
connectApp.CommandObject.Set("tcUrl", amf0.NewString(tcURL))
if err = client.WritePacket(connectApp, 1); err != nil {
return errors.Wrap(err, "write connect app")
}
var connectAppRes *rtmp.ConnectAppResPacket
if _, err = client.ExpectPacket(&connectAppRes); err != nil {
return errors.Wrap(err, "expect connect app res")
}
logger.Tf(ctx, "RTMP connect app success, tcUrl=%v", tcURL)
createStream := rtmp.NewCreateStreamPacket()
if err = client.WritePacket(createStream, 1); err != nil {
return errors.Wrap(err, "write create stream")
}
var createStreamRes *rtmp.CreateStreamResPacket
if _, err = client.ExpectPacket(&createStreamRes); err != nil {
return errors.Wrap(err, "expect create stream res")
}
logger.Tf(ctx, "RTMP create stream success")
publish := rtmp.NewPublishPacket()
publish.StreamName = *amf0.NewString(stream)
if err = client.WritePacket(publish, 1); err != nil {
return errors.Wrap(err, "write publish")
}
logger.Tf(ctx, "RTMP publish stream success, stream=%v", stream)
if closeAfterPublished {
logger.Tf(ctx, "Close connection after published")
return nil
}
return nil
}

68
trunk/3rdparty/srs-bench/live/stat.go vendored Normal file
View file

@ -0,0 +1,68 @@
// The MIT License (MIT)
//
// # Copyright (c) 2021 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package live
import (
"context"
"encoding/json"
"net/http"
"strings"
"github.com/ossrs/go-oryx-lib/logger"
)
type statLive struct {
Publishers struct {
Expect int `json:"expect"`
Alive int `json:"alive"`
} `json:"publishers"`
Subscribers struct {
Expect int `json:"expect"`
Alive int `json:"alive"`
} `json:"subscribers"`
PeerConnection interface{} `json:"random-pc"`
}
var gStatLive statLive
func handleStat(ctx context.Context, mux *http.ServeMux, l string) {
if strings.HasPrefix(l, ":") {
l = "127.0.0.1" + l
}
logger.Tf(ctx, "Handle http://%v/api/v1/sb/live", l)
mux.HandleFunc("/api/v1/sb/live", func(w http.ResponseWriter, r *http.Request) {
res := &struct {
Code int `json:"code"`
Data interface{} `json:"data"`
}{
0, &gStatLive,
}
b, err := json.Marshal(res)
if err != nil {
logger.Wf(ctx, "marshal %v err %+v", res, err)
return
}
w.Write(b)
})
}

View file

@ -24,27 +24,31 @@ import (
"context"
"flag"
"fmt"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/srs-bench/gb28181"
"github.com/ossrs/srs-bench/janus"
"github.com/ossrs/srs-bench/srs"
"io/ioutil"
"os"
"os/signal"
"syscall"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/srs-bench/gb28181"
"github.com/ossrs/srs-bench/janus"
"github.com/ossrs/srs-bench/live"
"github.com/ossrs/srs-bench/srs"
)
func main() {
var sfu string
fl := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fl.SetOutput(ioutil.Discard)
fl.StringVar(&sfu, "sfu", "srs", "The SFU server, srs or gb28181 or janus")
fl.StringVar(&sfu, "sfu", "rtc", "")
_ = fl.Parse(os.Args[1:])
ctx := context.Background()
var conf interface{}
if sfu == "srs" {
if sfu == "rtc" || sfu == "srs" {
srs.Parse(ctx)
} else if sfu == "live" {
live.Parse(ctx)
} else if sfu == "gb28181" {
conf = gb28181.Parse(ctx)
} else if sfu == "janus" {
@ -52,7 +56,11 @@ func main() {
} else {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or gb28181 or janus. Default: srs"))
fmt.Println(fmt.Sprintf(" -sfu The target server that can be rtc, live, janus, or gb28181. Default: rtc"))
fmt.Println(fmt.Sprintf(" rtc/srs: SRS WebRTC SFU server, for WebRTC/WHIP/WHEP."))
fmt.Println(fmt.Sprintf(" live: SRS live streaming server, for RTMP/HTTP-FLV/HLS."))
fmt.Println(fmt.Sprintf(" janus: Janus WebRTC SFU server, for janus private protocol."))
fmt.Println(fmt.Sprintf(" gb28181: GB media server, for GB protocol."))
os.Exit(-1)
}
@ -67,8 +75,10 @@ func main() {
}()
var err error
if sfu == "srs" {
if sfu == "rtc" || sfu == "srs" {
err = srs.Run(ctx)
} else if sfu == "live" {
err = live.Run(ctx)
} else if sfu == "gb28181" {
err = gb28181.Run(ctx, conf)
} else if sfu == "janus" {

View file

@ -7,6 +7,7 @@ import (
"fmt"
"net"
"os"
"strings"
"time"
"github.com/google/gopacket"
@ -62,10 +63,20 @@ func doMain(ctx context.Context) error {
}
defer f.Close()
var source *gopacket.PacketSource
if strings.HasSuffix(filename, ".pcap") {
r, err := pcapgo.NewReader(f)
if err != nil {
return errors.Wrapf(err, "new reader")
}
source = gopacket.NewPacketSource(r, r.LinkType())
} else {
r, err := pcapgo.NewNgReader(f, pcapgo.DefaultNgReaderOptions)
if err != nil {
return errors.Wrapf(err, "new reader")
}
source = gopacket.NewPacketSource(r, r.LinkType())
}
// TODO: FIXME: Should start a goroutine to consume bytes from conn.
conn, err := net.Dial("tcp", server)
@ -76,7 +87,6 @@ func doMain(ctx context.Context) error {
var packetNumber uint64
var previousTime *time.Time
source := gopacket.NewPacketSource(r, r.LinkType())
for packet := range source.Packets() {
packetNumber++
@ -90,7 +100,7 @@ func doMain(ctx context.Context) error {
if len(payload) == 0 {
continue
}
if tcp.DstPort != 1935 {
if tcp.DstPort != 1935 && tcp.DstPort != 19350 {
continue
}

View file

@ -34,7 +34,7 @@ import (
)
// @see https://github.com/pion/webrtc/blob/master/examples/play-from-disk/main.go
func startPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps int, enableAudioLevel, enableTWCC bool) error {
func startPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps int, enableAudioLevel, enableTWCC, closeAfterPublished bool) error {
ctx = logger.WithContext(ctx)
logger.Tf(ctx, "Run publish url=%v, audio=%v, video=%v, fps=%v, audio-level=%v, twcc=%v",
@ -77,10 +77,13 @@ func startPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
return nil, err
}
if sourceAudio != "" {
// For CAP, we always add audio track, because both audio and video are disabled for CAP, which will
// cause failed when exchange SDP.
if sourceAudio != "" || closeAfterPublished {
aIngester = newAudioIngester(sourceAudio)
registry.Add(&rtpInteceptorFactory{aIngester.audioLevelInterceptor})
}
if sourceVideo != "" {
vIngester = newVideoIngester(sourceVideo)
registry.Add(&rtpInteceptorFactory{vIngester.markerInterceptor})
@ -178,6 +181,7 @@ func startPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
// Wait for event from context or tracks.
var wg sync.WaitGroup
defer wg.Wait()
wg.Add(1)
go func() {
@ -186,6 +190,18 @@ func startPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
doClose() // Interrupt the RTCP read.
}()
// If CAP, directly close the connection after published.
if closeAfterPublished {
select {
case <-ctx.Done():
case <-pcDoneCtx.Done():
}
logger.Tf(ctx, "Close connection after published")
cancel()
return nil
}
wg.Add(1)
go func() {
defer wg.Done()
@ -295,6 +311,5 @@ func startPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i
}
}()
wg.Wait()
return nil
}

View file

@ -2213,6 +2213,7 @@ func TestRtcDTLS_ClientActive_Corrupt_ClientHello(t *testing.T) {
// No.2 srs-server: ServerHello, Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone
// [Corrupt] No.3 srs-bench: Certificate, ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished
// No.4 srs-server: Alert (Level: Fatal, Description: Illegal Parameter)
//
// [Corrupt] No.1 srs-bench: ClientHello(Epoch=0, Sequence=0), change length from 129 to 0xf.
// No.2 srs-server: Alert (Level: Fatal, Description: Illegal Parameter)
func TestRtcDTLS_ClientActive_Corrupt_Certificate(t *testing.T) {

View file

@ -24,7 +24,6 @@ import (
"bytes"
"context"
"fmt"
"github.com/pkg/errors"
"math/rand"
"os"
"sync"
@ -36,6 +35,7 @@ import (
"github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/go-oryx-lib/rtmp"
"github.com/pion/interceptor"
"github.com/pkg/errors"
)
func TestRtmpPublishPlay(t *testing.T) {
@ -623,7 +623,10 @@ func TestRtmpPublish_HttpFlvPlayNoVideo(t *testing.T) {
go func() {
defer wg.Done()
publisher.onSendPacket = func(m *rtmp.Message) error {
time.Sleep(1 * time.Millisecond)
// Note that must greater than the cost of ffmpeg-opus, which is about 4ms, otherwise,
// the publisher will always get audio frames to transcode and won't accept new players
// connection and finally failed the case.
time.Sleep(5 * time.Millisecond)
return nil
}
if r0 = publisher.Ingest(ctx, *srsPublishAvatar); r0 != nil {

View file

@ -46,6 +46,8 @@ var clients, streams, delay int
var statListen string
var closeAfterPublished bool
func Parse(ctx context.Context) {
fl := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
@ -71,10 +73,16 @@ func Parse(ctx context.Context) {
fl.StringVar(&statListen, "stat", "", "")
fl.BoolVar(&closeAfterPublished, "cap", false, "")
fl.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or gb28181 or janus. Default: srs"))
fmt.Println(fmt.Sprintf(" -sfu The target server that can be rtc, live, janus, or gb28181. Default: rtc"))
fmt.Println(fmt.Sprintf(" rtc/srs: SRS WebRTC SFU server, for WebRTC/WHIP/WHEP."))
fmt.Println(fmt.Sprintf(" live: SRS live streaming server, for RTMP/HTTP-FLV/HLS."))
fmt.Println(fmt.Sprintf(" janus: Janus WebRTC SFU server, for janus private protocol."))
fmt.Println(fmt.Sprintf(" gb28181: GB media server, for GB protocol."))
fmt.Println(fmt.Sprintf(" -nn The number of clients to simulate. Default: 1"))
fmt.Println(fmt.Sprintf(" -sn The number of streams to simulate. Variable: %%d. Default: 1"))
fmt.Println(fmt.Sprintf(" -delay The start delay in ms for each client or stream to simulate. Default: 50"))
@ -91,6 +99,7 @@ func Parse(ctx context.Context) {
fmt.Println(fmt.Sprintf(" -fps [Optional] The fps of .h264 source file."))
fmt.Println(fmt.Sprintf(" -sa [Optional] The file path to read audio, ignore if empty."))
fmt.Println(fmt.Sprintf(" -sv [Optional] The file path to read video, ignore if empty."))
fmt.Println(fmt.Sprintf(" -cap Whether to close connection after publish. Default: false"))
fmt.Println(fmt.Sprintf("\n例如1个播放1个推流:"))
fmt.Println(fmt.Sprintf(" %v -sr webrtc://localhost/live/livestream", os.Args[0]))
fmt.Println(fmt.Sprintf(" %v -pr webrtc://localhost/live/livestream -sa avatar.ogg -sv avatar.h264 -fps 25", os.Args[0]))
@ -114,7 +123,7 @@ func Parse(ctx context.Context) {
if sr == "" && pr == "" {
showHelp = true
}
if pr != "" && (sourceAudio == "" && sourceVideo == "") {
if pr != "" && !closeAfterPublished && (sourceAudio == "" && sourceVideo == "") {
showHelp = true
}
if showHelp {
@ -131,8 +140,8 @@ func Parse(ctx context.Context) {
summaryDesc = fmt.Sprintf("%v, play(url=%v, da=%v, dv=%v, pli=%v)", summaryDesc, sr, dumpAudio, dumpVideo, pli)
}
if pr != "" {
summaryDesc = fmt.Sprintf("%v, publish(url=%v, sa=%v, sv=%v, fps=%v)",
summaryDesc, pr, sourceAudio, sourceVideo, fps)
summaryDesc = fmt.Sprintf("%v, publish(url=%v, sa=%v, sv=%v, fps=%v, cap=%v)",
summaryDesc, pr, sourceAudio, sourceVideo, fps, closeAfterPublished)
}
logger.Tf(ctx, "Run benchmark with %v", summaryDesc)
@ -161,6 +170,7 @@ func Run(ctx context.Context) error {
// Run tasks.
var wg sync.WaitGroup
defer wg.Wait()
// Run STAT API server.
wg.Add(1)
@ -266,7 +276,7 @@ func Run(ctx context.Context) error {
gStatRTC.Publishers.Alive--
}()
if err := startPublish(ctx, pr, sourceAudio, sourceVideo, fps, audioLevel, videoTWCC); err != nil {
if err := startPublish(ctx, pr, sourceAudio, sourceVideo, fps, audioLevel, videoTWCC, closeAfterPublished); err != nil {
if errors.Cause(err) != context.Canceled {
logger.Wf(ctx, "Run err %+v", err)
}
@ -276,7 +286,5 @@ func Run(ctx context.Context) error {
time.Sleep(time.Duration(delay) * time.Millisecond)
}
wg.Wait()
return nil
}

View file

@ -0,0 +1,159 @@
package main
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"sync"
"time"
)
func main() {
if err := doMain(); err != nil {
panic(err)
}
}
func doMain() error {
hashID := buildHashID()
listener, err := net.Listen("tcp", ":1935")
if err != nil {
return err
}
trace(hashID, "Listen at %v", listener.Addr())
for {
client, err := listener.Accept()
if err != nil {
return err
}
backend, err := net.Dial("tcp", "localhost:19350")
if err != nil {
return err
}
go serve(client, backend)
}
return nil
}
func serve(client, backend net.Conn) {
defer client.Close()
defer backend.Close()
hashID := buildHashID()
if err := doServe(hashID, client, backend); err != nil {
trace(hashID, "Serve error %v", err)
}
}
func doServe(hashID string, client, backend net.Conn) error {
var wg sync.WaitGroup
var r0 error
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if c, ok := client.(*net.TCPConn); ok {
c.SetNoDelay(true)
}
if c, ok := backend.(*net.TCPConn); ok {
c.SetNoDelay(true)
}
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
for {
buf := make([]byte, 128*1024)
nn, err := client.Read(buf)
if err != nil {
trace(hashID, "Read from client error %v", err)
r0 = err
return
}
if nn == 0 {
trace(hashID, "Read from client EOF")
return
}
_, err = backend.Write(buf[:nn])
if err != nil {
trace(hashID, "Write to RTMP backend error %v", err)
r0 = err
return
}
trace(hashID, "Copy %v bytes to RTMP backend", nn)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
for {
buf := make([]byte, 128*1024)
nn, err := backend.Read(buf)
if err != nil {
trace(hashID, "Read from RTMP backend error %v", err)
r0 = err
return
}
if nn == 0 {
trace(hashID, "Read from RTMP backend EOF")
return
}
_, err = client.Write(buf[:nn])
if err != nil {
trace(hashID, "Write to client error %v", err)
r0 = err
return
}
trace(hashID, "Copy %v bytes to RTMP client", nn)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
defer client.Close()
defer backend.Close()
<-ctx.Done()
trace(hashID, "Context is done, close the connections")
}()
trace(hashID, "Start proxing client %v over %v to backend %v", client.RemoteAddr(), backend.LocalAddr(), backend.RemoteAddr())
wg.Wait()
trace(hashID, "Finish proxing client %v over %v to backend %v", client.RemoteAddr(), backend.LocalAddr(), backend.RemoteAddr())
return r0
}
func trace(id, msg string, a ...interface{}) {
fmt.Println(fmt.Sprintf("[%v][%v] %v",
time.Now().Format("2006-01-02 15:04:05.000"), id,
fmt.Sprintf(msg, a...),
))
}
func buildHashID() string {
randomData := make([]byte, 16)
if _, err := rand.Read(randomData); err != nil {
return ""
}
hash := sha256.Sum256(randomData)
return hex.EncodeToString(hash[:])[:6]
}

View file

@ -0,0 +1,373 @@
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

View file

@ -0,0 +1,63 @@
[![PkgGoDev](https://pkg.go.dev/badge/github.com/haivision/srtgo)](https://pkg.go.dev/github.com/haivision/srtgo)
# srtgo
Go bindings for [SRT](https://github.com/Haivision/srt) (Secure Reliable Transport), the open source transport technology that optimizes streaming performance across unpredictable networks.
## Why srtgo?
The purpose of srtgo is easing the adoption of SRT transport technology. Using Go, with just a few lines of code you can implement an application that sends/receives data with all the benefits of SRT technology: security and reliability, while keeping latency low.
## Is this a new implementation of SRT?
No! We are just exposing the great work done by the community in the [SRT project](https://github.com/Haivision/srt) as a golang library. All the functionality and implementation still resides in the official SRT project.
# Features supported
* Basic API exposed to easy develop SRT sender/receiver apps
* Caller and Listener mode
* Live transport type
* File transport type
* Message/Buffer API
* SRT transport options up to SRT 1.4.1
* SRT Stats retrieval
# Usage
Example of a SRT receiver application:
``` go
package main
import (
"github.com/haivision/srtgo"
"fmt"
)
func main() {
options := make(map[string]string)
options["transtype"] = "file"
sck := srtgo.NewSrtSocket("0.0.0.0", 8090, options)
defer sck.Close()
sck.Listen(1)
s, _ := sck.Accept()
defer s.Close()
buf := make([]byte, 2048)
for {
n, _ := s.Read(buf)
if n == 0 {
break
}
fmt.Println("Received %d bytes", n)
}
//....
}
```
# Dependencies
* srtlib
You can find detailed instructions about how to install srtlib in its [README file](https://github.com/Haivision/srt#requirements)
gosrt has been developed with srt 1.4.1 as its main target and has been successfully tested in srt 1.3.4 and above.

View file

@ -0,0 +1,69 @@
package srtgo
/*
#cgo LDFLAGS: -lsrt
#include <srt/srt.h>
SRTSOCKET srt_accept_wrapped(SRTSOCKET lsn, struct sockaddr* addr, int* addrlen, int *srterror, int *syserror)
{
int ret = srt_accept(lsn, addr, addrlen);
if (ret < 0) {
*srterror = srt_getlasterror(syserror);
}
return ret;
}
*/
import "C"
import (
"fmt"
"net"
"syscall"
"unsafe"
)
func srtAcceptImpl(lsn C.SRTSOCKET, addr *C.struct_sockaddr, addrlen *C.int) (C.SRTSOCKET, error) {
srterr := C.int(0)
syserr := C.int(0)
socket := C.srt_accept_wrapped(lsn, addr, addrlen, &srterr, &syserr)
if srterr != 0 {
srterror := SRTErrno(srterr)
if syserr < 0 {
srterror.wrapSysErr(syscall.Errno(syserr))
}
return socket, srterror
}
return socket, nil
}
// Accept an incoming connection
func (s SrtSocket) Accept() (*SrtSocket, *net.UDPAddr, error) {
var err error
if !s.blocking {
err = s.pd.wait(ModeRead)
if err != nil {
return nil, nil, err
}
}
var addr syscall.RawSockaddrAny
sclen := C.int(syscall.SizeofSockaddrAny)
socket, err := srtAcceptImpl(s.socket, (*C.struct_sockaddr)(unsafe.Pointer(&addr)), &sclen)
if err != nil {
return nil, nil, err
}
if socket == SRT_INVALID_SOCK {
return nil, nil, fmt.Errorf("srt accept, error accepting the connection: %w", srtGetAndClearError())
}
newSocket, err := newFromSocket(&s, socket)
if err != nil {
return nil, nil, fmt.Errorf("new socket could not be created: %w", err)
}
udpAddr, err := udpAddrFromSockaddr(&addr)
if err != nil {
return nil, nil, err
}
return newSocket, udpAddr, nil
}

View file

@ -0,0 +1,7 @@
#include <srt/srt.h>
int srtListenCBWrapper(void* opaque, SRTSOCKET ns, int hs_version, struct sockaddr* peeraddr, char* streamid);
void srtConnectCBWrapper(void* opaque, SRTSOCKET ns, int errorcode, struct sockaddr* peeraddr, int token);
int srtListenCB(void* opaque, SRTSOCKET ns, int hs_version, const struct sockaddr* peeraddr, const char* streamid);
void srtConnectCB(void* opaque, SRTSOCKET ns, int errorcode, const struct sockaddr* peeraddr, int token);

View file

@ -0,0 +1,17 @@
package srtgo
/*
#cgo LDFLAGS: -lsrt
#include "callback.h"
int srtListenCB(void* opaque, SRTSOCKET ns, int hs_version, const struct sockaddr* peeraddr, const char* streamid)
{
return srtListenCBWrapper(opaque, ns, hs_version, (struct sockaddr*)peeraddr, (char*)streamid);
}
void srtConnectCB(void* opaque, SRTSOCKET ns, int errorcode, const struct sockaddr* peeraddr, int token)
{
srtConnectCBWrapper(opaque, ns, errorcode, (struct sockaddr*)peeraddr, token);
}
*/
import "C"

View file

@ -0,0 +1,242 @@
package srtgo
/*
#cgo LDFLAGS: -lsrt
#include <srt/srt.h>
*/
import "C"
import (
"strconv"
"syscall"
)
type SrtInvalidSock struct{}
type SrtRendezvousUnbound struct{}
type SrtSockConnected struct{}
type SrtConnectionRejected struct{}
type SrtConnectTimeout struct{}
type SrtSocketClosed struct{}
type SrtEpollTimeout struct{}
func (m *SrtInvalidSock) Error() string {
return "Socket u indicates no valid socket ID"
}
func (m *SrtRendezvousUnbound) Error() string {
return "Socket u is in rendezvous mode, but it wasn't bound"
}
func (m *SrtSockConnected) Error() string {
return "Socket u is already connected"
}
func (m *SrtConnectionRejected) Error() string {
return "Connection has been rejected"
}
func (m *SrtConnectTimeout) Error() string {
return "Connection has been timed out"
}
func (m *SrtSocketClosed) Error() string {
return "The socket has been closed"
}
func (m *SrtEpollTimeout) Error() string {
return "Operation has timed out"
}
func (m *SrtEpollTimeout) Timeout() bool {
return true
}
func (m *SrtEpollTimeout) Temporary() bool {
return true
}
//MUST be called from same OS thread that generated the error (i.e.: use runtime.LockOSThread())
func srtGetAndClearError() error {
defer C.srt_clearlasterror()
eSysErrno := C.int(0)
errno := C.srt_getlasterror(&eSysErrno)
srterr := SRTErrno(errno)
if eSysErrno != 0 {
return srterr.wrapSysErr(syscall.Errno(eSysErrno))
}
return srterr
}
//Based of off golang errno handling: https://cs.opensource.google/go/go/+/refs/tags/go1.16.6:src/syscall/syscall_unix.go;l=114
type SRTErrno int
func (e SRTErrno) Error() string {
//Workaround for unknown being -1
if e == Unknown {
return "Internal error when setting the right error code"
}
if 0 <= int(e) && int(e) < len(srterrors) {
s := srterrors[e]
if s != "" {
return s
}
}
return "srterrno: " + strconv.Itoa(int(e))
}
func (e SRTErrno) Is(target error) bool {
//for backwards compat
switch target.(type) {
case *SrtInvalidSock:
return e == EInvSock
case *SrtRendezvousUnbound:
return e == ERdvUnbound
case *SrtSockConnected:
return e == EConnSock
case *SrtConnectionRejected:
return e == EConnRej
case *SrtConnectTimeout:
return e == ETimeout
case *SrtSocketClosed:
return e == ESClosed
}
return false
}
func (e SRTErrno) Temporary() bool {
return e == EAsyncFAIL || e == EAsyncRCV || e == EAsyncSND || e == ECongest || e == ETimeout
}
func (e SRTErrno) Timeout() bool {
return e == ETimeout
}
func (e SRTErrno) wrapSysErr(errno syscall.Errno) error {
return &srtErrnoSysErrnoWrapped{
e: e,
eSys: errno,
}
}
type srtErrnoSysErrnoWrapped struct {
e SRTErrno
eSys syscall.Errno
}
func (e *srtErrnoSysErrnoWrapped) Error() string {
return e.e.Error()
}
func (e *srtErrnoSysErrnoWrapped) Is(target error) bool {
return e.e.Is(target)
}
func (e *srtErrnoSysErrnoWrapped) Temporary() bool {
return e.e.Temporary()
}
func (e *srtErrnoSysErrnoWrapped) Timeout() bool {
return e.e.Timeout()
}
func (e *srtErrnoSysErrnoWrapped) Unwrap() error {
return error(e.eSys)
}
//Shadows SRT_ERRNO srtcore/srt.h line 490+
const (
Unknown = SRTErrno(C.SRT_EUNKNOWN)
Success = SRTErrno(C.SRT_SUCCESS)
//Major: SETUP
EConnSetup = SRTErrno(C.SRT_ECONNSETUP)
ENoServer = SRTErrno(C.SRT_ENOSERVER)
EConnRej = SRTErrno(C.SRT_ECONNREJ)
ESockFail = SRTErrno(C.SRT_ESOCKFAIL)
ESecFail = SRTErrno(C.SRT_ESECFAIL)
ESClosed = SRTErrno(C.SRT_ESCLOSED)
//Major: CONNECTION
EConnFail = SRTErrno(C.SRT_ECONNFAIL)
EConnLost = SRTErrno(C.SRT_ECONNLOST)
ENoConn = SRTErrno(C.SRT_ENOCONN)
//Major: SYSTEMRES
EResource = SRTErrno(C.SRT_ERESOURCE)
EThread = SRTErrno(C.SRT_ETHREAD)
EnoBuf = SRTErrno(C.SRT_ENOBUF)
ESysObj = SRTErrno(C.SRT_ESYSOBJ)
//Major: FILESYSTEM
EFile = SRTErrno(C.SRT_EFILE)
EInvRdOff = SRTErrno(C.SRT_EINVRDOFF)
ERdPerm = SRTErrno(C.SRT_ERDPERM)
EInvWrOff = SRTErrno(C.SRT_EINVWROFF)
EWrPerm = SRTErrno(C.SRT_EWRPERM)
//Major: NOTSUP
EInvOp = SRTErrno(C.SRT_EINVOP)
EBoundSock = SRTErrno(C.SRT_EBOUNDSOCK)
EConnSock = SRTErrno(C.SRT_ECONNSOCK)
EInvParam = SRTErrno(C.SRT_EINVPARAM)
EInvSock = SRTErrno(C.SRT_EINVSOCK)
EUnboundSock = SRTErrno(C.SRT_EUNBOUNDSOCK)
ENoListen = SRTErrno(C.SRT_ENOLISTEN)
ERdvNoServ = SRTErrno(C.SRT_ERDVNOSERV)
ERdvUnbound = SRTErrno(C.SRT_ERDVUNBOUND)
EInvalMsgAPI = SRTErrno(C.SRT_EINVALMSGAPI)
EInvalBufferAPI = SRTErrno(C.SRT_EINVALBUFFERAPI)
EDupListen = SRTErrno(C.SRT_EDUPLISTEN)
ELargeMsg = SRTErrno(C.SRT_ELARGEMSG)
EInvPollID = SRTErrno(C.SRT_EINVPOLLID)
EPollEmpty = SRTErrno(C.SRT_EPOLLEMPTY)
//EBindConflict = SRTErrno(C.SRT_EBINDCONFLICT)
//Major: AGAIN
EAsyncFAIL = SRTErrno(C.SRT_EASYNCFAIL)
EAsyncSND = SRTErrno(C.SRT_EASYNCSND)
EAsyncRCV = SRTErrno(C.SRT_EASYNCRCV)
ETimeout = SRTErrno(C.SRT_ETIMEOUT)
ECongest = SRTErrno(C.SRT_ECONGEST)
//Major: PEERERROR
EPeer = SRTErrno(C.SRT_EPEERERR)
)
//Unknown cannot be here since it would have a negative index!
//Error strings taken from: https://github.com/Haivision/srt/blob/master/docs/API/API-functions.md
var srterrors = [...]string{
Success: "The value set when the last error was cleared and no error has occurred since then",
EConnSetup: "General setup error resulting from internal system state",
ENoServer: "Connection timed out while attempting to connect to the remote address",
EConnRej: "Connection has been rejected",
ESockFail: "An error occurred when trying to call a system function on an internally used UDP socket",
ESecFail: "A possible tampering with the handshake packets was detected, or encryption request wasn't properly fulfilled.",
ESClosed: "A socket that was vital for an operation called in blocking mode has been closed during the operation",
EConnFail: "General connection failure of unknown details",
EConnLost: "The socket was properly connected, but the connection has been broken",
ENoConn: "The socket is not connected",
EResource: "System or standard library error reported unexpectedly for unknown purpose",
EThread: "System was unable to spawn a new thread when requried",
EnoBuf: "System was unable to allocate memory for buffers",
ESysObj: "System was unable to allocate system specific objects",
EFile: "General filesystem error (for functions operating with file transmission)",
EInvRdOff: "Failure when trying to read from a given position in the file",
ERdPerm: "Read permission was denied when trying to read from file",
EInvWrOff: "Failed to set position in the written file",
EWrPerm: "Write permission was denied when trying to write to a file",
EInvOp: "Invalid operation performed for the current state of a socket",
EBoundSock: "The socket is currently bound and the required operation cannot be performed in this state",
EConnSock: "The socket is currently connected and therefore performing the required operation is not possible",
EInvParam: "Call parameters for API functions have some requirements that were not satisfied",
EInvSock: "The API function required an ID of an entity (socket or group) and it was invalid",
EUnboundSock: "The operation to be performed on a socket requires that it first be explicitly bound",
ENoListen: "The socket passed for the operation is required to be in the listen state",
ERdvNoServ: "The required operation cannot be performed when the socket is set to rendezvous mode",
ERdvUnbound: "An attempt was made to connect to a socket set to rendezvous mode that was not first bound",
EInvalMsgAPI: "The function was used incorrectly in the message API",
EInvalBufferAPI: "The function was used incorrectly in the stream (buffer) API",
EDupListen: "The port tried to be bound for listening is already busy",
ELargeMsg: "Size exceeded",
EInvPollID: "The epoll ID passed to an epoll function is invalid",
EPollEmpty: "The epoll container currently has no subscribed sockets",
//EBindConflict: "SRT_EBINDCONFLICT",
EAsyncFAIL: "General asynchronous failure (not in use currently)",
EAsyncSND: "Sending operation is not ready to perform",
EAsyncRCV: "Receiving operation is not ready to perform",
ETimeout: "The operation timed out",
ECongest: "With SRTO_TSBPDMODE and SRTO_TLPKTDROP set to true, some packets were dropped by sender",
EPeer: "Receiver peer is writing to a file that the agent is sending",
}

View file

@ -0,0 +1,66 @@
package srtgo
/*
#cgo LDFLAGS: -lsrt
#include <srt/srt.h>
extern void srtLogCB(void* opaque, int level, const char* file, int line, const char* area, const char* message);
*/
import "C"
import (
"sync"
"unsafe"
gopointer "github.com/mattn/go-pointer"
)
type LogCallBackFunc func(level SrtLogLevel, file string, line int, area, message string)
type SrtLogLevel int
const (
// SrtLogLevelEmerg = int(C.LOG_EMERG)
// SrtLogLevelAlert = int(C.LOG_ALERT)
SrtLogLevelCrit SrtLogLevel = SrtLogLevel(C.LOG_CRIT)
SrtLogLevelErr SrtLogLevel = SrtLogLevel(C.LOG_ERR)
SrtLogLevelWarning SrtLogLevel = SrtLogLevel(C.LOG_WARNING)
SrtLogLevelNotice SrtLogLevel = SrtLogLevel(C.LOG_NOTICE)
SrtLogLevelInfo SrtLogLevel = SrtLogLevel(C.LOG_INFO)
SrtLogLevelDebug SrtLogLevel = SrtLogLevel(C.LOG_DEBUG)
SrtLogLevelTrace SrtLogLevel = SrtLogLevel(8)
)
var (
logCBPtr unsafe.Pointer = nil
logCBPtrLock sync.Mutex
)
//export srtLogCBWrapper
func srtLogCBWrapper(arg unsafe.Pointer, level C.int, file *C.char, line C.int, area, message *C.char) {
userCB := gopointer.Restore(arg).(LogCallBackFunc)
go userCB(SrtLogLevel(level), C.GoString(file), int(line), C.GoString(area), C.GoString(message))
}
func SrtSetLogLevel(level SrtLogLevel) {
C.srt_setloglevel(C.int(level))
}
func SrtSetLogHandler(cb LogCallBackFunc) {
ptr := gopointer.Save(cb)
C.srt_setloghandler(ptr, (*C.SRT_LOG_HANDLER_FN)(C.srtLogCB))
storeLogCBPtr(ptr)
}
func SrtUnsetLogHandler() {
C.srt_setloghandler(nil, nil)
storeLogCBPtr(nil)
}
func storeLogCBPtr(ptr unsafe.Pointer) {
logCBPtrLock.Lock()
defer logCBPtrLock.Unlock()
if logCBPtr != nil {
gopointer.Unref(logCBPtr)
}
logCBPtr = ptr
}

View file

@ -0,0 +1,14 @@
package srtgo
/*
#cgo LDFLAGS: -lsrt
#include <srt/srt.h>
extern void srtLogCBWrapper (void* opaque, int level, char* file, int line, char* area, char* message);
void srtLogCB(void* opaque, int level, const char* file, int line, const char* area, const char* message)
{
srtLogCBWrapper(opaque, level, (char*)file, line, (char*)area,(char*) message);
}
*/
import "C"

View file

@ -0,0 +1,87 @@
package srtgo
//#include <srt/srt.h>
import "C"
import (
"encoding/binary"
"fmt"
"net"
"syscall"
"unsafe"
)
func ntohs(val uint16) uint16 {
tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
return binary.BigEndian.Uint16((*tmp)[:])
}
func udpAddrFromSockaddr(addr *syscall.RawSockaddrAny) (*net.UDPAddr, error) {
var udpAddr net.UDPAddr
switch addr.Addr.Family {
case afINET6:
ptr := (*syscall.RawSockaddrInet6)(unsafe.Pointer(addr))
udpAddr.Port = int(ntohs(ptr.Port))
udpAddr.IP = ptr.Addr[:]
case afINET4:
ptr := (*syscall.RawSockaddrInet4)(unsafe.Pointer(addr))
udpAddr.Port = int(ntohs(ptr.Port))
udpAddr.IP = net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
ptr.Addr[2],
ptr.Addr[3],
)
default:
return nil, fmt.Errorf("unknown address family: %v", addr.Addr.Family)
}
return &udpAddr, nil
}
func sockAddrFromIp4(ip net.IP, port uint16) (*C.struct_sockaddr, int, error) {
var raw syscall.RawSockaddrInet4
raw.Family = afINET4
p := (*[2]byte)(unsafe.Pointer(&raw.Port))
p[0] = byte(port >> 8)
p[1] = byte(port)
copy(raw.Addr[:], ip.To4())
return (*C.struct_sockaddr)(unsafe.Pointer(&raw)), int(sizeofSockAddrInet4), nil
}
func sockAddrFromIp6(ip net.IP, port uint16) (*C.struct_sockaddr, int, error) {
var raw syscall.RawSockaddrInet6
raw.Family = afINET6
p := (*[2]byte)(unsafe.Pointer(&raw.Port))
p[0] = byte(port >> 8)
p[1] = byte(port)
copy(raw.Addr[:], ip.To16())
return (*C.struct_sockaddr)(unsafe.Pointer(&raw)), int(sizeofSockAddrInet6), nil
}
func CreateAddrInet(name string, port uint16) (*C.struct_sockaddr, int, error) {
ip := net.ParseIP(name)
if ip == nil {
ips, err := net.LookupIP(name)
if err != nil {
return nil, 0, fmt.Errorf("Error in CreateAddrInet, LookupIP")
}
ip = ips[0]
}
if ip.To4() != nil {
return sockAddrFromIp4(ip, port)
} else if ip.To16() != nil {
return sockAddrFromIp6(ip, port)
}
return nil, 0, fmt.Errorf("Error in CreateAddrInet, LookupIP")
}

View file

@ -0,0 +1,17 @@
//go:build !windows
package srtgo
import (
"syscall"
"golang.org/x/sys/unix"
)
const (
sizeofSockAddrInet4 = syscall.SizeofSockaddrInet4
sizeofSockAddrInet6 = syscall.SizeofSockaddrInet6
sizeofSockaddrAny = syscall.SizeofSockaddrAny
afINET4 = unix.AF_INET
afINET6 = unix.AF_INET6
)

View file

@ -0,0 +1,29 @@
//go:build windows
package srtgo
import (
"unsafe"
"golang.org/x/sys/windows"
)
const (
afINET4 = windows.AF_INET
afINET6 = windows.AF_INET6
)
var (
sizeofSockAddrInet4 uint64 = 0
sizeofSockAddrInet6 uint64 = 0
sizeofSockaddrAny uint64 = 0
)
func init() {
inet4 := windows.RawSockaddrInet4{}
inet6 := windows.RawSockaddrInet6{}
any := windows.RawSockaddrAny{}
sizeofSockAddrInet4 = uint64(unsafe.Sizeof(inet4))
sizeofSockAddrInet6 = uint64(unsafe.Sizeof(inet6))
sizeofSockaddrAny = uint64(unsafe.Sizeof(any))
}

View file

@ -0,0 +1,269 @@
package srtgo
/*
#cgo LDFLAGS: -lsrt
#include <srt/srt.h>
*/
import "C"
import (
"sync"
"sync/atomic"
"time"
)
const (
pollDefault = int32(iota)
pollReady = int32(iota)
pollWait = int32(iota)
)
type PollMode int
const (
ModeRead = PollMode(iota)
ModeWrite
)
/*
pollDesc contains the polling state for the associated SrtSocket
closing: socket is closing, reject all poll operations
pollErr: an error occured on the socket, indicates it's not useable anymore.
unblockRd: is used to unblock the poller when the socket becomes ready for io
rdState: polling state for read operations
rdDeadline: deadline in NS before poll operation times out, -1 means timedout (needs to be cleared), 0 is without timeout
rdSeq: sequence number protects against spurious signalling of timeouts when timer is reset.
rdTimer: timer used to enforce deadline.
*/
type pollDesc struct {
lock sync.Mutex
closing bool
fd C.SRTSOCKET
pollErr bool
unblockRd chan interface{}
rdState int32
rdLock sync.Mutex
rdDeadline int64
rdSeq int64
rdTimer *time.Timer
rtSeq int64
unblockWr chan interface{}
wrState int32
wrLock sync.Mutex
wdDeadline int64
wdSeq int64
wdTimer *time.Timer
wtSeq int64
pollS *pollServer
}
var pdPool = sync.Pool{
New: func() interface{} {
return &pollDesc{
unblockRd: make(chan interface{}, 1),
unblockWr: make(chan interface{}, 1),
rdTimer: time.NewTimer(0),
wdTimer: time.NewTimer(0),
}
},
}
func pollDescInit(s C.SRTSOCKET) *pollDesc {
pd := pdPool.Get().(*pollDesc)
pd.lock.Lock()
defer pd.lock.Unlock()
pd.fd = s
pd.rdState = pollDefault
pd.wrState = pollDefault
pd.pollS = pollServerCtx()
pd.closing = false
pd.pollErr = false
pd.rdSeq++
pd.wdSeq++
pd.pollS.pollOpen(pd)
return pd
}
func (pd *pollDesc) release() {
pd.lock.Lock()
defer pd.lock.Unlock()
if !pd.closing || pd.rdState == pollWait || pd.wrState == pollWait {
panic("returning open or blocked upon pollDesc")
}
pd.fd = 0
pdPool.Put(pd)
}
func (pd *pollDesc) wait(mode PollMode) error {
defer pd.reset(mode)
if err := pd.checkPollErr(mode); err != nil {
return err
}
state := &pd.rdState
unblockChan := pd.unblockRd
expiryChan := pd.rdTimer.C
timerSeq := int64(0)
pd.lock.Lock()
if mode == ModeRead {
timerSeq = pd.rtSeq
pd.rdLock.Lock()
defer pd.rdLock.Unlock()
} else if mode == ModeWrite {
timerSeq = pd.wtSeq
state = &pd.wrState
unblockChan = pd.unblockWr
expiryChan = pd.wdTimer.C
pd.wrLock.Lock()
defer pd.wrLock.Unlock()
}
for {
old := *state
if old == pollReady {
*state = pollDefault
pd.lock.Unlock()
return nil
}
if atomic.CompareAndSwapInt32(state, pollDefault, pollWait) {
break
}
}
pd.lock.Unlock()
wait:
for {
select {
case <-unblockChan:
break wait
case <-expiryChan:
pd.lock.Lock()
if mode == ModeRead {
if timerSeq == pd.rdSeq {
pd.rdDeadline = -1
pd.lock.Unlock()
break wait
}
timerSeq = pd.rtSeq
}
if mode == ModeWrite {
if timerSeq == pd.wdSeq {
pd.wdDeadline = -1
pd.lock.Unlock()
break wait
}
timerSeq = pd.wtSeq
}
pd.lock.Unlock()
}
}
err := pd.checkPollErr(mode)
return err
}
func (pd *pollDesc) close() {
pd.lock.Lock()
defer pd.lock.Unlock()
if pd.closing {
return
}
pd.closing = true
pd.pollS.pollClose(pd)
}
func (pd *pollDesc) checkPollErr(mode PollMode) error {
pd.lock.Lock()
defer pd.lock.Unlock()
if pd.closing {
return &SrtSocketClosed{}
}
if mode == ModeRead && pd.rdDeadline < 0 || mode == ModeWrite && pd.wdDeadline < 0 {
return &SrtEpollTimeout{}
}
if pd.pollErr {
return &SrtSocketClosed{}
}
return nil
}
func (pd *pollDesc) setDeadline(t time.Time, mode PollMode) {
pd.lock.Lock()
defer pd.lock.Unlock()
var d int64
if !t.IsZero() {
d = int64(time.Until(t))
if d == 0 {
d = -1
}
}
if mode == ModeRead || mode == ModeRead+ModeWrite {
pd.rdSeq++
pd.rtSeq = pd.rdSeq
if pd.rdDeadline > 0 {
pd.rdTimer.Stop()
}
pd.rdDeadline = d
if d > 0 {
pd.rdTimer.Reset(time.Duration(d))
}
if d < 0 {
pd.unblock(ModeRead, false, false)
}
}
if mode == ModeWrite || mode == ModeRead+ModeWrite {
pd.wdSeq++
pd.wtSeq = pd.wdSeq
if pd.wdDeadline > 0 {
pd.wdTimer.Stop()
}
pd.wdDeadline = d
if d > 0 {
pd.wdTimer.Reset(time.Duration(d))
}
if d < 0 {
pd.unblock(ModeWrite, false, false)
}
}
}
func (pd *pollDesc) unblock(mode PollMode, pollerr, ioready bool) {
if pollerr {
pd.lock.Lock()
pd.pollErr = pollerr
pd.lock.Unlock()
}
state := &pd.rdState
unblockChan := pd.unblockRd
if mode == ModeWrite {
state = &pd.wrState
unblockChan = pd.unblockWr
}
pd.lock.Lock()
old := atomic.LoadInt32(state)
if ioready {
atomic.StoreInt32(state, pollReady)
}
pd.lock.Unlock()
if old == pollWait {
//make sure we never block here
select {
case unblockChan <- struct{}{}:
//
default:
//
}
}
}
func (pd *pollDesc) reset(mode PollMode) {
if mode == ModeRead {
pd.rdLock.Lock()
pd.rdState = pollDefault
pd.rdLock.Unlock()
} else if mode == ModeWrite {
pd.wrLock.Lock()
pd.wrState = pollDefault
pd.wrLock.Unlock()
}
}

View file

@ -0,0 +1,109 @@
package srtgo
/*
#cgo LDFLAGS: -lsrt
#include <srt/srt.h>
*/
import "C"
import (
"sync"
"unsafe"
)
var (
phctx *pollServer
once sync.Once
)
func pollServerCtx() *pollServer {
once.Do(pollServerCtxInit)
return phctx
}
func pollServerCtxInit() {
eid := C.srt_epoll_create()
C.srt_epoll_set(eid, C.SRT_EPOLL_ENABLE_EMPTY)
phctx = &pollServer{
srtEpollDescr: eid,
pollDescs: make(map[C.SRTSOCKET]*pollDesc),
}
go phctx.run()
}
type pollServer struct {
srtEpollDescr C.int
pollDescLock sync.Mutex
pollDescs map[C.SRTSOCKET]*pollDesc
}
func (p *pollServer) pollOpen(pd *pollDesc) {
//use uint because otherwise with ET it would overflow :/ (srt should accept an uint instead, or fix it's SRT_EPOLL_ET definition)
events := C.uint(C.SRT_EPOLL_IN | C.SRT_EPOLL_OUT | C.SRT_EPOLL_ERR | C.SRT_EPOLL_ET)
//via unsafe.Pointer because we cannot cast *C.uint to *C.int directly
//block poller
p.pollDescLock.Lock()
ret := C.srt_epoll_add_usock(p.srtEpollDescr, pd.fd, (*C.int)(unsafe.Pointer(&events)))
if ret == -1 {
panic("ERROR ADDING FD TO EPOLL")
}
p.pollDescs[pd.fd] = pd
p.pollDescLock.Unlock()
}
func (p *pollServer) pollClose(pd *pollDesc) {
sockstate := C.srt_getsockstate(pd.fd)
//Broken/closed sockets get removed internally by SRT lib
if sockstate == C.SRTS_BROKEN || sockstate == C.SRTS_CLOSING || sockstate == C.SRTS_CLOSED || sockstate == C.SRTS_NONEXIST {
return
}
ret := C.srt_epoll_remove_usock(p.srtEpollDescr, pd.fd)
if ret == -1 {
panic("ERROR REMOVING FD FROM EPOLL")
}
p.pollDescLock.Lock()
delete(p.pollDescs, pd.fd)
p.pollDescLock.Unlock()
}
func init() {
}
func (p *pollServer) run() {
timeoutMs := C.int64_t(-1)
fds := [128]C.SRT_EPOLL_EVENT{}
fdlen := C.int(128)
for {
res := C.srt_epoll_uwait(p.srtEpollDescr, &fds[0], fdlen, timeoutMs)
if res == 0 {
continue //Shouldn't happen with -1
} else if res == -1 {
panic("srt_epoll_error")
} else if res > 0 {
max := int(res)
if fdlen < res {
max = int(fdlen)
}
p.pollDescLock.Lock()
for i := 0; i < max; i++ {
s := fds[i].fd
events := fds[i].events
pd := p.pollDescs[s]
if events&C.SRT_EPOLL_ERR != 0 {
pd.unblock(ModeRead, true, false)
pd.unblock(ModeWrite, true, false)
continue
}
if events&C.SRT_EPOLL_IN != 0 {
pd.unblock(ModeRead, false, true)
}
if events&C.SRT_EPOLL_OUT != 0 {
pd.unblock(ModeWrite, false, true)
}
}
p.pollDescLock.Unlock()
}
}
}

View file

@ -0,0 +1,54 @@
package srtgo
/*
#cgo LDFLAGS: -lsrt
#include <srt/srt.h>
int srt_recvmsg2_wrapped(SRTSOCKET u, char* buf, int len, SRT_MSGCTRL *mctrl, int *srterror, int *syserror)
{
int ret = srt_recvmsg2(u, buf, len, mctrl);
if (ret < 0) {
*srterror = srt_getlasterror(syserror);
}
return ret;
}
*/
import "C"
import (
"errors"
"syscall"
"unsafe"
)
func srtRecvMsg2Impl(u C.SRTSOCKET, buf []byte, msgctrl *C.SRT_MSGCTRL) (n int, err error) {
srterr := C.int(0)
syserr := C.int(0)
n = int(C.srt_recvmsg2_wrapped(u, (*C.char)(unsafe.Pointer(&buf[0])), C.int(len(buf)), msgctrl, &srterr, &syserr))
if n < 0 {
srterror := SRTErrno(srterr)
if syserr < 0 {
srterror.wrapSysErr(syscall.Errno(syserr))
}
err = srterror
n = 0
}
return
}
// Read data from the SRT socket
func (s SrtSocket) Read(b []byte) (n int, err error) {
//Fastpath
if !s.blocking {
s.pd.reset(ModeRead)
}
n, err = srtRecvMsg2Impl(s.socket, b, nil)
for {
if !errors.Is(err, error(EAsyncRCV)) || s.blocking {
return
}
s.pd.wait(ModeRead)
n, err = srtRecvMsg2Impl(s.socket, b, nil)
}
}

View file

@ -0,0 +1,578 @@
package srtgo
/*
#cgo LDFLAGS: -lsrt
#include <srt/srt.h>
#include <srt/access_control.h>
#include "callback.h"
static const SRTSOCKET get_srt_invalid_sock() { return SRT_INVALID_SOCK; };
static const int get_srt_error() { return SRT_ERROR; };
static const int get_srt_error_reject_predefined() { return SRT_REJC_PREDEFINED; };
static const int get_srt_error_reject_userdefined() { return SRT_REJC_USERDEFINED; };
*/
import "C"
import (
"errors"
"fmt"
"net"
"runtime"
"strconv"
"sync"
"syscall"
"time"
"unsafe"
gopointer "github.com/mattn/go-pointer"
)
// SRT Socket mode
const (
ModeFailure = iota
ModeListener
ModeCaller
ModeRendezvouz
)
// Binding ops
const (
bindingPre = 0
bindingPost = 1
)
// SrtSocket - SRT socket
type SrtSocket struct {
socket C.int
blocking bool
pd *pollDesc
host string
port uint16
options map[string]string
mode int
pktSize int
pollTimeout int64
}
var (
callbackMutex sync.Mutex
listenCallbackMap map[C.int]unsafe.Pointer = make(map[C.int]unsafe.Pointer)
connectCallbackMap map[C.int]unsafe.Pointer = make(map[C.int]unsafe.Pointer)
)
// Static consts from library
var (
SRT_INVALID_SOCK = C.get_srt_invalid_sock()
SRT_ERROR = C.get_srt_error()
SRTS_CONNECTED = C.SRTS_CONNECTED
)
const defaultPacketSize = 1456
// InitSRT - Initialize srt library
func InitSRT() {
C.srt_startup()
}
// CleanupSRT - Cleanup SRT lib
func CleanupSRT() {
C.srt_cleanup()
}
// NewSrtSocket - Create a new SRT Socket
func NewSrtSocket(host string, port uint16, options map[string]string) *SrtSocket {
s := new(SrtSocket)
s.socket = C.srt_create_socket()
if s.socket == SRT_INVALID_SOCK {
return nil
}
s.host = host
s.port = port
s.options = options
s.pollTimeout = -1
val, exists := options["pktsize"]
if exists {
pktSize, err := strconv.Atoi(val)
if err != nil {
s.pktSize = pktSize
}
}
if s.pktSize <= 0 {
s.pktSize = defaultPacketSize
}
val, exists = options["blocking"]
if exists && val != "0" {
s.blocking = true
}
if !s.blocking {
s.pd = pollDescInit(s.socket)
}
finalizer := func(obj interface{}) {
sf := obj.(*SrtSocket)
sf.Close()
if sf.pd != nil {
sf.pd.release()
}
}
//Cleanup SrtSocket if no references exist anymore
runtime.SetFinalizer(s, finalizer)
var err error
s.mode, err = s.preconfiguration()
if err != nil {
return nil
}
return s
}
func newFromSocket(acceptSocket *SrtSocket, socket C.SRTSOCKET) (*SrtSocket, error) {
s := new(SrtSocket)
s.socket = socket
s.pktSize = acceptSocket.pktSize
s.blocking = acceptSocket.blocking
s.pollTimeout = acceptSocket.pollTimeout
err := acceptSocket.postconfiguration(s)
if err != nil {
return nil, err
}
if !s.blocking {
s.pd = pollDescInit(s.socket)
}
finalizer := func(obj interface{}) {
sf := obj.(*SrtSocket)
sf.Close()
if sf.pd != nil {
sf.pd.release()
}
}
//Cleanup SrtSocket if no references exist anymore
runtime.SetFinalizer(s, finalizer)
return s, nil
}
func (s SrtSocket) GetSocket() C.int {
return s.socket
}
// Listen for incoming connections. The backlog setting defines how many sockets
// may be allowed to wait until they are accepted (excessive connection requests
// are rejected in advance)
func (s *SrtSocket) Listen(backlog int) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
nbacklog := C.int(backlog)
sa, salen, err := CreateAddrInet(s.host, s.port)
if err != nil {
return err
}
res := C.srt_bind(s.socket, sa, C.int(salen))
if res == SRT_ERROR {
C.srt_close(s.socket)
return fmt.Errorf("Error in srt_bind: %w", srtGetAndClearError())
}
res = C.srt_listen(s.socket, nbacklog)
if res == SRT_ERROR {
C.srt_close(s.socket)
return fmt.Errorf("Error in srt_listen: %w", srtGetAndClearError())
}
err = s.postconfiguration(s)
if err != nil {
return fmt.Errorf("Error setting post socket options")
}
return nil
}
// Connect to a remote endpoint
func (s *SrtSocket) Connect() error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
sa, salen, err := CreateAddrInet(s.host, s.port)
if err != nil {
return err
}
res := C.srt_connect(s.socket, sa, C.int(salen))
if res == SRT_ERROR {
C.srt_close(s.socket)
return srtGetAndClearError()
}
if !s.blocking {
if err := s.pd.wait(ModeWrite); err != nil {
return err
}
}
err = s.postconfiguration(s)
if err != nil {
return fmt.Errorf("Error setting post socket options in connect")
}
return nil
}
// Stats - Retrieve stats from the SRT socket
func (s SrtSocket) Stats() (*SrtStats, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
var stats C.SRT_TRACEBSTATS = C.SRT_TRACEBSTATS{}
var b C.int = 1
if C.srt_bstats(s.socket, &stats, b) == SRT_ERROR {
return nil, fmt.Errorf("Error getting stats, %w", srtGetAndClearError())
}
return newSrtStats(&stats), nil
}
// Mode - Return working mode of the SRT socket
func (s SrtSocket) Mode() int {
return s.mode
}
// PacketSize - Return packet size of the SRT socket
func (s SrtSocket) PacketSize() int {
return s.pktSize
}
// PollTimeout - Return polling max time, for connect/read/write operations.
// Only applied when socket is in non-blocking mode.
func (s SrtSocket) PollTimeout() time.Duration {
return time.Duration(s.pollTimeout) * time.Millisecond
}
// SetPollTimeout - Sets polling max time, for connect/read/write operations.
// Only applied when socket is in non-blocking mode.
func (s *SrtSocket) SetPollTimeout(pollTimeout time.Duration) {
s.pollTimeout = pollTimeout.Milliseconds()
}
func (s *SrtSocket) SetDeadline(deadline time.Time) {
s.pd.setDeadline(deadline, ModeRead+ModeWrite)
}
func (s *SrtSocket) SetReadDeadline(deadline time.Time) {
s.pd.setDeadline(deadline, ModeRead)
}
func (s *SrtSocket) SetWriteDeadline(deadline time.Time) {
s.pd.setDeadline(deadline, ModeWrite)
}
// Close the SRT socket
func (s *SrtSocket) Close() {
C.srt_close(s.socket)
s.socket = SRT_INVALID_SOCK
if !s.blocking {
s.pd.close()
}
callbackMutex.Lock()
if ptr, exists := listenCallbackMap[s.socket]; exists {
gopointer.Unref(ptr)
}
if ptr, exists := connectCallbackMap[s.socket]; exists {
gopointer.Unref(ptr)
}
callbackMutex.Unlock()
}
// ListenCallbackFunc specifies a function to be called before a connecting socket is passed to accept
type ListenCallbackFunc func(socket *SrtSocket, version int, addr *net.UDPAddr, streamid string) bool
//export srtListenCBWrapper
func srtListenCBWrapper(arg unsafe.Pointer, socket C.SRTSOCKET, hsVersion C.int, peeraddr *C.struct_sockaddr, streamid *C.char) C.int {
userCB := gopointer.Restore(arg).(ListenCallbackFunc)
s := new(SrtSocket)
s.socket = socket
udpAddr, _ := udpAddrFromSockaddr((*syscall.RawSockaddrAny)(unsafe.Pointer(peeraddr)))
if userCB(s, int(hsVersion), udpAddr, C.GoString(streamid)) {
return 0
}
return SRT_ERROR
}
// SetListenCallback - set a function to be called early in the handshake before a client
// is handed to accept on a listening socket.
// The connection can be rejected by returning false from the callback.
// See examples/echo-receiver for more details.
func (s SrtSocket) SetListenCallback(cb ListenCallbackFunc) {
ptr := gopointer.Save(cb)
C.srt_listen_callback(s.socket, (*C.srt_listen_callback_fn)(C.srtListenCB), ptr)
callbackMutex.Lock()
defer callbackMutex.Unlock()
if listenCallbackMap[s.socket] != nil {
gopointer.Unref(listenCallbackMap[s.socket])
}
listenCallbackMap[s.socket] = ptr
}
// ConnectCallbackFunc specifies a function to be called after a socket or connection in a group has failed.
type ConnectCallbackFunc func(socket *SrtSocket, err error, addr *net.UDPAddr, token int)
//export srtConnectCBWrapper
func srtConnectCBWrapper(arg unsafe.Pointer, socket C.SRTSOCKET, errcode C.int, peeraddr *C.struct_sockaddr, token C.int) {
userCB := gopointer.Restore(arg).(ConnectCallbackFunc)
s := new(SrtSocket)
s.socket = socket
udpAddr, _ := udpAddrFromSockaddr((*syscall.RawSockaddrAny)(unsafe.Pointer(peeraddr)))
userCB(s, SRTErrno(errcode), udpAddr, int(token))
}
// SetConnectCallback - set a function to be called after a socket or connection in a group has failed
// Note that the function is not guaranteed to be called if the socket is set to blocking mode.
func (s SrtSocket) SetConnectCallback(cb ConnectCallbackFunc) {
ptr := gopointer.Save(cb)
C.srt_connect_callback(s.socket, (*C.srt_connect_callback_fn)(C.srtConnectCB), ptr)
callbackMutex.Lock()
defer callbackMutex.Unlock()
if connectCallbackMap[s.socket] != nil {
gopointer.Unref(connectCallbackMap[s.socket])
}
connectCallbackMap[s.socket] = ptr
}
// Rejection reasons
var (
// Start of range for predefined rejection reasons
RejectionReasonPredefined = int(C.get_srt_error_reject_predefined())
// General syntax error in the SocketID specification (also a fallback code for undefined cases)
RejectionReasonBadRequest = RejectionReasonPredefined + 400
// Authentication failed, provided that the user was correctly identified and access to the required resource would be granted
RejectionReasonUnauthorized = RejectionReasonPredefined + 401
// The server is too heavily loaded, or you have exceeded credits for accessing the service and the resource.
RejectionReasonOverload = RejectionReasonPredefined + 402
// Access denied to the resource by any kind of reason
RejectionReasonForbidden = RejectionReasonPredefined + 403
// Resource not found at this time.
RejectionReasonNotFound = RejectionReasonPredefined + 404
// The mode specified in `m` key in StreamID is not supported for this request.
RejectionReasonBadMode = RejectionReasonPredefined + 405
// The requested parameters specified in SocketID cannot be satisfied for the requested resource. Also when m=publish and the data format is not acceptable.
RejectionReasonUnacceptable = RejectionReasonPredefined + 406
// Start of range for application defined rejection reasons
RejectionReasonUserDefined = int(C.get_srt_error_reject_predefined())
)
// SetRejectReason - set custom reason for connection reject
func (s SrtSocket) SetRejectReason(value int) error {
res := C.srt_setrejectreason(s.socket, C.int(value))
if res == SRT_ERROR {
return errors.New(C.GoString(C.srt_getlasterror_str()))
}
return nil
}
// GetSockOptByte - return byte value obtained with srt_getsockopt
func (s SrtSocket) GetSockOptByte(opt int) (byte, error) {
var v byte
l := 1
err := s.getSockOpt(opt, unsafe.Pointer(&v), &l)
return v, err
}
// GetSockOptBool - return bool value obtained with srt_getsockopt
func (s SrtSocket) GetSockOptBool(opt int) (bool, error) {
var v int32
l := 4
err := s.getSockOpt(opt, unsafe.Pointer(&v), &l)
if v == 1 {
return true, err
}
return false, err
}
// GetSockOptInt - return int value obtained with srt_getsockopt
func (s SrtSocket) GetSockOptInt(opt int) (int, error) {
var v int32
l := 4
err := s.getSockOpt(opt, unsafe.Pointer(&v), &l)
return int(v), err
}
// GetSockOptInt64 - return int64 value obtained with srt_getsockopt
func (s SrtSocket) GetSockOptInt64(opt int) (int64, error) {
var v int64
l := 8
err := s.getSockOpt(opt, unsafe.Pointer(&v), &l)
return v, err
}
// GetSockOptString - return string value obtained with srt_getsockopt
func (s SrtSocket) GetSockOptString(opt int) (string, error) {
buf := make([]byte, 256)
l := len(buf)
err := s.getSockOpt(opt, unsafe.Pointer(&buf[0]), &l)
if err != nil {
return "", err
}
return string(buf[:l]), nil
}
// SetSockOptByte - set byte value using srt_setsockopt
func (s SrtSocket) SetSockOptByte(opt int, value byte) error {
return s.setSockOpt(opt, unsafe.Pointer(&value), 1)
}
// SetSockOptBool - set bool value using srt_setsockopt
func (s SrtSocket) SetSockOptBool(opt int, value bool) error {
val := int(0)
if value {
val = 1
}
return s.setSockOpt(opt, unsafe.Pointer(&val), 4)
}
// SetSockOptInt - set int value using srt_setsockopt
func (s SrtSocket) SetSockOptInt(opt int, value int) error {
return s.setSockOpt(opt, unsafe.Pointer(&value), 4)
}
// SetSockOptInt64 - set int64 value using srt_setsockopt
func (s SrtSocket) SetSockOptInt64(opt int, value int64) error {
return s.setSockOpt(opt, unsafe.Pointer(&value), 8)
}
// SetSockOptString - set string value using srt_setsockopt
func (s SrtSocket) SetSockOptString(opt int, value string) error {
return s.setSockOpt(opt, unsafe.Pointer(&[]byte(value)[0]), len(value))
}
func (s SrtSocket) setSockOpt(opt int, data unsafe.Pointer, size int) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
res := C.srt_setsockopt(s.socket, 0, C.SRT_SOCKOPT(opt), data, C.int(size))
if res == -1 {
return fmt.Errorf("Error calling srt_setsockopt %w", srtGetAndClearError())
}
return nil
}
func (s SrtSocket) getSockOpt(opt int, data unsafe.Pointer, size *int) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
res := C.srt_getsockopt(s.socket, 0, C.SRT_SOCKOPT(opt), data, (*C.int)(unsafe.Pointer(size)))
if res == -1 {
return fmt.Errorf("Error calling srt_getsockopt %w", srtGetAndClearError())
}
return nil
}
func (s SrtSocket) preconfiguration() (int, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
var blocking C.int
if s.blocking {
blocking = C.int(1)
} else {
blocking = C.int(0)
}
result := C.srt_setsockopt(s.socket, 0, C.SRTO_RCVSYN, unsafe.Pointer(&blocking), C.int(unsafe.Sizeof(blocking)))
if result == -1 {
return ModeFailure, fmt.Errorf("could not set SRTO_RCVSYN flag: %w", srtGetAndClearError())
}
var mode int
modeVal, ok := s.options["mode"]
if !ok {
modeVal = "default"
}
if modeVal == "client" || modeVal == "caller" {
mode = ModeCaller
} else if modeVal == "server" || modeVal == "listener" {
mode = ModeListener
} else if modeVal == "default" {
if s.host == "" {
mode = ModeListener
} else {
// Host is given, so check also "adapter"
if _, ok := s.options["adapter"]; ok {
mode = ModeRendezvouz
} else {
mode = ModeCaller
}
}
} else {
mode = ModeFailure
}
if linger, ok := s.options["linger"]; ok {
li, err := strconv.Atoi(linger)
if err == nil {
if err := setSocketLingerOption(s.socket, int32(li)); err != nil {
return ModeFailure, fmt.Errorf("could not set LINGER option %w", err)
}
} else {
return ModeFailure, fmt.Errorf("could not set LINGER option %w", err)
}
}
err := setSocketOptions(s.socket, bindingPre, s.options)
if err != nil {
return ModeFailure, fmt.Errorf("Error setting socket options: %w", err)
}
return mode, nil
}
func (s SrtSocket) postconfiguration(sck *SrtSocket) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
var blocking C.int
if s.blocking {
blocking = 1
} else {
blocking = 0
}
res := C.srt_setsockopt(sck.socket, 0, C.SRTO_SNDSYN, unsafe.Pointer(&blocking), C.int(unsafe.Sizeof(blocking)))
if res == -1 {
return fmt.Errorf("Error in postconfiguration setting SRTO_SNDSYN: %w", srtGetAndClearError())
}
res = C.srt_setsockopt(sck.socket, 0, C.SRTO_RCVSYN, unsafe.Pointer(&blocking), C.int(unsafe.Sizeof(blocking)))
if res == -1 {
return fmt.Errorf("Error in postconfiguration setting SRTO_RCVSYN: %w", srtGetAndClearError())
}
err := setSocketOptions(sck.socket, bindingPost, s.options)
return err
}

View file

@ -0,0 +1,191 @@
package srtgo
// #cgo LDFLAGS: -lsrt
// #include <srt/srt.h>
import "C"
import (
"errors"
"fmt"
"strconv"
"syscall"
"unsafe"
)
const (
transTypeLive = 0
transTypeFile = 1
)
const (
tInteger32 = 0
tInteger64 = 1
tString = 2
tBoolean = 3
tTransType = 4
SRTO_TRANSTYPE = C.SRTO_TRANSTYPE
SRTO_MAXBW = C.SRTO_MAXBW
SRTO_PBKEYLEN = C.SRTO_PBKEYLEN
SRTO_PASSPHRASE = C.SRTO_PASSPHRASE
SRTO_MSS = C.SRTO_MSS
SRTO_FC = C.SRTO_FC
SRTO_SNDBUF = C.SRTO_SNDBUF
SRTO_RCVBUF = C.SRTO_RCVBUF
SRTO_IPTTL = C.SRTO_IPTTL
SRTO_IPTOS = C.SRTO_IPTOS
SRTO_INPUTBW = C.SRTO_INPUTBW
SRTO_OHEADBW = C.SRTO_OHEADBW
SRTO_LATENCY = C.SRTO_LATENCY
SRTO_TSBPDMODE = C.SRTO_TSBPDMODE
SRTO_TLPKTDROP = C.SRTO_TLPKTDROP
SRTO_SNDDROPDELAY = C.SRTO_SNDDROPDELAY
SRTO_NAKREPORT = C.SRTO_NAKREPORT
SRTO_CONNTIMEO = C.SRTO_CONNTIMEO
SRTO_LOSSMAXTTL = C.SRTO_LOSSMAXTTL
SRTO_RCVLATENCY = C.SRTO_RCVLATENCY
SRTO_PEERLATENCY = C.SRTO_PEERLATENCY
SRTO_MINVERSION = C.SRTO_MINVERSION
SRTO_STREAMID = C.SRTO_STREAMID
SRTO_CONGESTION = C.SRTO_CONGESTION
SRTO_MESSAGEAPI = C.SRTO_MESSAGEAPI
SRTO_PAYLOADSIZE = C.SRTO_PAYLOADSIZE
SRTO_KMREFRESHRATE = C.SRTO_KMREFRESHRATE
SRTO_KMPREANNOUNCE = C.SRTO_KMPREANNOUNCE
SRTO_ENFORCEDENCRYPTION = C.SRTO_ENFORCEDENCRYPTION
SRTO_PEERIDLETIMEO = C.SRTO_PEERIDLETIMEO
SRTO_PACKETFILTER = C.SRTO_PACKETFILTER
SRTO_STATE = C.SRTO_STATE
)
type socketOption struct {
name string
level int
option int
binding int
dataType int
}
// List of possible srt socket options
var SocketOptions = []socketOption{
{"transtype", 0, SRTO_TRANSTYPE, bindingPre, tTransType},
{"maxbw", 0, SRTO_MAXBW, bindingPre, tInteger64},
{"pbkeylen", 0, SRTO_PBKEYLEN, bindingPre, tInteger32},
{"passphrase", 0, SRTO_PASSPHRASE, bindingPre, tString},
{"mss", 0, SRTO_MSS, bindingPre, tInteger32},
{"fc", 0, SRTO_FC, bindingPre, tInteger32},
{"sndbuf", 0, SRTO_SNDBUF, bindingPre, tInteger32},
{"rcvbuf", 0, SRTO_RCVBUF, bindingPre, tInteger32},
{"ipttl", 0, SRTO_IPTTL, bindingPre, tInteger32},
{"iptos", 0, SRTO_IPTOS, bindingPre, tInteger32},
{"inputbw", 0, SRTO_INPUTBW, bindingPost, tInteger64},
{"oheadbw", 0, SRTO_OHEADBW, bindingPost, tInteger32},
{"latency", 0, SRTO_LATENCY, bindingPre, tInteger32},
{"tsbpdmode", 0, SRTO_TSBPDMODE, bindingPre, tBoolean},
{"tlpktdrop", 0, SRTO_TLPKTDROP, bindingPre, tBoolean},
{"snddropdelay", 0, SRTO_SNDDROPDELAY, bindingPost, tInteger32},
{"nakreport", 0, SRTO_NAKREPORT, bindingPre, tBoolean},
{"conntimeo", 0, SRTO_CONNTIMEO, bindingPre, tInteger32},
{"lossmaxttl", 0, SRTO_LOSSMAXTTL, bindingPre, tInteger32},
{"rcvlatency", 0, SRTO_RCVLATENCY, bindingPre, tInteger32},
{"peerlatency", 0, SRTO_PEERLATENCY, bindingPre, tInteger32},
{"minversion", 0, SRTO_MINVERSION, bindingPre, tInteger32},
{"streamid", 0, SRTO_STREAMID, bindingPre, tString},
{"congestion", 0, SRTO_CONGESTION, bindingPre, tString},
{"messageapi", 0, SRTO_MESSAGEAPI, bindingPre, tBoolean},
{"payloadsize", 0, SRTO_PAYLOADSIZE, bindingPre, tInteger32},
{"kmrefreshrate", 0, SRTO_KMREFRESHRATE, bindingPre, tInteger32},
{"kmpreannounce", 0, SRTO_KMPREANNOUNCE, bindingPre, tInteger32},
{"enforcedencryption", 0, SRTO_ENFORCEDENCRYPTION, bindingPre, tBoolean},
{"peeridletimeo", 0, SRTO_PEERIDLETIMEO, bindingPre, tInteger32},
{"packetfilter", 0, SRTO_PACKETFILTER, bindingPre, tString},
}
func setSocketLingerOption(s C.int, li int32) error {
var lin syscall.Linger
lin.Linger = li
if lin.Linger > 0 {
lin.Onoff = 1
} else {
lin.Onoff = 0
}
res := C.srt_setsockopt(s, bindingPre, C.SRTO_LINGER, unsafe.Pointer(&lin), C.int(unsafe.Sizeof(lin)))
if res == SRT_ERROR {
return errors.New("failed to set linger")
}
return nil
}
func getSocketLingerOption(s *SrtSocket) (int32, error) {
var lin syscall.Linger
size := int(unsafe.Sizeof(lin))
err := s.getSockOpt(C.SRTO_LINGER, unsafe.Pointer(&lin), &size)
if err != nil {
return 0, err
}
if lin.Onoff == 0 {
return 0, nil
}
return lin.Linger, nil
}
// Set socket options for SRT
func setSocketOptions(s C.int, binding int, options map[string]string) error {
for _, so := range SocketOptions {
if val, ok := options[so.name]; ok {
if so.binding == binding {
if so.dataType == tInteger32 {
v, err := strconv.Atoi(val)
v32 := int32(v)
if err == nil {
result := C.srt_setsockflag(s, C.SRT_SOCKOPT(so.option), unsafe.Pointer(&v32), C.int32_t(unsafe.Sizeof(v32)))
if result == -1 {
return fmt.Errorf("warning - error setting option %s to %s, %w", so.name, val, srtGetAndClearError())
}
}
} else if so.dataType == tInteger64 {
v, err := strconv.ParseInt(val, 10, 64)
if err == nil {
result := C.srt_setsockflag(s, C.SRT_SOCKOPT(so.option), unsafe.Pointer(&v), C.int32_t(unsafe.Sizeof(v)))
if result == -1 {
return fmt.Errorf("warning - error setting option %s to %s, %w", so.name, val, srtGetAndClearError())
}
}
} else if so.dataType == tString {
sval := C.CString(val)
defer C.free(unsafe.Pointer(sval))
result := C.srt_setsockflag(s, C.SRT_SOCKOPT(so.option), unsafe.Pointer(sval), C.int32_t(len(val)))
if result == -1 {
return fmt.Errorf("warning - error setting option %s to %s, %w", so.name, val, srtGetAndClearError())
}
} else if so.dataType == tBoolean {
var result C.int
if val == "1" {
v := C.char(1)
result = C.srt_setsockflag(s, C.SRT_SOCKOPT(so.option), unsafe.Pointer(&v), C.int32_t(unsafe.Sizeof(v)))
} else if val == "0" {
v := C.char(0)
result = C.srt_setsockflag(s, C.SRT_SOCKOPT(so.option), unsafe.Pointer(&v), C.int32_t(unsafe.Sizeof(v)))
}
if result == -1 {
return fmt.Errorf("warning - error setting option %s to %s, %w", so.name, val, srtGetAndClearError())
}
} else if so.dataType == tTransType {
var result C.int
if val == "live" {
var v int32 = C.SRTT_LIVE
result = C.srt_setsockflag(s, C.SRT_SOCKOPT(so.option), unsafe.Pointer(&v), C.int32_t(unsafe.Sizeof(v)))
} else if val == "file" {
var v int32 = C.SRTT_FILE
result = C.srt_setsockflag(s, C.SRT_SOCKOPT(so.option), unsafe.Pointer(&v), C.int32_t(unsafe.Sizeof(v)))
}
if result == -1 {
return fmt.Errorf("warning - error setting option %s to %s: %w", so.name, val, srtGetAndClearError())
}
}
}
}
}
return nil
}

View file

@ -0,0 +1,188 @@
package srtgo
// #cgo LDFLAGS: -lsrt
// #include <srt/srt.h>
import "C"
type SrtStats struct {
// Global measurements
MsTimeStamp int64 // time since the UDT entity is started, in milliseconds
PktSentTotal int64 // total number of sent data packets, including retransmissions
PktRecvTotal int64 // total number of received packets
PktSndLossTotal int // total number of lost packets (sender side)
PktRcvLossTotal int // total number of lost packets (receiver side)
PktRetransTotal int // total number of retransmitted packets
PktSentACKTotal int // total number of sent ACK packets
PktRecvACKTotal int // total number of received ACK packets
PktSentNAKTotal int // total number of sent NAK packets
PktRecvNAKTotal int // total number of received NAK packets
UsSndDurationTotal int64 // total time duration when UDT is sending data (idle time exclusive)
PktSndDropTotal int // number of too-late-to-send dropped packets
PktRcvDropTotal int // number of too-late-to play missing packets
PktRcvUndecryptTotal int // number of undecrypted packets
ByteSentTotal int64 // total number of sent data bytes, including retransmissions
ByteRecvTotal int64 // total number of received bytes
ByteRcvLossTotal int64 // total number of lost bytes
ByteRetransTotal int64 // total number of retransmitted bytes
ByteSndDropTotal int64 // number of too-late-to-send dropped bytes
ByteRcvDropTotal int64 // number of too-late-to play missing bytes (estimate based on average packet size)
ByteRcvUndecryptTotal int64 // number of undecrypted bytes
// Local measurements
PktSent int64 // number of sent data packets, including retransmissions
PktRecv int64 // number of received packets
PktSndLoss int // number of lost packets (sender side)
PktRcvLoss int // number of lost packets (receiver side)
PktRetrans int // number of retransmitted packets
PktRcvRetrans int // number of retransmitted packets received
PktSentACK int // number of sent ACK packets
PktRecvACK int // number of received ACK packets
PktSentNAK int // number of sent NAK packets
PktRecvNAK int // number of received NAK packets
MbpsSendRate float64 // sending rate in Mb/s
MbpsRecvRate float64 // receiving rate in Mb/s
UsSndDuration int64 // busy sending time (i.e., idle time exclusive)
PktReorderDistance int // size of order discrepancy in received sequences
PktRcvAvgBelatedTime float64 // average time of packet delay for belated packets (packets with sequence past the ACK)
PktRcvBelated int64 // number of received AND IGNORED packets due to having come too late
PktSndDrop int // number of too-late-to-send dropped packets
PktRcvDrop int // number of too-late-to play missing packets
PktRcvUndecrypt int // number of undecrypted packets
ByteSent int64 // number of sent data bytes, including retransmissions
ByteRecv int64 // number of received bytes
ByteRcvLoss int64 // number of retransmitted Bytes
ByteRetrans int64 // number of retransmitted Bytes
ByteSndDrop int64 // number of too-late-to-send dropped Bytes
ByteRcvDrop int64 // number of too-late-to play missing Bytes (estimate based on average packet size)
ByteRcvUndecrypt int64 // number of undecrypted bytes
// Instant measurements
UsPktSndPeriod float64 // packet sending period, in microseconds
PktFlowWindow int // flow window size, in number of packets
PktCongestionWindow int // congestion window size, in number of packets
PktFlightSize int // number of packets on flight
MsRTT float64 // RTT, in milliseconds
MbpsBandwidth float64 // estimated bandwidth, in Mb/s
ByteAvailSndBuf int // available UDT sender buffer size
ByteAvailRcvBuf int // available UDT receiver buffer size
MbpsMaxBW float64 // Transmit Bandwidth ceiling (Mbps)
ByteMSS int // MTU
PktSndBuf int // UnACKed packets in UDT sender
ByteSndBuf int // UnACKed bytes in UDT sender
MsSndBuf int // UnACKed timespan (msec) of UDT sender
MsSndTsbPdDelay int // Timestamp-based Packet Delivery Delay
PktRcvBuf int // Undelivered packets in UDT receiver
ByteRcvBuf int // Undelivered bytes of UDT receiver
MsRcvBuf int // Undelivered timespan (msec) of UDT receiver
MsRcvTsbPdDelay int // Timestamp-based Packet Delivery Delay
PktSndFilterExtraTotal int // number of control packets supplied by packet filter
PktRcvFilterExtraTotal int // number of control packets received and not supplied back
PktRcvFilterSupplyTotal int // number of packets that the filter supplied extra (e.g. FEC rebuilt)
PktRcvFilterLossTotal int // number of packet loss not coverable by filter
PktSndFilterExtra int // number of control packets supplied by packet filter
PktRcvFilterExtra int // number of control packets received and not supplied back
PktRcvFilterSupply int // number of packets that the filter supplied extra (e.g. FEC rebuilt)
PktRcvFilterLoss int // number of packet loss not coverable by filter
PktReorderTolerance int // packet reorder tolerance value
}
func newSrtStats(stats *C.SRT_TRACEBSTATS) *SrtStats {
s := new(SrtStats)
s.MsTimeStamp = int64(stats.msTimeStamp)
s.PktSentTotal = int64(stats.pktSentTotal)
s.PktRecvTotal = int64(stats.pktRecvTotal)
s.PktSndLossTotal = int(stats.pktSndLossTotal)
s.PktRcvLossTotal = int(stats.pktRcvLossTotal)
s.PktRetransTotal = int(stats.pktRetransTotal)
s.PktSentACKTotal = int(stats.pktSentACKTotal)
s.PktRecvACKTotal = int(stats.pktRecvACKTotal)
s.PktSentNAKTotal = int(stats.pktSentNAKTotal)
s.PktRecvNAKTotal = int(stats.pktRecvNAKTotal)
s.UsSndDurationTotal = int64(stats.usSndDurationTotal)
s.PktSndDropTotal = int(stats.pktSndDropTotal)
s.PktRcvDropTotal = int(stats.pktRcvDropTotal)
s.PktRcvUndecryptTotal = int(stats.pktRcvUndecryptTotal)
s.ByteSentTotal = int64(stats.byteSentTotal)
s.ByteRecvTotal = int64(stats.byteRecvTotal)
s.ByteRcvLossTotal = int64(stats.byteRcvLossTotal)
s.ByteRetransTotal = int64(stats.byteRetransTotal)
s.ByteSndDropTotal = int64(stats.byteSndDropTotal)
s.ByteRcvDropTotal = int64(stats.byteRcvDropTotal)
s.ByteRcvUndecryptTotal = int64(stats.byteRcvUndecryptTotal)
s.PktSent = int64(stats.pktSent)
s.PktRecv = int64(stats.pktRecv)
s.PktSndLoss = int(stats.pktSndLoss)
s.PktRcvLoss = int(stats.pktRcvLoss)
s.PktRetrans = int(stats.pktRetrans)
s.PktRcvRetrans = int(stats.pktRcvRetrans)
s.PktSentACK = int(stats.pktSentACK)
s.PktRecvACK = int(stats.pktRecvACK)
s.PktSentNAK = int(stats.pktSentNAK)
s.PktRecvNAK = int(stats.pktRecvNAK)
s.MbpsSendRate = float64(stats.mbpsSendRate)
s.MbpsRecvRate = float64(stats.mbpsRecvRate)
s.UsSndDuration = int64(stats.usSndDuration)
s.PktReorderDistance = int(stats.pktReorderDistance)
s.PktRcvAvgBelatedTime = float64(stats.pktRcvAvgBelatedTime)
s.PktRcvBelated = int64(stats.pktRcvBelated)
s.PktSndDrop = int(stats.pktSndDrop)
s.PktRcvDrop = int(stats.pktRcvDrop)
s.PktRcvUndecrypt = int(stats.pktRcvUndecrypt)
s.ByteSent = int64(stats.byteSent)
s.ByteRecv = int64(stats.byteRecv)
s.ByteRcvLoss = int64(stats.byteRcvLoss)
s.ByteRetrans = int64(stats.byteRetrans)
s.ByteSndDrop = int64(stats.byteSndDrop)
s.ByteRcvDrop = int64(stats.byteRcvDrop)
s.ByteRcvUndecrypt = int64(stats.byteRcvUndecrypt)
s.UsPktSndPeriod = float64(stats.usPktSndPeriod)
s.PktFlowWindow = int(stats.pktFlowWindow)
s.PktCongestionWindow = int(stats.pktCongestionWindow)
s.PktFlightSize = int(stats.pktFlightSize)
s.MsRTT = float64(stats.msRTT)
s.MbpsBandwidth = float64(stats.mbpsBandwidth)
s.ByteAvailSndBuf = int(stats.byteAvailSndBuf)
s.ByteAvailRcvBuf = int(stats.byteAvailRcvBuf)
s.MbpsMaxBW = float64(stats.mbpsMaxBW)
s.ByteMSS = int(stats.byteMSS)
s.PktSndBuf = int(stats.pktSndBuf)
s.ByteSndBuf = int(stats.byteSndBuf)
s.MsSndBuf = int(stats.msSndBuf)
s.MsSndTsbPdDelay = int(stats.msSndTsbPdDelay)
s.PktRcvBuf = int(stats.pktRcvBuf)
s.ByteRcvBuf = int(stats.byteRcvBuf)
s.MsRcvBuf = int(stats.msRcvBuf)
s.MsRcvTsbPdDelay = int(stats.msRcvTsbPdDelay)
s.PktSndFilterExtraTotal = int(stats.pktSndFilterExtraTotal)
s.PktRcvFilterExtraTotal = int(stats.pktRcvFilterExtraTotal)
s.PktRcvFilterSupplyTotal = int(stats.pktRcvFilterSupplyTotal)
s.PktRcvFilterLossTotal = int(stats.pktRcvFilterLossTotal)
s.PktSndFilterExtra = int(stats.pktSndFilterExtra)
s.PktRcvFilterExtra = int(stats.pktRcvFilterExtra)
s.PktRcvFilterSupply = int(stats.pktRcvFilterSupply)
s.PktRcvFilterLoss = int(stats.pktRcvFilterLoss)
s.PktReorderTolerance = int(stats.pktReorderTolerance)
return s
}

View file

@ -0,0 +1,55 @@
package srtgo
/*
#cgo LDFLAGS: -lsrt
#include <srt/srt.h>
int srt_sendmsg2_wrapped(SRTSOCKET u, const char* buf, int len, SRT_MSGCTRL *mctrl, int *srterror, int *syserror)
{
int ret = srt_sendmsg2(u, buf, len, mctrl);
if (ret < 0) {
*srterror = srt_getlasterror(syserror);
}
return ret;
}
*/
import "C"
import (
"errors"
"syscall"
"unsafe"
)
func srtSendMsg2Impl(u C.SRTSOCKET, buf []byte, msgctrl *C.SRT_MSGCTRL) (n int, err error) {
srterr := C.int(0)
syserr := C.int(0)
n = int(C.srt_sendmsg2_wrapped(u, (*C.char)(unsafe.Pointer(&buf[0])), C.int(len(buf)), msgctrl, &srterr, &syserr))
if n < 0 {
srterror := SRTErrno(srterr)
if syserr < 0 {
srterror.wrapSysErr(syscall.Errno(syserr))
}
err = srterror
n = 0
}
return
}
// Write data to the SRT socket
func (s SrtSocket) Write(b []byte) (n int, err error) {
//Fastpath:
if !s.blocking {
s.pd.reset(ModeWrite)
}
n, err = srtSendMsg2Impl(s.socket, b, nil)
for {
if !errors.Is(err, error(EAsyncSND)) || s.blocking {
return
}
s.pd.wait(ModeWrite)
n, err = srtSendMsg2Impl(s.socket, b, nil)
}
}

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2019 Yasuhiro Matsumoto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,29 @@
# go-pointer
Utility for cgo
## Usage
https://github.com/golang/proposal/blob/master/design/12416-cgo-pointers.md
In go 1.6, cgo argument can't be passed Go pointer.
```
var s string
C.pass_pointer(pointer.Save(&s))
v := *(pointer.Restore(C.get_from_pointer()).(*string))
```
## Installation
```
go get github.com/mattn/go-pointer
```
## License
MIT
## Author
Yasuhiro Matsumoto (a.k.a mattn)

View file

@ -0,0 +1 @@
package pointer

View file

@ -0,0 +1,57 @@
package pointer
// #include <stdlib.h>
import "C"
import (
"sync"
"unsafe"
)
var (
mutex sync.RWMutex
store = map[unsafe.Pointer]interface{}{}
)
func Save(v interface{}) unsafe.Pointer {
if v == nil {
return nil
}
// Generate real fake C pointer.
// This pointer will not store any data, but will bi used for indexing purposes.
// Since Go doest allow to cast dangling pointer to unsafe.Pointer, we do rally allocate one byte.
// Why we need indexing, because Go doest allow C code to store pointers to Go data.
var ptr unsafe.Pointer = C.malloc(C.size_t(1))
if ptr == nil {
panic("can't allocate 'cgo-pointer hack index pointer': ptr == nil")
}
mutex.Lock()
store[ptr] = v
mutex.Unlock()
return ptr
}
func Restore(ptr unsafe.Pointer) (v interface{}) {
if ptr == nil {
return nil
}
mutex.RLock()
v = store[ptr]
mutex.RUnlock()
return
}
func Unref(ptr unsafe.Pointer) {
if ptr == nil {
return
}
mutex.Lock()
delete(store, ptr)
mutex.Unlock()
C.free(ptr)
}

View file

@ -28,6 +28,8 @@ var (
uint32Type = reflect.TypeOf(uint32(1))
uint64Type = reflect.TypeOf(uint64(1))
uintptrType = reflect.TypeOf(uintptr(1))
float32Type = reflect.TypeOf(float32(1))
float64Type = reflect.TypeOf(float64(1))
@ -308,11 +310,11 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
case reflect.Struct:
{
// All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) {
if !obj1Value.CanConvert(timeType) {
break
}
// time.Time can compared!
// time.Time can be compared!
timeObj1, ok := obj1.(time.Time)
if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
@ -328,7 +330,7 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
case reflect.Slice:
{
// We only care about the []byte type.
if !canConvert(obj1Value, bytesType) {
if !obj1Value.CanConvert(bytesType) {
break
}
@ -345,6 +347,26 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
}
case reflect.Uintptr:
{
uintptrObj1, ok := obj1.(uintptr)
if !ok {
uintptrObj1 = obj1Value.Convert(uintptrType).Interface().(uintptr)
}
uintptrObj2, ok := obj2.(uintptr)
if !ok {
uintptrObj2 = obj2Value.Convert(uintptrType).Interface().(uintptr)
}
if uintptrObj1 > uintptrObj2 {
return compareGreater, true
}
if uintptrObj1 == uintptrObj2 {
return compareEqual, true
}
if uintptrObj1 < uintptrObj2 {
return compareLess, true
}
}
}
return compareEqual, false

View file

@ -1,16 +0,0 @@
//go:build go1.17
// +build go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_legacy.go
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatibility
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View file

@ -1,16 +0,0 @@
//go:build !go1.17
// +build !go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_can_convert.go
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View file

@ -1,7 +1,4 @@
/*
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
// Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
package assert
@ -107,7 +104,7 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{},
return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// EqualValuesf asserts that two objects are equal or convertable to the same types
// EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal.
//
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
@ -616,6 +613,16 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf
return NotErrorIs(t, err, target, append([]interface{}{msg}, args...)...)
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotImplements(t, interfaceObject, object, append([]interface{}{msg}, args...)...)
}
// NotNilf asserts that the specified object is not nil.
//
// assert.NotNilf(t, err, "error message %s", "formatted")
@ -660,10 +667,12 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string,
return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// NotSubsetf asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
//
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted")
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -747,10 +756,11 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg
return Same(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// Subsetf asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
// Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
//
// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted")
// assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()

View file

@ -1,7 +1,4 @@
/*
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
// Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
package assert
@ -189,7 +186,7 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
return EqualExportedValuesf(a.t, expected, actual, msg, args...)
}
// EqualValues asserts that two objects are equal or convertable to the same types
// EqualValues asserts that two objects are equal or convertible to the same types
// and equal.
//
// a.EqualValues(uint32(123), int32(123))
@ -200,7 +197,7 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
return EqualValues(a.t, expected, actual, msgAndArgs...)
}
// EqualValuesf asserts that two objects are equal or convertable to the same types
// EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal.
//
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
@ -1221,6 +1218,26 @@ func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...in
return NotErrorIsf(a.t, err, target, msg, args...)
}
// NotImplements asserts that an object does not implement the specified interface.
//
// a.NotImplements((*MyInterface)(nil), new(MyObject))
func (a *Assertions) NotImplements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotImplements(a.t, interfaceObject, object, msgAndArgs...)
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// a.NotImplementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func (a *Assertions) NotImplementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotImplementsf(a.t, interfaceObject, object, msg, args...)
}
// NotNil asserts that the specified object is not nil.
//
// a.NotNil(err)
@ -1309,10 +1326,12 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
return NotSamef(a.t, expected, actual, msg, args...)
}
// NotSubset asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
// NotSubset asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
//
// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]")
// a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3})
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1320,10 +1339,12 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
return NotSubset(a.t, list, subset, msgAndArgs...)
}
// NotSubsetf asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
//
// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted")
// a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1483,10 +1504,11 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
return Samef(a.t, expected, actual, msg, args...)
}
// Subset asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
// Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
//
// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]")
// a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1})
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1494,10 +1516,11 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
return Subset(a.t, list, subset, msgAndArgs...)
}
// Subsetf asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
// Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
//
// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted")
// a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()

View file

@ -19,7 +19,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib"
yaml "gopkg.in/yaml.v3"
"gopkg.in/yaml.v3"
)
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl"
@ -110,7 +110,12 @@ func copyExportedFields(expected interface{}) interface{} {
return result.Interface()
case reflect.Array, reflect.Slice:
result := reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len())
var result reflect.Value
if expectedKind == reflect.Array {
result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem()
} else {
result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len())
}
for i := 0; i < expectedValue.Len(); i++ {
index := expectedValue.Index(i)
if isNil(index) {
@ -140,6 +145,8 @@ func copyExportedFields(expected interface{}) interface{} {
// structures.
//
// This function does no assertion of any kind.
//
// Deprecated: Use [EqualExportedValues] instead.
func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool {
expectedCleaned := copyExportedFields(expected)
actualCleaned := copyExportedFields(actual)
@ -153,19 +160,42 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
return true
}
actualType := reflect.TypeOf(actual)
if actualType == nil {
return false
}
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
actualValue := reflect.ValueOf(actual)
if !expectedValue.IsValid() || !actualValue.IsValid() {
return false
}
expectedType := expectedValue.Type()
actualType := actualValue.Type()
if !expectedType.ConvertibleTo(actualType) {
return false
}
if !isNumericType(expectedType) || !isNumericType(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(
expectedValue.Convert(actualType).Interface(), actual,
)
}
// If BOTH values are numeric, there are chances of false positives due
// to overflow or underflow. So, we need to make sure to always convert
// the smaller type to a larger type before comparing.
if expectedType.Size() >= actualType.Size() {
return actualValue.Convert(expectedType).Interface() == expected
}
return expectedValue.Convert(actualType).Interface() == actual
}
// isNumericType returns true if the type is one of:
// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64,
// float32, float64, complex64, complex128
func isNumericType(t reflect.Type) bool {
return t.Kind() >= reflect.Int && t.Kind() <= reflect.Complex128
}
/* CallerInfo is necessary because the assert functions use the testing object
internally, causing it to print the file:line of the assert method, rather than where
the problem actually occurred in calling code.*/
@ -266,7 +296,7 @@ func messageFromMsgAndArgs(msgAndArgs ...interface{}) string {
// Aligns the provided message so that all lines after the first line start at the same location as the first line.
// Assumes that the first line starts at the correct location (after carriage return, tab, label, spacer and tab).
// The longestLabelLen parameter specifies the length of the longest label in the output (required becaues this is the
// The longestLabelLen parameter specifies the length of the longest label in the output (required because this is the
// basis on which the alignment occurs).
func indentMessageLines(message string, longestLabelLen int) string {
outBuf := new(bytes.Buffer)
@ -382,6 +412,25 @@ func Implements(t TestingT, interfaceObject interface{}, object interface{}, msg
return true
}
// NotImplements asserts that an object does not implement the specified interface.
//
// assert.NotImplements(t, (*MyInterface)(nil), new(MyObject))
func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
interfaceType := reflect.TypeOf(interfaceObject).Elem()
if object == nil {
return Fail(t, fmt.Sprintf("Cannot check if nil does not implement %v", interfaceType), msgAndArgs...)
}
if reflect.TypeOf(object).Implements(interfaceType) {
return Fail(t, fmt.Sprintf("%T implements %v", object, interfaceType), msgAndArgs...)
}
return true
}
// IsType asserts that the specified objects are of the same type.
func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
@ -496,7 +545,7 @@ func samePointers(first, second interface{}) bool {
// representations appropriate to be presented to the user.
//
// If the values are not of like type, the returned strings will be prefixed
// with the type name, and the value will be enclosed in parenthesis similar
// with the type name, and the value will be enclosed in parentheses similar
// to a type conversion in the Go grammar.
func formatUnequalValues(expected, actual interface{}) (e string, a string) {
if reflect.TypeOf(expected) != reflect.TypeOf(actual) {
@ -523,7 +572,7 @@ func truncatingFormat(data interface{}) string {
return value
}
// EqualValues asserts that two objects are equal or convertable to the same types
// EqualValues asserts that two objects are equal or convertible to the same types
// and equal.
//
// assert.EqualValues(t, uint32(123), int32(123))
@ -566,12 +615,19 @@ func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs ..
return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...)
}
if aType.Kind() == reflect.Ptr {
aType = aType.Elem()
}
if bType.Kind() == reflect.Ptr {
bType = bType.Elem()
}
if aType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...)
return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...)
}
if bType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...)
return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...)
}
expected = copyExportedFields(expected)
@ -620,17 +676,6 @@ func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return Fail(t, "Expected value not to be nil.", msgAndArgs...)
}
// containsKind checks if a specified kind in the slice of kinds.
func containsKind(kinds []reflect.Kind, kind reflect.Kind) bool {
for i := 0; i < len(kinds); i++ {
if kind == kinds[i] {
return true
}
}
return false
}
// isNil checks if a specified object is nil or not, without Failing.
func isNil(object interface{}) bool {
if object == nil {
@ -638,16 +683,13 @@ func isNil(object interface{}) bool {
}
value := reflect.ValueOf(object)
kind := value.Kind()
isNilableKind := containsKind(
[]reflect.Kind{
switch value.Kind() {
case
reflect.Chan, reflect.Func,
reflect.Interface, reflect.Map,
reflect.Ptr, reflect.Slice, reflect.UnsafePointer},
kind)
reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
if isNilableKind && value.IsNil() {
return true
return value.IsNil()
}
return false
@ -731,16 +773,14 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
}
// getLen try to get length of object.
// return (false, 0) if impossible.
func getLen(x interface{}) (ok bool, length int) {
// getLen tries to get the length of an object.
// It returns (0, false) if impossible.
func getLen(x interface{}) (length int, ok bool) {
v := reflect.ValueOf(x)
defer func() {
if e := recover(); e != nil {
ok = false
}
ok = recover() == nil
}()
return true, v.Len()
return v.Len(), true
}
// Len asserts that the specified object has specific length.
@ -751,13 +791,13 @@ func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{})
if h, ok := t.(tHelper); ok {
h.Helper()
}
ok, l := getLen(object)
l, ok := getLen(object)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", object), msgAndArgs...)
return Fail(t, fmt.Sprintf("\"%v\" could not be applied builtin len()", object), msgAndArgs...)
}
if l != length {
return Fail(t, fmt.Sprintf("\"%s\" should have %d item(s), but has %d", object, length, l), msgAndArgs...)
return Fail(t, fmt.Sprintf("\"%v\" should have %d item(s), but has %d", object, length, l), msgAndArgs...)
}
return true
}
@ -919,10 +959,11 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
}
// Subset asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
// Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
//
// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]")
// assert.Subset(t, [1, 2, 3], [1, 2])
// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1})
func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -975,10 +1016,12 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true
}
// NotSubset asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
// NotSubset asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
//
// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]")
// assert.NotSubset(t, [1, 3, 4], [1, 2])
// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3})
func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -1439,7 +1482,7 @@ func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAnd
h.Helper()
}
if math.IsNaN(epsilon) {
return Fail(t, "epsilon must not be NaN")
return Fail(t, "epsilon must not be NaN", msgAndArgs...)
}
actualEpsilon, err := calcRelativeError(expected, actual)
if err != nil {
@ -1458,19 +1501,26 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
if h, ok := t.(tHelper); ok {
h.Helper()
}
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
if expected == nil || actual == nil {
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
expectedSlice := reflect.ValueOf(expected)
actualSlice := reflect.ValueOf(actual)
for i := 0; i < actualSlice.Len(); i++ {
result := InEpsilon(t, actualSlice.Index(i).Interface(), expectedSlice.Index(i).Interface(), epsilon)
if !result {
return result
if expectedSlice.Type().Kind() != reflect.Slice {
return Fail(t, "Expected value must be slice", msgAndArgs...)
}
expectedLen := expectedSlice.Len()
if !IsType(t, expected, actual) || !Len(t, actual, expectedLen) {
return false
}
for i := 0; i < expectedLen; i++ {
if !InEpsilon(t, expectedSlice.Index(i).Interface(), actualSlice.Index(i).Interface(), epsilon, "at index %d", i) {
return false
}
}
@ -1870,23 +1920,18 @@ func (c *CollectT) Errorf(format string, args ...interface{}) {
}
// FailNow panics.
func (c *CollectT) FailNow() {
func (*CollectT) FailNow() {
panic("Assertion failed")
}
// Reset clears the collected errors.
func (c *CollectT) Reset() {
c.errors = nil
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
func (*CollectT) Reset() {
panic("Reset() is deprecated")
}
// Copy copies the collected errors to the supplied t.
func (c *CollectT) Copy(t TestingT) {
if tt, ok := t.(tHelper); ok {
tt.Helper()
}
for _, err := range c.errors {
t.Errorf("%v", err)
}
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
func (*CollectT) Copy(TestingT) {
panic("Copy() is deprecated")
}
// EventuallyWithT asserts that given condition will be met in waitFor time,
@ -1912,8 +1957,8 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
h.Helper()
}
collect := new(CollectT)
ch := make(chan bool, 1)
var lastFinishedTickErrs []error
ch := make(chan []error, 1)
timer := time.NewTimer(waitFor)
defer timer.Stop()
@ -1924,19 +1969,25 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
for tick := ticker.C; ; {
select {
case <-timer.C:
collect.Copy(t)
for _, err := range lastFinishedTickErrs {
t.Errorf("%v", err)
}
return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick:
tick = nil
collect.Reset()
go func() {
condition(collect)
ch <- len(collect.errors) == 0
collect := new(CollectT)
defer func() {
ch <- collect.errors
}()
case v := <-ch:
if v {
condition(collect)
}()
case errs := <-ch:
if len(errs) == 0 {
return true
}
// Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached.
lastFinishedTickErrs = errs
tick = ticker.C
}
}

View file

@ -12,7 +12,7 @@ import (
// an error if building a new request fails.
func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) {
w := httptest.NewRecorder()
req, err := http.NewRequest(method, url, nil)
req, err := http.NewRequest(method, url, http.NoBody)
if err != nil {
return -1, err
}
@ -32,12 +32,12 @@ func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, value
}
code, err := httpCode(handler, method, url, values)
if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
}
isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent
if !isSuccessCode {
Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code))
Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
}
return isSuccessCode
@ -54,12 +54,12 @@ func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, valu
}
code, err := httpCode(handler, method, url, values)
if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
}
isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect
if !isRedirectCode {
Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code))
Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
}
return isRedirectCode
@ -76,12 +76,12 @@ func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values
}
code, err := httpCode(handler, method, url, values)
if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
}
isErrorCode := code >= http.StatusBadRequest
if !isErrorCode {
Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code))
Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
}
return isErrorCode
@ -98,12 +98,12 @@ func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, va
}
code, err := httpCode(handler, method, url, values)
if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
}
successful := code == statuscode
if !successful {
Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code))
Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code), msgAndArgs...)
}
return successful
@ -113,7 +113,10 @@ func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, va
// empty string if building a new request fails.
func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string {
w := httptest.NewRecorder()
req, err := http.NewRequest(method, url+"?"+values.Encode(), nil)
if len(values) > 0 {
url += "?" + values.Encode()
}
req, err := http.NewRequest(method, url, http.NoBody)
if err != nil {
return ""
}
@ -135,7 +138,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string,
contains := strings.Contains(body, fmt.Sprint(str))
if !contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body))
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
}
return contains
@ -155,7 +158,7 @@ func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url strin
contains := strings.Contains(body, fmt.Sprint(str))
if contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body))
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
}
return !contains

View file

@ -1,7 +1,4 @@
/*
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
// Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
package require
@ -235,7 +232,7 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{},
t.FailNow()
}
// EqualValues asserts that two objects are equal or convertable to the same types
// EqualValues asserts that two objects are equal or convertible to the same types
// and equal.
//
// assert.EqualValues(t, uint32(123), int32(123))
@ -249,7 +246,7 @@ func EqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArg
t.FailNow()
}
// EqualValuesf asserts that two objects are equal or convertable to the same types
// EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal.
//
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
@ -1546,6 +1543,32 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf
t.FailNow()
}
// NotImplements asserts that an object does not implement the specified interface.
//
// assert.NotImplements(t, (*MyInterface)(nil), new(MyObject))
func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.NotImplements(t, interfaceObject, object, msgAndArgs...) {
return
}
t.FailNow()
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.NotImplementsf(t, interfaceObject, object, msg, args...) {
return
}
t.FailNow()
}
// NotNil asserts that the specified object is not nil.
//
// assert.NotNil(t, err)
@ -1658,10 +1681,12 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string,
t.FailNow()
}
// NotSubset asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
// NotSubset asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
//
// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]")
// assert.NotSubset(t, [1, 3, 4], [1, 2])
// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3})
func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -1672,10 +1697,12 @@ func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...i
t.FailNow()
}
// NotSubsetf asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
//
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted")
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -1880,10 +1907,11 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg
t.FailNow()
}
// Subset asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
// Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
//
// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]")
// assert.Subset(t, [1, 2, 3], [1, 2])
// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1})
func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
@ -1894,10 +1922,11 @@ func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...inte
t.FailNow()
}
// Subsetf asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
// Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
//
// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted")
// assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()

View file

@ -1,7 +1,4 @@
/*
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
// Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
package require
@ -190,7 +187,7 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
EqualExportedValuesf(a.t, expected, actual, msg, args...)
}
// EqualValues asserts that two objects are equal or convertable to the same types
// EqualValues asserts that two objects are equal or convertible to the same types
// and equal.
//
// a.EqualValues(uint32(123), int32(123))
@ -201,7 +198,7 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
EqualValues(a.t, expected, actual, msgAndArgs...)
}
// EqualValuesf asserts that two objects are equal or convertable to the same types
// EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal.
//
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
@ -1222,6 +1219,26 @@ func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...in
NotErrorIsf(a.t, err, target, msg, args...)
}
// NotImplements asserts that an object does not implement the specified interface.
//
// a.NotImplements((*MyInterface)(nil), new(MyObject))
func (a *Assertions) NotImplements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotImplements(a.t, interfaceObject, object, msgAndArgs...)
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// a.NotImplementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func (a *Assertions) NotImplementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
NotImplementsf(a.t, interfaceObject, object, msg, args...)
}
// NotNil asserts that the specified object is not nil.
//
// a.NotNil(err)
@ -1310,10 +1327,12 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
NotSamef(a.t, expected, actual, msg, args...)
}
// NotSubset asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
// NotSubset asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
//
// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]")
// a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3})
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1321,10 +1340,12 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
NotSubset(a.t, list, subset, msgAndArgs...)
}
// NotSubsetf asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// contain all elements given in the specified subset list(array, slice...) or
// map.
//
// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted")
// a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1484,10 +1505,11 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
Samef(a.t, expected, actual, msg, args...)
}
// Subset asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
// Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
//
// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]")
// a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1})
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@ -1495,10 +1517,11 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
Subset(a.t, list, subset, msgAndArgs...)
}
// Subsetf asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
// Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset list(array, slice...) or map.
//
// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted")
// a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()

View file

@ -733,13 +733,14 @@ func (s *String) ReadOptionalASN1OctetString(out *[]byte, outPresent *bool, tag
return true
}
// ReadOptionalASN1Boolean sets *out to the value of the next ASN.1 BOOLEAN or,
// if the next bytes are not an ASN.1 BOOLEAN, to the value of defaultValue.
// It reports whether the operation was successful.
func (s *String) ReadOptionalASN1Boolean(out *bool, defaultValue bool) bool {
// ReadOptionalASN1Boolean attempts to read an optional ASN.1 BOOLEAN
// explicitly tagged with tag into out and advances. If no element with a
// matching tag is present, it sets "out" to defaultValue instead. It reports
// whether the read was successful.
func (s *String) ReadOptionalASN1Boolean(out *bool, tag asn1.Tag, defaultValue bool) bool {
var present bool
var child String
if !s.ReadOptionalASN1(&child, &present, asn1.BOOLEAN) {
if !s.ReadOptionalASN1(&child, &present, tag) {
return false
}
@ -748,7 +749,7 @@ func (s *String) ReadOptionalASN1Boolean(out *bool, defaultValue bool) bool {
return true
}
return s.ReadASN1Boolean(out)
return child.ReadASN1Boolean(out)
}
func (s *String) readASN1(out *String, outTag *asn1.Tag, skipHeader bool) bool {

View file

@ -95,6 +95,11 @@ func (b *Builder) AddUint32(v uint32) {
b.add(byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
// AddUint48 appends a big-endian, 48-bit value to the byte string.
func (b *Builder) AddUint48(v uint64) {
b.add(byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
// AddUint64 appends a big-endian, 64-bit value to the byte string.
func (b *Builder) AddUint64(v uint64) {
b.add(byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v))

View file

@ -81,6 +81,17 @@ func (s *String) ReadUint32(out *uint32) bool {
return true
}
// ReadUint48 decodes a big-endian, 48-bit value into out and advances over it.
// It reports whether the read was successful.
func (s *String) ReadUint48(out *uint64) bool {
v := s.read(6)
if v == nil {
return false
}
*out = uint64(v[0])<<40 | uint64(v[1])<<32 | uint64(v[2])<<24 | uint64(v[3])<<16 | uint64(v[4])<<8 | uint64(v[5])
return true
}
// ReadUint64 decodes a big-endian, 64-bit value into out and advances over it.
// It reports whether the read was successful.
func (s *String) ReadUint64(out *uint64) bool {

View file

@ -1,7 +1,6 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
package field

View file

@ -1,7 +1,6 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
#include "textflag.h"

View file

@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
// +build !amd64 !gc purego
package field

View file

@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
package field

View file

@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
#include "textflag.h"

Some files were not shown because too many files have changed in this diff Show more