1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

GB28181: Support GB28181-2016 protocol. v5.0.74 (#3201)

01. Support GB config as StreamCaster.
02. Support disable GB by --gb28181=off.
03. Add utests for SIP examples.
04. Wireshark plugin to decode TCP/9000 as rtp.rfc4571
05. Support MPEGPS program stream codec.
06. Add utest for PS stream codec.
07. Decode MPEGPS packet stream.
08. Carry RTP and PS packet as helper in PS message.
09. Support recover from error mode.
10. Support process by a pack of PS/TS messages.
11. Add statistic for recovered and msgs dropped.
12. Recover from err position fastly.
13. Define state machine for GB session.
14. Bind context to GB session.
15. Re-invite when media disconnected.
16. Update GitHub actions with GB28181.
17. Support parse CANDIDATE by env or pip.
18. Support mux GB28181 to RTMP.
19. Support regression test by srs-bench.
This commit is contained in:
Winlin 2022-10-06 17:40:58 +08:00 committed by GitHub
parent 9c81a0e1bd
commit 5a420ece3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
298 changed files with 43343 additions and 763 deletions

View file

@ -15,8 +15,8 @@ jobs:
# Build for CentOS 7
- name: Build on CentOS7, baseline
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target centos7-baseline .
- name: Build on CentOS7, with SRT
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target centos7-srt .
- name: Build on CentOS7, with all features
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target centos7-all .
- name: Build on CentOS7, without WebRTC
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target centos7-no-webrtc .
- name: Build on CentOS7, without ASM
@ -35,8 +35,8 @@ jobs:
# Build for CentOS 6
- name: Build on CentOS6, baseline
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target centos6-baseline .
- name: Build on CentOS6, with SRT
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target centos6-srt .
- name: Build on CentOS6, with all features
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target centos6-all .
build-ubuntu16:
name: build-ubuntu16
@ -49,8 +49,8 @@ jobs:
# Build for Ubuntu16
- name: Build on Ubuntu16, baseline
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target ubuntu16-baseline .
- name: Build on Ubuntu16, with SRT
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target ubuntu16-srt .
- name: Build on Ubuntu16, with all features
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target ubuntu16-all .
build-ubuntu18:
name: build-ubuntu18
@ -63,8 +63,8 @@ jobs:
# Build for Ubuntu18
- name: Build on Ubuntu18, baseline
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target ubuntu18-baseline .
- name: Build on Ubuntu18, with SRT
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target ubuntu18-srt .
- name: Build on Ubuntu18, with all features
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target ubuntu18-all .
build-ubuntu20:
name: build-ubuntu20
@ -77,8 +77,8 @@ jobs:
# Build for Ubuntu20
- name: Build on Ubuntu20, baseline
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target ubuntu20-baseline .
- name: Build on Ubuntu20, with SRT
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target ubuntu20-srt .
- name: Build on Ubuntu20, with all features
run: DOCKER_BUILDKIT=1 docker build -f trunk/Dockerfile.builds --target ubuntu20-all .
build-cross-arm:
name: build-cross-arm

View file

@ -112,7 +112,7 @@ fi
然后运行回归测试用例,如果只跑一次,可以直接运行:
```bash
go test ./srs -mod=vendor -v
go test ./srs -mod=vendor -v -count=1
```
也可以用make编译出重复使用的二进制
@ -137,7 +137,7 @@ PASS
可以给回归测试传参数,这样可以测试不同的序列,比如:
```bash
go test ./srs -mod=vendor -v -srs-server=127.0.0.1
go test ./srs -mod=vendor -v -count=1 -srs-server=127.0.0.1
# Or
make && ./objs/srs_test -test.v -srs-server=127.0.0.1
```
@ -151,8 +151,8 @@ make && ./objs/srs_test -test.v -srs-log -test.run TestRtcBasic_PublishPlay
支持的参数如下:
* `-srs-server`RTC服务器地址。默认值`127.0.0.1`
* `-srs-stream`RTC流地址。默认值`/rtc/regression`
* `-srs-timeout`每个Case的超时时间毫秒。默认值`3000`即3秒。
* `-srs-stream`RTC流地址,一般会加上随机的后缀。默认值:`/rtc/regression`
* `-srs-timeout`每个Case的超时时间毫秒。默认值`5000`即5秒。
* `-srs-publish-audio`,推流时,使用的音频文件。默认值:`avatar.ogg`
* `-srs-publish-video`,推流时,使用的视频文件。默认值:`avatar.h264`
* `-srs-publish-video-fps`推流时视频文件的FPS。默认值`25`
@ -189,7 +189,7 @@ pip install lxml && pip install gcovr
支持Janus的压测使用选项`-sfu janus`可以查看帮助:
```bash
./objs/srs_bench -sfu janus --help
make && ./objs/srs_bench -sfu janus --help
```
首先需要启动Janus推荐使用[janus-docker](https://github.com/winlinvip/janus-docker#usage):
@ -221,4 +221,31 @@ make -j10 && ./objs/srs_bench -sfu janus \
-nn 5
```
## GB28181
支持GB28181的压测使用选项`-sfu gb28181`可以查看帮助:
```bash
make && ./objs/srs_bench -sfu gb28181 --help
```
运行回归测试用例,更多命令请参考[Regression Test](#regression-test)
```bash
go test ./gb28181 -mod=vendor -v -count=1
```
支持的参数如下:
* `-srs-sip`SIP服务器地址。默认值`tcp://127.0.0.1:5060`
* `-srs-stream`GB的user即流名称一般会加上随机的后缀。默认值`3402000000`
* `-srs-timeout`每个Case的超时时间毫秒。默认值`11000`即11秒。
* `-srs-publish-audio`,推流时,使用的音频文件。默认值:`avatar.aac`
* `-srs-publish-video`,推流时,使用的视频文件。默认值:`avatar.h264`
* `-srs-publish-video-fps`推流时视频文件的FPS。默认值`25`
其他不常用参数:
* `-srs-log`,是否开启详细日志。默认值:`false`
2021.01, Winlin

BIN
trunk/3rdparty/srs-bench/avatar.aac vendored Normal file

Binary file not shown.

View file

@ -0,0 +1,145 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package gb28181
import (
"context"
"flag"
"fmt"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger"
"io"
"os"
"strings"
"time"
)
type gbMainConfig struct {
sipConfig SIPConfig
psConfig PSConfig
}
func Parse(ctx context.Context) interface{} {
fl := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
var sfu string
fl.StringVar(&sfu, "sfu", "srs", "The SFU server, srs or gb28181 or janus")
c := &gbMainConfig{}
fl.StringVar(&c.sipConfig.addr, "pr", "", "")
fl.StringVar(&c.sipConfig.user, "user", "", "")
fl.StringVar(&c.sipConfig.server, "server", "", "")
fl.StringVar(&c.sipConfig.domain, "domain", "", "")
fl.IntVar(&c.sipConfig.random, "random", 0, "")
fl.StringVar(&c.psConfig.video, "sv", "", "")
fl.StringVar(&c.psConfig.audio, "sa", "", "")
fl.IntVar(&c.psConfig.fps, "fps", 0, "")
fl.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or gb28181 or janus. Default: srs"))
fmt.Println(fmt.Sprintf("SIP:"))
fmt.Println(fmt.Sprintf(" -user The SIP username, ID of device."))
fmt.Println(fmt.Sprintf(" -random Append N number to user as random device ID, like 1320000001."))
fmt.Println(fmt.Sprintf(" -server The SIP server ID, ID of server."))
fmt.Println(fmt.Sprintf(" -domain The SIP domain, domain of server and device."))
fmt.Println(fmt.Sprintf("Publisher:"))
fmt.Println(fmt.Sprintf(" -pr The SIP server address, format is tcp://ip:port over TCP."))
fmt.Println(fmt.Sprintf(" -fps [Optional] The fps of .h264 source file."))
fmt.Println(fmt.Sprintf(" -sa [Optional] The file path to read audio, ignore if empty."))
fmt.Println(fmt.Sprintf(" -sv [Optional] The file path to read video, ignore if empty."))
fmt.Println(fmt.Sprintf("\n例如1个推流"))
fmt.Println(fmt.Sprintf(" %v -sfu gb28181 -pr tcp://127.0.0.1:5060 -user 34020000001320000001 -server 34020000002000000001 -domain 3402000000", os.Args[0]))
fmt.Println(fmt.Sprintf(" %v -sfu gb28181 -pr tcp://127.0.0.1:5060 -user 3402000000 -random 10 -server 34020000002000000001 -domain 3402000000", os.Args[0]))
fmt.Println(fmt.Sprintf(" %v -sfu gb28181 -pr tcp://127.0.0.1:5060 -user 3402000000 -random 10 -server 34020000002000000001 -domain 3402000000 -sa avatar.aac -sv avatar.h264 -fps 25", os.Args[0]))
fmt.Println(fmt.Sprintf(" %v -sfu gb28181 -pr tcp://127.0.0.1:5060 -user livestream -server srs -domain ossrs.io -sa avatar.aac -sv avatar.h264 -fps 25", os.Args[0]))
fmt.Println()
}
if err := fl.Parse(os.Args[1:]); err == flag.ErrHelp {
os.Exit(0)
}
showHelp := c.sipConfig.String() == ""
if showHelp {
fl.Usage()
os.Exit(-1)
}
summaryDesc := ""
if c.sipConfig.addr != "" {
pubString := strings.Join([]string{c.sipConfig.String(), c.psConfig.String()}, ",")
summaryDesc = fmt.Sprintf("%v, publish(%v)", summaryDesc, pubString)
}
logger.Tf(ctx, "Run benchmark with %v", summaryDesc)
return c
}
func Run(ctx context.Context, r0 interface{}) (err error) {
conf := r0.(*gbMainConfig)
ctx, cancel := context.WithCancel(ctx)
session := NewGBSession(&GBSessionConfig{
regTimeout: 3 * time.Hour, inviteTimeout: 3 * time.Hour,
}, &conf.sipConfig)
defer session.Close()
if err := session.Connect(ctx); err != nil {
return errors.Wrapf(err, "connect %v", conf.sipConfig)
}
if err := session.Register(ctx); err != nil {
return errors.Wrapf(err, "register %v", conf.sipConfig)
}
if err := session.Invite(ctx); err != nil {
return errors.Wrapf(err, "invite %v", conf.sipConfig)
}
if conf.psConfig.video == "" || conf.psConfig.audio == "" {
cancel()
return nil
}
ingester := NewPSIngester(&IngesterConfig{
psConfig: conf.psConfig,
ssrc: uint32(session.out.ssrc),
clockRate: session.out.clockRate,
payloadType: uint8(session.out.payloadType),
})
defer ingester.Close()
if ingester.conf.serverAddr, err = utilBuildMediaAddr(session.sip.conf.addr, session.out.mediaPort); err != nil {
return err
}
if err := ingester.Ingest(ctx); err != nil {
if errors.Cause(err) == io.EOF {
logger.Tf(ctx, "EOF, video=%v, audio=%v", conf.psConfig.video, conf.psConfig.audio)
return nil
}
return errors.Wrap(err, "ingest")
}
return nil
}

View file

@ -0,0 +1,45 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package gb28181
import (
"github.com/ossrs/go-oryx-lib/logger"
"io/ioutil"
"os"
"testing"
)
func TestMain(m *testing.M) {
if err := prepareTest(); err != nil {
logger.Ef(nil, "Prepare test fail, err %+v", err)
os.Exit(-1)
}
// Disable the logger during all tests.
if *srsLog == false {
olw := logger.Switch(ioutil.Discard)
defer func() {
logger.Switch(olw)
}()
}
os.Exit(m.Run())
}

View file

@ -0,0 +1,495 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package gb28181
import (
"context"
"fmt"
"github.com/ghettovoice/gosip/sip"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/go-oryx-lib/errors"
"testing"
"time"
)
func TestGbPublishRegularly(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
err := func() error {
t := NewGBTestPublisher()
defer t.Close()
var nnPackets int
t.ingester.onSendPacket = func(pack *PSPackStream) error {
if nnPackets += 1; nnPackets > 10 {
cancel()
}
return nil
}
if err := t.Run(ctx); err != nil {
return err
}
return nil
}()
if err := filterTestError(ctx.Err(), err); err != nil {
t.Errorf("err %+v", err)
}
}
func TestGbSessionHandshake(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
err := func() error {
t := NewGBTestSession()
defer t.Close()
// Use fast heartbeat for utest.
t.session.heartbeatInterval = 100 * time.Millisecond
if err := t.Run(ctx); err != nil {
return err
}
var nn int
t.session.onMessageHeartbeat = func(req, res sip.Message) error {
if nn++; nn >= 3 {
t.session.cancel()
}
return nil
}
<-t.session.heartbeatCtx.Done()
return t.session.heartbeatCtx.Err()
}()
if err := filterTestError(ctx.Err(), err); err != nil {
t.Errorf("err %+v", err)
}
}
func TestGbSessionHandshakeDropRegisterOk(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
var conf *SIPConfig
r0 := func() error {
t := NewGBTestSession()
defer t.Close()
conf = t.session.sip.conf
ctx, cancel2 := context.WithCancel(ctx)
t.session.onRegisterDone = func(req, res sip.Message) error {
cancel2()
return nil
}
return t.Run(ctx)
}()
// Use the same session for SIP.
r1 := func() error {
session := NewGBTestSession()
session.session.sip.conf = conf
defer session.Close()
return session.Run(ctx)
}()
if err := filterTestError(ctx.Err(), r0, r1); err != nil {
t.Errorf("err %+v", err)
}
}
func TestGbSessionHandshakeDropInviteRequest(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
var conf *SIPConfig
r0 := func() error {
t := NewGBTestSession()
defer t.Close()
conf = t.session.sip.conf
// Drop the invite request, to simulate the device crash or disconnect when got this message.
ctx2, cancel2 := context.WithCancel(ctx)
t.session.onInviteRequest = func(req sip.Message) error {
cancel2()
return nil
}
return t.Run(ctx2)
}()
// When device restart session when inviting, server should re-invite when got register message.
r1 := func() error {
t := NewGBTestSession()
t.session.sip.conf = conf
defer t.Close()
return t.Run(ctx)
}()
if err := filterTestError(ctx.Err(), r0, r1); err != nil {
t.Errorf("err %+v", err)
}
}
func TestGbSessionHandshakeDropInvite200Ack(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
var conf *SIPConfig
r0 := func() error {
t := NewGBTestSession()
defer t.Close()
conf = t.session.sip.conf
// Drop the invite ok ACK, to simulate the device crash or disconnect when got this message.
ctx2, cancel2 := context.WithCancel(ctx)
t.session.onInviteOkAck = func(req, res sip.Message) error {
cancel2()
return nil
}
return t.Run(ctx2)
}()
// When device restart session when 200 ack of invite, server should be stable state and waiting for media, then
//there should be a media timeout and re-invite.
r1 := func() error {
t := NewGBTestSession()
t.session.sip.conf = conf
defer t.Close()
return t.Run(ctx)
}()
if err := filterTestError(ctx.Err(), r0, r1); err != nil {
t.Errorf("err %+v", err)
}
}
func TestGbPublishMediaDisconnect(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
var conf *SIPConfig
r0 := func() error {
t := NewGBTestPublisher()
defer t.Close()
conf = t.session.sip.conf
var nnPackets int
ctx2, cancel2 := context.WithCancel(ctx)
t.ingester.onSendPacket = func(pack *PSPackStream) error {
if nnPackets += 1; nnPackets > 200 {
cancel2()
}
return nil
}
if err := t.Run(ctx2); err != nil {
return err
}
return nil
}()
r1 := func() error {
t := NewGBTestSession()
t.session.sip.conf = conf
defer t.Close()
return t.Run(ctx)
}()
if err := filterTestError(ctx.Err(), r0, r1); err != nil {
t.Errorf("err %+v", err)
}
}
func TestGbSessionBye(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
err := func() error {
t := NewGBTestSession()
defer t.Close()
// Use fast heartbeat for utest.
t.session.heartbeatInterval = 100 * time.Millisecond
if err := t.Run(ctx); err != nil {
return err
}
var nn int
t.session.onMessageHeartbeat = func(req, res sip.Message) error {
if nn++; nn == 3 {
return t.session.Bye(ctx)
}
return nil
}
reconnectTimeout := time.Duration(*srsMediaTimeout+*srsReinviteTimeout+1000) * time.Millisecond
ctx2, cancel2 := context.WithTimeout(ctx, reconnectTimeout)
defer cancel2()
req, err := t.session.sip.Wait(ctx2, sip.INVITE)
if req != nil {
return fmt.Errorf("should not invite after bye")
}
if errors.Cause(err) == context.DeadlineExceeded {
return nil
}
return err
}()
if err := filterTestError(ctx.Err(), err); err != nil {
t.Errorf("err %+v", err)
}
}
func TestGbSessionUnregister(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
err := func() error {
t := NewGBTestSession()
defer t.Close()
// Use fast heartbeat for utest.
t.session.heartbeatInterval = 100 * time.Millisecond
if err := t.Run(ctx); err != nil {
return err
}
var nn int
t.session.onMessageHeartbeat = func(req, res sip.Message) error {
if nn++; nn == 3 {
return t.session.UnRegister(ctx)
}
return nil
}
reconnectTimeout := time.Duration(*srsMediaTimeout+*srsReinviteTimeout+1000) * time.Millisecond
ctx2, cancel2 := context.WithTimeout(ctx, reconnectTimeout)
defer cancel2()
req, err := t.session.sip.Wait(ctx2, sip.INVITE)
if req != nil {
return fmt.Errorf("should not invite after bye")
}
if errors.Cause(err) == context.DeadlineExceeded {
return nil
}
return err
}()
if err := filterTestError(ctx.Err(), err); err != nil {
t.Errorf("err %+v", err)
}
}
func TestGbPublishReinvite(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
var conf *SIPConfig
err := func() error {
t := NewGBTestPublisher()
defer t.Close()
conf = t.session.sip.conf
var nnPackets int
ctx2, cancel2 := context.WithCancel(ctx)
t.ingester.onSendPacket = func(pack *PSPackStream) error {
if nnPackets += 1; nnPackets == 3 {
cancel2()
}
return nil
}
if err := t.Run(ctx2); err != nil {
return err
}
return nil
}()
r1 := func() error {
t := NewGBTestSession()
defer t.Close()
t.session.sip.conf = conf
// Only register the device, bind to session.
if err := t.session.Connect(ctx); err != nil {
return err
}
if err := t.session.Register(ctx); err != nil {
return err
}
// We should get reinvite when reconnect to SRS.
req, err := t.session.sip.Wait(ctx, sip.INVITE)
if req == nil {
return fmt.Errorf("should reinvite after disconnect")
}
return err
}()
if err := filterTestError(ctx.Err(), err, r1); err != nil {
t.Errorf("err %+v", err)
}
}
func TestGbPublishBye(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
var conf *SIPConfig
err := func() error {
t := NewGBTestPublisher()
defer t.Close()
conf = t.session.sip.conf
var nnPackets int
ctx2, cancel2 := context.WithCancel(ctx)
t.ingester.onSendPacket = func(pack *PSPackStream) error {
if nnPackets += 1; nnPackets == 10 {
if err := t.session.Bye(ctx2); err != nil {
return err
}
cancel2()
}
return nil
}
if err := t.Run(ctx2); err != nil {
return err
}
return nil
}()
r1 := func() error {
t := NewGBTestSession()
defer t.Close()
t.session.sip.conf = conf
// Only register the device, bind to session.
if err := t.session.Connect(ctx); err != nil {
return err
}
if err := t.session.Register(ctx); err != nil {
return err
}
// We should not get reinvite when reconnect to SRS.
reconnectTimeout := time.Duration(*srsMediaTimeout+*srsReinviteTimeout+1000) * time.Millisecond
ctx2, cancel2 := context.WithTimeout(ctx, reconnectTimeout)
defer cancel2()
req, err := t.session.sip.Wait(ctx2, sip.INVITE)
if req != nil {
return fmt.Errorf("should not invite after bye")
}
if errors.Cause(err) == context.DeadlineExceeded {
return nil
}
return err
}()
if err := filterTestError(ctx.Err(), err, r1); err != nil {
t.Errorf("err %+v", err)
}
}
func TestGbPublishUnregister(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
defer cancel()
var conf *SIPConfig
err := func() error {
t := NewGBTestPublisher()
defer t.Close()
conf = t.session.sip.conf
var nnPackets int
ctx2, cancel2 := context.WithCancel(ctx)
t.ingester.onSendPacket = func(pack *PSPackStream) error {
if nnPackets += 1; nnPackets == 10 {
if err := t.session.UnRegister(ctx2); err != nil {
return err
}
cancel2()
}
return nil
}
if err := t.Run(ctx2); err != nil {
return err
}
return nil
}()
r1 := func() error {
t := NewGBTestSession()
defer t.Close()
t.session.sip.conf = conf
// Only register the device, bind to session.
if err := t.session.Connect(ctx); err != nil {
return err
}
if err := t.session.Register(ctx); err != nil {
return err
}
// We should not get reinvite when reconnect to SRS.
reconnectTimeout := time.Duration(*srsMediaTimeout+*srsReinviteTimeout+1000) * time.Millisecond
ctx2, cancel2 := context.WithTimeout(ctx, reconnectTimeout)
defer cancel2()
req, err := t.session.sip.Wait(ctx2, sip.INVITE)
if req != nil {
return fmt.Errorf("should not invite after bye")
}
if errors.Cause(err) == context.DeadlineExceeded {
return nil
}
return err
}()
if err := filterTestError(ctx.Err(), err, r1); err != nil {
t.Errorf("err %+v", err)
}
}

View file

@ -0,0 +1,418 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package gb28181
import (
"context"
"github.com/ghettovoice/gosip/sip"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/pion/webrtc/v3/pkg/media/h264reader"
"io"
"os"
"strconv"
"strings"
"sync"
"time"
)
type GBSessionConfig struct {
regTimeout time.Duration
inviteTimeout time.Duration
}
type GBSessionOutput struct {
ssrc int64
mediaPort int64
clockRate uint64
payloadType uint8
}
type GBSession struct {
// GB config.
conf *GBSessionConfig
// The output of session.
out *GBSessionOutput
// The SIP session object.
sip *SIPSession
// Callback when REGISTER done.
onRegisterDone func(req, res sip.Message) error
// Callback when got INVITE request.
onInviteRequest func(req sip.Message) error
// Callback when got INVITE 200 OK ACK request.
onInviteOkAck func(req, res sip.Message) error
// Callback when got MESSAGE response.
onMessageHeartbeat func(req, res sip.Message) error
// For heartbeat coroutines.
heartbeatInterval time.Duration
heartbeatCtx context.Context
cancel context.CancelFunc
// WaitGroup for coroutines.
wg sync.WaitGroup
}
func NewGBSession(c *GBSessionConfig, sc *SIPConfig) *GBSession {
return &GBSession{
sip: NewSIPSession(sc),
conf: c,
out: &GBSessionOutput{
clockRate: uint64(90000),
payloadType: uint8(96),
},
heartbeatInterval: 1 * time.Second,
}
}
func (v *GBSession) Close() error {
if v.cancel != nil {
v.cancel()
}
v.sip.Close()
v.wg.Wait()
return nil
}
func (v *GBSession) Connect(ctx context.Context) error {
client := v.sip
if err := client.Connect(ctx); err != nil {
return errors.Wrap(err, "connect")
}
return ctx.Err()
}
func (v *GBSession) Register(ctx context.Context) error {
client := v.sip
for ctx.Err() == nil {
ctx, regCancel := context.WithTimeout(ctx, v.conf.regTimeout)
defer regCancel()
regReq, regRes, err := client.Register(ctx)
if err != nil {
return errors.Wrap(err, "register")
}
logger.Tf(ctx, "Register id=%v, response=%v", regReq.MessageID(), regRes.MessageID())
if v.onRegisterDone != nil {
if err = v.onRegisterDone(regReq, regRes); err != nil {
return errors.Wrap(err, "callback")
}
}
break
}
return ctx.Err()
}
func (v *GBSession) Invite(ctx context.Context) error {
client := v.sip
for ctx.Err() == nil {
ctx, inviteCancel := context.WithTimeout(ctx, v.conf.inviteTimeout)
defer inviteCancel()
inviteReq, err := client.Wait(ctx, sip.INVITE)
if err != nil {
return errors.Wrap(err, "wait")
}
logger.Tf(ctx, "Got INVITE request, Call-ID=%v", sipGetCallID(inviteReq))
if v.onInviteRequest != nil {
if err = v.onInviteRequest(inviteReq); err != nil {
return errors.Wrap(err, "callback")
}
}
if err = client.Trying(ctx, inviteReq); err != nil {
return errors.Wrapf(err, "trying invite is %v", inviteReq.String())
}
time.Sleep(100 * time.Millisecond)
inviteRes, err := client.InviteResponse(ctx, inviteReq)
if err != nil {
return errors.Wrapf(err, "response invite is %v", inviteReq.String())
}
offer := inviteReq.Body()
ssrcStr := strings.Split(strings.Split(offer, "y=")[1], "\r\n")[0]
if v.out.ssrc, err = strconv.ParseInt(ssrcStr, 10, 64); err != nil {
return errors.Wrapf(err, "parse ssrc=%v, sdp %v", ssrcStr, offer)
}
mediaPortStr := strings.Split(strings.Split(offer, "m=video")[1], " ")[1]
if v.out.mediaPort, err = strconv.ParseInt(mediaPortStr, 10, 64); err != nil {
return errors.Wrapf(err, "parse media port=%v, sdp %v", mediaPortStr, offer)
}
logger.Tf(ctx, "Invite id=%v, response=%v, y=%v, ssrc=%v, mediaPort=%v",
inviteReq.MessageID(), inviteRes.MessageID(), ssrcStr, v.out.ssrc, v.out.mediaPort,
)
if v.onInviteOkAck != nil {
if err = v.onInviteOkAck(inviteReq, inviteRes); err != nil {
return errors.Wrap(err, "callback")
}
}
break
}
// Start goroutine for heartbeat every 1s.
v.heartbeatCtx, v.cancel = context.WithCancel(ctx)
go func(ctx context.Context) {
v.wg.Add(1)
defer v.wg.Done()
for ctx.Err() == nil {
req, res, err := client.Message(ctx)
if err != nil {
v.cancel()
logger.Ef(ctx, "heartbeat err %+v", err)
return
}
if v.onMessageHeartbeat != nil {
if err = v.onMessageHeartbeat(req, res); err != nil {
v.cancel()
logger.Ef(ctx, "callback err %+v", err)
return
}
}
select {
case <-ctx.Done():
return
case <-time.After(v.heartbeatInterval):
}
}
}(v.heartbeatCtx)
return ctx.Err()
}
func (v *GBSession) Bye(ctx context.Context) error {
client := v.sip
for ctx.Err() == nil {
ctx, regCancel := context.WithTimeout(ctx, v.conf.regTimeout)
defer regCancel()
regReq, regRes, err := client.Bye(ctx)
if err != nil {
return errors.Wrap(err, "bye")
}
logger.Tf(ctx, "Bye id=%v, response=%v", regReq.MessageID(), regRes.MessageID())
break
}
return ctx.Err()
}
func (v *GBSession) UnRegister(ctx context.Context) error {
client := v.sip
for ctx.Err() == nil {
ctx, regCancel := context.WithTimeout(ctx, v.conf.regTimeout)
defer regCancel()
regReq, regRes, err := client.UnRegister(ctx)
if err != nil {
return errors.Wrap(err, "UnRegister")
}
logger.Tf(ctx, "UnRegister id=%v, response=%v", regReq.MessageID(), regRes.MessageID())
break
}
return ctx.Err()
}
type IngesterConfig struct {
psConfig PSConfig
ssrc uint32
serverAddr string
clockRate uint64
payloadType uint8
}
type PSIngester struct {
conf *IngesterConfig
onSendPacket func(pack *PSPackStream) error
cancel context.CancelFunc
}
func NewPSIngester(c *IngesterConfig) *PSIngester {
return &PSIngester{conf: c}
}
func (v *PSIngester) Close() error {
if v.cancel != nil {
v.cancel()
}
return nil
}
func (v *PSIngester) Ingest(ctx context.Context) error {
ctx, v.cancel = context.WithCancel(ctx)
ps := NewPSClient(uint32(v.conf.ssrc), v.conf.serverAddr)
if err := ps.Connect(ctx); err != nil {
return errors.Wrapf(err, "connect media=%v", v.conf.serverAddr)
}
defer ps.Close()
videoFile, err := os.Open(v.conf.psConfig.video)
if err != nil {
return errors.Wrapf(err, "Open file %v", v.conf.psConfig.video)
}
defer videoFile.Close()
f, err := os.Open(v.conf.psConfig.audio)
if err != nil {
return errors.Wrapf(err, "Open file %v", v.conf.psConfig.audio)
}
defer f.Close()
h264, err := h264reader.NewReader(videoFile)
if err != nil {
return errors.Wrapf(err, "Open h264 %v", v.conf.psConfig.video)
}
audio, err := NewAACReader(f)
if err != nil {
return errors.Wrapf(err, "Open ogg %v", v.conf.psConfig.audio)
}
// Scale the video samples to 1024 according to AAC, that is 1 video frame means 1024 samples.
audioSampleRate := audio.codec.ASC().SampleRate.ToHz()
videoSampleRate := 1024 * 1000 / v.conf.psConfig.fps
logger.Tf(ctx, "PS: Media stream, tbn=%v, ssrc=%v, pt=%v, Video(%v, fps=%v, rate=%v), Audio(%v, rate=%v, channels=%v)",
v.conf.clockRate, v.conf.ssrc, v.conf.payloadType, v.conf.psConfig.video, v.conf.psConfig.fps, videoSampleRate,
v.conf.psConfig.audio, audioSampleRate, audio.codec.ASC().Channels)
lastPrint := time.Now()
var aacSamples, avcSamples uint64
var audioDTS, videoDTS uint64
defer func() {
logger.Tf(ctx, "Consume Video(samples=%v, dts=%v, ts=%.2f) and Audio(samples=%v, dts=%v, ts=%.2f)",
avcSamples, videoDTS, float64(videoDTS)/90.0, aacSamples, audioDTS, float64(audioDTS)/90.0,
)
}()
clock := newWallClock()
var pack *PSPackStream
for ctx.Err() == nil {
if pack == nil {
pack = NewPSPackStream(v.conf.payloadType)
}
// One pack should only contains one video frame.
if !pack.hasVideo {
var sps, pps *h264reader.NAL
var videoFrames []*h264reader.NAL
for ctx.Err() == nil {
frame, err := h264.NextNAL()
if err == io.EOF {
return io.EOF
}
if err != nil {
return errors.Wrapf(err, "Read h264")
}
videoFrames = append(videoFrames, frame)
logger.If(ctx, "NALU %v PictureOrderCount=%v, ForbiddenZeroBit=%v, RefIdc=%v, %v bytes",
frame.UnitType.String(), frame.PictureOrderCount, frame.ForbiddenZeroBit, frame.RefIdc, len(frame.Data))
if frame.UnitType == h264reader.NalUnitTypeSPS {
sps = frame
} else if frame.UnitType == h264reader.NalUnitTypePPS {
pps = frame
} else {
break
}
}
// We convert the video sample rate to be based over 1024, that is 1024 samples means one video frame.
avcSamples += 1024
videoDTS = uint64(v.conf.clockRate*avcSamples) / uint64(videoSampleRate)
if sps != nil || pps != nil {
err = pack.WriteHeader(videoDTS)
} else {
err = pack.WritePackHeader(videoDTS)
}
if err != nil {
return errors.Wrap(err, "pack header")
}
for _, frame := range videoFrames {
if err = pack.WriteVideo(frame.Data, videoDTS); err != nil {
return errors.Wrapf(err, "write video %v", len(frame.Data))
}
}
}
// Always read and consume one audio frame each time.
if true {
audioFrame, err := audio.NextADTSFrame()
if err != nil {
return errors.Wrap(err, "Read AAC")
}
// Each AAC frame contains 1024 samples, DTS = total-samples / sample-rate
aacSamples += 1024
audioDTS = uint64(v.conf.clockRate*aacSamples) / uint64(audioSampleRate)
if time.Now().Sub(lastPrint) > 3*time.Second {
lastPrint = time.Now()
logger.Tf(ctx, "Consume Video(samples=%v, dts=%v, ts=%.2f) and Audio(samples=%v, dts=%v, ts=%.2f)",
avcSamples, videoDTS, float64(videoDTS)/90.0, aacSamples, audioDTS, float64(audioDTS)/90.0,
)
}
if err = pack.WriteAudio(audioFrame, audioDTS); err != nil {
return errors.Wrapf(err, "write audio %v", len(audioFrame))
}
}
// Send pack when got video and enough audio frames.
if pack.hasVideo && videoDTS < audioDTS {
if err := ps.WritePacksOverRTP(pack.packets); err != nil {
return errors.Wrap(err, "write")
}
if v.onSendPacket != nil {
if err := v.onSendPacket(pack); err != nil {
return errors.Wrap(err, "callback")
}
}
pack = nil // Reset pack.
}
// One audio frame(1024 samples), the duration is 1024/audioSampleRate in seconds.
sampleDuration := time.Duration(uint64(time.Second) * 1024 / uint64(audioSampleRate))
if d := clock.Tick(sampleDuration); d > 0 {
time.Sleep(d)
}
}
return nil
}

281
trunk/3rdparty/srs-bench/gb28181/ps.go vendored Normal file
View file

@ -0,0 +1,281 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package gb28181
import (
"context"
"fmt"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/pion/rtp"
"github.com/yapingcat/gomedia/codec"
"github.com/yapingcat/gomedia/mpeg2"
"math"
"net"
"net/url"
"strings"
)
type PSConfig struct {
// The video source file.
video string
// The fps for h264 file.
fps int
// The audio source file.
audio string
}
func (v *PSConfig) String() string {
sb := []string{}
if v.video != "" {
sb = append(sb, fmt.Sprintf("video=%v", v.video))
}
if v.fps > 0 {
sb = append(sb, fmt.Sprintf("fps=%v", v.fps))
}
if v.audio != "" {
sb = append(sb, fmt.Sprintf("audio=%v", v.audio))
}
return strings.Join(sb, ",")
}
type PSClient struct {
// SSRC from SDP.
ssrc uint32
// The server IP address and port to connect to.
serverAddr string
// Inner state, sequence number.
seq uint16
// Inner state, media TCP connection
conn *net.TCPConn
}
func NewPSClient(ssrc uint32, serverAddr string) *PSClient {
return &PSClient{ssrc: ssrc, serverAddr: serverAddr}
}
func (v *PSClient) Close() error {
if v.conn != nil {
v.conn.Close()
}
return nil
}
func (v *PSClient) Connect(ctx context.Context) error {
if u, err := url.Parse(v.serverAddr); err != nil {
return errors.Wrapf(err, "parse addr=%v", v.serverAddr)
} else if addr, err := net.ResolveTCPAddr(u.Scheme, u.Host); err != nil {
return errors.Wrapf(err, "parse addr=%v, scheme=%v, host=%v", v.serverAddr, u.Scheme, u.Host)
} else if v.conn, err = net.DialTCP(u.Scheme, nil, addr); err != nil {
return errors.Wrapf(err, "connect addr=%v as %v", v.serverAddr, addr.String())
}
return nil
}
func (v *PSClient) WritePacksOverRTP(packs []*PSPacket) error {
for _, pack := range packs {
for _, payload := range pack.ps {
v.seq++
p := rtp.Packet{Header: rtp.Header{
Version: 2, PayloadType: uint8(pack.pt), SequenceNumber: v.seq,
Timestamp: uint32(pack.ts), SSRC: uint32(v.ssrc),
}, Payload: payload}
b, err := p.Marshal()
if err != nil {
return errors.Wrapf(err, "rtp marshal")
}
if _, err = v.conn.Write([]byte{uint8(len(b) >> 8), uint8(len(b))}); err != nil {
return errors.Wrapf(err, "write length=%v", len(b))
}
if _, err = v.conn.Write(b); err != nil {
return errors.Wrapf(err, "write payload %v bytes", len(b))
}
}
}
return nil
}
type PSPacketType int
const (
PSPacketTypePackHeader PSPacketType = iota
PSPacketTypeSystemHeader
PSPacketTypeProgramStramMap
PSPacketTypeVideo
PSPacketTypeAudio
)
type PSPacket struct {
t PSPacketType
ts uint64
pt uint8
ps [][]byte
}
func NewPSPacket(t PSPacketType, p []byte, ts uint64, pt uint8) *PSPacket {
v := &PSPacket{t: t, ts: ts, pt: pt}
if p != nil {
v.ps = append(v.ps, p)
}
return v
}
func (v *PSPacket) Append(p []byte) *PSPacket {
v.ps = append(v.ps, p)
return v
}
type PSPackStream struct {
// The RTP paload type.
pt uint8
// Split a big media frame to small PES packets.
ideaPesLength int
// The generated bytes of PS stream data.
packets []*PSPacket
// Whether has video packet.
hasVideo bool
}
func NewPSPackStream(pt uint8) *PSPackStream {
return &PSPackStream{ideaPesLength: 1400, pt: pt}
}
func (v *PSPackStream) WriteHeader(dts uint64) error {
if err := v.WritePackHeader(dts); err != nil {
return err
}
if err := v.WriteSystemHeader(dts); err != nil {
return err
}
if err := v.WriteProgramStreamMap(dts); err != nil {
return err
}
return nil
}
func (v *PSPackStream) WritePackHeader(dts uint64) error {
w := codec.NewBitStreamWriter(1500)
pack := &mpeg2.PSPackHeader{
System_clock_reference_base: dts,
Program_mux_rate: 159953,
Pack_stuffing_length: 6,
}
pack.Encode(w)
v.packets = append(v.packets, NewPSPacket(PSPacketTypePackHeader, w.Bits(), dts, v.pt))
return nil
}
func (v *PSPackStream) WriteSystemHeader(dts uint64) error {
w := codec.NewBitStreamWriter(1500)
system := &mpeg2.System_header{
Rate_bound: 159953,
Video_bound: 1,
Audio_bound: 1,
Streams: []*mpeg2.Elementary_Stream{
// SrsTsPESStreamIdVideoCommon = 0xe0
&mpeg2.Elementary_Stream{Stream_id: uint8(0xe0), P_STD_buffer_bound_scale: 1, P_STD_buffer_size_bound: 128},
// SrsTsPESStreamIdAudioCommon = 0xc0
&mpeg2.Elementary_Stream{Stream_id: uint8(0xc0), P_STD_buffer_bound_scale: 0, P_STD_buffer_size_bound: 8},
// SrsTsPESStreamIdPrivateStream1 = 0xbd
&mpeg2.Elementary_Stream{Stream_id: uint8(0xbd), P_STD_buffer_bound_scale: 1, P_STD_buffer_size_bound: 128},
// SrsTsPESStreamIdPrivateStream2 = 0xbf
&mpeg2.Elementary_Stream{Stream_id: uint8(0xbf), P_STD_buffer_bound_scale: 1, P_STD_buffer_size_bound: 128},
},
}
system.Encode(w)
v.packets = append(v.packets, NewPSPacket(PSPacketTypeSystemHeader, w.Bits(), dts, v.pt))
return nil
}
func (v *PSPackStream) WriteProgramStreamMap(dts uint64) error {
w := codec.NewBitStreamWriter(1500)
psm := &mpeg2.Program_stream_map{
Stream_map: []*mpeg2.Elementary_stream_elem{
// SrsTsPESStreamIdVideoCommon = 0xe0
mpeg2.NewElementary_stream_elem(uint8(mpeg2.PS_STREAM_H264), 0xe0),
// SrsTsPESStreamIdAudioCommon = 0xc0
mpeg2.NewElementary_stream_elem(uint8(mpeg2.PS_STREAM_AAC), 0xc0),
},
}
psm.Encode(w)
v.packets = append(v.packets, NewPSPacket(PSPacketTypeProgramStramMap, w.Bits(), dts, v.pt))
return nil
}
// The nalu is raw data without ANNEXB header.
func (v *PSPackStream) WriteVideo(nalu []byte, dts uint64) error {
// Mux frame payload in AnnexB format. Always fresh NALU header for frame, see srs_avc_insert_aud.
annexb := append([]byte{0, 0, 0, 1}, nalu...)
video := NewPSPacket(PSPacketTypeVideo, nil, dts, v.pt)
for i := 0; i < len(annexb); i += v.ideaPesLength {
payloadLength := int(math.Min(float64(v.ideaPesLength), float64(len(annexb)-i)))
bb := annexb[i : i+payloadLength]
w := codec.NewBitStreamWriter(65535)
pes := &mpeg2.PesPacket{
Stream_id: uint8(0xe0), // SrsTsPESStreamIdVideoCommon = 0xe0
PTS_DTS_flags: uint8(0x03), Dts: dts, Pts: dts, // Both DTS and PTS.
Pes_payload: bb,
}
utilUpdatePesPacketLength(pes)
pes.Encode(w)
video.Append(w.Bits())
}
v.hasVideo = true
v.packets = append(v.packets, video)
return nil
}
// Write AAC ADTS frame.
func (v *PSPackStream) WriteAudio(adts []byte, dts uint64) error {
w := codec.NewBitStreamWriter(65535)
pes := &mpeg2.PesPacket{
Stream_id: uint8(0xc0), // SrsTsPESStreamIdAudioCommon = 0xc0
PTS_DTS_flags: uint8(0x03), Dts: dts, Pts: dts, // Both DTS and PTS.
Pes_payload: adts,
}
utilUpdatePesPacketLength(pes)
pes.Encode(w)
v.packets = append(v.packets, NewPSPacket(PSPacketTypeAudio, w.Bits(), dts, v.pt))
return nil
}

561
trunk/3rdparty/srs-bench/gb28181/sip.go vendored Normal file
View file

@ -0,0 +1,561 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package gb28181
import (
"context"
"fmt"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
"github.com/ghettovoice/gosip/transport"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger"
"math/rand"
"net/url"
"strings"
"sync"
"time"
)
type SIPConfig struct {
// The server address, for example: tcp://127.0.0.1:5060k
addr string
// The SIP domain, for example: ossrs.io or 3402000000
domain string
// The SIP device ID, for example: camera or 34020000001320000001
user string
// The N number of random device ID, for example, 10 means 1320000001
random int
// The SIP server ID, for example: srs or 34020000002000000001
server string
// The cached device id.
deviceID string
}
// The global cache to avoid conflict of deviceID.
// Note that it's not coroutine safe, but it should be OK for utest.
var deviceIDCache map[string]bool
func init() {
deviceIDCache = make(map[string]bool)
}
func (v *SIPConfig) DeviceID() string {
for v.deviceID == "" {
// Generate a random ID.
var rid string
for len(rid) < v.random {
rid += fmt.Sprintf("%v", rand.Uint64())
}
deviceID := fmt.Sprintf("%v%v", v.user, rid[:v.random])
// Ignore if exists.
if _, ok := deviceIDCache[deviceID]; !ok {
v.deviceID = deviceID
deviceIDCache[deviceID] = true
}
}
return v.deviceID
}
func (v *SIPConfig) String() string {
sb := []string{}
if v.addr != "" {
sb = append(sb, fmt.Sprintf("addr=%v", v.addr))
}
if v.domain != "" {
sb = append(sb, fmt.Sprintf("domain=%v", v.domain))
}
if v.user != "" {
sb = append(sb, fmt.Sprintf("user=%v", v.user))
sb = append(sb, fmt.Sprintf("deviceID=%v", v.DeviceID()))
}
if v.random > 0 {
sb = append(sb, fmt.Sprintf("random=%v", v.random))
}
if v.server != "" {
sb = append(sb, fmt.Sprintf("server=%v", v.server))
}
return strings.Join(sb, ",")
}
type SIPSession struct {
conf *SIPConfig
rb *sip.RequestBuilder
requests chan sip.Request
responses chan sip.Response
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
client *SIPClient
seq uint
}
func NewSIPSession(c *SIPConfig) *SIPSession {
return &SIPSession{
conf: c, client: NewSIPClient(), rb: sip.NewRequestBuilder(),
requests: make(chan sip.Request, 1024), responses: make(chan sip.Response, 1024),
seq: 100,
}
}
func (v *SIPSession) Close() error {
if v.cancel != nil {
v.cancel()
}
v.client.Close()
v.wg.Wait()
return nil
}
func (v *SIPSession) Connect(ctx context.Context) error {
if ctx.Err() != nil {
return ctx.Err()
}
ctx, cancel := context.WithCancel(ctx)
v.ctx, v.cancel = ctx, cancel
if err := v.client.Connect(ctx, v.conf.addr); err != nil {
return errors.Wrapf(err, "connect with sipConfig %v", v.conf.String())
}
// Dispatch requests and responses.
go func() {
v.wg.Add(1)
defer v.wg.Done()
for {
select {
case <-v.ctx.Done():
return
case msg := <-v.client.incoming:
if req, ok := msg.(sip.Request); ok {
select {
case v.requests <- req:
case <-v.ctx.Done():
return
}
} else if res, ok := msg.(sip.Response); ok {
select {
case v.responses <- res:
case <-v.ctx.Done():
return
}
} else {
logger.Wf(ctx, "Drop message %v", msg.String())
}
}
}
}()
return nil
}
func (v *SIPSession) Register(ctx context.Context) (sip.Message, sip.Message, error) {
return v.doRegister(ctx, 3600)
}
func (v *SIPSession) UnRegister(ctx context.Context) (sip.Message, sip.Message, error) {
return v.doRegister(ctx, 0)
}
func (v *SIPSession) doRegister(ctx context.Context, expires int) (sip.Message, sip.Message, error) {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
}
sipPort := sip.Port(5060)
sipCallID := sip.CallID(fmt.Sprintf("%v", rand.Uint64()))
sipBranch := fmt.Sprintf("z9hG4bK_%v", rand.Uint32())
sipTag := fmt.Sprintf("%v", rand.Uint32())
sipMaxForwards := sip.MaxForwards(70)
sipExpires := sip.Expires(uint32(expires))
sipPIP := "192.168.3.99"
v.seq++
rb := v.rb
rb.SetTransport("TCP")
rb.SetMethod(sip.REGISTER)
rb.AddVia(&sip.ViaHop{
ProtocolName: "SIP", ProtocolVersion: "2.0", Transport: "TCP", Host: sipPIP, Port: &sipPort,
Params: sip.NewParams().Add("branch", sip.String{Str: sipBranch}),
})
rb.SetFrom(&sip.Address{
Uri: &sip.SipUri{FUser: sip.String{v.conf.DeviceID()}, FHost: v.conf.domain},
Params: sip.NewParams().Add("tag", sip.String{Str: sipTag}),
})
rb.SetTo(&sip.Address{
Uri: &sip.SipUri{FUser: sip.String{v.conf.DeviceID()}, FHost: v.conf.domain},
})
rb.SetCallID(&sipCallID)
rb.SetSeqNo(v.seq)
rb.SetRecipient(&sip.SipUri{FUser: sip.String{v.conf.server}, FHost: v.conf.domain})
rb.SetContact(&sip.Address{
Uri: &sip.SipUri{FUser: sip.String{v.conf.DeviceID()}, FHost: sipPIP, FPort: &sipPort},
})
rb.SetMaxForwards(&sipMaxForwards)
rb.SetExpires(&sipExpires)
req, err := rb.Build()
if err != nil {
return req, nil, errors.Wrap(err, "build request")
}
if err = v.client.Send(req); err != nil {
return req, nil, errors.Wrapf(err, "send request %v", req.String())
}
callID := sipGetCallID(req)
if callID == "" {
return req, nil, errors.Errorf("Invalid SIP Call-ID register %v", req.String())
}
logger.Tf(ctx, "Send REGISTER request, Call-ID=%v, Expires=%v", callID, expires)
for {
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
case <-v.ctx.Done():
return nil, nil, v.ctx.Err()
case msg := <-v.responses:
if tv := sipGetCallID(msg); tv == callID {
return req, msg, nil
} else {
logger.Wf(v.ctx, "Not callID=%v, msg=%v, drop message %v", callID, tv, msg.String())
}
}
}
}
func (v *SIPSession) Trying(ctx context.Context, invite sip.Message) error {
if ctx.Err() != nil {
return ctx.Err()
}
req, ok := invite.(sip.Request)
if !ok {
return errors.Errorf("Invalid SIP request invite %v", invite.String())
}
res := sip.NewResponseFromRequest("", req, sip.StatusCode(100), "Trying", "")
if err := v.client.Send(res); err != nil {
return errors.Wrapf(err, "send response %v", res.String())
}
return nil
}
func (v *SIPSession) InviteResponse(ctx context.Context, invite sip.Message) (sip.Message, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
req, ok := invite.(sip.Request)
if !ok {
return nil, errors.Errorf("Invalid SIP request invite %v", invite.String())
}
callID := sipGetCallID(invite)
if callID == "" {
return nil, errors.Errorf("Invalid SIP Call-ID invite %v", invite.String())
}
res := sip.NewResponseFromRequest("", req, sip.StatusCode(200), "OK", "")
if err := v.client.Send(res); err != nil {
return nil, errors.Wrapf(err, "send response %v", res.String())
}
logger.Tf(ctx, "Send INVITE response, Call-ID=%v", callID)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-v.ctx.Done():
return nil, v.ctx.Err()
case msg := <-v.requests:
// Must be an ACK message.
if !msg.IsAck() {
return msg, errors.Errorf("invalid ACK message %v", msg.String())
}
// Check CALL-ID of ACK, should be equal to 200 OK.
if tv := sipGetCallID(msg); tv == callID {
return msg, nil
} else {
logger.Wf(v.ctx, "Not callID=%v, msg=%v, drop message %v", callID, tv, msg.String())
}
}
}
}
func (v *SIPSession) Message(ctx context.Context) (sip.Message, sip.Message, error) {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
}
sipPort := sip.Port(5060)
sipCallID := sip.CallID(fmt.Sprintf("%v", rand.Uint64()))
sipBranch := fmt.Sprintf("z9hG4bK_%v", rand.Uint32())
sipTag := fmt.Sprintf("%v", rand.Uint32())
sipMaxForwards := sip.MaxForwards(70)
sipExpires := sip.Expires(3600)
sipPIP := "192.168.3.99"
v.seq++
rb := v.rb
rb.SetTransport("TCP")
rb.SetMethod(sip.MESSAGE)
rb.AddVia(&sip.ViaHop{
ProtocolName: "SIP", ProtocolVersion: "2.0", Transport: "TCP", Host: sipPIP, Port: &sipPort,
Params: sip.NewParams().Add("branch", sip.String{Str: sipBranch}),
})
rb.SetFrom(&sip.Address{
Uri: &sip.SipUri{FUser: sip.String{v.conf.DeviceID()}, FHost: v.conf.domain},
Params: sip.NewParams().Add("tag", sip.String{Str: sipTag}),
})
rb.SetTo(&sip.Address{
Uri: &sip.SipUri{FUser: sip.String{v.conf.server}, FHost: v.conf.domain},
})
rb.SetCallID(&sipCallID)
rb.SetSeqNo(v.seq)
rb.SetRecipient(&sip.SipUri{FUser: sip.String{v.conf.server}, FHost: v.conf.domain})
rb.SetContact(&sip.Address{
Uri: &sip.SipUri{FUser: sip.String{v.conf.DeviceID()}, FHost: sipPIP, FPort: &sipPort},
})
rb.SetMaxForwards(&sipMaxForwards)
rb.SetExpires(&sipExpires)
v.seq++
rb.SetBody(strings.Join([]string{
`<?xml version="1.0" encoding="GB2312"?>`,
"<Notify>",
"<CmdType>Keepalive</CmdType>",
fmt.Sprintf("<SN>%v</SN>", v.seq),
fmt.Sprintf("<DeviceID>%v</DeviceID>", v.conf.DeviceID()),
"<Status>OK</Status>",
"</Notify>\n",
}, "\n"))
req, err := rb.Build()
if err != nil {
return req, nil, errors.Wrap(err, "build request")
}
if err = v.client.Send(req); err != nil {
return req, nil, errors.Wrapf(err, "send request %v", req.String())
}
callID := sipGetCallID(req)
if callID == "" {
return req, nil, errors.Errorf("Invalid SIP Call-ID message %v", req.String())
}
logger.Tf(ctx, "Send MESSAGE request, Call-ID=%v", callID)
for {
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
case <-v.ctx.Done():
return nil, nil, v.ctx.Err()
case msg := <-v.responses:
if tv := sipGetCallID(msg); tv == callID {
return req, msg, nil
} else {
logger.Wf(v.ctx, "Not callID=%v, msg=%v, drop message %v", callID, tv, msg.String())
}
}
}
}
func (v *SIPSession) Bye(ctx context.Context) (sip.Message, sip.Message, error) {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
}
sipPort := sip.Port(5060)
sipCallID := sip.CallID(fmt.Sprintf("%v", rand.Uint64()))
sipBranch := fmt.Sprintf("z9hG4bK_%v", rand.Uint32())
sipTag := fmt.Sprintf("%v", rand.Uint32())
sipMaxForwards := sip.MaxForwards(70)
sipExpires := sip.Expires(3600)
sipPIP := "192.168.3.99"
v.seq++
rb := v.rb
rb.SetTransport("TCP")
rb.SetMethod(sip.BYE)
rb.AddVia(&sip.ViaHop{
ProtocolName: "SIP", ProtocolVersion: "2.0", Transport: "TCP", Host: sipPIP, Port: &sipPort,
Params: sip.NewParams().Add("branch", sip.String{Str: sipBranch}),
})
rb.SetFrom(&sip.Address{
Uri: &sip.SipUri{FUser: sip.String{v.conf.DeviceID()}, FHost: v.conf.domain},
Params: sip.NewParams().Add("tag", sip.String{Str: sipTag}),
})
rb.SetTo(&sip.Address{
Uri: &sip.SipUri{FUser: sip.String{v.conf.server}, FHost: v.conf.domain},
})
rb.SetCallID(&sipCallID)
rb.SetSeqNo(v.seq)
rb.SetRecipient(&sip.SipUri{FUser: sip.String{v.conf.server}, FHost: v.conf.domain})
rb.SetContact(&sip.Address{
Uri: &sip.SipUri{FUser: sip.String{v.conf.DeviceID()}, FHost: sipPIP, FPort: &sipPort},
})
rb.SetMaxForwards(&sipMaxForwards)
rb.SetExpires(&sipExpires)
req, err := rb.Build()
if err != nil {
return req, nil, errors.Wrap(err, "build request")
}
if err = v.client.Send(req); err != nil {
return req, nil, errors.Wrapf(err, "send request %v", req.String())
}
callID := sipGetCallID(req)
if callID == "" {
return req, nil, errors.Errorf("Invalid SIP Call-ID bye %v", req.String())
}
logger.Tf(ctx, "Send BYE request, Call-ID=%v", callID)
for {
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
case <-v.ctx.Done():
return nil, nil, v.ctx.Err()
case msg := <-v.responses:
if tv := sipGetCallID(msg); tv == callID {
return req, msg, nil
} else {
logger.Wf(v.ctx, "Not callID=%v, msg=%v, drop message %v", callID, tv, msg.String())
}
}
}
}
func (v *SIPSession) Wait(ctx context.Context, method sip.RequestMethod) (sip.Message, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-v.ctx.Done():
return nil, v.ctx.Err()
case msg := <-v.requests:
if r, ok := msg.(sip.Request); ok && r.Method() == method {
return msg, nil
} else {
logger.Wf(v.ctx, "Not method=%v, drop message %v", method, msg.String())
}
}
}
}
type SIPClient struct {
ctx context.Context
cancel context.CancelFunc
incoming chan sip.Message
target *transport.Target
protocol transport.Protocol
cleanupTimeout time.Duration
}
func NewSIPClient() *SIPClient {
return &SIPClient{
cleanupTimeout: 5 * time.Second,
}
}
func (v *SIPClient) Close() error {
if v.cancel != nil {
v.cancel()
}
// Wait for protocol stack to cleanup.
if v.protocol != nil {
select {
case <-time.After(v.cleanupTimeout):
logger.E(v.ctx, "Wait for protocol cleanup timeout")
case <-v.protocol.Done():
logger.T(v.ctx, "SIP protocol stack done")
}
}
return nil
}
func (v *SIPClient) Connect(ctx context.Context, addr string) error {
prURL, err := url.Parse(addr)
if err != nil {
return errors.Wrapf(err, "parse addr=%v", addr)
}
if prURL.Scheme != "tcp" && prURL.Scheme != "tcp4" {
return errors.Errorf("invalid scheme=%v of addr=%v", prURL.Scheme, addr)
}
target, err := transport.NewTargetFromAddr(prURL.Host)
if err != nil {
return errors.Wrapf(err, "create target to %v", prURL.Host)
}
v.target = target
incoming := make(chan sip.Message, 1024)
errs := make(chan error, 1)
cancels := make(chan struct{}, 1)
protocol := transport.NewTcpProtocol(incoming, errs, cancels, nil, log.NewDefaultLogrusLogger())
v.protocol = protocol
v.incoming = incoming
// Convert protocol stack errs to context signal.
ctx, cancel := context.WithCancel(ctx)
v.cancel = cancel
v.ctx = ctx
go func() {
select {
case <-ctx.Done():
return
case r0 := <-errs:
logger.Ef(ctx, "SIP stack err %+v", r0)
cancel()
}
}()
// Covert context signal to cancels for protocol stack.
go func() {
<-ctx.Done()
close(cancels)
logger.Tf(ctx, "Notify SIP stack to cancel")
}()
return nil
}
func (v *SIPClient) Send(msg sip.Message) error {
logger.Tf(v.ctx, "Send msg %v", msg.String())
return v.protocol.Send(v.target, msg)
}

361
trunk/3rdparty/srs-bench/gb28181/util.go vendored Normal file
View file

@ -0,0 +1,361 @@
// The MIT License (MIT)
//
// Copyright (c) 2022 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package gb28181
import (
"bufio"
"context"
"flag"
"fmt"
"github.com/ghettovoice/gosip/sip"
"github.com/ossrs/go-oryx-lib/aac"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/yapingcat/gomedia/mpeg2"
"io"
"net"
"net/url"
"os"
"path"
"strings"
"time"
)
var srsLog *bool
var srsTimeout *int
var srsPublishVideoFps *int
var srsSipAddr *string
var srsSipUser *string
var srsSipRandomID *int
var srsSipDomain *string
var srsSipSvrID *string
var srsMediaTimeout *int
var srsReinviteTimeout *int
var srsPublishAudio *string
var srsPublishVideo *string
func prepareTest() (err error) {
srsSipAddr = flag.String("srs-sip", "tcp://127.0.0.1:5060", "The SRS GB server to connect to")
srsSipUser = flag.String("srs-stream", "3402000000", "The GB user/stream to publish")
srsSipRandomID = flag.Int("srs-random", 10, "The GB user/stream random suffix to publish")
srsSipDomain = flag.String("srs-domain", "3402000000", "The GB SIP domain")
srsSipSvrID = flag.String("srs-server", "34020000002000000001", "The GB server ID for SIP")
srsLog = flag.Bool("srs-log", false, "Whether enable the detail log")
srsTimeout = flag.Int("srs-timeout", 11000, "For each case, the timeout in ms")
srsMediaTimeout = flag.Int("srs-media-timeout", 2100, "PS media disconnect timeout in ms")
srsReinviteTimeout = flag.Int("srs-reinvite-timeout", 1200, "When disconnect, SIP re-invite timeout in ms")
srsPublishAudio = flag.String("srs-publish-audio", "avatar.aac", "The audio file for publisher.")
srsPublishVideo = flag.String("srs-publish-video", "avatar.h264", "The video file for publisher.")
srsPublishVideoFps = flag.Int("srs-publish-video-fps", 25, "The video fps for publisher.")
// Should parse it first.
flag.Parse()
// Check file.
tryOpenFile := func(filename string) (string, error) {
if filename == "" {
return filename, nil
}
f, err := os.Open(filename)
if err != nil {
nfilename := path.Join("../", filename)
f2, err := os.Open(nfilename)
if err != nil {
return filename, errors.Wrapf(err, "No video file at %v or %v", filename, nfilename)
}
defer f2.Close()
return nfilename, nil
}
defer f.Close()
return filename, nil
}
if *srsPublishVideo, err = tryOpenFile(*srsPublishVideo); err != nil {
return err
}
if *srsPublishAudio, err = tryOpenFile(*srsPublishAudio); err != nil {
return err
}
return nil
}
type GBTestSession struct {
session *GBSession
}
func NewGBTestSession() *GBTestSession {
sipConfig := SIPConfig{
addr: *srsSipAddr,
domain: *srsSipDomain,
user: *srsSipUser,
random: *srsSipRandomID,
server: *srsSipSvrID,
}
return &GBTestSession{
session: NewGBSession(&GBSessionConfig{
regTimeout: time.Duration(*srsTimeout) * 5 * time.Minute,
inviteTimeout: time.Duration(*srsTimeout) * 5 * time.Minute,
}, &sipConfig),
}
}
func (v *GBTestSession) Close() error {
v.session.Close()
return nil
}
func (v *GBTestSession) Run(ctx context.Context) (err error) {
if err = v.session.Connect(ctx); err != nil {
return errors.Wrap(err, "connect")
}
if err = v.session.Register(ctx); err != nil {
return errors.Wrap(err, "register")
}
if err = v.session.Invite(ctx); err != nil {
return errors.Wrap(err, "invite")
}
return nil
}
type GBTestPublisher struct {
session *GBSession
ingester *PSIngester
}
func NewGBTestPublisher() *GBTestPublisher {
sipConfig := SIPConfig{
addr: *srsSipAddr,
domain: *srsSipDomain,
user: *srsSipUser,
random: *srsSipRandomID,
server: *srsSipSvrID,
}
psConfig := PSConfig{
video: *srsPublishVideo,
fps: *srsPublishVideoFps,
audio: *srsPublishAudio,
}
return &GBTestPublisher{
session: NewGBSession(&GBSessionConfig{
regTimeout: time.Duration(*srsTimeout) * 5 * time.Minute,
inviteTimeout: time.Duration(*srsTimeout) * 5 * time.Minute,
}, &sipConfig),
ingester: NewPSIngester(&IngesterConfig{
psConfig: psConfig,
}),
}
}
func (v *GBTestPublisher) Close() error {
v.ingester.Close()
v.session.Close()
return nil
}
func (v *GBTestPublisher) Run(ctx context.Context) (err error) {
if err = v.session.Connect(ctx); err != nil {
return errors.Wrap(err, "connect")
}
if err = v.session.Register(ctx); err != nil {
return errors.Wrap(err, "register")
}
if err = v.session.Invite(ctx); err != nil {
return errors.Wrap(err, "invite")
}
serverAddr, err := utilBuildMediaAddr(v.session.sip.conf.addr, v.session.out.mediaPort)
if err != nil {
return errors.Wrap(err, "parse")
}
v.ingester.conf.serverAddr = serverAddr
v.ingester.conf.ssrc = uint32(v.session.out.ssrc)
v.ingester.conf.clockRate = v.session.out.clockRate
v.ingester.conf.payloadType = uint8(v.session.out.payloadType)
if err := v.ingester.Ingest(ctx); err != nil {
return errors.Wrap(err, "ingest")
}
return nil
}
// Filter the test error, ignore context.Canceled
func filterTestError(errs ...error) error {
var filteredErrors []error
for _, err := range errs {
if err == nil || errors.Cause(err) == context.Canceled {
continue
}
// If url error, server maybe error, do not print the detail log.
if r0 := errors.Cause(err); r0 != nil {
if r1, ok := r0.(*url.Error); ok {
err = r1
}
}
filteredErrors = append(filteredErrors, err)
}
if len(filteredErrors) == 0 {
return nil
}
if len(filteredErrors) == 1 {
return filteredErrors[0]
}
var descs []string
for i, err := range filteredErrors[1:] {
descs = append(descs, fmt.Sprintf("err #%d, %+v", i, err))
}
return errors.Wrapf(filteredErrors[0], "with %v", strings.Join(descs, ","))
}
type wallClock struct {
start time.Time
duration time.Duration
}
func newWallClock() *wallClock {
return &wallClock{start: time.Now()}
}
func (v *wallClock) Tick(d time.Duration) time.Duration {
v.duration += d
wc := time.Now().Sub(v.start)
re := v.duration - wc
if re > 30*time.Millisecond {
return re
}
return 0
}
func sipGetCallID(m sip.Message) string {
if v, ok := m.CallID(); !ok {
return ""
} else {
return v.Value()
}
}
func utilBuildMediaAddr(addr string, mediaPort int64) (string, error) {
if u, err := url.Parse(addr); err != nil {
return "", errors.Wrapf(err, "parse %v", addr)
} else if addr, err := net.ResolveTCPAddr(u.Scheme, u.Host); err != nil {
return "", errors.Wrapf(err, "parse %v scheme=%v, host=%v", addr, u.Scheme, u.Host)
} else {
return fmt.Sprintf("%v://%v:%v",
u.Scheme, addr.IP.String(), mediaPort,
), nil
}
}
// See SrsMpegPES::decode
func utilUpdatePesPacketLength(pes *mpeg2.PesPacket) {
var nb_required int
if pes.PTS_DTS_flags == 0x2 {
nb_required += 5
}
if pes.PTS_DTS_flags == 0x3 {
nb_required += 10
}
if pes.ESCR_flag > 0 {
nb_required += 6
}
if pes.ES_rate_flag > 0 {
nb_required += 3
}
if pes.DSM_trick_mode_flag > 0 {
nb_required += 1
}
if pes.Additional_copy_info_flag > 0 {
nb_required += 1
}
if pes.PES_CRC_flag > 0 {
nb_required += 2
}
if pes.PES_extension_flag > 0 {
nb_required += 1
}
// Size before PES_header_data_length.
const fixed = uint16(3)
// Size after PES_header_data_length.
pes.PES_header_data_length = uint8(nb_required)
// Size after PES_packet_length
pes.PES_packet_length = uint16(len(pes.Pes_payload)) + fixed + uint16(pes.PES_header_data_length)
}
type AACReader struct {
codec aac.ADTS
r *bufio.Reader
}
func NewAACReader(f io.Reader) (*AACReader, error) {
v := &AACReader{}
var err error
if v.codec, err = aac.NewADTS(); err != nil {
return nil, err
}
v.r = bufio.NewReaderSize(f, 4096)
b, err := v.r.Peek(7 + 1024)
if err != nil {
return nil, err
}
if _, _, err = v.codec.Decode(b); err != nil {
return nil, err
}
return v, nil
}
func (v *AACReader) NextADTSFrame() ([]byte, error) {
b, err := v.r.Peek(7 + 1024)
if err != nil {
return nil, err
}
_, left, err := v.codec.Decode(b)
if err != nil {
return nil, err
}
adts := b[:len(b)-len(left)]
if _, err = v.r.Discard(len(adts)); err != nil {
return nil, err
}
return adts, nil
}

View file

@ -3,6 +3,7 @@ module github.com/ossrs/srs-bench
go 1.15
require (
github.com/ghettovoice/gosip v0.0.0-20220929080231-de8ba881be83
github.com/ossrs/go-oryx-lib v0.0.9
github.com/pion/interceptor v0.0.10
github.com/pion/logging v0.2.2
@ -11,4 +12,6 @@ require (
github.com/pion/sdp/v3 v3.0.4
github.com/pion/transport v0.12.2
github.com/pion/webrtc/v3 v3.0.13
github.com/yapingcat/gomedia/codec v0.0.0-20220617074658-94762898dc25
github.com/yapingcat/gomedia/mpeg2 v0.0.0-20220617074658-94762898dc25
)

View file

@ -1,33 +1,59 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/discoviking/fsm v0.0.0-20150126104936-f4a273feecca/go.mod h1:W+3LQaEkN8qAwwcw0KC546sUEnX86GIT8CcMLZC4mG0=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/ghettovoice/gosip v0.0.0-20220929080231-de8ba881be83 h1:4v14bwSGZH2usyuG9XWZgMbGkVU33ayg0cb68nvKfj0=
github.com/ghettovoice/gosip v0.0.0-20220929080231-de8ba881be83/go.mod h1:yTr3BEYSFe9As6XM7ldyrVgqsPwlnw8Ahc4N28VFM2g=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.1.0-rc.1 h1:VK3aeRXMI8osaS6YCDKNZhU6RKtcP3B2wzqxOogNDz8=
github.com/gobwas/ws v1.1.0-rc.1/go.mod h1:nzvNcVha5eUziGrbxFCo6qFIojQHjJV5cLYIbezhfL0=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.1.5 h1:kxhtnfFVi+rYdOALN0B3k9UT86zVJKfBimRaciULW4I=
github.com/google/uuid v1.1.5/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.5 h1:obHEce3upls1IBn1gTw/o7bCv7OJb6Ib/o7wNO+4eKw=
github.com/nxadm/tail v1.4.5/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M=
github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc=
github.com/onsi/gomega v1.10.4 h1:NiTx7EEvBzu9sFOD1zORteLSt3o8gnlvZZwSE9TnY9U=
github.com/onsi/gomega v1.10.4/go.mod h1:g/HbgYopi++010VEqkFgJHKC09uJiW9UkXvMUuKHUCQ=
github.com/ossrs/go-oryx-lib v0.0.9 h1:piZkzit/1hqAcXP31/mvDEDpHVjCmBMmvzF3hN8hUuQ=
github.com/ossrs/go-oryx-lib v0.0.9/go.mod h1:i2tH4TZBzAw5h+HwGrNOKvP/nmZgSQz0OEnLLdzcT/8=
github.com/pion/datachannel v1.4.21 h1:3ZvhNyfmxsAqltQrApLPQMhSFNA+aT87RqyCq4OXmf0=
@ -74,13 +100,28 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b h1:gQZ0qzfKHQIybLANtM3mBXNUtOfsCFXeTsnBqCsx1KM=
github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tevino/abool v0.0.0-20170917061928-9b9efcf221b5 h1:hNna6Fi0eP1f2sMBe/rJicDmaHmoXGe1Ta84FPYHLuE=
github.com/tevino/abool v0.0.0-20170917061928-9b9efcf221b5/go.mod h1:f1SCnEOt6sc3fOJfPQDRDzHOtSXuTtnz0ImG9kPRDV0=
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE=
github.com/yapingcat/gomedia/codec v0.0.0-20220609081842-9e0c0e8a19a0/go.mod h1:obSECV6X3NPUsLL0olA7DurvQHKMq7J3iBTNQ4bL/vQ=
github.com/yapingcat/gomedia/codec v0.0.0-20220617074658-94762898dc25 h1:1mq/skGEQGCqxHJPKfontELt/a052Gu236H0bge0Qr0=
github.com/yapingcat/gomedia/codec v0.0.0-20220617074658-94762898dc25/go.mod h1:obSECV6X3NPUsLL0olA7DurvQHKMq7J3iBTNQ4bL/vQ=
github.com/yapingcat/gomedia/mpeg2 v0.0.0-20220617074658-94762898dc25 h1:51qjqT2jsOESm/jDi0k0AdQX33Sg4vhw8X6eooj7c8A=
github.com/yapingcat/gomedia/mpeg2 v0.0.0-20220617074658-94762898dc25/go.mod h1:bvxj2Oi5Rwj7eHm2OjqgOIs8x2T0j+V068eS/SAyZLA=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
@ -95,12 +136,15 @@ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210119194325-5f4716e94777 h1:003p0dJM77cxMSyCPFphvZf/Y5/NXf5fzg6ufd1/Oew=
golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -108,12 +152,16 @@ golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201214095126-aec9a390925b h1:tv7/y4pd+sR8bcNb2D6o7BNU6zjWm0VjQLac+w7fNNM=
golang.org/x/sys v0.0.0-20201214095126-aec9a390925b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@ -124,14 +172,17 @@ google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -46,7 +46,7 @@ func Parse(ctx context.Context) {
fl := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
var sfu string
fl.StringVar(&sfu, "sfu", "srs", "The SFU server, srs or janus")
fl.StringVar(&sfu, "sfu", "srs", "The SFU server, srs or gb28181 or janus")
fl.StringVar(&sr, "sr", "", "")
fl.IntVar(&pli, "pli", 10, "")
@ -66,7 +66,7 @@ func Parse(ctx context.Context) {
fl.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or janus. Default: srs"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or gb28181 or janus. Default: srs"))
fmt.Println(fmt.Sprintf(" -nn The number of clients to simulate. Default: 1"))
fmt.Println(fmt.Sprintf(" -sn The number of streams to simulate. Variable: %%d. Default: 1"))
fmt.Println(fmt.Sprintf(" -delay The start delay in ms for each client or stream to simulate. Default: 50"))
@ -77,7 +77,7 @@ func Parse(ctx context.Context) {
fmt.Println(fmt.Sprintf(" -pli [Optional] PLI request interval in seconds. Default: 10"))
fmt.Println(fmt.Sprintf("Publisher:"))
fmt.Println(fmt.Sprintf(" -pr The url to publish. If sn exceed 1, auto append variable %%d."))
fmt.Println(fmt.Sprintf(" -fps The fps of .h264 source file."))
fmt.Println(fmt.Sprintf(" -fps [Optional] The fps of .h264 source file."))
fmt.Println(fmt.Sprintf(" -sa [Optional] The file path to read audio, ignore if empty."))
fmt.Println(fmt.Sprintf(" -sv [Optional] The file path to read video, ignore if empty."))
fmt.Println(fmt.Sprintf("\n例如1个播放1个推流:"))

View file

@ -25,6 +25,7 @@ import (
"flag"
"fmt"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/ossrs/srs-bench/gb28181"
"github.com/ossrs/srs-bench/janus"
"github.com/ossrs/srs-bench/srs"
"io/ioutil"
@ -37,18 +38,21 @@ func main() {
var sfu string
fl := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fl.SetOutput(ioutil.Discard)
fl.StringVar(&sfu, "sfu", "srs", "The SFU server, srs or janus")
fl.StringVar(&sfu, "sfu", "srs", "The SFU server, srs or gb28181 or janus")
_ = fl.Parse(os.Args[1:])
ctx := context.Background()
var conf interface{}
if sfu == "srs" {
srs.Parse(ctx)
} else if sfu == "gb28181" {
conf = gb28181.Parse(ctx)
} else if sfu == "janus" {
janus.Parse(ctx)
} else {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or janus. Default: srs"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or gb28181 or janus. Default: srs"))
os.Exit(-1)
}
@ -65,6 +69,8 @@ func main() {
var err error
if sfu == "srs" {
err = srs.Run(ctx)
} else if sfu == "gb28181" {
err = gb28181.Run(ctx, conf)
} else if sfu == "janus" {
err = janus.Run(ctx)
}

View file

@ -50,7 +50,7 @@ func Parse(ctx context.Context) {
fl := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
var sfu string
fl.StringVar(&sfu, "sfu", "srs", "The SFU server, srs or janus")
fl.StringVar(&sfu, "sfu", "srs", "The SFU server, srs or gb28181 or janus")
fl.StringVar(&sr, "sr", "", "")
fl.StringVar(&dumpAudio, "da", "", "")
@ -74,7 +74,7 @@ func Parse(ctx context.Context) {
fl.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or janus. Default: srs"))
fmt.Println(fmt.Sprintf(" -sfu The target SFU, srs or gb28181 or janus. Default: srs"))
fmt.Println(fmt.Sprintf(" -nn The number of clients to simulate. Default: 1"))
fmt.Println(fmt.Sprintf(" -sn The number of streams to simulate. Variable: %%d. Default: 1"))
fmt.Println(fmt.Sprintf(" -delay The start delay in ms for each client or stream to simulate. Default: 50"))
@ -88,7 +88,7 @@ func Parse(ctx context.Context) {
fmt.Println(fmt.Sprintf(" -pli [Optional] PLI request interval in seconds. Default: 10"))
fmt.Println(fmt.Sprintf("Publisher:"))
fmt.Println(fmt.Sprintf(" -pr The url to publish. If sn exceed 1, auto append variable %%d."))
fmt.Println(fmt.Sprintf(" -fps The fps of .h264 source file."))
fmt.Println(fmt.Sprintf(" -fps [Optional] The fps of .h264 source file."))
fmt.Println(fmt.Sprintf(" -sa [Optional] The file path to read audio, ignore if empty."))
fmt.Println(fmt.Sprintf(" -sv [Optional] The file path to read video, ignore if empty."))
fmt.Println(fmt.Sprintf("\n例如1个播放1个推流:"))

View file

@ -73,9 +73,7 @@ var srsPublishAvatar *string
var srsPublishBBB *string
var srsVnetClientIP *string
func prepareTest() error {
var err error
func prepareTest() (err error) {
srsHttps = flag.Bool("srs-https", false, "Whther connect to HTTPS-API")
srsServer = flag.String("srs-server", "127.0.0.1", "The RTC server to connect to")
srsStream = flag.String("srs-stream", "/rtc/regression", "The RTC app/stream to play")

View file

@ -0,0 +1,25 @@
BSD 2-Clause License
Copyright (c) 2017, The GoSIP authors.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,85 @@
package log
import (
"fmt"
"strings"
)
// Logger interface used as base logger throughout the library.
type Logger interface {
Print(args ...interface{})
Printf(format string, args ...interface{})
Trace(args ...interface{})
Tracef(format string, args ...interface{})
Debug(args ...interface{})
Debugf(format string, args ...interface{})
Info(args ...interface{})
Infof(format string, args ...interface{})
Warn(args ...interface{})
Warnf(format string, args ...interface{})
Error(args ...interface{})
Errorf(format string, args ...interface{})
Fatal(args ...interface{})
Fatalf(format string, args ...interface{})
Panic(args ...interface{})
Panicf(format string, args ...interface{})
WithPrefix(prefix string) Logger
Prefix() string
WithFields(fields Fields) Logger
Fields() Fields
SetLevel(level Level)
}
type Loggable interface {
Log() Logger
}
type Fields map[string]interface{}
func (fields Fields) String() string {
str := make([]string, 0)
for k, v := range fields {
str = append(str, fmt.Sprintf("%s=%+v", k, v))
}
return strings.Join(str, " ")
}
func (fields Fields) WithFields(newFields Fields) Fields {
allFields := make(Fields)
for k, v := range fields {
allFields[k] = v
}
for k, v := range newFields {
allFields[k] = v
}
return allFields
}
func AddFieldsFrom(logger Logger, values ...interface{}) Logger {
for _, value := range values {
switch v := value.(type) {
case Logger:
logger = logger.WithFields(v.Fields())
case Loggable:
logger = logger.WithFields(v.Log().Fields())
case interface{ Fields() Fields }:
logger = logger.WithFields(v.Fields())
}
}
return logger
}

View file

@ -0,0 +1,148 @@
package log
import (
"github.com/sirupsen/logrus"
prefixed "github.com/x-cray/logrus-prefixed-formatter"
)
type LogrusLogger struct {
log logrus.Ext1FieldLogger
prefix string
fields Fields
}
// Level type
type Level uint32
// These are the different logging levels. You can set the logging level to log
// on your instance of logger, obtained with `logrus.New()`.
const (
// PanicLevel level, highest level of severity. Logs and then calls panic with the
// message passed to Debug, Info, ...
PanicLevel Level = iota
// FatalLevel level. Logs and then calls `logger.Exit(1)`. It will exit even if the
// logging level is set to Panic.
FatalLevel
// ErrorLevel level. Logs. Used for errors that should definitely be noted.
// Commonly used for hooks to send errors to an error tracking service.
ErrorLevel
// WarnLevel level. Non-critical entries that deserve eyes.
WarnLevel
// InfoLevel level. General operational entries about what's going on inside the
// application.
InfoLevel
// DebugLevel level. Usually only enabled when debugging. Very verbose logging.
DebugLevel
// TraceLevel level. Designates finer-grained informational events than the Debug.
TraceLevel
)
func NewLogrusLogger(logrus logrus.Ext1FieldLogger, prefix string, fields Fields) *LogrusLogger {
return &LogrusLogger{
log: logrus,
prefix: prefix,
fields: fields,
}
}
func NewDefaultLogrusLogger() *LogrusLogger {
logger := logrus.New()
logger.Formatter = &prefixed.TextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05.000",
}
return NewLogrusLogger(logger, "main", nil)
}
func (l *LogrusLogger) Print(args ...interface{}) {
l.prepareEntry().Print(args...)
}
func (l *LogrusLogger) Printf(format string, args ...interface{}) {
l.prepareEntry().Printf(format, args...)
}
func (l *LogrusLogger) Trace(args ...interface{}) {
l.prepareEntry().Trace(args...)
}
func (l *LogrusLogger) Tracef(format string, args ...interface{}) {
l.prepareEntry().Tracef(format, args...)
}
func (l *LogrusLogger) Debug(args ...interface{}) {
l.prepareEntry().Debug(args...)
}
func (l *LogrusLogger) Debugf(format string, args ...interface{}) {
l.prepareEntry().Debugf(format, args...)
}
func (l *LogrusLogger) Info(args ...interface{}) {
l.prepareEntry().Info(args...)
}
func (l *LogrusLogger) Infof(format string, args ...interface{}) {
l.prepareEntry().Infof(format, args...)
}
func (l *LogrusLogger) Warn(args ...interface{}) {
l.prepareEntry().Warn(args...)
}
func (l *LogrusLogger) Warnf(format string, args ...interface{}) {
l.prepareEntry().Warnf(format, args...)
}
func (l *LogrusLogger) Error(args ...interface{}) {
l.prepareEntry().Error(args...)
}
func (l *LogrusLogger) Errorf(format string, args ...interface{}) {
l.prepareEntry().Errorf(format, args...)
}
func (l *LogrusLogger) Fatal(args ...interface{}) {
l.prepareEntry().Fatal(args...)
}
func (l *LogrusLogger) Fatalf(format string, args ...interface{}) {
l.prepareEntry().Fatalf(format, args...)
}
func (l *LogrusLogger) Panic(args ...interface{}) {
l.prepareEntry().Panic(args...)
}
func (l *LogrusLogger) Panicf(format string, args ...interface{}) {
l.prepareEntry().Panicf(format, args...)
}
func (l *LogrusLogger) WithPrefix(prefix string) Logger {
return NewLogrusLogger(l.log, prefix, l.Fields())
}
func (l *LogrusLogger) Prefix() string {
return l.prefix
}
func (l *LogrusLogger) WithFields(fields Fields) Logger {
return NewLogrusLogger(l.log, l.Prefix(), l.Fields().WithFields(fields))
}
func (l *LogrusLogger) Fields() Fields {
return l.fields
}
func (l *LogrusLogger) prepareEntry() *logrus.Entry {
return l.log.
WithFields(logrus.Fields(l.Fields())).
WithField("prefix", l.Prefix())
}
func (l *LogrusLogger) SetLevel(level Level) {
if ll, ok := l.log.(*logrus.Logger); ok {
ll.SetLevel(logrus.Level(level))
}
}

View file

@ -0,0 +1,279 @@
package sip
import (
"crypto/md5"
"encoding/hex"
"fmt"
"regexp"
"strings"
)
// currently only Digest and MD5
type Authorization struct {
realm string
nonce string
algorithm string
username string
password string
uri string
response string
method string
qop string
nc string
cnonce string
other map[string]string
}
func AuthFromValue(value string) *Authorization {
auth := &Authorization{
algorithm: "MD5",
other: make(map[string]string),
}
re := regexp.MustCompile(`([\w]+)="([^"]+)"`)
matches := re.FindAllStringSubmatch(value, -1)
for _, match := range matches {
switch match[1] {
case "realm":
auth.realm = match[2]
case "algorithm":
auth.algorithm = match[2]
case "nonce":
auth.nonce = match[2]
case "username":
auth.username = match[2]
case "uri":
auth.uri = match[2]
case "response":
auth.response = match[2]
case "qop":
for _, v := range strings.Split(match[2], ",") {
v = strings.Trim(v, " ")
if v == "auth" || v == "auth-int" {
auth.qop = "auth"
break
}
}
case "nc":
auth.nc = match[2]
case "cnonce":
auth.cnonce = match[2]
default:
auth.other[match[1]] = match[2]
}
}
return auth
}
func (auth *Authorization) Realm() string {
return auth.realm
}
func (auth *Authorization) Nonce() string {
return auth.nonce
}
func (auth *Authorization) Algorithm() string {
return auth.algorithm
}
func (auth *Authorization) Username() string {
return auth.username
}
func (auth *Authorization) SetUsername(username string) *Authorization {
auth.username = username
return auth
}
func (auth *Authorization) SetPassword(password string) *Authorization {
auth.password = password
return auth
}
func (auth *Authorization) Uri() string {
return auth.uri
}
func (auth *Authorization) SetUri(uri string) *Authorization {
auth.uri = uri
return auth
}
func (auth *Authorization) SetMethod(method string) *Authorization {
auth.method = method
return auth
}
func (auth *Authorization) Response() string {
return auth.response
}
func (auth *Authorization) SetResponse(response string) {
auth.response = response
}
func (auth *Authorization) Qop() string {
return auth.qop
}
func (auth *Authorization) SetQop(qop string) {
auth.qop = qop
}
func (auth *Authorization) Nc() string {
return auth.nc
}
func (auth *Authorization) SetNc(nc string) {
auth.nc = nc
}
func (auth *Authorization) CNonce() string {
return auth.cnonce
}
func (auth *Authorization) SetCNonce(cnonce string) {
auth.cnonce = cnonce
}
func (auth *Authorization) CalcResponse() string {
return calcResponse(
auth.username,
auth.realm,
auth.password,
auth.method,
auth.uri,
auth.nonce,
auth.qop,
auth.cnonce,
auth.nc,
)
}
func (auth *Authorization) String() string {
if auth == nil {
return "<nil>"
}
str := fmt.Sprintf(
`Digest realm="%s",algorithm=%s,nonce="%s",username="%s",uri="%s",response="%s"`,
auth.realm,
auth.algorithm,
auth.nonce,
auth.username,
auth.uri,
auth.response,
)
if auth.qop == "auth" {
str += fmt.Sprintf(`,qop=%s,nc=%s,cnonce="%s"`, auth.qop, auth.nc, auth.cnonce)
}
return str
}
// calculates Authorization response https://www.ietf.org/rfc/rfc2617.txt
func calcResponse(username, realm, password, method, uri, nonce, qop, cnonce, nc string) string {
calcA1 := func() string {
encoder := md5.New()
encoder.Write([]byte(username + ":" + realm + ":" + password))
return hex.EncodeToString(encoder.Sum(nil))
}
calcA2 := func() string {
encoder := md5.New()
encoder.Write([]byte(method + ":" + uri))
return hex.EncodeToString(encoder.Sum(nil))
}
encoder := md5.New()
encoder.Write([]byte(calcA1() + ":" + nonce + ":"))
if qop != "" {
encoder.Write([]byte(nc + ":" + cnonce + ":" + qop + ":"))
}
encoder.Write([]byte(calcA2()))
return hex.EncodeToString(encoder.Sum(nil))
}
func AuthorizeRequest(request Request, response Response, user, password MaybeString) error {
if user == nil {
return fmt.Errorf("authorize request: user is nil")
}
var authenticateHeaderName, authorizeHeaderName string
if response.StatusCode() == 401 {
// on 401 Unauthorized increase request seq num, add Authorization header and send once again
authenticateHeaderName = "WWW-Authenticate"
authorizeHeaderName = "Authorization"
} else {
// 407 Proxy authentication
authenticateHeaderName = "Proxy-Authenticate"
authorizeHeaderName = "Proxy-Authorization"
}
if hdrs := response.GetHeaders(authenticateHeaderName); len(hdrs) > 0 {
authenticateHeader := hdrs[0].(*GenericHeader)
auth := AuthFromValue(authenticateHeader.Contents).
SetMethod(string(request.Method())).
SetUri(request.Recipient().String()).
SetUsername(user.String())
if password != nil {
auth.SetPassword(password.String())
}
if auth.Qop() == "auth" {
auth.SetNc("00000001")
encoder := md5.New()
encoder.Write([]byte(user.String() + request.Recipient().String()))
if password != nil {
encoder.Write([]byte(password.String()))
}
auth.SetCNonce(hex.EncodeToString(encoder.Sum(nil)))
}
auth.SetResponse(auth.CalcResponse())
if hdrs = request.GetHeaders(authorizeHeaderName); len(hdrs) > 0 {
authorizationHeader := hdrs[0].Clone().(*GenericHeader)
authorizationHeader.Contents = auth.String()
request.ReplaceHeaders(authorizationHeader.Name(), []Header{authorizationHeader})
} else {
request.AppendHeader(&GenericHeader{
HeaderName: authorizeHeaderName,
Contents: auth.String(),
})
}
} else {
return fmt.Errorf("authorize request: header '%s' not found in response", authenticateHeaderName)
}
if viaHop, ok := request.ViaHop(); ok {
viaHop.Params.Add("branch", String{Str: GenerateBranch()})
}
if cseq, ok := request.CSeq(); ok {
cseq := cseq.Clone().(*CSeq)
cseq.SeqNo++
request.ReplaceHeaders(cseq.Name(), []Header{cseq})
}
return nil
}
type Authorizer interface {
AuthorizeRequest(request Request, response Response) error
}
type DefaultAuthorizer struct {
User MaybeString
Password MaybeString
}
func (auth *DefaultAuthorizer) AuthorizeRequest(request Request, response Response) error {
return AuthorizeRequest(request, response, auth.User, auth.Password)
}

View file

@ -0,0 +1,336 @@
package sip
import (
"fmt"
"github.com/ghettovoice/gosip/util"
)
type RequestBuilder struct {
protocol string
protocolVersion string
transport string
host string
method RequestMethod
cseq *CSeq
recipient Uri
body string
callID *CallID
via ViaHeader
from *FromHeader
to *ToHeader
contact *ContactHeader
expires *Expires
userAgent *UserAgentHeader
maxForwards *MaxForwards
supported *SupportedHeader
require *RequireHeader
allow AllowHeader
contentType *ContentType
accept *Accept
route *RouteHeader
generic map[string]Header
}
func NewRequestBuilder() *RequestBuilder {
callID := CallID(util.RandString(32))
maxForwards := MaxForwards(70)
userAgent := UserAgentHeader("GoSIP")
rb := &RequestBuilder{
protocol: "SIP",
protocolVersion: "2.0",
transport: "UDP",
host: "localhost",
cseq: &CSeq{SeqNo: 1},
body: "",
via: make(ViaHeader, 0),
callID: &callID,
userAgent: &userAgent,
maxForwards: &maxForwards,
generic: make(map[string]Header),
}
return rb
}
func (rb *RequestBuilder) SetTransport(transport string) *RequestBuilder {
if transport == "" {
rb.transport = "UDP"
} else {
rb.transport = transport
}
return rb
}
func (rb *RequestBuilder) SetHost(host string) *RequestBuilder {
if host == "" {
rb.host = "localhost"
} else {
rb.host = host
}
return rb
}
func (rb *RequestBuilder) SetMethod(method RequestMethod) *RequestBuilder {
rb.method = method
rb.cseq.MethodName = method
return rb
}
func (rb *RequestBuilder) SetSeqNo(seqNo uint) *RequestBuilder {
rb.cseq.SeqNo = uint32(seqNo)
return rb
}
func (rb *RequestBuilder) SetRecipient(uri Uri) *RequestBuilder {
rb.recipient = uri.Clone()
return rb
}
func (rb *RequestBuilder) SetBody(body string) *RequestBuilder {
rb.body = body
return rb
}
func (rb *RequestBuilder) SetCallID(callID *CallID) *RequestBuilder {
if callID != nil {
rb.callID = callID
}
return rb
}
func (rb *RequestBuilder) AddVia(via *ViaHop) *RequestBuilder {
if via.ProtocolName == "" {
via.ProtocolName = rb.protocol
}
if via.ProtocolVersion == "" {
via.ProtocolVersion = rb.protocolVersion
}
if via.Transport == "" {
via.Transport = rb.transport
}
if via.Host == "" {
via.Host = rb.host
}
if via.Params == nil {
via.Params = NewParams()
}
rb.via = append(rb.via, via)
return rb
}
func (rb *RequestBuilder) SetFrom(address *Address) *RequestBuilder {
if address == nil {
rb.from = nil
} else {
address = address.Clone()
if address.Uri.Host() == "" {
address.Uri.SetHost(rb.host)
}
rb.from = &FromHeader{
DisplayName: address.DisplayName,
Address: address.Uri,
Params: address.Params,
}
}
return rb
}
func (rb *RequestBuilder) SetTo(address *Address) *RequestBuilder {
if address == nil {
rb.to = nil
} else {
address = address.Clone()
if address.Uri.Host() == "" {
address.Uri.SetHost(rb.host)
}
rb.to = &ToHeader{
DisplayName: address.DisplayName,
Address: address.Uri,
Params: address.Params,
}
}
return rb
}
func (rb *RequestBuilder) SetContact(address *Address) *RequestBuilder {
if address == nil {
rb.contact = nil
} else {
address = address.Clone()
if address.Uri.Host() == "" {
address.Uri.SetHost(rb.host)
}
rb.contact = &ContactHeader{
DisplayName: address.DisplayName,
Address: address.Uri,
Params: address.Params,
}
}
return rb
}
func (rb *RequestBuilder) SetExpires(expires *Expires) *RequestBuilder {
rb.expires = expires
return rb
}
func (rb *RequestBuilder) SetUserAgent(userAgent *UserAgentHeader) *RequestBuilder {
rb.userAgent = userAgent
return rb
}
func (rb *RequestBuilder) SetMaxForwards(maxForwards *MaxForwards) *RequestBuilder {
rb.maxForwards = maxForwards
return rb
}
func (rb *RequestBuilder) SetAllow(methods []RequestMethod) *RequestBuilder {
rb.allow = methods
return rb
}
func (rb *RequestBuilder) SetSupported(options []string) *RequestBuilder {
if len(options) == 0 {
rb.supported = nil
} else {
rb.supported = &SupportedHeader{
Options: options,
}
}
return rb
}
func (rb *RequestBuilder) SetRequire(options []string) *RequestBuilder {
if len(options) == 0 {
rb.require = nil
} else {
rb.require = &RequireHeader{
Options: options,
}
}
return rb
}
func (rb *RequestBuilder) SetContentType(contentType *ContentType) *RequestBuilder {
rb.contentType = contentType
return rb
}
func (rb *RequestBuilder) SetAccept(accept *Accept) *RequestBuilder {
rb.accept = accept
return rb
}
func (rb *RequestBuilder) SetRoutes(routes []Uri) *RequestBuilder {
if len(routes) == 0 {
rb.route = nil
} else {
rb.route = &RouteHeader{
Addresses: routes,
}
}
return rb
}
func (rb *RequestBuilder) AddHeader(header Header) *RequestBuilder {
rb.generic[header.Name()] = header
return rb
}
func (rb *RequestBuilder) RemoveHeader(headerName string) *RequestBuilder {
if _, ok := rb.generic[headerName]; ok {
delete(rb.generic, headerName)
}
return rb
}
func (rb *RequestBuilder) Build() (Request, error) {
if rb.method == "" {
return nil, fmt.Errorf("undefined method name")
}
if rb.recipient == nil {
return nil, fmt.Errorf("empty recipient")
}
if rb.from == nil {
return nil, fmt.Errorf("empty 'From' header")
}
if rb.to == nil {
return nil, fmt.Errorf("empty 'From' header")
}
hdrs := make([]Header, 0)
if rb.route != nil {
hdrs = append(hdrs, rb.route)
}
if len(rb.via) != 0 {
via := make(ViaHeader, 0)
for _, viaHop := range rb.via {
via = append(via, viaHop)
}
hdrs = append(hdrs, via)
}
hdrs = append(hdrs, rb.cseq, rb.from, rb.to, rb.callID)
if rb.contact != nil {
hdrs = append(hdrs, rb.contact)
}
if rb.maxForwards != nil {
hdrs = append(hdrs, rb.maxForwards)
}
if rb.expires != nil {
hdrs = append(hdrs, rb.expires)
}
if rb.supported != nil {
hdrs = append(hdrs, rb.supported)
}
if rb.allow != nil {
hdrs = append(hdrs, rb.allow)
}
if rb.contentType != nil {
hdrs = append(hdrs, rb.contentType)
}
if rb.accept != nil {
hdrs = append(hdrs, rb.accept)
}
if rb.userAgent != nil {
hdrs = append(hdrs, rb.userAgent)
}
for _, header := range rb.generic {
hdrs = append(hdrs, header)
}
sipVersion := rb.protocol + "/" + rb.protocolVersion
// basic request
req := NewRequest("", rb.method, rb.recipient, sipVersion, hdrs, "", nil)
req.SetBody(rb.body, true)
return req, nil
}

View file

@ -0,0 +1,412 @@
package sip
import (
"bytes"
"fmt"
"strings"
"github.com/ghettovoice/gosip/util"
)
const (
MTU uint = 1500
DefaultHost = "127.0.0.1"
DefaultProtocol = "UDP"
DefaultUdpPort Port = 5060
DefaultTcpPort Port = 5060
DefaultTlsPort Port = 5061
DefaultWsPort Port = 80
DefaultWssPort Port = 443
)
// TODO should be refactored, currently here the pit
type Address struct {
DisplayName MaybeString
Uri Uri
Params Params
}
func NewAddressFromFromHeader(from *FromHeader) *Address {
addr := &Address{
DisplayName: from.DisplayName,
}
if from.Address != nil {
addr.Uri = from.Address.Clone()
}
if from.Params != nil {
addr.Params = from.Params.Clone()
}
return addr
}
func NewAddressFromToHeader(to *ToHeader) *Address {
addr := &Address{
DisplayName: to.DisplayName,
}
if to.Address != nil {
addr.Uri = to.Address.Clone()
}
if to.Params != nil {
addr.Params = to.Params.Clone()
}
return addr
}
func NewAddressFromContactHeader(cnt *ContactHeader) *Address {
addr := &Address{
DisplayName: cnt.DisplayName,
}
if cnt.Address != nil {
addr.Uri = cnt.Address.Clone()
}
if cnt.Params != nil {
addr.Params = cnt.Params.Clone()
}
return addr
}
func (addr *Address) String() string {
var buffer bytes.Buffer
if addr == nil {
return "<nil>"
}
if addr.DisplayName != nil {
if displayName, ok := addr.DisplayName.(String); ok && displayName.String() != "" {
buffer.WriteString(fmt.Sprintf("\"%s\" ", displayName))
}
}
buffer.WriteString(fmt.Sprintf("<%s>", addr.Uri))
if addr.Params != nil && addr.Params.Length() > 0 {
buffer.WriteString(";")
buffer.WriteString(addr.Params.ToString(';'))
}
return buffer.String()
}
func (addr *Address) Clone() *Address {
var name MaybeString
var uri Uri
var params Params
if addr.DisplayName != nil {
name = String{Str: addr.DisplayName.String()}
}
if addr.Uri != nil {
uri = addr.Uri.Clone()
}
if addr.Params != nil {
params = addr.Params.Clone()
}
return &Address{
DisplayName: name,
Uri: uri,
Params: params,
}
}
func (addr *Address) Equals(other interface{}) bool {
otherPtr, ok := other.(*Address)
if !ok {
return false
}
if addr == otherPtr {
return true
}
if addr == nil && otherPtr != nil || addr != nil && otherPtr == nil {
return false
}
res := true
if addr.DisplayName != otherPtr.DisplayName {
if addr.DisplayName == nil {
res = res && otherPtr.DisplayName == nil
} else {
res = res && addr.DisplayName.Equals(otherPtr.DisplayName)
}
}
if addr.Uri != otherPtr.Uri {
if addr.Uri == nil {
res = res && otherPtr.Uri == nil
} else {
res = res && addr.Uri.Equals(otherPtr.Uri)
}
}
if addr.Params != otherPtr.Params {
if addr.Params == nil {
res = res && otherPtr.Params == nil
} else {
res = res && addr.Params.Equals(otherPtr.Params)
}
}
return res
}
func (addr *Address) AsToHeader() *ToHeader {
to := &ToHeader{
DisplayName: addr.DisplayName,
}
if addr.Uri != nil {
to.Address = addr.Uri.Clone()
}
if addr.Params != nil {
to.Params = addr.Params.Clone()
}
return to
}
func (addr *Address) AsFromHeader() *FromHeader {
from := &FromHeader{
DisplayName: addr.DisplayName,
}
if addr.Uri != nil {
from.Address = addr.Uri.Clone()
}
if addr.Params != nil {
from.Params = addr.Params.Clone()
}
return from
}
func (addr *Address) AsContactHeader() *ContactHeader {
cnt := &ContactHeader{
DisplayName: addr.DisplayName,
}
if addr.Uri != nil {
cnt.Address = addr.Uri.Clone()
}
if addr.Params != nil {
cnt.Params = addr.Params.Clone()
}
return cnt
}
// Port number
type Port uint16
func (port *Port) Clone() *Port {
if port == nil {
return nil
}
newPort := *port
return &newPort
}
func (port *Port) String() string {
if port == nil {
return ""
}
return fmt.Sprintf("%d", *port)
}
func (port *Port) Equals(other interface{}) bool {
if p, ok := other.(*Port); ok {
return util.Uint16PtrEq((*uint16)(port), (*uint16)(p))
}
return false
}
// String wrapper
type MaybeString interface {
String() string
Equals(other interface{}) bool
}
type String struct {
Str string
}
func (str String) String() string {
return str.Str
}
func (str String) Equals(other interface{}) bool {
if v, ok := other.(string); ok {
return str.Str == v
}
if v, ok := other.(String); ok {
return str.Str == v.Str
}
return false
}
type CancelError interface {
Canceled() bool
}
type ExpireError interface {
Expired() bool
}
type MessageError interface {
error
// Malformed indicates that message is syntactically valid but has invalid headers, or
// without required headers.
Malformed() bool
// Broken or incomplete message, or not a SIP message
Broken() bool
}
// Broken or incomplete messages, or not a SIP message.
type BrokenMessageError struct {
Err error
Msg string
}
func (err *BrokenMessageError) Malformed() bool { return false }
func (err *BrokenMessageError) Broken() bool { return true }
func (err *BrokenMessageError) Error() string {
if err == nil {
return "<nil>"
}
s := "BrokenMessageError: " + err.Err.Error()
if err.Msg != "" {
s += fmt.Sprintf("\nMessage dump:\n%s", err.Msg)
}
return s
}
// syntactically valid but logically invalid message
type MalformedMessageError struct {
Err error
Msg string
}
func (err *MalformedMessageError) Malformed() bool { return true }
func (err *MalformedMessageError) Broken() bool { return false }
func (err *MalformedMessageError) Error() string {
if err == nil {
return "<nil>"
}
s := "MalformedMessageError: " + err.Err.Error()
if err.Msg != "" {
s += fmt.Sprintf("\nMessage dump:\n%s", err.Msg)
}
return s
}
type UnsupportedMessageError struct {
Err error
Msg string
}
func (err *UnsupportedMessageError) Malformed() bool { return true }
func (err *UnsupportedMessageError) Broken() bool { return false }
func (err *UnsupportedMessageError) Error() string {
if err == nil {
return "<nil>"
}
s := "UnsupportedMessageError: " + err.Err.Error()
if err.Msg != "" {
s += fmt.Sprintf("\nMessage dump:\n%s", err.Msg)
}
return s
}
type UnexpectedMessageError struct {
Err error
Msg string
}
func (err *UnexpectedMessageError) Broken() bool { return false }
func (err *UnexpectedMessageError) Malformed() bool { return false }
func (err *UnexpectedMessageError) Error() string {
if err == nil {
return "<nil>"
}
s := "UnexpectedMessageError: " + err.Err.Error()
if err.Msg != "" {
s += fmt.Sprintf("\nMessage dump:\n%s", err.Msg)
}
return s
}
const RFC3261BranchMagicCookie = "z9hG4bK"
// GenerateBranch returns random unique branch ID.
func GenerateBranch() string {
return strings.Join([]string{
RFC3261BranchMagicCookie,
util.RandString(32),
}, ".")
}
// DefaultPort returns protocol default port by network.
func DefaultPort(protocol string) Port {
switch strings.ToLower(protocol) {
case "tls":
return DefaultTlsPort
case "tcp":
return DefaultTcpPort
case "udp":
return DefaultUdpPort
case "ws":
return DefaultWsPort
case "wss":
return DefaultWssPort
default:
return DefaultTcpPort
}
}
func MakeDialogIDFromMessage(msg Message) (string, error) {
callID, ok := msg.CallID()
if !ok {
return "", fmt.Errorf("missing Call-ID header")
}
to, ok := msg.To()
if !ok {
return "", fmt.Errorf("missing To header")
}
toTag, ok := to.Params.Get("tag")
if !ok {
return "", fmt.Errorf("missing tag param in To header")
}
from, ok := msg.From()
if !ok {
return "", fmt.Errorf("missing To header")
}
fromTag, ok := from.Params.Get("tag")
if !ok {
return "", fmt.Errorf("missing tag param in From header")
}
return MakeDialogID(string(*callID), toTag.String(), fromTag.String()), nil
}
func MakeDialogID(callID, innerID, externalID string) string {
return strings.Join([]string{callID, innerID, externalID}, "__")
}

View file

@ -0,0 +1,37 @@
package sip
import "fmt"
type RequestError struct {
Request Request
Response Response
Code uint
Reason string
}
func NewRequestError(code uint, reason string, request Request, response Response) *RequestError {
err := &RequestError{
Code: code,
Reason: reason,
}
if request != nil {
err.Request = CopyRequest(request)
}
if response != nil {
err.Response = CopyResponse(response)
}
return err
}
func (err *RequestError) Error() string {
if err == nil {
return "<nil>"
}
reason := err.Reason
if err.Code != 0 {
reason += fmt.Sprintf(" (Code %d)", err.Code)
}
return fmt.Sprintf("sip.RequestError: request failed with reason '%s'", reason)
}

View file

@ -0,0 +1,230 @@
package sip
import (
"strconv"
"strings"
)
// Copyright 2009 The Go Authors. All rights reserved.
// This is actually shorten copy of escape/unescape helpers of the net/url package.
type encoding int
const (
EncodeUserPassword encoding = 1 + iota
EncodeHost
EncodeZone
EncodeQueryComponent
)
type EscapeError string
func (e EscapeError) Error() string {
return "invalid URL escape " + strconv.Quote(string(e))
}
type InvalidHostError string
func (e InvalidHostError) Error() string {
return "invalid character " + strconv.Quote(string(e)) + " in host name"
}
// unescape unescapes a string; the mode specifies
// which section of the URL string is being unescaped.
func Unescape(s string, mode encoding) (string, error) {
// Count %, check that they're well-formed.
n := 0
hasPlus := false
for i := 0; i < len(s); {
switch s[i] {
case '%':
n++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[:3]
}
return "", EscapeError(s)
}
// Per https://tools.ietf.org/html/rfc3986#page-21
// in the host component %-encoding can only be used
// for non-ASCII bytes.
// But https://tools.ietf.org/html/rfc6874#section-2
// introduces %25 being allowed to escape a percent sign
// in IPv6 scoped-address literals. Yay.
if mode == EncodeHost && unhex(s[i+1]) < 8 && s[i:i+3] != "%25" {
return "", EscapeError(s[i : i+3])
}
if mode == EncodeZone {
// RFC 6874 says basically "anything goes" for zone identifiers
// and that even non-ASCII can be redundantly escaped,
// but it seems prudent to restrict %-escaped bytes here to those
// that are valid host name bytes in their unescaped form.
// That is, you can use escaping in the zone identifier but not
// to introduce bytes you couldn't just write directly.
// But Windows puts spaces here! Yay.
v := unhex(s[i+1])<<4 | unhex(s[i+2])
if s[i:i+3] != "%25" && v != ' ' && shouldEscape(v, EncodeHost) {
return "", EscapeError(s[i : i+3])
}
}
i += 3
case '+':
hasPlus = mode == EncodeQueryComponent
i++
default:
if (mode == EncodeHost || mode == EncodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) {
return "", InvalidHostError(s[i : i+1])
}
i++
}
}
if n == 0 && !hasPlus {
return s, nil
}
var t strings.Builder
t.Grow(len(s) - 2*n)
for i := 0; i < len(s); i++ {
switch s[i] {
case '%':
t.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2]))
i += 2
case '+':
t.WriteByte('+')
default:
t.WriteByte(s[i])
}
}
return t.String(), nil
}
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}
// Return true if the specified character should be escaped when
// appearing in a URL string, according to RFC 3986.
//
// Please be informed that for now shouldEscape does not check all
// reserved characters correctly. See golang.org/issue/5684.
func shouldEscape(c byte, mode encoding) bool {
// §2.3 Unreserved characters (alphanum)
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
return false
}
if mode == EncodeHost || mode == EncodeZone {
// §3.2.2 Host allows
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "="
// as part of reg-name.
// We add : because we include :port as part of host.
// We add [ ] because we include [ipv6]:port as part of host.
// We add < > because they're the only characters left that
// we could possibly allow, and Parse will reject them if we
// escape them (because hosts can't use %-encoding for
// ASCII bytes).
switch c {
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"':
return false
}
}
switch c {
case '-', '_', '.', '~': // §2.3 Unreserved characters (mark)
return false
case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved)
// Different sections of the URL allow a few of
// the reserved characters to appear unescaped.
switch mode {
case EncodeUserPassword: // §3.2.1
// The RFC allows ';', ':', '&', '=', '+', '$', and ',' in
// userinfo, so we must escape only '@', '/', and '?'.
// The parsing of userinfo treats ':' as special so we must escape
// that too.
return c == '@' || c == '/' || c == '?' || c == ':'
case EncodeQueryComponent: // §3.4
// The RFC reserves (so we must escape) everything.
return true
}
}
// Everything else must be escaped.
return true
}
const upperhex = "0123456789ABCDEF"
func Escape(s string, mode encoding) string {
spaceCount, hexCount := 0, 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c, mode) {
if c == ' ' && mode == EncodeQueryComponent {
spaceCount++
} else {
hexCount++
}
}
}
if spaceCount == 0 && hexCount == 0 {
return s
}
var buf [64]byte
var t []byte
required := len(s) + 2*hexCount
if required <= len(buf) {
t = buf[:required]
} else {
t = make([]byte, required)
}
if hexCount == 0 {
copy(t, s)
return string(t)
}
j := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case c == ' ' && mode == EncodeQueryComponent:
t[j] = c
j++
case shouldEscape(c, mode):
t[j] = '%'
t[j+1] = upperhex[c>>4]
t[j+2] = upperhex[c&15]
j += 3
default:
t[j] = s[i]
j++
}
}
return string(t)
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,542 @@
package sip
import (
"bytes"
"strings"
"sync"
uuid "github.com/satori/go.uuid"
"github.com/ghettovoice/gosip/log"
)
// A representation of a SIP method.
// This is syntactic sugar around the string type, so make sure to use
// the Equals method rather than built-in equality, or you'll fall foul of case differences.
// If you're defining your own Method, uppercase is preferred but not compulsory.
type RequestMethod string
// StatusCode - response status code: 1xx - 6xx
type StatusCode uint16
// Determine if the given method equals some other given method.
// This is syntactic sugar for case insensitive equality checking.
func (method *RequestMethod) Equals(other *RequestMethod) bool {
if method != nil && other != nil {
return strings.EqualFold(string(*method), string(*other))
} else {
return method == other
}
}
// It's nicer to avoid using raw strings to represent methods, so the following standard
// method names are defined here as constants for convenience.
const (
INVITE RequestMethod = "INVITE"
ACK RequestMethod = "ACK"
CANCEL RequestMethod = "CANCEL"
BYE RequestMethod = "BYE"
REGISTER RequestMethod = "REGISTER"
OPTIONS RequestMethod = "OPTIONS"
SUBSCRIBE RequestMethod = "SUBSCRIBE"
NOTIFY RequestMethod = "NOTIFY"
REFER RequestMethod = "REFER"
INFO RequestMethod = "INFO"
MESSAGE RequestMethod = "MESSAGE"
PRACK RequestMethod = "PRACK"
UPDATE RequestMethod = "UPDATE"
PUBLISH RequestMethod = "PUBLISH"
)
type MessageID string
func NextMessageID() MessageID {
return MessageID(uuid.Must(uuid.NewV4()).String())
}
// Message introduces common SIP message RFC 3261 - 7.
type Message interface {
MessageID() MessageID
Clone() Message
// Start line returns message start line.
StartLine() string
// String returns string representation of SIP message in RFC 3261 form.
String() string
// Short returns short string info about message.
Short() string
// SipVersion returns SIP protocol version.
SipVersion() string
// SetSipVersion sets SIP protocol version.
SetSipVersion(version string)
// Headers returns all message headers.
Headers() []Header
// GetHeaders returns slice of headers of the given type.
GetHeaders(name string) []Header
// AppendHeader appends header to message.
AppendHeader(header Header)
// PrependHeader prepends header to message.
PrependHeader(header Header)
PrependHeaderAfter(header Header, afterName string)
// RemoveHeader removes header from message.
RemoveHeader(name string)
ReplaceHeaders(name string, headers []Header)
// Body returns message body.
Body() string
// SetBody sets message body.
SetBody(body string, setContentLength bool)
/* Helper getters for common headers */
// CallID returns 'Call-ID' header.
CallID() (*CallID, bool)
// Via returns the top 'Via' header field.
Via() (ViaHeader, bool)
// ViaHop returns the first segment of the top 'Via' header.
ViaHop() (*ViaHop, bool)
// From returns 'From' header field.
From() (*FromHeader, bool)
// To returns 'To' header field.
To() (*ToHeader, bool)
// CSeq returns 'CSeq' header field.
CSeq() (*CSeq, bool)
ContentLength() (*ContentLength, bool)
ContentType() (*ContentType, bool)
Contact() (*ContactHeader, bool)
Transport() string
SetTransport(tp string)
Source() string
SetSource(src string)
Destination() string
SetDestination(dest string)
IsCancel() bool
IsAck() bool
Fields() log.Fields
WithFields(fields log.Fields) Message
}
// headers is a struct with methods to work with SIP headers.
type headers struct {
mu sync.RWMutex
// The logical SIP headers attached to this message.
headers map[string][]Header
// The order the headers should be displayed in.
headerOrder []string
}
func newHeaders(hdrs []Header) *headers {
hs := new(headers)
hs.headers = make(map[string][]Header)
hs.headerOrder = make([]string, 0)
for _, header := range hdrs {
hs.AppendHeader(header)
}
return hs
}
func (hs *headers) String() string {
buffer := bytes.Buffer{}
hs.mu.RLock()
// Construct each header in turn and add it to the message.
for typeIdx, name := range hs.headerOrder {
headers := hs.headers[name]
for idx, header := range headers {
buffer.WriteString(header.String())
if typeIdx < len(hs.headerOrder) || idx < len(headers) {
buffer.WriteString("\r\n")
}
}
}
hs.mu.RUnlock()
return buffer.String()
}
// Add the given header.
func (hs *headers) AppendHeader(header Header) {
name := strings.ToLower(header.Name())
hs.mu.Lock()
if _, ok := hs.headers[name]; ok {
hs.headers[name] = append(hs.headers[name], header)
} else {
hs.headers[name] = []Header{header}
hs.headerOrder = append(hs.headerOrder, name)
}
hs.mu.Unlock()
}
// AddFrontHeader adds header to the front of header list
// if there is no header has h's name, add h to the font of all headers
// if there are some headers have h's name, add h to front of the sublist
func (hs *headers) PrependHeader(header Header) {
name := strings.ToLower(header.Name())
hs.mu.Lock()
if hdrs, ok := hs.headers[name]; ok {
hs.headers[name] = append([]Header{header}, hdrs...)
} else {
hs.headers[name] = []Header{header}
newOrder := make([]string, 1, len(hs.headerOrder)+1)
newOrder[0] = name
hs.headerOrder = append(newOrder, hs.headerOrder...)
}
hs.mu.Unlock()
}
func (hs *headers) PrependHeaderAfter(header Header, afterName string) {
headerName := strings.ToLower(header.Name())
afterName = strings.ToLower(afterName)
hs.mu.Lock()
if _, ok := hs.headers[afterName]; ok {
afterIdx := -1
headerIdx := -1
for i, name := range hs.headerOrder {
if name == afterName {
afterIdx = i
}
if name == headerName {
headerIdx = i
}
}
if headerIdx == -1 {
hs.headers[headerName] = []Header{header}
newOrder := make([]string, 0)
newOrder = append(newOrder, hs.headerOrder[:afterIdx+1]...)
newOrder = append(newOrder, headerName)
newOrder = append(newOrder, hs.headerOrder[afterIdx+1:]...)
hs.headerOrder = newOrder
} else {
hs.headers[headerName] = append([]Header{header}, hs.headers[headerName]...)
newOrder := make([]string, 0)
if afterIdx < headerIdx {
newOrder = append(newOrder, hs.headerOrder[:afterIdx+1]...)
newOrder = append(newOrder, headerName)
newOrder = append(newOrder, hs.headerOrder[afterIdx+1:headerIdx]...)
newOrder = append(newOrder, hs.headerOrder[headerIdx+1:]...)
} else {
newOrder = append(newOrder, hs.headerOrder[:headerIdx]...)
newOrder = append(newOrder, hs.headerOrder[headerIdx+1:afterIdx+1]...)
newOrder = append(newOrder, headerName)
newOrder = append(newOrder, hs.headerOrder[afterIdx+1:]...)
}
hs.headerOrder = newOrder
}
hs.mu.Unlock()
} else {
hs.mu.Unlock()
hs.PrependHeader(header)
}
}
func (hs *headers) ReplaceHeaders(name string, headers []Header) {
name = strings.ToLower(name)
hs.mu.Lock()
if _, ok := hs.headers[name]; ok {
hs.headers[name] = headers
}
hs.mu.Unlock()
}
// Gets some headers.
func (hs *headers) Headers() []Header {
hdrs := make([]Header, 0)
hs.mu.RLock()
for _, key := range hs.headerOrder {
hdrs = append(hdrs, hs.headers[key]...)
}
hs.mu.RUnlock()
return hdrs
}
func (hs *headers) GetHeaders(name string) []Header {
name = strings.ToLower(name)
hs.mu.RLock()
defer hs.mu.RUnlock()
if hs.headers == nil {
hs.headers = map[string][]Header{}
hs.headerOrder = []string{}
}
if headers, ok := hs.headers[name]; ok {
return headers
}
return []Header{}
}
func (hs *headers) RemoveHeader(name string) {
name = strings.ToLower(name)
hs.mu.Lock()
delete(hs.headers, name)
// update order slice
for idx, entry := range hs.headerOrder {
if entry == name {
hs.headerOrder = append(hs.headerOrder[:idx], hs.headerOrder[idx+1:]...)
break
}
}
hs.mu.Unlock()
}
// CloneHeaders returns all cloned headers in slice.
func (hs *headers) CloneHeaders() []Header {
return cloneHeaders(hs)
}
func cloneHeaders(msg interface{ Headers() []Header }) []Header {
hdrs := make([]Header, 0)
for _, header := range msg.Headers() {
hdrs = append(hdrs, header.Clone())
}
return hdrs
}
func (hs *headers) CallID() (*CallID, bool) {
hdrs := hs.GetHeaders("Call-ID")
if len(hdrs) == 0 {
return nil, false
}
callId, ok := hdrs[0].(*CallID)
if !ok {
return nil, false
}
return callId, true
}
func (hs *headers) Via() (ViaHeader, bool) {
hdrs := hs.GetHeaders("Via")
if len(hdrs) == 0 {
return nil, false
}
via, ok := (hdrs[0]).(ViaHeader)
if !ok {
return nil, false
}
return via, true
}
func (hs *headers) ViaHop() (*ViaHop, bool) {
via, ok := hs.Via()
if !ok {
return nil, false
}
hops := []*ViaHop(via)
if len(hops) == 0 {
return nil, false
}
return hops[0], true
}
func (hs *headers) From() (*FromHeader, bool) {
hdrs := hs.GetHeaders("From")
if len(hdrs) == 0 {
return nil, false
}
from, ok := hdrs[0].(*FromHeader)
if !ok {
return nil, false
}
return from, true
}
func (hs *headers) To() (*ToHeader, bool) {
hdrs := hs.GetHeaders("To")
if len(hdrs) == 0 {
return nil, false
}
to, ok := hdrs[0].(*ToHeader)
if !ok {
return nil, false
}
return to, true
}
func (hs *headers) CSeq() (*CSeq, bool) {
hdrs := hs.GetHeaders("CSeq")
if len(hdrs) == 0 {
return nil, false
}
cseq, ok := hdrs[0].(*CSeq)
if !ok {
return nil, false
}
return cseq, true
}
func (hs *headers) ContentLength() (*ContentLength, bool) {
hdrs := hs.GetHeaders("Content-Length")
if len(hdrs) == 0 {
return nil, false
}
contentLength, ok := hdrs[0].(*ContentLength)
if !ok {
return nil, false
}
return contentLength, true
}
func (hs *headers) ContentType() (*ContentType, bool) {
hdrs := hs.GetHeaders("Content-Type")
if len(hdrs) == 0 {
return nil, false
}
contentType, ok := hdrs[0].(*ContentType)
if !ok {
return nil, false
}
return contentType, true
}
func (hs *headers) Contact() (*ContactHeader, bool) {
hdrs := hs.GetHeaders("Contact")
if len(hdrs) == 0 {
return nil, false
}
contactHeader, ok := hdrs[0].(*ContactHeader)
if !ok {
return nil, false
}
return contactHeader, true
}
// basic message implementation
type message struct {
// message headers
*headers
mu sync.RWMutex
messID MessageID
sipVersion string
body string
startLine func() string
tp string
src string
dest string
fields log.Fields
}
func (msg *message) MessageID() MessageID {
return msg.messID
}
func (msg *message) StartLine() string {
return msg.startLine()
}
func (msg *message) Fields() log.Fields {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.fields.WithFields(log.Fields{
"transport": msg.tp,
"source": msg.src,
"destination": msg.dest,
})
}
func (msg *message) String() string {
var buffer bytes.Buffer
// write message start line
buffer.WriteString(msg.StartLine() + "\r\n")
// Write the headers.
msg.mu.RLock()
buffer.WriteString(msg.headers.String())
msg.mu.RUnlock()
// message body
buffer.WriteString("\r\n" + msg.Body())
return buffer.String()
}
func (msg *message) SipVersion() string {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.sipVersion
}
func (msg *message) SetSipVersion(version string) {
msg.mu.Lock()
msg.sipVersion = version
msg.mu.Unlock()
}
func (msg *message) Body() string {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.body
}
// SetBody sets message body, calculates it length and add 'Content-Length' header.
func (msg *message) SetBody(body string, setContentLength bool) {
msg.mu.Lock()
msg.body = body
msg.mu.Unlock()
if setContentLength {
hdrs := msg.GetHeaders("Content-Length")
if len(hdrs) == 0 {
length := ContentLength(len(body))
msg.AppendHeader(&length)
} else {
length := ContentLength(len(body))
msg.ReplaceHeaders("Content-Length", []Header{&length})
}
}
}
func (msg *message) Transport() string {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.tp
}
func (msg *message) SetTransport(tp string) {
msg.mu.Lock()
msg.tp = strings.ToUpper(tp)
msg.mu.Unlock()
}
func (msg *message) Source() string {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.src
}
func (msg *message) SetSource(src string) {
msg.mu.Lock()
msg.src = src
msg.mu.Unlock()
}
func (msg *message) Destination() string {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.dest
}
func (msg *message) SetDestination(dest string) {
msg.mu.Lock()
msg.dest = dest
msg.mu.Unlock()
}
// Copy all headers of one type from one message to another.
// Appending to any headers that were already there.
func CopyHeaders(name string, from, to Message) {
name = strings.ToLower(name)
for _, h := range from.GetHeaders(name) {
to.AppendHeader(h.Clone())
}
}
func PrependCopyHeaders(name string, from, to Message) {
name = strings.ToLower(name)
for _, h := range from.GetHeaders(name) {
to.PrependHeader(h.Clone())
}
}
type MessageMapper func(msg Message) Message

View file

@ -0,0 +1,5 @@
# SIP parser
> Package implements SIP protocol parser compatible with [RFC 3261](https://tools.ietf.org/html/rfc3261)
Originally forked from [gossip](https://github.com/StefanKopieczek/gossip) library by @StefanKopieczek.

View file

@ -0,0 +1,26 @@
package parser
type Error interface {
error
// Syntax indicates that this is syntax error
Syntax() bool
}
type InvalidStartLineError string
func (err InvalidStartLineError) Syntax() bool { return true }
func (err InvalidStartLineError) Malformed() bool { return false }
func (err InvalidStartLineError) Broken() bool { return true }
func (err InvalidStartLineError) Error() string { return "parser.InvalidStartLineError: " + string(err) }
type InvalidMessageFormat string
func (err InvalidMessageFormat) Syntax() bool { return true }
func (err InvalidMessageFormat) Malformed() bool { return true }
func (err InvalidMessageFormat) Broken() bool { return true }
func (err InvalidMessageFormat) Error() string { return "parser.InvalidMessageFormat: " + string(err) }
type WriteError string
func (err WriteError) Syntax() bool { return false }
func (err WriteError) Error() string { return "parser.WriteError: " + string(err) }

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,131 @@
// Forked from github.com/StefanKopieczek/gossip by @StefanKopieczek
package parser
import (
"bufio"
"bytes"
"fmt"
"io"
"sync"
"github.com/ghettovoice/gosip/log"
)
// parserBuffer is a specialized buffer for use in the parser.
// It is written to via the non-blocking Write.
// It exposes various blocking read methods, which wait until the requested
// data is available, and then return it.
type parserBuffer struct {
mu sync.RWMutex
writer io.Writer
buffer bytes.Buffer
// Wraps parserBuffer.pipeReader
reader *bufio.Reader
// Don't access this directly except when closing.
pipeReader *io.PipeReader
log log.Logger
}
// Create a new parserBuffer object (see struct comment for object details).
// Note that resources owned by the parserBuffer may not be able to be GCed
// until the Dispose() method is called.
func newParserBuffer(logger log.Logger) *parserBuffer {
var pb parserBuffer
pb.pipeReader, pb.writer = io.Pipe()
pb.reader = bufio.NewReader(pb.pipeReader)
pb.log = logger.
WithPrefix("parser.parserBuffer").
WithFields(log.Fields{
"parser_buffer_ptr": fmt.Sprintf("%p", &pb),
})
return &pb
}
func (pb *parserBuffer) Log() log.Logger {
return pb.log
}
func (pb *parserBuffer) Write(p []byte) (n int, err error) {
pb.mu.RLock()
defer pb.mu.RUnlock()
return pb.writer.Write(p)
}
// Block until the buffer contains at least one CRLF-terminated line.
// Return the line, excluding the terminal CRLF, and delete it from the buffer.
// Returns an error if the parserbuffer has been stopped.
func (pb *parserBuffer) NextLine() (response string, err error) {
var buffer bytes.Buffer
var data string
var b byte
// There has to be a better way!
for {
data, err = pb.reader.ReadString('\r')
if err != nil {
return
}
buffer.WriteString(data)
b, err = pb.reader.ReadByte()
if err != nil {
return
}
buffer.WriteByte(b)
if b == '\n' {
response = buffer.String()
response = response[:len(response)-2]
pb.Log().Tracef("return line '%s'", response)
return
}
}
}
// Block until the buffer contains at least n characters.
// Return precisely those n characters, then delete them from the buffer.
func (pb *parserBuffer) NextChunk(n int) (response string, err error) {
var data = make([]byte, n)
var read int
for total := 0; total < n; {
read, err = pb.reader.Read(data[total:])
total += read
if err != nil {
return
}
}
response = string(data)
pb.Log().Tracef("return chunk:\n%s", response)
return
}
// Stop the parser buffer.
func (pb *parserBuffer) Stop() {
pb.mu.RLock()
if err := pb.pipeReader.Close(); err != nil {
pb.Log().Errorf("parser pipe reader close failed: %s", err)
}
pb.mu.RUnlock()
pb.Log().Trace("parser buffer stopped")
}
func (pb *parserBuffer) Reset() {
pb.mu.Lock()
pb.pipeReader, pb.writer = io.Pipe()
pb.reader.Reset(pb.pipeReader)
pb.mu.Unlock()
}

View file

@ -0,0 +1,382 @@
package sip
import (
"bytes"
"fmt"
"strconv"
"strings"
"github.com/ghettovoice/gosip/log"
)
// Request RFC 3261 - 7.1.
type Request interface {
Message
Method() RequestMethod
SetMethod(method RequestMethod)
Recipient() Uri
SetRecipient(recipient Uri)
/* Common Helpers */
IsInvite() bool
}
type request struct {
message
method RequestMethod
recipient Uri
}
func NewRequest(
messID MessageID,
method RequestMethod,
recipient Uri,
sipVersion string,
hdrs []Header,
body string,
fields log.Fields,
) Request {
req := new(request)
if messID == "" {
req.messID = NextMessageID()
} else {
req.messID = messID
}
req.startLine = req.StartLine
req.sipVersion = sipVersion
req.headers = newHeaders(hdrs)
req.method = method
req.recipient = recipient
req.body = body
req.fields = fields.WithFields(log.Fields{
"request_id": req.messID,
})
return req
}
func (req *request) Short() string {
if req == nil {
return "<nil>"
}
fields := log.Fields{
"method": req.Method(),
"recipient": req.Recipient(),
"transport": req.Transport(),
"source": req.Source(),
"destination": req.Destination(),
}
if cseq, ok := req.CSeq(); ok {
fields["sequence"] = cseq.SeqNo
}
fields = req.Fields().WithFields(fields)
return fmt.Sprintf("sip.Request<%s>", fields)
}
func (req *request) Method() RequestMethod {
req.mu.RLock()
defer req.mu.RUnlock()
return req.method
}
func (req *request) SetMethod(method RequestMethod) {
req.mu.Lock()
req.method = method
req.mu.Unlock()
}
func (req *request) Recipient() Uri {
req.mu.RLock()
defer req.mu.RUnlock()
return req.recipient
}
func (req *request) SetRecipient(recipient Uri) {
req.mu.Lock()
req.recipient = recipient
req.mu.Unlock()
}
// StartLine returns Request Line - RFC 2361 7.1.
func (req *request) StartLine() string {
var buffer bytes.Buffer
// Every SIP request starts with a Request Line - RFC 2361 7.1.
buffer.WriteString(
fmt.Sprintf(
"%s %s %s",
string(req.Method()),
req.Recipient(),
req.SipVersion(),
),
)
return buffer.String()
}
func (req *request) Clone() Message {
return cloneRequest(req, "", nil)
}
func (req *request) Fields() log.Fields {
return req.fields.WithFields(log.Fields{
"transport": req.Transport(),
"source": req.Source(),
"destination": req.Destination(),
})
}
func (req *request) WithFields(fields log.Fields) Message {
req.mu.Lock()
req.fields = req.fields.WithFields(fields)
req.mu.Unlock()
return req
}
func (req *request) IsInvite() bool {
return req.Method() == INVITE
}
func (req *request) IsAck() bool {
return req.Method() == ACK
}
func (req *request) IsCancel() bool {
return req.Method() == CANCEL
}
func (req *request) Transport() string {
if tp := req.message.Transport(); tp != "" {
return strings.ToUpper(tp)
}
var tp string
if viaHop, ok := req.ViaHop(); ok && viaHop.Transport != "" {
tp = viaHop.Transport
} else {
tp = DefaultProtocol
}
uri := req.Recipient()
if hdrs := req.GetHeaders("Route"); len(hdrs) > 0 {
routeHeader, ok := hdrs[0].(*RouteHeader)
if ok && len(routeHeader.Addresses) > 0 {
uri = routeHeader.Addresses[0]
}
}
if uri != nil {
if uri.UriParams() != nil {
if val, ok := uri.UriParams().Get("transport"); ok && !val.Equals("") {
tp = strings.ToUpper(val.String())
}
}
if uri.IsEncrypted() {
if tp == "TCP" {
tp = "TLS"
} else if tp == "WS" {
tp = "WSS"
}
}
}
if tp == "UDP" && len(req.String()) > int(MTU)-200 {
tp = "TCP"
}
return tp
}
func (req *request) Source() string {
if src := req.message.Source(); src != "" {
return src
}
viaHop, ok := req.ViaHop()
if !ok {
return ""
}
var (
host string
port Port
)
host = viaHop.Host
if viaHop.Port != nil {
port = *viaHop.Port
} else {
port = DefaultPort(req.Transport())
}
if viaHop.Params != nil {
if received, ok := viaHop.Params.Get("received"); ok && received.String() != "" {
host = received.String()
}
if rport, ok := viaHop.Params.Get("rport"); ok && rport != nil && rport.String() != "" {
if p, err := strconv.Atoi(rport.String()); err == nil {
port = Port(uint16(p))
}
}
}
return fmt.Sprintf("%v:%v", host, port)
}
func (req *request) Destination() string {
if dest := req.message.Destination(); dest != "" {
return dest
}
var uri *SipUri
if hdrs := req.GetHeaders("Route"); len(hdrs) > 0 {
routeHeader, ok := hdrs[0].(*RouteHeader)
if ok && len(routeHeader.Addresses) > 0 {
uri = routeHeader.Addresses[0].(*SipUri)
}
}
if uri == nil {
if u, ok := req.Recipient().(*SipUri); ok {
uri = u
} else {
return ""
}
}
host := uri.FHost
var port Port
if uri.FPort != nil {
port = *uri.FPort
} else {
port = DefaultPort(req.Transport())
}
return fmt.Sprintf("%v:%v", host, port)
}
// NewAckRequest creates ACK request for 2xx INVITE
// https://tools.ietf.org/html/rfc3261#section-13.2.2.4
func NewAckRequest(ackID MessageID, inviteRequest Request, inviteResponse Response, body string, fields log.Fields) Request {
recipient := inviteRequest.Recipient()
if contact, ok := inviteResponse.Contact(); ok {
// For ws and wss (like clients in browser), don't use Contact
if strings.Index(strings.ToLower(recipient.String()), "transport=ws") == -1 {
recipient = contact.Address
}
}
ackRequest := NewRequest(
ackID,
ACK,
recipient,
inviteRequest.SipVersion(),
[]Header{},
body,
inviteRequest.Fields().
WithFields(fields).
WithFields(log.Fields{
"invite_request_id": inviteRequest.MessageID(),
"invite_response_id": inviteResponse.MessageID(),
}),
)
CopyHeaders("Via", inviteRequest, ackRequest)
if inviteResponse.IsSuccess() {
// update branch, 2xx ACK is separate Tx
viaHop, _ := ackRequest.ViaHop()
viaHop.Params.Add("branch", String{Str: GenerateBranch()})
}
if len(inviteRequest.GetHeaders("Route")) > 0 {
CopyHeaders("Route", inviteRequest, ackRequest)
} else {
hdrs := inviteResponse.GetHeaders("Record-Route")
for i := len(hdrs) - 1; i >= 0; i-- {
h := hdrs[i]
uris := make([]Uri, 0)
for j := len(h.(*RecordRouteHeader).Addresses) - 1; j >= 0; j-- {
uris = append(uris, h.(*RecordRouteHeader).Addresses[j].Clone())
}
ackRequest.AppendHeader(&RouteHeader{
Addresses: uris,
})
}
}
maxForwardsHeader := MaxForwards(70)
ackRequest.AppendHeader(&maxForwardsHeader)
CopyHeaders("From", inviteRequest, ackRequest)
CopyHeaders("To", inviteResponse, ackRequest)
CopyHeaders("Call-ID", inviteRequest, ackRequest)
CopyHeaders("CSeq", inviteRequest, ackRequest)
cseq, _ := ackRequest.CSeq()
cseq.MethodName = ACK
ackRequest.SetBody("", true)
ackRequest.SetTransport(inviteRequest.Transport())
ackRequest.SetSource(inviteRequest.Source())
ackRequest.SetDestination(inviteRequest.Destination())
return ackRequest
}
func NewCancelRequest(cancelID MessageID, requestForCancel Request, fields log.Fields) Request {
cancelReq := NewRequest(
cancelID,
CANCEL,
requestForCancel.Recipient(),
requestForCancel.SipVersion(),
[]Header{},
"",
requestForCancel.Fields().
WithFields(fields).
WithFields(log.Fields{
"cancelling_request_id": requestForCancel.MessageID(),
}),
)
viaHop, _ := requestForCancel.ViaHop()
cancelReq.AppendHeader(ViaHeader{viaHop.Clone()})
CopyHeaders("Route", requestForCancel, cancelReq)
maxForwardsHeader := MaxForwards(70)
cancelReq.AppendHeader(&maxForwardsHeader)
CopyHeaders("From", requestForCancel, cancelReq)
CopyHeaders("To", requestForCancel, cancelReq)
CopyHeaders("Call-ID", requestForCancel, cancelReq)
CopyHeaders("CSeq", requestForCancel, cancelReq)
cseq, _ := cancelReq.CSeq()
cseq.MethodName = CANCEL
cancelReq.SetBody("", true)
cancelReq.SetTransport(requestForCancel.Transport())
cancelReq.SetSource(requestForCancel.Source())
cancelReq.SetDestination(requestForCancel.Destination())
return cancelReq
}
func cloneRequest(req Request, id MessageID, fields log.Fields) Request {
newFields := req.Fields()
if fields != nil {
newFields = newFields.WithFields(fields)
}
newReq := NewRequest(
id,
req.Method(),
req.Recipient().Clone(),
req.SipVersion(),
cloneHeaders(req),
req.Body(),
newFields,
)
newReq.SetTransport(req.Transport())
newReq.SetSource(req.Source())
newReq.SetDestination(req.Destination())
return newReq
}
func CopyRequest(req Request) Request {
return cloneRequest(req, req.MessageID(), nil)
}

View file

@ -0,0 +1,310 @@
package sip
import (
"bytes"
"fmt"
"strconv"
"strings"
"github.com/ghettovoice/gosip/log"
)
// Response RFC 3261 - 7.2.
type Response interface {
Message
StatusCode() StatusCode
SetStatusCode(code StatusCode)
Reason() string
SetReason(reason string)
// Previous returns previous provisional responses
Previous() []Response
SetPrevious(responses []Response)
/* Common helpers */
IsProvisional() bool
IsSuccess() bool
IsRedirection() bool
IsClientError() bool
IsServerError() bool
IsGlobalError() bool
}
type response struct {
message
status StatusCode
reason string
previous []Response
}
func NewResponse(
messID MessageID,
sipVersion string,
statusCode StatusCode,
reason string,
hdrs []Header,
body string,
fields log.Fields,
) Response {
res := new(response)
if messID == "" {
res.messID = NextMessageID()
} else {
res.messID = messID
}
res.startLine = res.StartLine
res.sipVersion = sipVersion
res.headers = newHeaders(hdrs)
res.status = statusCode
res.reason = reason
res.body = body
res.fields = fields.WithFields(log.Fields{
"response_id": res.messID,
})
res.previous = make([]Response, 0)
return res
}
func (res *response) Short() string {
if res == nil {
return "<nil>"
}
fields := log.Fields{
"status": res.StatusCode(),
"reason": res.Reason(),
"transport": res.Transport(),
"source": res.Source(),
"destination": res.Destination(),
}
if cseq, ok := res.CSeq(); ok {
fields["method"] = cseq.MethodName
fields["sequence"] = cseq.SeqNo
}
fields = res.Fields().WithFields(fields)
return fmt.Sprintf("sip.Response<%s>", fields)
}
func (res *response) StatusCode() StatusCode {
res.mu.RLock()
defer res.mu.RUnlock()
return res.status
}
func (res *response) SetStatusCode(code StatusCode) {
res.mu.Lock()
res.status = code
res.mu.Unlock()
}
func (res *response) Reason() string {
res.mu.RLock()
defer res.mu.RUnlock()
return res.reason
}
func (res *response) SetReason(reason string) {
res.mu.Lock()
res.reason = reason
res.mu.Unlock()
}
func (res *response) Previous() []Response {
res.mu.RLock()
defer res.mu.RUnlock()
return res.previous
}
func (res *response) SetPrevious(responses []Response) {
res.mu.Lock()
res.previous = responses
res.mu.Unlock()
}
// StartLine returns Response Status Line - RFC 2361 7.2.
func (res *response) StartLine() string {
var buffer bytes.Buffer
// Every SIP response starts with a Status Line - RFC 2361 7.2.
buffer.WriteString(
fmt.Sprintf(
"%s %d %s",
res.SipVersion(),
res.StatusCode(),
res.Reason(),
),
)
return buffer.String()
}
func (res *response) Clone() Message {
return cloneResponse(res, "", nil)
}
func (res *response) Fields() log.Fields {
return res.fields.WithFields(log.Fields{
"transport": res.Transport(),
"source": res.Source(),
"destination": res.Destination(),
})
}
func (res *response) WithFields(fields log.Fields) Message {
res.mu.Lock()
res.fields = res.fields.WithFields(fields)
res.mu.Unlock()
return res
}
func (res *response) IsProvisional() bool {
return res.StatusCode() < 200
}
func (res *response) IsSuccess() bool {
return res.StatusCode() >= 200 && res.StatusCode() < 300
}
func (res *response) IsRedirection() bool {
return res.StatusCode() >= 300 && res.StatusCode() < 400
}
func (res *response) IsClientError() bool {
return res.StatusCode() >= 400 && res.StatusCode() < 500
}
func (res *response) IsServerError() bool {
return res.StatusCode() >= 500 && res.StatusCode() < 600
}
func (res *response) IsGlobalError() bool {
return res.StatusCode() >= 600
}
func (res *response) IsAck() bool {
if cseq, ok := res.CSeq(); ok {
return cseq.MethodName == ACK
}
return false
}
func (res *response) IsCancel() bool {
if cseq, ok := res.CSeq(); ok {
return cseq.MethodName == CANCEL
}
return false
}
func (res *response) Transport() string {
if tp := res.message.Transport(); tp != "" {
return strings.ToUpper(tp)
}
var tp string
if viaHop, ok := res.ViaHop(); ok && viaHop.Transport != "" {
tp = viaHop.Transport
} else {
tp = DefaultProtocol
}
return tp
}
func (res *response) Destination() string {
if dest := res.message.Destination(); dest != "" {
return dest
}
viaHop, ok := res.ViaHop()
if !ok {
return ""
}
var (
host string
port Port
)
host = viaHop.Host
if viaHop.Port != nil {
port = *viaHop.Port
} else {
port = DefaultPort(res.Transport())
}
if viaHop.Params != nil {
if received, ok := viaHop.Params.Get("received"); ok && received.String() != "" {
host = received.String()
}
if rport, ok := viaHop.Params.Get("rport"); ok && rport != nil && rport.String() != "" {
if p, err := strconv.Atoi(rport.String()); err == nil {
port = Port(uint16(p))
}
}
}
return fmt.Sprintf("%v:%v", host, port)
}
// RFC 3261 - 8.2.6
func NewResponseFromRequest(
resID MessageID,
req Request,
statusCode StatusCode,
reason string,
body string,
) Response {
res := NewResponse(
resID,
req.SipVersion(),
statusCode,
reason,
[]Header{},
"",
req.Fields(),
)
CopyHeaders("Record-Route", req, res)
CopyHeaders("Via", req, res)
CopyHeaders("From", req, res)
CopyHeaders("To", req, res)
CopyHeaders("Call-ID", req, res)
CopyHeaders("CSeq", req, res)
if statusCode == 100 {
CopyHeaders("Timestamp", req, res)
}
res.SetBody(body, true)
res.SetTransport(req.Transport())
res.SetSource(req.Destination())
res.SetDestination(req.Source())
return res
}
func cloneResponse(res Response, id MessageID, fields log.Fields) Response {
newFields := res.Fields()
if fields != nil {
newFields = newFields.WithFields(fields)
}
newRes := NewResponse(
id,
res.SipVersion(),
res.StatusCode(),
res.Reason(),
cloneHeaders(res),
res.Body(),
newFields,
)
newRes.SetPrevious(res.Previous())
newRes.SetTransport(res.Transport())
newRes.SetSource(res.Source())
newRes.SetDestination(res.Destination())
return newRes
}
func CopyResponse(res Response) Response {
return cloneResponse(res, res.MessageID(), nil)
}

View file

@ -0,0 +1,28 @@
package sip
type TransactionKey string
func (key TransactionKey) String() string {
return string(key)
}
type Transaction interface {
Origin() Request
Key() TransactionKey
String() string
Errors() <-chan error
Done() <-chan bool
}
type ServerTransaction interface {
Transaction
Respond(res Response) error
Acks() <-chan Request
Cancels() <-chan Request
}
type ClientTransaction interface {
Transaction
Responses() <-chan Response
Cancel() error
}

View file

@ -0,0 +1,8 @@
package sip
type Transport interface {
Messages() <-chan Message
Send(msg Message) error
IsReliable(network string) bool
IsStreamed(network string) bool
}

View file

@ -0,0 +1,234 @@
// Forked from github.com/StefanKopieczek/gossip by @StefanKopieczek
package timing
import (
"sync"
"time"
)
// Controls whether library calls should be mocked, or whether we should use the standard Go time library.
// If we're in Mock Mode, then time does not pass as normal, but only progresses when Elapse is called.
// False by default, indicating that we just call through to standard Go functions.
var MockMode = false
var currentTimeMock = time.Unix(0, 0)
var mockTimers = make([]*mockTimer, 0)
var mockTimerMu = new(sync.Mutex)
// Interface over Golang's built-in Timers, allowing them to be swapped out for mocked timers.
type Timer interface {
// Returns a channel which sends the current time immediately when the timer expires.
// Equivalent to time.Timer.C; however, we have to use a method here instead of a member since this is an interface.
C() <-chan time.Time
// Resets the timer such that it will expire in duration 'd' after the current time.
// Returns true if the timer had been active, and false if it had expired or been stopped.
Reset(d time.Duration) bool
// Stops the timer, preventing it from firing.
// Returns true if the timer had been active, and false if it had expired or been stopped.
Stop() bool
}
// Implementation of Timer that just wraps time.Timer.
type realTimer struct {
*time.Timer
}
func (t *realTimer) C() <-chan time.Time {
return t.Timer.C
}
func (t *realTimer) Reset(d time.Duration) bool {
t.Stop()
return t.Timer.Reset(d)
}
func (t *realTimer) Stop() bool {
// return t.Timer.Stop()
if !t.Timer.Stop() {
select {
case <-t.Timer.C:
return true
default:
return false
}
}
return true
}
// Implementation of Timer that mocks time.Timer, firing when the total elapsed time (as controlled by Elapse)
// exceeds the duration specified when the timer was constructed.
type mockTimer struct {
EndTime time.Time
Chan chan time.Time
fired bool
toRun func()
}
func (t *mockTimer) C() <-chan time.Time {
return t.Chan
}
func (t *mockTimer) Reset(d time.Duration) bool {
wasActive := removeMockTimer(t)
t.EndTime = currentTimeMock.Add(d)
if d > 0 {
mockTimerMu.Lock()
mockTimers = append(mockTimers, t)
mockTimerMu.Unlock()
} else {
// The new timer has an expiry time of 0.
// Fire it right away, and don't bother tracking it.
t.Chan <- currentTimeMock
}
return wasActive
}
func (t *mockTimer) Stop() bool {
if !removeMockTimer(t) {
select {
case <-t.Chan:
return true
default:
return false
}
}
return true
}
// Creates a new Timer; either a wrapper around a standard Go time.Timer, or a mocked-out Timer,
// depending on whether MockMode is set.
func NewTimer(d time.Duration) Timer {
if MockMode {
t := mockTimer{currentTimeMock.Add(d), make(chan time.Time, 1), false, nil}
if d == 0 {
t.Chan <- currentTimeMock
} else {
mockTimerMu.Lock()
mockTimers = append(mockTimers, &t)
mockTimerMu.Unlock()
}
return &t
} else {
return &realTimer{time.NewTimer(d)}
}
}
// See built-in time.After() function.
func After(d time.Duration) <-chan time.Time {
return NewTimer(d).C()
}
// See built-in time.AfterFunc() function.
func AfterFunc(d time.Duration, f func()) Timer {
if MockMode {
mockTimerMu.Lock()
t := mockTimer{currentTimeMock.Add(d), make(chan time.Time, 1), false, f}
mockTimerMu.Unlock()
if d == 0 {
go f()
t.Chan <- currentTimeMock
} else {
mockTimerMu.Lock()
mockTimers = append(mockTimers, &t)
mockTimerMu.Unlock()
}
return &t
} else {
return &realTimer{time.AfterFunc(d, f)}
}
}
// See built-in time.Sleep() function.
func Sleep(d time.Duration) {
<-After(d)
}
// Increment the current time by the given Duration.
// This function can only be called in Mock Mode, otherwise we will panic.
func Elapse(d time.Duration) {
requireMockMode()
mockTimerMu.Lock()
currentTimeMock = currentTimeMock.Add(d)
mockTimerMu.Unlock()
// Fire any timers whose time has come up.
mockTimerMu.Lock()
for _, t := range mockTimers {
t.fired = false
if !t.EndTime.After(currentTimeMock) {
if t.toRun != nil {
go t.toRun()
}
// Clear the channel if something is already in it.
select {
case <-t.Chan:
default:
}
t.Chan <- currentTimeMock
t.fired = true
}
}
mockTimerMu.Unlock()
// Stop tracking any fired timers.
remainingTimers := make([]*mockTimer, 0)
mockTimerMu.Lock()
for _, t := range mockTimers {
if !t.fired {
remainingTimers = append(remainingTimers, t)
}
}
mockTimers = remainingTimers
mockTimerMu.Unlock()
}
// Returns the current time.
// If Mock Mode is set, this will be the sum of all Durations passed into Elapse calls;
// otherwise it will be the true system time.
func Now() time.Time {
if MockMode {
return currentTimeMock
} else {
return time.Now()
}
}
// Shortcut method to enforce that Mock Mode is enabled.
func requireMockMode() {
if !MockMode {
panic("This method requires MockMode to be enabled")
}
}
// Utility method to remove a mockTimer from the list of outstanding timers.
func removeMockTimer(t *mockTimer) bool {
// First, find the index of the timer in our list.
found := false
var idx int
var elt *mockTimer
mockTimerMu.Lock()
for idx, elt = range mockTimers {
if elt == t {
found = true
break
}
}
mockTimerMu.Unlock()
if found {
mockTimerMu.Lock()
// We found the given timer. Remove it.
mockTimers = append(mockTimers[:idx], mockTimers[idx+1:]...)
mockTimerMu.Unlock()
return true
} else {
// The timer was not present, indicating that it was already expired.
return false
}
}

View file

@ -0,0 +1,219 @@
package transport
import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/ghettovoice/gosip/log"
)
var (
bufferSize uint16 = 65535 - 20 - 8 // IPv4 max size - IPv4 Header size - UDP Header size
)
// Wrapper around net.Conn.
type Connection interface {
net.Conn
Key() ConnectionKey
Network() string
Streamed() bool
String() string
ReadFrom(buf []byte) (num int, raddr net.Addr, err error)
WriteTo(buf []byte, raddr net.Addr) (num int, err error)
}
// Connection implementation.
type connection struct {
baseConn net.Conn
key ConnectionKey
network string
laddr net.Addr
raddr net.Addr
streamed bool
mu sync.RWMutex
log log.Logger
}
func NewConnection(baseConn net.Conn, key ConnectionKey, network string, logger log.Logger) Connection {
var stream bool
switch baseConn.(type) {
case net.PacketConn:
stream = false
default:
stream = true
}
conn := &connection{
baseConn: baseConn,
key: key,
network: network,
laddr: baseConn.LocalAddr(),
raddr: baseConn.RemoteAddr(),
streamed: stream,
}
conn.log = logger.
WithPrefix("transport.Connection").
WithFields(log.Fields{
"connection_ptr": fmt.Sprintf("%p", conn),
"connection_key": conn.Key(),
})
return conn
}
func (conn *connection) String() string {
if conn == nil {
return "<nil>"
}
fields := conn.Log().Fields().WithFields(log.Fields{
"key": conn.Key(),
"network": conn.Network(),
"local_addr": conn.LocalAddr(),
"remote_addr": conn.RemoteAddr(),
})
return fmt.Sprintf("transport.Connection<%s>", fields)
}
func (conn *connection) Log() log.Logger {
return conn.log
}
func (conn *connection) Key() ConnectionKey {
return conn.key
}
func (conn *connection) Streamed() bool {
return conn.streamed
}
func (conn *connection) Network() string {
return strings.ToUpper(conn.network)
}
func (conn *connection) Read(buf []byte) (int, error) {
var (
num int
err error
)
num, err = conn.baseConn.Read(buf)
if err != nil {
return num, &ConnectionError{
err,
"read",
conn.Network(),
fmt.Sprintf("%v", conn.RemoteAddr()),
fmt.Sprintf("%v", conn.LocalAddr()),
fmt.Sprintf("%p", conn),
}
}
conn.Log().Tracef("read %d bytes %s <- %s:\n%s", num, conn.LocalAddr(), conn.RemoteAddr(), buf[:num])
return num, err
}
func (conn *connection) ReadFrom(buf []byte) (num int, raddr net.Addr, err error) {
num, raddr, err = conn.baseConn.(net.PacketConn).ReadFrom(buf)
if err != nil {
return num, raddr, &ConnectionError{
err,
"read",
conn.Network(),
fmt.Sprintf("%v", raddr),
fmt.Sprintf("%v", conn.LocalAddr()),
fmt.Sprintf("%p", conn),
}
}
conn.Log().Tracef("read %d bytes %s <- %s:\n%s", num, conn.LocalAddr(), raddr, buf[:num])
return num, raddr, err
}
func (conn *connection) Write(buf []byte) (int, error) {
var (
num int
err error
)
num, err = conn.baseConn.Write(buf)
if err != nil {
return num, &ConnectionError{
err,
"write",
conn.Network(),
fmt.Sprintf("%v", conn.LocalAddr()),
fmt.Sprintf("%v", conn.RemoteAddr()),
fmt.Sprintf("%p", conn),
}
}
conn.Log().Tracef("write %d bytes %s -> %s:\n%s", num, conn.LocalAddr(), conn.RemoteAddr(), buf[:num])
return num, err
}
func (conn *connection) WriteTo(buf []byte, raddr net.Addr) (num int, err error) {
num, err = conn.baseConn.(net.PacketConn).WriteTo(buf, raddr)
if err != nil {
return num, &ConnectionError{
err,
"write",
conn.Network(),
fmt.Sprintf("%v", conn.LocalAddr()),
fmt.Sprintf("%v", raddr),
fmt.Sprintf("%p", conn),
}
}
conn.Log().Tracef("write %d bytes %s -> %s:\n%s", num, conn.LocalAddr(), raddr, buf[:num])
return num, err
}
func (conn *connection) LocalAddr() net.Addr {
return conn.baseConn.LocalAddr()
}
func (conn *connection) RemoteAddr() net.Addr {
return conn.baseConn.RemoteAddr()
}
func (conn *connection) Close() error {
err := conn.baseConn.Close()
if err != nil {
return &ConnectionError{
err,
"close",
conn.Network(),
"",
"",
fmt.Sprintf("%p", conn),
}
}
conn.Log().Trace("connection closed")
return nil
}
func (conn *connection) SetDeadline(t time.Time) error {
return conn.baseConn.SetDeadline(t)
}
func (conn *connection) SetReadDeadline(t time.Time) error {
return conn.baseConn.SetReadDeadline(t)
}
func (conn *connection) SetWriteDeadline(t time.Time) error {
return conn.baseConn.SetWriteDeadline(t)
}

View file

@ -0,0 +1,802 @@
package transport
import (
"bytes"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
"github.com/ghettovoice/gosip/sip/parser"
"github.com/ghettovoice/gosip/timing"
"github.com/ghettovoice/gosip/util"
)
type ConnectionKey string
func (key ConnectionKey) String() string {
return string(key)
}
// ConnectionPool used for active connection management.
type ConnectionPool interface {
Done() <-chan struct{}
String() string
Put(connection Connection, ttl time.Duration) error
Get(key ConnectionKey) (Connection, error)
All() []Connection
Drop(key ConnectionKey) error
DropAll() error
Length() int
}
// ConnectionHandler serves associated connection, i.e. parses
// incoming data, manages expiry time & etc.
type ConnectionHandler interface {
Cancel()
Done() <-chan struct{}
String() string
Key() ConnectionKey
Connection() Connection
// Expiry returns connection expiry time.
Expiry() time.Time
Expired() bool
// Update updates connection expiry time.
// TODO put later to allow runtime update
// Update(conn Connection, ttl time.Duration)
// Manage runs connection serving.
Serve()
}
type connectionPool struct {
store map[ConnectionKey]ConnectionHandler
msgMapper sip.MessageMapper
output chan<- sip.Message
errs chan<- error
cancel <-chan struct{}
done chan struct{}
hmess chan sip.Message
herrs chan error
hwg sync.WaitGroup
mu sync.RWMutex
log log.Logger
}
func NewConnectionPool(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) ConnectionPool {
pool := &connectionPool{
store: make(map[ConnectionKey]ConnectionHandler),
msgMapper: msgMapper,
output: output,
errs: errs,
cancel: cancel,
done: make(chan struct{}),
hmess: make(chan sip.Message),
herrs: make(chan error),
}
pool.log = logger.
WithPrefix("transport.ConnectionPool").
WithFields(log.Fields{
"connection_pool_ptr": fmt.Sprintf("%p", pool),
})
go func() {
<-pool.cancel
pool.dispose()
}()
go pool.serveHandlers()
return pool
}
func (pool *connectionPool) String() string {
if pool == nil {
return "<nil>"
}
return fmt.Sprintf("transport.ConnectionPool<%s>", pool.Log().Fields())
}
func (pool *connectionPool) Log() log.Logger {
return pool.log
}
func (pool *connectionPool) Done() <-chan struct{} {
return pool.done
}
// Put adds new connection to pool or updates TTL of existing connection
// TTL - 0 - unlimited; 1 - ... - time to live in pool
func (pool *connectionPool) Put(connection Connection, ttl time.Duration) error {
select {
case <-pool.cancel:
return &PoolError{
fmt.Errorf("connection pool closed"),
"get connection",
pool.String(),
}
default:
}
key := connection.Key()
if key == "" {
return &PoolError{
fmt.Errorf("empty connection key"),
"put connection",
pool.String(),
}
}
pool.mu.Lock()
defer pool.mu.Unlock()
return pool.put(key, connection, ttl)
}
func (pool *connectionPool) Get(key ConnectionKey) (Connection, error) {
pool.mu.RLock()
defer pool.mu.RUnlock()
return pool.getConnection(key)
}
func (pool *connectionPool) Drop(key ConnectionKey) error {
pool.mu.Lock()
defer pool.mu.Unlock()
return pool.drop(key)
}
func (pool *connectionPool) DropAll() error {
pool.mu.Lock()
for key := range pool.store {
if err := pool.drop(key); err != nil {
pool.Log().Errorf("drop connection %s failed: %s", key, err)
}
}
pool.mu.Unlock()
return nil
}
func (pool *connectionPool) All() []Connection {
pool.mu.RLock()
conns := make([]Connection, 0)
for _, handler := range pool.store {
conns = append(conns, handler.Connection())
}
pool.mu.RUnlock()
return conns
}
func (pool *connectionPool) Length() int {
pool.mu.RLock()
defer pool.mu.RUnlock()
return len(pool.store)
}
func (pool *connectionPool) dispose() {
// clean pool
pool.DropAll()
pool.hwg.Wait()
// stop serveHandlers goroutine
close(pool.hmess)
close(pool.herrs)
close(pool.done)
}
func (pool *connectionPool) serveHandlers() {
pool.Log().Debug("begin serve connection handlers")
defer pool.Log().Debug("stop serve connection handlers")
for {
logger := pool.Log()
select {
case msg, ok := <-pool.hmess:
// cancel signal, serveStore exists
if !ok {
return
}
if msg == nil {
continue
}
logger = logger.WithFields(msg.Fields())
logger.Trace("passing up SIP message")
select {
case <-pool.cancel:
return
case pool.output <- msg:
logger.Trace("SIP message passed up")
continue
}
case err, ok := <-pool.herrs:
// cancel signal, serveStore exists
if !ok {
return
}
if err == nil {
continue
}
// on ConnectionHandleError we should drop handler in some cases
// all other possible errors ignored because in pool.herrs should be only ConnectionHandlerErrors
// so ConnectionPool passes up only Network (when connection falls) and MalformedMessage errors
var herr *ConnectionHandlerError
if !errors.As(err, &herr) {
// all other possible errors
logger.Tracef("ignore non connection error: %s", err)
continue
}
pool.mu.RLock()
handler, gerr := pool.get(herr.Key)
pool.mu.RUnlock()
if gerr != nil {
// ignore, handler already dropped out
logger.Tracef("ignore error from already dropped out connection %s: %s", herr.Key, gerr)
continue
}
logger = logger.WithFields(log.Fields{
"connection_handler": handler.String(),
})
if herr.Expired() {
// handler expired, drop it from pool and continue without emitting error
if handler.Expired() {
// connection expired
logger.Debug("connection expired, drop it and go further")
if err := pool.Drop(handler.Key()); err != nil {
logger.Error(err)
}
} else {
// Due to a race condition, the socket has been updated since this expiry happened.
// Ignore the expiry since we already have a new socket for this address.
logger.Trace("ignore spurious connection expiry")
}
continue
} else if herr.EOF() {
select {
case <-pool.cancel:
return
default:
}
// remote endpoint closed
logger.Debugf("connection EOF: %s; drop it and go further", herr)
if err := pool.Drop(handler.Key()); err != nil {
logger.Error(err)
}
var connErr *ConnectionError
if errors.As(herr.Err, &connErr) {
pool.errs <- herr.Err
}
continue
} else if herr.Network() {
// connection broken or closed
logger.Debugf("connection network error: %s; drop it and pass the error up", herr)
if err := pool.Drop(handler.Key()); err != nil {
logger.Error(err)
}
} else {
// syntax errors, malformed message errors and other
logger.Tracef("connection error: %s; pass the error up", herr)
}
// send initial error
select {
case <-pool.cancel:
return
case pool.errs <- herr.Err:
logger.Trace("error passed up")
continue
}
}
}
}
func (pool *connectionPool) put(key ConnectionKey, conn Connection, ttl time.Duration) error {
if _, err := pool.get(key); err == nil {
return &PoolError{
fmt.Errorf("key %s already exists in the pool", key),
"put connection",
pool.String(),
}
}
// wrap to handler
handler := NewConnectionHandler(
conn,
ttl,
pool.hmess,
pool.herrs,
pool.msgMapper,
pool.Log(),
)
logger := log.AddFieldsFrom(pool.Log(), handler)
logger.Tracef("put connection to the pool with TTL = %s", ttl)
pool.store[handler.Key()] = handler
// start serving
pool.hwg.Add(1)
go handler.Serve()
go func() {
<-handler.Done()
pool.hwg.Done()
}()
return nil
}
func (pool *connectionPool) drop(key ConnectionKey) error {
// check existence in pool
handler, err := pool.get(key)
if err != nil {
return err
}
handler.Cancel()
logger := log.AddFieldsFrom(pool.Log(), handler)
logger.Trace("drop connection from the pool")
// modify store
delete(pool.store, key)
return nil
}
func (pool *connectionPool) get(key ConnectionKey) (ConnectionHandler, error) {
if handler, ok := pool.store[key]; ok {
return handler, nil
}
return nil, &PoolError{
fmt.Errorf("connection %s not found in the pool", key),
"get connection",
pool.String(),
}
}
func (pool *connectionPool) getConnection(key ConnectionKey) (Connection, error) {
var conn Connection
handler, err := pool.get(key)
if err == nil {
conn = handler.Connection()
}
return conn, err
}
// connectionHandler actually serves associated connection
type connectionHandler struct {
connection Connection
msgMapper sip.MessageMapper
timer timing.Timer
ttl time.Duration
expiry time.Time
output chan<- sip.Message
errs chan<- error
cancelOnce sync.Once
canceled chan struct{}
done chan struct{}
addrs util.ElasticChan
log log.Logger
}
func NewConnectionHandler(
conn Connection,
ttl time.Duration,
output chan<- sip.Message,
errs chan<- error,
msgMapper sip.MessageMapper,
logger log.Logger,
) ConnectionHandler {
handler := &connectionHandler{
connection: conn,
msgMapper: msgMapper,
output: output,
errs: errs,
canceled: make(chan struct{}),
done: make(chan struct{}),
ttl: ttl,
}
handler.log = logger.
WithPrefix("transport.ConnectionHandler").
WithFields(log.Fields{
"connection_handler_ptr": fmt.Sprintf("%p", handler),
"connection_ptr": fmt.Sprintf("%p", conn),
"connection_key": conn.Key(),
"connection_network": conn.Network(),
})
// handler.Update(ttl)
if ttl > 0 {
handler.expiry = time.Now().Add(ttl)
handler.timer = timing.NewTimer(ttl)
} else {
handler.expiry = time.Time{}
handler.timer = timing.NewTimer(0)
if !handler.timer.Stop() {
<-handler.timer.C()
}
}
if handler.msgMapper == nil {
handler.msgMapper = func(msg sip.Message) sip.Message {
return msg
}
}
return handler
}
func (handler *connectionHandler) String() string {
if handler == nil {
return "<nil>"
}
return fmt.Sprintf("transport.ConnectionHandler<%s>", handler.Log().Fields())
}
func (handler *connectionHandler) Log() log.Logger {
return handler.log
}
func (handler *connectionHandler) Key() ConnectionKey {
return handler.connection.Key()
}
func (handler *connectionHandler) Connection() Connection {
return handler.connection
}
func (handler *connectionHandler) Expiry() time.Time {
return handler.expiry
}
func (handler *connectionHandler) Expired() bool {
return !handler.Expiry().IsZero() && handler.Expiry().Before(time.Now())
}
// resets the timeout timer.
// func (handler *connectionHandler) Update(ttl time.Duration) {
// if ttl > 0 {
// expiryTime := timing.Now().Put(ttl)
// handler.Log().Debugf("set %s expiry time to %s", handler, expiryTime)
// handler.expiry = expiryTime
//
// if handler.timer == nil {
// handler.timer = timing.NewTimer(ttl)
// } else {
// handler.timer.Reset(ttl)
// }
// } else {
// handler.Log().Debugf("set %s unlimited expiry time")
// handler.expiry = time.Time{}
//
// if handler.timer == nil {
// handler.timer = timing.NewTimer(0)
// }
// handler.timer.Stop()
// }
// }
// Serve is connection serving loop.
// Waits for the connection to expire, and notifies the pool when it does.
func (handler *connectionHandler) Serve() {
defer close(handler.done)
handler.Log().Debug("begin serve connection")
defer handler.Log().Debug("stop serve connection")
// start connection serving goroutines
msgs, errs := handler.readConnection()
handler.pipeOutputs(msgs, errs)
}
func (handler *connectionHandler) readConnection() (<-chan sip.Message, <-chan error) {
msgs := make(chan sip.Message)
errs := make(chan error)
streamed := handler.Connection().Streamed()
var (
pktPrs *parser.PacketParser
strPrs parser.Parser
)
if streamed {
strPrs = parser.NewParser(msgs, errs, streamed, handler.Log())
} else {
pktPrs = parser.NewPacketParser(handler.Log())
}
var raddr net.Addr
if streamed {
raddr = handler.Connection().RemoteAddr()
} else {
handler.addrs.Init()
handler.addrs.SetLog(handler.Log())
handler.addrs.Run()
}
go func() {
defer func() {
handler.Connection().Close()
if streamed {
strPrs.Stop()
} else {
pktPrs.Stop()
}
if !streamed {
handler.addrs.Stop()
}
close(msgs)
close(errs)
}()
handler.Log().Debug("begin read connection")
defer handler.Log().Debug("stop read connection")
buf := make([]byte, bufferSize)
var (
num int
err error
)
for {
// wait for data
if streamed {
num, err = handler.Connection().Read(buf)
} else {
num, raddr, err = handler.Connection().ReadFrom(buf)
}
if err != nil {
//// if we get timeout error just go further and try read on the next iteration
//var netErr net.Error
//if errors.As(err, &netErr) {
// if netErr.Timeout() || netErr.Temporary() {
// handler.Log().Tracef(
// "connection read failed due to timeout or temporary unavailable reason: %s, sleep by %s",
// err,
// netErrRetryTime,
// )
//
// time.Sleep(netErrRetryTime)
//
// continue
// }
//}
// broken or closed connection
// so send error and exit
handler.handleError(err, fmt.Sprintf("%v", raddr))
return
}
data := buf[:num]
// skip empty udp packets
if len(bytes.Trim(data, "\x00")) == 0 {
handler.Log().Tracef("skip empty data: %#v", data)
continue
}
// parse received data
if streamed {
if _, err := strPrs.Write(data); err != nil {
handler.handleError(err, fmt.Sprintf("%v", raddr))
}
} else {
if msg, err := pktPrs.ParseMessage(data); err == nil {
handler.handleMessage(msg, fmt.Sprintf("%v", raddr))
} else {
handler.handleError(err, fmt.Sprintf("%v", raddr))
}
}
}
}()
return msgs, errs
}
func (handler *connectionHandler) pipeOutputs(msgs <-chan sip.Message, errs <-chan error) {
streamed := handler.Connection().Streamed()
handler.Log().Debug("begin pipe outputs")
defer handler.Log().Debug("stop pipe outputs")
for {
select {
case <-handler.timer.C():
var raddr string
if streamed {
raddr = fmt.Sprintf("%v", handler.Connection().RemoteAddr())
}
if handler.Expiry().IsZero() {
// handler expiryTime is zero only when TTL = 0 (unlimited handler)
// so we must not get here with zero expiryTime
handler.Log().Panic("fires expiry timer with ZERO expiryTime")
}
// pass up to the pool
handler.handleError(ExpireError("connection expired"), raddr)
case msg, ok := <-msgs:
if !ok {
return
}
handler.handleMessage(msg, handler.getRemoteAddr())
case err, ok := <-errs:
if !ok {
return
}
handler.handleError(err, handler.getRemoteAddr())
}
}
}
func (handler *connectionHandler) getRemoteAddr() string {
if handler.Connection().Streamed() {
return fmt.Sprintf("%v", handler.Connection().RemoteAddr())
} else {
// use non-blocking read because remote address already should be here
// or error occurred in read connection goroutine
select {
case v := <-handler.addrs.Out:
return v.(string)
default:
return "<nil>"
}
}
}
func (handler *connectionHandler) handleMessage(msg sip.Message, raddr string) {
msg.SetDestination(handler.Connection().LocalAddr().String())
rhost, rport, _ := net.SplitHostPort(raddr)
switch msg := msg.(type) {
case sip.Request:
// RFC 3261 - 18.2.1
viaHop, ok := msg.ViaHop()
if !ok {
handler.Log().Warn("ignore message without 'Via' header")
return
}
if rhost != "" && rhost != viaHop.Host {
viaHop.Params.Add("received", sip.String{Str: rhost})
}
// rfc3581
if viaHop.Params.Has("rport") {
viaHop.Params.Add("rport", sip.String{Str: rport})
}
if !handler.Connection().Streamed() {
if !viaHop.Params.Has("rport") {
var port sip.Port
if viaHop.Port != nil {
port = *viaHop.Port
} else {
port = sip.DefaultPort(handler.Connection().Network())
}
raddr = fmt.Sprintf("%s:%d", rhost, port)
}
}
msg.SetTransport(handler.connection.Network())
msg.SetSource(raddr)
case sip.Response:
// Set Remote Address as response source
msg.SetTransport(handler.connection.Network())
msg.SetSource(raddr)
}
msg = handler.msgMapper(msg.WithFields(log.Fields{
"connection_key": handler.Connection().Key(),
"received_at": time.Now(),
}))
// pass up
handler.output <- msg
if !handler.Expiry().IsZero() {
handler.expiry = time.Now().Add(handler.ttl)
handler.timer.Reset(handler.ttl)
}
}
func (handler *connectionHandler) handleError(err error, raddr string) {
if isSyntaxError(err) {
handler.Log().Tracef("ignore error: %s", err)
return
}
err = &ConnectionHandlerError{
err,
handler.Key(),
fmt.Sprintf("%p", handler),
handler.Connection().Network(),
fmt.Sprintf("%v", handler.Connection().LocalAddr()),
raddr,
}
select {
case <-handler.canceled:
case handler.errs <- err:
}
}
func isSyntaxError(err error) bool {
var perr parser.Error
if errors.As(err, &perr) && perr.Syntax() {
return true
}
var merr sip.MessageError
if errors.As(err, &merr) && merr.Broken() {
return true
}
return false
}
// Cancel simply calls runtime provided cancel function.
func (handler *connectionHandler) Cancel() {
handler.cancelOnce.Do(func() {
close(handler.canceled)
handler.Connection().Close()
handler.Log().Debug("connection handler canceled")
})
}
func (handler *connectionHandler) Done() <-chan struct{} {
return handler.done
}

View file

@ -0,0 +1,477 @@
package transport
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"strings"
"sync"
"time"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
// Layer is responsible for the actual transmission of messages - RFC 3261 - 18.
type Layer interface {
Cancel()
Done() <-chan struct{}
Messages() <-chan sip.Message
Errors() <-chan error
// Listen starts listening on `addr` for each registered protocol.
Listen(network string, addr string, options ...ListenOption) error
// Send sends message on suitable protocol.
Send(msg sip.Message) error
String() string
IsReliable(network string) bool
IsStreamed(network string) bool
}
var protocolFactory ProtocolFactory = func(
network string,
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) (Protocol, error) {
switch strings.ToLower(network) {
case "udp":
return NewUdpProtocol(output, errs, cancel, msgMapper, logger), nil
case "tcp":
return NewTcpProtocol(output, errs, cancel, msgMapper, logger), nil
case "tls":
return NewTlsProtocol(output, errs, cancel, msgMapper, logger), nil
case "ws":
return NewWsProtocol(output, errs, cancel, msgMapper, logger), nil
case "wss":
return NewWssProtocol(output, errs, cancel, msgMapper, logger), nil
default:
return nil, UnsupportedProtocolError(fmt.Sprintf("protocol %s is not supported", network))
}
}
// SetProtocolFactory replaces default protocol factory
func SetProtocolFactory(factory ProtocolFactory) {
protocolFactory = factory
}
// GetProtocolFactory returns default protocol factory
func GetProtocolFactory() ProtocolFactory {
return protocolFactory
}
// TransportLayer implementation.
type layer struct {
protocols *protocolStore
listenPorts map[string][]sip.Port
ip net.IP
dnsResolver *net.Resolver
msgMapper sip.MessageMapper
msgs chan sip.Message
errs chan error
pmsgs chan sip.Message
perrs chan error
canceled chan struct{}
done chan struct{}
wg sync.WaitGroup
cancelOnce sync.Once
log log.Logger
}
// NewLayer creates transport layer.
// - ip - host IP
// - dnsAddr - DNS server address, default is 127.0.0.1:53
func NewLayer(
ip net.IP,
dnsResolver *net.Resolver,
msgMapper sip.MessageMapper,
logger log.Logger,
) Layer {
tpl := &layer{
protocols: newProtocolStore(),
listenPorts: make(map[string][]sip.Port),
ip: ip,
dnsResolver: dnsResolver,
msgMapper: msgMapper,
msgs: make(chan sip.Message),
errs: make(chan error),
pmsgs: make(chan sip.Message),
perrs: make(chan error),
canceled: make(chan struct{}),
done: make(chan struct{}),
}
tpl.log = logger.
WithPrefix("transport.Layer").
WithFields(map[string]interface{}{
"transport_layer_ptr": fmt.Sprintf("%p", tpl),
})
go tpl.serveProtocols()
return tpl
}
func (tpl *layer) String() string {
if tpl == nil {
return "<nil>"
}
return fmt.Sprintf("transport.Layer<%s>", tpl.Log().Fields())
}
func (tpl *layer) Log() log.Logger {
return tpl.log
}
func (tpl *layer) Cancel() {
select {
case <-tpl.canceled:
return
default:
}
tpl.cancelOnce.Do(func() {
close(tpl.canceled)
tpl.Log().Debug("transport layer canceled")
})
}
func (tpl *layer) Done() <-chan struct{} {
return tpl.done
}
func (tpl *layer) Messages() <-chan sip.Message {
return tpl.msgs
}
func (tpl *layer) Errors() <-chan error {
return tpl.errs
}
func (tpl *layer) IsReliable(network string) bool {
if protocol, ok := tpl.protocols.get(protocolKey(network)); ok && protocol.Reliable() {
return true
}
return false
}
func (tpl *layer) IsStreamed(network string) bool {
if protocol, ok := tpl.protocols.get(protocolKey(network)); ok && protocol.Streamed() {
return true
}
return false
}
func (tpl *layer) Listen(network string, addr string, options ...ListenOption) error {
select {
case <-tpl.canceled:
return fmt.Errorf("transport layer is canceled")
default:
}
protocol, err := tpl.getProtocol(network)
if err != nil {
return err
}
target, err := NewTargetFromAddr(addr)
if err != nil {
return err
}
target = FillTargetHostAndPort(protocol.Network(), target)
err = protocol.Listen(target, options...)
if err == nil {
if _, ok := tpl.listenPorts[protocol.Network()]; !ok {
if tpl.listenPorts[protocol.Network()] == nil {
tpl.listenPorts[protocol.Network()] = make([]sip.Port, 0)
}
tpl.listenPorts[protocol.Network()] = append(tpl.listenPorts[protocol.Network()], *target.Port)
}
}
return err
}
func (tpl *layer) Send(msg sip.Message) error {
select {
case <-tpl.canceled:
return fmt.Errorf("transport layer is canceled")
default:
}
viaHop, ok := msg.ViaHop()
if !ok {
return &sip.MalformedMessageError{
Err: fmt.Errorf("missing required 'Via' header"),
Msg: msg.String(),
}
}
switch msg := msg.(type) {
// RFC 3261 - 18.1.1.
case sip.Request:
network := msg.Transport()
// rewrite sent-by transport
viaHop.Transport = strings.ToUpper(network)
viaHop.Host = tpl.ip.String()
protocol, err := tpl.getProtocol(network)
if err != nil {
return err
}
// rewrite sent-by port
if viaHop.Port == nil {
if ports, ok := tpl.listenPorts[network]; ok {
port := ports[rand.Intn(len(ports))]
viaHop.Port = &port
} else {
defPort := sip.DefaultPort(network)
viaHop.Port = &defPort
}
}
target, err := NewTargetFromAddr(msg.Destination())
if err != nil {
return fmt.Errorf("build address target for %s: %w", msg.Destination(), err)
}
// dns srv lookup
if net.ParseIP(target.Host) == nil {
ctx := context.Background()
proto := strings.ToLower(network)
if _, addrs, err := tpl.dnsResolver.LookupSRV(ctx, "sip", proto, target.Host); err == nil && len(addrs) > 0 {
addr := addrs[0]
addrStr := fmt.Sprintf("%s:%d", addr.Target[:len(addr.Target)-1], addr.Port)
switch network {
case "UDP":
if addr, err := net.ResolveUDPAddr("udp", addrStr); err == nil {
port := sip.Port(addr.Port)
if addr.IP.To4() == nil {
target.Host = fmt.Sprintf("[%v]", addr.IP.String())
} else {
target.Host = addr.IP.String()
}
target.Port = &port
}
case "TLS":
fallthrough
case "WS":
fallthrough
case "WSS":
fallthrough
case "TCP":
if addr, err := net.ResolveTCPAddr("tcp", addrStr); err == nil {
port := sip.Port(addr.Port)
if addr.IP.To4() == nil {
target.Host = fmt.Sprintf("[%v]", addr.IP.String())
} else {
target.Host = addr.IP.String()
}
target.Port = &port
}
}
}
}
logger := log.AddFieldsFrom(tpl.Log(), protocol, msg)
logger.Debugf("sending SIP request:\n%s", msg)
if err = protocol.Send(target, msg); err != nil {
return fmt.Errorf("send SIP message through %s protocol to %s: %w", protocol.Network(), target.Addr(), err)
}
return nil
// RFC 3261 - 18.2.2.
case sip.Response:
// resolve protocol from Via
protocol, err := tpl.getProtocol(msg.Transport())
if err != nil {
return err
}
target, err := NewTargetFromAddr(msg.Destination())
if err != nil {
return fmt.Errorf("build address target for %s: %w", msg.Destination(), err)
}
logger := log.AddFieldsFrom(tpl.Log(), protocol, msg)
logger.Debugf("sending SIP response:\n%s", msg)
if err = protocol.Send(target, msg); err != nil {
return fmt.Errorf("send SIP message through %s protocol to %s: %w", protocol.Network(), target.Addr(), err)
}
return nil
default:
return &sip.UnsupportedMessageError{
Err: fmt.Errorf("unsupported message %s", msg.Short()),
Msg: msg.String(),
}
}
}
func (tpl *layer) getProtocol(network string) (Protocol, error) {
network = strings.ToLower(network)
return tpl.protocols.getOrPutNew(protocolKey(network), func() (Protocol, error) {
return protocolFactory(
network,
tpl.pmsgs,
tpl.perrs,
tpl.canceled,
tpl.msgMapper,
tpl.Log(),
)
})
}
func (tpl *layer) serveProtocols() {
defer func() {
tpl.dispose()
close(tpl.done)
}()
tpl.Log().Debug("begin serve protocols")
defer tpl.Log().Debug("stop serve protocols")
for {
select {
case <-tpl.canceled:
return
case msg := <-tpl.pmsgs:
tpl.handleMessage(msg)
case err := <-tpl.perrs:
tpl.handlerError(err)
}
}
}
func (tpl *layer) dispose() {
tpl.Log().Debug("disposing...")
// wait for protocols
for _, protocol := range tpl.protocols.all() {
tpl.protocols.drop(protocolKey(protocol.Network()))
<-protocol.Done()
}
tpl.listenPorts = make(map[string][]sip.Port)
close(tpl.pmsgs)
close(tpl.perrs)
close(tpl.msgs)
close(tpl.errs)
}
// handles incoming message from protocol
// should be called inside goroutine for non-blocking forwarding
func (tpl *layer) handleMessage(msg sip.Message) {
logger := tpl.Log().WithFields(msg.Fields())
logger.Debugf("received SIP message:\n%s", msg)
logger.Trace("passing up SIP message...")
// pass up message
select {
case <-tpl.canceled:
case tpl.msgs <- msg:
logger.Trace("SIP message passed up")
}
}
func (tpl *layer) handlerError(err error) {
// TODO: implement re-connection strategy for listeners
var terr Error
if errors.As(err, &terr) {
// currently log
tpl.Log().Warnf("SIP transport error: %s", err)
}
logger := tpl.Log().WithFields(log.Fields{
"sip_error": err.Error(),
})
logger.Trace("passing up error...")
select {
case <-tpl.canceled:
case tpl.errs <- err:
logger.Trace("error passed up")
}
}
type protocolKey string
// Thread-safe protocols pool.
type protocolStore struct {
protocols map[protocolKey]Protocol
mu sync.RWMutex
}
func newProtocolStore() *protocolStore {
return &protocolStore{
protocols: make(map[protocolKey]Protocol),
}
}
func (store *protocolStore) put(key protocolKey, protocol Protocol) {
store.mu.Lock()
store.protocols[key] = protocol
store.mu.Unlock()
}
func (store *protocolStore) get(key protocolKey) (Protocol, bool) {
store.mu.RLock()
defer store.mu.RUnlock()
protocol, ok := store.protocols[key]
return protocol, ok
}
func (store *protocolStore) getOrPutNew(key protocolKey, factory func() (Protocol, error)) (Protocol, error) {
store.mu.Lock()
defer store.mu.Unlock()
protocol, ok := store.protocols[key]
if ok {
return protocol, nil
}
var err error
protocol, err = factory()
if err != nil {
return nil, err
}
store.protocols[key] = protocol
return protocol, nil
}
func (store *protocolStore) drop(key protocolKey) bool {
if _, ok := store.get(key); !ok {
return false
}
store.mu.Lock()
defer store.mu.Unlock()
delete(store.protocols, key)
return true
}
func (store *protocolStore) all() []Protocol {
all := make([]Protocol, 0)
store.mu.RLock()
defer store.mu.RUnlock()
for _, protocol := range store.protocols {
all = append(all, protocol)
}
return all
}

View file

@ -0,0 +1,498 @@
package transport
import (
"crypto/tls"
"errors"
"fmt"
"net"
"strings"
"sync"
"github.com/ghettovoice/gosip/log"
)
type ListenerKey string
func (key ListenerKey) String() string {
return string(key)
}
type ListenerPool interface {
log.Loggable
Done() <-chan struct{}
String() string
Put(key ListenerKey, listener net.Listener) error
Get(key ListenerKey) (net.Listener, error)
All() []net.Listener
Drop(key ListenerKey) error
DropAll() error
Length() int
}
type ListenerHandler interface {
log.Loggable
Cancel()
Done() <-chan struct{}
String() string
Key() ListenerKey
Listener() net.Listener
Serve()
// TODO implement later, runtime replace of the net.Listener in handler
// Update(ls net.Listener)
}
type listenerPool struct {
hwg sync.WaitGroup
mu sync.RWMutex
store map[ListenerKey]ListenerHandler
output chan<- Connection
errs chan<- error
cancel <-chan struct{}
done chan struct{}
hconns chan Connection
herrs chan error
log log.Logger
}
func NewListenerPool(
output chan<- Connection,
errs chan<- error,
cancel <-chan struct{},
logger log.Logger,
) ListenerPool {
pool := &listenerPool{
store: make(map[ListenerKey]ListenerHandler),
output: output,
errs: errs,
cancel: cancel,
done: make(chan struct{}),
hconns: make(chan Connection),
herrs: make(chan error),
}
pool.log = logger.
WithPrefix("transport.ListenerPool").
WithFields(log.Fields{
"listener_pool_ptr": fmt.Sprintf("%p", pool),
})
go func() {
<-pool.cancel
pool.dispose()
}()
go pool.serveHandlers()
return pool
}
func (pool *listenerPool) String() string {
if pool == nil {
return "<nil>"
}
return fmt.Sprintf("transport.ListenerPool<%s>", pool.Log().Fields())
}
func (pool *listenerPool) Log() log.Logger {
return pool.log
}
// Done returns channel that resolves when pool gracefully completes it work.
func (pool *listenerPool) Done() <-chan struct{} {
return pool.done
}
func (pool *listenerPool) Put(key ListenerKey, listener net.Listener) error {
select {
case <-pool.cancel:
return &PoolError{
fmt.Errorf("listener pool closed"),
"put listener",
pool.String(),
}
default:
}
if key == "" {
return &PoolError{
fmt.Errorf("empty listener key"),
"put listener",
pool.String(),
}
}
pool.mu.Lock()
defer pool.mu.Unlock()
return pool.put(key, listener)
}
func (pool *listenerPool) Get(key ListenerKey) (net.Listener, error) {
pool.mu.RLock()
defer pool.mu.RUnlock()
return pool.getListener(key)
}
func (pool *listenerPool) Drop(key ListenerKey) error {
pool.mu.Lock()
defer pool.mu.Unlock()
return pool.drop(key)
}
func (pool *listenerPool) DropAll() error {
pool.mu.Lock()
for key := range pool.store {
if err := pool.drop(key); err != nil {
pool.Log().Errorf("drop listener %s failed: %s", key, err)
}
}
pool.mu.Unlock()
return nil
}
func (pool *listenerPool) All() []net.Listener {
pool.mu.RLock()
listns := make([]net.Listener, 0)
for _, handler := range pool.store {
listns = append(listns, handler.Listener())
}
pool.mu.RUnlock()
return listns
}
func (pool *listenerPool) Length() int {
pool.mu.RLock()
defer pool.mu.RUnlock()
return len(pool.store)
}
func (pool *listenerPool) dispose() {
// clean pool
pool.DropAll()
pool.hwg.Wait()
// stop serveHandlers goroutine
close(pool.hconns)
close(pool.herrs)
close(pool.done)
}
func (pool *listenerPool) serveHandlers() {
pool.Log().Debug("start serve listener handlers")
defer pool.Log().Debug("stop serve listener handlers")
for {
logger := pool.Log()
select {
case conn, ok := <-pool.hconns:
if !ok {
return
}
if conn == nil {
continue
}
logger = log.AddFieldsFrom(logger, conn)
logger.Trace("passing up connection")
select {
case <-pool.cancel:
return
case pool.output <- conn:
logger.Trace("connection passed up")
}
case err, ok := <-pool.herrs:
if !ok {
return
}
if err == nil {
continue
}
var lerr *ListenerHandlerError
if errors.As(err, &lerr) {
pool.mu.RLock()
handler, gerr := pool.get(lerr.Key)
pool.mu.RUnlock()
if gerr == nil {
logger = logger.WithFields(handler.Log().Fields())
if lerr.Network() {
// listener broken or closed, should be dropped
logger.Debugf("listener network error: %s; drop it and go further", lerr)
if err := pool.Drop(handler.Key()); err != nil {
logger.Error(err)
}
} else {
// other
logger.Tracef("listener error: %s; pass the error up", lerr)
}
} else {
// ignore, handler already dropped out
logger.Tracef("ignore error from already dropped out listener %s: %s", lerr.Key, lerr)
continue
}
} else {
// all other possible errors
logger.Tracef("ignore non listener error: %s", err)
continue
}
select {
case <-pool.cancel:
return
case pool.errs <- err:
logger.Trace("error passed up")
}
}
}
}
func (pool *listenerPool) put(key ListenerKey, listener net.Listener) error {
if _, err := pool.get(key); err == nil {
return &PoolError{
fmt.Errorf("key %s already exists in the pool", key),
"put listener",
pool.String(),
}
}
// wrap to handler
handler := NewListenerHandler(key, listener, pool.hconns, pool.herrs, pool.Log())
pool.Log().WithFields(handler.Log().Fields()).Trace("put listener to the pool")
// update store
pool.store[handler.Key()] = handler
// start serving
pool.hwg.Add(1)
go handler.Serve()
go func() {
<-handler.Done()
pool.hwg.Done()
}()
return nil
}
func (pool *listenerPool) drop(key ListenerKey) error {
// check existence in pool
handler, err := pool.get(key)
if err != nil {
return err
}
handler.Cancel()
pool.Log().WithFields(handler.Log().Fields()).Trace("drop listener from the pool")
// modify store
delete(pool.store, key)
return nil
}
func (pool *listenerPool) get(key ListenerKey) (ListenerHandler, error) {
if handler, ok := pool.store[key]; ok {
return handler, nil
}
return nil, &PoolError{
fmt.Errorf("listenr %s not found in the pool", key),
"get listener",
pool.String(),
}
}
func (pool *listenerPool) getListener(key ListenerKey) (net.Listener, error) {
if handler, err := pool.get(key); err == nil {
return handler.Listener(), nil
} else {
return nil, err
}
}
type listenerHandler struct {
key ListenerKey
listener net.Listener
output chan<- Connection
errs chan<- error
cancelOnce sync.Once
canceled chan struct{}
done chan struct{}
log log.Logger
}
func NewListenerHandler(
key ListenerKey,
listener net.Listener,
output chan<- Connection,
errs chan<- error,
logger log.Logger,
) ListenerHandler {
handler := &listenerHandler{
key: key,
listener: listener,
output: output,
errs: errs,
canceled: make(chan struct{}),
done: make(chan struct{}),
}
handler.log = logger.
WithPrefix("transport.ListenerHandler").
WithFields(log.Fields{
"listener_handler_ptr": fmt.Sprintf("%p", handler),
"listener_ptr": fmt.Sprintf("%p", listener),
"listener_key": key,
})
return handler
}
func (handler *listenerHandler) String() string {
if handler == nil {
return "<nil>"
}
return fmt.Sprintf("transport.ListenerHandler<%s>", handler.Log().Fields())
}
func (handler *listenerHandler) Log() log.Logger {
return handler.log
}
func (handler *listenerHandler) Key() ListenerKey {
return handler.key
}
func (handler *listenerHandler) Listener() net.Listener {
return handler.listener
}
func (handler *listenerHandler) Serve() {
defer close(handler.done)
handler.Log().Debug("begin serve listener")
defer handler.Log().Debugf("stop serve listener")
wg := &sync.WaitGroup{}
wg.Add(1)
go handler.acceptConnections(wg)
wg.Wait()
}
func (handler *listenerHandler) acceptConnections(wg *sync.WaitGroup) {
defer func() {
handler.Listener().Close()
wg.Done()
}()
handler.Log().Debug("begin accept connections")
defer handler.Log().Debug("stop accept connections")
for {
// wait for the new connection
baseConn, err := handler.Listener().Accept()
if err != nil {
//// if we get timeout error just go further and try accept on the next iteration
//var netErr net.Error
//if errors.As(err, &netErr) {
// if netErr.Timeout() || netErr.Temporary() {
// handler.Log().Warnf("listener timeout or temporary unavailable, sleep by %s", netErrRetryTime)
//
// time.Sleep(netErrRetryTime)
//
// continue
// }
//}
// broken or closed listener
// pass up error and exit
err = &ListenerHandlerError{
err,
handler.Key(),
fmt.Sprintf("%p", handler),
listenerNetwork(handler.Listener()),
handler.Listener().Addr().String(),
}
select {
case <-handler.canceled:
case handler.errs <- err:
}
return
}
var network string
switch bc := baseConn.(type) {
case *tls.Conn:
network = "tls"
case *wsConn:
if _, ok := bc.Conn.(*tls.Conn); ok {
network = "wss"
} else {
network = "ws"
}
default:
network = strings.ToLower(baseConn.RemoteAddr().Network())
}
key := ConnectionKey(network + ":" + baseConn.RemoteAddr().String())
handler.output <- NewConnection(baseConn, key, network, handler.Log())
}
}
// Cancel stops serving.
// blocked until Serve completes
func (handler *listenerHandler) Cancel() {
handler.cancelOnce.Do(func() {
close(handler.canceled)
handler.Listener().Close()
handler.Log().Debug("listener handler canceled")
})
}
// Done returns channel that resolves when handler gracefully completes it work.
func (handler *listenerHandler) Done() <-chan struct{} {
return handler.done
}
func listenerNetwork(ls net.Listener) string {
if val, ok := ls.(interface{ Network() string }); ok {
return val.Network()
}
switch ls.(type) {
case *net.TCPListener:
return "tcp"
case *net.UnixListener:
return "unix"
default:
return ""
}
}

View file

@ -0,0 +1,90 @@
package transport
import (
"net"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
// TODO migrate other factories to functional arguments
type Options struct {
MessageMapper sip.MessageMapper
Logger log.Logger
}
type LayerOption interface {
ApplyLayer(opts *LayerOptions)
}
type LayerOptions struct {
Options
DNSResolver *net.Resolver
}
type ProtocolOption interface {
ApplyProtocol(opts *ProtocolOptions)
}
type ProtocolOptions struct {
Options
}
func WithMessageMapper(mapper sip.MessageMapper) interface {
LayerOption
ProtocolOption
} {
return withMessageMapper{mapper}
}
type withMessageMapper struct {
mapper sip.MessageMapper
}
func (o withMessageMapper) ApplyLayer(opts *LayerOptions) {
opts.MessageMapper = o.mapper
}
func (o withMessageMapper) ApplyProtocol(opts *ProtocolOptions) {
opts.MessageMapper = o.mapper
}
func WithLogger(logger log.Logger) interface {
LayerOption
ProtocolOption
} {
return withLogger{logger}
}
type withLogger struct {
logger log.Logger
}
func (o withLogger) ApplyLayer(opts *LayerOptions) {
opts.Logger = o.logger
}
func (o withLogger) ApplyProtocol(opts *ProtocolOptions) {
opts.Logger = o.logger
}
func WithDNSResolver(resolver *net.Resolver) LayerOption {
return withDnsResolver{resolver}
}
type withDnsResolver struct {
resolver *net.Resolver
}
func (o withDnsResolver) ApplyLayer(opts *LayerOptions) {
opts.DNSResolver = o.resolver
}
// Listen method options
type ListenOption interface {
ApplyListen(opts *ListenOptions)
}
type ListenOptions struct {
TLSConfig TLSConfig
}

View file

@ -0,0 +1,71 @@
package transport
import (
"fmt"
"strings"
"time"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
const (
//netErrRetryTime = 5 * time.Second
sockTTL = time.Hour
)
// Protocol implements network specific features.
type Protocol interface {
Done() <-chan struct{}
Network() string
Reliable() bool
Streamed() bool
Listen(target *Target, options ...ListenOption) error
Send(target *Target, msg sip.Message) error
String() string
}
type ProtocolFactory func(
network string,
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) (Protocol, error)
type protocol struct {
network string
reliable bool
streamed bool
log log.Logger
}
func (pr *protocol) Log() log.Logger {
return pr.log
}
func (pr *protocol) String() string {
if pr == nil {
return "<nil>"
}
fields := pr.Log().Fields().WithFields(log.Fields{
"network": pr.network,
})
return fmt.Sprintf("transport.Protocol<%s>", fields)
}
func (pr *protocol) Network() string {
return strings.ToUpper(pr.network)
}
func (pr *protocol) Reliable() bool {
return pr.reliable
}
func (pr *protocol) Streamed() bool {
return pr.streamed
}

View file

@ -0,0 +1,209 @@
package transport
import (
"fmt"
"net"
"strings"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
type tcpListener struct {
net.Listener
network string
}
func (l *tcpListener) Network() string {
return strings.ToUpper(l.network)
}
// TCP protocol implementation
type tcpProtocol struct {
protocol
listeners ListenerPool
connections ConnectionPool
conns chan Connection
listen func(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error)
dial func(addr *net.TCPAddr) (net.Conn, error)
resolveAddr func(addr string) (*net.TCPAddr, error)
}
func NewTcpProtocol(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) Protocol {
p := new(tcpProtocol)
p.network = "tcp"
p.reliable = true
p.streamed = true
p.conns = make(chan Connection)
p.log = logger.
WithPrefix("transport.Protocol").
WithFields(log.Fields{
"protocol_ptr": fmt.Sprintf("%p", p),
})
// TODO: add separate errs chan to listen errors from pool for reconnection?
p.listeners = NewListenerPool(p.conns, errs, cancel, p.Log())
p.connections = NewConnectionPool(output, errs, cancel, msgMapper, p.Log())
p.listen = p.defaultListen
p.dial = p.defaultDial
p.resolveAddr = p.defaultResolveAddr
// pipe listener and connection pools
go p.pipePools()
return p
}
func (p *tcpProtocol) defaultListen(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error) {
return net.ListenTCP(p.network, addr)
}
func (p *tcpProtocol) defaultDial(addr *net.TCPAddr) (net.Conn, error) {
return net.DialTCP(p.network, nil, addr)
}
func (p *tcpProtocol) defaultResolveAddr(addr string) (*net.TCPAddr, error) {
return net.ResolveTCPAddr(p.network, addr)
}
func (p *tcpProtocol) Done() <-chan struct{} {
return p.connections.Done()
}
// piping new connections to connection pool for serving
func (p *tcpProtocol) pipePools() {
defer close(p.conns)
p.Log().Debug("start pipe pools")
defer p.Log().Debug("stop pipe pools")
for {
select {
case <-p.listeners.Done():
return
case conn := <-p.conns:
logger := log.AddFieldsFrom(p.Log(), conn)
if err := p.connections.Put(conn, sockTTL); err != nil {
// TODO should it be passed up to UA?
logger.Errorf("put %s connection to the pool failed: %s", conn.Key(), err)
conn.Close()
continue
}
}
}
}
func (p *tcpProtocol) Listen(target *Target, options ...ListenOption) error {
target = FillTargetHostAndPort(p.Network(), target)
laddr, err := p.resolveAddr(target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
listener, err := p.listen(laddr, options...)
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("listen on %s %s address", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
p.Log().Debugf("begin listening on %s %s", p.Network(), target.Addr())
// index listeners by local address
// should live infinitely
key := ListenerKey(fmt.Sprintf("%s:0.0.0.0:%d", p.network, target.Port))
err = p.listeners.Put(key, &tcpListener{
Listener: listener,
network: p.network,
})
if err != nil {
err = &ProtocolError{
Err: err,
Op: fmt.Sprintf("put %s listener to the pool", key),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return err // should be nil here
}
func (p *tcpProtocol) Send(target *Target, msg sip.Message) error {
target = FillTargetHostAndPort(p.Network(), target)
// validate remote address
if target.Host == "" {
return &ProtocolError{
fmt.Errorf("empty remote target host"),
fmt.Sprintf("send SIP message to %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
// resolve remote address
raddr, err := p.resolveAddr(target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
// find or create connection
conn, err := p.getOrCreateConnection(raddr)
if err != nil {
return &ProtocolError{
Err: err,
Op: fmt.Sprintf("get or create %s connection", p.Network()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
logger := log.AddFieldsFrom(p.Log(), conn, msg)
logger.Tracef("writing SIP message to %s %s", p.Network(), raddr)
// send message
_, err = conn.Write([]byte(msg.String()))
if err != nil {
err = &ProtocolError{
Err: err,
Op: fmt.Sprintf("write SIP message to the %s connection", conn.Key()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return err
}
func (p *tcpProtocol) getOrCreateConnection(raddr *net.TCPAddr) (Connection, error) {
key := ConnectionKey(p.network + ":" + raddr.String())
conn, err := p.connections.Get(key)
if err != nil {
p.Log().Debugf("connection for remote address %s %s not found, create a new one", p.Network(), raddr)
tcpConn, err := p.dial(raddr)
if err != nil {
return nil, fmt.Errorf("dial to %s %s: %w", p.Network(), raddr, err)
}
conn = NewConnection(tcpConn, key, p.network, p.Log())
if err := p.connections.Put(conn, sockTTL); err != nil {
return conn, fmt.Errorf("put %s connection to the pool: %w", conn.Key(), err)
}
}
return conn, nil
}

View file

@ -0,0 +1,67 @@
package transport
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
type tlsProtocol struct {
tcpProtocol
}
func NewTlsProtocol(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) Protocol {
p := new(tlsProtocol)
p.network = "tls"
p.reliable = true
p.streamed = true
p.conns = make(chan Connection)
p.log = logger.
WithPrefix("transport.Protocol").
WithFields(log.Fields{
"protocol_ptr": fmt.Sprintf("%p", p),
})
//TODO: add separate errs chan to listen errors from pool for reconnection?
p.listeners = NewListenerPool(p.conns, errs, cancel, p.Log())
p.connections = NewConnectionPool(output, errs, cancel, msgMapper, p.Log())
p.listen = func(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error) {
if len(options) == 0 {
return net.ListenTCP("tcp", addr)
}
optsHash := ListenOptions{}
for _, opt := range options {
opt.ApplyListen(&optsHash)
}
cert, err := tls.LoadX509KeyPair(optsHash.TLSConfig.Cert, optsHash.TLSConfig.Key)
if err != nil {
return nil, fmt.Errorf("load TLS certficate %s: %w", optsHash.TLSConfig.Cert, err)
}
return tls.Listen("tcp", addr.String(), &tls.Config{
Certificates: []tls.Certificate{cert},
})
}
p.dial = func(addr *net.TCPAddr) (net.Conn, error) {
return tls.Dial("tcp", addr.String(), &tls.Config{
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return nil
},
})
}
p.resolveAddr = func(addr string) (*net.TCPAddr, error) {
return net.ResolveTCPAddr("tcp", addr)
}
//pipe listener and connection pools
go p.pipePools()
return p
}

View file

@ -0,0 +1,357 @@
// transport package implements SIP transport layer.
package transport
import (
"errors"
"fmt"
"io"
"net"
"regexp"
"strconv"
"strings"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
const (
MTU = sip.MTU
DefaultHost = sip.DefaultHost
DefaultProtocol = sip.DefaultProtocol
DefaultUdpPort = sip.DefaultUdpPort
DefaultTcpPort = sip.DefaultTcpPort
DefaultTlsPort = sip.DefaultTlsPort
DefaultWsPort = sip.DefaultWsPort
DefaultWssPort = sip.DefaultWssPort
)
// Target endpoint
type Target struct {
Host string
Port *sip.Port
}
func (trg *Target) Addr() string {
var (
host string
port sip.Port
)
if strings.TrimSpace(trg.Host) != "" {
host = trg.Host
} else {
host = DefaultHost
}
if trg.Port != nil {
port = *trg.Port
}
return fmt.Sprintf("%v:%v", host, port)
}
func (trg *Target) String() string {
if trg == nil {
return "<nil>"
}
fields := log.Fields{
"target_addr": trg.Addr(),
}
return fmt.Sprintf("transport.Target<%s>", fields)
}
func NewTarget(host string, port int) *Target {
cport := sip.Port(port)
return &Target{Host: host, Port: &cport}
}
func NewTargetFromAddr(addr string) (*Target, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
iport, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
return NewTarget(host, iport), nil
}
// Fills endpoint target with default values.
func FillTargetHostAndPort(network string, target *Target) *Target {
if strings.TrimSpace(target.Host) == "" {
target.Host = DefaultHost
}
if target.Port == nil {
p := sip.DefaultPort(network)
target.Port = &p
}
return target
}
// Transport error
type Error interface {
net.Error
// Network indicates network level errors
Network() bool
}
func isNetwork(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return true
} else {
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe)
}
}
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}
func isTemporary(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Temporary()
}
return false
}
func isCanceled(err error) bool {
var cancelErr sip.CancelError
if errors.As(err, &cancelErr) {
return cancelErr.Canceled()
}
return false
}
func isExpired(err error) bool {
var expiryErr sip.ExpireError
if errors.As(err, &expiryErr) {
return expiryErr.Expired()
}
return false
}
// Connection level error.
type ConnectionError struct {
Err error
Op string
Net string
Source string
Dest string
ConnPtr string
}
func (err *ConnectionError) Unwrap() error { return err.Err }
func (err *ConnectionError) Network() bool { return isNetwork(err.Err) }
func (err *ConnectionError) Timeout() bool { return isTimeout(err.Err) }
func (err *ConnectionError) Temporary() bool { return isTemporary(err.Err) }
func (err *ConnectionError) Error() string {
if err == nil {
return "<nil>"
}
fields := log.Fields{
"network": "???",
"connection_ptr": "???",
"source": "???",
"destination": "???",
}
if err.Net != "" {
fields["network"] = err.Net
}
if err.ConnPtr != "" {
fields["connection_ptr"] = err.ConnPtr
}
if err.Source != "" {
fields["source"] = err.Source
}
if err.Dest != "" {
fields["destination"] = err.Dest
}
return fmt.Sprintf("transport.ConnectionError<%s> %s failed: %s", fields, err.Op, err.Err)
}
type ExpireError string
func (err ExpireError) Network() bool { return false }
func (err ExpireError) Timeout() bool { return true }
func (err ExpireError) Temporary() bool { return false }
func (err ExpireError) Canceled() bool { return false }
func (err ExpireError) Expired() bool { return true }
func (err ExpireError) Error() string { return "transport.ExpireError: " + string(err) }
// Net Protocol level error
type ProtocolError struct {
Err error
Op string
ProtoPtr string
}
func (err *ProtocolError) Unwrap() error { return err.Err }
func (err *ProtocolError) Network() bool { return isNetwork(err.Err) }
func (err *ProtocolError) Timeout() bool { return isTimeout(err.Err) }
func (err *ProtocolError) Temporary() bool { return isTemporary(err.Err) }
func (err *ProtocolError) Error() string {
if err == nil {
return "<nil>"
}
fields := log.Fields{
"protocol_ptr": "???",
}
if err.ProtoPtr != "" {
fields["protocol_ptr"] = err.ProtoPtr
}
return fmt.Sprintf("transport.ProtocolError<%s> %s failed: %s", fields, err.Op, err.Err)
}
type ConnectionHandlerError struct {
Err error
Key ConnectionKey
HandlerPtr string
Net string
LAddr string
RAddr string
}
func (err *ConnectionHandlerError) Unwrap() error { return err.Err }
func (err *ConnectionHandlerError) Network() bool { return isNetwork(err.Err) }
func (err *ConnectionHandlerError) Timeout() bool { return isTimeout(err.Err) }
func (err *ConnectionHandlerError) Temporary() bool { return isTemporary(err.Err) }
func (err *ConnectionHandlerError) Canceled() bool { return isCanceled(err.Err) }
func (err *ConnectionHandlerError) Expired() bool { return isExpired(err.Err) }
func (err *ConnectionHandlerError) EOF() bool {
if err.Err == io.EOF {
return true
}
ok, _ := regexp.MatchString("(?i)eof", err.Err.Error())
return ok
}
func (err *ConnectionHandlerError) Error() string {
if err == nil {
return "<nil>"
}
fields := log.Fields{
"handler_ptr": "???",
"network": "???",
"local_addr": "???",
"remote_addr": "???",
}
if err.HandlerPtr != "" {
fields["handler_ptr"] = err.HandlerPtr
}
if err.Net != "" {
fields["network"] = err.Net
}
if err.LAddr != "" {
fields["local_addr"] = err.LAddr
}
if err.RAddr != "" {
fields["remote_addr"] = err.RAddr
}
return fmt.Sprintf("transport.ConnectionHandlerError<%s>: %s", fields, err.Err)
}
type ListenerHandlerError struct {
Err error
Key ListenerKey
HandlerPtr string
Net string
Addr string
}
func (err *ListenerHandlerError) Unwrap() error { return err.Err }
func (err *ListenerHandlerError) Network() bool { return isNetwork(err.Err) }
func (err *ListenerHandlerError) Timeout() bool { return isTimeout(err.Err) }
func (err *ListenerHandlerError) Temporary() bool { return isTemporary(err.Err) }
func (err *ListenerHandlerError) Canceled() bool { return isCanceled(err.Err) }
func (err *ListenerHandlerError) Expired() bool { return isExpired(err.Err) }
func (err *ListenerHandlerError) Error() string {
if err == nil {
return "<nil>"
}
fields := log.Fields{
"handler_ptr": "???",
"network": "???",
"local_addr": "???",
"remote_addr": "???",
}
if err.HandlerPtr != "" {
fields["handler_ptr"] = err.HandlerPtr
}
if err.Net != "" {
fields["network"] = err.Net
}
if err.Addr != "" {
fields["local_addr"] = err.Addr
}
return fmt.Sprintf("transport.ListenerHandlerError<%s>: %s", fields, err.Err)
}
type PoolError struct {
Err error
Op string
Pool string
}
func (err *PoolError) Unwrap() error { return err.Err }
func (err *PoolError) Network() bool { return isNetwork(err.Err) }
func (err *PoolError) Timeout() bool { return isTimeout(err.Err) }
func (err *PoolError) Temporary() bool { return isTemporary(err.Err) }
func (err *PoolError) Error() string {
if err == nil {
return "<nil>"
}
fields := log.Fields{
"pool": "???",
}
if err.Pool != "" {
fields["pool"] = err.Pool
}
return fmt.Sprintf("transport.PoolError<%s> %s failed: %s", fields, err.Op, err.Err)
}
type UnsupportedProtocolError string
func (err UnsupportedProtocolError) Network() bool { return false }
func (err UnsupportedProtocolError) Timeout() bool { return false }
func (err UnsupportedProtocolError) Temporary() bool { return false }
func (err UnsupportedProtocolError) Error() string {
return "transport.UnsupportedProtocolError: " + string(err)
}
//TLSConfig for TLS and WSS only
type TLSConfig struct {
Domain string
Cert string
Key string
Pass string
}
func (c TLSConfig) ApplyListen(opts *ListenOptions) {
opts.TLSConfig.Domain = c.Domain
opts.TLSConfig.Cert = c.Cert
opts.TLSConfig.Key = c.Key
opts.TLSConfig.Pass = c.Pass
}

View file

@ -0,0 +1,138 @@
package transport
import (
"fmt"
"net"
"strings"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
// UDP protocol implementation
type udpProtocol struct {
protocol
connections ConnectionPool
}
func NewUdpProtocol(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) Protocol {
p := new(udpProtocol)
p.network = "udp"
p.reliable = false
p.streamed = false
p.log = logger.
WithPrefix("transport.Protocol").
WithFields(log.Fields{
"protocol_ptr": fmt.Sprintf("%p", p),
})
// TODO: add separate errs chan to listen errors from pool for reconnection?
p.connections = NewConnectionPool(output, errs, cancel, msgMapper, p.Log())
return p
}
func (p *udpProtocol) Done() <-chan struct{} {
return p.connections.Done()
}
func (p *udpProtocol) Listen(target *Target, options ...ListenOption) error {
// fill empty target props with default values
target = FillTargetHostAndPort(p.Network(), target)
// resolve local UDP endpoint
laddr, err := net.ResolveUDPAddr(p.network, target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
// create UDP connection
udpConn, err := net.ListenUDP(p.network, laddr)
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("listen on %s %s address", p.Network(), laddr),
fmt.Sprintf("%p", p),
}
}
p.Log().Debugf("begin listening on %s %s", p.Network(), laddr)
// register new connection
// index by local address, TTL=0 - unlimited expiry time
key := ConnectionKey(fmt.Sprintf("%s:0.0.0.0:%d", p.network, laddr.Port))
conn := NewConnection(udpConn, key, p.network, p.Log())
err = p.connections.Put(conn, 0)
if err != nil {
err = &ProtocolError{
Err: err,
Op: fmt.Sprintf("put %s connection to the pool", conn.Key()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return err // should be nil here
}
func (p *udpProtocol) Send(target *Target, msg sip.Message) error {
target = FillTargetHostAndPort(p.Network(), target)
// validate remote address
if target.Host == "" {
return &ProtocolError{
fmt.Errorf("empty remote target host"),
fmt.Sprintf("send SIP message to %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
// resolve remote address
raddr, err := net.ResolveUDPAddr(p.network, target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
_, port, err := net.SplitHostPort(msg.Source())
if err != nil {
return &ProtocolError{
Err: err,
Op: "resolve source port",
ProtoPtr: fmt.Sprintf("%p", p),
}
}
for _, conn := range p.connections.All() {
parts := strings.Split(string(conn.Key()), ":")
if parts[2] == port {
logger := log.AddFieldsFrom(p.Log(), conn, msg)
logger.Tracef("writing SIP message to %s %s", p.Network(), raddr)
if _, err = conn.WriteTo([]byte(msg.String()), raddr); err != nil {
return &ProtocolError{
Err: err,
Op: fmt.Sprintf("write SIP message to the %s connection", conn.Key()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return nil
}
}
return &ProtocolError{
fmt.Errorf("connection on port %s not found", port),
"search connection",
fmt.Sprintf("%p", p),
}
}

View file

@ -0,0 +1,301 @@
package transport
import (
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
var (
wsSubProtocol = "sip"
)
type wsConn struct {
net.Conn
client bool
}
func (wc *wsConn) Read(b []byte) (n int, err error) {
var msg []byte
var op ws.OpCode
if wc.client {
msg, op, err = wsutil.ReadServerData(wc.Conn)
} else {
msg, op, err = wsutil.ReadClientData(wc.Conn)
}
if err != nil {
// handle error
var wsErr wsutil.ClosedError
if errors.As(err, &wsErr) {
return n, io.EOF
}
return n, err
}
if op == ws.OpClose {
return n, io.EOF
}
copy(b, msg)
return len(msg), err
}
func (wc *wsConn) Write(b []byte) (n int, err error) {
if wc.client {
err = wsutil.WriteClientMessage(wc.Conn, ws.OpText, b)
} else {
err = wsutil.WriteServerMessage(wc.Conn, ws.OpText, b)
}
if err != nil {
// handle error
var wsErr wsutil.ClosedError
if errors.As(err, &wsErr) {
return n, io.EOF
}
return n, err
}
return len(b), nil
}
type wsListener struct {
net.Listener
network string
u ws.Upgrader
log log.Logger
}
func NewWsListener(listener net.Listener, network string, log log.Logger) *wsListener {
l := &wsListener{
Listener: listener,
network: network,
log: log,
}
l.u.Protocol = func(val []byte) bool {
return string(val) == wsSubProtocol
}
return l
}
func (l *wsListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, fmt.Errorf("accept new connection: %w", err)
}
if _, err = l.u.Upgrade(conn); err == nil {
conn = &wsConn{
Conn: conn,
client: false,
}
} else {
l.log.Warnf("fallback to simple TCP connection due to WS upgrade error: %s", err)
err = nil
}
return conn, err
}
func (l *wsListener) Network() string {
return strings.ToUpper(l.network)
}
type wsProtocol struct {
protocol
listeners ListenerPool
connections ConnectionPool
conns chan Connection
listen func(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error)
resolveAddr func(addr string) (*net.TCPAddr, error)
dialer ws.Dialer
}
func NewWsProtocol(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) Protocol {
p := new(wsProtocol)
p.network = "ws"
p.reliable = true
p.streamed = true
p.conns = make(chan Connection)
p.log = logger.
WithPrefix("transport.Protocol").
WithFields(log.Fields{
"protocol_ptr": fmt.Sprintf("%p", p),
})
//TODO: add separate errs chan to listen errors from pool for reconnection?
p.listeners = NewListenerPool(p.conns, errs, cancel, p.Log())
p.connections = NewConnectionPool(output, errs, cancel, msgMapper, p.Log())
p.listen = p.defaultListen
p.resolveAddr = p.defaultResolveAddr
p.dialer.Protocols = []string{wsSubProtocol}
p.dialer.Timeout = time.Minute
//pipe listener and connection pools
go p.pipePools()
return p
}
func (p *wsProtocol) defaultListen(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error) {
return net.ListenTCP("tcp", addr)
}
func (p *wsProtocol) defaultResolveAddr(addr string) (*net.TCPAddr, error) {
return net.ResolveTCPAddr("tcp", addr)
}
func (p *wsProtocol) Done() <-chan struct{} {
return p.connections.Done()
}
//piping new connections to connection pool for serving
func (p *wsProtocol) pipePools() {
defer close(p.conns)
p.Log().Debug("start pipe pools")
defer p.Log().Debug("stop pipe pools")
for {
select {
case <-p.listeners.Done():
return
case conn := <-p.conns:
logger := log.AddFieldsFrom(p.Log(), conn)
if err := p.connections.Put(conn, sockTTL); err != nil {
// TODO should it be passed up to UA?
logger.Errorf("put %s connection to the pool failed: %s", conn.Key(), err)
conn.Close()
continue
}
}
}
}
func (p *wsProtocol) Listen(target *Target, options ...ListenOption) error {
target = FillTargetHostAndPort(p.Network(), target)
laddr, err := p.resolveAddr(target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
listener, err := p.listen(laddr, options...)
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("listen on %s %s address", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
p.Log().Debugf("begin listening on %s %s", p.Network(), target.Addr())
//index listeners by local address
// should live infinitely
key := ListenerKey(fmt.Sprintf("%s:0.0.0.0:%d", p.network, target.Port))
err = p.listeners.Put(key, NewWsListener(listener, p.network, p.Log()))
if err != nil {
err = &ProtocolError{
Err: err,
Op: fmt.Sprintf("put %s listener to the pool", key),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return err //should be nil here
}
func (p *wsProtocol) Send(target *Target, msg sip.Message) error {
target = FillTargetHostAndPort(p.Network(), target)
//validate remote address
if target.Host == "" {
return &ProtocolError{
fmt.Errorf("empty remote target host"),
fmt.Sprintf("send SIP message to %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
//resolve remote address
raddr, err := p.resolveAddr(target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
//find or create connection
conn, err := p.getOrCreateConnection(raddr)
if err != nil {
return &ProtocolError{
Err: err,
Op: fmt.Sprintf("get or create %s connection", p.Network()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
logger := log.AddFieldsFrom(p.Log(), conn, msg)
logger.Tracef("writing SIP message to %s %s", p.Network(), raddr)
//send message
_, err = conn.Write([]byte(msg.String()))
if err != nil {
err = &ProtocolError{
Err: err,
Op: fmt.Sprintf("write SIP message to the %s connection", conn.Key()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return err
}
func (p *wsProtocol) getOrCreateConnection(raddr *net.TCPAddr) (Connection, error) {
key := ConnectionKey(p.network + ":" + raddr.String())
conn, err := p.connections.Get(key)
if err != nil {
p.Log().Debugf("connection for address %s %s not found; create a new one", p.Network(), raddr)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
url := fmt.Sprintf("%s://%s", p.network, raddr)
baseConn, _, _, err := p.dialer.Dial(ctx, url)
if err == nil {
baseConn = &wsConn{
Conn: baseConn,
client: true,
}
} else {
if baseConn == nil {
return nil, fmt.Errorf("dial to %s %s: %w", p.Network(), raddr, err)
}
p.Log().Warnf("fallback to TCP connection due to WS upgrade error: %s", err)
}
conn = NewConnection(baseConn, key, p.network, p.Log())
if err := p.connections.Put(conn, sockTTL); err != nil {
return conn, fmt.Errorf("put %s connection to the pool: %w", conn.Key(), err)
}
}
return conn, nil
}

View file

@ -0,0 +1,66 @@
package transport
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"time"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
type wssProtocol struct {
wsProtocol
}
func NewWssProtocol(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) Protocol {
p := new(wssProtocol)
p.network = "wss"
p.reliable = true
p.streamed = true
p.conns = make(chan Connection)
p.log = logger.
WithPrefix("transport.Protocol").
WithFields(log.Fields{
"protocol_ptr": fmt.Sprintf("%p", p),
})
//TODO: add separate errs chan to listen errors from pool for reconnection?
p.listeners = NewListenerPool(p.conns, errs, cancel, p.Log())
p.connections = NewConnectionPool(output, errs, cancel, msgMapper, p.Log())
p.listen = func(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error) {
if len(options) == 0 {
return net.ListenTCP("tcp", addr)
}
optsHash := ListenOptions{}
for _, opt := range options {
opt.ApplyListen(&optsHash)
}
cert, err := tls.LoadX509KeyPair(optsHash.TLSConfig.Cert, optsHash.TLSConfig.Key)
if err != nil {
return nil, fmt.Errorf("load TLS certficate %s: %w", optsHash.TLSConfig.Cert, err)
}
return tls.Listen("tcp", addr.String(), &tls.Config{
Certificates: []tls.Certificate{cert},
})
}
p.resolveAddr = p.defaultResolveAddr
p.dialer.Protocols = []string{wsSubProtocol}
p.dialer.Timeout = time.Minute
p.dialer.TLSConfig = &tls.Config{
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return nil
},
}
//pipe listener and connection pools
go p.pipePools()
return p
}

View file

@ -0,0 +1,147 @@
// Forked from github.com/StefanKopieczek/gossip by @StefanKopieczek
package util
import (
"fmt"
"sync"
"github.com/ghettovoice/gosip/log"
)
// The buffer size of the primitive input and output chans.
const c_ELASTIC_CHANSIZE = 3
// A dynamic channel that does not block on send, but has an unlimited buffer capacity.
// ElasticChan uses a dynamic slice to buffer signals received on the input channel until
// the output channel is ready to process them.
type ElasticChan struct {
In chan interface{}
Out chan interface{}
buffer []interface{}
stopped bool
done chan struct{}
log log.Logger
logMu sync.RWMutex
}
// Initialise the Elastic channel, and start the management goroutine.
func (c *ElasticChan) Init() {
c.In = make(chan interface{}, c_ELASTIC_CHANSIZE)
c.Out = make(chan interface{}, c_ELASTIC_CHANSIZE)
c.buffer = make([]interface{}, 0)
c.done = make(chan struct{})
}
func (c *ElasticChan) Run() {
go c.manage()
}
func (c *ElasticChan) Stop() {
select {
case <-c.done:
return
default:
}
logger := c.Log()
if logger != nil {
logger.Trace("stopping elastic chan...")
}
close(c.In)
<-c.done
if logger != nil {
logger.Trace("elastic chan stopped")
}
}
func (c *ElasticChan) Log() log.Logger {
c.logMu.RLock()
defer c.logMu.RUnlock()
return c.log
}
func (c *ElasticChan) SetLog(logger log.Logger) {
c.logMu.Lock()
c.log = logger.
WithPrefix("util.ElasticChan").
WithFields(log.Fields{
"elastic_chan_ptr": fmt.Sprintf("%p", c),
})
c.logMu.Unlock()
}
// Poll for input from one end of the channel and add it to the buffer.
// Also poll sending buffered signals out over the output chan.
// TODO: add cancel chan
func (c *ElasticChan) manage() {
defer close(c.done)
loop:
for {
logger := c.Log()
if len(c.buffer) > 0 {
// The buffer has something in it, so try to send as well as
// receive.
// (Receive first in order to minimize blocked Send() calls).
select {
case in, ok := <-c.In:
if !ok {
if logger != nil {
logger.Trace("elastic chan will dispose")
}
break loop
}
c.Log().Tracef("ElasticChan %p gets '%v'", c, in)
c.buffer = append(c.buffer, in)
case c.Out <- c.buffer[0]:
c.Log().Tracef("ElasticChan %p sends '%v'", c, c.buffer[0])
c.buffer = c.buffer[1:]
}
} else {
// The buffer is empty, so there's nothing to send.
// Just wait to receive.
in, ok := <-c.In
if !ok {
if logger != nil {
logger.Trace("elastic chan will dispose")
}
break loop
}
c.Log().Tracef("ElasticChan %p gets '%v'", c, in)
c.buffer = append(c.buffer, in)
}
}
c.dispose()
}
func (c *ElasticChan) dispose() {
logger := c.Log()
if logger != nil {
logger.Trace("elastic chan disposing...")
}
for len(c.buffer) > 0 {
select {
case c.Out <- c.buffer[0]:
c.buffer = c.buffer[1:]
default:
}
}
if logger != nil {
logger.Trace("elastic chan disposed")
}
}

View file

@ -0,0 +1,104 @@
// Forked from github.com/StefanKopieczek/gossip by @StefanKopieczek
package util
import (
"errors"
"net"
"sync"
)
// Check two string pointers for equality as follows:
// - If neither pointer is nil, check equality of the underlying strings.
// - If either pointer is nil, return true if and only if they both are.
func StrPtrEq(a *string, b *string) bool {
if a == nil || b == nil {
return a == b
}
return *a == *b
}
// Check two uint16 pointers for equality as follows:
// - If neither pointer is nil, check equality of the underlying uint16s.
// - If either pointer is nil, return true if and only if they both are.
func Uint16PtrEq(a *uint16, b *uint16) bool {
if a == nil || b == nil {
return a == b
}
return *a == *b
}
func Coalesce(arg1 interface{}, arg2 interface{}, args ...interface{}) interface{} {
all := append([]interface{}{arg1, arg2}, args...)
for _, arg := range all {
if arg != nil {
return arg
}
}
return nil
}
func Noop() {}
func MergeErrs(chs ...<-chan error) <-chan error {
wg := new(sync.WaitGroup)
out := make(chan error)
pipe := func(ch <-chan error) {
defer wg.Done()
for err := range ch {
out <- err
}
}
wg.Add(len(chs))
for _, ch := range chs {
go pipe(ch)
}
go func() {
wg.Wait()
close(out)
}()
return out
}
func ResolveSelfIP() (net.IP, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
return ip, nil
}
}
return nil, errors.New("server not connected to any network")
}

View file

@ -0,0 +1,38 @@
package util
import (
"math/rand"
"time"
)
const (
letterBytes = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
// https://github.com/kpbird/golang_random_string
func RandString(n int) string {
output := make([]byte, n)
// We will take n bytes, one byte for each character of output.
randomness := make([]byte, n)
// read all random
_, err := rand.Read(randomness)
if err != nil {
panic(err)
}
l := len(letterBytes)
// fill output
for pos := range output {
// get random item
random := randomness[pos]
// random % 64
randomPos := random % uint8(l)
// put into output
output[pos] = letterBytes[randomPos]
}
return string(output)
}

View file

@ -0,0 +1,75 @@
// Forked from github.com/StefanKopieczek/gossip by @StefanKopieczek
package util
import "sync"
// Simple semaphore implementation.
// Any number of calls to Acquire() can be made; these will not block.
// If the semaphore has been acquired more times than it has been released, it is called 'blocked'.
// Otherwise, it is called 'free'.
type Semaphore interface {
// Take a semaphore lock.
Acquire()
// Release an acquired semaphore lock.
// This should only be called when the semaphore is blocked, otherwise behaviour is undefined
Release()
// Block execution until the semaphore is free.
Wait()
// Clean up the semaphore object.
Dispose()
}
func NewSemaphore() Semaphore {
sem := new(semaphore)
sem.cond = sync.NewCond(&sync.Mutex{})
go func(s *semaphore) {
select {
case <-s.stop:
return
case <-s.acquired:
s.locks += 1
case <-s.released:
s.locks -= 1
if s.locks == 0 {
s.cond.Broadcast()
}
}
}(sem)
return sem
}
// Concrete implementation of Semaphore.
type semaphore struct {
held bool
locks int
acquired chan bool
released chan bool
stop chan bool
cond *sync.Cond
}
// Implements Semaphore.Acquire()
func (sem *semaphore) Acquire() {
sem.acquired <- true
}
// Implements Semaphore.Release()
func (sem *semaphore) Release() {
sem.released <- true
}
// Implements Semaphore.Wait()
func (sem *semaphore) Wait() {
sem.cond.L.Lock()
for sem.locks != 0 {
sem.cond.Wait()
}
}
// Implements Semaphore.Dispose()
func (sem *semaphore) Dispose() {
sem.stop <- true
}

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2017 Sergey Kamardin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,63 @@
# httphead.[go](https://golang.org)
[![GoDoc][godoc-image]][godoc-url]
> Tiny HTTP header value parsing library in go.
## Overview
This library contains low-level functions for scanning HTTP RFC2616 compatible header value grammars.
## Install
```shell
go get github.com/gobwas/httphead
```
## Example
The example below shows how multiple-choise HTTP header value could be parsed with this library:
```go
options, ok := httphead.ParseOptions([]byte(`foo;bar=1,baz`), nil)
fmt.Println(options, ok)
// Output: [{foo map[bar:1]} {baz map[]}] true
```
The low-level example below shows how to optimize keys skipping and selection
of some key:
```go
// The right part of full header line like:
// X-My-Header: key;foo=bar;baz,key;baz
header := []byte(`foo;a=0,foo;a=1,foo;a=2,foo;a=3`)
// We want to search key "foo" with an "a" parameter that equal to "2".
var (
foo = []byte(`foo`)
a = []byte(`a`)
v = []byte(`2`)
)
var found bool
httphead.ScanOptions(header, func(i int, key, param, value []byte) Control {
if !bytes.Equal(key, foo) {
return ControlSkip
}
if !bytes.Equal(param, a) {
if bytes.Equal(value, v) {
// Found it!
found = true
return ControlBreak
}
return ControlSkip
}
return ControlContinue
})
```
For more usage examples please see [docs][godoc-url] or package tests.
[godoc-image]: https://godoc.org/github.com/gobwas/httphead?status.svg
[godoc-url]: https://godoc.org/github.com/gobwas/httphead
[travis-image]: https://travis-ci.org/gobwas/httphead.svg?branch=master
[travis-url]: https://travis-ci.org/gobwas/httphead

View file

@ -0,0 +1,200 @@
package httphead
import (
"bytes"
)
// ScanCookie scans cookie pairs from data using DefaultCookieScanner.Scan()
// method.
func ScanCookie(data []byte, it func(key, value []byte) bool) bool {
return DefaultCookieScanner.Scan(data, it)
}
// DefaultCookieScanner is a CookieScanner which is used by ScanCookie().
// Note that it is intended to have the same behavior as http.Request.Cookies()
// has.
var DefaultCookieScanner = CookieScanner{}
// CookieScanner contains options for scanning cookie pairs.
// See https://tools.ietf.org/html/rfc6265#section-4.1.1
type CookieScanner struct {
// DisableNameValidation disables name validation of a cookie. If false,
// only RFC2616 "tokens" are accepted.
DisableNameValidation bool
// DisableValueValidation disables value validation of a cookie. If false,
// only RFC6265 "cookie-octet" characters are accepted.
//
// Note that Strict option also affects validation of a value.
//
// If Strict is false, then scanner begins to allow space and comma
// characters inside the value for better compatibility with non standard
// cookies implementations.
DisableValueValidation bool
// BreakOnPairError sets scanner to immediately return after first pair syntax
// validation error.
// If false, scanner will try to skip invalid pair bytes and go ahead.
BreakOnPairError bool
// Strict enables strict RFC6265 mode scanning. It affects name and value
// validation, as also some other rules.
// If false, it is intended to bring the same behavior as
// http.Request.Cookies().
Strict bool
}
// Scan maps data to name and value pairs. Usually data represents value of the
// Cookie header.
func (c CookieScanner) Scan(data []byte, it func(name, value []byte) bool) bool {
lexer := &Scanner{data: data}
const (
statePair = iota
stateBefore
)
state := statePair
for lexer.Buffered() > 0 {
switch state {
case stateBefore:
// Pairs separated by ";" and space, according to the RFC6265:
// cookie-pair *( ";" SP cookie-pair )
//
// Cookie pairs MUST be separated by (";" SP). So our only option
// here is to fail as syntax error.
a, b := lexer.Peek2()
if a != ';' {
return false
}
state = statePair
advance := 1
if b == ' ' {
advance++
} else if c.Strict {
return false
}
lexer.Advance(advance)
case statePair:
if !lexer.FetchUntil(';') {
return false
}
var value []byte
name := lexer.Bytes()
if i := bytes.IndexByte(name, '='); i != -1 {
value = name[i+1:]
name = name[:i]
} else if c.Strict {
if !c.BreakOnPairError {
goto nextPair
}
return false
}
if !c.Strict {
trimLeft(name)
}
if !c.DisableNameValidation && !ValidCookieName(name) {
if !c.BreakOnPairError {
goto nextPair
}
return false
}
if !c.Strict {
value = trimRight(value)
}
value = stripQuotes(value)
if !c.DisableValueValidation && !ValidCookieValue(value, c.Strict) {
if !c.BreakOnPairError {
goto nextPair
}
return false
}
if !it(name, value) {
return true
}
nextPair:
state = stateBefore
}
}
return true
}
// ValidCookieValue reports whether given value is a valid RFC6265
// "cookie-octet" bytes.
//
// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E
// ; US-ASCII characters excluding CTLs,
// ; whitespace DQUOTE, comma, semicolon,
// ; and backslash
//
// Note that the false strict parameter disables errors on space 0x20 and comma
// 0x2c. This could be useful to bring some compatibility with non-compliant
// clients/servers in the real world.
// It acts the same as standard library cookie parser if strict is false.
func ValidCookieValue(value []byte, strict bool) bool {
if len(value) == 0 {
return true
}
for _, c := range value {
switch c {
case '"', ';', '\\':
return false
case ',', ' ':
if strict {
return false
}
default:
if c <= 0x20 {
return false
}
if c >= 0x7f {
return false
}
}
}
return true
}
// ValidCookieName reports wheter given bytes is a valid RFC2616 "token" bytes.
func ValidCookieName(name []byte) bool {
for _, c := range name {
if !OctetTypes[c].IsToken() {
return false
}
}
return true
}
func stripQuotes(bts []byte) []byte {
if last := len(bts) - 1; last > 0 && bts[0] == '"' && bts[last] == '"' {
return bts[1:last]
}
return bts
}
func trimLeft(p []byte) []byte {
var i int
for i < len(p) && OctetTypes[p[i]].IsSpace() {
i++
}
return p[i:]
}
func trimRight(p []byte) []byte {
j := len(p)
for j > 0 && OctetTypes[p[j-1]].IsSpace() {
j--
}
return p[:j]
}

View file

@ -0,0 +1,3 @@
module github.com/gobwas/httphead
go 1.15

View file

@ -0,0 +1,275 @@
package httphead
import (
"bufio"
"bytes"
)
// Version contains protocol major and minor version.
type Version struct {
Major int
Minor int
}
// RequestLine contains parameters parsed from the first request line.
type RequestLine struct {
Method []byte
URI []byte
Version Version
}
// ResponseLine contains parameters parsed from the first response line.
type ResponseLine struct {
Version Version
Status int
Reason []byte
}
// SplitRequestLine splits given slice of bytes into three chunks without
// parsing.
func SplitRequestLine(line []byte) (method, uri, version []byte) {
return split3(line, ' ')
}
// ParseRequestLine parses http request line like "GET / HTTP/1.0".
func ParseRequestLine(line []byte) (r RequestLine, ok bool) {
var i int
for i = 0; i < len(line); i++ {
c := line[i]
if !OctetTypes[c].IsToken() {
if i > 0 && c == ' ' {
break
}
return
}
}
if i == len(line) {
return
}
var proto []byte
r.Method = line[:i]
r.URI, proto = split2(line[i+1:], ' ')
if len(r.URI) == 0 {
return
}
if major, minor, ok := ParseVersion(proto); ok {
r.Version.Major = major
r.Version.Minor = minor
return r, true
}
return r, false
}
// SplitResponseLine splits given slice of bytes into three chunks without
// parsing.
func SplitResponseLine(line []byte) (version, status, reason []byte) {
return split3(line, ' ')
}
// ParseResponseLine parses first response line into ResponseLine struct.
func ParseResponseLine(line []byte) (r ResponseLine, ok bool) {
var (
proto []byte
status []byte
)
proto, status, r.Reason = split3(line, ' ')
if major, minor, ok := ParseVersion(proto); ok {
r.Version.Major = major
r.Version.Minor = minor
} else {
return r, false
}
if n, ok := IntFromASCII(status); ok {
r.Status = n
} else {
return r, false
}
// TODO(gobwas): parse here r.Reason fot TEXT rule:
// TEXT = <any OCTET except CTLs,
// but including LWS>
return r, true
}
var (
httpVersion10 = []byte("HTTP/1.0")
httpVersion11 = []byte("HTTP/1.1")
httpVersionPrefix = []byte("HTTP/")
)
// ParseVersion parses major and minor version of HTTP protocol.
// It returns parsed values and true if parse is ok.
func ParseVersion(bts []byte) (major, minor int, ok bool) {
switch {
case bytes.Equal(bts, httpVersion11):
return 1, 1, true
case bytes.Equal(bts, httpVersion10):
return 1, 0, true
case len(bts) < 8:
return
case !bytes.Equal(bts[:5], httpVersionPrefix):
return
}
bts = bts[5:]
dot := bytes.IndexByte(bts, '.')
if dot == -1 {
return
}
major, ok = IntFromASCII(bts[:dot])
if !ok {
return
}
minor, ok = IntFromASCII(bts[dot+1:])
if !ok {
return
}
return major, minor, true
}
// ReadLine reads line from br. It reads until '\n' and returns bytes without
// '\n' or '\r\n' at the end.
// It returns err if and only if line does not end in '\n'. Note that read
// bytes returned in any case of error.
//
// It is much like the textproto/Reader.ReadLine() except the thing that it
// returns raw bytes, instead of string. That is, it avoids copying bytes read
// from br.
//
// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be
// safe with future I/O operations on br.
//
// We could control I/O operations on br and do not need to make additional
// copy for safety.
func ReadLine(br *bufio.Reader) ([]byte, error) {
var line []byte
for {
bts, err := br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
// Copy bytes because next read will discard them.
line = append(line, bts...)
continue
}
// Avoid copy of single read.
if line == nil {
line = bts
} else {
line = append(line, bts...)
}
if err != nil {
return line, err
}
// Size of line is at least 1.
// In other case bufio.ReadSlice() returns error.
n := len(line)
// Cut '\n' or '\r\n'.
if n > 1 && line[n-2] == '\r' {
line = line[:n-2]
} else {
line = line[:n-1]
}
return line, nil
}
}
// ParseHeaderLine parses HTTP header as key-value pair. It returns parsed
// values and true if parse is ok.
func ParseHeaderLine(line []byte) (k, v []byte, ok bool) {
colon := bytes.IndexByte(line, ':')
if colon == -1 {
return
}
k = trim(line[:colon])
for _, c := range k {
if !OctetTypes[c].IsToken() {
return nil, nil, false
}
}
v = trim(line[colon+1:])
return k, v, true
}
// IntFromASCII converts ascii encoded decimal numeric value from HTTP entities
// to an integer.
func IntFromASCII(bts []byte) (ret int, ok bool) {
// ASCII numbers all start with the high-order bits 0011.
// If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those
// bits and interpret them directly as an integer.
var n int
if n = len(bts); n < 1 {
return 0, false
}
for i := 0; i < n; i++ {
if bts[i]&0xf0 != 0x30 {
return 0, false
}
ret += int(bts[i]&0xf) * pow(10, n-i-1)
}
return ret, true
}
const (
toLower = 'a' - 'A' // for use with OR.
toUpper = ^byte(toLower) // for use with AND.
)
// CanonicalizeHeaderKey is like standard textproto/CanonicalMIMEHeaderKey,
// except that it operates with slice of bytes and modifies it inplace without
// copying.
func CanonicalizeHeaderKey(k []byte) {
upper := true
for i, c := range k {
if upper && 'a' <= c && c <= 'z' {
k[i] &= toUpper
} else if !upper && 'A' <= c && c <= 'Z' {
k[i] |= toLower
}
upper = c == '-'
}
}
// pow for integers implementation.
// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3
func pow(a, b int) int {
p := 1
for b > 0 {
if b&1 != 0 {
p *= a
}
b >>= 1
a *= a
}
return p
}
func split3(p []byte, sep byte) (p1, p2, p3 []byte) {
a := bytes.IndexByte(p, sep)
b := bytes.IndexByte(p[a+1:], sep)
if a == -1 || b == -1 {
return p, nil, nil
}
b += a + 1
return p[:a], p[a+1 : b], p[b+1:]
}
func split2(p []byte, sep byte) (p1, p2 []byte) {
i := bytes.IndexByte(p, sep)
if i == -1 {
return p, nil
}
return p[:i], p[i+1:]
}
func trim(p []byte) []byte {
var i, j int
for i = 0; i < len(p) && (p[i] == ' ' || p[i] == '\t'); {
i++
}
for j = len(p); j > i && (p[j-1] == ' ' || p[j-1] == '\t'); {
j--
}
return p[i:j]
}

View file

@ -0,0 +1,331 @@
// Package httphead contains utils for parsing HTTP and HTTP-grammar compatible
// text protocols headers.
//
// That is, this package first aim is to bring ability to easily parse
// constructions, described here https://tools.ietf.org/html/rfc2616#section-2
package httphead
import (
"bytes"
"strings"
)
// ScanTokens parses data in this form:
//
// list = 1#token
//
// It returns false if data is malformed.
func ScanTokens(data []byte, it func([]byte) bool) bool {
lexer := &Scanner{data: data}
var ok bool
for lexer.Next() {
switch lexer.Type() {
case ItemToken:
ok = true
if !it(lexer.Bytes()) {
return true
}
case ItemSeparator:
if !isComma(lexer.Bytes()) {
return false
}
default:
return false
}
}
return ok && !lexer.err
}
// ParseOptions parses all header options and appends it to given slice of
// Option. It returns flag of successful (wellformed input) parsing.
//
// Note that appended options are all consist of subslices of data. That is,
// mutation of data will mutate appended options.
func ParseOptions(data []byte, options []Option) ([]Option, bool) {
var i int
index := -1
return options, ScanOptions(data, func(idx int, name, attr, val []byte) Control {
if idx != index {
index = idx
i = len(options)
options = append(options, Option{Name: name})
}
if attr != nil {
options[i].Parameters.Set(attr, val)
}
return ControlContinue
})
}
// SelectFlag encodes way of options selection.
type SelectFlag byte
// String represetns flag as string.
func (f SelectFlag) String() string {
var flags [2]string
var n int
if f&SelectCopy != 0 {
flags[n] = "copy"
n++
}
if f&SelectUnique != 0 {
flags[n] = "unique"
n++
}
return "[" + strings.Join(flags[:n], "|") + "]"
}
const (
// SelectCopy causes selector to copy selected option before appending it
// to resulting slice.
// If SelectCopy flag is not passed to selector, then appended options will
// contain sub-slices of the initial data.
SelectCopy SelectFlag = 1 << iota
// SelectUnique causes selector to append only not yet existing option to
// resulting slice. Unique is checked by comparing option names.
SelectUnique
)
// OptionSelector contains configuration for selecting Options from header value.
type OptionSelector struct {
// Check is a filter function that applied to every Option that possibly
// could be selected.
// If Check is nil all options will be selected.
Check func(Option) bool
// Flags contains flags for options selection.
Flags SelectFlag
// Alloc used to allocate slice of bytes when selector is configured with
// SelectCopy flag. It will be called with number of bytes needed for copy
// of single Option.
// If Alloc is nil make is used.
Alloc func(n int) []byte
}
// Select parses header data and appends it to given slice of Option.
// It also returns flag of successful (wellformed input) parsing.
func (s OptionSelector) Select(data []byte, options []Option) ([]Option, bool) {
var current Option
var has bool
index := -1
alloc := s.Alloc
if alloc == nil {
alloc = defaultAlloc
}
check := s.Check
if check == nil {
check = defaultCheck
}
ok := ScanOptions(data, func(idx int, name, attr, val []byte) Control {
if idx != index {
if has && check(current) {
if s.Flags&SelectCopy != 0 {
current = current.Copy(alloc(current.Size()))
}
options = append(options, current)
has = false
}
if s.Flags&SelectUnique != 0 {
for i := len(options) - 1; i >= 0; i-- {
if bytes.Equal(options[i].Name, name) {
return ControlSkip
}
}
}
index = idx
current = Option{Name: name}
has = true
}
if attr != nil {
current.Parameters.Set(attr, val)
}
return ControlContinue
})
if has && check(current) {
if s.Flags&SelectCopy != 0 {
current = current.Copy(alloc(current.Size()))
}
options = append(options, current)
}
return options, ok
}
func defaultAlloc(n int) []byte { return make([]byte, n) }
func defaultCheck(Option) bool { return true }
// Control represents operation that scanner should perform.
type Control byte
const (
// ControlContinue causes scanner to continue scan tokens.
ControlContinue Control = iota
// ControlBreak causes scanner to stop scan tokens.
ControlBreak
// ControlSkip causes scanner to skip current entity.
ControlSkip
)
// ScanOptions parses data in this form:
//
// values = 1#value
// value = token *( ";" param )
// param = token [ "=" (token | quoted-string) ]
//
// It calls given callback with the index of the option, option itself and its
// parameter (attribute and its value, both could be nil). Index is useful when
// header contains multiple choises for the same named option.
//
// Given callback should return one of the defined Control* values.
// ControlSkip means that passed key is not in caller's interest. That is, all
// parameters of that key will be skipped.
// ControlBreak means that no more keys and parameters should be parsed. That
// is, it must break parsing immediately.
// ControlContinue means that caller want to receive next parameter and its
// value or the next key.
//
// It returns false if data is malformed.
func ScanOptions(data []byte, it func(index int, option, attribute, value []byte) Control) bool {
lexer := &Scanner{data: data}
var ok bool
var state int
const (
stateKey = iota
stateParamBeforeName
stateParamName
stateParamBeforeValue
stateParamValue
)
var (
index int
key, param, value []byte
mustCall bool
)
for lexer.Next() {
var (
call bool
growIndex int
)
t := lexer.Type()
v := lexer.Bytes()
switch t {
case ItemToken:
switch state {
case stateKey, stateParamBeforeName:
key = v
state = stateParamBeforeName
mustCall = true
case stateParamName:
param = v
state = stateParamBeforeValue
mustCall = true
case stateParamValue:
value = v
state = stateParamBeforeName
call = true
default:
return false
}
case ItemString:
if state != stateParamValue {
return false
}
value = v
state = stateParamBeforeName
call = true
case ItemSeparator:
switch {
case isComma(v) && state == stateKey:
// Nothing to do.
case isComma(v) && state == stateParamBeforeName:
state = stateKey
// Make call only if we have not called this key yet.
call = mustCall
if !call {
// If we have already called callback with the key
// that just ended.
index++
} else {
// Else grow the index after calling callback.
growIndex = 1
}
case isComma(v) && state == stateParamBeforeValue:
state = stateKey
growIndex = 1
call = true
case isSemicolon(v) && state == stateParamBeforeName:
state = stateParamName
case isSemicolon(v) && state == stateParamBeforeValue:
state = stateParamName
call = true
case isEquality(v) && state == stateParamBeforeValue:
state = stateParamValue
default:
return false
}
default:
return false
}
if call {
switch it(index, key, param, value) {
case ControlBreak:
// User want to stop to parsing parameters.
return true
case ControlSkip:
// User want to skip current param.
state = stateKey
lexer.SkipEscaped(',')
case ControlContinue:
// User is interested in rest of parameters.
// Nothing to do.
default:
panic("unexpected control value")
}
ok = true
param = nil
value = nil
mustCall = false
index += growIndex
}
}
if mustCall {
ok = true
it(index, key, param, value)
}
return ok && !lexer.err
}
func isComma(b []byte) bool {
return len(b) == 1 && b[0] == ','
}
func isSemicolon(b []byte) bool {
return len(b) == 1 && b[0] == ';'
}
func isEquality(b []byte) bool {
return len(b) == 1 && b[0] == '='
}

View file

@ -0,0 +1,360 @@
package httphead
import (
"bytes"
)
// ItemType encodes type of the lexing token.
type ItemType int
const (
// ItemUndef reports that token is undefined.
ItemUndef ItemType = iota
// ItemToken reports that token is RFC2616 token.
ItemToken
// ItemSeparator reports that token is RFC2616 separator.
ItemSeparator
// ItemString reports that token is RFC2616 quouted string.
ItemString
// ItemComment reports that token is RFC2616 comment.
ItemComment
// ItemOctet reports that token is octet slice.
ItemOctet
)
// Scanner represents header tokens scanner.
// See https://tools.ietf.org/html/rfc2616#section-2
type Scanner struct {
data []byte
pos int
itemType ItemType
itemBytes []byte
err bool
}
// NewScanner creates new RFC2616 data scanner.
func NewScanner(data []byte) *Scanner {
return &Scanner{data: data}
}
// Next scans for next token. It returns true on successful scanning, and false
// on error or EOF.
func (l *Scanner) Next() bool {
c, ok := l.nextChar()
if !ok {
return false
}
switch c {
case '"': // quoted-string;
return l.fetchQuotedString()
case '(': // comment;
return l.fetchComment()
case '\\', ')': // unexpected chars;
l.err = true
return false
default:
return l.fetchToken()
}
}
// FetchUntil fetches ItemOctet from current scanner position to first
// occurence of the c or to the end of the underlying data.
func (l *Scanner) FetchUntil(c byte) bool {
l.resetItem()
if l.pos == len(l.data) {
return false
}
return l.fetchOctet(c)
}
// Peek reads byte at current position without advancing it. On end of data it
// returns 0.
func (l *Scanner) Peek() byte {
if l.pos == len(l.data) {
return 0
}
return l.data[l.pos]
}
// Peek2 reads two first bytes at current position without advancing it.
// If there not enough data it returs 0.
func (l *Scanner) Peek2() (a, b byte) {
if l.pos == len(l.data) {
return 0, 0
}
if l.pos+1 == len(l.data) {
return l.data[l.pos], 0
}
return l.data[l.pos], l.data[l.pos+1]
}
// Buffered reporst how many bytes there are left to scan.
func (l *Scanner) Buffered() int {
return len(l.data) - l.pos
}
// Advance moves current position index at n bytes. It returns true on
// successful move.
func (l *Scanner) Advance(n int) bool {
l.pos += n
if l.pos > len(l.data) {
l.pos = len(l.data)
return false
}
return true
}
// Skip skips all bytes until first occurence of c.
func (l *Scanner) Skip(c byte) {
if l.err {
return
}
// Reset scanner state.
l.resetItem()
if i := bytes.IndexByte(l.data[l.pos:], c); i == -1 {
// Reached the end of data.
l.pos = len(l.data)
} else {
l.pos += i + 1
}
}
// SkipEscaped skips all bytes until first occurence of non-escaped c.
func (l *Scanner) SkipEscaped(c byte) {
if l.err {
return
}
// Reset scanner state.
l.resetItem()
if i := ScanUntil(l.data[l.pos:], c); i == -1 {
// Reached the end of data.
l.pos = len(l.data)
} else {
l.pos += i + 1
}
}
// Type reports current token type.
func (l *Scanner) Type() ItemType {
return l.itemType
}
// Bytes returns current token bytes.
func (l *Scanner) Bytes() []byte {
return l.itemBytes
}
func (l *Scanner) nextChar() (byte, bool) {
// Reset scanner state.
l.resetItem()
if l.err {
return 0, false
}
l.pos += SkipSpace(l.data[l.pos:])
if l.pos == len(l.data) {
return 0, false
}
return l.data[l.pos], true
}
func (l *Scanner) resetItem() {
l.itemType = ItemUndef
l.itemBytes = nil
}
func (l *Scanner) fetchOctet(c byte) bool {
i := l.pos
if j := bytes.IndexByte(l.data[l.pos:], c); j == -1 {
// Reached the end of data.
l.pos = len(l.data)
} else {
l.pos += j
}
l.itemType = ItemOctet
l.itemBytes = l.data[i:l.pos]
return true
}
func (l *Scanner) fetchToken() bool {
n, t := ScanToken(l.data[l.pos:])
if n == -1 {
l.err = true
return false
}
l.itemType = t
l.itemBytes = l.data[l.pos : l.pos+n]
l.pos += n
return true
}
func (l *Scanner) fetchQuotedString() (ok bool) {
l.pos++
n := ScanUntil(l.data[l.pos:], '"')
if n == -1 {
l.err = true
return false
}
l.itemType = ItemString
l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\')
l.pos += n + 1
return true
}
func (l *Scanner) fetchComment() (ok bool) {
l.pos++
n := ScanPairGreedy(l.data[l.pos:], '(', ')')
if n == -1 {
l.err = true
return false
}
l.itemType = ItemComment
l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\')
l.pos += n + 1
return true
}
// ScanUntil scans for first non-escaped character c in given data.
// It returns index of matched c and -1 if c is not found.
func ScanUntil(data []byte, c byte) (n int) {
for {
i := bytes.IndexByte(data[n:], c)
if i == -1 {
return -1
}
n += i
if n == 0 || data[n-1] != '\\' {
break
}
n++
}
return
}
// ScanPairGreedy scans for complete pair of opening and closing chars in greedy manner.
// Note that first opening byte must not be present in data.
func ScanPairGreedy(data []byte, open, close byte) (n int) {
var m int
opened := 1
for {
i := bytes.IndexByte(data[n:], close)
if i == -1 {
return -1
}
n += i
// If found index is not escaped then it is the end.
if n == 0 || data[n-1] != '\\' {
opened--
}
for m < i {
j := bytes.IndexByte(data[m:i], open)
if j == -1 {
break
}
m += j + 1
opened++
}
if opened == 0 {
break
}
n++
m = n
}
return
}
// RemoveByte returns data without c. If c is not present in data it returns
// the same slice. If not, it copies data without c.
func RemoveByte(data []byte, c byte) []byte {
j := bytes.IndexByte(data, c)
if j == -1 {
return data
}
n := len(data) - 1
// If character is present, than allocate slice with n-1 capacity. That is,
// resulting bytes could be at most n-1 length.
result := make([]byte, n)
k := copy(result, data[:j])
for i := j + 1; i < n; {
j = bytes.IndexByte(data[i:], c)
if j != -1 {
k += copy(result[k:], data[i:i+j])
i = i + j + 1
} else {
k += copy(result[k:], data[i:])
break
}
}
return result[:k]
}
// SkipSpace skips spaces and lws-sequences from p.
// It returns number ob bytes skipped.
func SkipSpace(p []byte) (n int) {
for len(p) > 0 {
switch {
case len(p) >= 3 &&
p[0] == '\r' &&
p[1] == '\n' &&
OctetTypes[p[2]].IsSpace():
p = p[3:]
n += 3
case OctetTypes[p[0]].IsSpace():
p = p[1:]
n++
default:
return
}
}
return
}
// ScanToken scan for next token in p. It returns length of the token and its
// type. It do not trim p.
func ScanToken(p []byte) (n int, t ItemType) {
if len(p) == 0 {
return 0, ItemUndef
}
c := p[0]
switch {
case OctetTypes[c].IsSeparator():
return 1, ItemSeparator
case OctetTypes[c].IsToken():
for n = 1; n < len(p); n++ {
c := p[n]
if !OctetTypes[c].IsToken() {
break
}
}
return n, ItemToken
default:
return -1, ItemUndef
}
}

View file

@ -0,0 +1,83 @@
package httphead
// OctetType desribes character type.
//
// From the "Basic Rules" chapter of RFC2616
// See https://tools.ietf.org/html/rfc2616#section-2.2
//
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// UPALPHA = <any US-ASCII uppercase letter "A".."Z">
// LOALPHA = <any US-ASCII lowercase letter "a".."z">
// ALPHA = UPALPHA | LOALPHA
// DIGIT = <any US-ASCII digit "0".."9">
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
//
// Many HTTP/1.1 header field values consist of words separated by LWS
// or special characters. These special characters MUST be in a quoted
// string to be used within a parameter value (as defined in section
// 3.6).
//
// token = 1*<any CHAR except CTLs or separators>
// separators = "(" | ")" | "<" | ">" | "@"
// | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "="
// | "{" | "}" | SP | HT
type OctetType byte
// IsChar reports whether octet is CHAR.
func (t OctetType) IsChar() bool { return t&octetChar != 0 }
// IsControl reports whether octet is CTL.
func (t OctetType) IsControl() bool { return t&octetControl != 0 }
// IsSeparator reports whether octet is separator.
func (t OctetType) IsSeparator() bool { return t&octetSeparator != 0 }
// IsSpace reports whether octet is space (SP or HT).
func (t OctetType) IsSpace() bool { return t&octetSpace != 0 }
// IsToken reports whether octet is token.
func (t OctetType) IsToken() bool { return t&octetToken != 0 }
const (
octetChar OctetType = 1 << iota
octetControl
octetSpace
octetSeparator
octetToken
)
// OctetTypes is a table of octets.
var OctetTypes [256]OctetType
func init() {
for c := 32; c < 256; c++ {
var t OctetType
if c <= 127 {
t |= octetChar
}
if 0 <= c && c <= 31 || c == 127 {
t |= octetControl
}
switch c {
case '(', ')', '<', '>', '@', ',', ';', ':', '"', '/', '[', ']', '?', '=', '{', '}', '\\':
t |= octetSeparator
case ' ', '\t':
t |= octetSpace | octetSeparator
}
if t.IsChar() && !t.IsControl() && !t.IsSeparator() && !t.IsSpace() {
t |= octetToken
}
OctetTypes[c] = t
}
}

View file

@ -0,0 +1,193 @@
package httphead
import (
"bytes"
"sort"
)
// Option represents a header option.
type Option struct {
Name []byte
Parameters Parameters
}
// Size returns number of bytes need to be allocated for use in opt.Copy.
func (opt Option) Size() int {
return len(opt.Name) + opt.Parameters.bytes
}
// Copy copies all underlying []byte slices into p and returns new Option.
// Note that p must be at least of opt.Size() length.
func (opt Option) Copy(p []byte) Option {
n := copy(p, opt.Name)
opt.Name = p[:n]
opt.Parameters, p = opt.Parameters.Copy(p[n:])
return opt
}
// Clone is a shorthand for making slice of opt.Size() sequenced with Copy()
// call.
func (opt Option) Clone() Option {
return opt.Copy(make([]byte, opt.Size()))
}
// String represents option as a string.
func (opt Option) String() string {
return "{" + string(opt.Name) + " " + opt.Parameters.String() + "}"
}
// NewOption creates named option with given parameters.
func NewOption(name string, params map[string]string) Option {
p := Parameters{}
for k, v := range params {
p.Set([]byte(k), []byte(v))
}
return Option{
Name: []byte(name),
Parameters: p,
}
}
// Equal reports whether option is equal to b.
func (opt Option) Equal(b Option) bool {
if bytes.Equal(opt.Name, b.Name) {
return opt.Parameters.Equal(b.Parameters)
}
return false
}
// Parameters represents option's parameters.
type Parameters struct {
pos int
bytes int
arr [8]pair
dyn []pair
}
// Equal reports whether a equal to b.
func (p Parameters) Equal(b Parameters) bool {
switch {
case p.dyn == nil && b.dyn == nil:
case p.dyn != nil && b.dyn != nil:
default:
return false
}
ad, bd := p.data(), b.data()
if len(ad) != len(bd) {
return false
}
sort.Sort(pairs(ad))
sort.Sort(pairs(bd))
for i := 0; i < len(ad); i++ {
av, bv := ad[i], bd[i]
if !bytes.Equal(av.key, bv.key) || !bytes.Equal(av.value, bv.value) {
return false
}
}
return true
}
// Size returns number of bytes that needed to copy p.
func (p *Parameters) Size() int {
return p.bytes
}
// Copy copies all underlying []byte slices into dst and returns new
// Parameters.
// Note that dst must be at least of p.Size() length.
func (p *Parameters) Copy(dst []byte) (Parameters, []byte) {
ret := Parameters{
pos: p.pos,
bytes: p.bytes,
}
if p.dyn != nil {
ret.dyn = make([]pair, len(p.dyn))
for i, v := range p.dyn {
ret.dyn[i], dst = v.copy(dst)
}
} else {
for i, p := range p.arr {
ret.arr[i], dst = p.copy(dst)
}
}
return ret, dst
}
// Get returns value by key and flag about existence such value.
func (p *Parameters) Get(key string) (value []byte, ok bool) {
for _, v := range p.data() {
if string(v.key) == key {
return v.value, true
}
}
return nil, false
}
// Set sets value by key.
func (p *Parameters) Set(key, value []byte) {
p.bytes += len(key) + len(value)
if p.pos < len(p.arr) {
p.arr[p.pos] = pair{key, value}
p.pos++
return
}
if p.dyn == nil {
p.dyn = make([]pair, len(p.arr), len(p.arr)+1)
copy(p.dyn, p.arr[:])
}
p.dyn = append(p.dyn, pair{key, value})
}
// ForEach iterates over parameters key-value pairs and calls cb for each one.
func (p *Parameters) ForEach(cb func(k, v []byte) bool) {
for _, v := range p.data() {
if !cb(v.key, v.value) {
break
}
}
}
// String represents parameters as a string.
func (p *Parameters) String() (ret string) {
ret = "["
for i, v := range p.data() {
if i > 0 {
ret += " "
}
ret += string(v.key) + ":" + string(v.value)
}
return ret + "]"
}
func (p *Parameters) data() []pair {
if p.dyn != nil {
return p.dyn
}
return p.arr[:p.pos]
}
type pair struct {
key, value []byte
}
func (p pair) copy(dst []byte) (pair, []byte) {
n := copy(dst, p.key)
p.key = dst[:n]
m := n + copy(dst[n:], p.value)
p.value = dst[n:m]
dst = dst[m:]
return p, dst
}
type pairs []pair
func (p pairs) Len() int { return len(p) }
func (p pairs) Less(a, b int) bool { return bytes.Compare(p[a].key, p[b].key) == -1 }
func (p pairs) Swap(a, b int) { p[a], p[b] = p[b], p[a] }

View file

@ -0,0 +1,101 @@
package httphead
import "io"
var (
comma = []byte{','}
equality = []byte{'='}
semicolon = []byte{';'}
quote = []byte{'"'}
escape = []byte{'\\'}
)
// WriteOptions write options list to the dest.
// It uses the same form as {Scan,Parse}Options functions:
// values = 1#value
// value = token *( ";" param )
// param = token [ "=" (token | quoted-string) ]
//
// It wraps valuse into the quoted-string sequence if it contains any
// non-token characters.
func WriteOptions(dest io.Writer, options []Option) (n int, err error) {
w := writer{w: dest}
for i, opt := range options {
if i > 0 {
w.write(comma)
}
writeTokenSanitized(&w, opt.Name)
for _, p := range opt.Parameters.data() {
w.write(semicolon)
writeTokenSanitized(&w, p.key)
if len(p.value) != 0 {
w.write(equality)
writeTokenSanitized(&w, p.value)
}
}
}
return w.result()
}
// writeTokenSanitized writes token as is or as quouted string if it contains
// non-token characters.
//
// Note that is is not expects LWS sequnces be in s, cause LWS is used only as
// header field continuation:
// "A CRLF is allowed in the definition of TEXT only as part of a header field
// continuation. It is expected that the folding LWS will be replaced with a
// single SP before interpretation of the TEXT value."
// See https://tools.ietf.org/html/rfc2616#section-2
//
// That is we sanitizing s for writing, so there could not be any header field
// continuation.
// That is any CRLF will be escaped as any other control characters not allowd in TEXT.
func writeTokenSanitized(bw *writer, bts []byte) {
var qt bool
var pos int
for i := 0; i < len(bts); i++ {
c := bts[i]
if !OctetTypes[c].IsToken() && !qt {
qt = true
bw.write(quote)
}
if OctetTypes[c].IsControl() || c == '"' {
if !qt {
qt = true
bw.write(quote)
}
bw.write(bts[pos:i])
bw.write(escape)
bw.write(bts[i : i+1])
pos = i + 1
}
}
if !qt {
bw.write(bts)
} else {
bw.write(bts[pos:])
bw.write(quote)
}
}
type writer struct {
w io.Writer
n int
err error
}
func (w *writer) write(p []byte) {
if w.err != nil {
return
}
var n int
n, w.err = w.w.Write(p)
w.n += n
return
}
func (w *writer) result() (int, error) {
return w.n, w.err
}

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2017-2019 Sergey Kamardin <gobwas@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,107 @@
# pool
[![GoDoc][godoc-image]][godoc-url]
> Tiny memory reuse helpers for Go.
## generic
Without use of subpackages, `pool` allows to reuse any struct distinguishable
by size in generic way:
```go
package main
import "github.com/gobwas/pool"
func main() {
x, n := pool.Get(100) // Returns object with size 128 or nil.
if x == nil {
// Create x somehow with knowledge that n is 128.
}
defer pool.Put(x, n)
// Work with x.
}
```
Pool allows you to pass specific options for constructing custom pool:
```go
package main
import "github.com/gobwas/pool"
func main() {
p := pool.Custom(
pool.WithLogSizeMapping(), // Will ceil size n passed to Get(n) to nearest power of two.
pool.WithLogSizeRange(64, 512), // Will reuse objects in logarithmic range [64, 512].
pool.WithSize(65536), // Will reuse object with size 65536.
)
x, n := p.Get(1000) // Returns nil and 1000 because mapped size 1000 => 1024 is not reusing by the pool.
defer pool.Put(x, n) // Will not reuse x.
// Work with x.
}
```
Note that there are few non-generic pooling implementations inside subpackages.
## pbytes
Subpackage `pbytes` is intended for `[]byte` reuse.
```go
package main
import "github.com/gobwas/pool/pbytes"
func main() {
bts := pbytes.GetCap(100) // Returns make([]byte, 0, 128).
defer pbytes.Put(bts)
// Work with bts.
}
```
You can also create your own range for pooling:
```go
package main
import "github.com/gobwas/pool/pbytes"
func main() {
// Reuse only slices whose capacity is 128, 256, 512 or 1024.
pool := pbytes.New(128, 1024)
bts := pool.GetCap(100) // Returns make([]byte, 0, 128).
defer pool.Put(bts)
// Work with bts.
}
```
## pbufio
Subpackage `pbufio` is intended for `*bufio.{Reader, Writer}` reuse.
```go
package main
import "github.com/gobwas/pool/pbufio"
func main() {
bw := pbufio.GetWriter(os.Stdout, 100) // Returns bufio.NewWriterSize(128).
defer pbufio.PutWriter(bw)
// Work with bw.
}
```
Like with `pbytes`, you can also create pool with custom reuse bounds.
[godoc-image]: https://godoc.org/github.com/gobwas/pool?status.svg
[godoc-url]: https://godoc.org/github.com/gobwas/pool

View file

@ -0,0 +1,87 @@
package pool
import (
"sync"
"github.com/gobwas/pool/internal/pmath"
)
var DefaultPool = New(128, 65536)
// Get pulls object whose generic size is at least of given size. It also
// returns a real size of x for further pass to Put(). It returns -1 as real
// size for nil x. Size >-1 does not mean that x is non-nil, so checks must be
// done.
//
// Note that size could be ceiled to the next power of two.
//
// Get is a wrapper around DefaultPool.Get().
func Get(size int) (interface{}, int) { return DefaultPool.Get(size) }
// Put takes x and its size for future reuse.
// Put is a wrapper around DefaultPool.Put().
func Put(x interface{}, size int) { DefaultPool.Put(x, size) }
// Pool contains logic of reusing objects distinguishable by size in generic
// way.
type Pool struct {
pool map[int]*sync.Pool
size func(int) int
}
// New creates new Pool that reuses objects which size is in logarithmic range
// [min, max].
//
// Note that it is a shortcut for Custom() constructor with Options provided by
// WithLogSizeMapping() and WithLogSizeRange(min, max) calls.
func New(min, max int) *Pool {
return Custom(
WithLogSizeMapping(),
WithLogSizeRange(min, max),
)
}
// Custom creates new Pool with given options.
func Custom(opts ...Option) *Pool {
p := &Pool{
pool: make(map[int]*sync.Pool),
size: pmath.Identity,
}
c := (*poolConfig)(p)
for _, opt := range opts {
opt(c)
}
return p
}
// Get pulls object whose generic size is at least of given size.
// It also returns a real size of x for further pass to Put() even if x is nil.
// Note that size could be ceiled to the next power of two.
func (p *Pool) Get(size int) (interface{}, int) {
n := p.size(size)
if pool := p.pool[n]; pool != nil {
return pool.Get(), n
}
return nil, size
}
// Put takes x and its size for future reuse.
func (p *Pool) Put(x interface{}, size int) {
if pool := p.pool[size]; pool != nil {
pool.Put(x)
}
}
type poolConfig Pool
// AddSize adds size n to the map.
func (p *poolConfig) AddSize(n int) {
p.pool[n] = new(sync.Pool)
}
// SetSizeMapping sets up incoming size mapping function.
func (p *poolConfig) SetSizeMapping(size func(int) int) {
p.size = size
}

View file

@ -0,0 +1,65 @@
package pmath
const (
bitsize = 32 << (^uint(0) >> 63)
maxint = int(1<<(bitsize-1) - 1)
maxintHeadBit = 1 << (bitsize - 2)
)
// LogarithmicRange iterates from ceiled to power of two min to max,
// calling cb on each iteration.
func LogarithmicRange(min, max int, cb func(int)) {
if min == 0 {
min = 1
}
for n := CeilToPowerOfTwo(min); n <= max; n <<= 1 {
cb(n)
}
}
// IsPowerOfTwo reports whether given integer is a power of two.
func IsPowerOfTwo(n int) bool {
return n&(n-1) == 0
}
// Identity is identity.
func Identity(n int) int {
return n
}
// CeilToPowerOfTwo returns the least power of two integer value greater than
// or equal to n.
func CeilToPowerOfTwo(n int) int {
if n&maxintHeadBit != 0 && n > maxintHeadBit {
panic("argument is too large")
}
if n <= 2 {
return n
}
n--
n = fillBits(n)
n++
return n
}
// FloorToPowerOfTwo returns the greatest power of two integer value less than
// or equal to n.
func FloorToPowerOfTwo(n int) int {
if n <= 2 {
return n
}
n = fillBits(n)
n >>= 1
n++
return n
}
func fillBits(n int) int {
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
return n
}

View file

@ -0,0 +1,43 @@
package pool
import "github.com/gobwas/pool/internal/pmath"
// Option configures pool.
type Option func(Config)
// Config describes generic pool configuration.
type Config interface {
AddSize(n int)
SetSizeMapping(func(int) int)
}
// WithSizeLogRange returns an Option that will add logarithmic range of
// pooling sizes containing [min, max] values.
func WithLogSizeRange(min, max int) Option {
return func(c Config) {
pmath.LogarithmicRange(min, max, func(n int) {
c.AddSize(n)
})
}
}
// WithSize returns an Option that will add given pooling size to the pool.
func WithSize(n int) Option {
return func(c Config) {
c.AddSize(n)
}
}
func WithSizeMapping(sz func(int) int) Option {
return func(c Config) {
c.SetSizeMapping(sz)
}
}
func WithLogSizeMapping() Option {
return WithSizeMapping(pmath.CeilToPowerOfTwo)
}
func WithIdentitySizeMapping() Option {
return WithSizeMapping(pmath.Identity)
}

View file

@ -0,0 +1,106 @@
// Package pbufio contains tools for pooling bufio.Reader and bufio.Writers.
package pbufio
import (
"bufio"
"io"
"github.com/gobwas/pool"
)
var (
DefaultWriterPool = NewWriterPool(256, 65536)
DefaultReaderPool = NewReaderPool(256, 65536)
)
// GetWriter returns bufio.Writer whose buffer has at least size bytes.
// Note that size could be ceiled to the next power of two.
// GetWriter is a wrapper around DefaultWriterPool.Get().
func GetWriter(w io.Writer, size int) *bufio.Writer { return DefaultWriterPool.Get(w, size) }
// PutWriter takes bufio.Writer for future reuse.
// It does not reuse bufio.Writer which underlying buffer size is not power of
// PutWriter is a wrapper around DefaultWriterPool.Put().
func PutWriter(bw *bufio.Writer) { DefaultWriterPool.Put(bw) }
// GetReader returns bufio.Reader whose buffer has at least size bytes. It returns
// its capacity for further pass to Put().
// Note that size could be ceiled to the next power of two.
// GetReader is a wrapper around DefaultReaderPool.Get().
func GetReader(w io.Reader, size int) *bufio.Reader { return DefaultReaderPool.Get(w, size) }
// PutReader takes bufio.Reader and its size for future reuse.
// It does not reuse bufio.Reader if size is not power of two or is out of pool
// min/max range.
// PutReader is a wrapper around DefaultReaderPool.Put().
func PutReader(bw *bufio.Reader) { DefaultReaderPool.Put(bw) }
// WriterPool contains logic of *bufio.Writer reuse with various size.
type WriterPool struct {
pool *pool.Pool
}
// NewWriterPool creates new WriterPool that reuses writers which size is in
// logarithmic range [min, max].
func NewWriterPool(min, max int) *WriterPool {
return &WriterPool{pool.New(min, max)}
}
// CustomWriterPool creates new WriterPool with given options.
func CustomWriterPool(opts ...pool.Option) *WriterPool {
return &WriterPool{pool.Custom(opts...)}
}
// Get returns bufio.Writer whose buffer has at least size bytes.
func (wp *WriterPool) Get(w io.Writer, size int) *bufio.Writer {
v, n := wp.pool.Get(size)
if v != nil {
bw := v.(*bufio.Writer)
bw.Reset(w)
return bw
}
return bufio.NewWriterSize(w, n)
}
// Put takes ownership of bufio.Writer for further reuse.
func (wp *WriterPool) Put(bw *bufio.Writer) {
// Should reset even if we do Reset() inside Get().
// This is done to prevent locking underlying io.Writer from GC.
bw.Reset(nil)
wp.pool.Put(bw, writerSize(bw))
}
// ReaderPool contains logic of *bufio.Reader reuse with various size.
type ReaderPool struct {
pool *pool.Pool
}
// NewReaderPool creates new ReaderPool that reuses writers which size is in
// logarithmic range [min, max].
func NewReaderPool(min, max int) *ReaderPool {
return &ReaderPool{pool.New(min, max)}
}
// CustomReaderPool creates new ReaderPool with given options.
func CustomReaderPool(opts ...pool.Option) *ReaderPool {
return &ReaderPool{pool.Custom(opts...)}
}
// Get returns bufio.Reader whose buffer has at least size bytes.
func (rp *ReaderPool) Get(r io.Reader, size int) *bufio.Reader {
v, n := rp.pool.Get(size)
if v != nil {
br := v.(*bufio.Reader)
br.Reset(r)
return br
}
return bufio.NewReaderSize(r, n)
}
// Put takes ownership of bufio.Reader for further reuse.
func (rp *ReaderPool) Put(br *bufio.Reader) {
// Should reset even if we do Reset() inside Get().
// This is done to prevent locking underlying io.Reader from GC.
br.Reset(nil)
rp.pool.Put(br, readerSize(br))
}

View file

@ -0,0 +1,13 @@
// +build go1.10
package pbufio
import "bufio"
func writerSize(bw *bufio.Writer) int {
return bw.Size()
}
func readerSize(br *bufio.Reader) int {
return br.Size()
}

View file

@ -0,0 +1,27 @@
// +build !go1.10
package pbufio
import "bufio"
func writerSize(bw *bufio.Writer) int {
return bw.Available() + bw.Buffered()
}
// readerSize returns buffer size of the given buffered reader.
// NOTE: current workaround implementation resets underlying io.Reader.
func readerSize(br *bufio.Reader) int {
br.Reset(sizeReader)
br.ReadByte()
n := br.Buffered() + 1
br.Reset(nil)
return n
}
var sizeReader optimisticReader
type optimisticReader struct{}
func (optimisticReader) Read(p []byte) (int, error) {
return len(p), nil
}

View file

@ -0,0 +1,24 @@
// Package pbytes contains tools for pooling byte pool.
// Note that by default it reuse slices with capacity from 128 to 65536 bytes.
package pbytes
// DefaultPool is used by pacakge level functions.
var DefaultPool = New(128, 65536)
// Get returns probably reused slice of bytes with at least capacity of c and
// exactly len of n.
// Get is a wrapper around DefaultPool.Get().
func Get(n, c int) []byte { return DefaultPool.Get(n, c) }
// GetCap returns probably reused slice of bytes with at least capacity of n.
// GetCap is a wrapper around DefaultPool.GetCap().
func GetCap(c int) []byte { return DefaultPool.GetCap(c) }
// GetLen returns probably reused slice of bytes with at least capacity of n
// and exactly len of n.
// GetLen is a wrapper around DefaultPool.GetLen().
func GetLen(n int) []byte { return DefaultPool.GetLen(n) }
// Put returns given slice to reuse pool.
// Put is a wrapper around DefaultPool.Put().
func Put(p []byte) { DefaultPool.Put(p) }

View file

@ -0,0 +1,59 @@
// +build !pool_sanitize
package pbytes
import "github.com/gobwas/pool"
// Pool contains logic of reusing byte slices of various size.
type Pool struct {
pool *pool.Pool
}
// New creates new Pool that reuses slices which size is in logarithmic range
// [min, max].
//
// Note that it is a shortcut for Custom() constructor with Options provided by
// pool.WithLogSizeMapping() and pool.WithLogSizeRange(min, max) calls.
func New(min, max int) *Pool {
return &Pool{pool.New(min, max)}
}
// New creates new Pool with given options.
func Custom(opts ...pool.Option) *Pool {
return &Pool{pool.Custom(opts...)}
}
// Get returns probably reused slice of bytes with at least capacity of c and
// exactly len of n.
func (p *Pool) Get(n, c int) []byte {
if n > c {
panic("requested length is greater than capacity")
}
v, x := p.pool.Get(c)
if v != nil {
bts := v.([]byte)
bts = bts[:n]
return bts
}
return make([]byte, n, x)
}
// Put returns given slice to reuse pool.
// It does not reuse bytes whose size is not power of two or is out of pool
// min/max range.
func (p *Pool) Put(bts []byte) {
p.pool.Put(bts, cap(bts))
}
// GetCap returns probably reused slice of bytes with at least capacity of n.
func (p *Pool) GetCap(c int) []byte {
return p.Get(0, c)
}
// GetLen returns probably reused slice of bytes with at least capacity of n
// and exactly len of n.
func (p *Pool) GetLen(n int) []byte {
return p.Get(n, n)
}

View file

@ -0,0 +1,121 @@
// +build pool_sanitize
package pbytes
import (
"reflect"
"runtime"
"sync/atomic"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
const magic = uint64(0x777742)
type guard struct {
magic uint64
size int
owners int32
}
const guardSize = int(unsafe.Sizeof(guard{}))
type Pool struct {
min, max int
}
func New(min, max int) *Pool {
return &Pool{min, max}
}
// Get returns probably reused slice of bytes with at least capacity of c and
// exactly len of n.
func (p *Pool) Get(n, c int) []byte {
if n > c {
panic("requested length is greater than capacity")
}
pageSize := syscall.Getpagesize()
pages := (c+guardSize)/pageSize + 1
size := pages * pageSize
bts := alloc(size)
g := (*guard)(unsafe.Pointer(&bts[0]))
*g = guard{
magic: magic,
size: size,
owners: 1,
}
return bts[guardSize : guardSize+n]
}
func (p *Pool) GetCap(c int) []byte { return p.Get(0, c) }
func (p *Pool) GetLen(n int) []byte { return Get(n, n) }
// Put returns given slice to reuse pool.
func (p *Pool) Put(bts []byte) {
hdr := *(*reflect.SliceHeader)(unsafe.Pointer(&bts))
ptr := hdr.Data - uintptr(guardSize)
g := (*guard)(unsafe.Pointer(ptr))
if g.magic != magic {
panic("unknown slice returned to the pool")
}
if n := atomic.AddInt32(&g.owners, -1); n < 0 {
panic("multiple Put() detected")
}
// Disable read and write on bytes memory pages. This will cause panic on
// incorrect access to returned slice.
mprotect(ptr, false, false, g.size)
runtime.SetFinalizer(&bts, func(b *[]byte) {
mprotect(ptr, true, true, g.size)
free(*(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
Data: ptr,
Len: g.size,
Cap: g.size,
})))
})
}
func alloc(n int) []byte {
b, err := unix.Mmap(-1, 0, n, unix.PROT_READ|unix.PROT_WRITE|unix.PROT_EXEC, unix.MAP_SHARED|unix.MAP_ANONYMOUS)
if err != nil {
panic(err.Error())
}
return b
}
func free(b []byte) {
if err := unix.Munmap(b); err != nil {
panic(err.Error())
}
}
func mprotect(ptr uintptr, r, w bool, size int) {
// Need to avoid "EINVAL addr is not a valid pointer,
// or not a multiple of PAGESIZE."
start := ptr & ^(uintptr(syscall.Getpagesize() - 1))
prot := uintptr(syscall.PROT_EXEC)
switch {
case r && w:
prot |= syscall.PROT_READ | syscall.PROT_WRITE
case r:
prot |= syscall.PROT_READ
case w:
prot |= syscall.PROT_WRITE
}
_, _, err := syscall.Syscall(syscall.SYS_MPROTECT,
start, uintptr(size), prot,
)
if err != 0 {
panic(err.Error())
}
}

View file

@ -0,0 +1,25 @@
// Package pool contains helpers for pooling structures distinguishable by
// size.
//
// Quick example:
//
// import "github.com/gobwas/pool"
//
// func main() {
// // Reuse objects in logarithmic range from 0 to 64 (0,1,2,4,6,8,16,32,64).
// p := pool.New(0, 64)
//
// buf, n := p.Get(10) // Returns buffer with 16 capacity.
// if buf == nil {
// buf = bytes.NewBuffer(make([]byte, n))
// }
// defer p.Put(buf, n)
//
// // Work with buf.
// }
//
// There are non-generic implementations for pooling:
// - pool/pbytes for []byte reuse;
// - pool/pbufio for *bufio.Reader and *bufio.Writer reuse;
//
package pool

View file

@ -0,0 +1,5 @@
bin/
reports/
cpu.out
mem.out
ws.test

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2017-2018 Sergey Kamardin <gobwas@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,54 @@
BENCH ?=.
BENCH_BASE?=master
clean:
rm -f bin/reporter
rm -fr autobahn/report/*
bin/reporter:
go build -o bin/reporter ./autobahn
bin/gocovmerge:
go build -o bin/gocovmerge github.com/wadey/gocovmerge
.PHONY: autobahn
autobahn: clean bin/reporter
./autobahn/script/test.sh --build --follow-logs
bin/reporter $(PWD)/autobahn/report/index.json
.PHONY: autobahn/report
autobahn/report: bin/reporter
./bin/reporter -http localhost:5555 ./autobahn/report/index.json
test:
go test -coverprofile=ws.coverage .
go test -coverprofile=wsutil.coverage ./wsutil
go test -coverprofile=wsfalte.coverage ./wsflate
# No statemenets to cover in ./tests (there are only tests).
go test ./tests
cover: bin/gocovmerge test autobahn
bin/gocovmerge ws.coverage wsutil.coverage wsflate.coverage autobahn/report/server.coverage > total.coverage
benchcmp: BENCH_BRANCH=$(shell git rev-parse --abbrev-ref HEAD)
benchcmp: BENCH_OLD:=$(shell mktemp -t old.XXXX)
benchcmp: BENCH_NEW:=$(shell mktemp -t new.XXXX)
benchcmp:
if [ ! -z "$(shell git status -s)" ]; then\
echo "could not compare with $(BENCH_BASE) found unstaged changes";\
exit 1;\
fi;\
if [ "$(BENCH_BRANCH)" == "$(BENCH_BASE)" ]; then\
echo "comparing the same branches";\
exit 1;\
fi;\
echo "benchmarking $(BENCH_BRANCH)...";\
go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_NEW);\
echo "benchmarking $(BENCH_BASE)...";\
git checkout -q $(BENCH_BASE);\
go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_OLD);\
git checkout -q $(BENCH_BRANCH);\
echo "\nresults:";\
echo "========\n";\
benchcmp $(BENCH_OLD) $(BENCH_NEW);\

View file

@ -0,0 +1,450 @@
# ws
[![GoDoc][godoc-image]][godoc-url]
[![CI][ci-badge]][ci-url]
> [RFC6455][rfc-url] WebSocket implementation in Go.
# Features
- Zero-copy upgrade
- No intermediate allocations during I/O
- Low-level API which allows to build your own logic of packet handling and
buffers reuse
- High-level wrappers and helpers around API in `wsutil` package, which allow
to start fast without digging the protocol internals
# Documentation
[GoDoc][godoc-url].
# Why
Existing WebSocket implementations do not allow users to reuse I/O buffers
between connections in clear way. This library aims to export efficient
low-level interface for working with the protocol without forcing only one way
it could be used.
By the way, if you want get the higher-level tools, you can use `wsutil`
package.
# Status
Library is tagged as `v1*` so its API must not be broken during some
improvements or refactoring.
This implementation of RFC6455 passes [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) and currently has
about 78% coverage.
# Examples
Example applications using `ws` are developed in separate repository
[ws-examples](https://github.com/gobwas/ws-examples).
# Usage
The higher-level example of WebSocket echo server:
```go
package main
import (
"net/http"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
func main() {
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
// handle error
}
go func() {
defer conn.Close()
for {
msg, op, err := wsutil.ReadClientData(conn)
if err != nil {
// handle error
}
err = wsutil.WriteServerMessage(conn, op, msg)
if err != nil {
// handle error
}
}
}()
}))
}
```
Lower-level, but still high-level example:
```go
import (
"net/http"
"io"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
func main() {
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
// handle error
}
go func() {
defer conn.Close()
var (
state = ws.StateServerSide
reader = wsutil.NewReader(conn, state)
writer = wsutil.NewWriter(conn, state, ws.OpText)
)
for {
header, err := reader.NextFrame()
if err != nil {
// handle error
}
// Reset writer to write frame with right operation code.
writer.Reset(conn, state, header.OpCode)
if _, err = io.Copy(writer, reader); err != nil {
// handle error
}
if err = writer.Flush(); err != nil {
// handle error
}
}
}()
}))
}
```
We can apply the same pattern to read and write structured responses through a JSON encoder and decoder.:
```go
...
var (
r = wsutil.NewReader(conn, ws.StateServerSide)
w = wsutil.NewWriter(conn, ws.StateServerSide, ws.OpText)
decoder = json.NewDecoder(r)
encoder = json.NewEncoder(w)
)
for {
hdr, err = r.NextFrame()
if err != nil {
return err
}
if hdr.OpCode == ws.OpClose {
return io.EOF
}
var req Request
if err := decoder.Decode(&req); err != nil {
return err
}
var resp Response
if err := encoder.Encode(&resp); err != nil {
return err
}
if err = w.Flush(); err != nil {
return err
}
}
...
```
The lower-level example without `wsutil`:
```go
package main
import (
"net"
"io"
"github.com/gobwas/ws"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
log.Fatal(err)
}
for {
conn, err := ln.Accept()
if err != nil {
// handle error
}
_, err = ws.Upgrade(conn)
if err != nil {
// handle error
}
go func() {
defer conn.Close()
for {
header, err := ws.ReadHeader(conn)
if err != nil {
// handle error
}
payload := make([]byte, header.Length)
_, err = io.ReadFull(conn, payload)
if err != nil {
// handle error
}
if header.Masked {
ws.Cipher(payload, header.Mask, 0)
}
// Reset the Masked flag, server frames must not be masked as
// RFC6455 says.
header.Masked = false
if err := ws.WriteHeader(conn, header); err != nil {
// handle error
}
if _, err := conn.Write(payload); err != nil {
// handle error
}
if header.OpCode == ws.OpClose {
return
}
}
}()
}
}
```
# Zero-copy upgrade
Zero-copy upgrade helps to avoid unnecessary allocations and copying while
handling HTTP Upgrade request.
Processing of all non-websocket headers is made in place with use of registered
user callbacks whose arguments are only valid until callback returns.
The simple example looks like this:
```go
package main
import (
"net"
"log"
"github.com/gobwas/ws"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
log.Fatal(err)
}
u := ws.Upgrader{
OnHeader: func(key, value []byte) (err error) {
log.Printf("non-websocket header: %q=%q", key, value)
return
},
}
for {
conn, err := ln.Accept()
if err != nil {
// handle error
}
_, err = u.Upgrade(conn)
if err != nil {
// handle error
}
}
}
```
Usage of `ws.Upgrader` here brings ability to control incoming connections on
tcp level and simply not to accept them by some logic.
Zero-copy upgrade is for high-load services which have to control many
resources such as connections buffers.
The real life example could be like this:
```go
package main
import (
"fmt"
"io"
"log"
"net"
"net/http"
"runtime"
"github.com/gobwas/httphead"
"github.com/gobwas/ws"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
// handle error
}
// Prepare handshake header writer from http.Header mapping.
header := ws.HandshakeHeaderHTTP(http.Header{
"X-Go-Version": []string{runtime.Version()},
})
u := ws.Upgrader{
OnHost: func(host []byte) error {
if string(host) == "github.com" {
return nil
}
return ws.RejectConnectionError(
ws.RejectionStatus(403),
ws.RejectionHeader(ws.HandshakeHeaderString(
"X-Want-Host: github.com\r\n",
)),
)
},
OnHeader: func(key, value []byte) error {
if string(key) != "Cookie" {
return nil
}
ok := httphead.ScanCookie(value, func(key, value []byte) bool {
// Check session here or do some other stuff with cookies.
// Maybe copy some values for future use.
return true
})
if ok {
return nil
}
return ws.RejectConnectionError(
ws.RejectionReason("bad cookie"),
ws.RejectionStatus(400),
)
},
OnBeforeUpgrade: func() (ws.HandshakeHeader, error) {
return header, nil
},
}
for {
conn, err := ln.Accept()
if err != nil {
log.Fatal(err)
}
_, err = u.Upgrade(conn)
if err != nil {
log.Printf("upgrade error: %s", err)
}
}
}
```
# Compression
There is a `ws/wsflate` package to support [Permessage-Deflate Compression
Extension][rfc-pmce].
It provides minimalistic I/O wrappers to be used in conjunction with any
deflate implementation (for example, the standard library's
[compress/flate][compress/flate].
```go
package main
import (
"bytes"
"log"
"net"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
// handle error
}
e := wsflate.Extension{
// We are using default parameters here since we use
// wsflate.{Compress,Decompress}Frame helpers below in the code.
// This assumes that we use standard compress/flate package as flate
// implementation.
Parameters: wsflate.DefaultParameters,
}
u := ws.Upgrader{
Negotiate: e.Negotiate,
}
for {
conn, err := ln.Accept()
if err != nil {
log.Fatal(err)
}
// Reset extension after previous upgrades.
e.Reset()
_, err = u.Upgrade(conn)
if err != nil {
log.Printf("upgrade error: %s", err)
continue
}
if _, ok := e.Accepted(); !ok {
log.Printf("didn't negotiate compression for %s", conn.RemoteAddr())
conn.Close()
continue
}
go func() {
defer conn.Close()
for {
frame, err := ws.ReadFrame(conn)
if err != nil {
// Handle error.
return
}
frame = ws.UnmaskFrameInPlace(frame)
frame, err = wsflate.DecompressFrame(frame)
if err != nil {
// Handle error.
return
}
// Do something with frame...
ack := ws.NewTextFrame([]byte("this is an acknowledgement"))
ack, err = wsflate.CompressFrame(ack)
if err != nil {
// Handle error.
return
}
if err = ws.WriteFrame(conn, ack); err != nil {
// Handle error.
return
}
}
}()
}
}
```
[rfc-url]: https://tools.ietf.org/html/rfc6455
[rfc-pmce]: https://tools.ietf.org/html/rfc7692#section-7
[godoc-image]: https://godoc.org/github.com/gobwas/ws?status.svg
[godoc-url]: https://godoc.org/github.com/gobwas/ws
[compress/flate]: https://golang.org/pkg/compress/flate/
[ci-badge]: https://github.com/gobwas/ws/workflows/CI/badge.svg
[ci-url]: https://github.com/gobwas/ws/actions?query=workflow%3ACI

View file

@ -0,0 +1,145 @@
package ws
import "unicode/utf8"
// State represents state of websocket endpoint.
// It used by some functions to be more strict when checking compatibility with RFC6455.
type State uint8
const (
// StateServerSide means that endpoint (caller) is a server.
StateServerSide State = 0x1 << iota
// StateClientSide means that endpoint (caller) is a client.
StateClientSide
// StateExtended means that extension was negotiated during handshake.
StateExtended
// StateFragmented means that endpoint (caller) has received fragmented
// frame and waits for continuation parts.
StateFragmented
)
// Is checks whether the s has v enabled.
func (s State) Is(v State) bool {
return uint8(s)&uint8(v) != 0
}
// Set enables v state on s.
func (s State) Set(v State) State {
return s | v
}
// Clear disables v state on s.
func (s State) Clear(v State) State {
return s & (^v)
}
// ServerSide reports whether states represents server side.
func (s State) ServerSide() bool { return s.Is(StateServerSide) }
// ClientSide reports whether state represents client side.
func (s State) ClientSide() bool { return s.Is(StateClientSide) }
// Extended reports whether state is extended.
func (s State) Extended() bool { return s.Is(StateExtended) }
// Fragmented reports whether state is fragmented.
func (s State) Fragmented() bool { return s.Is(StateFragmented) }
// ProtocolError describes error during checking/parsing websocket frames or
// headers.
type ProtocolError string
// Error implements error interface.
func (p ProtocolError) Error() string { return string(p) }
// Errors used by the protocol checkers.
var (
ErrProtocolOpCodeReserved = ProtocolError("use of reserved op code")
ErrProtocolControlPayloadOverflow = ProtocolError("control frame payload limit exceeded")
ErrProtocolControlNotFinal = ProtocolError("control frame is not final")
ErrProtocolNonZeroRsv = ProtocolError("non-zero rsv bits with no extension negotiated")
ErrProtocolMaskRequired = ProtocolError("frames from client to server must be masked")
ErrProtocolMaskUnexpected = ProtocolError("frames from server to client must be not masked")
ErrProtocolContinuationExpected = ProtocolError("unexpected non-continuation data frame")
ErrProtocolContinuationUnexpected = ProtocolError("unexpected continuation data frame")
ErrProtocolStatusCodeNotInUse = ProtocolError("status code is not in use")
ErrProtocolStatusCodeApplicationLevel = ProtocolError("status code is only application level")
ErrProtocolStatusCodeNoMeaning = ProtocolError("status code has no meaning yet")
ErrProtocolStatusCodeUnknown = ProtocolError("status code is not defined in spec")
ErrProtocolInvalidUTF8 = ProtocolError("invalid utf8 sequence in close reason")
)
// CheckHeader checks h to contain valid header data for given state s.
//
// Note that zero state (0) means that state is clean,
// neither server or client side, nor fragmented, nor extended.
func CheckHeader(h Header, s State) error {
if h.OpCode.IsReserved() {
return ErrProtocolOpCodeReserved
}
if h.OpCode.IsControl() {
if h.Length > MaxControlFramePayloadSize {
return ErrProtocolControlPayloadOverflow
}
if !h.Fin {
return ErrProtocolControlNotFinal
}
}
switch {
// [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for
// non-zero values. If a nonzero value is received and none of the
// negotiated extensions defines the meaning of such a nonzero value, the
// receiving endpoint MUST _Fail the WebSocket Connection_.
case h.Rsv != 0 && !s.Extended():
return ErrProtocolNonZeroRsv
// [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked.
// In this case, a server MAY send a Close frame with a status code of 1002 (protocol error)
// as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client.
// A client MUST close a connection if it detects a masked frame. In this case, it MAY use the
// status code 1002 (protocol error) as defined in Section 7.4.1.
case s.ServerSide() && !h.Masked:
return ErrProtocolMaskRequired
case s.ClientSide() && h.Masked:
return ErrProtocolMaskUnexpected
// [RFC6455]: See detailed explanation in 5.4 section.
case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation:
return ErrProtocolContinuationExpected
case !s.Fragmented() && h.OpCode == OpContinuation:
return ErrProtocolContinuationUnexpected
default:
return nil
}
}
// CheckCloseFrameData checks received close information
// to be valid RFC6455 compatible close info.
//
// Note that code.Empty() or code.IsAppLevel() will raise error.
//
// If endpoint sends close frame without status code (with frame.Length = 0),
// application should not check its payload.
func CheckCloseFrameData(code StatusCode, reason string) error {
switch {
case code.IsNotUsed():
return ErrProtocolStatusCodeNotInUse
case code.IsProtocolReserved():
return ErrProtocolStatusCodeApplicationLevel
case code == StatusNoMeaningYet:
return ErrProtocolStatusCodeNoMeaning
case code.IsProtocolSpec() && !code.IsProtocolDefined():
return ErrProtocolStatusCodeUnknown
case !utf8.ValidString(reason):
return ErrProtocolInvalidUTF8
default:
return nil
}
}

View file

@ -0,0 +1,61 @@
package ws
import (
"encoding/binary"
)
// Cipher applies XOR cipher to the payload using mask.
// Offset is used to cipher chunked data (e.g. in io.Reader implementations).
//
// To convert masked data into unmasked data, or vice versa, the following
// algorithm is applied. The same algorithm applies regardless of the
// direction of the translation, e.g., the same steps are applied to
// mask the data as to unmask the data.
func Cipher(payload []byte, mask [4]byte, offset int) {
n := len(payload)
if n < 8 {
for i := 0; i < n; i++ {
payload[i] ^= mask[(offset+i)%4]
}
return
}
// Calculate position in mask due to previously processed bytes number.
mpos := offset % 4
// Count number of bytes will processed one by one from the beginning of payload.
ln := remain[mpos]
// Count number of bytes will processed one by one from the end of payload.
// This is done to process payload by 8 bytes in each iteration of main loop.
rn := (n - ln) % 8
for i := 0; i < ln; i++ {
payload[i] ^= mask[(mpos+i)%4]
}
for i := n - rn; i < n; i++ {
payload[i] ^= mask[(mpos+i)%4]
}
// NOTE: we use here binary.LittleEndian regardless of what is real
// endianess on machine is. To do so, we have to use binary.LittleEndian in
// the masking loop below as well.
var (
m = binary.LittleEndian.Uint32(mask[:])
m2 = uint64(m)<<32 | uint64(m)
)
// Skip already processed right part.
// Get number of uint64 parts remaining to process.
n = (n - ln - rn) >> 3
for i := 0; i < n; i++ {
var (
j = ln + (i << 3)
chunk = payload[j : j+8]
)
p := binary.LittleEndian.Uint64(chunk)
p = p ^ m2
binary.LittleEndian.PutUint64(chunk, p)
}
}
// remain maps position in masking key [0,4) to number
// of bytes that need to be processed manually inside Cipher().
var remain = [4]int{0, 3, 2, 1}

View file

@ -0,0 +1,563 @@
package ws
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/gobwas/httphead"
"github.com/gobwas/pool/pbufio"
)
// Constants used by Dialer.
const (
DefaultClientReadBufferSize = 4096
DefaultClientWriteBufferSize = 4096
)
// Handshake represents handshake result.
type Handshake struct {
// Protocol is the subprotocol selected during handshake.
Protocol string
// Extensions is the list of negotiated extensions.
Extensions []httphead.Option
}
// Errors used by the websocket client.
var (
ErrHandshakeBadStatus = fmt.Errorf("unexpected http status")
ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol)
)
// DefaultDialer is dialer that holds no options and is used by Dial function.
var DefaultDialer Dialer
// Dial is like Dialer{}.Dial().
func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) {
return DefaultDialer.Dial(ctx, urlstr)
}
// Dialer contains options for establishing websocket connection to an url.
type Dialer struct {
// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
// They used to read and write http data while upgrading to WebSocket.
// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
//
// If a size is zero then default value is used.
ReadBufferSize, WriteBufferSize int
// Timeout is the maximum amount of time a Dial() will wait for a connect
// and an handshake to complete.
//
// The default is no timeout.
Timeout time.Duration
// Protocols is the list of subprotocols that the client wants to speak,
// ordered by preference.
//
// See https://tools.ietf.org/html/rfc6455#section-4.1
Protocols []string
// Extensions is the list of extensions that client wants to speak.
//
// Note that if server decides to use some of this extensions, Dial() will
// return Handshake struct containing a slice of items, which are the
// shallow copies of the items from this list. That is, internals of
// Extensions items are shared during Dial().
//
// See https://tools.ietf.org/html/rfc6455#section-4.1
// See https://tools.ietf.org/html/rfc6455#section-9.1
Extensions []httphead.Option
// Header is an optional HandshakeHeader instance that could be used to
// write additional headers to the handshake request.
//
// It used instead of any key-value mappings to avoid allocations in user
// land.
Header HandshakeHeader
// OnStatusError is the callback that will be called after receiving non
// "101 Continue" HTTP response status. It receives an io.Reader object
// representing server response bytes. That is, it gives ability to parse
// HTTP response somehow (probably with http.ReadResponse call) and make a
// decision of further logic.
//
// The arguments are only valid until the callback returns.
OnStatusError func(status int, reason []byte, resp io.Reader)
// OnHeader is the callback that will be called after successful parsing of
// header, that is not used during WebSocket handshake procedure. That is,
// it will be called with non-websocket headers, which could be relevant
// for application-level logic.
//
// The arguments are only valid until the callback returns.
//
// Returned value could be used to prevent processing response.
OnHeader func(key, value []byte) (err error)
// NetDial is the function that is used to get plain tcp connection.
// If it is not nil, then it is used instead of net.Dialer.
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
// TLSClient is the callback that will be called after successful dial with
// received connection and its remote host name. If it is nil, then the
// default tls.Client() will be used.
// If it is not nil, then TLSConfig field is ignored.
TLSClient func(conn net.Conn, hostname string) net.Conn
// TLSConfig is passed to tls.Client() to start TLS over established
// connection. If TLSClient is not nil, then it is ignored. If TLSConfig is
// non-nil and its ServerName is empty, then for every Dial() it will be
// cloned and appropriate ServerName will be set.
TLSConfig *tls.Config
// WrapConn is the optional callback that will be called when connection is
// ready for an i/o. That is, it will be called after successful dial and
// TLS initialization (for "wss" schemes). It may be helpful for different
// user land purposes such as end to end encryption.
//
// Note that for debugging purposes of an http handshake (e.g. sent request
// and received response), there is an wsutil.DebugDialer struct.
WrapConn func(conn net.Conn) net.Conn
}
// Dial connects to the url host and upgrades connection to WebSocket.
//
// If server has sent frames right after successful handshake then returned
// buffer will be non-nil. In other cases buffer is always nil. For better
// memory efficiency received non-nil bufio.Reader should be returned to the
// inner pool with PutReader() function after use.
//
// Note that Dialer does not implement IDNA (RFC5895) logic as net/http does.
// If you want to dial non-ascii host name, take care of its name serialization
// avoiding bad request issues. For more info see net/http Request.Write()
// implementation, especially cleanHost() function.
func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) {
u, err := url.ParseRequestURI(urlstr)
if err != nil {
return
}
// Prepare context to dial with. Initially it is the same as original, but
// if d.Timeout is non-zero and points to time that is before ctx.Deadline,
// we use more shorter context for dial.
dialctx := ctx
var deadline time.Time
if t := d.Timeout; t != 0 {
deadline = time.Now().Add(t)
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
var cancel context.CancelFunc
dialctx, cancel = context.WithDeadline(ctx, deadline)
defer cancel()
}
}
if conn, err = d.dial(dialctx, u); err != nil {
return
}
defer func() {
if err != nil {
conn.Close()
}
}()
if ctx == context.Background() {
// No need to start I/O interrupter goroutine which is not zero-cost.
conn.SetDeadline(deadline)
defer conn.SetDeadline(noDeadline)
} else {
// Context could be canceled or its deadline could be exceeded.
// Start the interrupter goroutine to handle context cancelation.
done := setupContextDeadliner(ctx, conn)
defer func() {
// Map Upgrade() error to a possible context expiration error. That
// is, even if Upgrade() err is nil, context could be already
// expired and connection be "poisoned" by SetDeadline() call.
// In that case we must not return ctx.Err() error.
done(&err)
}()
}
br, hs, err = d.Upgrade(conn, u)
return
}
var (
// netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if
// Dialer.NetDial is not provided.
netEmptyDialer net.Dialer
// tlsEmptyConfig is an empty tls.Config used as default one.
tlsEmptyConfig tls.Config
)
func tlsDefaultConfig() *tls.Config {
return &tlsEmptyConfig
}
func hostport(host string, defaultPort string) (hostname, addr string) {
var (
colon = strings.LastIndexByte(host, ':')
bracket = strings.IndexByte(host, ']')
)
if colon > bracket {
return host[:colon], host
}
return host, host + defaultPort
}
func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
dial := d.NetDial
if dial == nil {
dial = netEmptyDialer.DialContext
}
switch u.Scheme {
case "ws":
_, addr := hostport(u.Host, ":80")
conn, err = dial(ctx, "tcp", addr)
case "wss":
hostname, addr := hostport(u.Host, ":443")
conn, err = dial(ctx, "tcp", addr)
if err != nil {
return
}
tlsClient := d.TLSClient
if tlsClient == nil {
tlsClient = d.tlsClient
}
conn = tlsClient(conn, hostname)
default:
return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme)
}
if wrap := d.WrapConn; wrap != nil {
conn = wrap(conn)
}
return
}
func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn {
config := d.TLSConfig
if config == nil {
config = tlsDefaultConfig()
}
if config.ServerName == "" {
config = tlsCloneConfig(config)
config.ServerName = hostname
}
// Do not make conn.Handshake() here because downstairs we will prepare
// i/o on this conn with proper context's timeout handling.
return tls.Client(conn, config)
}
var (
// This variables are set like in net/net.go.
// noDeadline is just zero value for readability.
noDeadline = time.Time{}
// aLongTimeAgo is a non-zero time, far in the past, used for immediate
// cancelation of dials.
aLongTimeAgo = time.Unix(42, 0)
)
// Upgrade writes an upgrade request to the given io.ReadWriter conn at given
// url u and reads a response from it.
//
// It is a caller responsibility to manage I/O deadlines on conn.
//
// It returns handshake info and some bytes which could be written by the peer
// right after response and be caught by us during buffered read.
func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) {
// headerSeen constants helps to report whether or not some header was seen
// during reading request bytes.
const (
headerSeenUpgrade = 1 << iota
headerSeenConnection
headerSeenSecAccept
// headerSeenAll is the value that we expect to receive at the end of
// headers read/parse loop.
headerSeenAll = 0 |
headerSeenUpgrade |
headerSeenConnection |
headerSeenSecAccept
)
br = pbufio.GetReader(conn,
nonZero(d.ReadBufferSize, DefaultClientReadBufferSize),
)
bw := pbufio.GetWriter(conn,
nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize),
)
defer func() {
pbufio.PutWriter(bw)
if br.Buffered() == 0 || err != nil {
// Server does not wrote additional bytes to the connection or
// error occurred. That is, no reason to return buffer.
pbufio.PutReader(br)
br = nil
}
}()
nonce := make([]byte, nonceSize)
initNonce(nonce)
httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header)
if err = bw.Flush(); err != nil {
return
}
// Read HTTP status line like "HTTP/1.1 101 Switching Protocols".
sl, err := readLine(br)
if err != nil {
return
}
// Begin validation of the response.
// See https://tools.ietf.org/html/rfc6455#section-4.2.2
// Parse request line data like HTTP version, uri and method.
resp, err := httpParseResponseLine(sl)
if err != nil {
return
}
// Even if RFC says "1.1 or higher" without mentioning the part of the
// version, we apply it only to minor part.
if resp.major != 1 || resp.minor < 1 {
err = ErrHandshakeBadProtocol
return
}
if resp.status != 101 {
err = StatusError(resp.status)
if onStatusError := d.OnStatusError; onStatusError != nil {
// Invoke callback with multireader of status-line bytes br.
onStatusError(resp.status, resp.reason,
io.MultiReader(
bytes.NewReader(sl),
strings.NewReader(crlf),
br,
),
)
}
return
}
// If response status is 101 then we expect all technical headers to be
// valid. If not, then we stop processing response without giving user
// ability to read non-technical headers. That is, we do not distinguish
// technical errors (such as parsing error) and protocol errors.
var headerSeen byte
for {
line, e := readLine(br)
if e != nil {
err = e
return
}
if len(line) == 0 {
// Blank line, no more lines to read.
break
}
k, v, ok := httpParseHeaderLine(line)
if !ok {
err = ErrMalformedResponse
return
}
switch btsToString(k) {
case headerUpgradeCanonical:
headerSeen |= headerSeenUpgrade
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
err = ErrHandshakeBadUpgrade
return
}
case headerConnectionCanonical:
headerSeen |= headerSeenConnection
// Note that as RFC6455 says:
// > A |Connection| header field with value "Upgrade".
// That is, in server side, "Connection" header could contain
// multiple token. But in response it must contains exactly one.
if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
err = ErrHandshakeBadConnection
return
}
case headerSecAcceptCanonical:
headerSeen |= headerSeenSecAccept
if !checkAcceptFromNonce(v, nonce) {
err = ErrHandshakeBadSecAccept
return
}
case headerSecProtocolCanonical:
// RFC6455 1.3:
// "The server selects one or none of the acceptable protocols
// and echoes that value in its handshake to indicate that it has
// selected that protocol."
for _, want := range d.Protocols {
if string(v) == want {
hs.Protocol = want
break
}
}
if hs.Protocol == "" {
// Server echoed subprotocol that is not present in client
// requested protocols.
err = ErrHandshakeBadSubProtocol
return
}
case headerSecExtensionsCanonical:
hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
if err != nil {
return
}
default:
if onHeader := d.OnHeader; onHeader != nil {
if e := onHeader(k, v); e != nil {
err = e
return
}
}
}
}
if err == nil && headerSeen != headerSeenAll {
switch {
case headerSeen&headerSeenUpgrade == 0:
err = ErrHandshakeBadUpgrade
case headerSeen&headerSeenConnection == 0:
err = ErrHandshakeBadConnection
case headerSeen&headerSeenSecAccept == 0:
err = ErrHandshakeBadSecAccept
default:
panic("unknown headers state")
}
}
return
}
// PutReader returns bufio.Reader instance to the inner reuse pool.
// It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which
// contains unprocessed buffered data, that was sent by the server quickly
// right after handshake.
func PutReader(br *bufio.Reader) {
pbufio.PutReader(br)
}
// StatusError contains an unexpected status-line code from the server.
type StatusError int
func (s StatusError) Error() string {
return "unexpected HTTP response status: " + strconv.Itoa(int(s))
}
func isTimeoutError(err error) bool {
t, ok := err.(net.Error)
return ok && t.Timeout()
}
func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) {
if len(selected) == 0 {
return received, nil
}
var (
index int
option httphead.Option
err error
)
index = -1
match := func() (ok bool) {
for _, want := range wanted {
// A server accepts one or more extensions by including a
// |Sec-WebSocket-Extensions| header field containing one or more
// extensions that were requested by the client.
//
// The interpretation of any extension parameters, and what
// constitutes a valid response by a server to a requested set of
// parameters by a client, will be defined by each such extension.
if bytes.Equal(option.Name, want.Name) {
// Check parsed extension to be present in client
// requested extensions. We move matched extension
// from client list to avoid allocation.
received = append(received, option)
return true
}
}
return false
}
ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control {
if i != index {
// Met next option.
index = i
if i != 0 && !match() {
// Server returned non-requested extension.
err = ErrHandshakeBadExtensions
return httphead.ControlBreak
}
option = httphead.Option{Name: name}
}
if attr != nil {
option.Parameters.Set(attr, val)
}
return httphead.ControlContinue
})
if !ok {
err = ErrMalformedResponse
return received, err
}
if !match() {
return received, ErrHandshakeBadExtensions
}
return received, err
}
// setupContextDeadliner is a helper function that starts connection I/O
// interrupter goroutine.
//
// Started goroutine calls SetDeadline() with long time ago value when context
// become expired to make any I/O operations failed. It returns done function
// that stops started goroutine and maps error received from conn I/O methods
// to possible context expiration error.
//
// In concern with possible SetDeadline() call inside interrupter goroutine,
// caller passes pointer to its I/O error (even if it is nil) to done(&err).
// That is, even if I/O error is nil, context could be already expired and
// connection "poisoned" by SetDeadline() call. In that case done(&err) will
// store at *err ctx.Err() result. If err is caused not by timeout, it will
// leaved untouched.
func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) {
var (
quit = make(chan struct{})
interrupt = make(chan error, 1)
)
go func() {
select {
case <-quit:
interrupt <- nil
case <-ctx.Done():
// Cancel i/o immediately.
conn.SetDeadline(aLongTimeAgo)
interrupt <- ctx.Err()
}
}()
return func(err *error) {
close(quit)
// If ctx.Err() is non-nil and the original err is net.Error with
// Timeout() == true, then it means that I/O was canceled by us by
// SetDeadline(aLongTimeAgo) call, or by somebody else previously
// by conn.SetDeadline(x).
//
// Even on race condition when both deadlines are expired
// (SetDeadline() made not by us and context's), we prefer ctx.Err() to
// be returned.
if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) {
*err = ctxErr
}
}
}

View file

@ -0,0 +1,35 @@
// +build !go1.8
package ws
import "crypto/tls"
func tlsCloneConfig(c *tls.Config) *tls.Config {
// NOTE: we copying SessionTicketsDisabled and SessionTicketKey here
// without calling inner c.initOnceServer somehow because we only could get
// here from the ws.Dialer code, which is obviously a client and makes
// tls.Client() when it gets new net.Conn.
return &tls.Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: c.ClientSessionCache,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
}
}

View file

@ -0,0 +1,9 @@
// +build go1.8
package ws
import "crypto/tls"
func tlsCloneConfig(c *tls.Config) *tls.Config {
return c.Clone()
}

View file

@ -0,0 +1,81 @@
/*
Package ws implements a client and server for the WebSocket protocol as
specified in RFC 6455.
The main purpose of this package is to provide simple low-level API for
efficient work with protocol.
Overview.
Upgrade to WebSocket (or WebSocket handshake) can be done in two ways.
The first way is to use `net/http` server:
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
})
The second and much more efficient way is so-called "zero-copy upgrade". It
avoids redundant allocations and copying of not used headers or other request
data. User decides by himself which data should be copied.
ln, err := net.Listen("tcp", ":8080")
if err != nil {
// handle error
}
conn, err := ln.Accept()
if err != nil {
// handle error
}
handshake, err := ws.Upgrade(conn)
if err != nil {
// handle error
}
For customization details see `ws.Upgrader` documentation.
After WebSocket handshake you can work with connection in multiple ways.
That is, `ws` does not force the only one way of how to work with WebSocket:
header, err := ws.ReadHeader(conn)
if err != nil {
// handle err
}
buf := make([]byte, header.Length)
_, err := io.ReadFull(conn, buf)
if err != nil {
// handle err
}
resp := ws.NewBinaryFrame([]byte("hello, world!"))
if err := ws.WriteFrame(conn, frame); err != nil {
// handle err
}
As you can see, it stream friendly:
const N = 42
ws.WriteHeader(ws.Header{
Fin: true,
Length: N,
OpCode: ws.OpBinary,
})
io.CopyN(conn, rand.Reader, N)
Or:
header, err := ws.ReadHeader(conn)
if err != nil {
// handle err
}
io.CopyN(ioutil.Discard, conn, header.Length)
For more info see the documentation.
*/
package ws

View file

@ -0,0 +1,54 @@
package ws
// RejectOption represents an option used to control the way connection is
// rejected.
type RejectOption func(*rejectConnectionError)
// RejectionReason returns an option that makes connection to be rejected with
// given reason.
func RejectionReason(reason string) RejectOption {
return func(err *rejectConnectionError) {
err.reason = reason
}
}
// RejectionStatus returns an option that makes connection to be rejected with
// given HTTP status code.
func RejectionStatus(code int) RejectOption {
return func(err *rejectConnectionError) {
err.code = code
}
}
// RejectionHeader returns an option that makes connection to be rejected with
// given HTTP headers.
func RejectionHeader(h HandshakeHeader) RejectOption {
return func(err *rejectConnectionError) {
err.header = h
}
}
// RejectConnectionError constructs an error that could be used to control the way
// handshake is rejected by Upgrader.
func RejectConnectionError(options ...RejectOption) error {
err := new(rejectConnectionError)
for _, opt := range options {
opt(err)
}
return err
}
// rejectConnectionError represents a rejection of upgrade error.
//
// It can be returned by Upgrader's On* hooks to control the way WebSocket
// handshake is rejected.
type rejectConnectionError struct {
reason string
code int
header HandshakeHeader
}
// Error implements error interface.
func (r *rejectConnectionError) Error() string {
return r.reason
}

View file

@ -0,0 +1,420 @@
package ws
import (
"bytes"
"encoding/binary"
"math/rand"
)
// Constants defined by specification.
const (
// All control frames MUST have a payload length of 125 bytes or less and MUST NOT be fragmented.
MaxControlFramePayloadSize = 125
)
// OpCode represents operation code.
type OpCode byte
// Operation codes defined by specification.
// See https://tools.ietf.org/html/rfc6455#section-5.2
const (
OpContinuation OpCode = 0x0
OpText OpCode = 0x1
OpBinary OpCode = 0x2
OpClose OpCode = 0x8
OpPing OpCode = 0x9
OpPong OpCode = 0xa
)
// IsControl checks whether the c is control operation code.
// See https://tools.ietf.org/html/rfc6455#section-5.5
func (c OpCode) IsControl() bool {
// RFC6455: Control frames are identified by opcodes where
// the most significant bit of the opcode is 1.
//
// Note that OpCode is only 4 bit length.
return c&0x8 != 0
}
// IsData checks whether the c is data operation code.
// See https://tools.ietf.org/html/rfc6455#section-5.6
func (c OpCode) IsData() bool {
// RFC6455: Data frames (e.g., non-control frames) are identified by opcodes
// where the most significant bit of the opcode is 0.
//
// Note that OpCode is only 4 bit length.
return c&0x8 == 0
}
// IsReserved checks whether the c is reserved operation code.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func (c OpCode) IsReserved() bool {
// RFC6455:
// %x3-7 are reserved for further non-control frames
// %xB-F are reserved for further control frames
return (0x3 <= c && c <= 0x7) || (0xb <= c && c <= 0xf)
}
// StatusCode represents the encoded reason for closure of websocket connection.
//
// There are few helper methods on StatusCode that helps to define a range in
// which given code is lay in. accordingly to ranges defined in specification.
//
// See https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode uint16
// StatusCodeRange describes range of StatusCode values.
type StatusCodeRange struct {
Min, Max StatusCode
}
// Status code ranges defined by specification.
// See https://tools.ietf.org/html/rfc6455#section-7.4.2
var (
StatusRangeNotInUse = StatusCodeRange{0, 999}
StatusRangeProtocol = StatusCodeRange{1000, 2999}
StatusRangeApplication = StatusCodeRange{3000, 3999}
StatusRangePrivate = StatusCodeRange{4000, 4999}
)
// Status codes defined by specification.
// See https://tools.ietf.org/html/rfc6455#section-7.4.1
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
StatusNoMeaningYet StatusCode = 1004
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExt StatusCode = 1010
StatusInternalServerError StatusCode = 1011
StatusTLSHandshake StatusCode = 1015
// StatusAbnormalClosure is a special code designated for use in
// applications.
StatusAbnormalClosure StatusCode = 1006
// StatusNoStatusRcvd is a special code designated for use in applications.
StatusNoStatusRcvd StatusCode = 1005
)
// In reports whether the code is defined in given range.
func (s StatusCode) In(r StatusCodeRange) bool {
return r.Min <= s && s <= r.Max
}
// Empty reports whether the code is empty.
// Empty code has no any meaning neither app level codes nor other.
// This method is useful just to check that code is golang default value 0.
func (s StatusCode) Empty() bool {
return s == 0
}
// IsNotUsed reports whether the code is predefined in not used range.
func (s StatusCode) IsNotUsed() bool {
return s.In(StatusRangeNotInUse)
}
// IsApplicationSpec reports whether the code should be defined by
// application, framework or libraries specification.
func (s StatusCode) IsApplicationSpec() bool {
return s.In(StatusRangeApplication)
}
// IsPrivateSpec reports whether the code should be defined privately.
func (s StatusCode) IsPrivateSpec() bool {
return s.In(StatusRangePrivate)
}
// IsProtocolSpec reports whether the code should be defined by protocol specification.
func (s StatusCode) IsProtocolSpec() bool {
return s.In(StatusRangeProtocol)
}
// IsProtocolDefined reports whether the code is already defined by protocol specification.
func (s StatusCode) IsProtocolDefined() bool {
switch s {
case StatusNormalClosure,
StatusGoingAway,
StatusProtocolError,
StatusUnsupportedData,
StatusInvalidFramePayloadData,
StatusPolicyViolation,
StatusMessageTooBig,
StatusMandatoryExt,
StatusInternalServerError,
StatusNoStatusRcvd,
StatusAbnormalClosure,
StatusTLSHandshake:
return true
}
return false
}
// IsProtocolReserved reports whether the code is defined by protocol specification
// to be reserved only for application usage purpose.
func (s StatusCode) IsProtocolReserved() bool {
switch s {
// [RFC6455]: {1005,1006,1015} is a reserved value and MUST NOT be set as a status code in a
// Close control frame by an endpoint.
case StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
return true
default:
return false
}
}
// Compiled control frames for common use cases.
// For construct-serialize optimizations.
var (
CompiledPing = MustCompileFrame(NewPingFrame(nil))
CompiledPong = MustCompileFrame(NewPongFrame(nil))
CompiledClose = MustCompileFrame(NewCloseFrame(nil))
CompiledCloseNormalClosure = MustCompileFrame(closeFrameNormalClosure)
CompiledCloseGoingAway = MustCompileFrame(closeFrameGoingAway)
CompiledCloseProtocolError = MustCompileFrame(closeFrameProtocolError)
CompiledCloseUnsupportedData = MustCompileFrame(closeFrameUnsupportedData)
CompiledCloseNoMeaningYet = MustCompileFrame(closeFrameNoMeaningYet)
CompiledCloseInvalidFramePayloadData = MustCompileFrame(closeFrameInvalidFramePayloadData)
CompiledClosePolicyViolation = MustCompileFrame(closeFramePolicyViolation)
CompiledCloseMessageTooBig = MustCompileFrame(closeFrameMessageTooBig)
CompiledCloseMandatoryExt = MustCompileFrame(closeFrameMandatoryExt)
CompiledCloseInternalServerError = MustCompileFrame(closeFrameInternalServerError)
CompiledCloseTLSHandshake = MustCompileFrame(closeFrameTLSHandshake)
)
// Header represents websocket frame header.
// See https://tools.ietf.org/html/rfc6455#section-5.2
type Header struct {
Fin bool
Rsv byte
OpCode OpCode
Masked bool
Mask [4]byte
Length int64
}
// Rsv1 reports whether the header has first rsv bit set.
func (h Header) Rsv1() bool { return h.Rsv&bit5 != 0 }
// Rsv2 reports whether the header has second rsv bit set.
func (h Header) Rsv2() bool { return h.Rsv&bit6 != 0 }
// Rsv3 reports whether the header has third rsv bit set.
func (h Header) Rsv3() bool { return h.Rsv&bit7 != 0 }
// Rsv creates rsv byte representation from bits.
func Rsv(r1, r2, r3 bool) (rsv byte) {
if r1 {
rsv |= bit5
}
if r2 {
rsv |= bit6
}
if r3 {
rsv |= bit7
}
return rsv
}
// RsvBits returns rsv bits from bytes representation.
func RsvBits(rsv byte) (r1, r2, r3 bool) {
r1 = rsv&bit5 != 0
r2 = rsv&bit6 != 0
r3 = rsv&bit7 != 0
return
}
// Frame represents websocket frame.
// See https://tools.ietf.org/html/rfc6455#section-5.2
type Frame struct {
Header Header
Payload []byte
}
// NewFrame creates frame with given operation code,
// flag of completeness and payload bytes.
func NewFrame(op OpCode, fin bool, p []byte) Frame {
return Frame{
Header: Header{
Fin: fin,
OpCode: op,
Length: int64(len(p)),
},
Payload: p,
}
}
// NewTextFrame creates text frame with p as payload.
// Note that p is not copied.
func NewTextFrame(p []byte) Frame {
return NewFrame(OpText, true, p)
}
// NewBinaryFrame creates binary frame with p as payload.
// Note that p is not copied.
func NewBinaryFrame(p []byte) Frame {
return NewFrame(OpBinary, true, p)
}
// NewPingFrame creates ping frame with p as payload.
// Note that p is not copied.
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
// to RFC.
func NewPingFrame(p []byte) Frame {
return NewFrame(OpPing, true, p)
}
// NewPongFrame creates pong frame with p as payload.
// Note that p is not copied.
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
// to RFC.
func NewPongFrame(p []byte) Frame {
return NewFrame(OpPong, true, p)
}
// NewCloseFrame creates close frame with given close body.
// Note that p is not copied.
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
// to RFC.
func NewCloseFrame(p []byte) Frame {
return NewFrame(OpClose, true, p)
}
// NewCloseFrameBody encodes a closure code and a reason into a binary
// representation.
//
// It returns slice which is at most MaxControlFramePayloadSize bytes length.
// If the reason is too big it will be cropped to fit the limit defined by the
// spec.
//
// See https://tools.ietf.org/html/rfc6455#section-5.5
func NewCloseFrameBody(code StatusCode, reason string) []byte {
n := min(2+len(reason), MaxControlFramePayloadSize)
p := make([]byte, n)
crop := min(MaxControlFramePayloadSize-2, len(reason))
PutCloseFrameBody(p, code, reason[:crop])
return p
}
// PutCloseFrameBody encodes code and reason into buf.
//
// It will panic if the buffer is too small to accommodate a code or a reason.
//
// PutCloseFrameBody does not check buffer to be RFC compliant, but note that
// by RFC it must be at most MaxControlFramePayloadSize.
func PutCloseFrameBody(p []byte, code StatusCode, reason string) {
_ = p[1+len(reason)]
binary.BigEndian.PutUint16(p, uint16(code))
copy(p[2:], reason)
}
// MaskFrame masks frame and returns frame with masked payload and Mask header's field set.
// Note that it copies f payload to prevent collisions.
// For less allocations you could use MaskFrameInPlace or construct frame manually.
func MaskFrame(f Frame) Frame {
return MaskFrameWith(f, NewMask())
}
// MaskFrameWith masks frame with given mask and returns frame
// with masked payload and Mask header's field set.
// Note that it copies f payload to prevent collisions.
// For less allocations you could use MaskFrameInPlaceWith or construct frame manually.
func MaskFrameWith(f Frame, mask [4]byte) Frame {
// TODO(gobwas): check CopyCipher ws copy() Cipher().
p := make([]byte, len(f.Payload))
copy(p, f.Payload)
f.Payload = p
return MaskFrameInPlaceWith(f, mask)
}
// MaskFrameInPlace masks frame and returns frame with masked payload and Mask
// header's field set.
// Note that it applies xor cipher to f.Payload without copying, that is, it
// modifies f.Payload inplace.
func MaskFrameInPlace(f Frame) Frame {
return MaskFrameInPlaceWith(f, NewMask())
}
var zeroMask [4]byte
// UnmaskFrame unmasks frame and returns frame with unmasked payload and Mask
// header's field cleared.
// Note that it copies f payload.
func UnmaskFrame(f Frame) Frame {
p := make([]byte, len(f.Payload))
copy(p, f.Payload)
f.Payload = p
return UnmaskFrameInPlace(f)
}
// UnmaskFrameInPlace unmasks frame and returns frame with unmasked payload and
// Mask header's field cleared.
// Note that it applies xor cipher to f.Payload without copying, that is, it
// modifies f.Payload inplace.
func UnmaskFrameInPlace(f Frame) Frame {
Cipher(f.Payload, f.Header.Mask, 0)
f.Header.Masked = false
f.Header.Mask = zeroMask
return f
}
// MaskFrameInPlaceWith masks frame with given mask and returns frame
// with masked payload and Mask header's field set.
// Note that it applies xor cipher to f.Payload without copying, that is, it
// modifies f.Payload inplace.
func MaskFrameInPlaceWith(f Frame, m [4]byte) Frame {
f.Header.Masked = true
f.Header.Mask = m
Cipher(f.Payload, m, 0)
return f
}
// NewMask creates new random mask.
func NewMask() (ret [4]byte) {
binary.BigEndian.PutUint32(ret[:], rand.Uint32())
return
}
// CompileFrame returns byte representation of given frame.
// In terms of memory consumption it is useful to precompile static frames
// which are often used.
func CompileFrame(f Frame) (bts []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, 16))
err = WriteFrame(buf, f)
bts = buf.Bytes()
return
}
// MustCompileFrame is like CompileFrame but panics if frame can not be
// encoded.
func MustCompileFrame(f Frame) []byte {
bts, err := CompileFrame(f)
if err != nil {
panic(err)
}
return bts
}
func makeCloseFrame(code StatusCode) Frame {
return NewCloseFrame(NewCloseFrameBody(code, ""))
}
var (
closeFrameNormalClosure = makeCloseFrame(StatusNormalClosure)
closeFrameGoingAway = makeCloseFrame(StatusGoingAway)
closeFrameProtocolError = makeCloseFrame(StatusProtocolError)
closeFrameUnsupportedData = makeCloseFrame(StatusUnsupportedData)
closeFrameNoMeaningYet = makeCloseFrame(StatusNoMeaningYet)
closeFrameInvalidFramePayloadData = makeCloseFrame(StatusInvalidFramePayloadData)
closeFramePolicyViolation = makeCloseFrame(StatusPolicyViolation)
closeFrameMessageTooBig = makeCloseFrame(StatusMessageTooBig)
closeFrameMandatoryExt = makeCloseFrame(StatusMandatoryExt)
closeFrameInternalServerError = makeCloseFrame(StatusInternalServerError)
closeFrameTLSHandshake = makeCloseFrame(StatusTLSHandshake)
)

View file

@ -0,0 +1,9 @@
module github.com/gobwas/ws
go 1.15
require (
github.com/gobwas/httphead v0.1.0
github.com/gobwas/pool v0.2.1
golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d // indirect
)

View file

@ -0,0 +1,6 @@
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d h1:MiWWjyhUzZ+jvhZvloX6ZrUsdEghn8a64Upd8EMHglE=
golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View file

@ -0,0 +1,504 @@
package ws
import (
"bufio"
"bytes"
"io"
"net/http"
"net/textproto"
"net/url"
"strconv"
"github.com/gobwas/httphead"
)
const (
crlf = "\r\n"
colonAndSpace = ": "
commaAndSpace = ", "
)
const (
textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
)
var (
textHeadBadRequest = statusText(http.StatusBadRequest)
textHeadInternalServerError = statusText(http.StatusInternalServerError)
textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired)
textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol)
textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod)
textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost)
textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade)
textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept)
textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey)
textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired)
)
var (
headerHost = "Host"
headerUpgrade = "Upgrade"
headerConnection = "Connection"
headerSecVersion = "Sec-WebSocket-Version"
headerSecProtocol = "Sec-WebSocket-Protocol"
headerSecExtensions = "Sec-WebSocket-Extensions"
headerSecKey = "Sec-WebSocket-Key"
headerSecAccept = "Sec-WebSocket-Accept"
headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost)
headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade)
headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection)
headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion)
headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol)
headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)
headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey)
headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept)
)
var (
specHeaderValueUpgrade = []byte("websocket")
specHeaderValueConnection = []byte("Upgrade")
specHeaderValueConnectionLower = []byte("upgrade")
specHeaderValueSecVersion = []byte("13")
)
var (
httpVersion1_0 = []byte("HTTP/1.0")
httpVersion1_1 = []byte("HTTP/1.1")
httpVersionPrefix = []byte("HTTP/")
)
type httpRequestLine struct {
method, uri []byte
major, minor int
}
type httpResponseLine struct {
major, minor int
status int
reason []byte
}
// httpParseRequestLine parses http request line like "GET / HTTP/1.0".
func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
var proto []byte
req.method, req.uri, proto = bsplit3(line, ' ')
var ok bool
req.major, req.minor, ok = httpParseVersion(proto)
if !ok {
err = ErrMalformedRequest
return
}
return
}
func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
var (
proto []byte
status []byte
)
proto, status, resp.reason = bsplit3(line, ' ')
var ok bool
resp.major, resp.minor, ok = httpParseVersion(proto)
if !ok {
return resp, ErrMalformedResponse
}
var convErr error
resp.status, convErr = asciiToInt(status)
if convErr != nil {
return resp, ErrMalformedResponse
}
return resp, nil
}
// httpParseVersion parses major and minor version of HTTP protocol. It returns
// parsed values and true if parse is ok.
func httpParseVersion(bts []byte) (major, minor int, ok bool) {
switch {
case bytes.Equal(bts, httpVersion1_0):
return 1, 0, true
case bytes.Equal(bts, httpVersion1_1):
return 1, 1, true
case len(bts) < 8:
return
case !bytes.Equal(bts[:5], httpVersionPrefix):
return
}
bts = bts[5:]
dot := bytes.IndexByte(bts, '.')
if dot == -1 {
return
}
var err error
major, err = asciiToInt(bts[:dot])
if err != nil {
return
}
minor, err = asciiToInt(bts[dot+1:])
if err != nil {
return
}
return major, minor, true
}
// httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
// values and true if parse is ok.
func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
colon := bytes.IndexByte(line, ':')
if colon == -1 {
return
}
k = btrim(line[:colon])
// TODO(gobwas): maybe use just lower here?
canonicalizeHeaderKey(k)
v = btrim(line[colon+1:])
return k, v, true
}
// httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
// that key is already canonical. This helps to increase performance.
func httpGetHeader(h http.Header, key string) string {
if h == nil {
return ""
}
v := h[key]
if len(v) == 0 {
return ""
}
return v[0]
}
// The request MAY include a header field with the name
// |Sec-WebSocket-Protocol|. If present, this value indicates one or more
// comma-separated subprotocol the client wishes to speak, ordered by
// preference. The elements that comprise this value MUST be non-empty strings
// with characters in the range U+0021 to U+007E not including separator
// characters as defined in [RFC2616] and MUST all be unique strings. The ABNF
// for the value of this header field is 1#token, where the definitions of
// constructs and rules are as given in [RFC2616].
func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
if check(btsToString(v)) {
ret = string(v)
return false
}
return true
})
return
}
func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
var selected []byte
ok = httphead.ScanTokens(h, func(v []byte) bool {
if check(v) {
selected = v
return false
}
return true
})
if ok && selected != nil {
return string(selected), true
}
return
}
func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
s := httphead.OptionSelector{
Flags: httphead.SelectCopy,
Check: check,
}
return s.Select(h, selected)
}
func negotiateMaybe(in httphead.Option, dest []httphead.Option, f func(httphead.Option) (httphead.Option, error)) ([]httphead.Option, error) {
if in.Size() == 0 {
return dest, nil
}
opt, err := f(in)
if err != nil {
return nil, err
}
if opt.Size() > 0 {
dest = append(dest, opt)
}
return dest, nil
}
func negotiateExtensions(
h []byte, dest []httphead.Option,
f func(httphead.Option) (httphead.Option, error),
) (_ []httphead.Option, err error) {
index := -1
var current httphead.Option
ok := httphead.ScanOptions(h, func(i int, name, attr, val []byte) httphead.Control {
if i != index {
dest, err = negotiateMaybe(current, dest, f)
if err != nil {
return httphead.ControlBreak
}
index = i
current = httphead.Option{Name: name}
}
if attr != nil {
current.Parameters.Set(attr, val)
}
return httphead.ControlContinue
})
if !ok {
return nil, ErrMalformedRequest
}
return negotiateMaybe(current, dest, f)
}
func httpWriteHeader(bw *bufio.Writer, key, value string) {
httpWriteHeaderKey(bw, key)
bw.WriteString(value)
bw.WriteString(crlf)
}
func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
httpWriteHeaderKey(bw, key)
bw.Write(value)
bw.WriteString(crlf)
}
func httpWriteHeaderKey(bw *bufio.Writer, key string) {
bw.WriteString(key)
bw.WriteString(colonAndSpace)
}
func httpWriteUpgradeRequest(
bw *bufio.Writer,
u *url.URL,
nonce []byte,
protocols []string,
extensions []httphead.Option,
header HandshakeHeader,
) {
bw.WriteString("GET ")
bw.WriteString(u.RequestURI())
bw.WriteString(" HTTP/1.1\r\n")
httpWriteHeader(bw, headerHost, u.Host)
httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
// NOTE: write nonce bytes as a string to prevent heap allocation
// WriteString() copy given string into its inner buffer, unlike Write()
// which may write p directly to the underlying io.Writer which in turn
// will lead to p escape.
httpWriteHeader(bw, headerSecKey, btsToString(nonce))
if len(protocols) > 0 {
httpWriteHeaderKey(bw, headerSecProtocol)
for i, p := range protocols {
if i > 0 {
bw.WriteString(commaAndSpace)
}
bw.WriteString(p)
}
bw.WriteString(crlf)
}
if len(extensions) > 0 {
httpWriteHeaderKey(bw, headerSecExtensions)
httphead.WriteOptions(bw, extensions)
bw.WriteString(crlf)
}
if header != nil {
header.WriteTo(bw)
}
bw.WriteString(crlf)
}
func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
bw.WriteString(textHeadUpgrade)
httpWriteHeaderKey(bw, headerSecAccept)
writeAccept(bw, nonce)
bw.WriteString(crlf)
if hs.Protocol != "" {
httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
}
if len(hs.Extensions) > 0 {
httpWriteHeaderKey(bw, headerSecExtensions)
httphead.WriteOptions(bw, hs.Extensions)
bw.WriteString(crlf)
}
if header != nil {
header(bw)
}
bw.WriteString(crlf)
}
func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
switch code {
case http.StatusBadRequest:
bw.WriteString(textHeadBadRequest)
case http.StatusInternalServerError:
bw.WriteString(textHeadInternalServerError)
case http.StatusUpgradeRequired:
bw.WriteString(textHeadUpgradeRequired)
default:
writeStatusText(bw, code)
}
// Write custom headers.
if header != nil {
header(bw)
}
switch err {
case ErrHandshakeBadProtocol:
bw.WriteString(textTailErrHandshakeBadProtocol)
case ErrHandshakeBadMethod:
bw.WriteString(textTailErrHandshakeBadMethod)
case ErrHandshakeBadHost:
bw.WriteString(textTailErrHandshakeBadHost)
case ErrHandshakeBadUpgrade:
bw.WriteString(textTailErrHandshakeBadUpgrade)
case ErrHandshakeBadConnection:
bw.WriteString(textTailErrHandshakeBadConnection)
case ErrHandshakeBadSecAccept:
bw.WriteString(textTailErrHandshakeBadSecAccept)
case ErrHandshakeBadSecKey:
bw.WriteString(textTailErrHandshakeBadSecKey)
case ErrHandshakeBadSecVersion:
bw.WriteString(textTailErrHandshakeBadSecVersion)
case ErrHandshakeUpgradeRequired:
bw.WriteString(textTailErrUpgradeRequired)
case nil:
bw.WriteString(crlf)
default:
writeErrorText(bw, err)
}
}
func writeStatusText(bw *bufio.Writer, code int) {
bw.WriteString("HTTP/1.1 ")
bw.WriteString(strconv.Itoa(code))
bw.WriteByte(' ')
bw.WriteString(http.StatusText(code))
bw.WriteString(crlf)
bw.WriteString("Content-Type: text/plain; charset=utf-8")
bw.WriteString(crlf)
}
func writeErrorText(bw *bufio.Writer, err error) {
body := err.Error()
bw.WriteString("Content-Length: ")
bw.WriteString(strconv.Itoa(len(body)))
bw.WriteString(crlf)
bw.WriteString(crlf)
bw.WriteString(body)
}
// httpError is like the http.Error with WebSocket context exception.
func httpError(w http.ResponseWriter, body string, code int) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
w.WriteHeader(code)
w.Write([]byte(body))
}
// statusText is a non-performant status text generator.
// NOTE: Used only to generate constants.
func statusText(code int) string {
var buf bytes.Buffer
bw := bufio.NewWriter(&buf)
writeStatusText(bw, code)
bw.Flush()
return buf.String()
}
// errorText is a non-performant error text generator.
// NOTE: Used only to generate constants.
func errorText(err error) string {
var buf bytes.Buffer
bw := bufio.NewWriter(&buf)
writeErrorText(bw, err)
bw.Flush()
return buf.String()
}
// HandshakeHeader is the interface that writes both upgrade request or
// response headers into a given io.Writer.
type HandshakeHeader interface {
io.WriterTo
}
// HandshakeHeaderString is an adapter to allow the use of headers represented
// by ordinary string as HandshakeHeader.
type HandshakeHeaderString string
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
n, err := io.WriteString(w, string(s))
return int64(n), err
}
// HandshakeHeaderBytes is an adapter to allow the use of headers represented
// by ordinary slice of bytes as HandshakeHeader.
type HandshakeHeaderBytes []byte
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(b)
return int64(n), err
}
// HandshakeHeaderFunc is an adapter to allow the use of headers represented by
// ordinary function as HandshakeHeader.
type HandshakeHeaderFunc func(io.Writer) (int64, error)
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
return f(w)
}
// HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
// HandshakeHeader.
type HandshakeHeaderHTTP http.Header
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
wr := writer{w: w}
err := http.Header(h).Write(&wr)
return wr.n, err
}
type writer struct {
n int64
w io.Writer
}
func (w *writer) WriteString(s string) (int, error) {
n, err := io.WriteString(w.w, s)
w.n += int64(n)
return n, err
}
func (w *writer) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.n += int64(n)
return n, err
}

View file

@ -0,0 +1,80 @@
package ws
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"fmt"
"math/rand"
)
const (
// RFC6455: The value of this header field MUST be a nonce consisting of a
// randomly selected 16-byte value that has been base64-encoded (see
// Section 4 of [RFC4648]). The nonce MUST be selected randomly for each
// connection.
nonceKeySize = 16
nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
// RFC6455: The value of this header field is constructed by concatenating
// /key/, defined above in step 4 in Section 4.2.2, with the string
// "258EAFA5- E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this
// concatenated value to obtain a 20-byte value and base64- encoding (see
// Section 4 of [RFC4648]) this 20-byte hash.
acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size)
)
// initNonce fills given slice with random base64-encoded nonce bytes.
func initNonce(dst []byte) {
// NOTE: bts does not escape.
bts := make([]byte, nonceKeySize)
if _, err := rand.Read(bts); err != nil {
panic(fmt.Sprintf("rand read error: %s", err))
}
base64.StdEncoding.Encode(dst, bts)
}
// checkAcceptFromNonce reports whether given accept bytes are valid for given
// nonce bytes.
func checkAcceptFromNonce(accept, nonce []byte) bool {
if len(accept) != acceptSize {
return false
}
// NOTE: expect does not escape.
expect := make([]byte, acceptSize)
initAcceptFromNonce(expect, nonce)
return bytes.Equal(expect, accept)
}
// initAcceptFromNonce fills given slice with accept bytes generated from given
// nonce bytes. Given buffer should be exactly acceptSize bytes.
func initAcceptFromNonce(accept, nonce []byte) {
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
if len(accept) != acceptSize {
panic("accept buffer is invalid")
}
if len(nonce) != nonceSize {
panic("nonce is invalid")
}
p := make([]byte, nonceSize+len(magic))
copy(p[:nonceSize], nonce)
copy(p[nonceSize:], magic)
sum := sha1.Sum(p)
base64.StdEncoding.Encode(accept, sum[:])
return
}
func writeAccept(bw *bufio.Writer, nonce []byte) (int, error) {
accept := make([]byte, acceptSize)
initAcceptFromNonce(accept, nonce)
// NOTE: write accept bytes as a string to prevent heap allocation
// WriteString() copy given string into its inner buffer, unlike Write()
// which may write p directly to the underlying io.Writer which in turn
// will lead to p escape.
return bw.WriteString(btsToString(accept))
}

View file

@ -0,0 +1,147 @@
package ws
import (
"encoding/binary"
"fmt"
"io"
)
// Errors used by frame reader.
var (
ErrHeaderLengthMSB = fmt.Errorf("header error: the most significant bit must be 0")
ErrHeaderLengthUnexpected = fmt.Errorf("header error: unexpected payload length bits")
)
// ReadHeader reads a frame header from r.
func ReadHeader(r io.Reader) (h Header, err error) {
// Make slice of bytes with capacity 12 that could hold any header.
//
// The maximum header size is 14, but due to the 2 hop reads,
// after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
// So 14 - 2 = 12.
bts := make([]byte, 2, MaxHeaderSize-2)
// Prepare to hold first 2 bytes to choose size of next read.
_, err = io.ReadFull(r, bts)
if err != nil {
return
}
h.Fin = bts[0]&bit0 != 0
h.Rsv = (bts[0] & 0x70) >> 4
h.OpCode = OpCode(bts[0] & 0x0f)
var extra int
if bts[1]&bit0 != 0 {
h.Masked = true
extra += 4
}
length := bts[1] & 0x7f
switch {
case length < 126:
h.Length = int64(length)
case length == 126:
extra += 2
case length == 127:
extra += 8
default:
err = ErrHeaderLengthUnexpected
return
}
if extra == 0 {
return
}
// Increase len of bts to extra bytes need to read.
// Overwrite first 2 bytes that was read before.
bts = bts[:extra]
_, err = io.ReadFull(r, bts)
if err != nil {
return
}
switch {
case length == 126:
h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
bts = bts[2:]
case length == 127:
if bts[0]&0x80 != 0 {
err = ErrHeaderLengthMSB
return
}
h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
bts = bts[8:]
}
if h.Masked {
copy(h.Mask[:], bts)
}
return
}
// ReadFrame reads a frame from r.
// It is not designed for high optimized use case cause it makes allocation
// for frame.Header.Length size inside to read frame payload into.
//
// Note that ReadFrame does not unmask payload.
func ReadFrame(r io.Reader) (f Frame, err error) {
f.Header, err = ReadHeader(r)
if err != nil {
return
}
if f.Header.Length > 0 {
// int(f.Header.Length) is safe here cause we have
// checked it for overflow above in ReadHeader.
f.Payload = make([]byte, int(f.Header.Length))
_, err = io.ReadFull(r, f.Payload)
}
return
}
// MustReadFrame is like ReadFrame but panics if frame can not be read.
func MustReadFrame(r io.Reader) Frame {
f, err := ReadFrame(r)
if err != nil {
panic(err)
}
return f
}
// ParseCloseFrameData parses close frame status code and closure reason if any provided.
// If there is no status code in the payload
// the empty status code is returned (code.Empty()) with empty string as a reason.
func ParseCloseFrameData(payload []byte) (code StatusCode, reason string) {
if len(payload) < 2 {
// We returning empty StatusCode here, preventing the situation
// when endpoint really sent code 1005 and we should return ProtocolError on that.
//
// In other words, we ignoring this rule [RFC6455:7.1.5]:
// If this Close control frame contains no status code, _The WebSocket
// Connection Close Code_ is considered to be 1005.
return
}
code = StatusCode(binary.BigEndian.Uint16(payload))
reason = string(payload[2:])
return
}
// ParseCloseFrameDataUnsafe is like ParseCloseFrameData except the thing
// that it does not copies payload bytes into reason, but prepares unsafe cast.
func ParseCloseFrameDataUnsafe(payload []byte) (code StatusCode, reason string) {
if len(payload) < 2 {
return
}
code = StatusCode(binary.BigEndian.Uint16(payload))
reason = btsToString(payload[2:])
return
}

View file

@ -0,0 +1,663 @@
package ws
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/gobwas/httphead"
"github.com/gobwas/pool/pbufio"
)
// Constants used by ConnUpgrader.
const (
DefaultServerReadBufferSize = 4096
DefaultServerWriteBufferSize = 512
)
// Errors used by both client and server when preparing WebSocket handshake.
var (
ErrHandshakeBadProtocol = RejectConnectionError(
RejectionStatus(http.StatusHTTPVersionNotSupported),
RejectionReason(fmt.Sprintf("handshake error: bad HTTP protocol version")),
)
ErrHandshakeBadMethod = RejectConnectionError(
RejectionStatus(http.StatusMethodNotAllowed),
RejectionReason(fmt.Sprintf("handshake error: bad HTTP request method")),
)
ErrHandshakeBadHost = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerHost)),
)
ErrHandshakeBadUpgrade = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerUpgrade)),
)
ErrHandshakeBadConnection = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerConnection)),
)
ErrHandshakeBadSecAccept = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecAccept)),
)
ErrHandshakeBadSecKey = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecKey)),
)
ErrHandshakeBadSecVersion = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
)
)
// ErrMalformedResponse is returned by Dialer to indicate that server response
// can not be parsed.
var ErrMalformedResponse = fmt.Errorf("malformed HTTP response")
// ErrMalformedRequest is returned when HTTP request can not be parsed.
var ErrMalformedRequest = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason("malformed HTTP request"),
)
// ErrHandshakeUpgradeRequired is returned by Upgrader to indicate that
// connection is rejected because given WebSocket version is malformed.
//
// According to RFC6455:
// If this version does not match a version understood by the server, the
// server MUST abort the WebSocket handshake described in this section and
// instead send an appropriate HTTP error code (such as 426 Upgrade Required)
// and a |Sec-WebSocket-Version| header field indicating the version(s) the
// server is capable of understanding.
var ErrHandshakeUpgradeRequired = RejectConnectionError(
RejectionStatus(http.StatusUpgradeRequired),
RejectionHeader(HandshakeHeaderString(headerSecVersion+": 13\r\n")),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
)
// ErrNotHijacker is an error returned when http.ResponseWriter does not
// implement http.Hijacker interface.
var ErrNotHijacker = RejectConnectionError(
RejectionStatus(http.StatusInternalServerError),
RejectionReason("given http.ResponseWriter is not a http.Hijacker"),
)
// DefaultHTTPUpgrader is an HTTPUpgrader that holds no options and is used by
// UpgradeHTTP function.
var DefaultHTTPUpgrader HTTPUpgrader
// UpgradeHTTP is like HTTPUpgrader{}.Upgrade().
func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, Handshake, error) {
return DefaultHTTPUpgrader.Upgrade(r, w)
}
// DefaultUpgrader is an Upgrader that holds no options and is used by Upgrade
// function.
var DefaultUpgrader Upgrader
// Upgrade is like Upgrader{}.Upgrade().
func Upgrade(conn io.ReadWriter) (Handshake, error) {
return DefaultUpgrader.Upgrade(conn)
}
// HTTPUpgrader contains options for upgrading connection to websocket from
// net/http Handler arguments.
type HTTPUpgrader struct {
// Timeout is the maximum amount of time an Upgrade() will spent while
// writing handshake response.
//
// The default is no timeout.
Timeout time.Duration
// Header is an optional http.Header mapping that could be used to
// write additional headers to the handshake response.
//
// Note that if present, it will be written in any result of handshake.
Header http.Header
// Protocol is the select function that is used to select subprotocol from
// list requested by client. If this field is set, then the first matched
// protocol is sent to a client as negotiated.
Protocol func(string) bool
// Extension is the select function that is used to select extensions from
// list requested by client. If this field is set, then the all matched
// extensions are sent to a client as negotiated.
//
// DEPRECATED. Use Negotiate instead.
Extension func(httphead.Option) bool
// Negotiate is the callback that is used to negotiate extensions from
// the client's offer. If this field is set, then the returned non-zero
// extensions are sent to the client as accepted extensions in the
// response.
//
// The argument is only valid until the Negotiate callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
Negotiate func(httphead.Option) (httphead.Option, error)
}
// Upgrade upgrades http connection to the websocket connection.
//
// It hijacks net.Conn from w and returns received net.Conn and
// bufio.ReadWriter. On successful handshake it returns Handshake struct
// describing handshake info.
func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) {
// Hijack connection first to get the ability to write rejection errors the
// same way as in Upgrader.
hj, ok := w.(http.Hijacker)
if ok {
conn, rw, err = hj.Hijack()
} else {
err = ErrNotHijacker
}
if err != nil {
httpError(w, err.Error(), http.StatusInternalServerError)
return
}
// See https://tools.ietf.org/html/rfc6455#section-4.1
// The method of the request MUST be GET, and the HTTP version MUST be at least 1.1.
var nonce string
if r.Method != http.MethodGet {
err = ErrHandshakeBadMethod
} else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) {
err = ErrHandshakeBadProtocol
} else if r.Host == "" {
err = ErrHandshakeBadHost
} else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") {
err = ErrHandshakeBadUpgrade
} else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") {
err = ErrHandshakeBadConnection
} else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize {
err = ErrHandshakeBadSecKey
} else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" {
// According to RFC6455:
//
// If this version does not match a version understood by the server,
// the server MUST abort the WebSocket handshake described in this
// section and instead send an appropriate HTTP error code (such as 426
// Upgrade Required) and a |Sec-WebSocket-Version| header field
// indicating the version(s) the server is capable of understanding.
//
// So we branching here cause empty or not present version does not
// meet the ABNF rules of RFC6455:
//
// version = DIGIT | (NZDIGIT DIGIT) |
// ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
// ; Limited to 0-255 range, with no leading zeros
//
// That is, if version is really invalid we sent 426 status, if it
// not present or empty it is 400.
if v != "" {
err = ErrHandshakeUpgradeRequired
} else {
err = ErrHandshakeBadSecVersion
}
}
if check := u.Protocol; err == nil && check != nil {
ps := r.Header[headerSecProtocolCanonical]
for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ {
var ok bool
hs.Protocol, ok = strSelectProtocol(ps[i], check)
if !ok {
err = ErrMalformedRequest
}
}
}
if f := u.Negotiate; err == nil && f != nil {
for _, h := range r.Header[headerSecExtensionsCanonical] {
hs.Extensions, err = negotiateExtensions(strToBytes(h), hs.Extensions, f)
if err != nil {
break
}
}
}
// DEPRECATED path.
if check := u.Extension; err == nil && check != nil && u.Negotiate == nil {
xs := r.Header[headerSecExtensionsCanonical]
for i := 0; i < len(xs) && err == nil; i++ {
var ok bool
hs.Extensions, ok = btsSelectExtensions(strToBytes(xs[i]), hs.Extensions, check)
if !ok {
err = ErrMalformedRequest
}
}
}
// Clear deadlines set by server.
conn.SetDeadline(noDeadline)
if t := u.Timeout; t != 0 {
conn.SetWriteDeadline(time.Now().Add(t))
defer conn.SetWriteDeadline(noDeadline)
}
var header handshakeHeader
if h := u.Header; h != nil {
header[0] = HandshakeHeaderHTTP(h)
}
if err == nil {
httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo)
err = rw.Writer.Flush()
} else {
var code int
if rej, ok := err.(*rejectConnectionError); ok {
code = rej.code
header[1] = rej.header
}
if code == 0 {
code = http.StatusInternalServerError
}
httpWriteResponseError(rw.Writer, err, code, header.WriteTo)
// Do not store Flush() error to not override already existing one.
rw.Writer.Flush()
}
return
}
// Upgrader contains options for upgrading connection to websocket.
type Upgrader struct {
// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
// They used to read and write http data while upgrading to WebSocket.
// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
//
// If a size is zero then default value is used.
//
// Usually it is useful to set read buffer size bigger than write buffer
// size because incoming request could contain long header values, such as
// Cookie. Response, in other way, could be big only if user write multiple
// custom headers. Usually response takes less than 256 bytes.
ReadBufferSize, WriteBufferSize int
// Protocol is a select function that is used to select subprotocol
// from list requested by client. If this field is set, then the first matched
// protocol is sent to a client as negotiated.
//
// The argument is only valid until the callback returns.
Protocol func([]byte) bool
// ProtocolCustrom allow user to parse Sec-WebSocket-Protocol header manually.
// Note that returned bytes must be valid until Upgrade returns.
// If ProtocolCustom is set, it used instead of Protocol function.
ProtocolCustom func([]byte) (string, bool)
// Extension is a select function that is used to select extensions
// from list requested by client. If this field is set, then the all matched
// extensions are sent to a client as negotiated.
//
// Note that Extension may be called multiple times and implementations
// must track uniqueness of accepted extensions manually.
//
// The argument is only valid until the callback returns.
//
// According to the RFC6455 order of extensions passed by a client is
// significant. That is, returning true from this function means that no
// other extension with the same name should be checked because server
// accepted the most preferable extension right now:
// "Note that the order of extensions is significant. Any interactions between
// multiple extensions MAY be defined in the documents defining the extensions.
// In the absence of such definitions, the interpretation is that the header
// fields listed by the client in its request represent a preference of the
// header fields it wishes to use, with the first options listed being most
// preferable."
//
// DEPRECATED. Use Negotiate instead.
Extension func(httphead.Option) bool
// ExtensionCustom allow user to parse Sec-WebSocket-Extensions header
// manually.
//
// If ExtensionCustom() decides to accept received extension, it must
// append appropriate option to the given slice of httphead.Option.
// It returns results of append() to the given slice and a flag that
// reports whether given header value is wellformed or not.
//
// Note that ExtensionCustom may be called multiple times and
// implementations must track uniqueness of accepted extensions manually.
//
// Note that returned options should be valid until Upgrade returns.
// If ExtensionCustom is set, it used instead of Extension function.
ExtensionCustom func([]byte, []httphead.Option) ([]httphead.Option, bool)
// Negotiate is the callback that is used to negotiate extensions from
// the client's offer. If this field is set, then the returned non-zero
// extensions are sent to the client as accepted extensions in the
// response.
//
// The argument is only valid until the Negotiate callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
Negotiate func(httphead.Option) (httphead.Option, error)
// Header is an optional HandshakeHeader instance that could be used to
// write additional headers to the handshake response.
//
// It used instead of any key-value mappings to avoid allocations in user
// land.
//
// Note that if present, it will be written in any result of handshake.
Header HandshakeHeader
// OnRequest is a callback that will be called after request line
// successful parsing.
//
// The arguments are only valid until the callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnRequest func(uri []byte) error
// OnHost is a callback that will be called after "Host" header successful
// parsing.
//
// It is separated from OnHeader callback because the Host header must be
// present in each request since HTTP/1.1. Thus Host header is non-optional
// and required for every WebSocket handshake.
//
// The arguments are only valid until the callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnHost func(host []byte) error
// OnHeader is a callback that will be called after successful parsing of
// header, that is not used during WebSocket handshake procedure. That is,
// it will be called with non-websocket headers, which could be relevant
// for application-level logic.
//
// The arguments are only valid until the callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnHeader func(key, value []byte) error
// OnBeforeUpgrade is a callback that will be called before sending
// successful upgrade response.
//
// Setting OnBeforeUpgrade allows user to make final application-level
// checks and decide whether this connection is allowed to successfully
// upgrade to WebSocket.
//
// It must return non-nil either HandshakeHeader or error and never both.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnBeforeUpgrade func() (header HandshakeHeader, err error)
}
// Upgrade zero-copy upgrades connection to WebSocket. It interprets given conn
// as connection with incoming HTTP Upgrade request.
//
// It is a caller responsibility to manage i/o timeouts on conn.
//
// Non-nil error means that request for the WebSocket upgrade is invalid or
// malformed and usually connection should be closed.
// Even when error is non-nil Upgrade will write appropriate response into
// connection in compliance with RFC.
func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
// headerSeen constants helps to report whether or not some header was seen
// during reading request bytes.
const (
headerSeenHost = 1 << iota
headerSeenUpgrade
headerSeenConnection
headerSeenSecVersion
headerSeenSecKey
// headerSeenAll is the value that we expect to receive at the end of
// headers read/parse loop.
headerSeenAll = 0 |
headerSeenHost |
headerSeenUpgrade |
headerSeenConnection |
headerSeenSecVersion |
headerSeenSecKey
)
// Prepare I/O buffers.
// TODO(gobwas): make it configurable.
br := pbufio.GetReader(conn,
nonZero(u.ReadBufferSize, DefaultServerReadBufferSize),
)
bw := pbufio.GetWriter(conn,
nonZero(u.WriteBufferSize, DefaultServerWriteBufferSize),
)
defer func() {
pbufio.PutReader(br)
pbufio.PutWriter(bw)
}()
// Read HTTP request line like "GET /ws HTTP/1.1".
rl, err := readLine(br)
if err != nil {
return
}
// Parse request line data like HTTP version, uri and method.
req, err := httpParseRequestLine(rl)
if err != nil {
return
}
// Prepare stack-based handshake header list.
header := handshakeHeader{
0: u.Header,
}
// Parse and check HTTP request.
// As RFC6455 says:
// The client's opening handshake consists of the following parts. If the
// server, while reading the handshake, finds that the client did not
// send a handshake that matches the description below (note that as per
// [RFC2616], the order of the header fields is not important), including
// but not limited to any violations of the ABNF grammar specified for
// the components of the handshake, the server MUST stop processing the
// client's handshake and return an HTTP response with an appropriate
// error code (such as 400 Bad Request).
//
// See https://tools.ietf.org/html/rfc6455#section-4.2.1
// An HTTP/1.1 or higher GET request, including a "Request-URI".
//
// Even if RFC says "1.1 or higher" without mentioning the part of the
// version, we apply it only to minor part.
switch {
case req.major != 1 || req.minor < 1:
// Abort processing the whole request because we do not even know how
// to actually parse it.
err = ErrHandshakeBadProtocol
case btsToString(req.method) != http.MethodGet:
err = ErrHandshakeBadMethod
default:
if onRequest := u.OnRequest; onRequest != nil {
err = onRequest(req.uri)
}
}
// Start headers read/parse loop.
var (
// headerSeen reports which header was seen by setting corresponding
// bit on.
headerSeen byte
nonce = make([]byte, nonceSize)
)
for err == nil {
line, e := readLine(br)
if e != nil {
return hs, e
}
if len(line) == 0 {
// Blank line, no more lines to read.
break
}
k, v, ok := httpParseHeaderLine(line)
if !ok {
err = ErrMalformedRequest
break
}
switch btsToString(k) {
case headerHostCanonical:
headerSeen |= headerSeenHost
if onHost := u.OnHost; onHost != nil {
err = onHost(v)
}
case headerUpgradeCanonical:
headerSeen |= headerSeenUpgrade
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
err = ErrHandshakeBadUpgrade
}
case headerConnectionCanonical:
headerSeen |= headerSeenConnection
if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) {
err = ErrHandshakeBadConnection
}
case headerSecVersionCanonical:
headerSeen |= headerSeenSecVersion
if !bytes.Equal(v, specHeaderValueSecVersion) {
err = ErrHandshakeUpgradeRequired
}
case headerSecKeyCanonical:
headerSeen |= headerSeenSecKey
if len(v) != nonceSize {
err = ErrHandshakeBadSecKey
} else {
copy(nonce[:], v)
}
case headerSecProtocolCanonical:
if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) {
var ok bool
if custom != nil {
hs.Protocol, ok = custom(v)
} else {
hs.Protocol, ok = btsSelectProtocol(v, check)
}
if !ok {
err = ErrMalformedRequest
}
}
case headerSecExtensionsCanonical:
if f := u.Negotiate; err == nil && f != nil {
hs.Extensions, err = negotiateExtensions(v, hs.Extensions, f)
}
// DEPRECATED path.
if custom, check := u.ExtensionCustom, u.Extension; u.Negotiate == nil && (custom != nil || check != nil) {
var ok bool
if custom != nil {
hs.Extensions, ok = custom(v, hs.Extensions)
} else {
hs.Extensions, ok = btsSelectExtensions(v, hs.Extensions, check)
}
if !ok {
err = ErrMalformedRequest
}
}
default:
if onHeader := u.OnHeader; onHeader != nil {
err = onHeader(k, v)
}
}
}
switch {
case err == nil && headerSeen != headerSeenAll:
switch {
case headerSeen&headerSeenHost == 0:
// As RFC2616 says:
// A client MUST include a Host header field in all HTTP/1.1
// request messages. If the requested URI does not include an
// Internet host name for the service being requested, then the
// Host header field MUST be given with an empty value. An
// HTTP/1.1 proxy MUST ensure that any request message it
// forwards does contain an appropriate Host header field that
// identifies the service being requested by the proxy. All
// Internet-based HTTP/1.1 servers MUST respond with a 400 (Bad
// Request) status code to any HTTP/1.1 request message which
// lacks a Host header field.
err = ErrHandshakeBadHost
case headerSeen&headerSeenUpgrade == 0:
err = ErrHandshakeBadUpgrade
case headerSeen&headerSeenConnection == 0:
err = ErrHandshakeBadConnection
case headerSeen&headerSeenSecVersion == 0:
// In case of empty or not present version we do not send 426 status,
// because it does not meet the ABNF rules of RFC6455:
//
// version = DIGIT | (NZDIGIT DIGIT) |
// ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
// ; Limited to 0-255 range, with no leading zeros
//
// That is, if version is really invalid we sent 426 status as above, if it
// not present it is 400.
err = ErrHandshakeBadSecVersion
case headerSeen&headerSeenSecKey == 0:
err = ErrHandshakeBadSecKey
default:
panic("unknown headers state")
}
case err == nil && u.OnBeforeUpgrade != nil:
header[1], err = u.OnBeforeUpgrade()
}
if err != nil {
var code int
if rej, ok := err.(*rejectConnectionError); ok {
code = rej.code
header[1] = rej.header
}
if code == 0 {
code = http.StatusInternalServerError
}
httpWriteResponseError(bw, err, code, header.WriteTo)
// Do not store Flush() error to not override already existing one.
bw.Flush()
return
}
httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo)
err = bw.Flush()
return
}
type handshakeHeader [2]HandshakeHeader
func (hs handshakeHeader) WriteTo(w io.Writer) (n int64, err error) {
for i := 0; i < len(hs) && err == nil; i++ {
if h := hs[i]; h != nil {
var m int64
m, err = h.WriteTo(w)
n += m
}
}
return n, err
}

View file

View file

@ -0,0 +1,214 @@
package ws
import (
"bufio"
"bytes"
"fmt"
"reflect"
"unsafe"
"github.com/gobwas/httphead"
)
// SelectFromSlice creates accept function that could be used as Protocol/Extension
// select during upgrade.
func SelectFromSlice(accept []string) func(string) bool {
if len(accept) > 16 {
mp := make(map[string]struct{}, len(accept))
for _, p := range accept {
mp[p] = struct{}{}
}
return func(p string) bool {
_, ok := mp[p]
return ok
}
}
return func(p string) bool {
for _, ok := range accept {
if p == ok {
return true
}
}
return false
}
}
// SelectEqual creates accept function that could be used as Protocol/Extension
// select during upgrade.
func SelectEqual(v string) func(string) bool {
return func(p string) bool {
return v == p
}
}
func strToBytes(str string) (bts []byte) {
s := (*reflect.StringHeader)(unsafe.Pointer(&str))
b := (*reflect.SliceHeader)(unsafe.Pointer(&bts))
b.Data = s.Data
b.Len = s.Len
b.Cap = s.Len
return
}
func btsToString(bts []byte) (str string) {
return *(*string)(unsafe.Pointer(&bts))
}
// asciiToInt converts bytes to int.
func asciiToInt(bts []byte) (ret int, err error) {
// ASCII numbers all start with the high-order bits 0011.
// If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those
// bits and interpret them directly as an integer.
var n int
if n = len(bts); n < 1 {
return 0, fmt.Errorf("converting empty bytes to int")
}
for i := 0; i < n; i++ {
if bts[i]&0xf0 != 0x30 {
return 0, fmt.Errorf("%s is not a numeric character", string(bts[i]))
}
ret += int(bts[i]&0xf) * pow(10, n-i-1)
}
return ret, nil
}
// pow for integers implementation.
// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3
func pow(a, b int) int {
p := 1
for b > 0 {
if b&1 != 0 {
p *= a
}
b >>= 1
a *= a
}
return p
}
func bsplit3(bts []byte, sep byte) (b1, b2, b3 []byte) {
a := bytes.IndexByte(bts, sep)
b := bytes.IndexByte(bts[a+1:], sep)
if a == -1 || b == -1 {
return bts, nil, nil
}
b += a + 1
return bts[:a], bts[a+1 : b], bts[b+1:]
}
func btrim(bts []byte) []byte {
var i, j int
for i = 0; i < len(bts) && (bts[i] == ' ' || bts[i] == '\t'); {
i++
}
for j = len(bts); j > i && (bts[j-1] == ' ' || bts[j-1] == '\t'); {
j--
}
return bts[i:j]
}
func strHasToken(header, token string) (has bool) {
return btsHasToken(strToBytes(header), strToBytes(token))
}
func btsHasToken(header, token []byte) (has bool) {
httphead.ScanTokens(header, func(v []byte) bool {
has = bytes.EqualFold(v, token)
return !has
})
return
}
const (
toLower = 'a' - 'A' // for use with OR.
toUpper = ^byte(toLower) // for use with AND.
toLower8 = uint64(toLower) |
uint64(toLower)<<8 |
uint64(toLower)<<16 |
uint64(toLower)<<24 |
uint64(toLower)<<32 |
uint64(toLower)<<40 |
uint64(toLower)<<48 |
uint64(toLower)<<56
)
// Algorithm below is like standard textproto/CanonicalMIMEHeaderKey, except
// that it operates with slice of bytes and modifies it inplace without copying.
func canonicalizeHeaderKey(k []byte) {
upper := true
for i, c := range k {
if upper && 'a' <= c && c <= 'z' {
k[i] &= toUpper
} else if !upper && 'A' <= c && c <= 'Z' {
k[i] |= toLower
}
upper = c == '-'
}
}
// readLine reads line from br. It reads until '\n' and returns bytes without
// '\n' or '\r\n' at the end.
// It returns err if and only if line does not end in '\n'. Note that read
// bytes returned in any case of error.
//
// It is much like the textproto/Reader.ReadLine() except the thing that it
// returns raw bytes, instead of string. That is, it avoids copying bytes read
// from br.
//
// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be
// safe with future I/O operations on br.
//
// We could control I/O operations on br and do not need to make additional
// copy for safety.
//
// NOTE: it may return copied flag to notify that returned buffer is safe to
// use.
func readLine(br *bufio.Reader) ([]byte, error) {
var line []byte
for {
bts, err := br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
// Copy bytes because next read will discard them.
line = append(line, bts...)
continue
}
// Avoid copy of single read.
if line == nil {
line = bts
} else {
line = append(line, bts...)
}
if err != nil {
return line, err
}
// Size of line is at least 1.
// In other case bufio.ReadSlice() returns error.
n := len(line)
// Cut '\n' or '\r\n'.
if n > 1 && line[n-2] == '\r' {
line = line[:n-2]
} else {
line = line[:n-1]
}
return line, nil
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func nonZero(a, b int) int {
if a != 0 {
return a
}
return b
}

View file

@ -0,0 +1,104 @@
package ws
import (
"encoding/binary"
"io"
)
// Header size length bounds in bytes.
const (
MaxHeaderSize = 14
MinHeaderSize = 2
)
const (
bit0 = 0x80
bit1 = 0x40
bit2 = 0x20
bit3 = 0x10
bit4 = 0x08
bit5 = 0x04
bit6 = 0x02
bit7 = 0x01
len7 = int64(125)
len16 = int64(^(uint16(0)))
len64 = int64(^(uint64(0)) >> 1)
)
// HeaderSize returns number of bytes that are needed to encode given header.
// It returns -1 if header is malformed.
func HeaderSize(h Header) (n int) {
switch {
case h.Length < 126:
n = 2
case h.Length <= len16:
n = 4
case h.Length <= len64:
n = 10
default:
return -1
}
if h.Masked {
n += len(h.Mask)
}
return n
}
// WriteHeader writes header binary representation into w.
func WriteHeader(w io.Writer, h Header) error {
// Make slice of bytes with capacity 14 that could hold any header.
bts := make([]byte, MaxHeaderSize)
if h.Fin {
bts[0] |= bit0
}
bts[0] |= h.Rsv << 4
bts[0] |= byte(h.OpCode)
var n int
switch {
case h.Length <= len7:
bts[1] = byte(h.Length)
n = 2
case h.Length <= len16:
bts[1] = 126
binary.BigEndian.PutUint16(bts[2:4], uint16(h.Length))
n = 4
case h.Length <= len64:
bts[1] = 127
binary.BigEndian.PutUint64(bts[2:10], uint64(h.Length))
n = 10
default:
return ErrHeaderLengthUnexpected
}
if h.Masked {
bts[1] |= bit0
n += copy(bts[n:], h.Mask[:])
}
_, err := w.Write(bts[:n])
return err
}
// WriteFrame writes frame binary representation into w.
func WriteFrame(w io.Writer, f Frame) error {
err := WriteHeader(w, f.Header)
if err != nil {
return err
}
_, err = w.Write(f.Payload)
return err
}
// MustWriteFrame is like WriteFrame but panics if frame can not be read.
func MustWriteFrame(w io.Writer, f Frame) {
if err := WriteFrame(w, f); err != nil {
panic(err)
}
}

View file

@ -0,0 +1,72 @@
package wsutil
import (
"io"
"github.com/gobwas/pool/pbytes"
"github.com/gobwas/ws"
)
// CipherReader implements io.Reader that applies xor-cipher to the bytes read
// from source.
// It could help to unmask WebSocket frame payload on the fly.
type CipherReader struct {
r io.Reader
mask [4]byte
pos int
}
// NewCipherReader creates xor-cipher reader from r with given mask.
func NewCipherReader(r io.Reader, mask [4]byte) *CipherReader {
return &CipherReader{r, mask, 0}
}
// Reset resets CipherReader to read from r with given mask.
func (c *CipherReader) Reset(r io.Reader, mask [4]byte) {
c.r = r
c.mask = mask
c.pos = 0
}
// Read implements io.Reader interface. It applies mask given during
// initialization to every read byte.
func (c *CipherReader) Read(p []byte) (n int, err error) {
n, err = c.r.Read(p)
ws.Cipher(p[:n], c.mask, c.pos)
c.pos += n
return
}
// CipherWriter implements io.Writer that applies xor-cipher to the bytes
// written to the destination writer. It does not modify the original bytes.
type CipherWriter struct {
w io.Writer
mask [4]byte
pos int
}
// NewCipherWriter creates xor-cipher writer to w with given mask.
func NewCipherWriter(w io.Writer, mask [4]byte) *CipherWriter {
return &CipherWriter{w, mask, 0}
}
// Reset reset CipherWriter to write to w with given mask.
func (c *CipherWriter) Reset(w io.Writer, mask [4]byte) {
c.w = w
c.mask = mask
c.pos = 0
}
// Write implements io.Writer interface. It applies masking during
// initialization to every sent byte. It does not modify original slice.
func (c *CipherWriter) Write(p []byte) (n int, err error) {
cp := pbytes.GetLen(len(p))
defer pbytes.Put(cp)
copy(cp, p)
ws.Cipher(cp, c.mask, c.pos)
n, err = c.w.Write(cp)
c.pos += n
return
}

View file

@ -0,0 +1,146 @@
package wsutil
import (
"bufio"
"bytes"
"context"
"io"
"io/ioutil"
"net"
"net/http"
"github.com/gobwas/ws"
)
// DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
// handshake. That is, it gives ability to receive copied HTTP request and
// response bytes that made inside Dialer.Dial().
//
// Note that it must not be used in production applications that requires
// Dial() to be efficient.
type DebugDialer struct {
// Dialer contains WebSocket connection establishment options.
Dialer ws.Dialer
// OnRequest and OnResponse are the callbacks that will be called with the
// HTTP request and response respectively.
OnRequest, OnResponse func([]byte)
}
// Dial connects to the url host and upgrades connection to WebSocket. It makes
// it by calling d.Dialer.Dial().
func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
// Need to copy Dialer to prevent original object mutation.
dialer := d.Dialer
var (
reqBuf bytes.Buffer
resBuf bytes.Buffer
resContentLength int64
)
userWrap := dialer.WrapConn
dialer.WrapConn = func(c net.Conn) net.Conn {
if userWrap != nil {
c = userWrap(c)
}
// Save the pointer to the raw connection.
conn = c
var (
r io.Reader = conn
w io.Writer = conn
)
if d.OnResponse != nil {
r = &prefetchResponseReader{
source: conn,
buffer: &resBuf,
contentLength: &resContentLength,
}
}
if d.OnRequest != nil {
w = io.MultiWriter(conn, &reqBuf)
}
return rwConn{conn, r, w}
}
_, br, hs, err = dialer.Dial(ctx, urlstr)
if onRequest := d.OnRequest; onRequest != nil {
onRequest(reqBuf.Bytes())
}
if onResponse := d.OnResponse; onResponse != nil {
// We must split response inside buffered bytes from other received
// bytes from server.
p := resBuf.Bytes()
n := bytes.Index(p, headEnd)
h := n + len(headEnd) // Head end index.
n = h + int(resContentLength) // Body end index.
onResponse(p[:n])
if br != nil {
// If br is non-nil, then it mean two things. First is that
// handshake is OK and server has sent additional bytes probably
// immediate sent frames (or weird but possible response body).
// Second, the bad one, is that br buffer's source is now rwConn
// instance from above WrapConn call. It is incorrect, so we must
// fix it.
var r io.Reader = conn
if len(p) > h {
// Buffer contains more than just HTTP headers bytes.
r = io.MultiReader(
bytes.NewReader(p[h:]),
conn,
)
}
br.Reset(r)
// Must make br.Buffered() to be non-zero.
br.Peek(len(p[h:]))
}
}
return conn, br, hs, err
}
type rwConn struct {
net.Conn
r io.Reader
w io.Writer
}
func (rwc rwConn) Read(p []byte) (int, error) {
return rwc.r.Read(p)
}
func (rwc rwConn) Write(p []byte) (int, error) {
return rwc.w.Write(p)
}
var headEnd = []byte("\r\n\r\n")
type prefetchResponseReader struct {
source io.Reader // Original connection source.
reader io.Reader // Wrapped reader used to read from by clients.
buffer *bytes.Buffer
contentLength *int64
}
func (r *prefetchResponseReader) Read(p []byte) (int, error) {
if r.reader == nil {
resp, err := http.ReadResponse(bufio.NewReader(
io.TeeReader(r.source, r.buffer),
), nil)
if err == nil {
*r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
}
bts := r.buffer.Bytes()
r.reader = io.MultiReader(
bytes.NewReader(bts),
r.source,
)
}
return r.reader.Read(p)
}

View file

@ -0,0 +1,29 @@
package wsutil
// RecvExtension is an interface for clearing fragment header RSV bits.
type RecvExtension interface {
BitsRecv(seq int, rsv byte) (byte, error)
}
// RecvExtensionFunc is an adapter to allow the use of ordinary functions as
// RecvExtension.
type RecvExtensionFunc func(int, byte) (byte, error)
// BitsRecv implements RecvExtension.
func (fn RecvExtensionFunc) BitsRecv(seq int, rsv byte) (byte, error) {
return fn(seq, rsv)
}
// SendExtension is an interface for setting fragment header RSV bits.
type SendExtension interface {
BitsSend(seq int, rsv byte) (byte, error)
}
// SendExtensionFunc is an adapter to allow the use of ordinary functions as
// SendExtension.
type SendExtensionFunc func(int, byte) (byte, error)
// BitsSend implements SendExtension.
func (fn SendExtensionFunc) BitsSend(seq int, rsv byte) (byte, error) {
return fn(seq, rsv)
}

View file

@ -0,0 +1,219 @@
package wsutil
import (
"errors"
"io"
"io/ioutil"
"strconv"
"github.com/gobwas/pool/pbytes"
"github.com/gobwas/ws"
)
// ClosedError returned when peer has closed the connection with appropriate
// code and a textual reason.
type ClosedError struct {
Code ws.StatusCode
Reason string
}
// Error implements error interface.
func (err ClosedError) Error() string {
return "ws closed: " + strconv.FormatUint(uint64(err.Code), 10) + " " + err.Reason
}
// ControlHandler contains logic of handling control frames.
//
// The intentional way to use it is to read the next frame header from the
// connection, optionally check its validity via ws.CheckHeader() and if it is
// not a ws.OpText of ws.OpBinary (or ws.OpContinuation) pass it to Handle()
// method.
//
// That is, passed header should be checked to get rid of unexpected errors.
//
// The Handle() method will read out all control frame payload (if any) and
// write necessary bytes as a rfc compatible response.
type ControlHandler struct {
Src io.Reader
Dst io.Writer
State ws.State
// DisableSrcCiphering disables unmasking payload data read from Src.
// It is useful when wsutil.Reader is used or when frame payload already
// pulled and ciphered out from the connection (and introduced by
// bytes.Reader, for example).
DisableSrcCiphering bool
}
// ErrNotControlFrame is returned by ControlHandler to indicate that given
// header could not be handled.
var ErrNotControlFrame = errors.New("not a control frame")
// Handle handles control frames regarding to the c.State and writes responses
// to the c.Dst when needed.
//
// It returns ErrNotControlFrame when given header is not of ws.OpClose,
// ws.OpPing or ws.OpPong operation code.
func (c ControlHandler) Handle(h ws.Header) error {
switch h.OpCode {
case ws.OpPing:
return c.HandlePing(h)
case ws.OpPong:
return c.HandlePong(h)
case ws.OpClose:
return c.HandleClose(h)
}
return ErrNotControlFrame
}
// HandlePing handles ping frame and writes specification compatible response
// to the c.Dst.
func (c ControlHandler) HandlePing(h ws.Header) error {
if h.Length == 0 {
// The most common case when ping is empty.
// Note that when sending masked frame the mask for empty payload is
// just four zero bytes.
return ws.WriteHeader(c.Dst, ws.Header{
Fin: true,
OpCode: ws.OpPong,
Masked: c.State.ClientSide(),
})
}
// In other way reply with Pong frame with copied payload.
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
Length: h.Length,
Masked: c.State.ClientSide(),
}))
defer pbytes.Put(p)
// Deal with ciphering i/o:
// Masking key is used to mask the "Payload data" defined in the same
// section as frame-payload-data, which includes "Extension data" and
// "Application data".
//
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// NOTE: We prefer ControlWriter with preallocated buffer to
// ws.WriteHeader because it performs one syscall instead of two.
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpPong, p)
r := c.Src
if c.State.ServerSide() && !c.DisableSrcCiphering {
r = NewCipherReader(r, h.Mask)
}
_, err := io.Copy(w, r)
if err == nil {
err = w.Flush()
}
return err
}
// HandlePong handles pong frame by discarding it.
func (c ControlHandler) HandlePong(h ws.Header) error {
if h.Length == 0 {
return nil
}
buf := pbytes.GetLen(int(h.Length))
defer pbytes.Put(buf)
// Discard pong message according to the RFC6455:
// A Pong frame MAY be sent unsolicited. This serves as a
// unidirectional heartbeat. A response to an unsolicited Pong frame
// is not expected.
_, err := io.CopyBuffer(ioutil.Discard, c.Src, buf)
return err
}
// HandleClose handles close frame, makes protocol validity checks and writes
// specification compatible response to the c.Dst.
func (c ControlHandler) HandleClose(h ws.Header) error {
if h.Length == 0 {
err := ws.WriteHeader(c.Dst, ws.Header{
Fin: true,
OpCode: ws.OpClose,
Masked: c.State.ClientSide(),
})
if err != nil {
return err
}
// Due to RFC, we should interpret the code as no status code
// received:
// If this Close control frame contains no status code, _The WebSocket
// Connection Close Code_ is considered to be 1005.
//
// See https://tools.ietf.org/html/rfc6455#section-7.1.5
return ClosedError{
Code: ws.StatusNoStatusRcvd,
}
}
// Prepare bytes both for reading reason and sending response.
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
Length: h.Length,
Masked: c.State.ClientSide(),
}))
defer pbytes.Put(p)
// Get the subslice to read the frame payload out.
subp := p[:h.Length]
r := c.Src
if c.State.ServerSide() && !c.DisableSrcCiphering {
r = NewCipherReader(r, h.Mask)
}
if _, err := io.ReadFull(r, subp); err != nil {
return err
}
code, reason := ws.ParseCloseFrameData(subp)
if err := ws.CheckCloseFrameData(code, reason); err != nil {
// Here we could not use the prepared bytes because there is no
// guarantee that it may fit our protocol error closure code and a
// reason.
c.closeWithProtocolError(err)
return err
}
// Deal with ciphering i/o:
// Masking key is used to mask the "Payload data" defined in the same
// section as frame-payload-data, which includes "Extension data" and
// "Application data".
//
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// NOTE: We prefer ControlWriter with preallocated buffer to
// ws.WriteHeader because it performs one syscall instead of two.
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpClose, p)
// RFC6455#5.5.1:
// If an endpoint receives a Close frame and did not previously
// send a Close frame, the endpoint MUST send a Close frame in
// response. (When sending a Close frame in response, the endpoint
// typically echoes the status code it received.)
_, err := w.Write(p[:2])
if err != nil {
return err
}
if err = w.Flush(); err != nil {
return err
}
return ClosedError{
Code: code,
Reason: reason,
}
}
func (c ControlHandler) closeWithProtocolError(reason error) error {
f := ws.NewCloseFrame(ws.NewCloseFrameBody(
ws.StatusProtocolError, reason.Error(),
))
if c.State.ClientSide() {
ws.MaskFrameInPlace(f)
}
return ws.WriteFrame(c.Dst, f)
}

View file

@ -0,0 +1,279 @@
package wsutil
import (
"bytes"
"io"
"io/ioutil"
"github.com/gobwas/ws"
)
// Message represents a message from peer, that could be presented in one or
// more frames. That is, it contains payload of all message fragments and
// operation code of initial frame for this message.
type Message struct {
OpCode ws.OpCode
Payload []byte
}
// ReadMessage is a helper function that reads next message from r. It appends
// received message(s) to the third argument and returns the result of it and
// an error if some failure happened. That is, it probably could receive more
// than one message when peer sending fragmented message in multiple frames and
// want to send some control frame between fragments. Then returned slice will
// contain those control frames at first, and then result of gluing fragments.
//
// TODO(gobwas): add DefaultReader with buffer size options.
func ReadMessage(r io.Reader, s ws.State, m []Message) ([]Message, error) {
rd := Reader{
Source: r,
State: s,
CheckUTF8: true,
OnIntermediate: func(hdr ws.Header, src io.Reader) error {
bts, err := ioutil.ReadAll(src)
if err != nil {
return err
}
m = append(m, Message{hdr.OpCode, bts})
return nil
},
}
h, err := rd.NextFrame()
if err != nil {
return m, err
}
var p []byte
if h.Fin {
// No more frames will be read. Use fixed sized buffer to read payload.
p = make([]byte, h.Length)
// It is not possible to receive io.EOF here because Reader does not
// return EOF if frame payload was successfully fetched.
// Thus we consistent here with io.Reader behavior.
_, err = io.ReadFull(&rd, p)
} else {
// Frame is fragmented, thus use ioutil.ReadAll behavior.
var buf bytes.Buffer
_, err = buf.ReadFrom(&rd)
p = buf.Bytes()
}
if err != nil {
return m, err
}
return append(m, Message{h.OpCode, p}), nil
}
// ReadClientMessage reads next message from r, considering that caller
// represents server side.
// It is a shortcut for ReadMessage(r, ws.StateServerSide, m)
func ReadClientMessage(r io.Reader, m []Message) ([]Message, error) {
return ReadMessage(r, ws.StateServerSide, m)
}
// ReadServerMessage reads next message from r, considering that caller
// represents client side.
// It is a shortcut for ReadMessage(r, ws.StateClientSide, m)
func ReadServerMessage(r io.Reader, m []Message) ([]Message, error) {
return ReadMessage(r, ws.StateClientSide, m)
}
// ReadData is a helper function that reads next data (non-control) message
// from rw.
// It takes care on handling all control frames. It will write response on
// control frames to the write part of rw. It blocks until some data frame
// will be received.
//
// Note this may handle and write control frames into the writer part of a
// given io.ReadWriter.
func ReadData(rw io.ReadWriter, s ws.State) ([]byte, ws.OpCode, error) {
return readData(rw, s, ws.OpText|ws.OpBinary)
}
// ReadClientData reads next data message from rw, considering that caller
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
//
// Note this may handle and write control frames into the writer part of a
// given io.ReadWriter.
func ReadClientData(rw io.ReadWriter) ([]byte, ws.OpCode, error) {
return ReadData(rw, ws.StateServerSide)
}
// ReadClientText reads next text message from rw, considering that caller
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
// It discards received binary messages.
//
// Note this may handle and write control frames into the writer part of a
// given io.ReadWriter.
func ReadClientText(rw io.ReadWriter) ([]byte, error) {
p, _, err := readData(rw, ws.StateServerSide, ws.OpText)
return p, err
}
// ReadClientBinary reads next binary message from rw, considering that caller
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
// It discards received text messages.
//
// Note this may handle and write control frames into the writer part of a given
// io.ReadWriter.
func ReadClientBinary(rw io.ReadWriter) ([]byte, error) {
p, _, err := readData(rw, ws.StateServerSide, ws.OpBinary)
return p, err
}
// ReadServerData reads next data message from rw, considering that caller
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
//
// Note this may handle and write control frames into the writer part of a
// given io.ReadWriter.
func ReadServerData(rw io.ReadWriter) ([]byte, ws.OpCode, error) {
return ReadData(rw, ws.StateClientSide)
}
// ReadServerText reads next text message from rw, considering that caller
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
// It discards received binary messages.
//
// Note this may handle and write control frames into the writer part of a given
// io.ReadWriter.
func ReadServerText(rw io.ReadWriter) ([]byte, error) {
p, _, err := readData(rw, ws.StateClientSide, ws.OpText)
return p, err
}
// ReadServerBinary reads next binary message from rw, considering that caller
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
// It discards received text messages.
//
// Note this may handle and write control frames into the writer part of a
// given io.ReadWriter.
func ReadServerBinary(rw io.ReadWriter) ([]byte, error) {
p, _, err := readData(rw, ws.StateClientSide, ws.OpBinary)
return p, err
}
// WriteMessage is a helper function that writes message to the w. It
// constructs single frame with given operation code and payload.
// It uses given state to prepare side-dependent things, like cipher
// payload bytes from client to server. It will not mutate p bytes if
// cipher must be made.
//
// If you want to write message in fragmented frames, use Writer instead.
func WriteMessage(w io.Writer, s ws.State, op ws.OpCode, p []byte) error {
return writeFrame(w, s, op, true, p)
}
// WriteServerMessage writes message to w, considering that caller
// represents server side.
func WriteServerMessage(w io.Writer, op ws.OpCode, p []byte) error {
return WriteMessage(w, ws.StateServerSide, op, p)
}
// WriteServerText is the same as WriteServerMessage with
// ws.OpText.
func WriteServerText(w io.Writer, p []byte) error {
return WriteServerMessage(w, ws.OpText, p)
}
// WriteServerBinary is the same as WriteServerMessage with
// ws.OpBinary.
func WriteServerBinary(w io.Writer, p []byte) error {
return WriteServerMessage(w, ws.OpBinary, p)
}
// WriteClientMessage writes message to w, considering that caller
// represents client side.
func WriteClientMessage(w io.Writer, op ws.OpCode, p []byte) error {
return WriteMessage(w, ws.StateClientSide, op, p)
}
// WriteClientText is the same as WriteClientMessage with
// ws.OpText.
func WriteClientText(w io.Writer, p []byte) error {
return WriteClientMessage(w, ws.OpText, p)
}
// WriteClientBinary is the same as WriteClientMessage with
// ws.OpBinary.
func WriteClientBinary(w io.Writer, p []byte) error {
return WriteClientMessage(w, ws.OpBinary, p)
}
// HandleClientControlMessage handles control frame from conn and writes
// response when needed.
//
// It considers that caller represents server side.
func HandleClientControlMessage(conn io.Writer, msg Message) error {
return HandleControlMessage(conn, ws.StateServerSide, msg)
}
// HandleServerControlMessage handles control frame from conn and writes
// response when needed.
//
// It considers that caller represents client side.
func HandleServerControlMessage(conn io.Writer, msg Message) error {
return HandleControlMessage(conn, ws.StateClientSide, msg)
}
// HandleControlMessage handles message which was read by ReadMessage()
// functions.
//
// That is, it is expected, that payload is already unmasked and frame header
// were checked by ws.CheckHeader() call.
func HandleControlMessage(conn io.Writer, state ws.State, msg Message) error {
return (ControlHandler{
DisableSrcCiphering: true,
Src: bytes.NewReader(msg.Payload),
Dst: conn,
State: state,
}).Handle(ws.Header{
Length: int64(len(msg.Payload)),
OpCode: msg.OpCode,
Fin: true,
Masked: state.ServerSide(),
})
}
// ControlFrameHandler returns FrameHandlerFunc for handling control frames.
// For more info see ControlHandler docs.
func ControlFrameHandler(w io.Writer, state ws.State) FrameHandlerFunc {
return func(h ws.Header, r io.Reader) error {
return (ControlHandler{
DisableSrcCiphering: true,
Src: r,
Dst: w,
State: state,
}).Handle(h)
}
}
func readData(rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, ws.OpCode, error) {
controlHandler := ControlFrameHandler(rw, s)
rd := Reader{
Source: rw,
State: s,
CheckUTF8: true,
SkipHeaderCheck: false,
OnIntermediate: controlHandler,
}
for {
hdr, err := rd.NextFrame()
if err != nil {
return nil, 0, err
}
if hdr.OpCode.IsControl() {
if err := controlHandler(hdr, &rd); err != nil {
return nil, 0, err
}
continue
}
if hdr.OpCode&want == 0 {
if err := rd.Discard(); err != nil {
return nil, 0, err
}
continue
}
bts, err := ioutil.ReadAll(&rd)
return bts, hdr.OpCode, err
}
}

View file

@ -0,0 +1,280 @@
package wsutil
import (
"errors"
"io"
"io/ioutil"
"github.com/gobwas/ws"
)
// ErrNoFrameAdvance means that Reader's Read() method was called without
// preceding NextFrame() call.
var ErrNoFrameAdvance = errors.New("no frame advance")
// FrameHandlerFunc handles parsed frame header and its body represented by
// io.Reader.
//
// Note that reader represents already unmasked body.
type FrameHandlerFunc func(ws.Header, io.Reader) error
// Reader is a wrapper around source io.Reader which represents WebSocket
// connection. It contains options for reading messages from source.
//
// Reader implements io.Reader, which Read() method reads payload of incoming
// WebSocket frames. It also takes care on fragmented frames and possibly
// intermediate control frames between them.
//
// Note that Reader's methods are not goroutine safe.
type Reader struct {
Source io.Reader
State ws.State
// SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
SkipHeaderCheck bool
// CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
// bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
CheckUTF8 bool
// Extensions is a list of negotiated extensions for reader Source.
// It is used to meet the specs and clear appropriate bits in fragment
// header RSV segment.
Extensions []RecvExtension
// TODO(gobwas): add max frame size limit here.
OnContinuation FrameHandlerFunc
OnIntermediate FrameHandlerFunc
opCode ws.OpCode // Used to store message op code on fragmentation.
frame io.Reader // Used to as frame reader.
raw io.LimitedReader // Used to discard frames without cipher.
utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
fseq int // Fragment sequence in message counter.
}
// NewReader creates new frame reader that reads from r keeping given state to
// make some protocol validity checks when it needed.
func NewReader(r io.Reader, s ws.State) *Reader {
return &Reader{
Source: r,
State: s,
}
}
// NewClientSideReader is a helper function that calls NewReader with r and
// ws.StateClientSide.
func NewClientSideReader(r io.Reader) *Reader {
return NewReader(r, ws.StateClientSide)
}
// NewServerSideReader is a helper function that calls NewReader with r and
// ws.StateServerSide.
func NewServerSideReader(r io.Reader) *Reader {
return NewReader(r, ws.StateServerSide)
}
// Read implements io.Reader. It reads the next message payload into p.
// It takes care on fragmented messages.
//
// The error is io.EOF only if all of message bytes were read.
// If an io.EOF happens during reading some but not all the message bytes
// Read() returns io.ErrUnexpectedEOF.
//
// The error is ErrNoFrameAdvance if no NextFrame() call was made before
// reading next message bytes.
func (r *Reader) Read(p []byte) (n int, err error) {
if r.frame == nil {
if !r.fragmented() {
// Every new Read() must be preceded by NextFrame() call.
return 0, ErrNoFrameAdvance
}
// Read next continuation or intermediate control frame.
_, err := r.NextFrame()
if err != nil {
return 0, err
}
if r.frame == nil {
// We handled intermediate control and now got nothing to read.
return 0, nil
}
}
n, err = r.frame.Read(p)
if err != nil && err != io.EOF {
return
}
if err == nil && r.raw.N != 0 {
return n, nil
}
// EOF condition (either err is io.EOF or r.raw.N is zero).
switch {
case r.raw.N != 0:
err = io.ErrUnexpectedEOF
case r.fragmented():
err = nil
r.resetFragment()
case r.CheckUTF8 && !r.utf8.Valid():
// NOTE: check utf8 only when full message received, since partial
// reads may be invalid.
n = r.utf8.Accepted()
err = ErrInvalidUTF8
default:
r.reset()
err = io.EOF
}
return
}
// Discard discards current message unread bytes.
// It discards all frames of fragmented message.
func (r *Reader) Discard() (err error) {
for {
_, err = io.Copy(ioutil.Discard, &r.raw)
if err != nil {
break
}
if !r.fragmented() {
break
}
if _, err = r.NextFrame(); err != nil {
break
}
}
r.reset()
return err
}
// NextFrame prepares r to read next message. It returns received frame header
// and non-nil error on failure.
//
// Note that next NextFrame() call must be done after receiving or discarding
// all current message bytes.
func (r *Reader) NextFrame() (hdr ws.Header, err error) {
hdr, err = ws.ReadHeader(r.Source)
if err == io.EOF && r.fragmented() {
// If we are in fragmented state EOF means that is was totally
// unexpected.
//
// NOTE: This is necessary to prevent callers such that
// ioutil.ReadAll to receive some amount of bytes without an error.
// ReadAll() ignores an io.EOF error, thus caller may think that
// whole message fetched, but actually only part of it.
err = io.ErrUnexpectedEOF
}
if err == nil && !r.SkipHeaderCheck {
err = ws.CheckHeader(hdr, r.State)
}
if err != nil {
return hdr, err
}
// Save raw reader to use it on discarding frame without ciphering and
// other streaming checks.
r.raw = io.LimitedReader{
R: r.Source,
N: hdr.Length,
}
frame := io.Reader(&r.raw)
if hdr.Masked {
frame = NewCipherReader(frame, hdr.Mask)
}
for _, ext := range r.Extensions {
hdr.Rsv, err = ext.BitsRecv(r.fseq, hdr.Rsv)
if err != nil {
return hdr, err
}
}
if r.fragmented() {
if hdr.OpCode.IsControl() {
if cb := r.OnIntermediate; cb != nil {
err = cb(hdr, frame)
}
if err == nil {
// Ensure that src is empty.
_, err = io.Copy(ioutil.Discard, &r.raw)
}
return
}
} else {
r.opCode = hdr.OpCode
}
if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
r.utf8.Source = frame
frame = &r.utf8
}
// Save reader with ciphering and other streaming checks.
r.frame = frame
if hdr.OpCode == ws.OpContinuation {
if cb := r.OnContinuation; cb != nil {
err = cb(hdr, frame)
}
}
if hdr.Fin {
r.State = r.State.Clear(ws.StateFragmented)
r.fseq = 0
} else {
r.State = r.State.Set(ws.StateFragmented)
r.fseq++
}
return
}
func (r *Reader) fragmented() bool {
return r.State.Fragmented()
}
func (r *Reader) resetFragment() {
r.raw = io.LimitedReader{}
r.frame = nil
// Reset source of the UTF8Reader, but not the state.
r.utf8.Source = nil
}
func (r *Reader) reset() {
r.raw = io.LimitedReader{}
r.frame = nil
r.utf8 = UTF8Reader{}
r.fseq = 0
r.opCode = 0
}
// NextReader prepares next message read from r. It returns header that
// describes the message and io.Reader to read message's payload. It returns
// non-nil error when it is not possible to read message's initial frame.
//
// Note that next NextReader() on the same r should be done after reading all
// bytes from previously returned io.Reader. For more performant way to discard
// message use Reader and its Discard() method.
//
// Note that it will not handle any "intermediate" frames, that possibly could
// be received between text/binary continuation frames. That is, if peer sent
// text/binary frame with fin flag "false", then it could send ping frame, and
// eventually remaining part of text/binary frame with fin "true" with
// NextReader() the ping frame will be dropped without any notice. To handle
// this rare, but possible situation (and if you do not know exactly which
// frames peer could send), you could use Reader with OnIntermediate field set.
func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
rd := &Reader{
Source: r,
State: s,
}
header, err := rd.NextFrame()
if err != nil {
return header, nil, err
}
return header, rd, nil
}

Some files were not shown because too many files have changed in this diff Show more