diff --git a/README.md b/README.md index a487d349d..f76b13bbc 100755 --- a/README.md +++ b/README.md @@ -186,6 +186,7 @@ Other documents: ## V4 changes +* v4.0, 2021-03-09, DTLS: Fix ARQ bug, use openssl timeout. 4.0.84 * v4.0, 2021-03-08, DTLS: Fix dead loop by duplicated Alert message. 4.0.83 * v4.0, 2021-03-08, Fix bug when client DTLS is passive. 4.0.82 * v4.0, 2021-03-03, Fix [#2106][bug #2106], [#2011][bug #2011], RTMP/AAC transcode to Opus bug. 4.0.81 diff --git a/trunk/.gitignore b/trunk/.gitignore index 1cbd7b7e6..44e1d851d 100644 --- a/trunk/.gitignore +++ b/trunk/.gitignore @@ -34,7 +34,7 @@ /research/speex/ /test/ .DS_Store -srs +./srs *.dSYM/ *.gcov *.ts diff --git a/trunk/3rdparty/ccache/build_ccache.sh b/trunk/3rdparty/ccache/build_ccache.sh deleted file mode 100755 index 53b187882..000000000 --- a/trunk/3rdparty/ccache/build_ccache.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -# check exists. -if [[ -f /usr/local/bin/ccache ]]; then - echo "ccache is ok"; - exit 0; -fi - -# check sudoer. -sudo echo "ok" > /dev/null 2>&1; -ret=$?; if [[ 0 -ne ${ret} ]]; then echo "you must be sudoer"; exit 1; fi - -unzip ccache-3.1.9.zip && cd ccache-3.1.9 && ./configure && make -ret=$?; if [[ $ret -ne 0 ]]; then echo "build ccache failed."; exit $ret; fi - -sudo cp ccache /usr/local/bin && sudo ln -s ccache /usr/local/bin/gcc && sudo ln -s ccache /usr/local/bin/g++ && sudo ln -s ccache /usr/local/bin/cc && sudo ln -s ccache /usr/local/bin/c++ -ret=$?; if [[ $ret -ne 0 ]]; then echo "install ccache failed."; exit $ret; fi diff --git a/trunk/3rdparty/ccache/ccache-3.1.9.zip b/trunk/3rdparty/ccache/ccache-3.1.9.zip deleted file mode 100644 index 10c96dd0c..000000000 Binary files a/trunk/3rdparty/ccache/ccache-3.1.9.zip and /dev/null differ diff --git a/trunk/3rdparty/ccache/readme.txt b/trunk/3rdparty/ccache/readme.txt deleted file mode 100644 index 611d9eaee..000000000 --- a/trunk/3rdparty/ccache/readme.txt +++ /dev/null @@ -1,11 +0,0 @@ -ccache是samba组织提供的加速编译过程的工具, -使用虚拟机编译可以考虑用这个工具,让编译过程飞快。 - -链接: - http://ccache.samba.org/ - http://samba.org/ftp/ccache/ccache-3.1.9.tar.xz - http://ccache.samba.org/manual.html - -安装方法: - bash build_ccache.sh -注意:要求以sudoer执行,要修改文件。 \ No newline at end of file diff --git a/trunk/3rdparty/srs-bench/LICENSE b/trunk/3rdparty/srs-bench/LICENSE new file mode 100644 index 000000000..77ba5769d --- /dev/null +++ b/trunk/3rdparty/srs-bench/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2021 srs-bench(ossrs) + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/trunk/3rdparty/srs-bench/Makefile b/trunk/3rdparty/srs-bench/Makefile index f4ee3f9d0..dedfc5fe5 100644 --- a/trunk/3rdparty/srs-bench/Makefile +++ b/trunk/3rdparty/srs-bench/Makefile @@ -5,18 +5,18 @@ default: bench test clean: rm -f ./objs/srs_bench ./objs/srs_test -.format.txt: *.go rtc/*.go srs/*.go +.format.txt: *.go srs/*.go vnet/*.go gofmt -w . echo "done" > .format.txt bench: ./objs/srs_bench -./objs/srs_bench: .format.txt *.go rtc/*.go srs/*.go Makefile +./objs/srs_bench: .format.txt *.go srs/*.go vnet/*.go Makefile go build -mod=vendor -o objs/srs_bench . test: ./objs/srs_test -./objs/srs_test: .format.txt *.go rtc/*.go srs/*.go Makefile +./objs/srs_test: .format.txt *.go srs/*.go vnet/*.go Makefile go test ./srs -mod=vendor -c -o ./objs/srs_test help: diff --git a/trunk/3rdparty/srs-bench/README.md b/trunk/3rdparty/srs-bench/README.md index 061976383..415cb195f 100644 --- a/trunk/3rdparty/srs-bench/README.md +++ b/trunk/3rdparty/srs-bench/README.md @@ -102,8 +102,7 @@ ffmpeg -re -i doc/source.200kbps.768x320.flv -c copy -f flv -y rtmp://localhost/ 回归测试需要先启动[SRS](https://github.com/ossrs/srs/issues/307),支持WebRTC推拉流: ```bash -eip=$(ifconfig en0 inet| grep 'inet '|awk '{print $2}') -if [[ ! -z $eip ]]; then +if [[ ! -z $(ifconfig en0 inet| grep 'inet '|awk '{print $2}') ]]; then docker run -p 1935:1935 -p 8080:8080 -p 1985:1985 -p 8000:8000/udp \ --rm --env CANDIDATE=$(ifconfig en0 inet| grep 'inet '|awk '{print $2}')\ registry.cn-hangzhou.aliyuncs.com/ossrs/srs:v4.0.76 objs/srs -c conf/rtc.conf @@ -119,7 +118,20 @@ go test ./srs -mod=vendor -v 也可以用make编译出重复使用的二进制: ```bash -make test && ./objs/srs_test -test.v +make && ./objs/srs_test -test.v +``` + +> Note: 注意由于pion不支持`DTLS 1.0`,所以SFU必须要支持`DTLS 1.2`才行。 + +运行结果如下: + +```bash +$ make && ./objs/srs_test -test.v +=== RUN TestRTCServerVersion +--- PASS: TestRTCServerVersion (0.00s) +=== RUN TestRTCServerPublishPlay +--- PASS: TestRTCServerPublishPlay (1.28s) +PASS ``` 可以给回归测试传参数,这样可以测试不同的序列,比如: @@ -127,23 +139,43 @@ make test && ./objs/srs_test -test.v ```bash go test ./srs -mod=vendor -v -srs-server=127.0.0.1 # Or -make test && ./objs/srs_test -test.v -srs-server=127.0.0.1 +make && ./objs/srs_test -test.v -srs-server=127.0.0.1 ``` 支持的参数如下: * `-srs-server`,RTC服务器地址。默认值:`127.0.0.1` * `-srs-stream`,RTC流地址。默认值:`/rtc/regression` -* `-srs-log`,是否开启详细日志。默认值:`false` * `-srs-timeout`,每个Case的超时时间,毫秒。默认值:`3000`,即3秒。 -* `-srs-play-pli`,播放时,PLI的间隔,毫秒。默认值:`5000`,即5秒。 -* `-srs-play-ok-packets`,播放时,收到多少个包认为是测试通过,默认值:`10` * `-srs-publish-audio`,推流时,使用的音频文件。默认值:`avatar.ogg` * `-srs-publish-video`,推流时,使用的视频文件。默认值:`avatar.h264` * `-srs-publish-video-fps`,推流时,视频文件的FPS。默认值:`25` +* `-srs-vnet-client-ip`,设置[pion/vnet](https://github.com/ossrs/srs-bench/blob/feature/rtc/vnet/example_test.go)客户端的虚拟IP,不能和服务器IP冲突。默认值:`192.168.168.168` 其他不常用参数: +* `-srs-log`,是否开启详细日志。默认值:`false` +* `-srs-play-ok-packets`,播放时,收到多少个包认为是测试通过,默认值:`10` +* `-srs-publish-ok-packets`,推流时,发送多少个包认为时测试通过,默认值:`10` * `-srs-https`,是否连接HTTPS-API。默认值:`false`,即连接HTTP-API。 +* `-srs-play-pli`,播放时,PLI的间隔,毫秒。默认值:`5000`,即5秒。 +* `-srs-dtls-drop-packets`,DTLS丢包测试,丢了多少个包算成功,默认值:`5` + +## GCOVR + +本机生成覆盖率时,我们使用工具[gcovr](https://gcovr.com/en/stable/guide.html)。 + +在macOS上安装gcovr: + +```bash +pip3 install gcovr +``` + +在CentOS上安装gcovr: + +```bash +yum install -y python2-pip && +pip install lxml && pip install gcovr +``` 2021.01, Winlin diff --git a/trunk/3rdparty/srs-bench/avatar.h264 b/trunk/3rdparty/srs-bench/avatar.h264 new file mode 100644 index 000000000..a82911a64 Binary files /dev/null and b/trunk/3rdparty/srs-bench/avatar.h264 differ diff --git a/trunk/3rdparty/srs-bench/main.go b/trunk/3rdparty/srs-bench/main.go index a2384a524..d56fa4995 100644 --- a/trunk/3rdparty/srs-bench/main.go +++ b/trunk/3rdparty/srs-bench/main.go @@ -1,3 +1,23 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. package main import ( @@ -42,7 +62,7 @@ func main() { flag.IntVar(&delay, "delay", 50, "") var statListen string - flag.StringVar(&statListen, "stat", ":18000", "") + flag.StringVar(&statListen, "stat", "", "") flag.Usage = func() { fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0])) @@ -52,7 +72,7 @@ func main() { fmt.Println(fmt.Sprintf(" -delay The start delay in ms for each client or stream to simulate. Default: 50")) fmt.Println(fmt.Sprintf(" -al [Optional] Whether enable audio-level. Default: true")) fmt.Println(fmt.Sprintf(" -twcc [Optional] Whether enable vdieo-twcc. Default: true")) - fmt.Println(fmt.Sprintf(" -stat [Optional] The stat server API listen port. Default: :18000")) + fmt.Println(fmt.Sprintf(" -stat [Optional] The stat server API listen port.")) fmt.Println(fmt.Sprintf("Player or Subscriber:")) fmt.Println(fmt.Sprintf(" -sr The url to play/subscribe. If sn exceed 1, auto append variable %%d.")) fmt.Println(fmt.Sprintf(" -da [Optional] The file path to dump audio, ignore if empty.")) diff --git a/trunk/3rdparty/srs-bench/rtc/pion_constants.go b/trunk/3rdparty/srs-bench/rtc/pion_constants.go deleted file mode 100644 index 4b44a7064..000000000 --- a/trunk/3rdparty/srs-bench/rtc/pion_constants.go +++ /dev/null @@ -1,5 +0,0 @@ -package rtc - -const ( - rtpOutboundMTU = 1200 -) diff --git a/trunk/3rdparty/srs-bench/rtc/pion_mediaengine.go b/trunk/3rdparty/srs-bench/rtc/pion_mediaengine.go deleted file mode 100644 index a86e71f24..000000000 --- a/trunk/3rdparty/srs-bench/rtc/pion_mediaengine.go +++ /dev/null @@ -1,27 +0,0 @@ -package rtc - -import ( - "github.com/pion/rtp" - "github.com/pion/rtp/codecs" - "github.com/pion/webrtc/v3" - "strings" -) - -func payloaderForCodec(codec webrtc.RTPCodecCapability) (rtp.Payloader, error) { - switch strings.ToLower(codec.MimeType) { - case strings.ToLower(webrtc.MimeTypeH264): - return &codecs.H264Payloader{}, nil - case strings.ToLower(webrtc.MimeTypeOpus): - return &codecs.OpusPayloader{}, nil - case strings.ToLower(webrtc.MimeTypeVP8): - return &codecs.VP8Payloader{}, nil - case strings.ToLower(webrtc.MimeTypeVP9): - return &codecs.VP9Payloader{}, nil - case strings.ToLower(webrtc.MimeTypeG722): - return &codecs.G722Payloader{}, nil - case strings.ToLower(webrtc.MimeTypePCMU), strings.ToLower(webrtc.MimeTypePCMA): - return &codecs.G711Payloader{}, nil - default: - return nil, webrtc.ErrNoPayloaderForCodec - } -} diff --git a/trunk/3rdparty/srs-bench/rtc/pion_rtpcodec.go b/trunk/3rdparty/srs-bench/rtc/pion_rtpcodec.go deleted file mode 100644 index c7f85828b..000000000 --- a/trunk/3rdparty/srs-bench/rtc/pion_rtpcodec.go +++ /dev/null @@ -1,27 +0,0 @@ -package rtc - -import ( - "github.com/pion/webrtc/v3" - "strings" -) - -// Do a fuzzy find for a codec in the list of codecs -// Used for lookup up a codec in an existing list to find a match -func codecParametersFuzzySearch(needle webrtc.RTPCodecParameters, haystack []webrtc.RTPCodecParameters) (webrtc.RTPCodecParameters, error) { - // First attempt to match on MimeType + SDPFmtpLine - for _, c := range haystack { - if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) && - c.RTPCodecCapability.SDPFmtpLine == needle.RTPCodecCapability.SDPFmtpLine { - return c, nil - } - } - - // Fallback to just MimeType - for _, c := range haystack { - if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) { - return c, nil - } - } - - return webrtc.RTPCodecParameters{}, webrtc.ErrCodecNotFound -} diff --git a/trunk/3rdparty/srs-bench/rtc/pion_track_local_static.go b/trunk/3rdparty/srs-bench/rtc/pion_track_local_static.go deleted file mode 100644 index 4bf40cdb9..000000000 --- a/trunk/3rdparty/srs-bench/rtc/pion_track_local_static.go +++ /dev/null @@ -1,246 +0,0 @@ -package rtc - -import ( - "github.com/pion/rtp" - "github.com/pion/webrtc/v3" - "github.com/pion/webrtc/v3/pkg/media" - "strings" - "sync" -) - -// trackBinding is a single bind for a Track -// Bind can be called multiple times, this stores the -// result for a single bind call so that it can be used when writing -type trackBinding struct { - id string - ssrc webrtc.SSRC - payloadType webrtc.PayloadType - writeStream webrtc.TrackLocalWriter -} - -// TrackLocalStaticRTP is a TrackLocal that has a pre-set codec and accepts RTP Packets. -// If you wish to send a media.Sample use TrackLocalStaticSample -type TrackLocalStaticRTP struct { - mu sync.RWMutex - bindings []trackBinding - codec webrtc.RTPCodecCapability - id, streamID string -} - -// NewTrackLocalStaticRTP returns a TrackLocalStaticRTP. -func NewTrackLocalStaticRTP(c webrtc.RTPCodecCapability, id, streamID string) (*TrackLocalStaticRTP, error) { - return &TrackLocalStaticRTP{ - codec: c, - bindings: []trackBinding{}, - id: id, - streamID: streamID, - }, nil -} - -// Bind is called by the PeerConnection after negotiation is complete -// This asserts that the code requested is supported by the remote peer. -// If so it setups all the state (SSRC and PayloadType) to have a call -func (s *TrackLocalStaticRTP) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) { - s.mu.Lock() - defer s.mu.Unlock() - - parameters := webrtc.RTPCodecParameters{RTPCodecCapability: s.codec} - if codec, err := codecParametersFuzzySearch(parameters, t.CodecParameters()); err == nil { - s.bindings = append(s.bindings, trackBinding{ - ssrc: t.SSRC(), - payloadType: codec.PayloadType, - writeStream: t.WriteStream(), - id: t.ID(), - }) - return codec, nil - } - - return webrtc.RTPCodecParameters{}, webrtc.ErrUnsupportedCodec -} - -// Unbind implements the teardown logic when the track is no longer needed. This happens -// because a track has been stopped. -func (s *TrackLocalStaticRTP) Unbind(t webrtc.TrackLocalContext) error { - s.mu.Lock() - defer s.mu.Unlock() - - for i := range s.bindings { - if s.bindings[i].id == t.ID() { - s.bindings[i] = s.bindings[len(s.bindings)-1] - s.bindings = s.bindings[:len(s.bindings)-1] - return nil - } - } - - return webrtc.ErrUnbindFailed -} - -// ID is the unique identifier for this Track. This should be unique for the -// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' -// and StreamID would be 'desktop' or 'webcam' -func (s *TrackLocalStaticRTP) ID() string { return s.id } - -// StreamID is the group this track belongs too. This must be unique -func (s *TrackLocalStaticRTP) StreamID() string { return s.streamID } - -// Kind controls if this TrackLocal is audio or video -func (s *TrackLocalStaticRTP) Kind() webrtc.RTPCodecType { - switch { - case strings.HasPrefix(s.codec.MimeType, "audio/"): - return webrtc.RTPCodecTypeAudio - case strings.HasPrefix(s.codec.MimeType, "video/"): - return webrtc.RTPCodecTypeVideo - default: - return webrtc.RTPCodecType(0) - } -} - -// Codec gets the Codec of the track -func (s *TrackLocalStaticRTP) Codec() webrtc.RTPCodecCapability { - return s.codec -} - -// WriteRTP writes a RTP Packet to the TrackLocalStaticRTP -// If one PeerConnection fails the packets will still be sent to -// all PeerConnections. The error message will contain the ID of the failed -// PeerConnections so you can remove them -func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error { - s.mu.RLock() - defer s.mu.RUnlock() - - writeErrs := []error{} - outboundPacket := *p - - for _, b := range s.bindings { - outboundPacket.Header.SSRC = uint32(b.ssrc) - outboundPacket.Header.PayloadType = uint8(b.payloadType) - if _, err := b.writeStream.WriteRTP(&outboundPacket.Header, outboundPacket.Payload); err != nil { - writeErrs = append(writeErrs, err) - } - } - - return FlattenErrs(writeErrs) -} - -// Write writes a RTP Packet as a buffer to the TrackLocalStaticRTP -// If one PeerConnection fails the packets will still be sent to -// all PeerConnections. The error message will contain the ID of the failed -// PeerConnections so you can remove them -func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) { - packet := &rtp.Packet{} - if err = packet.Unmarshal(b); err != nil { - return 0, err - } - - return len(b), s.WriteRTP(packet) -} - -// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples. -// If you wish to send a RTP Packet use TrackLocalStaticRTP -type TrackLocalStaticSample struct { - packetizer rtp.Packetizer - rtpTrack *TrackLocalStaticRTP - clockRate float64 - - // Set the callback before write RTP packet. - OnBeforeWritePacket func(rtp *rtp.Packet) -} - -// NewTrackLocalStaticSample returns a TrackLocalStaticSample -func NewTrackLocalStaticSample(c webrtc.RTPCodecCapability, id, streamID string) (*TrackLocalStaticSample, error) { - rtpTrack, err := NewTrackLocalStaticRTP(c, id, streamID) - if err != nil { - return nil, err - } - - return &TrackLocalStaticSample{ - rtpTrack: rtpTrack, - }, nil -} - -// ID is the unique identifier for this Track. This should be unique for the -// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' -// and StreamID would be 'desktop' or 'webcam' -func (s *TrackLocalStaticSample) ID() string { return s.rtpTrack.ID() } - -// StreamID is the group this track belongs too. This must be unique -func (s *TrackLocalStaticSample) StreamID() string { return s.rtpTrack.StreamID() } - -// Kind controls if this TrackLocal is audio or video -func (s *TrackLocalStaticSample) Kind() webrtc.RTPCodecType { return s.rtpTrack.Kind() } - -// Codec gets the Codec of the track -func (s *TrackLocalStaticSample) Codec() webrtc.RTPCodecCapability { - return s.rtpTrack.Codec() -} - -// Bind is called by the PeerConnection after negotiation is complete -// This asserts that the code requested is supported by the remote peer. -// If so it setups all the state (SSRC and PayloadType) to have a call -func (s *TrackLocalStaticSample) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) { - codec, err := s.rtpTrack.Bind(t) - if err != nil { - return codec, err - } - - s.rtpTrack.mu.Lock() - defer s.rtpTrack.mu.Unlock() - - // We only need one packetizer - if s.packetizer != nil { - return codec, nil - } - - payloader, err := payloaderForCodec(codec.RTPCodecCapability) - if err != nil { - return codec, err - } - - s.packetizer = rtp.NewPacketizer( - rtpOutboundMTU, - 0, // Value is handled when writing - 0, // Value is handled when writing - payloader, - rtp.NewRandomSequencer(), - codec.ClockRate, - ) - s.clockRate = float64(codec.RTPCodecCapability.ClockRate) - return codec, nil -} - -// Unbind implements the teardown logic when the track is no longer needed. This happens -// because a track has been stopped. -func (s *TrackLocalStaticSample) Unbind(t webrtc.TrackLocalContext) error { - return s.rtpTrack.Unbind(t) -} - -// WriteSample writes a Sample to the TrackLocalStaticSample -// If one PeerConnection fails the packets will still be sent to -// all PeerConnections. The error message will contain the ID of the failed -// PeerConnections so you can remove them -func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error { - s.rtpTrack.mu.RLock() - p := s.packetizer - clockRate := s.clockRate - s.rtpTrack.mu.RUnlock() - - if p == nil { - return nil - } - - samples := sample.Duration.Seconds() * clockRate - packets := p.(rtp.Packetizer).Packetize(sample.Data, uint32(samples)) - - writeErrs := []error{} - for _, p := range packets { - if s.OnBeforeWritePacket != nil { - s.OnBeforeWritePacket(p) - } - - if err := s.rtpTrack.WriteRTP(p); err != nil { - writeErrs = append(writeErrs, err) - } - } - - return FlattenErrs(writeErrs) -} diff --git a/trunk/3rdparty/srs-bench/rtc/pion_util.go b/trunk/3rdparty/srs-bench/rtc/pion_util.go deleted file mode 100644 index 75ffe4601..000000000 --- a/trunk/3rdparty/srs-bench/rtc/pion_util.go +++ /dev/null @@ -1,10 +0,0 @@ -package rtc - -import "fmt" - -func FlattenErrs(errors []error) error { - if len(errors) == 0 { - return nil - } - return fmt.Errorf("%v", errors) -} diff --git a/trunk/3rdparty/srs-bench/srs/ingester.go b/trunk/3rdparty/srs-bench/srs/ingester.go new file mode 100644 index 000000000..1e3161a89 --- /dev/null +++ b/trunk/3rdparty/srs-bench/srs/ingester.go @@ -0,0 +1,285 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +package srs + +import ( + "context" + "github.com/ossrs/go-oryx-lib/errors" + "github.com/ossrs/go-oryx-lib/logger" + "github.com/pion/interceptor" + "github.com/pion/rtp" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v3" + "github.com/pion/webrtc/v3/pkg/media" + "github.com/pion/webrtc/v3/pkg/media/h264reader" + "github.com/pion/webrtc/v3/pkg/media/oggreader" + "io" + "os" + "strings" + "time" +) + +type videoIngester struct { + sourceVideo string + fps int + markerInterceptor *RTPInterceptor + sVideoTrack *webrtc.TrackLocalStaticSample + sVideoSender *webrtc.RTPSender +} + +func NewVideoIngester(sourceVideo string) *videoIngester { + return &videoIngester{markerInterceptor: &RTPInterceptor{}, sourceVideo: sourceVideo} +} + +func (v *videoIngester) Close() error { + if v.sVideoSender != nil { + v.sVideoSender.Stop() + v.sVideoSender = nil + } + return nil +} + +func (v *videoIngester) AddTrack(pc *webrtc.PeerConnection, fps int) error { + v.fps = fps + + mimeType, trackID := "video/H264", "video" + if strings.HasSuffix(v.sourceVideo, ".ivf") { + mimeType = "video/VP8" + } + + var err error + v.sVideoTrack, err = webrtc.NewTrackLocalStaticSample( + webrtc.RTPCodecCapability{MimeType: mimeType, ClockRate: 90000}, trackID, "pion", + ) + if err != nil { + return errors.Wrapf(err, "Create video track") + } + + v.sVideoSender, err = pc.AddTrack(v.sVideoTrack) + if err != nil { + return errors.Wrapf(err, "Add video track") + } + return err +} + +func (v *videoIngester) Ingest(ctx context.Context) error { + source, sender, track, fps := v.sourceVideo, v.sVideoSender, v.sVideoTrack, v.fps + + f, err := os.Open(source) + if err != nil { + return errors.Wrapf(err, "Open file %v", source) + } + defer f.Close() + + // TODO: FIXME: Support ivf for vp8. + h264, err := h264reader.NewReader(f) + if err != nil { + return errors.Wrapf(err, "Open h264 %v", source) + } + + enc := sender.GetParameters().Encodings[0] + codec := sender.GetParameters().Codecs[0] + headers := sender.GetParameters().HeaderExtensions + logger.Tf(ctx, "Video %v, tbn=%v, fps=%v, ssrc=%v, pt=%v, header=%v", + codec.MimeType, codec.ClockRate, fps, enc.SSRC, codec.PayloadType, headers) + + clock := newWallClock() + sampleDuration := time.Duration(uint64(time.Millisecond) * 1000 / uint64(fps)) + for ctx.Err() == nil { + var sps, pps *h264reader.NAL + var oFrames []*h264reader.NAL + for ctx.Err() == nil { + frame, err := h264.NextNAL() + if err == io.EOF { + return io.EOF + } + if err != nil { + return errors.Wrapf(err, "Read h264") + } + + oFrames = append(oFrames, frame) + logger.If(ctx, "NALU %v PictureOrderCount=%v, ForbiddenZeroBit=%v, RefIdc=%v, %v bytes", + frame.UnitType.String(), frame.PictureOrderCount, frame.ForbiddenZeroBit, frame.RefIdc, len(frame.Data)) + + if frame.UnitType == h264reader.NalUnitTypeSPS { + sps = frame + } else if frame.UnitType == h264reader.NalUnitTypePPS { + pps = frame + } else { + break + } + } + + var frames []*h264reader.NAL + // Package SPS/PPS to STAP-A + if sps != nil && pps != nil { + stapA := packageAsSTAPA(sps, pps) + frames = append(frames, stapA) + } + // Append other original frames. + for _, frame := range oFrames { + if frame.UnitType != h264reader.NalUnitTypeSPS && frame.UnitType != h264reader.NalUnitTypePPS { + frames = append(frames, frame) + } + } + + // Covert frames to sample(buffers). + for i, frame := range frames { + sample := media.Sample{Data: frame.Data, Duration: sampleDuration} + // Use the sample timestamp for frames. + if i != len(frames)-1 { + sample.Duration = 0 + } + + // For STAP-A, set marker to false, to make Chrome happy. + if ri := v.markerInterceptor; ri.rtpWriter == nil { + ri.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + // TODO: Should we decode to check whether SPS/PPS? + if len(payload) > 0 && payload[0]&0x1f == 24 { + header.Marker = false // 24, STAP-A + } + return ri.nextRTPWriter.Write(header, payload, attributes) + } + } + + if err = track.WriteSample(sample); err != nil { + return errors.Wrapf(err, "Write sample") + } + } + + if d := clock.Tick(sampleDuration); d > 0 { + time.Sleep(d) + } + } + + return ctx.Err() +} + +type audioIngester struct { + sourceAudio string + audioLevelInterceptor *RTPInterceptor + sAudioTrack *webrtc.TrackLocalStaticSample + sAudioSender *webrtc.RTPSender +} + +func NewAudioIngester(sourceAudio string) *audioIngester { + return &audioIngester{audioLevelInterceptor: &RTPInterceptor{}, sourceAudio: sourceAudio} +} + +func (v *audioIngester) Close() error { + if v.sAudioSender != nil { + v.sAudioSender.Stop() + v.sAudioSender = nil + } + return nil +} + +func (v *audioIngester) AddTrack(pc *webrtc.PeerConnection) error { + var err error + + mimeType, trackID := "audio/opus", "audio" + v.sAudioTrack, err = webrtc.NewTrackLocalStaticSample( + webrtc.RTPCodecCapability{MimeType: mimeType, ClockRate: 48000, Channels: 2}, trackID, "pion", + ) + if err != nil { + return errors.Wrapf(err, "Create audio track") + } + + v.sAudioSender, err = pc.AddTrack(v.sAudioTrack) + if err != nil { + return errors.Wrapf(err, "Add audio track") + } + + return nil +} + +func (v *audioIngester) Ingest(ctx context.Context) error { + source, sender, track := v.sourceAudio, v.sAudioSender, v.sAudioTrack + + f, err := os.Open(source) + if err != nil { + return errors.Wrapf(err, "Open file %v", source) + } + defer f.Close() + + ogg, _, err := oggreader.NewWith(f) + if err != nil { + return errors.Wrapf(err, "Open ogg %v", source) + } + + enc := sender.GetParameters().Encodings[0] + codec := sender.GetParameters().Codecs[0] + headers := sender.GetParameters().HeaderExtensions + logger.Tf(ctx, "Audio %v, tbn=%v, channels=%v, ssrc=%v, pt=%v, header=%v", + codec.MimeType, codec.ClockRate, codec.Channels, enc.SSRC, codec.PayloadType, headers) + + // Whether should encode the audio-level in RTP header. + var audioLevel *webrtc.RTPHeaderExtensionParameter + for _, h := range headers { + if h.URI == sdp.AudioLevelURI { + audioLevel = &h + } + } + + clock := newWallClock() + var lastGranule uint64 + + for ctx.Err() == nil { + pageData, pageHeader, err := ogg.ParseNextPage() + if err == io.EOF { + return io.EOF + } + if err != nil { + return errors.Wrapf(err, "Read ogg") + } + + // The amount of samples is the difference between the last and current timestamp + sampleCount := uint64(pageHeader.GranulePosition - lastGranule) + lastGranule = pageHeader.GranulePosition + sampleDuration := time.Duration(uint64(time.Millisecond) * 1000 * sampleCount / uint64(codec.ClockRate)) + + // For audio-level, set the extensions if negotiated. + if ri := v.audioLevelInterceptor; ri.rtpWriter == nil { + ri.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if audioLevel != nil { + audioLevelPayload, err := new(rtp.AudioLevelExtension).Marshal() + if err != nil { + return 0, err + } + + header.SetExtension(uint8(audioLevel.ID), audioLevelPayload) + } + + return ri.nextRTPWriter.Write(header, payload, attributes) + } + } + + if err = track.WriteSample(media.Sample{Data: pageData, Duration: sampleDuration}); err != nil { + return errors.Wrapf(err, "Write sample") + } + + if d := clock.Tick(sampleDuration); d > 0 { + time.Sleep(d) + } + } + + return ctx.Err() +} diff --git a/trunk/3rdparty/srs-bench/srs/interceptor.go b/trunk/3rdparty/srs-bench/srs/interceptor.go new file mode 100644 index 000000000..d853aaf7d --- /dev/null +++ b/trunk/3rdparty/srs-bench/srs/interceptor.go @@ -0,0 +1,175 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +package srs + +import ( + "github.com/pion/interceptor" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +type RTPInterceptorOptionFunc func(i *RTPInterceptor) + +// Common RTP packet interceptor for benchmark. +// @remark Should never merge with RTCPInterceptor, because they has the same Write interface. +type RTPInterceptor struct { + localInfo *interceptor.StreamInfo + remoteInfo *interceptor.StreamInfo + // If rtpReader is nil, use the default next one to read. + rtpReader interceptor.RTPReaderFunc + nextRTPReader interceptor.RTPReader + // If rtpWriter is nil, use the default next one to write. + rtpWriter interceptor.RTPWriterFunc + nextRTPWriter interceptor.RTPWriter + BypassInterceptor +} + +func NewRTPInterceptor(options ...RTPInterceptorOptionFunc) *RTPInterceptor { + v := &RTPInterceptor{} + for _, opt := range options { + opt(v) + } + return v +} + +func (v *RTPInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + if v.localInfo != nil { + return writer // Only handle one stream. + } + + v.localInfo = info + v.nextRTPWriter = writer + return v // Handle all RTP +} + +func (v *RTPInterceptor) Write(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if v.rtpWriter != nil { + return v.rtpWriter(header, payload, attributes) + } + return v.nextRTPWriter.Write(header, payload, attributes) +} + +func (v *RTPInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) { + if v.localInfo == nil || v.localInfo.ID != info.ID { + return + } + v.localInfo = nil // Reset the interceptor. +} + +func (v *RTPInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + if v.remoteInfo != nil { + return reader // Only handle one stream. + } + + v.nextRTPReader = reader + return v // Handle all RTP +} + +func (v *RTPInterceptor) Read(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + if v.rtpReader != nil { + return v.rtpReader(b, a) + } + return v.nextRTPReader.Read(b, a) +} + +func (v *RTPInterceptor) UnbindRemoteStream(info *interceptor.StreamInfo) { + if v.remoteInfo == nil || v.remoteInfo.ID != info.ID { + return + } + v.remoteInfo = nil +} + +type RTCPInterceptorOptionFunc func(i *RTCPInterceptor) + +// Common RTCP packet interceptor for benchmark. +// @remark Should never merge with RTPInterceptor, because they has the same Write interface. +type RTCPInterceptor struct { + // If rtcpReader is nil, use the default next one to read. + rtcpReader interceptor.RTCPReaderFunc + nextRTCPReader interceptor.RTCPReader + // If rtcpWriter is nil, use the default next one to write. + rtcpWriter interceptor.RTCPWriterFunc + nextRTCPWriter interceptor.RTCPWriter + BypassInterceptor +} + +func NewRTCPInterceptor(options ...RTCPInterceptorOptionFunc) *RTCPInterceptor { + v := &RTCPInterceptor{} + for _, opt := range options { + opt(v) + } + return v +} + +func (v *RTCPInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { + v.nextRTCPReader = reader + return v // Handle all RTCP +} + +func (v *RTCPInterceptor) Read(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + if v.rtcpReader != nil { + return v.rtcpReader(b, a) + } + return v.nextRTCPReader.Read(b, a) +} + +func (v *RTCPInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + v.nextRTCPWriter = writer + return v // Handle all RTCP +} + +func (v *RTCPInterceptor) Write(pkts []rtcp.Packet, attributes interceptor.Attributes) (int, error) { + if v.rtcpWriter != nil { + return v.rtcpWriter(pkts, attributes) + } + return v.nextRTCPWriter.Write(pkts, attributes) +} + +// Do nothing. +type BypassInterceptor struct { + interceptor.Interceptor +} + +func (v *BypassInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { + return reader +} + +func (v *BypassInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + return writer +} + +func (v *BypassInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + return writer +} + +func (v *BypassInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) { +} + +func (v *BypassInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + return reader +} + +func (v *BypassInterceptor) UnbindRemoteStream(info *interceptor.StreamInfo) { +} + +func (v *BypassInterceptor) Close() error { + return nil +} diff --git a/trunk/3rdparty/srs-bench/srs/player.go b/trunk/3rdparty/srs-bench/srs/player.go index 8977248e6..0947ad41c 100644 --- a/trunk/3rdparty/srs-bench/srs/player.go +++ b/trunk/3rdparty/srs-bench/srs/player.go @@ -1,3 +1,23 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. package srs import ( @@ -65,7 +85,14 @@ func StartPlay(ctx context.Context, r, dumpAudio, dumpVideo string, enableAudioL if err != nil { return errors.Wrapf(err, "Create PC") } - defer pc.Close() + + var receivers []*webrtc.RTPReceiver + defer func() { + pc.Close() + for _, receiver := range receivers { + receiver.Stop() + } + }() pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{ Direction: webrtc.RTPTransceiverDirectionRecvonly, @@ -132,6 +159,8 @@ func StartPlay(ctx context.Context, r, dumpAudio, dumpVideo string, enableAudioL } }() + receivers = append(receivers, receiver) + codec := track.Codec() trackDesc := fmt.Sprintf("channels=%v", codec.Channels) diff --git a/trunk/3rdparty/srs-bench/srs/publisher.go b/trunk/3rdparty/srs-bench/srs/publisher.go index 172ff9e21..8d38fb055 100644 --- a/trunk/3rdparty/srs-bench/srs/publisher.go +++ b/trunk/3rdparty/srs-bench/srs/publisher.go @@ -1,20 +1,33 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. package srs import ( "context" "github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/logger" - "github.com/ossrs/srs-bench/rtc" "github.com/pion/interceptor" - "github.com/pion/rtp" "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" - "github.com/pion/webrtc/v3/pkg/media" - "github.com/pion/webrtc/v3/pkg/media/h264reader" - "github.com/pion/webrtc/v3/pkg/media/oggreader" "io" - "os" - "strings" "sync" "time" ) @@ -26,7 +39,12 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i logger.Tf(ctx, "Start publish url=%v, audio=%v, video=%v, fps=%v, audio-level=%v, twcc=%v", r, sourceAudio, sourceVideo, fps, enableAudioLevel, enableTWCC) - // For audio-level. + // Filter for SPS/PPS marker. + var aIngester *audioIngester + var vIngester *videoIngester + + // For audio-level and sps/pps marker. + // TODO: FIXME: Should share with player. webrtcNewPeerConnection := func(configuration webrtc.Configuration) (*webrtc.PeerConnection, error) { m := &webrtc.MediaEngine{} if err := m.RegisterDefaultCodecs(); err != nil { @@ -53,12 +71,21 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i } } - i := &interceptor.Registry{} - if err := webrtc.RegisterDefaultInterceptors(m, i); err != nil { + registry := &interceptor.Registry{} + if err := webrtc.RegisterDefaultInterceptors(m, registry); err != nil { return nil, err } - api := webrtc.NewAPI(webrtc.WithMediaEngine(m), webrtc.WithInterceptorRegistry(i)) + if sourceAudio != "" { + aIngester = NewAudioIngester(sourceAudio) + registry.Add(aIngester.audioLevelInterceptor) + } + if sourceVideo != "" { + vIngester = NewVideoIngester(sourceVideo) + registry.Add(vIngester.markerInterceptor) + } + + api := webrtc.NewAPI(webrtc.WithMediaEngine(m), webrtc.WithInterceptorRegistry(registry)) return api.NewPeerConnection(configuration) } @@ -66,46 +93,30 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i if err != nil { return errors.Wrapf(err, "Create PC") } - defer pc.Close() - var sVideoTrack *rtc.TrackLocalStaticSample - var sVideoSender *webrtc.RTPSender - if sourceVideo != "" { - mimeType, trackID := "video/H264", "video" - if strings.HasSuffix(sourceVideo, ".ivf") { - mimeType = "video/VP8" + doClose := func() { + if pc != nil { + pc.Close() } + if vIngester != nil { + vIngester.Close() + } + if aIngester != nil { + aIngester.Close() + } + } + defer doClose() - sVideoTrack, err = rtc.NewTrackLocalStaticSample( - webrtc.RTPCodecCapability{MimeType: mimeType, ClockRate: 90000}, trackID, "pion", - ) - if err != nil { - return errors.Wrapf(err, "Create video track") + if vIngester != nil { + if err := vIngester.AddTrack(pc, fps); err != nil { + return errors.Wrapf(err, "Add track") } - - sVideoSender, err = pc.AddTrack(sVideoTrack) - if err != nil { - return errors.Wrapf(err, "Add video track") - } - sVideoSender.Stop() } - var sAudioTrack *rtc.TrackLocalStaticSample - var sAudioSender *webrtc.RTPSender - if sourceAudio != "" { - mimeType, trackID := "audio/opus", "audio" - sAudioTrack, err = rtc.NewTrackLocalStaticSample( - webrtc.RTPCodecCapability{MimeType: mimeType, ClockRate: 48000, Channels: 2}, trackID, "pion", - ) - if err != nil { - return errors.Wrapf(err, "Create audio track") + if aIngester != nil { + if err := aIngester.AddTrack(pc); err != nil { + return errors.Wrapf(err, "Add track") } - - sAudioSender, err = pc.AddTrack(sAudioTrack) - if err != nil { - return errors.Wrapf(err, "Add audio track") - } - defer sAudioSender.Stop() } offer, err := pc.CreateOffer(nil) @@ -139,9 +150,11 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i logger.Tf(ctx, "Signaling state %v", state) }) - sAudioSender.Transport().OnStateChange(func(state webrtc.DTLSTransportState) { - logger.Tf(ctx, "DTLS state %v", state) - }) + if aIngester != nil { + aIngester.sAudioSender.Transport().OnStateChange(func(state webrtc.DTLSTransportState) { + logger.Tf(ctx, "DTLS state %v", state) + }) + } ctx, cancel := context.WithCancel(ctx) pcDone, pcDoneCancel := context.WithCancel(context.Background()) @@ -168,8 +181,15 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i wg.Add(1) go func() { defer wg.Done() + <-ctx.Done() + doClose() // Interrupt the RTCP read. + }() - if sAudioSender == nil { + wg.Add(1) + go func() { + defer wg.Done() + + if aIngester == nil { return } @@ -181,7 +201,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i buf := make([]byte, 1500) for ctx.Err() == nil { - if _, _, err := sAudioSender.Read(buf); err != nil { + if _, _, err := aIngester.sAudioSender.Read(buf); err != nil { return } } @@ -191,7 +211,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i go func() { defer wg.Done() - if sAudioTrack == nil { + if aIngester == nil { return } @@ -201,8 +221,9 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i logger.Tf(ctx, "PC(ICE+DTLS+SRTP) done, start ingest audio %v", sourceAudio) } + // Read audio and send out. for ctx.Err() == nil { - if err := readAudioTrackFromDisk(ctx, sourceAudio, sAudioSender, sAudioTrack); err != nil { + if err := aIngester.Ingest(ctx); err != nil { if errors.Cause(err) == io.EOF { logger.Tf(ctx, "EOF, restart ingest audio %v", sourceAudio) continue @@ -216,7 +237,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i go func() { defer wg.Done() - if sVideoSender == nil { + if vIngester == nil { return } @@ -228,7 +249,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i buf := make([]byte, 1500) for ctx.Err() == nil { - if _, _, err := sVideoSender.Read(buf); err != nil { + if _, _, err := vIngester.sVideoSender.Read(buf); err != nil { return } } @@ -238,7 +259,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i go func() { defer wg.Done() - if sVideoTrack == nil { + if vIngester == nil { return } @@ -249,7 +270,7 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i } for ctx.Err() == nil { - if err := readVideoTrackFromDisk(ctx, sourceVideo, sVideoSender, fps, sVideoTrack); err != nil { + if err := vIngester.Ingest(ctx); err != nil { if errors.Cause(err) == io.EOF { logger.Tf(ctx, "EOF, restart ingest video %v", sourceVideo) continue @@ -276,154 +297,3 @@ func StartPublish(ctx context.Context, r, sourceAudio, sourceVideo string, fps i wg.Wait() return nil } - -func readAudioTrackFromDisk(ctx context.Context, source string, sender *webrtc.RTPSender, track *rtc.TrackLocalStaticSample) error { - f, err := os.Open(source) - if err != nil { - return errors.Wrapf(err, "Open file %v", source) - } - defer f.Close() - - ogg, _, err := oggreader.NewWith(f) - if err != nil { - return errors.Wrapf(err, "Open ogg %v", source) - } - - enc := sender.GetParameters().Encodings[0] - codec := sender.GetParameters().Codecs[0] - headers := sender.GetParameters().HeaderExtensions - logger.Tf(ctx, "Audio %v, tbn=%v, channels=%v, ssrc=%v, pt=%v, header=%v", - codec.MimeType, codec.ClockRate, codec.Channels, enc.SSRC, codec.PayloadType, headers) - - // Whether should encode the audio-level in RTP header. - var audioLevel *webrtc.RTPHeaderExtensionParameter - for _, h := range headers { - if h.URI == sdp.AudioLevelURI { - audioLevel = &h - } - } - - clock := newWallClock() - var lastGranule uint64 - - for ctx.Err() == nil { - pageData, pageHeader, err := ogg.ParseNextPage() - if err == io.EOF { - return nil - } - if err != nil { - return errors.Wrapf(err, "Read ogg") - } - - // The amount of samples is the difference between the last and current timestamp - sampleCount := uint64(pageHeader.GranulePosition - lastGranule) - lastGranule = pageHeader.GranulePosition - sampleDuration := time.Duration(uint64(time.Millisecond) * 1000 * sampleCount / uint64(codec.ClockRate)) - - // For audio-level, set the extensions if negotiated. - track.OnBeforeWritePacket = func(p *rtp.Packet) { - if audioLevel != nil { - if b, err := new(rtp.AudioLevelExtension).Marshal(); err == nil { - p.SetExtension(uint8(audioLevel.ID), b) - } - } - } - - if err = track.WriteSample(media.Sample{Data: pageData, Duration: sampleDuration}); err != nil { - return errors.Wrapf(err, "Write sample") - } - - if d := clock.Tick(sampleDuration); d > 0 { - time.Sleep(d) - } - } - - return nil -} - -func readVideoTrackFromDisk(ctx context.Context, source string, sender *webrtc.RTPSender, fps int, track *rtc.TrackLocalStaticSample) error { - f, err := os.Open(source) - if err != nil { - return errors.Wrapf(err, "Open file %v", source) - } - defer f.Close() - - // TODO: FIXME: Support ivf for vp8. - h264, err := h264reader.NewReader(f) - if err != nil { - return errors.Wrapf(err, "Open h264 %v", source) - } - - enc := sender.GetParameters().Encodings[0] - codec := sender.GetParameters().Codecs[0] - headers := sender.GetParameters().HeaderExtensions - logger.Tf(ctx, "Video %v, tbn=%v, fps=%v, ssrc=%v, pt=%v, header=%v", - codec.MimeType, codec.ClockRate, fps, enc.SSRC, codec.PayloadType, headers) - - clock := newWallClock() - sampleDuration := time.Duration(uint64(time.Millisecond) * 1000 / uint64(fps)) - for ctx.Err() == nil { - var sps, pps *h264reader.NAL - var oFrames []*h264reader.NAL - for ctx.Err() == nil { - frame, err := h264.NextNAL() - if err == io.EOF { - return nil - } - if err != nil { - return errors.Wrapf(err, "Read h264") - } - - oFrames = append(oFrames, frame) - logger.If(ctx, "NALU %v PictureOrderCount=%v, ForbiddenZeroBit=%v, RefIdc=%v, %v bytes", - frame.UnitType.String(), frame.PictureOrderCount, frame.ForbiddenZeroBit, frame.RefIdc, len(frame.Data)) - - if frame.UnitType == h264reader.NalUnitTypeSPS { - sps = frame - } else if frame.UnitType == h264reader.NalUnitTypePPS { - pps = frame - } else { - break - } - } - - var frames []*h264reader.NAL - // Package SPS/PPS to STAP-A - if sps != nil && pps != nil { - stapA := packageAsSTAPA(sps, pps) - frames = append(frames, stapA) - } - // Append other original frames. - for _, frame := range oFrames { - if frame.UnitType != h264reader.NalUnitTypeSPS && frame.UnitType != h264reader.NalUnitTypePPS { - frames = append(frames, frame) - } - } - - // Covert frames to sample(buffers). - for i, frame := range frames { - sample := media.Sample{Data: frame.Data, Duration: sampleDuration} - // Use the sample timestamp for frames. - if i != len(frames)-1 { - sample.Duration = 0 - } - - // For STAP-A, set marker to false, to make Chrome happy. - track.OnBeforeWritePacket = func(p *rtp.Packet) { - if i < len(frames)-1 { - p.Header.Marker = false - } - } - - if err = track.WriteSample(sample); err != nil { - return errors.Wrapf(err, "Write sample") - } - } - - if d := clock.Tick(sampleDuration); d > 0 { - time.Sleep(d) - } - } - - return nil -} diff --git a/trunk/3rdparty/srs-bench/srs/rtc_test.go b/trunk/3rdparty/srs-bench/srs/rtc_test.go index da7d89949..4ac869c42 100644 --- a/trunk/3rdparty/srs-bench/srs/rtc_test.go +++ b/trunk/3rdparty/srs-bench/srs/rtc_test.go @@ -1,449 +1,1467 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. package srs import ( "context" - "encoding/json" - "flag" "fmt" "github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/logger" - "github.com/ossrs/srs-bench/rtc" + "github.com/pion/interceptor" "github.com/pion/rtcp" - "github.com/pion/webrtc/v3" - "io" - "io/ioutil" - "net/http" + "github.com/pion/rtp" + "github.com/pion/transport/vnet" + "math/rand" "os" - "strings" "sync" "testing" "time" ) -var srsSchema = "http" -var srsHttps = flag.Bool("srs-https", false, "Whther connect to HTTPS-API") -var srsServer = flag.String("srs-server", "127.0.0.1", "The RTC server to connect to") -var srsStream = flag.String("srs-stream", "/rtc/regression", "The RTC stream to play") -var srsLog = flag.Bool("srs-log", false, "Whether enable the detail log") -var srsTimeout = flag.Int("srs-timeout", 3000, "For each case, the timeout in ms") -var srsPlayPLI = flag.Int("srs-play-pli", 5000, "The PLI interval in seconds for player.") -var srsPlayOKPackets = flag.Int("srs-play-ok-packets", 10, "If got N packets, it's ok, or fail") -var srsPublishAudio = flag.String("srs-publish-audio", "avatar.ogg", "The audio file for publisher.") -var srsPublishVideo = flag.String("srs-publish-video", "avatar.h264", "The video file for publisher.") -var srsPublishVideoFps = flag.Int("srs-publish-video-fps", 25, "The video fps for publisher.") - -func TestMain(m *testing.M) { - // Should parse it first. - flag.Parse() - - // The stream should starts with /, for example, /rtc/regression - if strings.HasPrefix(*srsStream, "/") { - *srsStream = "/" + *srsStream - } - - // Generate srs protocol from whether use HTTPS. - if *srsHttps { - srsSchema = "https" - } - - // Disable the logger during all tests. - logger.Tf(nil, "sys log %v", *srsLog) - - if *srsLog == false { - olw := logger.Switch(ioutil.Discard) - defer func() { - logger.Switch(olw) - }() - } - - // Run tests. - os.Exit(m.Run()) -} - -func TestRTCServerVersion(t *testing.T) { - api := fmt.Sprintf("http://%v:1985/api/v1/versions", *srsServer) - req, err := http.NewRequest("POST", api, nil) - if err != nil { - t.Errorf("Request %v", api) - return - } - - res, err := http.DefaultClient.Do(req) - if err != nil { - t.Errorf("Do request %v", api) - return - } - - b, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Errorf("Read body of %v", api) - return - } - - obj := struct { - Code int `json:"code"` - Server string `json:"server"` - Data struct { - Major int `json:"major"` - Minor int `json:"minor"` - Revision int `json:"revision"` - Version string `json:"version"` - } `json:"data"` - }{} - if err := json.Unmarshal(b, &obj); err != nil { - t.Errorf("Parse %v", string(b)) - return - } - if obj.Code != 0 { - t.Errorf("Server err code=%v, server=%v", obj.Code, obj.Server) - return - } - if obj.Data.Major == 0 && obj.Data.Minor == 0 { - t.Errorf("Invalid version %v", obj.Data) - return - } -} - -func TestRTCServerPublishPlay(t *testing.T) { +// Basic use scenario, publish a stream, then play it. +func TestRtcBasic_PublishPlay(t *testing.T) { ctx := logger.WithContext(context.Background()) - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond) - r := fmt.Sprintf("%v://%v%v", srsSchema, *srsServer, *srsStream) - publishReady, publishReadyCancel := context.WithCancel(context.Background()) - - startPlay := func(ctx context.Context) error { - logger.Tf(ctx, "Start play url=%v", r) - - pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) - if err != nil { - return errors.Wrapf(err, "Create PC") + var r0, r1, r2, r3 error + defer func(ctx context.Context) { + if err := filterTestError(ctx.Err(), r0, r1, r2, r3); err != nil { + t.Errorf("Fail for err %+v", err) + } else { + logger.Tf(ctx, "test done with err %+v", err) } - defer pc.Close() - - pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{ - Direction: webrtc.RTPTransceiverDirectionRecvonly, - }) - pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{ - Direction: webrtc.RTPTransceiverDirectionRecvonly, - }) - - offer, err := pc.CreateOffer(nil) - if err != nil { - return errors.Wrapf(err, "Create Offer") - } - - if err := pc.SetLocalDescription(offer); err != nil { - return errors.Wrapf(err, "Set offer %v", offer) - } - - answer, err := apiRtcRequest(ctx, "/rtc/v1/play", r, offer.SDP) - if err != nil { - return errors.Wrapf(err, "Api request offer=%v", offer.SDP) - } - - if err := pc.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeAnswer, SDP: answer, - }); err != nil { - return errors.Wrapf(err, "Set answer %v", answer) - } - - handleTrack := func(ctx context.Context, track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) error { - // Send a PLI on an interval so that the publisher is pushing a keyframe - go func() { - if track.Kind() == webrtc.RTPCodecTypeAudio { - return - } - - for { - select { - case <-ctx.Done(): - return - case <-time.After(time.Duration(*srsPlayPLI) * time.Millisecond): - _ = pc.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{ - MediaSSRC: uint32(track.SSRC()), - }}) - } - } - }() - - // Try to read packets of track. - for i := 0; i < *srsPlayOKPackets && ctx.Err() == nil; i++ { - _, _, err := track.ReadRTP() - if err != nil { - return errors.Wrapf(err, "Read RTP") - } - } - - // Completed. - cancel() - - return nil - } - - pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - err = handleTrack(ctx, track, receiver) - if err != nil { - codec := track.Codec() - err = errors.Wrapf(err, "Handle track %v, pt=%v", codec.MimeType, codec.PayloadType) - cancel() - } - }) - - pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { - if state == webrtc.ICEConnectionStateFailed || state == webrtc.ICEConnectionStateClosed { - err = errors.Errorf("Close for ICE state %v", state) - cancel() - } - }) - - <-ctx.Done() - return err - } - - startPublish := func(ctx context.Context) error { - sourceVideo := *srsPublishVideo - sourceAudio := *srsPublishAudio - fps := *srsPublishVideoFps - - logger.Tf(ctx, "Start publish url=%v, audio=%v, video=%v, fps=%v", - r, sourceAudio, sourceVideo, fps) - - pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) - if err != nil { - return errors.Wrapf(err, "Create PC") - } - defer pc.Close() - - var sVideoTrack *rtc.TrackLocalStaticSample - var sVideoSender *webrtc.RTPSender - if sourceVideo != "" { - mimeType, trackID := "video/H264", "video" - if strings.HasSuffix(sourceVideo, ".ivf") { - mimeType = "video/VP8" - } - - sVideoTrack, err = rtc.NewTrackLocalStaticSample( - webrtc.RTPCodecCapability{MimeType: mimeType, ClockRate: 90000}, trackID, "pion", - ) - if err != nil { - return errors.Wrapf(err, "Create video track") - } - - sVideoSender, err = pc.AddTrack(sVideoTrack) - if err != nil { - return errors.Wrapf(err, "Add video track") - } - sVideoSender.Stop() - } - - var sAudioTrack *rtc.TrackLocalStaticSample - var sAudioSender *webrtc.RTPSender - if sourceAudio != "" { - mimeType, trackID := "audio/opus", "audio" - sAudioTrack, err = rtc.NewTrackLocalStaticSample( - webrtc.RTPCodecCapability{MimeType: mimeType, ClockRate: 48000, Channels: 2}, trackID, "pion", - ) - if err != nil { - return errors.Wrapf(err, "Create audio track") - } - - sAudioSender, err = pc.AddTrack(sAudioTrack) - if err != nil { - return errors.Wrapf(err, "Add audio track") - } - defer sAudioSender.Stop() - } - - offer, err := pc.CreateOffer(nil) - if err != nil { - return errors.Wrapf(err, "Create Offer") - } - - if err := pc.SetLocalDescription(offer); err != nil { - return errors.Wrapf(err, "Set offer %v", offer) - } - - answer, err := apiRtcRequest(ctx, "/rtc/v1/publish", r, offer.SDP) - if err != nil { - return errors.Wrapf(err, "Api request offer=%v", offer.SDP) - } - - if err := pc.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeAnswer, SDP: answer, - }); err != nil { - return errors.Wrapf(err, "Set answer %v", answer) - } - - logger.Tf(ctx, "State signaling=%v, ice=%v, conn=%v", pc.SignalingState(), pc.ICEConnectionState(), pc.ConnectionState()) - - ctx, cancel := context.WithCancel(ctx) - pcDone, pcDoneCancel := context.WithCancel(context.Background()) - pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { - logger.Tf(ctx, "PC state %v", state) - - if state == webrtc.PeerConnectionStateConnected { - pcDoneCancel() - publishReadyCancel() - } - - if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { - err = errors.Errorf("Close for PC state %v", state) - cancel() - } - }) - - // Wait for event from context or tracks. - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - - if sAudioSender == nil { - return - } - - select { - case <-ctx.Done(): - case <-pcDone.Done(): - } - - buf := make([]byte, 1500) - for ctx.Err() == nil { - if _, _, err := sAudioSender.Read(buf); err != nil { - return - } - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - - if sAudioTrack == nil { - return - } - - select { - case <-ctx.Done(): - case <-pcDone.Done(): - } - - for ctx.Err() == nil { - if err := readAudioTrackFromDisk(ctx, sourceAudio, sAudioSender, sAudioTrack); err != nil { - if errors.Cause(err) == io.EOF { - logger.Tf(ctx, "EOF, restart ingest audio %v", sourceAudio) - continue - } - logger.Wf(ctx, "Ignore audio err %+v", err) - } - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - - if sVideoSender == nil { - return - } - - select { - case <-ctx.Done(): - case <-pcDone.Done(): - logger.Tf(ctx, "PC(ICE+DTLS+SRTP) done, start read video packets") - } - - buf := make([]byte, 1500) - for ctx.Err() == nil { - if _, _, err := sVideoSender.Read(buf); err != nil { - return - } - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - - if sVideoTrack == nil { - return - } - - select { - case <-ctx.Done(): - case <-pcDone.Done(): - logger.Tf(ctx, "PC(ICE+DTLS+SRTP) done, start ingest video %v", sourceVideo) - } - - for ctx.Err() == nil { - if err := readVideoTrackFromDisk(ctx, sourceVideo, sVideoSender, fps, sVideoTrack); err != nil { - if errors.Cause(err) == io.EOF { - logger.Tf(ctx, "EOF, restart ingest video %v", sourceVideo) - continue - } - logger.Wf(ctx, "Ignore video err %+v", err) - } - } - }() - - wg.Wait() - return err - } + }(ctx) var wg sync.WaitGroup - errs := make(chan error, 0) + defer wg.Wait() + // The event notify. + var thePublisher *TestPublisher + var thePlayer *TestPlayer + mainReady, mainReadyCancel := context.WithCancel(context.Background()) + publishReady, publishReadyCancel := context.WithCancel(context.Background()) + + // Objects init. wg.Add(1) go func() { defer wg.Done() + defer cancel() + + doInit := func() error { + playOK := *srsPlayOKPackets + vnetClientIP := *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("basic-publish-play-%v-%v", os.Getpid(), rand.Int()) + play := NewTestPlayer(api, func(play *TestPlayer) { + play.streamSuffix = streamSuffix + }) + defer play.Close() + + pub := NewTestPublisher(api, func(pub *TestPublisher) { + pub.streamSuffix = streamSuffix + pub.iceReadyCancel = publishReadyCancel + }) + defer pub.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nnWriteRTP, nnReadRTP, nnWriteRTCP, nnReadRTCP int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpReader = func(buf []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { + nn, attr, err := i.nextRTPReader.Read(buf, attributes) + nnReadRTP++ + return nn, attr, err + } + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + nn, err := i.nextRTPWriter.Write(header, payload, attributes) + + nnWriteRTP++ + logger.Tf(ctx, "publish rtp=(read:%v write:%v), rtcp=(read:%v write:%v) packets", + nnReadRTP, nnWriteRTP, nnReadRTCP, nnWriteRTCP) + return nn, err + } + })) + api.registry.Add(NewRTCPInterceptor(func(i *RTCPInterceptor) { + i.rtcpReader = func(buf []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { + nn, attr, err := i.nextRTCPReader.Read(buf, attributes) + nnReadRTCP++ + return nn, attr, err + } + i.rtcpWriter = func(pkts []rtcp.Packet, attributes interceptor.Attributes) (int, error) { + nn, err := i.nextRTCPWriter.Write(pkts, attributes) + nnWriteRTCP++ + return nn, err + } + })) + }, func(api *TestWebRTCAPI) { + var nn uint64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpReader = func(payload []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { + if nn++; nn >= uint64(playOK) { + cancel() // Completed. + } + logger.Tf(ctx, "play got %v packets", nn) + return i.nextRTPReader.Read(payload, attributes) + } + })) + }); err != nil { + return err + } + + // Set the available objects. + mainReadyCancel() + thePublisher = pub + thePlayer = play + + <-ctx.Done() + return nil + } + + if err := doInit(); err != nil { + r1 = err + } + }() + + // Run publisher. + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + select { + case <-ctx.Done(): + return + case <-mainReady.Done(): + } + + doPublish := func() error { + if err := thePublisher.Run(logger.WithContext(ctx), cancel); err != nil { + return err + } + + logger.Tf(ctx, "pub done") + return nil + } + if err := doPublish(); err != nil { + r2 = err + } + }() + + // Run player. + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + select { + case <-ctx.Done(): + return + case <-mainReady.Done(): + } - // Wait for publisher to start first. select { case <-ctx.Done(): return case <-publishReady.Done(): } - errs <- startPlay(logger.WithContext(ctx)) - cancel() - }() - - wg.Add(1) - go func() { - defer wg.Done() - - errs <- startPublish(logger.WithContext(ctx)) - cancel() - }() - - wg.Add(1) - go func() { - defer wg.Done() - - select { - case <-ctx.Done(): - case <-time.After(time.Duration(*srsTimeout) * time.Millisecond): - errs <- errors.Errorf("timeout for %vms", *srsTimeout) - cancel() - } - }() - - testDone, testDoneCancel := context.WithCancel(context.Background()) - go func() { - wg.Wait() - testDoneCancel() - }() - - // Handle errs, the test result. - for { - select { - case <-testDone.Done(): - return - case err := <-errs: - if err != nil && err != context.Canceled && !t.Failed() { - t.Errorf("err %+v", err) + doPlay := func() error { + if err := thePlayer.Run(logger.WithContext(ctx), cancel); err != nil { + return err } + + logger.Tf(ctx, "play done") + return nil } + if err := doPlay(); err != nil { + r3 = err + } + + }() +} + +// The srs-server is DTLS server(passive), srs-bench is DTLS client which is active mode. +// No.1 srs-bench: ClientHello +// No.2 srs-server: ServerHello, Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// No.3 srs-bench: Certificate, ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// No.4 srs-server: ChangeCipherSpec, Finished +func TestRtcDTLS_ClientActive_Default(t *testing.T) { + if err := filterTestError(func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-no-arq-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupActive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed { + return true + } + logger.Tf(ctx, "Chunk %v, ok=%v %v bytes", chunk, ok, len(c.UserData())) + return true + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }()); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client(client), srs-bench is DTLS server which is passive mode. +// No.1 srs-server: ClientHello +// No.2 srs-bench: ServerHello, Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// No.3 srs-server: Certificate, ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// No.4 srs-bench: ChangeCipherSpec, Finished +func TestRtcDTLS_ClientPassive_Default(t *testing.T) { + if err := filterTestError(func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-active-no-arq-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed { + return true + } + logger.Tf(ctx, "Chunk %v, ok=%v %v bytes", chunk, ok, len(c.UserData())) + return true + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }()); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS server, srs-bench is DTLS client which is active mode. +// When srs-bench close the PC, it will send DTLS alert and might retransmit it. +func TestRtcDTLS_ClientActive_Duplicated_Alert(t *testing.T) { + if err := filterTestError(func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-active-no-arq-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupActive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed || chunk.chunk != ChunkTypeDTLS { + return true + } + + // Copy the alert to server, ignore error. + if chunk.content == DTLSContentTypeAlert { + _, _ = api.proxy.Deliver(c.SourceAddr(), c.DestinationAddr(), c.UserData()) + _, _ = api.proxy.Deliver(c.SourceAddr(), c.DestinationAddr(), c.UserData()) + } + + logger.Tf(ctx, "Chunk %v, ok=%v %v bytes", chunk, ok, len(c.UserData())) + return true + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }()); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client, srs-bench is DTLS server which is passive mode. +// When srs-bench close the PC, it will send DTLS alert and might retransmit it. +func TestRtcDTLS_ClientPassive_Duplicated_Alert(t *testing.T) { + if err := filterTestError(func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-active-no-arq-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed || chunk.chunk != ChunkTypeDTLS { + return true + } + + // Copy the alert to server, ignore error. + if chunk.content == DTLSContentTypeAlert { + _, _ = api.proxy.Deliver(c.SourceAddr(), c.DestinationAddr(), c.UserData()) + _, _ = api.proxy.Deliver(c.SourceAddr(), c.DestinationAddr(), c.UserData()) + } + + logger.Tf(ctx, "Chunk %v, ok=%v %v bytes", chunk, ok, len(c.UserData())) + return true + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }()); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS server, srs-bench is DTLS client which is active mode. +// [Drop] No.1 srs-bench: ClientHello(Epoch=0, Sequence=0) +// [ARQ] No.2 srs-bench: ClientHello(Epoch=0, Sequence=1) +// No.3 srs-server: ServerHello, Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// No.4 srs-bench: Certificate, ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// No.5 srs-server: ChangeCipherSpec, Finished +// +// @remark The pion is active, so it can be consider a benchmark for DTLS server. +func TestRtcDTLS_ClientActive_ARQ_ClientHello_ByDropped_ClientHello(t *testing.T) { + var r0 error + err := func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-active-arq-client-hello-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupActive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + nnClientHello, nnMaxDrop := 0, 1 + var lastClientHello *DTLSRecord + + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed || chunk.chunk != ChunkTypeDTLS || chunk.content != DTLSContentTypeHandshake || chunk.handshake != DTLSHandshakeTypeClientHello { + return true + } + + record, err := NewDTLSRecord(c.UserData()) + if err != nil { + return true + } + + if lastClientHello != nil && record.Equals(lastClientHello) { + r0 = errors.Errorf("dup record %v", record) + } + lastClientHello = record + + nnClientHello++ + ok = (nnClientHello > nnMaxDrop) + logger.Tf(ctx, "NN=%v, Chunk %v, %v, ok=%v %v bytes", nnClientHello, chunk, record, ok, len(c.UserData())) + return + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }() + if err := filterTestError(err, r0); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client, srs-bench is DTLS server which is passive mode. +// [Drop] No.1 srs-server: ClientHello(Epoch=0, Sequence=0) +// [ARQ] No.2 srs-server: ClientHello(Epoch=0, Sequence=1) +// No.3 srs-bench: ServerHello, Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// No.4 srs-server: Certificate, ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// No.5 srs-bench: ChangeCipherSpec, Finished +// +// @remark If retransmit the ClientHello, with the same epoch+sequence, peer will request HelloVerifyRequest, then +// openssl will create a new ClientHello with increased sequence. It's ok, but waste a lots of duplicated ClientHello +// packets, so we fail the test, requires the epoch+sequence never dup, even for ARQ. +func TestRtcDTLS_ClientPassive_ARQ_ClientHello_ByDropped_ClientHello(t *testing.T) { + var r0 error + err := func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-arq-client-hello-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + nnClientHello, nnMaxDrop := 0, 1 + var lastClientHello *DTLSRecord + + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed || chunk.chunk != ChunkTypeDTLS || chunk.content != DTLSContentTypeHandshake || chunk.handshake != DTLSHandshakeTypeClientHello { + return true + } + + record, err := NewDTLSRecord(c.UserData()) + if err != nil { + return true + } + + if lastClientHello != nil && record.Equals(lastClientHello) { + r0 = errors.Errorf("dup record %v", record) + } + lastClientHello = record + + nnClientHello++ + ok = (nnClientHello > nnMaxDrop) + logger.Tf(ctx, "NN=%v, Chunk %v, %v, ok=%v %v bytes", nnClientHello, chunk, record, ok, len(c.UserData())) + return + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }() + if err := filterTestError(err, r0); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS server, srs-bench is DTLS client which is active mode. +// No.1 srs-bench: ClientHello(Epoch=0, Sequence=0) +// [Drop] No.2 srs-server: ServerHello(Epoch=0, Sequence=0), Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// [ARQ] No.2 srs-bench: ClientHello(Epoch=0, Sequence=1) +// [ARQ] No.3 srs-server: ServerHello(Epoch=0, Sequence=5), Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// No.4 srs-bench: Certificate, ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// No.5 srs-server: ChangeCipherSpec, Finished +// +// @remark The pion is active, so it can be consider a benchmark for DTLS server. +func TestRtcDTLS_ClientActive_ARQ_ClientHello_ByDropped_ServerHello(t *testing.T) { + var r0, r1 error + err := func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-active-arq-client-hello-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupActive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + nnServerHello, nnMaxDrop := 0, 1 + var lastClientHello, lastServerHello *DTLSRecord + + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed || chunk.chunk != ChunkTypeDTLS || chunk.content != DTLSContentTypeHandshake || + (chunk.handshake != DTLSHandshakeTypeClientHello && chunk.handshake != DTLSHandshakeTypeServerHello) { + return true + } + + record, err := NewDTLSRecord(c.UserData()) + if err != nil { + return true + } + + if chunk.handshake == DTLSHandshakeTypeClientHello { + if lastClientHello != nil && record.Equals(lastClientHello) { + r0 = errors.Errorf("dup record %v", record) + } + lastClientHello = record + return true + } + + if lastServerHello != nil && record.Equals(lastServerHello) { + r1 = errors.Errorf("dup record %v", record) + } + lastServerHello = record + + nnServerHello++ + ok = (nnServerHello > nnMaxDrop) + logger.Tf(ctx, "NN=%v, Chunk %v, %v, ok=%v %v bytes", nnServerHello, chunk, record, ok, len(c.UserData())) + return + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }() + if err := filterTestError(err, r0, r1); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client, srs-bench is DTLS server which is passive mode. +// No.1 srs-server: ClientHello(Epoch=0, Sequence=0) +// [Drop] No.2 srs-bench: ServerHello(Epoch=0, Sequence=0), Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// [ARQ] No.2 srs-server: ClientHello(Epoch=0, Sequence=1) +// [ARQ] No.3 srs-bench: ServerHello(Epoch=0, Sequence=5), Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// No.4 srs-server: Certificate, ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// No.5 srs-bench: ChangeCipherSpec, Finished +// +// @remark If retransmit the ClientHello, with the same epoch+sequence, peer will request HelloVerifyRequest, then +// openssl will create a new ClientHello with increased sequence. It's ok, but waste a lots of duplicated ClientHello +// packets, so we fail the test, requires the epoch+sequence never dup, even for ARQ. +func TestRtcDTLS_ClientPassive_ARQ_ClientHello_ByDropped_ServerHello(t *testing.T) { + var r0, r1 error + err := func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-arq-client-hello-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + nnServerHello, nnMaxDrop := 0, 1 + var lastClientHello, lastServerHello *DTLSRecord + + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed || chunk.chunk != ChunkTypeDTLS || chunk.content != DTLSContentTypeHandshake || + (chunk.handshake != DTLSHandshakeTypeClientHello && chunk.handshake != DTLSHandshakeTypeServerHello) { + return true + } + + record, err := NewDTLSRecord(c.UserData()) + if err != nil { + return true + } + + if chunk.handshake == DTLSHandshakeTypeClientHello { + if lastClientHello != nil && record.Equals(lastClientHello) { + r0 = errors.Errorf("dup record %v", record) + } + lastClientHello = record + return true + } + + if lastServerHello != nil && record.Equals(lastServerHello) { + r1 = errors.Errorf("dup record %v", record) + } + lastServerHello = record + + nnServerHello++ + ok = (nnServerHello > nnMaxDrop) + logger.Tf(ctx, "NN=%v, Chunk %v, %v, ok=%v %v bytes", nnServerHello, chunk, record, ok, len(c.UserData())) + return + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }() + if err := filterTestError(err, r0, r1); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS server, srs-bench is DTLS client which is active mode. +// No.1 srs-bench: ClientHello +// No.2 srs-server: ServerHello, Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// [Drop] No.3 srs-bench: Certificate(Epoch=0, Sequence=0), ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// [ARQ] No.4 srs-bench: Certificate(Epoch=0, Sequence=5), ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// No.5 srs-server: ChangeCipherSpec, Finished +// +// @remark The pion is active, so it can be consider a benchmark for DTLS server. +func TestRtcDTLS_ClientActive_ARQ_Certificate_ByDropped_Certificate(t *testing.T) { + var r0 error + err := func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-active-arq-certificate-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupActive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + nnCertificate, nnMaxDrop := 0, 1 + var lastCertificate *DTLSRecord + + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed || chunk.chunk != ChunkTypeDTLS || chunk.content != DTLSContentTypeHandshake || chunk.handshake != DTLSHandshakeTypeCertificate { + return true + } + + record, err := NewDTLSRecord(c.UserData()) + if err != nil { + return true + } + + if lastCertificate != nil && lastCertificate.Equals(record) { + r0 = errors.Errorf("dup record %v", record) + } + lastCertificate = record + + nnCertificate++ + ok = (nnCertificate > nnMaxDrop) + logger.Tf(ctx, "NN=%v, Chunk %v, %v, ok=%v %v bytes", nnCertificate, chunk, record, ok, len(c.UserData())) + return + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }() + if err := filterTestError(err, r0); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client, srs-bench is DTLS server which is passive mode. +// No.1 srs-server: ClientHello +// No.2 srs-bench: ServerHello, Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// [Drop] No.3 srs-server: Certificate(Epoch=0, Sequence=0), ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// [ARQ] No.4 srs-server: Certificate(Epoch=0, Sequence=5), ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// No.5 srs-bench: ChangeCipherSpec, Finished +// +// @remark If retransmit the Certificate, with the same epoch+sequence, peer will drop the message. It's ok right now, but +// wast some packets, so we check the epoch+sequence which should never dup, even for ARQ. +func TestRtcDTLS_ClientPassive_ARQ_Certificate_ByDropped_Certificate(t *testing.T) { + var r0 error + err := func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-arq-certificate-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + nnCertificate, nnMaxDrop := 0, 1 + var lastCertificate *DTLSRecord + + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed || chunk.chunk != ChunkTypeDTLS || chunk.content != DTLSContentTypeHandshake || chunk.handshake != DTLSHandshakeTypeCertificate { + return true + } + + record, err := NewDTLSRecord(c.UserData()) + if err != nil { + return true + } + + if lastCertificate != nil && lastCertificate.Equals(record) { + r0 = errors.Errorf("dup record %v", record) + } + lastCertificate = record + + nnCertificate++ + ok = (nnCertificate > nnMaxDrop) + logger.Tf(ctx, "NN=%v, Chunk %v, %v, ok=%v %v bytes", nnCertificate, chunk, record, ok, len(c.UserData())) + return + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }() + if err := filterTestError(err, r0); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS server, srs-bench is DTLS client which is active mode. +// No.1 srs-bench: ClientHello +// No.2 srs-server: ServerHello, Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// No.3 srs-bench: Certificate(Epoch=0, Sequence=0), ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// [Drop] No.5 srs-server: ChangeCipherSpec, Finished +// [ARQ] No.6 srs-bench: Certificate(Epoch=0, Sequence=5), ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// [ARQ] No.7 srs-server: ChangeCipherSpec, Finished +// +// @remark The pion is active, so it can be consider a benchmark for DTLS server. +func TestRtcDTLS_ClientActive_ARQ_Certificate_ByDropped_ChangeCipherSpec(t *testing.T) { + var r0, r1 error + err := func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-active-arq-certificate-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupActive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + nnCertificate, nnMaxDrop := 0, 1 + var lastChangeCipherSepc, lastCertifidate *DTLSRecord + + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed || (!chunk.IsChangeCipherSpec() && !chunk.IsCertificate()) { + return true + } + + record, err := NewDTLSRecord(c.UserData()) + if err != nil { + return true + } + + if chunk.IsCertificate() { + if lastCertifidate != nil && record.Equals(lastCertifidate) { + r0 = errors.Errorf("dup record %v", record) + } + lastCertifidate = record + return true + } + + if lastChangeCipherSepc != nil && lastChangeCipherSepc.Equals(record) { + r1 = errors.Errorf("dup record %v", record) + } + lastChangeCipherSepc = record + + nnCertificate++ + ok = (nnCertificate > nnMaxDrop) + logger.Tf(ctx, "NN=%v, Chunk %v, %v, ok=%v %v bytes", nnCertificate, chunk, record, ok, len(c.UserData())) + return + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }() + if err := filterTestError(err, r0, r1); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client, srs-bench is DTLS server which is passive mode. +// No.1 srs-server: ClientHello +// No.2 srs-bench: ServerHello, Certificate, ServerKeyExchange, CertificateRequest, ServerHelloDone +// No.3 srs-server: Certificate(Epoch=0, Sequence=0), ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// [Drop] No.5 srs-bench: ChangeCipherSpec, Finished +// [ARQ] No.6 srs-server: Certificate(Epoch=0, Sequence=5), ClientKeyExchange, CertificateVerify, ChangeCipherSpec, Finished +// [ARQ] No.7 srs-bench: ChangeCipherSpec, Finished +// +// @remark If retransmit the Certificate, with the same epoch+sequence, peer will drop the message, and never generate the +// ChangeCipherSpec, which will cause DTLS fail. +func TestRtcDTLS_ClientPassive_ARQ_Certificate_ByDropped_ChangeCipherSpec(t *testing.T) { + var r0, r1 error + err := func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-arq-certificate-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + nnCertificate, nnMaxDrop := 0, 1 + var lastChangeCipherSepc, lastCertifidate *DTLSRecord + + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed || (!chunk.IsChangeCipherSpec() && !chunk.IsCertificate()) { + return true + } + + record, err := NewDTLSRecord(c.UserData()) + if err != nil { + return true + } + + if chunk.IsCertificate() { + if lastCertifidate != nil && record.Equals(lastCertifidate) { + r0 = errors.Errorf("dup record %v", record) + } + lastCertifidate = record + return true + } + + if lastChangeCipherSepc != nil && lastChangeCipherSepc.Equals(record) { + r1 = errors.Errorf("dup record %v", record) + } + lastChangeCipherSepc = record + + nnCertificate++ + ok = (nnCertificate > nnMaxDrop) + logger.Tf(ctx, "NN=%v, Chunk %v, %v, ok=%v %v bytes", nnCertificate, chunk, record, ok, len(c.UserData())) + return + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }() + if err := filterTestError(err, r0, r1); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client(client), srs-bench is DTLS server which is passive mode. +// Drop all DTLS packets when got ClientHello, to test the server ARQ thread cleanup. +func TestRtcDTLS_ClientPassive_ARQ_DropAllAfter_ClientHello(t *testing.T) { + if err := filterTestError(func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + vnetClientIP, dtlsDropPackets := *srsVnetClientIP, *srsDTLSDropPackets + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-no-arq-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + nnDrop, dropAll := 0, false + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed { + return true + } + + if chunk.IsHandshake() { + if chunk.IsClientHello() { + dropAll = true + } + + if !dropAll { + return true + } + + if nnDrop++; nnDrop >= dtlsDropPackets { + cancel() // Done, server transmit 5 Client Hello. + } + + logger.Tf(ctx, "N=%v, Drop chunk %v %v bytes", nnDrop, chunk, len(c.UserData())) + return false + } + + return true + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }()); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client(client), srs-bench is DTLS server which is passive mode. +// Drop all DTLS packets when got ServerHello, to test the server ARQ thread cleanup. +func TestRtcDTLS_ClientPassive_ARQ_DropAllAfter_ServerHello(t *testing.T) { + if err := filterTestError(func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + vnetClientIP, dtlsDropPackets := *srsVnetClientIP, *srsDTLSDropPackets + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-no-arq-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + nnDrop, dropAll := 0, false + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed { + return true + } + + if chunk.IsHandshake() { + if chunk.IsServerHello() { + dropAll = true + } + + if !dropAll { + return true + } + + if nnDrop++; nnDrop >= dtlsDropPackets { + cancel() // Done, server transmit 5 Client Hello. + } + + logger.Tf(ctx, "N=%v, Drop chunk %v %v bytes", nnDrop, chunk, len(c.UserData())) + return false + } + + return true + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }()); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client(client), srs-bench is DTLS server which is passive mode. +// Drop all DTLS packets when got Certificate, to test the server ARQ thread cleanup. +func TestRtcDTLS_ClientPassive_ARQ_DropAllAfter_Certificate(t *testing.T) { + if err := filterTestError(func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + vnetClientIP, dtlsDropPackets := *srsVnetClientIP, *srsDTLSDropPackets + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-no-arq-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + nnDrop, dropAll := 0, false + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed { + return true + } + + if chunk.IsHandshake() { + if chunk.IsCertificate() { + dropAll = true + } + + if !dropAll { + return true + } + + if nnDrop++; nnDrop >= dtlsDropPackets { + cancel() // Done, server transmit 5 Client Hello. + } + + logger.Tf(ctx, "N=%v, Drop chunk %v %v bytes", nnDrop, chunk, len(c.UserData())) + return false + } + + return true + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }()); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client(client), srs-bench is DTLS server which is passive mode. +// Drop all DTLS packets when got ChangeCipherSpec, to test the server ARQ thread cleanup. +func TestRtcDTLS_ClientPassive_ARQ_DropAllAfter_ChangeCipherSpec(t *testing.T) { + if err := filterTestError(func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + vnetClientIP, dtlsDropPackets := *srsVnetClientIP, *srsDTLSDropPackets + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-no-arq-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + nnDrop, dropAll := 0, false + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed { + return true + } + + if chunk.IsHandshake() || chunk.IsChangeCipherSpec() { + if chunk.IsChangeCipherSpec() { + dropAll = true + } + + if !dropAll { + return true + } + + if nnDrop++; nnDrop >= dtlsDropPackets { + cancel() // Done, server transmit 5 Client Hello. + } + + logger.Tf(ctx, "N=%v, Drop chunk %v %v bytes", nnDrop, chunk, len(c.UserData())) + return false + } + + return true + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }()); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client(client), srs-bench is DTLS server which is passive mode. +// For very bad network, we drop 4 ClientHello consume about 750ms, then drop 4 Certificate +// which also consume about 750ms, but finally should be done successfully. +func TestRtcDTLS_ClientPassive_ARQ_VeryBadNetwork(t *testing.T) { + if err := filterTestError(func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP, dtlsDropPackets := *srsPublishOKPackets, *srsVnetClientIP, *srsDTLSDropPackets + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-no-arq-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + nnDropClientHello, nnDropCertificate := 0, 0 + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed { + return true + } + + if chunk.IsHandshake() { + if !chunk.IsClientHello() && !chunk.IsCertificate() { + return true + } + + if chunk.IsClientHello() { + if nnDropClientHello >= 4 { + return true + } + nnDropClientHello++ + } + + if chunk.IsCertificate() { + if nnDropCertificate >= dtlsDropPackets { + return true + } + nnDropCertificate++ + } + + logger.Tf(ctx, "N=%v/%v, Drop chunk %v %v bytes", nnDropClientHello, nnDropCertificate, chunk, len(c.UserData())) + return false + } + + return true + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }()); err != nil { + t.Errorf("err %+v", err) + } +} + +// The srs-server is DTLS client(client), srs-bench is DTLS server which is passive mode. +// If we retransmit 2 ClientHello packets, consumed 150ms, server might wait at 200ms. +// Then we retransmit the Certificate, server reset the timer and retransmit it in 50ms, not 200ms. +func TestRtcDTLS_ClientPassive_ARQ_Certificate_After_ClientHello(t *testing.T) { + var r0 error + err := func() error { + ctx, cancel := context.WithTimeout(logger.WithContext(context.Background()), time.Duration(*srsTimeout)*time.Millisecond) + publishOK, vnetClientIP := *srsPublishOKPackets, *srsVnetClientIP + + // Create top level test object. + api, err := NewTestWebRTCAPI() + if err != nil { + return err + } + defer api.Close() + + streamSuffix := fmt.Sprintf("dtls-passive-no-arq-%v-%v", os.Getpid(), rand.Int()) + p := NewTestPublisher(api, func(p *TestPublisher) { + p.streamSuffix = streamSuffix + p.onOffer = testUtilSetupPassive + }) + defer p.Close() + + if err := api.Setup(vnetClientIP, func(api *TestWebRTCAPI) { + var nn int64 + api.registry.Add(NewRTPInterceptor(func(i *RTPInterceptor) { + i.rtpWriter = func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if nn++; nn >= int64(publishOK) { + cancel() // Send enough packets, done. + } + logger.Tf(ctx, "publish write %v packets", nn) + return i.nextRTPWriter.Write(header, payload, attributes) + } + })) + }, func(api *TestWebRTCAPI) { + nnDropClientHello, nnDropCertificate := 0, 0 + var firstCertificate time.Time + api.router.AddChunkFilter(func(c vnet.Chunk) (ok bool) { + chunk, parsed := NewChunkMessageType(c) + if !parsed { + return true + } + + if chunk.IsHandshake() { + if !chunk.IsClientHello() && !chunk.IsCertificate() { + return true + } + + if chunk.IsClientHello() { + if nnDropClientHello > 3 { + return true + } + nnDropClientHello++ + } + + if chunk.IsCertificate() { + if nnDropCertificate == 0 { + firstCertificate = time.Now() + } else if nnDropCertificate == 1 { + if duration := time.Now().Sub(firstCertificate); duration > 150*time.Millisecond { + r0 = fmt.Errorf("ARQ between ClientHello and Certificate too large %v", duration) + } else { + logger.Tf(ctx, "ARQ between ClientHello and Certificate is %v", duration) + } + cancel() + } + nnDropCertificate++ + } + + logger.Tf(ctx, "N=%v/%v, Drop chunk %v %v bytes", nnDropClientHello, nnDropCertificate, chunk, len(c.UserData())) + return false + } + + return true + }) + }); err != nil { + return err + } + + return p.Run(ctx, cancel) + }() + if err := filterTestError(err, r0); err != nil { + t.Errorf("err %+v", err) } } diff --git a/trunk/3rdparty/srs-bench/srs/stat.go b/trunk/3rdparty/srs-bench/srs/stat.go index 35cbe3737..3ca7ed79a 100644 --- a/trunk/3rdparty/srs-bench/srs/stat.go +++ b/trunk/3rdparty/srs-bench/srs/stat.go @@ -1,3 +1,23 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. package srs import ( diff --git a/trunk/3rdparty/srs-bench/srs/util.go b/trunk/3rdparty/srs-bench/srs/util.go index 4e6b2f783..8c5a5434d 100644 --- a/trunk/3rdparty/srs-bench/srs/util.go +++ b/trunk/3rdparty/srs-bench/srs/util.go @@ -1,3 +1,23 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. package srs import ( @@ -7,10 +27,14 @@ import ( "fmt" "github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/logger" + "github.com/pion/transport/vnet" + "github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3/pkg/media/h264reader" "io/ioutil" + "net" "net/http" "net/url" + "strconv" "strings" "time" ) @@ -140,3 +164,305 @@ func (v *wallClock) Tick(d time.Duration) time.Duration { } return 0 } + +// Set to active, as DTLS client, to start ClientHello. +func testUtilSetupActive(s *webrtc.SessionDescription) error { + if strings.Contains(s.SDP, "setup:passive") { + return errors.New("set to active") + } + + s.SDP = strings.ReplaceAll(s.SDP, "setup:actpass", "setup:active") + return nil +} + +// Set to passive, as DTLS client, to start ClientHello. +func testUtilSetupPassive(s *webrtc.SessionDescription) error { + if strings.Contains(s.SDP, "setup:active") { + return errors.New("set to passive") + } + + s.SDP = strings.ReplaceAll(s.SDP, "setup:actpass", "setup:passive") + return nil +} + +// Parse address from SDP. +// candidate:0 1 udp 2130706431 192.168.3.8 8000 typ host generation 0 +func parseAddressOfCandidate(answerSDP string) (*net.UDPAddr, error) { + answer := webrtc.SessionDescription{Type: webrtc.SDPTypeAnswer, SDP: answerSDP} + answerObject, err := answer.Unmarshal() + if err != nil { + return nil, errors.Wrapf(err, "unmarshal answer %v", answerSDP) + } + + if len(answerObject.MediaDescriptions) == 0 { + return nil, errors.New("no media") + } + + candidate, ok := answerObject.MediaDescriptions[0].Attribute("candidate") + if !ok { + return nil, errors.New("no candidate") + } + + // candidate:0 1 udp 2130706431 192.168.3.8 8000 typ host generation 0 + attrs := strings.Split(candidate, " ") + if len(attrs) <= 6 { + return nil, errors.Errorf("no address in %v", candidate) + } + + // Parse ip and port from answer. + ip := attrs[4] + port, err := strconv.Atoi(attrs[5]) + if err != nil { + return nil, errors.Wrapf(err, "invalid port %v", candidate) + } + + address := fmt.Sprintf("%v:%v", ip, port) + addr, err := net.ResolveUDPAddr("udp4", address) + if err != nil { + return nil, errors.Wrapf(err, "parse %v", address) + } + + return addr, nil +} + +// Filter the test error, ignore context.Canceled +func filterTestError(errs ...error) error { + var filteredErrors []error + + for _, err := range errs { + if err == nil || errors.Cause(err) == context.Canceled { + continue + } + filteredErrors = append(filteredErrors, err) + } + + if len(filteredErrors) == 0 { + return nil + } + if len(filteredErrors) == 1 { + return filteredErrors[0] + } + + var descs []string + for i, err := range filteredErrors[1:] { + descs = append(descs, fmt.Sprintf("err #%d, %+v", i, err)) + } + return errors.Wrapf(filteredErrors[0], "with %v", strings.Join(descs, ",")) +} + +// For STUN packet, 0x00 is binding request, 0x01 is binding success response. +// @see srs_is_stun of https://github.com/ossrs/srs +func srsIsStun(b []byte) bool { + return len(b) > 0 && (b[0] == 0 || b[0] == 1) +} + +// change_cipher_spec(20), alert(21), handshake(22), application_data(23) +// @see https://tools.ietf.org/html/rfc2246#section-6.2.1 +// @see srs_is_dtls of https://github.com/ossrs/srs +func srsIsDTLS(b []byte) bool { + return (len(b) >= 13 && (b[0] > 19 && b[0] < 64)) +} + +// For RTP or RTCP, the V=2 which is in the high 2bits, 0xC0 (1100 0000) +// @see srs_is_rtp_or_rtcp of https://github.com/ossrs/srs +func srsIsRTPOrRTCP(b []byte) bool { + return (len(b) >= 12 && (b[0]&0xC0) == 0x80) +} + +// For RTCP, PT is [128, 223] (or without marker [0, 95]). +// Literally, RTCP starts from 64 not 0, so PT is [192, 223] (or without marker [64, 95]). +// @note For RTP, the PT is [96, 127], or [224, 255] with marker. +// @see srs_is_rtcp of https://github.com/ossrs/srs +func srsIsRTCP(b []byte) bool { + return (len(b) >= 12) && (b[0]&0x80) != 0 && (b[1] >= 192 && b[1] <= 223) +} + +type ChunkType int + +const ( + ChunkTypeICE ChunkType = iota + 1 + ChunkTypeDTLS + ChunkTypeRTP + ChunkTypeRTCP +) + +func (v ChunkType) String() string { + switch v { + case ChunkTypeICE: + return "ICE" + case ChunkTypeDTLS: + return "DTLS" + case ChunkTypeRTP: + return "RTP" + case ChunkTypeRTCP: + return "RTCP" + default: + return "Unknown" + } +} + +type DTLSContentType int + +const ( + DTLSContentTypeHandshake DTLSContentType = 22 + DTLSContentTypeChangeCipherSpec DTLSContentType = 20 + DTLSContentTypeAlert DTLSContentType = 21 +) + +func (v DTLSContentType) String() string { + switch v { + case DTLSContentTypeHandshake: + return "Handshake" + case DTLSContentTypeChangeCipherSpec: + return "ChangeCipherSpec" + default: + return "Unknown" + } +} + +type DTLSHandshakeType int + +const ( + DTLSHandshakeTypeClientHello DTLSHandshakeType = 1 + DTLSHandshakeTypeServerHello DTLSHandshakeType = 2 + DTLSHandshakeTypeCertificate DTLSHandshakeType = 11 + DTLSHandshakeTypeServerKeyExchange DTLSHandshakeType = 12 + DTLSHandshakeTypeCertificateRequest DTLSHandshakeType = 13 + DTLSHandshakeTypeServerDone DTLSHandshakeType = 14 + DTLSHandshakeTypeCertificateVerify DTLSHandshakeType = 15 + DTLSHandshakeTypeClientKeyExchange DTLSHandshakeType = 16 + DTLSHandshakeTypeFinished DTLSHandshakeType = 20 +) + +func (v DTLSHandshakeType) String() string { + switch v { + case DTLSHandshakeTypeClientHello: + return "ClientHello" + case DTLSHandshakeTypeServerHello: + return "ServerHello" + case DTLSHandshakeTypeCertificate: + return "Certificate" + case DTLSHandshakeTypeServerKeyExchange: + return "ServerKeyExchange" + case DTLSHandshakeTypeCertificateRequest: + return "CertificateRequest" + case DTLSHandshakeTypeServerDone: + return "ServerDone" + case DTLSHandshakeTypeCertificateVerify: + return "CertificateVerify" + case DTLSHandshakeTypeClientKeyExchange: + return "ClientKeyExchange" + case DTLSHandshakeTypeFinished: + return "Finished" + default: + return "Unknown" + } +} + +type ChunkMessageType struct { + chunk ChunkType + content DTLSContentType + handshake DTLSHandshakeType +} + +func (v *ChunkMessageType) String() string { + if v.chunk == ChunkTypeDTLS { + return fmt.Sprintf("%v-%v-%v", v.chunk, v.content, v.handshake) + } + return fmt.Sprintf("%v", v.chunk) +} + +func NewChunkMessageType(c vnet.Chunk) (*ChunkMessageType, bool) { + b := c.UserData() + + if len(b) == 0 { + return nil, false + } + + v := &ChunkMessageType{} + + if srsIsRTPOrRTCP(b) { + if srsIsRTCP(b) { + v.chunk = ChunkTypeRTCP + } else { + v.chunk = ChunkTypeRTP + } + return v, true + } + + if srsIsStun(b) { + v.chunk = ChunkTypeICE + return v, true + } + + if !srsIsDTLS(b) { + return nil, false + } + + v.chunk, v.content = ChunkTypeDTLS, DTLSContentType(b[0]) + if v.content != DTLSContentTypeHandshake { + return v, true + } + + if len(b) < 14 { + return v, false + } + v.handshake = DTLSHandshakeType(b[13]) + return v, true +} + +func (v *ChunkMessageType) IsHandshake() bool { + return v.chunk == ChunkTypeDTLS && v.content == DTLSContentTypeHandshake +} + +func (v *ChunkMessageType) IsClientHello() bool { + return v.chunk == ChunkTypeDTLS && v.content == DTLSContentTypeHandshake && v.handshake == DTLSHandshakeTypeClientHello +} + +func (v *ChunkMessageType) IsServerHello() bool { + return v.chunk == ChunkTypeDTLS && v.content == DTLSContentTypeHandshake && v.handshake == DTLSHandshakeTypeServerHello +} + +func (v *ChunkMessageType) IsCertificate() bool { + return v.chunk == ChunkTypeDTLS && v.content == DTLSContentTypeHandshake && v.handshake == DTLSHandshakeTypeCertificate +} + +func (v *ChunkMessageType) IsChangeCipherSpec() bool { + return v.chunk == ChunkTypeDTLS && v.content == DTLSContentTypeChangeCipherSpec +} + +type DTLSRecord struct { + ContentType DTLSContentType + Version uint16 + Epoch uint16 + SequenceNumber uint64 + Length uint16 + Data []byte +} + +func NewDTLSRecord(b []byte) (*DTLSRecord, error) { + v := &DTLSRecord{} + return v, v.Unmarshal(b) +} + +func (v *DTLSRecord) String() string { + return fmt.Sprintf("epoch=%v, sequence=%v", v.Epoch, v.SequenceNumber) +} + +func (v *DTLSRecord) Equals(p *DTLSRecord) bool { + return v.Epoch == p.Epoch && v.SequenceNumber == p.SequenceNumber +} + +func (v *DTLSRecord) Unmarshal(b []byte) error { + if len(b) < 13 { + return errors.Errorf("requires 13B only %v", len(b)) + } + + v.ContentType = DTLSContentType(uint8(b[0])) + v.Version = uint16(b[1])<<8 | uint16(b[2]) + v.Epoch = uint16(b[3])<<8 | uint16(b[4]) + v.SequenceNumber = uint64(b[5])<<40 | uint64(b[6])<<32 | uint64(b[7])<<24 | uint64(b[8])<<16 | uint64(b[9])<<8 | uint64(b[10]) + v.Length = uint16(b[11])<<8 | uint16(b[12]) + v.Data = b[13:] + return nil +} diff --git a/trunk/3rdparty/srs-bench/srs/util_test.go b/trunk/3rdparty/srs-bench/srs/util_test.go new file mode 100644 index 000000000..68187c9ef --- /dev/null +++ b/trunk/3rdparty/srs-bench/srs/util_test.go @@ -0,0 +1,723 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +package srs + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "github.com/ossrs/go-oryx-lib/errors" + "github.com/ossrs/go-oryx-lib/logger" + vnet_proxy "github.com/ossrs/srs-bench/vnet" + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/transport/vnet" + "github.com/pion/webrtc/v3" + "io" + "io/ioutil" + "net/http" + "os" + "path" + "strings" + "sync" + "testing" + "time" +) + +var srsSchema = "http" +var srsHttps = flag.Bool("srs-https", false, "Whther connect to HTTPS-API") +var srsServer = flag.String("srs-server", "127.0.0.1", "The RTC server to connect to") +var srsStream = flag.String("srs-stream", "/rtc/regression", "The RTC stream to play") +var srsLog = flag.Bool("srs-log", false, "Whether enable the detail log") +var srsTimeout = flag.Int("srs-timeout", 5000, "For each case, the timeout in ms") +var srsPlayPLI = flag.Int("srs-play-pli", 5000, "The PLI interval in seconds for player.") +var srsPlayOKPackets = flag.Int("srs-play-ok-packets", 10, "If got N packets, it's ok, or fail") +var srsPublishOKPackets = flag.Int("srs-publish-ok-packets", 10, "If send N packets, it's ok, or fail") +var srsPublishAudio = flag.String("srs-publish-audio", "avatar.ogg", "The audio file for publisher.") +var srsPublishVideo = flag.String("srs-publish-video", "avatar.h264", "The video file for publisher.") +var srsPublishVideoFps = flag.Int("srs-publish-video-fps", 25, "The video fps for publisher.") +var srsVnetClientIP = flag.String("srs-vnet-client-ip", "192.168.168.168", "The client ip in pion/vnet.") +var srsDTLSDropPackets = flag.Int("srs-dtls-drop-packets", 5, "If dropped N packets, it's ok, or fail") + +func prepareTest() error { + var err error + + // Should parse it first. + flag.Parse() + + // The stream should starts with /, for example, /rtc/regression + if !strings.HasPrefix(*srsStream, "/") { + *srsStream = "/" + *srsStream + } + + // Generate srs protocol from whether use HTTPS. + if *srsHttps { + srsSchema = "https" + } + + // Check file. + tryOpenFile := func(filename string) (string, error) { + if filename == "" { + return filename, nil + } + + f, err := os.Open(filename) + if err != nil { + nfilename := path.Join("../", filename) + f2, err := os.Open(nfilename) + if err != nil { + return filename, errors.Wrapf(err, "No video file at %v or %v", filename, nfilename) + } + defer f2.Close() + + return nfilename, nil + } + defer f.Close() + + return filename, nil + } + + if *srsPublishVideo, err = tryOpenFile(*srsPublishVideo); err != nil { + return err + } + + if *srsPublishAudio, err = tryOpenFile(*srsPublishAudio); err != nil { + return err + } + + return nil +} + +func TestMain(m *testing.M) { + if err := prepareTest(); err != nil { + logger.Ef(nil, "Prepare test fail, err %+v", err) + os.Exit(-1) + } + + // Disable the logger during all tests. + if *srsLog == false { + olw := logger.Switch(ioutil.Discard) + defer func() { + logger.Switch(olw) + }() + } + + os.Exit(m.Run()) +} + +type TestWebRTCAPIOptionFunc func(api *TestWebRTCAPI) + +type TestWebRTCAPI struct { + // The options to setup the api. + options []TestWebRTCAPIOptionFunc + // The api and settings. + api *webrtc.API + mediaEngine *webrtc.MediaEngine + registry *interceptor.Registry + settingEngine *webrtc.SettingEngine + // The vnet router, can be shared by different apis, but we do not share it. + router *vnet.Router + // The network for api. + network *vnet.Net + // The vnet UDP proxy bind to the router. + proxy *vnet_proxy.UDPProxy +} + +func NewTestWebRTCAPI(options ...TestWebRTCAPIOptionFunc) (*TestWebRTCAPI, error) { + v := &TestWebRTCAPI{} + + v.mediaEngine = &webrtc.MediaEngine{} + if err := v.mediaEngine.RegisterDefaultCodecs(); err != nil { + return nil, err + } + + v.registry = &interceptor.Registry{} + if err := webrtc.RegisterDefaultInterceptors(v.mediaEngine, v.registry); err != nil { + return nil, err + } + + for _, setup := range options { + setup(v) + } + + v.settingEngine = &webrtc.SettingEngine{} + + return v, nil +} + +func (v *TestWebRTCAPI) Close() error { + if v.proxy != nil { + v.proxy.Close() + v.proxy = nil + } + + if v.router != nil { + v.router.Stop() + v.router = nil + } + + return nil +} + +func (v *TestWebRTCAPI) Setup(vnetClientIP string, options ...TestWebRTCAPIOptionFunc) error { + // Setting engine for https://github.com/pion/transport/tree/master/vnet + setupVnet := func(vnetClientIP string) (err error) { + // We create a private router for a api, however, it's possible to share the + // same router between apis. + if v.router, err = vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", // Accept all ip, no sub router. + LoggerFactory: logging.NewDefaultLoggerFactory(), + }); err != nil { + return errors.Wrapf(err, "create router for api") + } + + // Each api should bind to a network, however, it's possible to share it + // for different apis. + v.network = vnet.NewNet(&vnet.NetConfig{ + StaticIP: vnetClientIP, + }) + + if err = v.router.AddNet(v.network); err != nil { + return errors.Wrapf(err, "create network for api") + } + + v.settingEngine.SetVNet(v.network) + + // Create a proxy bind to the router. + if v.proxy, err = vnet_proxy.NewProxy(v.router); err != nil { + return errors.Wrapf(err, "create proxy for router") + } + + return v.router.Start() + } + if err := setupVnet(vnetClientIP); err != nil { + return err + } + + for _, setup := range options { + setup(v) + } + + for _, setup := range v.options { + setup(v) + } + + v.api = webrtc.NewAPI( + webrtc.WithMediaEngine(v.mediaEngine), + webrtc.WithInterceptorRegistry(v.registry), + webrtc.WithSettingEngine(*v.settingEngine), + ) + + return nil +} + +func (v *TestWebRTCAPI) NewPeerConnection(configuration webrtc.Configuration) (*webrtc.PeerConnection, error) { + return v.api.NewPeerConnection(configuration) +} + +type TestPlayerOptionFunc func(p *TestPlayer) + +type TestPlayer struct { + pc *webrtc.PeerConnection + receivers []*webrtc.RTPReceiver + // root api object + api *TestWebRTCAPI + // Optional suffix for stream url. + streamSuffix string +} + +func NewTestPlayer(api *TestWebRTCAPI, options ...TestPlayerOptionFunc) *TestPlayer { + v := &TestPlayer{api: api} + + for _, opt := range options { + opt(v) + } + + return v +} + +func (v *TestPlayer) Close() error { + if v.pc != nil { + v.pc.Close() + v.pc = nil + } + + for _, receiver := range v.receivers { + receiver.Stop() + } + v.receivers = nil + + return nil +} + +func (v *TestPlayer) Run(ctx context.Context, cancel context.CancelFunc) error { + r := fmt.Sprintf("%v://%v%v", srsSchema, *srsServer, *srsStream) + if v.streamSuffix != "" { + r = fmt.Sprintf("%v-%v", r, v.streamSuffix) + } + pli := time.Duration(*srsPlayPLI) * time.Millisecond + logger.Tf(ctx, "Start play url=%v", r) + + pc, err := v.api.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + return errors.Wrapf(err, "Create PC") + } + v.pc = pc + + pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionRecvonly, + }) + pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionRecvonly, + }) + + offer, err := pc.CreateOffer(nil) + if err != nil { + return errors.Wrapf(err, "Create Offer") + } + + if err := pc.SetLocalDescription(offer); err != nil { + return errors.Wrapf(err, "Set offer %v", offer) + } + + answer, err := apiRtcRequest(ctx, "/rtc/v1/play", r, offer.SDP) + if err != nil { + return errors.Wrapf(err, "Api request offer=%v", offer.SDP) + } + + // Start a proxy for real server and vnet. + if address, err := parseAddressOfCandidate(answer); err != nil { + return errors.Wrapf(err, "parse address of %v", answer) + } else if err := v.api.proxy.Proxy(v.api.network, address); err != nil { + return errors.Wrapf(err, "proxy %v to %v", v.api.network, address) + } + + if err := pc.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, SDP: answer, + }); err != nil { + return errors.Wrapf(err, "Set answer %v", answer) + } + + handleTrack := func(ctx context.Context, track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) error { + // Send a PLI on an interval so that the publisher is pushing a keyframe + go func() { + if track.Kind() == webrtc.RTPCodecTypeAudio { + return + } + + for { + select { + case <-ctx.Done(): + return + case <-time.After(pli): + _ = pc.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{ + MediaSSRC: uint32(track.SSRC()), + }}) + } + } + }() + + v.receivers = append(v.receivers, receiver) + + for ctx.Err() == nil { + _, _, err := track.ReadRTP() + if err != nil { + return errors.Wrapf(err, "Read RTP") + } + } + + return nil + } + + pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + err = handleTrack(ctx, track, receiver) + if err != nil { + codec := track.Codec() + err = errors.Wrapf(err, "Handle track %v, pt=%v", codec.MimeType, codec.PayloadType) + cancel() + } + }) + + pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + if state == webrtc.ICEConnectionStateFailed || state == webrtc.ICEConnectionStateClosed { + err = errors.Errorf("Close for ICE state %v", state) + cancel() + } + }) + + <-ctx.Done() + return err +} + +type TestPublisherOptionFunc func(p *TestPublisher) + +type TestPublisher struct { + onOffer func(s *webrtc.SessionDescription) error + onAnswer func(s *webrtc.SessionDescription) error + iceReadyCancel context.CancelFunc + // internal objects + aIngester *audioIngester + vIngester *videoIngester + pc *webrtc.PeerConnection + // root api object + api *TestWebRTCAPI + // Optional suffix for stream url. + streamSuffix string +} + +func NewTestPublisher(api *TestWebRTCAPI, options ...TestPublisherOptionFunc) *TestPublisher { + sourceVideo, sourceAudio := *srsPublishVideo, *srsPublishAudio + + v := &TestPublisher{api: api} + + for _, opt := range options { + opt(v) + } + + // Create ingesters. + if sourceAudio != "" { + v.aIngester = NewAudioIngester(sourceAudio) + } + if sourceVideo != "" { + v.vIngester = NewVideoIngester(sourceVideo) + } + + // Setup the interceptors for packets. + api.options = append(api.options, func(api *TestWebRTCAPI) { + // Filter for RTCP packets. + rtcpInterceptor := &RTCPInterceptor{} + rtcpInterceptor.rtcpReader = func(buf []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { + return rtcpInterceptor.nextRTCPReader.Read(buf, attributes) + } + rtcpInterceptor.rtcpWriter = func(pkts []rtcp.Packet, attributes interceptor.Attributes) (int, error) { + return rtcpInterceptor.nextRTCPWriter.Write(pkts, attributes) + } + api.registry.Add(rtcpInterceptor) + + // Filter for ingesters. + if sourceAudio != "" { + api.registry.Add(v.aIngester.audioLevelInterceptor) + } + if sourceVideo != "" { + api.registry.Add(v.vIngester.markerInterceptor) + } + }) + + return v +} + +func (v *TestPublisher) Close() error { + if v.vIngester != nil { + v.vIngester.Close() + } + + if v.aIngester != nil { + v.aIngester.Close() + } + + if v.pc != nil { + v.pc.Close() + } + + return nil +} + +func (v *TestPublisher) SetStreamSuffix(suffix string) *TestPublisher { + v.streamSuffix = suffix + return v +} + +func (v *TestPublisher) Run(ctx context.Context, cancel context.CancelFunc) error { + r := fmt.Sprintf("%v://%v%v", srsSchema, *srsServer, *srsStream) + if v.streamSuffix != "" { + r = fmt.Sprintf("%v-%v", r, v.streamSuffix) + } + sourceVideo, sourceAudio, fps := *srsPublishVideo, *srsPublishAudio, *srsPublishVideoFps + + logger.Tf(ctx, "Start publish url=%v, audio=%v, video=%v, fps=%v", + r, sourceAudio, sourceVideo, fps) + + pc, err := v.api.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + return errors.Wrapf(err, "Create PC") + } + v.pc = pc + + if v.vIngester != nil { + if err := v.vIngester.AddTrack(pc, fps); err != nil { + return errors.Wrapf(err, "Add track") + } + defer v.vIngester.Close() + } + + if v.aIngester != nil { + if err := v.aIngester.AddTrack(pc); err != nil { + return errors.Wrapf(err, "Add track") + } + defer v.aIngester.Close() + } + + offer, err := pc.CreateOffer(nil) + if err != nil { + return errors.Wrapf(err, "Create Offer") + } + + if err := pc.SetLocalDescription(offer); err != nil { + return errors.Wrapf(err, "Set offer %v", offer) + } + + if v.onOffer != nil { + if err := v.onOffer(&offer); err != nil { + return errors.Wrapf(err, "sdp %v %v", offer.Type, offer.SDP) + } + } + + answerSDP, err := apiRtcRequest(ctx, "/rtc/v1/publish", r, offer.SDP) + if err != nil { + return errors.Wrapf(err, "Api request offer=%v", offer.SDP) + } + + // Start a proxy for real server and vnet. + if address, err := parseAddressOfCandidate(answerSDP); err != nil { + return errors.Wrapf(err, "parse address of %v", answerSDP) + } else if err := v.api.proxy.Proxy(v.api.network, address); err != nil { + return errors.Wrapf(err, "proxy %v to %v", v.api.network, address) + } + + answer := &webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, SDP: answerSDP, + } + if v.onAnswer != nil { + if err := v.onAnswer(answer); err != nil { + return errors.Wrapf(err, "on answerSDP") + } + } + + if err := pc.SetRemoteDescription(*answer); err != nil { + return errors.Wrapf(err, "Set answerSDP %v", answerSDP) + } + + logger.Tf(ctx, "State signaling=%v, ice=%v, conn=%v", pc.SignalingState(), pc.ICEConnectionState(), pc.ConnectionState()) + + // ICE state management. + pc.OnICEGatheringStateChange(func(state webrtc.ICEGathererState) { + logger.Tf(ctx, "ICE gather state %v", state) + }) + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + logger.Tf(ctx, "ICE candidate %v %v:%v", candidate.Protocol, candidate.Address, candidate.Port) + + }) + pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + logger.Tf(ctx, "ICE state %v", state) + }) + + pc.OnSignalingStateChange(func(state webrtc.SignalingState) { + logger.Tf(ctx, "Signaling state %v", state) + }) + + if v.aIngester != nil { + v.aIngester.sAudioSender.Transport().OnStateChange(func(state webrtc.DTLSTransportState) { + logger.Tf(ctx, "DTLS state %v", state) + }) + } + + pcDone, pcDoneCancel := context.WithCancel(context.Background()) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + logger.Tf(ctx, "PC state %v", state) + + if state == webrtc.PeerConnectionStateConnected { + pcDoneCancel() + if v.iceReadyCancel != nil { + v.iceReadyCancel() + } + } + + if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { + err = errors.Errorf("Close for PC state %v", state) + cancel() + } + }) + + // Wait for event from context or tracks. + var wg sync.WaitGroup + var finalErr error + + wg.Add(1) + go func() { + defer wg.Done() + defer logger.Tf(ctx, "ingest notify done") + + <-ctx.Done() + + if v.aIngester != nil && v.aIngester.sAudioSender != nil { + v.aIngester.sAudioSender.Stop() + } + + if v.vIngester != nil && v.vIngester.sVideoSender != nil { + v.vIngester.sVideoSender.Stop() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + if v.aIngester == nil { + return + } + + select { + case <-ctx.Done(): + return + case <-pcDone.Done(): + } + + wg.Add(1) + go func() { + defer wg.Done() + defer logger.Tf(ctx, "aingester sender read done") + + buf := make([]byte, 1500) + for ctx.Err() == nil { + if _, _, err := v.aIngester.sAudioSender.Read(buf); err != nil { + return + } + } + }() + + for { + if err := v.aIngester.Ingest(ctx); err != nil { + if err == io.EOF { + logger.Tf(ctx, "aingester retry for %v", err) + continue + } + if err != context.Canceled { + finalErr = errors.Wrapf(err, "audio") + } + + logger.Tf(ctx, "aingester err=%v, final=%v", err, finalErr) + return + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + if v.vIngester == nil { + return + } + + select { + case <-ctx.Done(): + return + case <-pcDone.Done(): + logger.Tf(ctx, "PC(ICE+DTLS+SRTP) done, start ingest video %v", sourceVideo) + } + + wg.Add(1) + go func() { + defer wg.Done() + defer logger.Tf(ctx, "vingester sender read done") + + buf := make([]byte, 1500) + for ctx.Err() == nil { + // The Read() might block in r.rtcpInterceptor.Read(b, a), + // so that the Stop() can not stop it. + if _, _, err := v.vIngester.sVideoSender.Read(buf); err != nil { + return + } + } + }() + + for { + if err := v.vIngester.Ingest(ctx); err != nil { + if err == io.EOF { + logger.Tf(ctx, "vingester retry for %v", err) + continue + } + if err != context.Canceled { + finalErr = errors.Wrapf(err, "video") + } + + logger.Tf(ctx, "vingester err=%v, final=%v", err, finalErr) + return + } + } + }() + + wg.Wait() + + logger.Tf(ctx, "ingester done ctx=%v, final=%v", ctx.Err(), finalErr) + if finalErr != nil { + return finalErr + } + return ctx.Err() +} + +func TestRTCServerVersion(t *testing.T) { + api := fmt.Sprintf("http://%v:1985/api/v1/versions", *srsServer) + req, err := http.NewRequest("POST", api, nil) + if err != nil { + t.Errorf("Request %v", api) + return + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("Do request %v", api) + return + } + + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("Read body of %v", api) + return + } + + obj := struct { + Code int `json:"code"` + Server string `json:"server"` + Data struct { + Major int `json:"major"` + Minor int `json:"minor"` + Revision int `json:"revision"` + Version string `json:"version"` + } `json:"data"` + }{} + if err := json.Unmarshal(b, &obj); err != nil { + t.Errorf("Parse %v", string(b)) + return + } + if obj.Code != 0 { + t.Errorf("Server err code=%v, server=%v", obj.Code, obj.Server) + return + } + if obj.Data.Major == 0 && obj.Data.Minor == 0 { + t.Errorf("Invalid version %v", obj.Data) + return + } +} diff --git a/trunk/3rdparty/srs-bench/vnet/example_udpproxy_test.go b/trunk/3rdparty/srs-bench/vnet/example_udpproxy_test.go new file mode 100644 index 000000000..d54e71653 --- /dev/null +++ b/trunk/3rdparty/srs-bench/vnet/example_udpproxy_test.go @@ -0,0 +1,278 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +package vnet_test + +import ( + "net" + + vnet_proxy "github.com/ossrs/srs-bench/vnet" + "github.com/pion/logging" + "github.com/pion/transport/vnet" +) + +// Proxy many vnet endpoint to one real server endpoint. +// For example: +// vnet(10.0.0.11:5787) => proxy => 192.168.1.10:8000 +// vnet(10.0.0.11:5788) => proxy => 192.168.1.10:8000 +// vnet(10.0.0.11:5789) => proxy => 192.168.1.10:8000 +func ExampleUDPProxyManyToOne() { // nolint:govet + var clientNetwork *vnet.Net + + var serverAddr *net.UDPAddr + if addr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000"); err != nil { + // handle error + } else { + serverAddr = addr + } + + // Setup the network and proxy. + if true { + // Create vnet WAN with one endpoint, please read from + // https://github.com/pion/transport/tree/master/vnet#example-wan-with-one-endpoint-vnet + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + if err != nil { + // handle error + } + + // Create a network and add to router, for example, for client. + clientNetwork = vnet.NewNet(&vnet.NetConfig{ + StaticIP: "10.0.0.11", + }) + if err = router.AddNet(clientNetwork); err != nil { + // handle error + } + + // Start the router. + if err = router.Start(); err != nil { + // handle error + } + defer router.Stop() // nolint:errcheck + + // Create a proxy, bind to the router. + proxy, err := vnet_proxy.NewProxy(router) + if err != nil { + // handle error + } + defer proxy.Close() // nolint:errcheck + + // Start to proxy some addresses, clientNetwork is a hit for proxy, + // that the client in vnet is from this network. + if err := proxy.Proxy(clientNetwork, serverAddr); err != nil { + // handle error + } + } + + // Now, all packets from client, will be proxy to real server, vice versa. + client0, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787") + if err != nil { + // handle error + } + _, _ = client0.WriteTo([]byte("Hello"), serverAddr) + + client1, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5788") + if err != nil { + // handle error + } + _, _ = client1.WriteTo([]byte("Hello"), serverAddr) + + client2, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5789") + if err != nil { + // handle error + } + _, _ = client2.WriteTo([]byte("Hello"), serverAddr) +} + +// Proxy many vnet endpoint to one real server endpoint. +// For example: +// vnet(10.0.0.11:5787) => proxy => 192.168.1.10:8000 +// vnet(10.0.0.11:5788) => proxy => 192.168.1.10:8000 +func ExampleUDPProxyMultileTimes() { // nolint:govet + var clientNetwork *vnet.Net + + var serverAddr *net.UDPAddr + if addr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000"); err != nil { + // handle error + } else { + serverAddr = addr + } + + // Setup the network and proxy. + var proxy *vnet_proxy.UDPProxy + if true { + // Create vnet WAN with one endpoint, please read from + // https://github.com/pion/transport/tree/master/vnet#example-wan-with-one-endpoint-vnet + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + if err != nil { + // handle error + } + + // Create a network and add to router, for example, for client. + clientNetwork = vnet.NewNet(&vnet.NetConfig{ + StaticIP: "10.0.0.11", + }) + if err = router.AddNet(clientNetwork); err != nil { + // handle error + } + + // Start the router. + if err = router.Start(); err != nil { + // handle error + } + defer router.Stop() // nolint:errcheck + + // Create a proxy, bind to the router. + proxy, err = vnet_proxy.NewProxy(router) + if err != nil { + // handle error + } + defer proxy.Close() // nolint:errcheck + } + + if true { + // Start to proxy some addresses, clientNetwork is a hit for proxy, + // that the client in vnet is from this network. + if err := proxy.Proxy(clientNetwork, serverAddr); err != nil { + // handle error + } + + // Now, all packets from client, will be proxy to real server, vice versa. + client0, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787") + if err != nil { + // handle error + } + _, _ = client0.WriteTo([]byte("Hello"), serverAddr) + } + + if true { + // It's ok to proxy multiple times, for example, the publisher and player + // might need to proxy when got answer. + if err := proxy.Proxy(clientNetwork, serverAddr); err != nil { + // handle error + } + + client1, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5788") + if err != nil { + // handle error + } + _, _ = client1.WriteTo([]byte("Hello"), serverAddr) + } +} + +// Proxy one vnet endpoint to one real server endpoint. +// For example: +// vnet(10.0.0.11:5787) => proxy0 => 192.168.1.10:8000 +// vnet(10.0.0.11:5788) => proxy1 => 192.168.1.10:8001 +// vnet(10.0.0.11:5789) => proxy2 => 192.168.1.10:8002 +func ExampleUDPProxyOneToOne() { // nolint:govet + var clientNetwork *vnet.Net + + var serverAddr0 *net.UDPAddr + if addr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000"); err != nil { + // handle error + } else { + serverAddr0 = addr + } + + var serverAddr1 *net.UDPAddr + if addr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8001"); err != nil { + // handle error + } else { + serverAddr1 = addr + } + + var serverAddr2 *net.UDPAddr + if addr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8002"); err != nil { + // handle error + } else { + serverAddr2 = addr + } + + // Setup the network and proxy. + if true { + // Create vnet WAN with one endpoint, please read from + // https://github.com/pion/transport/tree/master/vnet#example-wan-with-one-endpoint-vnet + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + if err != nil { + // handle error + } + + // Create a network and add to router, for example, for client. + clientNetwork = vnet.NewNet(&vnet.NetConfig{ + StaticIP: "10.0.0.11", + }) + if err = router.AddNet(clientNetwork); err != nil { + // handle error + } + + // Start the router. + if err = router.Start(); err != nil { + // handle error + } + defer router.Stop() // nolint:errcheck + + // Create a proxy, bind to the router. + proxy, err := vnet_proxy.NewProxy(router) + if err != nil { + // handle error + } + defer proxy.Close() // nolint:errcheck + + // Start to proxy some addresses, clientNetwork is a hit for proxy, + // that the client in vnet is from this network. + if err := proxy.Proxy(clientNetwork, serverAddr0); err != nil { + // handle error + } + if err := proxy.Proxy(clientNetwork, serverAddr1); err != nil { + // handle error + } + if err := proxy.Proxy(clientNetwork, serverAddr2); err != nil { + // handle error + } + } + + // Now, all packets from client, will be proxy to real server, vice versa. + client0, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787") + if err != nil { + // handle error + } + _, _ = client0.WriteTo([]byte("Hello"), serverAddr0) + + client1, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5788") + if err != nil { + // handle error + } + _, _ = client1.WriteTo([]byte("Hello"), serverAddr1) + + client2, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5789") + if err != nil { + // handle error + } + _, _ = client2.WriteTo([]byte("Hello"), serverAddr2) +} diff --git a/trunk/3rdparty/srs-bench/vnet/udpproxy.go b/trunk/3rdparty/srs-bench/vnet/udpproxy.go new file mode 100644 index 000000000..c21fe4603 --- /dev/null +++ b/trunk/3rdparty/srs-bench/vnet/udpproxy.go @@ -0,0 +1,222 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +package vnet + +import ( + "net" + "sync" + "time" + + "github.com/pion/transport/vnet" +) + +// A UDP proxy between real server(net.UDPConn) and vnet.UDPConn. +// +// High level design: +// .............................................. +// : Virtual Network (vnet) : +// : : +// +-------+ * 1 +----+ +--------+ : +// | :App |------------>|:Net|--o<-----|:Router | ............................. +// +-------+ +----+ | | : UDPProxy : +// : | | +----+ +---------+ +---------+ +--------+ +// : | |--->o--|:Net|-->o-| vnet. |-->o-| net. |--->-| :Real | +// : | | +----+ | UDPConn | | UDPConn | | Server | +// : | | : +---------+ +---------+ +--------+ +// : | | ............................: +// : +--------+ : +// ............................................... +// +// The whole big picture: +// ...................................... +// : Virtual Network (vnet) : +// : : +// +-------+ * 1 +----+ +--------+ : +// | :App |------------>|:Net|--o<-----|:Router | ............................. +// +-------+ +----+ | | : UDPProxy : +// +-----------+ * 1 +----+ | | +----+ +---------+ +---------+ +--------+ +// |:STUNServer|-------->|:Net|--o<-----| |--->o--|:Net|-->o-| vnet. |-->o-| net. |--->-| :Real | +// +-----------+ +----+ | | +----+ | UDPConn | | UDPConn | | Server | +// +-----------+ * 1 +----+ | | : +---------+ +---------+ +--------+ +// |:TURNServer|-------->|:Net|--o<-----| | ............................: +// +-----------+ +----+ [1] | | : +// : 1 | | 1 <> : +// : +---<>| |<>----+ [2] : +// : | +--------+ | : +// To form | *| v 0..1 : +// a subnet tree | o [3] +-----+ : +// : | ^ |:NAT | : +// : | | +-----+ : +// : +-------+ : +// ...................................... +type UDPProxy struct { + // The router bind to. + router *vnet.Router + + // Each vnet source, bind to a real socket to server. + // key is real server addr, which is net.Addr + // value is *aUDPProxyWorker + workers sync.Map + + // For each endpoint, we never know when to start and stop proxy, + // so we stop the endpoint when timeout. + timeout time.Duration + + // For utest, to mock the target real server. + // Optional, use the address of received client packet. + mockRealServerAddr *net.UDPAddr +} + +// NewProxy create a proxy, the router for this proxy belongs/bind to. If need to proxy for +// please create a new proxy for each router. For all addresses we proxy, we will create a +// vnet.Net in this router and proxy all packets. +func NewProxy(router *vnet.Router) (*UDPProxy, error) { + v := &UDPProxy{router: router, timeout: 2 * time.Minute} + return v, nil +} + +// Close the proxy, stop all workers. +func (v *UDPProxy) Close() error { + // nolint:godox // TODO: FIXME: Do cleanup. + return nil +} + +// Proxy starts a worker for server, ignore if already started. +func (v *UDPProxy) Proxy(client *vnet.Net, server *net.UDPAddr) error { + // Note that even if the worker exists, it's also ok to create a same worker, + // because the router will use the last one, and the real server will see a address + // change event after we switch to the next worker. + if _, ok := v.workers.Load(server.String()); ok { + // nolint:godox // TODO: Need to restart the stopped worker? + return nil + } + + // Not exists, create a new one. + worker := &aUDPProxyWorker{ + router: v.router, mockRealServerAddr: v.mockRealServerAddr, + } + v.workers.Store(server.String(), worker) + + return worker.Proxy(client, server) +} + +// A proxy worker for a specified proxy server. +type aUDPProxyWorker struct { + router *vnet.Router + mockRealServerAddr *net.UDPAddr + + // Each vnet source, bind to a real socket to server. + // key is vnet client addr, which is net.Addr + // value is *net.UDPConn + endpoints sync.Map +} + +func (v *aUDPProxyWorker) Proxy(client *vnet.Net, serverAddr *net.UDPAddr) error { // nolint:gocognit + // Create vnet for real server by serverAddr. + nw := vnet.NewNet(&vnet.NetConfig{ + StaticIP: serverAddr.IP.String(), + }) + if err := v.router.AddNet(nw); err != nil { + return err + } + + // We must create a "same" vnet.UDPConn as the net.UDPConn, + // which has the same ip:port, to copy packets between them. + vnetSocket, err := nw.ListenUDP("udp4", serverAddr) + if err != nil { + return err + } + + // Start a proxy goroutine. + var findEndpointBy func(addr net.Addr) (*net.UDPConn, error) + // nolint:godox // TODO: FIXME: Do cleanup. + go func() { + buf := make([]byte, 1500) + + for { + n, addr, err := vnetSocket.ReadFrom(buf) + if err != nil { + return + } + + if n <= 0 || addr == nil { + continue // Drop packet + } + + realSocket, err := findEndpointBy(addr) + if err != nil { + continue // Drop packet. + } + + if _, err := realSocket.Write(buf[:n]); err != nil { + return + } + } + }() + + // Got new vnet client, start a new endpoint. + findEndpointBy = func(addr net.Addr) (*net.UDPConn, error) { + // Exists binding. + if value, ok := v.endpoints.Load(addr.String()); ok { + // Exists endpoint, reuse it. + return value.(*net.UDPConn), nil + } + + // The real server we proxy to, for utest to mock it. + realAddr := serverAddr + if v.mockRealServerAddr != nil { + realAddr = v.mockRealServerAddr + } + + // Got new vnet client, create new endpoint. + realSocket, err := net.DialUDP("udp4", nil, realAddr) + if err != nil { + return nil, err + } + + // Bind address. + v.endpoints.Store(addr.String(), realSocket) + + // Got packet from real serverAddr, we should proxy it to vnet. + // nolint:godox // TODO: FIXME: Do cleanup. + go func(vnetClientAddr net.Addr) { + buf := make([]byte, 1500) + for { + n, _, err := realSocket.ReadFrom(buf) + if err != nil { + return + } + + if n <= 0 { + continue // Drop packet + } + + if _, err := vnetSocket.WriteTo(buf[:n], vnetClientAddr); err != nil { + return + } + } + }(addr) + + return realSocket, nil + } + + return nil +} diff --git a/trunk/3rdparty/srs-bench/vnet/udpproxy_direct.go b/trunk/3rdparty/srs-bench/vnet/udpproxy_direct.go new file mode 100644 index 000000000..6d49494ed --- /dev/null +++ b/trunk/3rdparty/srs-bench/vnet/udpproxy_direct.go @@ -0,0 +1,61 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +package vnet + +import ( + "net" +) + +func (v *UDPProxy) Deliver(sourceAddr, destAddr net.Addr, b []byte) (nn int, err error) { + v.workers.Range(func(key, value interface{}) bool { + if nn, err := value.(*aUDPProxyWorker).Deliver(sourceAddr, destAddr, b); err != nil { + return false // Fail, abort. + } else if nn == len(b) { + return false // Done. + } + + return true // Deliver by next worker. + }) + return +} + +func (v *aUDPProxyWorker) Deliver(sourceAddr, destAddr net.Addr, b []byte) (nn int, err error) { + addr, ok := sourceAddr.(*net.UDPAddr) + if !ok { + return 0, nil + } + + // TODO: Support deliver packet from real server to vnet. + // If packet is from vent, proxy to real server. + var realSocket *net.UDPConn + if value, ok := v.endpoints.Load(addr.String()); !ok { + return 0, nil + } else { + realSocket = value.(*net.UDPConn) + } + + // Send to real server. + if _, err := realSocket.Write(b); err != nil { + return 0, err + } + + return len(b), nil +} diff --git a/trunk/3rdparty/srs-bench/vnet/udpproxy_direct_test.go b/trunk/3rdparty/srs-bench/vnet/udpproxy_direct_test.go new file mode 100644 index 000000000..b347c682c --- /dev/null +++ b/trunk/3rdparty/srs-bench/vnet/udpproxy_direct_test.go @@ -0,0 +1,184 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +package vnet + +import ( + "context" + "fmt" + "github.com/pion/logging" + "github.com/pion/transport/vnet" + "net" + "sync" + "testing" + "time" +) + +// vnet client: +// 10.0.0.11:5787 +// proxy to real server: +// 192.168.1.10:8000 +func TestUDPProxyDirectDeliver(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var r0, r1, r2 error + defer func() { + if r0 != nil || r1 != nil || r2 != nil { + t.Errorf("fail for ctx=%v, r0=%v, r1=%v, r2=%v", ctx.Err(), r0, r1, r2) + } + }() + + var wg sync.WaitGroup + defer wg.Wait() + + // Timeout, fail + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + select { + case <-ctx.Done(): + case <-time.After(time.Duration(*testTimeout) * time.Millisecond): + r2 = fmt.Errorf("timeout") + } + }() + + // For utest, we always proxy vnet packets to the random port we listen to. + mockServer := NewMockUDPEchoServer() + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + if err := mockServer.doMockUDPServer(ctx); err != nil { + r0 = err + } + }() + + // Create a vent and proxy. + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + // When real server is ready, start the vnet test. + select { + case <-ctx.Done(): + return + case <-mockServer.realServerReady.Done(): + } + + doVnetProxy := func() error { + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + if err != nil { + return err + } + + clientNetwork := vnet.NewNet(&vnet.NetConfig{ + StaticIP: "10.0.0.11", + }) + if err = router.AddNet(clientNetwork); err != nil { + return err + } + + if err := router.Start(); err != nil { + return err + } + defer router.Stop() + + proxy, err := NewProxy(router) + if err != nil { + return err + } + defer proxy.Close() + + // For utest, mock the target real server. + proxy.mockRealServerAddr = mockServer.realServerAddr + + // The real server address to proxy to. + // Note that for utest, we will proxy to a local address. + serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000") + if err != nil { + return err + } + + if err := proxy.Proxy(clientNetwork, serverAddr); err != nil { + return err + } + + // Now, all packets from client, will be proxy to real server, vice versa. + client, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787") + if err != nil { + return err + } + + // When system quit, interrupt client. + selfKill, selfKillCancel := context.WithCancel(context.Background()) + go func() { + <-ctx.Done() + selfKillCancel() + client.Close() + }() + + // Write by vnet client. + if _, err := client.WriteTo([]byte("Hello"), serverAddr); err != nil { + return err + } + + buf := make([]byte, 1500) + if n, addr, err := client.ReadFrom(buf); err != nil { + if selfKill.Err() == context.Canceled { + return nil + } + return err + } else if n != 5 || addr == nil { + return fmt.Errorf("n=%v, addr=%v", n, addr) + } else if string(buf[:n]) != "Hello" { + return fmt.Errorf("data %v", buf[:n]) + } + + // Directly write, simulate the ARQ packet. + // We should got the echo packet also. + if _, err := proxy.Deliver(client.LocalAddr(), serverAddr, []byte("Hello")); err != nil { + return err + } + + if n, addr, err := client.ReadFrom(buf); err != nil { + if selfKill.Err() == context.Canceled { + return nil + } + return err + } else if n != 5 || addr == nil { + return fmt.Errorf("n=%v, addr=%v", n, addr) + } else if string(buf[:n]) != "Hello" { + return fmt.Errorf("data %v", buf[:n]) + } + + return err + } + + if err := doVnetProxy(); err != nil { + r1 = err + } + }() +} diff --git a/trunk/3rdparty/srs-bench/vnet/udpproxy_test.go b/trunk/3rdparty/srs-bench/vnet/udpproxy_test.go new file mode 100644 index 000000000..c0c1c4a2b --- /dev/null +++ b/trunk/3rdparty/srs-bench/vnet/udpproxy_test.go @@ -0,0 +1,615 @@ +// The MIT License (MIT) +// +// Copyright (c) 2021 srs-bench(ossrs) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +package vnet + +import ( + "context" + "errors" + "flag" + "fmt" + "net" + "os" + "sync" + "testing" + "time" + + "github.com/pion/logging" + "github.com/pion/transport/vnet" +) + +type MockUDPEchoServer struct { + realServerAddr *net.UDPAddr + realServerReady context.Context + realServerReadyCancel context.CancelFunc +} + +func NewMockUDPEchoServer() *MockUDPEchoServer { + v := &MockUDPEchoServer{} + v.realServerReady, v.realServerReadyCancel = context.WithCancel(context.Background()) + return v +} + +func (v *MockUDPEchoServer) doMockUDPServer(ctx context.Context) error { + // Listen to a random port. + laddr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:0") + if err != nil { + return err + } + + conn, err := net.ListenUDP("udp4", laddr) + if err != nil { + return err + } + + v.realServerAddr = conn.LocalAddr().(*net.UDPAddr) + v.realServerReadyCancel() + + // When system quit, interrupt client. + selfKill, selfKillCancel := context.WithCancel(context.Background()) + go func() { + <-ctx.Done() + selfKillCancel() + _ = conn.Close() + }() + + // Note that if they has the same ID, the address should not changed. + addrs := make(map[string]net.Addr) + + // Start an echo UDP server. + buf := make([]byte, 1500) + for ctx.Err() == nil { + n, addr, err := conn.ReadFrom(buf) + if err != nil { + if errors.Is(selfKill.Err(), context.Canceled) { + return nil + } + return err + } else if n == 0 || addr == nil { + return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:goerr113 + } else if nn, err := conn.WriteTo(buf[:n], addr); err != nil { + return err + } else if nn != n { + return fmt.Errorf("nn=%v, n=%v", nn, n) // nolint:goerr113 + } + + // Check the address, shold not change, use content as ID. + clientID := string(buf[:n]) + if oldAddr, ok := addrs[clientID]; ok && oldAddr.String() != addr.String() { + return fmt.Errorf("address change %v to %v", oldAddr.String(), addr.String()) // nolint:goerr113 + } + addrs[clientID] = addr + } + + return nil +} + +var testTimeout = flag.Int("timeout", 5000, "For each case, the timeout in ms") // nolint:gochecknoglobals + +func TestMain(m *testing.M) { + flag.Parse() + os.Exit(m.Run()) +} + +// vnet client: +// 10.0.0.11:5787 +// proxy to real server: +// 192.168.1.10:8000 +func TestUDPProxyOne2One(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var r0, r1, r2 error + defer func() { + if r0 != nil || r1 != nil || r2 != nil { + t.Errorf("fail for ctx=%v, r0=%v, r1=%v, r2=%v", ctx.Err(), r0, r1, r2) + } + }() + + var wg sync.WaitGroup + defer wg.Wait() + + // Timeout, fail + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + select { + case <-ctx.Done(): + case <-time.After(time.Duration(*testTimeout) * time.Millisecond): + r2 = fmt.Errorf("timeout") // nolint:goerr113 + } + }() + + // For utest, we always proxy vnet packets to the random port we listen to. + mockServer := NewMockUDPEchoServer() + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + if err := mockServer.doMockUDPServer(ctx); err != nil { + r0 = err + } + }() + + // Create a vent and proxy. + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + // When real server is ready, start the vnet test. + select { + case <-ctx.Done(): + return + case <-mockServer.realServerReady.Done(): + } + + doVnetProxy := func() error { + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + if err != nil { + return err + } + + clientNetwork := vnet.NewNet(&vnet.NetConfig{ + StaticIP: "10.0.0.11", + }) + if err = router.AddNet(clientNetwork); err != nil { + return err + } + + if err = router.Start(); err != nil { + return err + } + defer router.Stop() // nolint:errcheck + + proxy, err := NewProxy(router) + if err != nil { + return err + } + defer proxy.Close() // nolint:errcheck + + // For utest, mock the target real server. + proxy.mockRealServerAddr = mockServer.realServerAddr + + // The real server address to proxy to. + // Note that for utest, we will proxy to a local address. + serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000") + if err != nil { + return err + } + + if err = proxy.Proxy(clientNetwork, serverAddr); err != nil { + return err + } + + // Now, all packets from client, will be proxy to real server, vice versa. + client, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787") + if err != nil { + return err + } + + // When system quit, interrupt client. + selfKill, selfKillCancel := context.WithCancel(context.Background()) + go func() { + <-ctx.Done() + selfKillCancel() + _ = client.Close() // nolint:errcheck + }() + + for i := 0; i < 10; i++ { + if _, err = client.WriteTo([]byte("Hello"), serverAddr); err != nil { + return err + } + + var n int + var addr net.Addr + buf := make([]byte, 1500) + if n, addr, err = client.ReadFrom(buf); err != nil { // nolint:gocritic + if errors.Is(selfKill.Err(), context.Canceled) { + return nil + } + return err + } else if n != 5 || addr == nil { + return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:goerr113 + } else if string(buf[:n]) != "Hello" { + return fmt.Errorf("data %v", buf[:n]) // nolint:goerr113 + } + + // Wait for awhile for each UDP packet, to simulate real network. + select { + case <-ctx.Done(): + return nil + case <-time.After(30 * time.Millisecond): + } + } + + return err + } + + if err := doVnetProxy(); err != nil { + r1 = err + } + }() +} + +// vnet client: +// 10.0.0.11:5787 +// 10.0.0.11:5788 +// proxy to real server: +// 192.168.1.10:8000 +func TestUDPProxyTwo2One(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var r0, r1, r2, r3 error + defer func() { + if r0 != nil || r1 != nil || r2 != nil || r3 != nil { + t.Errorf("fail for ctx=%v, r0=%v, r1=%v, r2=%v, r3=%v", ctx.Err(), r0, r1, r2, r3) + } + }() + + var wg sync.WaitGroup + defer wg.Wait() + + // Timeout, fail + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + select { + case <-ctx.Done(): + case <-time.After(time.Duration(*testTimeout) * time.Millisecond): + r2 = fmt.Errorf("timeout") // nolint:goerr113 + } + }() + + // For utest, we always proxy vnet packets to the random port we listen to. + mockServer := NewMockUDPEchoServer() + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + if err := mockServer.doMockUDPServer(ctx); err != nil { + r0 = err + } + }() + + // Create a vent and proxy. + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + // When real server is ready, start the vnet test. + select { + case <-ctx.Done(): + return + case <-mockServer.realServerReady.Done(): + } + + doVnetProxy := func() error { + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + if err != nil { + return err + } + + clientNetwork := vnet.NewNet(&vnet.NetConfig{ + StaticIP: "10.0.0.11", + }) + if err = router.AddNet(clientNetwork); err != nil { + return err + } + + if err = router.Start(); err != nil { + return err + } + defer router.Stop() // nolint:errcheck + + proxy, err := NewProxy(router) + if err != nil { + return err + } + defer proxy.Close() // nolint:errcheck + + // For utest, mock the target real server. + proxy.mockRealServerAddr = mockServer.realServerAddr + + // The real server address to proxy to. + // Note that for utest, we will proxy to a local address. + serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000") + if err != nil { + return err + } + + if err = proxy.Proxy(clientNetwork, serverAddr); err != nil { + return err + } + + handClient := func(address, echoData string) error { + // Now, all packets from client, will be proxy to real server, vice versa. + client, err := clientNetwork.ListenPacket("udp4", address) // nolint:govet + if err != nil { + return err + } + + // When system quit, interrupt client. + selfKill, selfKillCancel := context.WithCancel(context.Background()) + go func() { + <-ctx.Done() + selfKillCancel() + _ = client.Close() + }() + + for i := 0; i < 10; i++ { + if _, err := client.WriteTo([]byte(echoData), serverAddr); err != nil { // nolint:govet + return err + } + + var n int + var addr net.Addr + buf := make([]byte, 1400) + if n, addr, err = client.ReadFrom(buf); err != nil { // nolint:gocritic + if errors.Is(selfKill.Err(), context.Canceled) { + return nil + } + return err + } else if n != len(echoData) || addr == nil { + return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:goerr113 + } else if string(buf[:n]) != echoData { + return fmt.Errorf("check data %v", buf[:n]) // nolint:goerr113 + } + + // Wait for awhile for each UDP packet, to simulate real network. + select { + case <-ctx.Done(): + return nil + case <-time.After(30 * time.Millisecond): + } + } + + return nil + } + + client0, client0Cancel := context.WithCancel(context.Background()) + go func() { + defer client0Cancel() + address := "10.0.0.11:5787" + if err := handClient(address, "Hello"); err != nil { // nolint:govet + r3 = fmt.Errorf("client %v err %v", address, err) // nolint:goerr113 + } + }() + + client1, client1Cancel := context.WithCancel(context.Background()) + go func() { + defer client1Cancel() + address := "10.0.0.11:5788" + if err := handClient(address, "World"); err != nil { // nolint:govet + r3 = fmt.Errorf("client %v err %v", address, err) // nolint:goerr113 + } + }() + + select { + case <-ctx.Done(): + case <-client0.Done(): + case <-client1.Done(): + } + + return err + } + + if err := doVnetProxy(); err != nil { + r1 = err + } + }() +} + +// vnet client: +// 10.0.0.11:5787 +// proxy to real server: +// 192.168.1.10:8000 +// +// vnet client: +// 10.0.0.11:5788 +// proxy to real server: +// 192.168.1.10:8000 +func TestUDPProxyProxyTwice(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var r0, r1, r2, r3 error + defer func() { + if r0 != nil || r1 != nil || r2 != nil || r3 != nil { + t.Errorf("fail for ctx=%v, r0=%v, r1=%v, r2=%v, r3=%v", ctx.Err(), r0, r1, r2, r3) + } + }() + + var wg sync.WaitGroup + defer wg.Wait() + + // Timeout, fail + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + select { + case <-ctx.Done(): + case <-time.After(time.Duration(*testTimeout) * time.Millisecond): + r2 = fmt.Errorf("timeout") // nolint:goerr113 + } + }() + + // For utest, we always proxy vnet packets to the random port we listen to. + mockServer := NewMockUDPEchoServer() + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + if err := mockServer.doMockUDPServer(ctx); err != nil { + r0 = err + } + }() + + // Create a vent and proxy. + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + // When real server is ready, start the vnet test. + select { + case <-ctx.Done(): + return + case <-mockServer.realServerReady.Done(): + } + + doVnetProxy := func() error { + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + if err != nil { + return err + } + + clientNetwork := vnet.NewNet(&vnet.NetConfig{ + StaticIP: "10.0.0.11", + }) + if err = router.AddNet(clientNetwork); err != nil { + return err + } + + if err = router.Start(); err != nil { + return err + } + defer router.Stop() // nolint:errcheck + + proxy, err := NewProxy(router) + if err != nil { + return err + } + defer proxy.Close() // nolint:errcheck + + // For utest, mock the target real server. + proxy.mockRealServerAddr = mockServer.realServerAddr + + // The real server address to proxy to. + // Note that for utest, we will proxy to a local address. + serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000") + if err != nil { + return err + } + + handClient := func(address, echoData string) error { + // We proxy multiple times, for example, in publisher and player, both call + // the proxy when got answer. + if err := proxy.Proxy(clientNetwork, serverAddr); err != nil { // nolint:govet + return err + } + + // Now, all packets from client, will be proxy to real server, vice versa. + client, err := clientNetwork.ListenPacket("udp4", address) // nolint:govet + if err != nil { + return err + } + + // When system quit, interrupt client. + selfKill, selfKillCancel := context.WithCancel(context.Background()) + go func() { + <-ctx.Done() + selfKillCancel() + _ = client.Close() // nolint:errcheck + }() + + for i := 0; i < 10; i++ { + if _, err = client.WriteTo([]byte(echoData), serverAddr); err != nil { + return err + } + + buf := make([]byte, 1500) + if n, addr, err := client.ReadFrom(buf); err != nil { // nolint:gocritic,govet + if errors.Is(selfKill.Err(), context.Canceled) { + return nil + } + return err + } else if n != len(echoData) || addr == nil { + return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:goerr113 + } else if string(buf[:n]) != echoData { + return fmt.Errorf("verify data %v", buf[:n]) // nolint:goerr113 + } + + // Wait for awhile for each UDP packet, to simulate real network. + select { + case <-ctx.Done(): + return nil + case <-time.After(30 * time.Millisecond): + } + } + + return nil + } + + client0, client0Cancel := context.WithCancel(context.Background()) + go func() { + defer client0Cancel() + address := "10.0.0.11:5787" + if err = handClient(address, "Hello"); err != nil { + r3 = fmt.Errorf("client %v err %v", address, err) // nolint:goerr113 + } + }() + + client1, client1Cancel := context.WithCancel(context.Background()) + go func() { + defer client1Cancel() + + // Slower than client0, 60ms. + // To simulate the real player or publisher, might not start at the same time. + select { + case <-ctx.Done(): + return + case <-time.After(150 * time.Millisecond): + } + + address := "10.0.0.11:5788" + if err = handClient(address, "World"); err != nil { + r3 = fmt.Errorf("client %v err %v", address, err) // nolint:goerr113 + } + }() + + select { + case <-ctx.Done(): + case <-client0.Done(): + case <-client1.Done(): + } + + return err + } + + if err := doVnetProxy(); err != nil { + r1 = err + } + }() +} diff --git a/trunk/src/app/srs_app_rtc_dtls.cpp b/trunk/src/app/srs_app_rtc_dtls.cpp index daa894fc1..a528f2631 100644 --- a/trunk/src/app/srs_app_rtc_dtls.cpp +++ b/trunk/src/app/srs_app_rtc_dtls.cpp @@ -35,6 +35,7 @@ using namespace std; #include #include #include +#include #include #include @@ -43,6 +44,35 @@ using namespace std; // Defined in HTTP/HTTPS client. extern int srs_verify_callback(int preverify_ok, X509_STORE_CTX *ctx); +// Setup the openssl timeout for DTLS packet. +// @see https://www.openssl.org/docs/man1.1.1/man3/DTLS_set_timer_cb.html +// +// Use step timeout for ARQ, [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200] in ms, +// then total timeout is sum([50, 100, 200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200]) = 102350ms. +// +// @remark The connection might be closed for timeout in about 30s by default, which stop the DTLS ARQ. +unsigned int dtls_timer_cb(SSL* dtls, unsigned int previous_us) +{ + SrsDtlsImpl* dtls_impl = (SrsDtlsImpl*)SSL_get_ex_data(dtls, 0); + srs_assert(dtls_impl); + + // Double the timeout. Note that it can be 0. + unsigned int timeout_us = previous_us * 2; + + // If previous_us is 0, for example, the HelloVerifyRequest, we should response it ASAP. + // When got ServerHello, we should reset the timer. + if (previous_us == 0 || dtls_impl->should_reset_timer()) { + timeout_us = 50 * 1000; // in us + } + + // Never exceed the max timeout. + timeout_us = srs_min(timeout_us, 30 * 1000 * 1000); // in us + + srs_info("DTLS: ARQ timer cb timeout=%ums, previous=%ums", timeout_us/1000, previous_us/1000); + + return timeout_us; +} + // Print the information of SSL, DTLS alert as such. void ssl_on_info(const SSL* dtls, int where, int ret) { @@ -377,8 +407,6 @@ SrsDtlsImpl::SrsDtlsImpl(ISrsDtlsCallback* callback) callback_ = callback; handshake_done_for_us = false; - last_outgoing_packet_cache = new uint8_t[kRtpPacketSize]; - nn_last_outgoing_packet = 0; nn_arq_packets = 0; version_ = SrsDtlsVersionAuto; @@ -401,8 +429,6 @@ SrsDtlsImpl::~SrsDtlsImpl() SSL_free(dtls); dtls = NULL; } - - srs_freepa(last_outgoing_packet_cache); } srs_error_t SrsDtlsImpl::initialize(std::string version, std::string role) @@ -431,6 +457,19 @@ srs_error_t SrsDtlsImpl::initialize(std::string version, std::string role) SSL_set_options(dtls, SSL_OP_NO_QUERY_MTU); SSL_set_mtu(dtls, kRtpPacketSize); + // @see https://linux.die.net/man/3/openssl_version_number + // MM NN FF PP S + // 0x1010102fL = 0x1 01 01 02 fL // 1.1.1b release + // MM(major) = 0x1 // 1.* + // NN(minor) = 0x01 // 1.1.* + // FF(fix) = 0x01 // 1.1.1* + // PP(patch) = 'a' + 0x02 - 1 = 'b' // 1.1.1b * + // S(status) = 0xf = release // 1.1.1b release + // @note Status 0 for development, 1 to e for betas 1 to 14, and f for release. +#if OPENSSL_VERSION_NUMBER >= 0x1010102fL // 1.1.1b + DTLS_set_timer_cb(dtls, dtls_timer_cb); +#endif + if ((bio_in = BIO_new(BIO_s_mem())) == NULL) { return srs_error_new(ERROR_OpenSslBIONew, "BIO_new in"); } @@ -461,6 +500,12 @@ srs_error_t SrsDtlsImpl::do_on_dtls(char* data, int nb_data) { srs_error_t err = srs_success; + // When already done, only for us, we still got message from client, + // it might be our response is lost, or application data. + if (handshake_done_for_us) { + srs_info("DTLS: After done, got %d bytes", nb_data); + } + int r0 = 0; // TODO: FIXME: Why reset it before writing? if ((r0 = BIO_reset(bio_in)) != 1) { @@ -471,7 +516,7 @@ srs_error_t SrsDtlsImpl::do_on_dtls(char* data, int nb_data) } // Trace the detail of DTLS packet. - state_trace((uint8_t*)data, nb_data, true, r0, SSL_ERROR_NONE, false, false); + state_trace((uint8_t*)data, nb_data, true, r0, SSL_ERROR_NONE, false); if ((r0 = BIO_write(bio_in, data, nb_data)) <= 0) { // TODO: 0 or -1 maybe block, use BIO_should_retry to check. @@ -502,6 +547,18 @@ srs_error_t SrsDtlsImpl::do_on_dtls(char* data, int nb_data) if (r1 != SSL_ERROR_WANT_READ && r1 != SSL_ERROR_WANT_WRITE) { break; } + + // We got data in memory, which can not read by SSL_read, generally, it's handshake data. + uint8_t* data = NULL; + int size = BIO_get_mem_data(bio_out, (char**)&data); + + // Logging when got SSL original data. + state_trace((uint8_t*)data, size, false, r0, r1, false); + + if (size > 0 && (err = callback_->write_dtls_data(data, size)) != srs_success) { + return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size, + srs_string_dumps_hex((char*)data, size, 32).c_str()); + } continue; } @@ -521,6 +578,12 @@ srs_error_t SrsDtlsImpl::do_handshake() { srs_error_t err = srs_success; + // Done for use, ignore handshake packets. If need to ARQ the handshake packets, + // we should use SSL_read to handle it. + if (handshake_done_for_us) { + return err; + } + // Do handshake and get the result. int r0 = SSL_do_handshake(dtls); int r1 = SSL_get_error(dtls, r0); @@ -537,18 +600,10 @@ srs_error_t SrsDtlsImpl::do_handshake() // The data to send out to peer. uint8_t* data = NULL; - int size = BIO_get_mem_data(bio_out, &data); + int size = BIO_get_mem_data(bio_out, (char**)&data); - // Callback when got SSL original data. - bool cache = false; - on_ssl_out_data(data, size, cache); - state_trace((uint8_t*)data, size, false, r0, r1, cache, false); - - // Update the packet cache. - if (size > 0 && data != last_outgoing_packet_cache && size < kRtpPacketSize) { - memcpy(last_outgoing_packet_cache, data, size); - nn_last_outgoing_packet = size; - } + // Logging when got SSL original data. + state_trace((uint8_t*)data, size, false, r0, r1, false); // Callback for the final output data, before send-out. if ((err = on_final_out_data(data, size)) != srs_success) { @@ -569,7 +624,7 @@ srs_error_t SrsDtlsImpl::do_handshake() return err; } -void SrsDtlsImpl::state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq) +void SrsDtlsImpl::state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool arq) { // change_cipher_spec(20), alert(21), handshake(22), application_data(23) // @see https://tools.ietf.org/html/rfc2246#section-6.2.1 @@ -588,8 +643,8 @@ void SrsDtlsImpl::state_trace(uint8_t* data, int length, bool incoming, int r0, handshake_type = (uint8_t)data[13]; } - srs_trace("DTLS: %s %s, done=%u, cache=%u, arq=%u/%u, r0=%d, r1=%d, len=%u, cnt=%u, size=%u, hs=%u", - (is_dtls_client()? "Active":"Passive"), (incoming? "RECV":"SEND"), handshake_done_for_us, cache, arq, + srs_trace("DTLS: State %s %s, done=%u, arq=%u/%u, r0=%d, r1=%d, len=%u, cnt=%u, size=%u, hs=%u", + (is_dtls_client()? "Active":"Passive"), (incoming? "RECV":"SEND"), handshake_done_for_us, arq, nn_arq_packets, r0, r1, length, content_type, size, handshake_type); } @@ -640,15 +695,9 @@ SrsDtlsClientImpl::SrsDtlsClientImpl(ISrsDtlsCallback* callback) : SrsDtlsImpl(c trd = NULL; state_ = SrsDtlsStateInit; - // The first wait and base interval for ARQ. - arq_interval = 10 * SRS_UTIME_MILLISECONDS; - - // Use step timeout for ARQ, the total timeout is sum(arq_to_ratios)*arq_interval. - // for example, if arq_interval is 10ms, arq_to_ratios is [3, 6, 9, 15, 20, 40, 80, 160], - // then total timeout is sum([3, 6, 9, 15, 20, 40, 80, 160]) * 10ms = 3330ms. - int ratios[] = {3, 6, 9, 15, 20, 40, 80, 160}; - srs_assert(sizeof(arq_to_ratios) == sizeof(ratios)); - memcpy(arq_to_ratios, ratios, sizeof(ratios)); + // the max dtls retry num is 12 in openssl. + arq_max_retry = 12 * 2; // Max ARQ limit shared for ClientHello and Certificate. + reset_timer_ = true; } SrsDtlsClientImpl::~SrsDtlsClientImpl() @@ -672,60 +721,47 @@ srs_error_t SrsDtlsClientImpl::initialize(std::string version, std::string role) } srs_error_t SrsDtlsClientImpl::start_active_handshake() -{ - return do_handshake(); -} - -srs_error_t SrsDtlsClientImpl::on_dtls(char* data, int nb_data) { srs_error_t err = srs_success; - // When got packet, stop the ARQ if server in the first ARQ state SrsDtlsStateServerHello. - // @note But for ARQ state, we should never stop the ARQ, for example, we are in the second ARQ sate - // SrsDtlsStateServerDone, but we got previous late wrong packet ServeHello, which is not the expect - // packet SessionNewTicket, we should never stop the ARQ thread. - if (state_ == SrsDtlsStateServerHello) { - stop_arq(); + if ((err = do_handshake()) != srs_success) { + return srs_error_wrap(err, "start handshake"); } - if ((err = SrsDtlsImpl::on_dtls(data, nb_data)) != srs_success) { - return err; + if ((err = start_arq()) != srs_success) { + return srs_error_wrap(err, "start arq"); } return err; } -void SrsDtlsClientImpl::on_ssl_out_data(uint8_t*& data, int& size, bool& cached) +bool SrsDtlsClientImpl::should_reset_timer() { - // DTLS client use ARQ thread to send cached packet. - cached = false; + bool v = reset_timer_; + reset_timer_ = false; + return v; } +// Note that only handshake sending packets drives the state, neither ARQ nor the +// final-packets(after handshake done) drives it. srs_error_t SrsDtlsClientImpl::on_final_out_data(uint8_t* data, int size) { srs_error_t err = srs_success; - // Driven ARQ and state for DTLS client. // If we are sending client hello, change from init to new state. - if (state_ == SrsDtlsStateInit && size > 14 && data[13] == 1) { + if (state_ == SrsDtlsStateInit && size > 14 && data[0] == 22 && data[13] == 1) { state_ = SrsDtlsStateClientHello; + return err; } - // If we are sending certificate, change from SrsDtlsStateServerHello to new state. - if (state_ == SrsDtlsStateServerHello && size > 14 && data[13] == 11) { + + // If we are sending certificate, change from SrsDtlsStateClientHello to new state. + if (state_ == SrsDtlsStateClientHello && size > 14 && data[0] == 22 && data[13] == 11) { state_ = SrsDtlsStateClientCertificate; - } - // Try to start the ARQ for client. - if ((state_ == SrsDtlsStateClientHello || state_ == SrsDtlsStateClientCertificate)) { - if (state_ == SrsDtlsStateClientHello) { - state_ = SrsDtlsStateServerHello; - } else if (state_ == SrsDtlsStateClientCertificate) { - state_ = SrsDtlsStateServerDone; - } - - if ((err = start_arq()) != srs_success) { - return srs_error_wrap(err, "start arq"); - } + // When we send out the certificate, we should reset the timer. + reset_timer_ = true; + srs_info("DTLS: Reset the timer for ServerHello"); + return err; } return err; @@ -735,8 +771,15 @@ srs_error_t SrsDtlsClientImpl::on_handshake_done() { srs_error_t err = srs_success; - // When handshake done, stop the ARQ. + // Ignore if done. + if (state_ == SrsDtlsStateClientDone) { + return err; + } + + // Change to done state. state_ = SrsDtlsStateClientDone; + + // When handshake done, stop the ARQ. stop_arq(); // Notify connection the DTLS is done. @@ -756,8 +799,6 @@ srs_error_t SrsDtlsClientImpl::start_arq() { srs_error_t err = srs_success; - srs_info("start arq, state=%u", state_); - // Dispose the previous ARQ thread. srs_freep(trd); trd = new SrsSTCoroutine("dtls", this, _srs_context->get_id()); @@ -772,20 +813,24 @@ srs_error_t SrsDtlsClientImpl::start_arq() void SrsDtlsClientImpl::stop_arq() { - srs_info("stop arq, state=%u", state_); srs_freep(trd); - srs_info("stop arq, done"); } srs_error_t SrsDtlsClientImpl::cycle() { srs_error_t err = srs_success; - // Limit the max retry for ARQ. - for (int i = 0; i < (int)(sizeof(arq_to_ratios) / sizeof(int)); i++) { - srs_utime_t arq_to = arq_interval * arq_to_ratios[i]; - srs_usleep(arq_to); + // Limit the max retry for ARQ, to avoid infinite loop. + // Note that we set the timeout to [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200] in ms, + // but the actual timeout is limit to 1s: + // 50ms, 100ms, 200ms, 400ms, 800ms, (1000ms,600ms), (200ms,1000ms,1000ms,1000ms), + // (400ms,1000ms,1000ms,1000ms,1000ms,1000ms,1000ms), ... + // So when the max ARQ limit to 12 times, the max loop is about 103. + // @remark We change the max sleep to 100ms, so we limit about (103*10)/2=500. + const int max_loop = 512; + int arq_count = 0; + for (int i = 0; arq_count < arq_max_retry && i < max_loop; i++) { // We ignore any error for ARQ thread. if ((err = trd->pull()) != srs_success) { srs_freep(err); @@ -798,27 +843,62 @@ srs_error_t SrsDtlsClientImpl::cycle() } // For DTLS client ARQ, the state should be specified. - if (state_ != SrsDtlsStateServerHello && state_ != SrsDtlsStateServerDone) { + if (state_ != SrsDtlsStateClientHello && state_ != SrsDtlsStateClientCertificate) { return err; } - // Try to retransmit the packet. - uint8_t* data = last_outgoing_packet_cache; - int size = nn_last_outgoing_packet; + // If there is a timeout in progress, it sets *out to the time remaining + // and returns one. Otherwise, it returns zero. + int r0 = 0; timeval to = {0}; + if ((r0 = DTLSv1_get_timeout(dtls, &to)) == 0) { + // No timeout, for example?, wait for a default 50ms. + srs_usleep(50 * SRS_UTIME_MILLISECONDS); + continue; + } + srs_utime_t timeout = to.tv_sec + to.tv_usec; - if (size) { - // Trace the detail of DTLS packet. - state_trace((uint8_t*)data, size, false, 1, SSL_ERROR_NONE, true, true); - nn_arq_packets++; - - if ((err = callback_->write_dtls_data(data, size)) != srs_success) { - return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size, - srs_string_dumps_hex((char*)data, size, 32).c_str()); - } + // There is timeout to wait, so we should wait, because there is no packet in openssl. + if (timeout > 0) { + // Never wait too long, because we might need to retransmit other messages. + // For example, we have transmit 2 ClientHello as [50ms, 100ms] then we sleep(200ms), + // during this we reset the openssl timer to 50ms and need to retransmit Certificate, + // we still need to wait 200ms not 50ms. + timeout = srs_min(100 * SRS_UTIME_MILLISECONDS, timeout); + timeout = srs_max(50 * SRS_UTIME_MILLISECONDS, timeout); + srs_usleep(timeout); + continue; } - srs_info("arq cycle, done=%u, state=%u, retry=%d, interval=%dms, to=%dms, size=%d, nn=%d", handshake_done_for_us, - state_, i, srsu2msi(arq_interval), srsu2msi(arq_to), size, nn_arq_packets); + // The timeout is 0, so there must be a ARQ packet to transmit in openssl. + r0 = BIO_reset(bio_out); int r1 = SSL_get_error(dtls, r0); + if (r0 != 1) { + return srs_error_new(ERROR_OpenSslBIOReset, "BIO_reset r0=%d, r1=%d", r0, r1); + } + + // DTLSv1_handle_timeout is called when a DTLS handshake timeout expires. If no timeout + // had expired, it returns 0. Otherwise, it retransmits the previous flight of handshake + // messages and returns 1. If too many timeouts had expired without progress or an error + // occurs, it returns -1. + r0 = DTLSv1_handle_timeout(dtls); r1 = SSL_get_error(dtls, r0); + if (r0 == 0) { + continue; // No timeout had expired. + } + if (r0 != 1) { + return srs_error_new(ERROR_RTC_DTLS, "ARQ r0=%d, r1=%d", r0, r1); + } + + // The data to send out to peer. + uint8_t* data = NULL; + int size = BIO_get_mem_data(bio_out, (char**)&data); + + arq_count++; + nn_arq_packets++; + state_trace((uint8_t*)data, size, false, r0, r1, true); + + if (size > 0 && (err = callback_->write_dtls_data(data, size)) != srs_success) { + return srs_error_wrap(err, "dtls send size=%u, data=[%s]", size, + srs_string_dumps_hex((char*)data, size, 32).c_str()); + } } return err; @@ -848,23 +928,19 @@ srs_error_t SrsDtlsServerImpl::initialize(std::string version, std::string role) srs_error_t SrsDtlsServerImpl::start_active_handshake() { + // For DTLS server, we do nothing, because DTLS client drive it. return srs_success; } -void SrsDtlsServerImpl::on_ssl_out_data(uint8_t*& data, int& size, bool& cached) +bool SrsDtlsServerImpl::should_reset_timer() { - // If outgoing packet is empty, we use the last cache. - // @remark Only for DTLS server, because DTLS client use ARQ thread to send cached packet. - if (size <= 0 && nn_last_outgoing_packet) { - size = nn_last_outgoing_packet; - data = last_outgoing_packet_cache; - nn_arq_packets++; - cached = true; - } + // For DTLS server, we never use timer for ARQ, because DTLS client drive it. + return false; } srs_error_t SrsDtlsServerImpl::on_final_out_data(uint8_t* data, int size) { + // No ARQ, driven by DTLS client packets. return srs_success; } diff --git a/trunk/src/app/srs_app_rtc_dtls.hpp b/trunk/src/app/srs_app_rtc_dtls.hpp index 1e28eaf7d..61916a72f 100644 --- a/trunk/src/app/srs_app_rtc_dtls.hpp +++ b/trunk/src/app/srs_app_rtc_dtls.hpp @@ -121,9 +121,6 @@ protected: // Whether the handshake is done, for us only. // @remark For us only, means peer maybe not done, we also need to handle the DTLS packet. bool handshake_done_for_us; - // DTLS packet cache, only last out-going packet. - uint8_t* last_outgoing_packet_cache; - int nn_last_outgoing_packet; // The stat for ARQ packets. int nn_arq_packets; public: @@ -132,16 +129,16 @@ public: public: virtual srs_error_t initialize(std::string version, std::string role); virtual srs_error_t start_active_handshake() = 0; + virtual bool should_reset_timer() = 0; virtual srs_error_t on_dtls(char* data, int nb_data); protected: srs_error_t do_on_dtls(char* data, int nb_data); srs_error_t do_handshake(); - void state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool cache, bool arq); + void state_trace(uint8_t* data, int length, bool incoming, int r0, int r1, bool arq); public: srs_error_t get_srtp_key(std::string& recv_key, std::string& send_key); void callback_by_ssl(std::string type, std::string desc); protected: - virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached) = 0; virtual srs_error_t on_final_out_data(uint8_t* data, int size) = 0; virtual srs_error_t on_handshake_done() = 0; virtual bool is_dtls_client() = 0; @@ -155,18 +152,19 @@ private: SrsCoroutine* trd; // The DTLS-client state to drive the ARQ thread. SrsDtlsState state_; - // The timeout for ARQ. - srs_utime_t arq_interval; - int arq_to_ratios[8]; + // The max ARQ retry. + int arq_max_retry; + // Should we reset the timer? + // It's true when init, or in state ServerHello. + bool reset_timer_; public: SrsDtlsClientImpl(ISrsDtlsCallback* callback); virtual ~SrsDtlsClientImpl(); public: virtual srs_error_t initialize(std::string version, std::string role); virtual srs_error_t start_active_handshake(); - virtual srs_error_t on_dtls(char* data, int nb_data); + virtual bool should_reset_timer(); protected: - virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached); virtual srs_error_t on_final_out_data(uint8_t* data, int size); virtual srs_error_t on_handshake_done(); virtual bool is_dtls_client(); @@ -185,8 +183,8 @@ public: public: virtual srs_error_t initialize(std::string version, std::string role); virtual srs_error_t start_active_handshake(); + virtual bool should_reset_timer(); protected: - virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached); virtual srs_error_t on_final_out_data(uint8_t* data, int size); virtual srs_error_t on_handshake_done(); virtual bool is_dtls_client(); diff --git a/trunk/src/core/srs_core_version4.hpp b/trunk/src/core/srs_core_version4.hpp index c67ed23d7..8a76cdb2f 100644 --- a/trunk/src/core/srs_core_version4.hpp +++ b/trunk/src/core/srs_core_version4.hpp @@ -24,6 +24,6 @@ #ifndef SRS_CORE_VERSION4_HPP #define SRS_CORE_VERSION4_HPP -#define SRS_VERSION4_REVISION 83 +#define SRS_VERSION4_REVISION 84 #endif diff --git a/trunk/src/utest/srs_utest_rtc.cpp b/trunk/src/utest/srs_utest_rtc.cpp index bf839aca3..beabf73bf 100644 --- a/trunk/src/utest/srs_utest_rtc.cpp +++ b/trunk/src/utest/srs_utest_rtc.cpp @@ -871,9 +871,12 @@ srs_error_t MockDtlsCallback::cycle() } // Wait for mock io to done, try to switch to coroutine many times. -void mock_wait_dtls_io_done(int count = 100, int interval = 0) +void mock_wait_dtls_io_done(SrsDtlsImpl* client_impl, int count = 100, int interval = 0) { for (int i = 0; i < count; i++) { + if (client_impl) { + dynamic_cast(client_impl)->reset_timer_ = true; + } srs_usleep(interval * SRS_UTIME_MILLISECONDS); } } @@ -895,138 +898,6 @@ public: } }; -VOID TEST(KernelRTCTest, DTLSARQLimitTest) -{ - srs_error_t err = srs_success; - - // ClientHello lost, client retransmit the ClientHello. - if (true) { - MockDtlsCallback cio; SrsDtls client(&cio); - MockDtlsCallback sio; SrsDtls server(&sio); - MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL); - HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); - HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); - - // Use very short interval for utest. - dynamic_cast(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS; - HELPER_ARRAY_INIT(dynamic_cast(client.impl)->arq_to_ratios, 8, 1); - - // Lost 10 packets, total packets should be 9(max to 9). - // Note that only one server hello. - cio.nn_client_hello_lost = 10; - - HELPER_EXPECT_SUCCESS(client.start_active_handshake()); - mock_wait_dtls_io_done(10, 3); - - EXPECT_TRUE(sio.r0 == srs_success); - EXPECT_TRUE(cio.r0 == srs_success); - - EXPECT_FALSE(cio.done); - EXPECT_FALSE(sio.done); - - EXPECT_EQ(9, cio.nn_client_hello); - EXPECT_EQ(0, sio.nn_server_hello); - EXPECT_EQ(0, cio.nn_certificate); - EXPECT_EQ(0, sio.nn_new_session); - EXPECT_EQ(0, sio.nn_change_cipher); - } - - // Certificate lost, client retransmit the Certificate. - if (true) { - MockDtlsCallback cio; SrsDtls client(&cio); - MockDtlsCallback sio; SrsDtls server(&sio); - MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL); - HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); - HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); - - // Use very short interval for utest. - dynamic_cast(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS; - HELPER_ARRAY_INIT(dynamic_cast(client.impl)->arq_to_ratios, 8, 1); - - // Lost 10 packets, total packets should be 9(max to 9). - // Note that only one server NewSessionTicket. - cio.nn_certificate_lost = 10; - - HELPER_EXPECT_SUCCESS(client.start_active_handshake()); - mock_wait_dtls_io_done(10, 3); - - EXPECT_TRUE(sio.r0 == srs_success); - EXPECT_TRUE(cio.r0 == srs_success); - - EXPECT_FALSE(cio.done); - EXPECT_FALSE(sio.done); - - EXPECT_EQ(1, cio.nn_client_hello); - EXPECT_EQ(1, sio.nn_server_hello); - EXPECT_EQ(9, cio.nn_certificate); - EXPECT_EQ(0, sio.nn_new_session); - EXPECT_EQ(0, sio.nn_change_cipher); - } - - // ServerHello lost, client retransmit the ClientHello. - if (true) { - MockDtlsCallback cio; SrsDtls client(&cio); - MockDtlsCallback sio; SrsDtls server(&sio); - MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL); - HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); - HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); - - // Use very short interval for utest. - dynamic_cast(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS; - HELPER_ARRAY_INIT(dynamic_cast(client.impl)->arq_to_ratios, 8, 1); - - // Lost 10 packets, total packets should be 9(max to 9). - sio.nn_server_hello_lost = 10; - - HELPER_EXPECT_SUCCESS(client.start_active_handshake()); - mock_wait_dtls_io_done(10, 3); - - EXPECT_TRUE(sio.r0 == srs_success); - EXPECT_TRUE(cio.r0 == srs_success); - - EXPECT_FALSE(cio.done); - EXPECT_FALSE(sio.done); - - EXPECT_EQ(9, cio.nn_client_hello); - EXPECT_EQ(9, sio.nn_server_hello); - EXPECT_EQ(0, cio.nn_certificate); - EXPECT_EQ(0, sio.nn_new_session); - EXPECT_EQ(0, sio.nn_change_cipher); - } - - // NewSessionTicket lost, client retransmit the Certificate. - if (true) { - MockDtlsCallback cio; SrsDtls client(&cio); - MockDtlsCallback sio; SrsDtls server(&sio); - MockBridgeDtlsIO b0(&cio, &server, NULL); MockBridgeDtlsIO b1(&sio, &client, NULL); - HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); - HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); - - // Use very short interval for utest. - dynamic_cast(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS; - HELPER_ARRAY_INIT(dynamic_cast(client.impl)->arq_to_ratios, 8, 1); - - // Lost 10 packets, total packets should be 9(max to 9). - sio.nn_new_session_lost = 10; - - HELPER_EXPECT_SUCCESS(client.start_active_handshake()); - mock_wait_dtls_io_done(10, 3); - - EXPECT_TRUE(sio.r0 == srs_success); - EXPECT_TRUE(cio.r0 == srs_success); - - // Although the packet is lost, but it's done for server, and not done for client. - EXPECT_FALSE(cio.done); - EXPECT_TRUE(sio.done); - - EXPECT_EQ(1, cio.nn_client_hello); - EXPECT_EQ(1, sio.nn_server_hello); - EXPECT_EQ(9, cio.nn_certificate); - EXPECT_EQ(9, sio.nn_new_session); - EXPECT_EQ(0, sio.nn_change_cipher); - } -} - VOID TEST(KernelRTCTest, DTLSClientARQTest) { srs_error_t err = srs_success; @@ -1040,7 +911,7 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest) HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); HELPER_EXPECT_SUCCESS(client.start_active_handshake()); - mock_wait_dtls_io_done(30, 1); + mock_wait_dtls_io_done(client.impl, 15, 5); EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success); @@ -1050,8 +921,8 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest) EXPECT_EQ(1, cio.nn_client_hello); EXPECT_EQ(1, sio.nn_server_hello); - EXPECT_TRUE(1 <= cio.nn_certificate); - EXPECT_TRUE(1 <= sio.nn_new_session); + EXPECT_EQ(1, cio.nn_certificate); + EXPECT_EQ(1, sio.nn_new_session); EXPECT_EQ(0, sio.nn_change_cipher); } @@ -1063,16 +934,12 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest) HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); - // Use very short interval for utest. - dynamic_cast(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS; - HELPER_ARRAY_INIT(dynamic_cast(client.impl)->arq_to_ratios, 8, 1); - // Lost 2 packets, total packets should be 3. // Note that only one server hello. - cio.nn_client_hello_lost = 2; + cio.nn_client_hello_lost = 1; HELPER_EXPECT_SUCCESS(client.start_active_handshake()); - mock_wait_dtls_io_done(10, 3); + mock_wait_dtls_io_done(client.impl, 15, 5); EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success); @@ -1080,10 +947,10 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest) EXPECT_TRUE(cio.done); EXPECT_TRUE(sio.done); - EXPECT_TRUE(3 <= cio.nn_client_hello); - EXPECT_TRUE(1 <= sio.nn_server_hello); - EXPECT_TRUE(1 <= cio.nn_certificate); - EXPECT_TRUE(1 <= sio.nn_new_session); + EXPECT_EQ(2, cio.nn_client_hello); + EXPECT_EQ(1, sio.nn_server_hello); + EXPECT_EQ(1, cio.nn_certificate); + EXPECT_EQ(1, sio.nn_new_session); EXPECT_EQ(0, sio.nn_change_cipher); } @@ -1095,16 +962,12 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest) HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); - // Use very short interval for utest. - dynamic_cast(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS; - HELPER_ARRAY_INIT(dynamic_cast(client.impl)->arq_to_ratios, 8, 1); - // Lost 2 packets, total packets should be 3. // Note that only one server NewSessionTicket. - cio.nn_certificate_lost = 2; + cio.nn_certificate_lost = 1; HELPER_EXPECT_SUCCESS(client.start_active_handshake()); - mock_wait_dtls_io_done(10, 3); + mock_wait_dtls_io_done(client.impl, 15, 5); EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success); @@ -1113,9 +976,9 @@ VOID TEST(KernelRTCTest, DTLSClientARQTest) EXPECT_TRUE(sio.done); EXPECT_EQ(1, cio.nn_client_hello); - EXPECT_EQ(1, sio.nn_server_hello); - EXPECT_TRUE(3 <= cio.nn_certificate); - EXPECT_TRUE(1 <= sio.nn_new_session); + EXPECT_EQ(2, sio.nn_server_hello); + EXPECT_EQ(2, cio.nn_certificate); + EXPECT_EQ(0, sio.nn_new_session); EXPECT_EQ(0, sio.nn_change_cipher); } } @@ -1133,7 +996,7 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest) HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); HELPER_EXPECT_SUCCESS(client.start_active_handshake()); - mock_wait_dtls_io_done(30, 1); + mock_wait_dtls_io_done(client.impl, 15, 5); EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success); @@ -1143,8 +1006,8 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest) EXPECT_EQ(1, cio.nn_client_hello); EXPECT_EQ(1, sio.nn_server_hello); - EXPECT_TRUE(1 <= cio.nn_certificate); - EXPECT_TRUE(1 <= sio.nn_new_session); + EXPECT_EQ(1, cio.nn_certificate); + EXPECT_EQ(1, sio.nn_new_session); EXPECT_EQ(0, sio.nn_change_cipher); } @@ -1156,15 +1019,11 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest) HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); - // Use very short interval for utest. - dynamic_cast(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS; - HELPER_ARRAY_INIT(dynamic_cast(client.impl)->arq_to_ratios, 8, 1); - // Lost 2 packets, total packets should be 3. - sio.nn_server_hello_lost = 2; + sio.nn_server_hello_lost = 1; HELPER_EXPECT_SUCCESS(client.start_active_handshake()); - mock_wait_dtls_io_done(10, 3); + mock_wait_dtls_io_done(client.impl, 15, 5); EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success); @@ -1172,10 +1031,10 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest) EXPECT_TRUE(cio.done); EXPECT_TRUE(sio.done); - EXPECT_EQ(3, cio.nn_client_hello); - EXPECT_EQ(3, sio.nn_server_hello); - EXPECT_TRUE(1 <= cio.nn_certificate); - EXPECT_TRUE(1 <= sio.nn_new_session); + EXPECT_EQ(2, cio.nn_client_hello); + EXPECT_EQ(2, sio.nn_server_hello); + EXPECT_EQ(1, cio.nn_certificate); + EXPECT_EQ(1, sio.nn_new_session); EXPECT_EQ(0, sio.nn_change_cipher); } @@ -1187,15 +1046,11 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest) HELPER_EXPECT_SUCCESS(client.initialize("active", "dtls1.0")); HELPER_EXPECT_SUCCESS(server.initialize("passive", "dtls1.0")); - // Use very short interval for utest. - dynamic_cast(client.impl)->arq_interval = 1 * SRS_UTIME_MILLISECONDS; - HELPER_ARRAY_INIT(dynamic_cast(client.impl)->arq_to_ratios, 8, 1); - // Lost 2 packets, total packets should be 3. - sio.nn_new_session_lost = 2; + sio.nn_new_session_lost = 1; HELPER_EXPECT_SUCCESS(client.start_active_handshake()); - mock_wait_dtls_io_done(10, 3); + mock_wait_dtls_io_done(client.impl, 15, 5); EXPECT_TRUE(sio.r0 == srs_success); EXPECT_TRUE(cio.r0 == srs_success); @@ -1205,8 +1060,8 @@ VOID TEST(KernelRTCTest, DTLSServerARQTest) EXPECT_EQ(1, cio.nn_client_hello); EXPECT_EQ(1, sio.nn_server_hello); - EXPECT_EQ(3, cio.nn_certificate); - EXPECT_EQ(3, sio.nn_new_session); + EXPECT_EQ(2, cio.nn_certificate); + EXPECT_EQ(2, sio.nn_new_session); EXPECT_EQ(0, sio.nn_change_cipher); } } @@ -1250,10 +1105,10 @@ VOID TEST(KernelRTCTest, DTLSClientFlowTest) {4, "auto", "dtls1.0", true, true, false, false}, // OK, Client: DTLS auto(v1.0 or v1.2), Server: DTLS v1.0 {5, "auto", "dtls1.2", true, true, false, false}, - // Fail, Client: DTLS v1.0, Server: DTLS v1.2 - {6, "dtls1.0", "dtls1.2", false, false, false, true}, - // Fail, Client: DTLS v1.2, Server: DTLS v1.0 - {7, "dtls1.2", "dtls1.0", false, false, true, false}, + // OK?, Client: DTLS v1.0, Server: DTLS v1.2 + {6, "dtls1.0", "dtls1.2", true, true, false, false}, + // OK?, Client: DTLS v1.2, Server: DTLS v1.0 + {7, "dtls1.2", "dtls1.0", true, true, false, false}, }; for (int i = 0; i < (int)(sizeof(cases) / sizeof(DTLSFlowCase)); i++) { @@ -1266,14 +1121,14 @@ VOID TEST(KernelRTCTest, DTLSClientFlowTest) HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c; HELPER_EXPECT_SUCCESS(client.start_active_handshake()) << c; - mock_wait_dtls_io_done(); + mock_wait_dtls_io_done(client.impl, 15, 5); // Note that the cio error is generated from server, vice versa. - EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c; - EXPECT_EQ(c.ServerError, cio.r0 != srs_success) << c; - EXPECT_EQ(c.ClientDone, cio.done) << c; EXPECT_EQ(c.ServerDone, sio.done) << c; + + EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c; + EXPECT_EQ(c.ServerError, cio.r0 != srs_success) << c; } } @@ -1294,10 +1149,10 @@ VOID TEST(KernelRTCTest, DTLSServerFlowTest) {4, "auto", "dtls1.0", true, true, false, false}, // OK, Client: DTLS auto(v1.0 or v1.2), Server: DTLS v1.0 {5, "auto", "dtls1.2", true, true, false, false}, - // Fail, Client: DTLS v1.0, Server: DTLS v1.2 - {6, "dtls1.0", "dtls1.2", false, false, false, true}, - // Fail, Client: DTLS v1.2, Server: DTLS v1.0 - {7, "dtls1.2", "dtls1.0", false, false, true, false}, + // OK?, Client: DTLS v1.0, Server: DTLS v1.2 + {6, "dtls1.0", "dtls1.2", true, true, false, false}, + // OK?, Client: DTLS v1.2, Server: DTLS v1.0 + {7, "dtls1.2", "dtls1.0", true, true, false, false}, }; for (int i = 0; i < (int)(sizeof(cases) / sizeof(DTLSFlowCase)); i++) { @@ -1310,14 +1165,14 @@ VOID TEST(KernelRTCTest, DTLSServerFlowTest) HELPER_EXPECT_SUCCESS(server.initialize("passive", c.ServerVersion)) << c; HELPER_EXPECT_SUCCESS(client.start_active_handshake()) << c; - mock_wait_dtls_io_done(); + mock_wait_dtls_io_done(NULL, 15, 5); // Note that the cio error is generated from server, vice versa. - EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c; - EXPECT_EQ(c.ServerError, cio.r0 != srs_success) << c; - EXPECT_EQ(c.ClientDone, cio.done) << c; EXPECT_EQ(c.ServerDone, sio.done) << c; + + EXPECT_EQ(c.ClientError, sio.r0 != srs_success) << c; + EXPECT_EQ(c.ServerError, cio.r0 != srs_success) << c; } }