1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

GB28181: Support GB28181-2016 protocol. v5.0.74 (#3201)

01. Support GB config as StreamCaster.
02. Support disable GB by --gb28181=off.
03. Add utests for SIP examples.
04. Wireshark plugin to decode TCP/9000 as rtp.rfc4571
05. Support MPEGPS program stream codec.
06. Add utest for PS stream codec.
07. Decode MPEGPS packet stream.
08. Carry RTP and PS packet as helper in PS message.
09. Support recover from error mode.
10. Support process by a pack of PS/TS messages.
11. Add statistic for recovered and msgs dropped.
12. Recover from err position fastly.
13. Define state machine for GB session.
14. Bind context to GB session.
15. Re-invite when media disconnected.
16. Update GitHub actions with GB28181.
17. Support parse CANDIDATE by env or pip.
18. Support mux GB28181 to RTMP.
19. Support regression test by srs-bench.
This commit is contained in:
Winlin 2022-10-06 17:40:58 +08:00 committed by GitHub
parent 9c81a0e1bd
commit 5a420ece3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
298 changed files with 43343 additions and 763 deletions

View file

@ -0,0 +1,5 @@
bin/
reports/
cpu.out
mem.out
ws.test

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2017-2018 Sergey Kamardin <gobwas@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,54 @@
BENCH ?=.
BENCH_BASE?=master
clean:
rm -f bin/reporter
rm -fr autobahn/report/*
bin/reporter:
go build -o bin/reporter ./autobahn
bin/gocovmerge:
go build -o bin/gocovmerge github.com/wadey/gocovmerge
.PHONY: autobahn
autobahn: clean bin/reporter
./autobahn/script/test.sh --build --follow-logs
bin/reporter $(PWD)/autobahn/report/index.json
.PHONY: autobahn/report
autobahn/report: bin/reporter
./bin/reporter -http localhost:5555 ./autobahn/report/index.json
test:
go test -coverprofile=ws.coverage .
go test -coverprofile=wsutil.coverage ./wsutil
go test -coverprofile=wsfalte.coverage ./wsflate
# No statemenets to cover in ./tests (there are only tests).
go test ./tests
cover: bin/gocovmerge test autobahn
bin/gocovmerge ws.coverage wsutil.coverage wsflate.coverage autobahn/report/server.coverage > total.coverage
benchcmp: BENCH_BRANCH=$(shell git rev-parse --abbrev-ref HEAD)
benchcmp: BENCH_OLD:=$(shell mktemp -t old.XXXX)
benchcmp: BENCH_NEW:=$(shell mktemp -t new.XXXX)
benchcmp:
if [ ! -z "$(shell git status -s)" ]; then\
echo "could not compare with $(BENCH_BASE) found unstaged changes";\
exit 1;\
fi;\
if [ "$(BENCH_BRANCH)" == "$(BENCH_BASE)" ]; then\
echo "comparing the same branches";\
exit 1;\
fi;\
echo "benchmarking $(BENCH_BRANCH)...";\
go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_NEW);\
echo "benchmarking $(BENCH_BASE)...";\
git checkout -q $(BENCH_BASE);\
go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_OLD);\
git checkout -q $(BENCH_BRANCH);\
echo "\nresults:";\
echo "========\n";\
benchcmp $(BENCH_OLD) $(BENCH_NEW);\

View file

@ -0,0 +1,450 @@
# ws
[![GoDoc][godoc-image]][godoc-url]
[![CI][ci-badge]][ci-url]
> [RFC6455][rfc-url] WebSocket implementation in Go.
# Features
- Zero-copy upgrade
- No intermediate allocations during I/O
- Low-level API which allows to build your own logic of packet handling and
buffers reuse
- High-level wrappers and helpers around API in `wsutil` package, which allow
to start fast without digging the protocol internals
# Documentation
[GoDoc][godoc-url].
# Why
Existing WebSocket implementations do not allow users to reuse I/O buffers
between connections in clear way. This library aims to export efficient
low-level interface for working with the protocol without forcing only one way
it could be used.
By the way, if you want get the higher-level tools, you can use `wsutil`
package.
# Status
Library is tagged as `v1*` so its API must not be broken during some
improvements or refactoring.
This implementation of RFC6455 passes [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) and currently has
about 78% coverage.
# Examples
Example applications using `ws` are developed in separate repository
[ws-examples](https://github.com/gobwas/ws-examples).
# Usage
The higher-level example of WebSocket echo server:
```go
package main
import (
"net/http"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
func main() {
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
// handle error
}
go func() {
defer conn.Close()
for {
msg, op, err := wsutil.ReadClientData(conn)
if err != nil {
// handle error
}
err = wsutil.WriteServerMessage(conn, op, msg)
if err != nil {
// handle error
}
}
}()
}))
}
```
Lower-level, but still high-level example:
```go
import (
"net/http"
"io"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
func main() {
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
// handle error
}
go func() {
defer conn.Close()
var (
state = ws.StateServerSide
reader = wsutil.NewReader(conn, state)
writer = wsutil.NewWriter(conn, state, ws.OpText)
)
for {
header, err := reader.NextFrame()
if err != nil {
// handle error
}
// Reset writer to write frame with right operation code.
writer.Reset(conn, state, header.OpCode)
if _, err = io.Copy(writer, reader); err != nil {
// handle error
}
if err = writer.Flush(); err != nil {
// handle error
}
}
}()
}))
}
```
We can apply the same pattern to read and write structured responses through a JSON encoder and decoder.:
```go
...
var (
r = wsutil.NewReader(conn, ws.StateServerSide)
w = wsutil.NewWriter(conn, ws.StateServerSide, ws.OpText)
decoder = json.NewDecoder(r)
encoder = json.NewEncoder(w)
)
for {
hdr, err = r.NextFrame()
if err != nil {
return err
}
if hdr.OpCode == ws.OpClose {
return io.EOF
}
var req Request
if err := decoder.Decode(&req); err != nil {
return err
}
var resp Response
if err := encoder.Encode(&resp); err != nil {
return err
}
if err = w.Flush(); err != nil {
return err
}
}
...
```
The lower-level example without `wsutil`:
```go
package main
import (
"net"
"io"
"github.com/gobwas/ws"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
log.Fatal(err)
}
for {
conn, err := ln.Accept()
if err != nil {
// handle error
}
_, err = ws.Upgrade(conn)
if err != nil {
// handle error
}
go func() {
defer conn.Close()
for {
header, err := ws.ReadHeader(conn)
if err != nil {
// handle error
}
payload := make([]byte, header.Length)
_, err = io.ReadFull(conn, payload)
if err != nil {
// handle error
}
if header.Masked {
ws.Cipher(payload, header.Mask, 0)
}
// Reset the Masked flag, server frames must not be masked as
// RFC6455 says.
header.Masked = false
if err := ws.WriteHeader(conn, header); err != nil {
// handle error
}
if _, err := conn.Write(payload); err != nil {
// handle error
}
if header.OpCode == ws.OpClose {
return
}
}
}()
}
}
```
# Zero-copy upgrade
Zero-copy upgrade helps to avoid unnecessary allocations and copying while
handling HTTP Upgrade request.
Processing of all non-websocket headers is made in place with use of registered
user callbacks whose arguments are only valid until callback returns.
The simple example looks like this:
```go
package main
import (
"net"
"log"
"github.com/gobwas/ws"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
log.Fatal(err)
}
u := ws.Upgrader{
OnHeader: func(key, value []byte) (err error) {
log.Printf("non-websocket header: %q=%q", key, value)
return
},
}
for {
conn, err := ln.Accept()
if err != nil {
// handle error
}
_, err = u.Upgrade(conn)
if err != nil {
// handle error
}
}
}
```
Usage of `ws.Upgrader` here brings ability to control incoming connections on
tcp level and simply not to accept them by some logic.
Zero-copy upgrade is for high-load services which have to control many
resources such as connections buffers.
The real life example could be like this:
```go
package main
import (
"fmt"
"io"
"log"
"net"
"net/http"
"runtime"
"github.com/gobwas/httphead"
"github.com/gobwas/ws"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
// handle error
}
// Prepare handshake header writer from http.Header mapping.
header := ws.HandshakeHeaderHTTP(http.Header{
"X-Go-Version": []string{runtime.Version()},
})
u := ws.Upgrader{
OnHost: func(host []byte) error {
if string(host) == "github.com" {
return nil
}
return ws.RejectConnectionError(
ws.RejectionStatus(403),
ws.RejectionHeader(ws.HandshakeHeaderString(
"X-Want-Host: github.com\r\n",
)),
)
},
OnHeader: func(key, value []byte) error {
if string(key) != "Cookie" {
return nil
}
ok := httphead.ScanCookie(value, func(key, value []byte) bool {
// Check session here or do some other stuff with cookies.
// Maybe copy some values for future use.
return true
})
if ok {
return nil
}
return ws.RejectConnectionError(
ws.RejectionReason("bad cookie"),
ws.RejectionStatus(400),
)
},
OnBeforeUpgrade: func() (ws.HandshakeHeader, error) {
return header, nil
},
}
for {
conn, err := ln.Accept()
if err != nil {
log.Fatal(err)
}
_, err = u.Upgrade(conn)
if err != nil {
log.Printf("upgrade error: %s", err)
}
}
}
```
# Compression
There is a `ws/wsflate` package to support [Permessage-Deflate Compression
Extension][rfc-pmce].
It provides minimalistic I/O wrappers to be used in conjunction with any
deflate implementation (for example, the standard library's
[compress/flate][compress/flate].
```go
package main
import (
"bytes"
"log"
"net"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
// handle error
}
e := wsflate.Extension{
// We are using default parameters here since we use
// wsflate.{Compress,Decompress}Frame helpers below in the code.
// This assumes that we use standard compress/flate package as flate
// implementation.
Parameters: wsflate.DefaultParameters,
}
u := ws.Upgrader{
Negotiate: e.Negotiate,
}
for {
conn, err := ln.Accept()
if err != nil {
log.Fatal(err)
}
// Reset extension after previous upgrades.
e.Reset()
_, err = u.Upgrade(conn)
if err != nil {
log.Printf("upgrade error: %s", err)
continue
}
if _, ok := e.Accepted(); !ok {
log.Printf("didn't negotiate compression for %s", conn.RemoteAddr())
conn.Close()
continue
}
go func() {
defer conn.Close()
for {
frame, err := ws.ReadFrame(conn)
if err != nil {
// Handle error.
return
}
frame = ws.UnmaskFrameInPlace(frame)
frame, err = wsflate.DecompressFrame(frame)
if err != nil {
// Handle error.
return
}
// Do something with frame...
ack := ws.NewTextFrame([]byte("this is an acknowledgement"))
ack, err = wsflate.CompressFrame(ack)
if err != nil {
// Handle error.
return
}
if err = ws.WriteFrame(conn, ack); err != nil {
// Handle error.
return
}
}
}()
}
}
```
[rfc-url]: https://tools.ietf.org/html/rfc6455
[rfc-pmce]: https://tools.ietf.org/html/rfc7692#section-7
[godoc-image]: https://godoc.org/github.com/gobwas/ws?status.svg
[godoc-url]: https://godoc.org/github.com/gobwas/ws
[compress/flate]: https://golang.org/pkg/compress/flate/
[ci-badge]: https://github.com/gobwas/ws/workflows/CI/badge.svg
[ci-url]: https://github.com/gobwas/ws/actions?query=workflow%3ACI

View file

@ -0,0 +1,145 @@
package ws
import "unicode/utf8"
// State represents state of websocket endpoint.
// It used by some functions to be more strict when checking compatibility with RFC6455.
type State uint8
const (
// StateServerSide means that endpoint (caller) is a server.
StateServerSide State = 0x1 << iota
// StateClientSide means that endpoint (caller) is a client.
StateClientSide
// StateExtended means that extension was negotiated during handshake.
StateExtended
// StateFragmented means that endpoint (caller) has received fragmented
// frame and waits for continuation parts.
StateFragmented
)
// Is checks whether the s has v enabled.
func (s State) Is(v State) bool {
return uint8(s)&uint8(v) != 0
}
// Set enables v state on s.
func (s State) Set(v State) State {
return s | v
}
// Clear disables v state on s.
func (s State) Clear(v State) State {
return s & (^v)
}
// ServerSide reports whether states represents server side.
func (s State) ServerSide() bool { return s.Is(StateServerSide) }
// ClientSide reports whether state represents client side.
func (s State) ClientSide() bool { return s.Is(StateClientSide) }
// Extended reports whether state is extended.
func (s State) Extended() bool { return s.Is(StateExtended) }
// Fragmented reports whether state is fragmented.
func (s State) Fragmented() bool { return s.Is(StateFragmented) }
// ProtocolError describes error during checking/parsing websocket frames or
// headers.
type ProtocolError string
// Error implements error interface.
func (p ProtocolError) Error() string { return string(p) }
// Errors used by the protocol checkers.
var (
ErrProtocolOpCodeReserved = ProtocolError("use of reserved op code")
ErrProtocolControlPayloadOverflow = ProtocolError("control frame payload limit exceeded")
ErrProtocolControlNotFinal = ProtocolError("control frame is not final")
ErrProtocolNonZeroRsv = ProtocolError("non-zero rsv bits with no extension negotiated")
ErrProtocolMaskRequired = ProtocolError("frames from client to server must be masked")
ErrProtocolMaskUnexpected = ProtocolError("frames from server to client must be not masked")
ErrProtocolContinuationExpected = ProtocolError("unexpected non-continuation data frame")
ErrProtocolContinuationUnexpected = ProtocolError("unexpected continuation data frame")
ErrProtocolStatusCodeNotInUse = ProtocolError("status code is not in use")
ErrProtocolStatusCodeApplicationLevel = ProtocolError("status code is only application level")
ErrProtocolStatusCodeNoMeaning = ProtocolError("status code has no meaning yet")
ErrProtocolStatusCodeUnknown = ProtocolError("status code is not defined in spec")
ErrProtocolInvalidUTF8 = ProtocolError("invalid utf8 sequence in close reason")
)
// CheckHeader checks h to contain valid header data for given state s.
//
// Note that zero state (0) means that state is clean,
// neither server or client side, nor fragmented, nor extended.
func CheckHeader(h Header, s State) error {
if h.OpCode.IsReserved() {
return ErrProtocolOpCodeReserved
}
if h.OpCode.IsControl() {
if h.Length > MaxControlFramePayloadSize {
return ErrProtocolControlPayloadOverflow
}
if !h.Fin {
return ErrProtocolControlNotFinal
}
}
switch {
// [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for
// non-zero values. If a nonzero value is received and none of the
// negotiated extensions defines the meaning of such a nonzero value, the
// receiving endpoint MUST _Fail the WebSocket Connection_.
case h.Rsv != 0 && !s.Extended():
return ErrProtocolNonZeroRsv
// [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked.
// In this case, a server MAY send a Close frame with a status code of 1002 (protocol error)
// as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client.
// A client MUST close a connection if it detects a masked frame. In this case, it MAY use the
// status code 1002 (protocol error) as defined in Section 7.4.1.
case s.ServerSide() && !h.Masked:
return ErrProtocolMaskRequired
case s.ClientSide() && h.Masked:
return ErrProtocolMaskUnexpected
// [RFC6455]: See detailed explanation in 5.4 section.
case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation:
return ErrProtocolContinuationExpected
case !s.Fragmented() && h.OpCode == OpContinuation:
return ErrProtocolContinuationUnexpected
default:
return nil
}
}
// CheckCloseFrameData checks received close information
// to be valid RFC6455 compatible close info.
//
// Note that code.Empty() or code.IsAppLevel() will raise error.
//
// If endpoint sends close frame without status code (with frame.Length = 0),
// application should not check its payload.
func CheckCloseFrameData(code StatusCode, reason string) error {
switch {
case code.IsNotUsed():
return ErrProtocolStatusCodeNotInUse
case code.IsProtocolReserved():
return ErrProtocolStatusCodeApplicationLevel
case code == StatusNoMeaningYet:
return ErrProtocolStatusCodeNoMeaning
case code.IsProtocolSpec() && !code.IsProtocolDefined():
return ErrProtocolStatusCodeUnknown
case !utf8.ValidString(reason):
return ErrProtocolInvalidUTF8
default:
return nil
}
}

View file

@ -0,0 +1,61 @@
package ws
import (
"encoding/binary"
)
// Cipher applies XOR cipher to the payload using mask.
// Offset is used to cipher chunked data (e.g. in io.Reader implementations).
//
// To convert masked data into unmasked data, or vice versa, the following
// algorithm is applied. The same algorithm applies regardless of the
// direction of the translation, e.g., the same steps are applied to
// mask the data as to unmask the data.
func Cipher(payload []byte, mask [4]byte, offset int) {
n := len(payload)
if n < 8 {
for i := 0; i < n; i++ {
payload[i] ^= mask[(offset+i)%4]
}
return
}
// Calculate position in mask due to previously processed bytes number.
mpos := offset % 4
// Count number of bytes will processed one by one from the beginning of payload.
ln := remain[mpos]
// Count number of bytes will processed one by one from the end of payload.
// This is done to process payload by 8 bytes in each iteration of main loop.
rn := (n - ln) % 8
for i := 0; i < ln; i++ {
payload[i] ^= mask[(mpos+i)%4]
}
for i := n - rn; i < n; i++ {
payload[i] ^= mask[(mpos+i)%4]
}
// NOTE: we use here binary.LittleEndian regardless of what is real
// endianess on machine is. To do so, we have to use binary.LittleEndian in
// the masking loop below as well.
var (
m = binary.LittleEndian.Uint32(mask[:])
m2 = uint64(m)<<32 | uint64(m)
)
// Skip already processed right part.
// Get number of uint64 parts remaining to process.
n = (n - ln - rn) >> 3
for i := 0; i < n; i++ {
var (
j = ln + (i << 3)
chunk = payload[j : j+8]
)
p := binary.LittleEndian.Uint64(chunk)
p = p ^ m2
binary.LittleEndian.PutUint64(chunk, p)
}
}
// remain maps position in masking key [0,4) to number
// of bytes that need to be processed manually inside Cipher().
var remain = [4]int{0, 3, 2, 1}

View file

@ -0,0 +1,563 @@
package ws
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/gobwas/httphead"
"github.com/gobwas/pool/pbufio"
)
// Constants used by Dialer.
const (
DefaultClientReadBufferSize = 4096
DefaultClientWriteBufferSize = 4096
)
// Handshake represents handshake result.
type Handshake struct {
// Protocol is the subprotocol selected during handshake.
Protocol string
// Extensions is the list of negotiated extensions.
Extensions []httphead.Option
}
// Errors used by the websocket client.
var (
ErrHandshakeBadStatus = fmt.Errorf("unexpected http status")
ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol)
)
// DefaultDialer is dialer that holds no options and is used by Dial function.
var DefaultDialer Dialer
// Dial is like Dialer{}.Dial().
func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) {
return DefaultDialer.Dial(ctx, urlstr)
}
// Dialer contains options for establishing websocket connection to an url.
type Dialer struct {
// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
// They used to read and write http data while upgrading to WebSocket.
// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
//
// If a size is zero then default value is used.
ReadBufferSize, WriteBufferSize int
// Timeout is the maximum amount of time a Dial() will wait for a connect
// and an handshake to complete.
//
// The default is no timeout.
Timeout time.Duration
// Protocols is the list of subprotocols that the client wants to speak,
// ordered by preference.
//
// See https://tools.ietf.org/html/rfc6455#section-4.1
Protocols []string
// Extensions is the list of extensions that client wants to speak.
//
// Note that if server decides to use some of this extensions, Dial() will
// return Handshake struct containing a slice of items, which are the
// shallow copies of the items from this list. That is, internals of
// Extensions items are shared during Dial().
//
// See https://tools.ietf.org/html/rfc6455#section-4.1
// See https://tools.ietf.org/html/rfc6455#section-9.1
Extensions []httphead.Option
// Header is an optional HandshakeHeader instance that could be used to
// write additional headers to the handshake request.
//
// It used instead of any key-value mappings to avoid allocations in user
// land.
Header HandshakeHeader
// OnStatusError is the callback that will be called after receiving non
// "101 Continue" HTTP response status. It receives an io.Reader object
// representing server response bytes. That is, it gives ability to parse
// HTTP response somehow (probably with http.ReadResponse call) and make a
// decision of further logic.
//
// The arguments are only valid until the callback returns.
OnStatusError func(status int, reason []byte, resp io.Reader)
// OnHeader is the callback that will be called after successful parsing of
// header, that is not used during WebSocket handshake procedure. That is,
// it will be called with non-websocket headers, which could be relevant
// for application-level logic.
//
// The arguments are only valid until the callback returns.
//
// Returned value could be used to prevent processing response.
OnHeader func(key, value []byte) (err error)
// NetDial is the function that is used to get plain tcp connection.
// If it is not nil, then it is used instead of net.Dialer.
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
// TLSClient is the callback that will be called after successful dial with
// received connection and its remote host name. If it is nil, then the
// default tls.Client() will be used.
// If it is not nil, then TLSConfig field is ignored.
TLSClient func(conn net.Conn, hostname string) net.Conn
// TLSConfig is passed to tls.Client() to start TLS over established
// connection. If TLSClient is not nil, then it is ignored. If TLSConfig is
// non-nil and its ServerName is empty, then for every Dial() it will be
// cloned and appropriate ServerName will be set.
TLSConfig *tls.Config
// WrapConn is the optional callback that will be called when connection is
// ready for an i/o. That is, it will be called after successful dial and
// TLS initialization (for "wss" schemes). It may be helpful for different
// user land purposes such as end to end encryption.
//
// Note that for debugging purposes of an http handshake (e.g. sent request
// and received response), there is an wsutil.DebugDialer struct.
WrapConn func(conn net.Conn) net.Conn
}
// Dial connects to the url host and upgrades connection to WebSocket.
//
// If server has sent frames right after successful handshake then returned
// buffer will be non-nil. In other cases buffer is always nil. For better
// memory efficiency received non-nil bufio.Reader should be returned to the
// inner pool with PutReader() function after use.
//
// Note that Dialer does not implement IDNA (RFC5895) logic as net/http does.
// If you want to dial non-ascii host name, take care of its name serialization
// avoiding bad request issues. For more info see net/http Request.Write()
// implementation, especially cleanHost() function.
func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) {
u, err := url.ParseRequestURI(urlstr)
if err != nil {
return
}
// Prepare context to dial with. Initially it is the same as original, but
// if d.Timeout is non-zero and points to time that is before ctx.Deadline,
// we use more shorter context for dial.
dialctx := ctx
var deadline time.Time
if t := d.Timeout; t != 0 {
deadline = time.Now().Add(t)
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
var cancel context.CancelFunc
dialctx, cancel = context.WithDeadline(ctx, deadline)
defer cancel()
}
}
if conn, err = d.dial(dialctx, u); err != nil {
return
}
defer func() {
if err != nil {
conn.Close()
}
}()
if ctx == context.Background() {
// No need to start I/O interrupter goroutine which is not zero-cost.
conn.SetDeadline(deadline)
defer conn.SetDeadline(noDeadline)
} else {
// Context could be canceled or its deadline could be exceeded.
// Start the interrupter goroutine to handle context cancelation.
done := setupContextDeadliner(ctx, conn)
defer func() {
// Map Upgrade() error to a possible context expiration error. That
// is, even if Upgrade() err is nil, context could be already
// expired and connection be "poisoned" by SetDeadline() call.
// In that case we must not return ctx.Err() error.
done(&err)
}()
}
br, hs, err = d.Upgrade(conn, u)
return
}
var (
// netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if
// Dialer.NetDial is not provided.
netEmptyDialer net.Dialer
// tlsEmptyConfig is an empty tls.Config used as default one.
tlsEmptyConfig tls.Config
)
func tlsDefaultConfig() *tls.Config {
return &tlsEmptyConfig
}
func hostport(host string, defaultPort string) (hostname, addr string) {
var (
colon = strings.LastIndexByte(host, ':')
bracket = strings.IndexByte(host, ']')
)
if colon > bracket {
return host[:colon], host
}
return host, host + defaultPort
}
func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
dial := d.NetDial
if dial == nil {
dial = netEmptyDialer.DialContext
}
switch u.Scheme {
case "ws":
_, addr := hostport(u.Host, ":80")
conn, err = dial(ctx, "tcp", addr)
case "wss":
hostname, addr := hostport(u.Host, ":443")
conn, err = dial(ctx, "tcp", addr)
if err != nil {
return
}
tlsClient := d.TLSClient
if tlsClient == nil {
tlsClient = d.tlsClient
}
conn = tlsClient(conn, hostname)
default:
return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme)
}
if wrap := d.WrapConn; wrap != nil {
conn = wrap(conn)
}
return
}
func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn {
config := d.TLSConfig
if config == nil {
config = tlsDefaultConfig()
}
if config.ServerName == "" {
config = tlsCloneConfig(config)
config.ServerName = hostname
}
// Do not make conn.Handshake() here because downstairs we will prepare
// i/o on this conn with proper context's timeout handling.
return tls.Client(conn, config)
}
var (
// This variables are set like in net/net.go.
// noDeadline is just zero value for readability.
noDeadline = time.Time{}
// aLongTimeAgo is a non-zero time, far in the past, used for immediate
// cancelation of dials.
aLongTimeAgo = time.Unix(42, 0)
)
// Upgrade writes an upgrade request to the given io.ReadWriter conn at given
// url u and reads a response from it.
//
// It is a caller responsibility to manage I/O deadlines on conn.
//
// It returns handshake info and some bytes which could be written by the peer
// right after response and be caught by us during buffered read.
func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) {
// headerSeen constants helps to report whether or not some header was seen
// during reading request bytes.
const (
headerSeenUpgrade = 1 << iota
headerSeenConnection
headerSeenSecAccept
// headerSeenAll is the value that we expect to receive at the end of
// headers read/parse loop.
headerSeenAll = 0 |
headerSeenUpgrade |
headerSeenConnection |
headerSeenSecAccept
)
br = pbufio.GetReader(conn,
nonZero(d.ReadBufferSize, DefaultClientReadBufferSize),
)
bw := pbufio.GetWriter(conn,
nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize),
)
defer func() {
pbufio.PutWriter(bw)
if br.Buffered() == 0 || err != nil {
// Server does not wrote additional bytes to the connection or
// error occurred. That is, no reason to return buffer.
pbufio.PutReader(br)
br = nil
}
}()
nonce := make([]byte, nonceSize)
initNonce(nonce)
httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header)
if err = bw.Flush(); err != nil {
return
}
// Read HTTP status line like "HTTP/1.1 101 Switching Protocols".
sl, err := readLine(br)
if err != nil {
return
}
// Begin validation of the response.
// See https://tools.ietf.org/html/rfc6455#section-4.2.2
// Parse request line data like HTTP version, uri and method.
resp, err := httpParseResponseLine(sl)
if err != nil {
return
}
// Even if RFC says "1.1 or higher" without mentioning the part of the
// version, we apply it only to minor part.
if resp.major != 1 || resp.minor < 1 {
err = ErrHandshakeBadProtocol
return
}
if resp.status != 101 {
err = StatusError(resp.status)
if onStatusError := d.OnStatusError; onStatusError != nil {
// Invoke callback with multireader of status-line bytes br.
onStatusError(resp.status, resp.reason,
io.MultiReader(
bytes.NewReader(sl),
strings.NewReader(crlf),
br,
),
)
}
return
}
// If response status is 101 then we expect all technical headers to be
// valid. If not, then we stop processing response without giving user
// ability to read non-technical headers. That is, we do not distinguish
// technical errors (such as parsing error) and protocol errors.
var headerSeen byte
for {
line, e := readLine(br)
if e != nil {
err = e
return
}
if len(line) == 0 {
// Blank line, no more lines to read.
break
}
k, v, ok := httpParseHeaderLine(line)
if !ok {
err = ErrMalformedResponse
return
}
switch btsToString(k) {
case headerUpgradeCanonical:
headerSeen |= headerSeenUpgrade
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
err = ErrHandshakeBadUpgrade
return
}
case headerConnectionCanonical:
headerSeen |= headerSeenConnection
// Note that as RFC6455 says:
// > A |Connection| header field with value "Upgrade".
// That is, in server side, "Connection" header could contain
// multiple token. But in response it must contains exactly one.
if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
err = ErrHandshakeBadConnection
return
}
case headerSecAcceptCanonical:
headerSeen |= headerSeenSecAccept
if !checkAcceptFromNonce(v, nonce) {
err = ErrHandshakeBadSecAccept
return
}
case headerSecProtocolCanonical:
// RFC6455 1.3:
// "The server selects one or none of the acceptable protocols
// and echoes that value in its handshake to indicate that it has
// selected that protocol."
for _, want := range d.Protocols {
if string(v) == want {
hs.Protocol = want
break
}
}
if hs.Protocol == "" {
// Server echoed subprotocol that is not present in client
// requested protocols.
err = ErrHandshakeBadSubProtocol
return
}
case headerSecExtensionsCanonical:
hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
if err != nil {
return
}
default:
if onHeader := d.OnHeader; onHeader != nil {
if e := onHeader(k, v); e != nil {
err = e
return
}
}
}
}
if err == nil && headerSeen != headerSeenAll {
switch {
case headerSeen&headerSeenUpgrade == 0:
err = ErrHandshakeBadUpgrade
case headerSeen&headerSeenConnection == 0:
err = ErrHandshakeBadConnection
case headerSeen&headerSeenSecAccept == 0:
err = ErrHandshakeBadSecAccept
default:
panic("unknown headers state")
}
}
return
}
// PutReader returns bufio.Reader instance to the inner reuse pool.
// It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which
// contains unprocessed buffered data, that was sent by the server quickly
// right after handshake.
func PutReader(br *bufio.Reader) {
pbufio.PutReader(br)
}
// StatusError contains an unexpected status-line code from the server.
type StatusError int
func (s StatusError) Error() string {
return "unexpected HTTP response status: " + strconv.Itoa(int(s))
}
func isTimeoutError(err error) bool {
t, ok := err.(net.Error)
return ok && t.Timeout()
}
func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) {
if len(selected) == 0 {
return received, nil
}
var (
index int
option httphead.Option
err error
)
index = -1
match := func() (ok bool) {
for _, want := range wanted {
// A server accepts one or more extensions by including a
// |Sec-WebSocket-Extensions| header field containing one or more
// extensions that were requested by the client.
//
// The interpretation of any extension parameters, and what
// constitutes a valid response by a server to a requested set of
// parameters by a client, will be defined by each such extension.
if bytes.Equal(option.Name, want.Name) {
// Check parsed extension to be present in client
// requested extensions. We move matched extension
// from client list to avoid allocation.
received = append(received, option)
return true
}
}
return false
}
ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control {
if i != index {
// Met next option.
index = i
if i != 0 && !match() {
// Server returned non-requested extension.
err = ErrHandshakeBadExtensions
return httphead.ControlBreak
}
option = httphead.Option{Name: name}
}
if attr != nil {
option.Parameters.Set(attr, val)
}
return httphead.ControlContinue
})
if !ok {
err = ErrMalformedResponse
return received, err
}
if !match() {
return received, ErrHandshakeBadExtensions
}
return received, err
}
// setupContextDeadliner is a helper function that starts connection I/O
// interrupter goroutine.
//
// Started goroutine calls SetDeadline() with long time ago value when context
// become expired to make any I/O operations failed. It returns done function
// that stops started goroutine and maps error received from conn I/O methods
// to possible context expiration error.
//
// In concern with possible SetDeadline() call inside interrupter goroutine,
// caller passes pointer to its I/O error (even if it is nil) to done(&err).
// That is, even if I/O error is nil, context could be already expired and
// connection "poisoned" by SetDeadline() call. In that case done(&err) will
// store at *err ctx.Err() result. If err is caused not by timeout, it will
// leaved untouched.
func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) {
var (
quit = make(chan struct{})
interrupt = make(chan error, 1)
)
go func() {
select {
case <-quit:
interrupt <- nil
case <-ctx.Done():
// Cancel i/o immediately.
conn.SetDeadline(aLongTimeAgo)
interrupt <- ctx.Err()
}
}()
return func(err *error) {
close(quit)
// If ctx.Err() is non-nil and the original err is net.Error with
// Timeout() == true, then it means that I/O was canceled by us by
// SetDeadline(aLongTimeAgo) call, or by somebody else previously
// by conn.SetDeadline(x).
//
// Even on race condition when both deadlines are expired
// (SetDeadline() made not by us and context's), we prefer ctx.Err() to
// be returned.
if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) {
*err = ctxErr
}
}
}

View file

@ -0,0 +1,35 @@
// +build !go1.8
package ws
import "crypto/tls"
func tlsCloneConfig(c *tls.Config) *tls.Config {
// NOTE: we copying SessionTicketsDisabled and SessionTicketKey here
// without calling inner c.initOnceServer somehow because we only could get
// here from the ws.Dialer code, which is obviously a client and makes
// tls.Client() when it gets new net.Conn.
return &tls.Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: c.ClientSessionCache,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
}
}

View file

@ -0,0 +1,9 @@
// +build go1.8
package ws
import "crypto/tls"
func tlsCloneConfig(c *tls.Config) *tls.Config {
return c.Clone()
}

View file

@ -0,0 +1,81 @@
/*
Package ws implements a client and server for the WebSocket protocol as
specified in RFC 6455.
The main purpose of this package is to provide simple low-level API for
efficient work with protocol.
Overview.
Upgrade to WebSocket (or WebSocket handshake) can be done in two ways.
The first way is to use `net/http` server:
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
})
The second and much more efficient way is so-called "zero-copy upgrade". It
avoids redundant allocations and copying of not used headers or other request
data. User decides by himself which data should be copied.
ln, err := net.Listen("tcp", ":8080")
if err != nil {
// handle error
}
conn, err := ln.Accept()
if err != nil {
// handle error
}
handshake, err := ws.Upgrade(conn)
if err != nil {
// handle error
}
For customization details see `ws.Upgrader` documentation.
After WebSocket handshake you can work with connection in multiple ways.
That is, `ws` does not force the only one way of how to work with WebSocket:
header, err := ws.ReadHeader(conn)
if err != nil {
// handle err
}
buf := make([]byte, header.Length)
_, err := io.ReadFull(conn, buf)
if err != nil {
// handle err
}
resp := ws.NewBinaryFrame([]byte("hello, world!"))
if err := ws.WriteFrame(conn, frame); err != nil {
// handle err
}
As you can see, it stream friendly:
const N = 42
ws.WriteHeader(ws.Header{
Fin: true,
Length: N,
OpCode: ws.OpBinary,
})
io.CopyN(conn, rand.Reader, N)
Or:
header, err := ws.ReadHeader(conn)
if err != nil {
// handle err
}
io.CopyN(ioutil.Discard, conn, header.Length)
For more info see the documentation.
*/
package ws

View file

@ -0,0 +1,54 @@
package ws
// RejectOption represents an option used to control the way connection is
// rejected.
type RejectOption func(*rejectConnectionError)
// RejectionReason returns an option that makes connection to be rejected with
// given reason.
func RejectionReason(reason string) RejectOption {
return func(err *rejectConnectionError) {
err.reason = reason
}
}
// RejectionStatus returns an option that makes connection to be rejected with
// given HTTP status code.
func RejectionStatus(code int) RejectOption {
return func(err *rejectConnectionError) {
err.code = code
}
}
// RejectionHeader returns an option that makes connection to be rejected with
// given HTTP headers.
func RejectionHeader(h HandshakeHeader) RejectOption {
return func(err *rejectConnectionError) {
err.header = h
}
}
// RejectConnectionError constructs an error that could be used to control the way
// handshake is rejected by Upgrader.
func RejectConnectionError(options ...RejectOption) error {
err := new(rejectConnectionError)
for _, opt := range options {
opt(err)
}
return err
}
// rejectConnectionError represents a rejection of upgrade error.
//
// It can be returned by Upgrader's On* hooks to control the way WebSocket
// handshake is rejected.
type rejectConnectionError struct {
reason string
code int
header HandshakeHeader
}
// Error implements error interface.
func (r *rejectConnectionError) Error() string {
return r.reason
}

View file

@ -0,0 +1,420 @@
package ws
import (
"bytes"
"encoding/binary"
"math/rand"
)
// Constants defined by specification.
const (
// All control frames MUST have a payload length of 125 bytes or less and MUST NOT be fragmented.
MaxControlFramePayloadSize = 125
)
// OpCode represents operation code.
type OpCode byte
// Operation codes defined by specification.
// See https://tools.ietf.org/html/rfc6455#section-5.2
const (
OpContinuation OpCode = 0x0
OpText OpCode = 0x1
OpBinary OpCode = 0x2
OpClose OpCode = 0x8
OpPing OpCode = 0x9
OpPong OpCode = 0xa
)
// IsControl checks whether the c is control operation code.
// See https://tools.ietf.org/html/rfc6455#section-5.5
func (c OpCode) IsControl() bool {
// RFC6455: Control frames are identified by opcodes where
// the most significant bit of the opcode is 1.
//
// Note that OpCode is only 4 bit length.
return c&0x8 != 0
}
// IsData checks whether the c is data operation code.
// See https://tools.ietf.org/html/rfc6455#section-5.6
func (c OpCode) IsData() bool {
// RFC6455: Data frames (e.g., non-control frames) are identified by opcodes
// where the most significant bit of the opcode is 0.
//
// Note that OpCode is only 4 bit length.
return c&0x8 == 0
}
// IsReserved checks whether the c is reserved operation code.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func (c OpCode) IsReserved() bool {
// RFC6455:
// %x3-7 are reserved for further non-control frames
// %xB-F are reserved for further control frames
return (0x3 <= c && c <= 0x7) || (0xb <= c && c <= 0xf)
}
// StatusCode represents the encoded reason for closure of websocket connection.
//
// There are few helper methods on StatusCode that helps to define a range in
// which given code is lay in. accordingly to ranges defined in specification.
//
// See https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode uint16
// StatusCodeRange describes range of StatusCode values.
type StatusCodeRange struct {
Min, Max StatusCode
}
// Status code ranges defined by specification.
// See https://tools.ietf.org/html/rfc6455#section-7.4.2
var (
StatusRangeNotInUse = StatusCodeRange{0, 999}
StatusRangeProtocol = StatusCodeRange{1000, 2999}
StatusRangeApplication = StatusCodeRange{3000, 3999}
StatusRangePrivate = StatusCodeRange{4000, 4999}
)
// Status codes defined by specification.
// See https://tools.ietf.org/html/rfc6455#section-7.4.1
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
StatusNoMeaningYet StatusCode = 1004
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExt StatusCode = 1010
StatusInternalServerError StatusCode = 1011
StatusTLSHandshake StatusCode = 1015
// StatusAbnormalClosure is a special code designated for use in
// applications.
StatusAbnormalClosure StatusCode = 1006
// StatusNoStatusRcvd is a special code designated for use in applications.
StatusNoStatusRcvd StatusCode = 1005
)
// In reports whether the code is defined in given range.
func (s StatusCode) In(r StatusCodeRange) bool {
return r.Min <= s && s <= r.Max
}
// Empty reports whether the code is empty.
// Empty code has no any meaning neither app level codes nor other.
// This method is useful just to check that code is golang default value 0.
func (s StatusCode) Empty() bool {
return s == 0
}
// IsNotUsed reports whether the code is predefined in not used range.
func (s StatusCode) IsNotUsed() bool {
return s.In(StatusRangeNotInUse)
}
// IsApplicationSpec reports whether the code should be defined by
// application, framework or libraries specification.
func (s StatusCode) IsApplicationSpec() bool {
return s.In(StatusRangeApplication)
}
// IsPrivateSpec reports whether the code should be defined privately.
func (s StatusCode) IsPrivateSpec() bool {
return s.In(StatusRangePrivate)
}
// IsProtocolSpec reports whether the code should be defined by protocol specification.
func (s StatusCode) IsProtocolSpec() bool {
return s.In(StatusRangeProtocol)
}
// IsProtocolDefined reports whether the code is already defined by protocol specification.
func (s StatusCode) IsProtocolDefined() bool {
switch s {
case StatusNormalClosure,
StatusGoingAway,
StatusProtocolError,
StatusUnsupportedData,
StatusInvalidFramePayloadData,
StatusPolicyViolation,
StatusMessageTooBig,
StatusMandatoryExt,
StatusInternalServerError,
StatusNoStatusRcvd,
StatusAbnormalClosure,
StatusTLSHandshake:
return true
}
return false
}
// IsProtocolReserved reports whether the code is defined by protocol specification
// to be reserved only for application usage purpose.
func (s StatusCode) IsProtocolReserved() bool {
switch s {
// [RFC6455]: {1005,1006,1015} is a reserved value and MUST NOT be set as a status code in a
// Close control frame by an endpoint.
case StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
return true
default:
return false
}
}
// Compiled control frames for common use cases.
// For construct-serialize optimizations.
var (
CompiledPing = MustCompileFrame(NewPingFrame(nil))
CompiledPong = MustCompileFrame(NewPongFrame(nil))
CompiledClose = MustCompileFrame(NewCloseFrame(nil))
CompiledCloseNormalClosure = MustCompileFrame(closeFrameNormalClosure)
CompiledCloseGoingAway = MustCompileFrame(closeFrameGoingAway)
CompiledCloseProtocolError = MustCompileFrame(closeFrameProtocolError)
CompiledCloseUnsupportedData = MustCompileFrame(closeFrameUnsupportedData)
CompiledCloseNoMeaningYet = MustCompileFrame(closeFrameNoMeaningYet)
CompiledCloseInvalidFramePayloadData = MustCompileFrame(closeFrameInvalidFramePayloadData)
CompiledClosePolicyViolation = MustCompileFrame(closeFramePolicyViolation)
CompiledCloseMessageTooBig = MustCompileFrame(closeFrameMessageTooBig)
CompiledCloseMandatoryExt = MustCompileFrame(closeFrameMandatoryExt)
CompiledCloseInternalServerError = MustCompileFrame(closeFrameInternalServerError)
CompiledCloseTLSHandshake = MustCompileFrame(closeFrameTLSHandshake)
)
// Header represents websocket frame header.
// See https://tools.ietf.org/html/rfc6455#section-5.2
type Header struct {
Fin bool
Rsv byte
OpCode OpCode
Masked bool
Mask [4]byte
Length int64
}
// Rsv1 reports whether the header has first rsv bit set.
func (h Header) Rsv1() bool { return h.Rsv&bit5 != 0 }
// Rsv2 reports whether the header has second rsv bit set.
func (h Header) Rsv2() bool { return h.Rsv&bit6 != 0 }
// Rsv3 reports whether the header has third rsv bit set.
func (h Header) Rsv3() bool { return h.Rsv&bit7 != 0 }
// Rsv creates rsv byte representation from bits.
func Rsv(r1, r2, r3 bool) (rsv byte) {
if r1 {
rsv |= bit5
}
if r2 {
rsv |= bit6
}
if r3 {
rsv |= bit7
}
return rsv
}
// RsvBits returns rsv bits from bytes representation.
func RsvBits(rsv byte) (r1, r2, r3 bool) {
r1 = rsv&bit5 != 0
r2 = rsv&bit6 != 0
r3 = rsv&bit7 != 0
return
}
// Frame represents websocket frame.
// See https://tools.ietf.org/html/rfc6455#section-5.2
type Frame struct {
Header Header
Payload []byte
}
// NewFrame creates frame with given operation code,
// flag of completeness and payload bytes.
func NewFrame(op OpCode, fin bool, p []byte) Frame {
return Frame{
Header: Header{
Fin: fin,
OpCode: op,
Length: int64(len(p)),
},
Payload: p,
}
}
// NewTextFrame creates text frame with p as payload.
// Note that p is not copied.
func NewTextFrame(p []byte) Frame {
return NewFrame(OpText, true, p)
}
// NewBinaryFrame creates binary frame with p as payload.
// Note that p is not copied.
func NewBinaryFrame(p []byte) Frame {
return NewFrame(OpBinary, true, p)
}
// NewPingFrame creates ping frame with p as payload.
// Note that p is not copied.
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
// to RFC.
func NewPingFrame(p []byte) Frame {
return NewFrame(OpPing, true, p)
}
// NewPongFrame creates pong frame with p as payload.
// Note that p is not copied.
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
// to RFC.
func NewPongFrame(p []byte) Frame {
return NewFrame(OpPong, true, p)
}
// NewCloseFrame creates close frame with given close body.
// Note that p is not copied.
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
// to RFC.
func NewCloseFrame(p []byte) Frame {
return NewFrame(OpClose, true, p)
}
// NewCloseFrameBody encodes a closure code and a reason into a binary
// representation.
//
// It returns slice which is at most MaxControlFramePayloadSize bytes length.
// If the reason is too big it will be cropped to fit the limit defined by the
// spec.
//
// See https://tools.ietf.org/html/rfc6455#section-5.5
func NewCloseFrameBody(code StatusCode, reason string) []byte {
n := min(2+len(reason), MaxControlFramePayloadSize)
p := make([]byte, n)
crop := min(MaxControlFramePayloadSize-2, len(reason))
PutCloseFrameBody(p, code, reason[:crop])
return p
}
// PutCloseFrameBody encodes code and reason into buf.
//
// It will panic if the buffer is too small to accommodate a code or a reason.
//
// PutCloseFrameBody does not check buffer to be RFC compliant, but note that
// by RFC it must be at most MaxControlFramePayloadSize.
func PutCloseFrameBody(p []byte, code StatusCode, reason string) {
_ = p[1+len(reason)]
binary.BigEndian.PutUint16(p, uint16(code))
copy(p[2:], reason)
}
// MaskFrame masks frame and returns frame with masked payload and Mask header's field set.
// Note that it copies f payload to prevent collisions.
// For less allocations you could use MaskFrameInPlace or construct frame manually.
func MaskFrame(f Frame) Frame {
return MaskFrameWith(f, NewMask())
}
// MaskFrameWith masks frame with given mask and returns frame
// with masked payload and Mask header's field set.
// Note that it copies f payload to prevent collisions.
// For less allocations you could use MaskFrameInPlaceWith or construct frame manually.
func MaskFrameWith(f Frame, mask [4]byte) Frame {
// TODO(gobwas): check CopyCipher ws copy() Cipher().
p := make([]byte, len(f.Payload))
copy(p, f.Payload)
f.Payload = p
return MaskFrameInPlaceWith(f, mask)
}
// MaskFrameInPlace masks frame and returns frame with masked payload and Mask
// header's field set.
// Note that it applies xor cipher to f.Payload without copying, that is, it
// modifies f.Payload inplace.
func MaskFrameInPlace(f Frame) Frame {
return MaskFrameInPlaceWith(f, NewMask())
}
var zeroMask [4]byte
// UnmaskFrame unmasks frame and returns frame with unmasked payload and Mask
// header's field cleared.
// Note that it copies f payload.
func UnmaskFrame(f Frame) Frame {
p := make([]byte, len(f.Payload))
copy(p, f.Payload)
f.Payload = p
return UnmaskFrameInPlace(f)
}
// UnmaskFrameInPlace unmasks frame and returns frame with unmasked payload and
// Mask header's field cleared.
// Note that it applies xor cipher to f.Payload without copying, that is, it
// modifies f.Payload inplace.
func UnmaskFrameInPlace(f Frame) Frame {
Cipher(f.Payload, f.Header.Mask, 0)
f.Header.Masked = false
f.Header.Mask = zeroMask
return f
}
// MaskFrameInPlaceWith masks frame with given mask and returns frame
// with masked payload and Mask header's field set.
// Note that it applies xor cipher to f.Payload without copying, that is, it
// modifies f.Payload inplace.
func MaskFrameInPlaceWith(f Frame, m [4]byte) Frame {
f.Header.Masked = true
f.Header.Mask = m
Cipher(f.Payload, m, 0)
return f
}
// NewMask creates new random mask.
func NewMask() (ret [4]byte) {
binary.BigEndian.PutUint32(ret[:], rand.Uint32())
return
}
// CompileFrame returns byte representation of given frame.
// In terms of memory consumption it is useful to precompile static frames
// which are often used.
func CompileFrame(f Frame) (bts []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, 16))
err = WriteFrame(buf, f)
bts = buf.Bytes()
return
}
// MustCompileFrame is like CompileFrame but panics if frame can not be
// encoded.
func MustCompileFrame(f Frame) []byte {
bts, err := CompileFrame(f)
if err != nil {
panic(err)
}
return bts
}
func makeCloseFrame(code StatusCode) Frame {
return NewCloseFrame(NewCloseFrameBody(code, ""))
}
var (
closeFrameNormalClosure = makeCloseFrame(StatusNormalClosure)
closeFrameGoingAway = makeCloseFrame(StatusGoingAway)
closeFrameProtocolError = makeCloseFrame(StatusProtocolError)
closeFrameUnsupportedData = makeCloseFrame(StatusUnsupportedData)
closeFrameNoMeaningYet = makeCloseFrame(StatusNoMeaningYet)
closeFrameInvalidFramePayloadData = makeCloseFrame(StatusInvalidFramePayloadData)
closeFramePolicyViolation = makeCloseFrame(StatusPolicyViolation)
closeFrameMessageTooBig = makeCloseFrame(StatusMessageTooBig)
closeFrameMandatoryExt = makeCloseFrame(StatusMandatoryExt)
closeFrameInternalServerError = makeCloseFrame(StatusInternalServerError)
closeFrameTLSHandshake = makeCloseFrame(StatusTLSHandshake)
)

View file

@ -0,0 +1,9 @@
module github.com/gobwas/ws
go 1.15
require (
github.com/gobwas/httphead v0.1.0
github.com/gobwas/pool v0.2.1
golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d // indirect
)

View file

@ -0,0 +1,6 @@
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d h1:MiWWjyhUzZ+jvhZvloX6ZrUsdEghn8a64Upd8EMHglE=
golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View file

@ -0,0 +1,504 @@
package ws
import (
"bufio"
"bytes"
"io"
"net/http"
"net/textproto"
"net/url"
"strconv"
"github.com/gobwas/httphead"
)
const (
crlf = "\r\n"
colonAndSpace = ": "
commaAndSpace = ", "
)
const (
textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
)
var (
textHeadBadRequest = statusText(http.StatusBadRequest)
textHeadInternalServerError = statusText(http.StatusInternalServerError)
textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired)
textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol)
textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod)
textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost)
textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade)
textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept)
textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey)
textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired)
)
var (
headerHost = "Host"
headerUpgrade = "Upgrade"
headerConnection = "Connection"
headerSecVersion = "Sec-WebSocket-Version"
headerSecProtocol = "Sec-WebSocket-Protocol"
headerSecExtensions = "Sec-WebSocket-Extensions"
headerSecKey = "Sec-WebSocket-Key"
headerSecAccept = "Sec-WebSocket-Accept"
headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost)
headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade)
headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection)
headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion)
headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol)
headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)
headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey)
headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept)
)
var (
specHeaderValueUpgrade = []byte("websocket")
specHeaderValueConnection = []byte("Upgrade")
specHeaderValueConnectionLower = []byte("upgrade")
specHeaderValueSecVersion = []byte("13")
)
var (
httpVersion1_0 = []byte("HTTP/1.0")
httpVersion1_1 = []byte("HTTP/1.1")
httpVersionPrefix = []byte("HTTP/")
)
type httpRequestLine struct {
method, uri []byte
major, minor int
}
type httpResponseLine struct {
major, minor int
status int
reason []byte
}
// httpParseRequestLine parses http request line like "GET / HTTP/1.0".
func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
var proto []byte
req.method, req.uri, proto = bsplit3(line, ' ')
var ok bool
req.major, req.minor, ok = httpParseVersion(proto)
if !ok {
err = ErrMalformedRequest
return
}
return
}
func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
var (
proto []byte
status []byte
)
proto, status, resp.reason = bsplit3(line, ' ')
var ok bool
resp.major, resp.minor, ok = httpParseVersion(proto)
if !ok {
return resp, ErrMalformedResponse
}
var convErr error
resp.status, convErr = asciiToInt(status)
if convErr != nil {
return resp, ErrMalformedResponse
}
return resp, nil
}
// httpParseVersion parses major and minor version of HTTP protocol. It returns
// parsed values and true if parse is ok.
func httpParseVersion(bts []byte) (major, minor int, ok bool) {
switch {
case bytes.Equal(bts, httpVersion1_0):
return 1, 0, true
case bytes.Equal(bts, httpVersion1_1):
return 1, 1, true
case len(bts) < 8:
return
case !bytes.Equal(bts[:5], httpVersionPrefix):
return
}
bts = bts[5:]
dot := bytes.IndexByte(bts, '.')
if dot == -1 {
return
}
var err error
major, err = asciiToInt(bts[:dot])
if err != nil {
return
}
minor, err = asciiToInt(bts[dot+1:])
if err != nil {
return
}
return major, minor, true
}
// httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
// values and true if parse is ok.
func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
colon := bytes.IndexByte(line, ':')
if colon == -1 {
return
}
k = btrim(line[:colon])
// TODO(gobwas): maybe use just lower here?
canonicalizeHeaderKey(k)
v = btrim(line[colon+1:])
return k, v, true
}
// httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
// that key is already canonical. This helps to increase performance.
func httpGetHeader(h http.Header, key string) string {
if h == nil {
return ""
}
v := h[key]
if len(v) == 0 {
return ""
}
return v[0]
}
// The request MAY include a header field with the name
// |Sec-WebSocket-Protocol|. If present, this value indicates one or more
// comma-separated subprotocol the client wishes to speak, ordered by
// preference. The elements that comprise this value MUST be non-empty strings
// with characters in the range U+0021 to U+007E not including separator
// characters as defined in [RFC2616] and MUST all be unique strings. The ABNF
// for the value of this header field is 1#token, where the definitions of
// constructs and rules are as given in [RFC2616].
func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
if check(btsToString(v)) {
ret = string(v)
return false
}
return true
})
return
}
func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
var selected []byte
ok = httphead.ScanTokens(h, func(v []byte) bool {
if check(v) {
selected = v
return false
}
return true
})
if ok && selected != nil {
return string(selected), true
}
return
}
func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
s := httphead.OptionSelector{
Flags: httphead.SelectCopy,
Check: check,
}
return s.Select(h, selected)
}
func negotiateMaybe(in httphead.Option, dest []httphead.Option, f func(httphead.Option) (httphead.Option, error)) ([]httphead.Option, error) {
if in.Size() == 0 {
return dest, nil
}
opt, err := f(in)
if err != nil {
return nil, err
}
if opt.Size() > 0 {
dest = append(dest, opt)
}
return dest, nil
}
func negotiateExtensions(
h []byte, dest []httphead.Option,
f func(httphead.Option) (httphead.Option, error),
) (_ []httphead.Option, err error) {
index := -1
var current httphead.Option
ok := httphead.ScanOptions(h, func(i int, name, attr, val []byte) httphead.Control {
if i != index {
dest, err = negotiateMaybe(current, dest, f)
if err != nil {
return httphead.ControlBreak
}
index = i
current = httphead.Option{Name: name}
}
if attr != nil {
current.Parameters.Set(attr, val)
}
return httphead.ControlContinue
})
if !ok {
return nil, ErrMalformedRequest
}
return negotiateMaybe(current, dest, f)
}
func httpWriteHeader(bw *bufio.Writer, key, value string) {
httpWriteHeaderKey(bw, key)
bw.WriteString(value)
bw.WriteString(crlf)
}
func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
httpWriteHeaderKey(bw, key)
bw.Write(value)
bw.WriteString(crlf)
}
func httpWriteHeaderKey(bw *bufio.Writer, key string) {
bw.WriteString(key)
bw.WriteString(colonAndSpace)
}
func httpWriteUpgradeRequest(
bw *bufio.Writer,
u *url.URL,
nonce []byte,
protocols []string,
extensions []httphead.Option,
header HandshakeHeader,
) {
bw.WriteString("GET ")
bw.WriteString(u.RequestURI())
bw.WriteString(" HTTP/1.1\r\n")
httpWriteHeader(bw, headerHost, u.Host)
httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
// NOTE: write nonce bytes as a string to prevent heap allocation
// WriteString() copy given string into its inner buffer, unlike Write()
// which may write p directly to the underlying io.Writer which in turn
// will lead to p escape.
httpWriteHeader(bw, headerSecKey, btsToString(nonce))
if len(protocols) > 0 {
httpWriteHeaderKey(bw, headerSecProtocol)
for i, p := range protocols {
if i > 0 {
bw.WriteString(commaAndSpace)
}
bw.WriteString(p)
}
bw.WriteString(crlf)
}
if len(extensions) > 0 {
httpWriteHeaderKey(bw, headerSecExtensions)
httphead.WriteOptions(bw, extensions)
bw.WriteString(crlf)
}
if header != nil {
header.WriteTo(bw)
}
bw.WriteString(crlf)
}
func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
bw.WriteString(textHeadUpgrade)
httpWriteHeaderKey(bw, headerSecAccept)
writeAccept(bw, nonce)
bw.WriteString(crlf)
if hs.Protocol != "" {
httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
}
if len(hs.Extensions) > 0 {
httpWriteHeaderKey(bw, headerSecExtensions)
httphead.WriteOptions(bw, hs.Extensions)
bw.WriteString(crlf)
}
if header != nil {
header(bw)
}
bw.WriteString(crlf)
}
func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
switch code {
case http.StatusBadRequest:
bw.WriteString(textHeadBadRequest)
case http.StatusInternalServerError:
bw.WriteString(textHeadInternalServerError)
case http.StatusUpgradeRequired:
bw.WriteString(textHeadUpgradeRequired)
default:
writeStatusText(bw, code)
}
// Write custom headers.
if header != nil {
header(bw)
}
switch err {
case ErrHandshakeBadProtocol:
bw.WriteString(textTailErrHandshakeBadProtocol)
case ErrHandshakeBadMethod:
bw.WriteString(textTailErrHandshakeBadMethod)
case ErrHandshakeBadHost:
bw.WriteString(textTailErrHandshakeBadHost)
case ErrHandshakeBadUpgrade:
bw.WriteString(textTailErrHandshakeBadUpgrade)
case ErrHandshakeBadConnection:
bw.WriteString(textTailErrHandshakeBadConnection)
case ErrHandshakeBadSecAccept:
bw.WriteString(textTailErrHandshakeBadSecAccept)
case ErrHandshakeBadSecKey:
bw.WriteString(textTailErrHandshakeBadSecKey)
case ErrHandshakeBadSecVersion:
bw.WriteString(textTailErrHandshakeBadSecVersion)
case ErrHandshakeUpgradeRequired:
bw.WriteString(textTailErrUpgradeRequired)
case nil:
bw.WriteString(crlf)
default:
writeErrorText(bw, err)
}
}
func writeStatusText(bw *bufio.Writer, code int) {
bw.WriteString("HTTP/1.1 ")
bw.WriteString(strconv.Itoa(code))
bw.WriteByte(' ')
bw.WriteString(http.StatusText(code))
bw.WriteString(crlf)
bw.WriteString("Content-Type: text/plain; charset=utf-8")
bw.WriteString(crlf)
}
func writeErrorText(bw *bufio.Writer, err error) {
body := err.Error()
bw.WriteString("Content-Length: ")
bw.WriteString(strconv.Itoa(len(body)))
bw.WriteString(crlf)
bw.WriteString(crlf)
bw.WriteString(body)
}
// httpError is like the http.Error with WebSocket context exception.
func httpError(w http.ResponseWriter, body string, code int) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
w.WriteHeader(code)
w.Write([]byte(body))
}
// statusText is a non-performant status text generator.
// NOTE: Used only to generate constants.
func statusText(code int) string {
var buf bytes.Buffer
bw := bufio.NewWriter(&buf)
writeStatusText(bw, code)
bw.Flush()
return buf.String()
}
// errorText is a non-performant error text generator.
// NOTE: Used only to generate constants.
func errorText(err error) string {
var buf bytes.Buffer
bw := bufio.NewWriter(&buf)
writeErrorText(bw, err)
bw.Flush()
return buf.String()
}
// HandshakeHeader is the interface that writes both upgrade request or
// response headers into a given io.Writer.
type HandshakeHeader interface {
io.WriterTo
}
// HandshakeHeaderString is an adapter to allow the use of headers represented
// by ordinary string as HandshakeHeader.
type HandshakeHeaderString string
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
n, err := io.WriteString(w, string(s))
return int64(n), err
}
// HandshakeHeaderBytes is an adapter to allow the use of headers represented
// by ordinary slice of bytes as HandshakeHeader.
type HandshakeHeaderBytes []byte
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(b)
return int64(n), err
}
// HandshakeHeaderFunc is an adapter to allow the use of headers represented by
// ordinary function as HandshakeHeader.
type HandshakeHeaderFunc func(io.Writer) (int64, error)
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
return f(w)
}
// HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
// HandshakeHeader.
type HandshakeHeaderHTTP http.Header
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
wr := writer{w: w}
err := http.Header(h).Write(&wr)
return wr.n, err
}
type writer struct {
n int64
w io.Writer
}
func (w *writer) WriteString(s string) (int, error) {
n, err := io.WriteString(w.w, s)
w.n += int64(n)
return n, err
}
func (w *writer) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.n += int64(n)
return n, err
}

View file

@ -0,0 +1,80 @@
package ws
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"fmt"
"math/rand"
)
const (
// RFC6455: The value of this header field MUST be a nonce consisting of a
// randomly selected 16-byte value that has been base64-encoded (see
// Section 4 of [RFC4648]). The nonce MUST be selected randomly for each
// connection.
nonceKeySize = 16
nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
// RFC6455: The value of this header field is constructed by concatenating
// /key/, defined above in step 4 in Section 4.2.2, with the string
// "258EAFA5- E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this
// concatenated value to obtain a 20-byte value and base64- encoding (see
// Section 4 of [RFC4648]) this 20-byte hash.
acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size)
)
// initNonce fills given slice with random base64-encoded nonce bytes.
func initNonce(dst []byte) {
// NOTE: bts does not escape.
bts := make([]byte, nonceKeySize)
if _, err := rand.Read(bts); err != nil {
panic(fmt.Sprintf("rand read error: %s", err))
}
base64.StdEncoding.Encode(dst, bts)
}
// checkAcceptFromNonce reports whether given accept bytes are valid for given
// nonce bytes.
func checkAcceptFromNonce(accept, nonce []byte) bool {
if len(accept) != acceptSize {
return false
}
// NOTE: expect does not escape.
expect := make([]byte, acceptSize)
initAcceptFromNonce(expect, nonce)
return bytes.Equal(expect, accept)
}
// initAcceptFromNonce fills given slice with accept bytes generated from given
// nonce bytes. Given buffer should be exactly acceptSize bytes.
func initAcceptFromNonce(accept, nonce []byte) {
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
if len(accept) != acceptSize {
panic("accept buffer is invalid")
}
if len(nonce) != nonceSize {
panic("nonce is invalid")
}
p := make([]byte, nonceSize+len(magic))
copy(p[:nonceSize], nonce)
copy(p[nonceSize:], magic)
sum := sha1.Sum(p)
base64.StdEncoding.Encode(accept, sum[:])
return
}
func writeAccept(bw *bufio.Writer, nonce []byte) (int, error) {
accept := make([]byte, acceptSize)
initAcceptFromNonce(accept, nonce)
// NOTE: write accept bytes as a string to prevent heap allocation
// WriteString() copy given string into its inner buffer, unlike Write()
// which may write p directly to the underlying io.Writer which in turn
// will lead to p escape.
return bw.WriteString(btsToString(accept))
}

View file

@ -0,0 +1,147 @@
package ws
import (
"encoding/binary"
"fmt"
"io"
)
// Errors used by frame reader.
var (
ErrHeaderLengthMSB = fmt.Errorf("header error: the most significant bit must be 0")
ErrHeaderLengthUnexpected = fmt.Errorf("header error: unexpected payload length bits")
)
// ReadHeader reads a frame header from r.
func ReadHeader(r io.Reader) (h Header, err error) {
// Make slice of bytes with capacity 12 that could hold any header.
//
// The maximum header size is 14, but due to the 2 hop reads,
// after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
// So 14 - 2 = 12.
bts := make([]byte, 2, MaxHeaderSize-2)
// Prepare to hold first 2 bytes to choose size of next read.
_, err = io.ReadFull(r, bts)
if err != nil {
return
}
h.Fin = bts[0]&bit0 != 0
h.Rsv = (bts[0] & 0x70) >> 4
h.OpCode = OpCode(bts[0] & 0x0f)
var extra int
if bts[1]&bit0 != 0 {
h.Masked = true
extra += 4
}
length := bts[1] & 0x7f
switch {
case length < 126:
h.Length = int64(length)
case length == 126:
extra += 2
case length == 127:
extra += 8
default:
err = ErrHeaderLengthUnexpected
return
}
if extra == 0 {
return
}
// Increase len of bts to extra bytes need to read.
// Overwrite first 2 bytes that was read before.
bts = bts[:extra]
_, err = io.ReadFull(r, bts)
if err != nil {
return
}
switch {
case length == 126:
h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
bts = bts[2:]
case length == 127:
if bts[0]&0x80 != 0 {
err = ErrHeaderLengthMSB
return
}
h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
bts = bts[8:]
}
if h.Masked {
copy(h.Mask[:], bts)
}
return
}
// ReadFrame reads a frame from r.
// It is not designed for high optimized use case cause it makes allocation
// for frame.Header.Length size inside to read frame payload into.
//
// Note that ReadFrame does not unmask payload.
func ReadFrame(r io.Reader) (f Frame, err error) {
f.Header, err = ReadHeader(r)
if err != nil {
return
}
if f.Header.Length > 0 {
// int(f.Header.Length) is safe here cause we have
// checked it for overflow above in ReadHeader.
f.Payload = make([]byte, int(f.Header.Length))
_, err = io.ReadFull(r, f.Payload)
}
return
}
// MustReadFrame is like ReadFrame but panics if frame can not be read.
func MustReadFrame(r io.Reader) Frame {
f, err := ReadFrame(r)
if err != nil {
panic(err)
}
return f
}
// ParseCloseFrameData parses close frame status code and closure reason if any provided.
// If there is no status code in the payload
// the empty status code is returned (code.Empty()) with empty string as a reason.
func ParseCloseFrameData(payload []byte) (code StatusCode, reason string) {
if len(payload) < 2 {
// We returning empty StatusCode here, preventing the situation
// when endpoint really sent code 1005 and we should return ProtocolError on that.
//
// In other words, we ignoring this rule [RFC6455:7.1.5]:
// If this Close control frame contains no status code, _The WebSocket
// Connection Close Code_ is considered to be 1005.
return
}
code = StatusCode(binary.BigEndian.Uint16(payload))
reason = string(payload[2:])
return
}
// ParseCloseFrameDataUnsafe is like ParseCloseFrameData except the thing
// that it does not copies payload bytes into reason, but prepares unsafe cast.
func ParseCloseFrameDataUnsafe(payload []byte) (code StatusCode, reason string) {
if len(payload) < 2 {
return
}
code = StatusCode(binary.BigEndian.Uint16(payload))
reason = btsToString(payload[2:])
return
}

View file

@ -0,0 +1,663 @@
package ws
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/gobwas/httphead"
"github.com/gobwas/pool/pbufio"
)
// Constants used by ConnUpgrader.
const (
DefaultServerReadBufferSize = 4096
DefaultServerWriteBufferSize = 512
)
// Errors used by both client and server when preparing WebSocket handshake.
var (
ErrHandshakeBadProtocol = RejectConnectionError(
RejectionStatus(http.StatusHTTPVersionNotSupported),
RejectionReason(fmt.Sprintf("handshake error: bad HTTP protocol version")),
)
ErrHandshakeBadMethod = RejectConnectionError(
RejectionStatus(http.StatusMethodNotAllowed),
RejectionReason(fmt.Sprintf("handshake error: bad HTTP request method")),
)
ErrHandshakeBadHost = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerHost)),
)
ErrHandshakeBadUpgrade = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerUpgrade)),
)
ErrHandshakeBadConnection = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerConnection)),
)
ErrHandshakeBadSecAccept = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecAccept)),
)
ErrHandshakeBadSecKey = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecKey)),
)
ErrHandshakeBadSecVersion = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
)
)
// ErrMalformedResponse is returned by Dialer to indicate that server response
// can not be parsed.
var ErrMalformedResponse = fmt.Errorf("malformed HTTP response")
// ErrMalformedRequest is returned when HTTP request can not be parsed.
var ErrMalformedRequest = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason("malformed HTTP request"),
)
// ErrHandshakeUpgradeRequired is returned by Upgrader to indicate that
// connection is rejected because given WebSocket version is malformed.
//
// According to RFC6455:
// If this version does not match a version understood by the server, the
// server MUST abort the WebSocket handshake described in this section and
// instead send an appropriate HTTP error code (such as 426 Upgrade Required)
// and a |Sec-WebSocket-Version| header field indicating the version(s) the
// server is capable of understanding.
var ErrHandshakeUpgradeRequired = RejectConnectionError(
RejectionStatus(http.StatusUpgradeRequired),
RejectionHeader(HandshakeHeaderString(headerSecVersion+": 13\r\n")),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
)
// ErrNotHijacker is an error returned when http.ResponseWriter does not
// implement http.Hijacker interface.
var ErrNotHijacker = RejectConnectionError(
RejectionStatus(http.StatusInternalServerError),
RejectionReason("given http.ResponseWriter is not a http.Hijacker"),
)
// DefaultHTTPUpgrader is an HTTPUpgrader that holds no options and is used by
// UpgradeHTTP function.
var DefaultHTTPUpgrader HTTPUpgrader
// UpgradeHTTP is like HTTPUpgrader{}.Upgrade().
func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, Handshake, error) {
return DefaultHTTPUpgrader.Upgrade(r, w)
}
// DefaultUpgrader is an Upgrader that holds no options and is used by Upgrade
// function.
var DefaultUpgrader Upgrader
// Upgrade is like Upgrader{}.Upgrade().
func Upgrade(conn io.ReadWriter) (Handshake, error) {
return DefaultUpgrader.Upgrade(conn)
}
// HTTPUpgrader contains options for upgrading connection to websocket from
// net/http Handler arguments.
type HTTPUpgrader struct {
// Timeout is the maximum amount of time an Upgrade() will spent while
// writing handshake response.
//
// The default is no timeout.
Timeout time.Duration
// Header is an optional http.Header mapping that could be used to
// write additional headers to the handshake response.
//
// Note that if present, it will be written in any result of handshake.
Header http.Header
// Protocol is the select function that is used to select subprotocol from
// list requested by client. If this field is set, then the first matched
// protocol is sent to a client as negotiated.
Protocol func(string) bool
// Extension is the select function that is used to select extensions from
// list requested by client. If this field is set, then the all matched
// extensions are sent to a client as negotiated.
//
// DEPRECATED. Use Negotiate instead.
Extension func(httphead.Option) bool
// Negotiate is the callback that is used to negotiate extensions from
// the client's offer. If this field is set, then the returned non-zero
// extensions are sent to the client as accepted extensions in the
// response.
//
// The argument is only valid until the Negotiate callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
Negotiate func(httphead.Option) (httphead.Option, error)
}
// Upgrade upgrades http connection to the websocket connection.
//
// It hijacks net.Conn from w and returns received net.Conn and
// bufio.ReadWriter. On successful handshake it returns Handshake struct
// describing handshake info.
func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) {
// Hijack connection first to get the ability to write rejection errors the
// same way as in Upgrader.
hj, ok := w.(http.Hijacker)
if ok {
conn, rw, err = hj.Hijack()
} else {
err = ErrNotHijacker
}
if err != nil {
httpError(w, err.Error(), http.StatusInternalServerError)
return
}
// See https://tools.ietf.org/html/rfc6455#section-4.1
// The method of the request MUST be GET, and the HTTP version MUST be at least 1.1.
var nonce string
if r.Method != http.MethodGet {
err = ErrHandshakeBadMethod
} else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) {
err = ErrHandshakeBadProtocol
} else if r.Host == "" {
err = ErrHandshakeBadHost
} else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") {
err = ErrHandshakeBadUpgrade
} else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") {
err = ErrHandshakeBadConnection
} else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize {
err = ErrHandshakeBadSecKey
} else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" {
// According to RFC6455:
//
// If this version does not match a version understood by the server,
// the server MUST abort the WebSocket handshake described in this
// section and instead send an appropriate HTTP error code (such as 426
// Upgrade Required) and a |Sec-WebSocket-Version| header field
// indicating the version(s) the server is capable of understanding.
//
// So we branching here cause empty or not present version does not
// meet the ABNF rules of RFC6455:
//
// version = DIGIT | (NZDIGIT DIGIT) |
// ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
// ; Limited to 0-255 range, with no leading zeros
//
// That is, if version is really invalid we sent 426 status, if it
// not present or empty it is 400.
if v != "" {
err = ErrHandshakeUpgradeRequired
} else {
err = ErrHandshakeBadSecVersion
}
}
if check := u.Protocol; err == nil && check != nil {
ps := r.Header[headerSecProtocolCanonical]
for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ {
var ok bool
hs.Protocol, ok = strSelectProtocol(ps[i], check)
if !ok {
err = ErrMalformedRequest
}
}
}
if f := u.Negotiate; err == nil && f != nil {
for _, h := range r.Header[headerSecExtensionsCanonical] {
hs.Extensions, err = negotiateExtensions(strToBytes(h), hs.Extensions, f)
if err != nil {
break
}
}
}
// DEPRECATED path.
if check := u.Extension; err == nil && check != nil && u.Negotiate == nil {
xs := r.Header[headerSecExtensionsCanonical]
for i := 0; i < len(xs) && err == nil; i++ {
var ok bool
hs.Extensions, ok = btsSelectExtensions(strToBytes(xs[i]), hs.Extensions, check)
if !ok {
err = ErrMalformedRequest
}
}
}
// Clear deadlines set by server.
conn.SetDeadline(noDeadline)
if t := u.Timeout; t != 0 {
conn.SetWriteDeadline(time.Now().Add(t))
defer conn.SetWriteDeadline(noDeadline)
}
var header handshakeHeader
if h := u.Header; h != nil {
header[0] = HandshakeHeaderHTTP(h)
}
if err == nil {
httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo)
err = rw.Writer.Flush()
} else {
var code int
if rej, ok := err.(*rejectConnectionError); ok {
code = rej.code
header[1] = rej.header
}
if code == 0 {
code = http.StatusInternalServerError
}
httpWriteResponseError(rw.Writer, err, code, header.WriteTo)
// Do not store Flush() error to not override already existing one.
rw.Writer.Flush()
}
return
}
// Upgrader contains options for upgrading connection to websocket.
type Upgrader struct {
// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
// They used to read and write http data while upgrading to WebSocket.
// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
//
// If a size is zero then default value is used.
//
// Usually it is useful to set read buffer size bigger than write buffer
// size because incoming request could contain long header values, such as
// Cookie. Response, in other way, could be big only if user write multiple
// custom headers. Usually response takes less than 256 bytes.
ReadBufferSize, WriteBufferSize int
// Protocol is a select function that is used to select subprotocol
// from list requested by client. If this field is set, then the first matched
// protocol is sent to a client as negotiated.
//
// The argument is only valid until the callback returns.
Protocol func([]byte) bool
// ProtocolCustrom allow user to parse Sec-WebSocket-Protocol header manually.
// Note that returned bytes must be valid until Upgrade returns.
// If ProtocolCustom is set, it used instead of Protocol function.
ProtocolCustom func([]byte) (string, bool)
// Extension is a select function that is used to select extensions
// from list requested by client. If this field is set, then the all matched
// extensions are sent to a client as negotiated.
//
// Note that Extension may be called multiple times and implementations
// must track uniqueness of accepted extensions manually.
//
// The argument is only valid until the callback returns.
//
// According to the RFC6455 order of extensions passed by a client is
// significant. That is, returning true from this function means that no
// other extension with the same name should be checked because server
// accepted the most preferable extension right now:
// "Note that the order of extensions is significant. Any interactions between
// multiple extensions MAY be defined in the documents defining the extensions.
// In the absence of such definitions, the interpretation is that the header
// fields listed by the client in its request represent a preference of the
// header fields it wishes to use, with the first options listed being most
// preferable."
//
// DEPRECATED. Use Negotiate instead.
Extension func(httphead.Option) bool
// ExtensionCustom allow user to parse Sec-WebSocket-Extensions header
// manually.
//
// If ExtensionCustom() decides to accept received extension, it must
// append appropriate option to the given slice of httphead.Option.
// It returns results of append() to the given slice and a flag that
// reports whether given header value is wellformed or not.
//
// Note that ExtensionCustom may be called multiple times and
// implementations must track uniqueness of accepted extensions manually.
//
// Note that returned options should be valid until Upgrade returns.
// If ExtensionCustom is set, it used instead of Extension function.
ExtensionCustom func([]byte, []httphead.Option) ([]httphead.Option, bool)
// Negotiate is the callback that is used to negotiate extensions from
// the client's offer. If this field is set, then the returned non-zero
// extensions are sent to the client as accepted extensions in the
// response.
//
// The argument is only valid until the Negotiate callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
Negotiate func(httphead.Option) (httphead.Option, error)
// Header is an optional HandshakeHeader instance that could be used to
// write additional headers to the handshake response.
//
// It used instead of any key-value mappings to avoid allocations in user
// land.
//
// Note that if present, it will be written in any result of handshake.
Header HandshakeHeader
// OnRequest is a callback that will be called after request line
// successful parsing.
//
// The arguments are only valid until the callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnRequest func(uri []byte) error
// OnHost is a callback that will be called after "Host" header successful
// parsing.
//
// It is separated from OnHeader callback because the Host header must be
// present in each request since HTTP/1.1. Thus Host header is non-optional
// and required for every WebSocket handshake.
//
// The arguments are only valid until the callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnHost func(host []byte) error
// OnHeader is a callback that will be called after successful parsing of
// header, that is not used during WebSocket handshake procedure. That is,
// it will be called with non-websocket headers, which could be relevant
// for application-level logic.
//
// The arguments are only valid until the callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnHeader func(key, value []byte) error
// OnBeforeUpgrade is a callback that will be called before sending
// successful upgrade response.
//
// Setting OnBeforeUpgrade allows user to make final application-level
// checks and decide whether this connection is allowed to successfully
// upgrade to WebSocket.
//
// It must return non-nil either HandshakeHeader or error and never both.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnBeforeUpgrade func() (header HandshakeHeader, err error)
}
// Upgrade zero-copy upgrades connection to WebSocket. It interprets given conn
// as connection with incoming HTTP Upgrade request.
//
// It is a caller responsibility to manage i/o timeouts on conn.
//
// Non-nil error means that request for the WebSocket upgrade is invalid or
// malformed and usually connection should be closed.
// Even when error is non-nil Upgrade will write appropriate response into
// connection in compliance with RFC.
func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
// headerSeen constants helps to report whether or not some header was seen
// during reading request bytes.
const (
headerSeenHost = 1 << iota
headerSeenUpgrade
headerSeenConnection
headerSeenSecVersion
headerSeenSecKey
// headerSeenAll is the value that we expect to receive at the end of
// headers read/parse loop.
headerSeenAll = 0 |
headerSeenHost |
headerSeenUpgrade |
headerSeenConnection |
headerSeenSecVersion |
headerSeenSecKey
)
// Prepare I/O buffers.
// TODO(gobwas): make it configurable.
br := pbufio.GetReader(conn,
nonZero(u.ReadBufferSize, DefaultServerReadBufferSize),
)
bw := pbufio.GetWriter(conn,
nonZero(u.WriteBufferSize, DefaultServerWriteBufferSize),
)
defer func() {
pbufio.PutReader(br)
pbufio.PutWriter(bw)
}()
// Read HTTP request line like "GET /ws HTTP/1.1".
rl, err := readLine(br)
if err != nil {
return
}
// Parse request line data like HTTP version, uri and method.
req, err := httpParseRequestLine(rl)
if err != nil {
return
}
// Prepare stack-based handshake header list.
header := handshakeHeader{
0: u.Header,
}
// Parse and check HTTP request.
// As RFC6455 says:
// The client's opening handshake consists of the following parts. If the
// server, while reading the handshake, finds that the client did not
// send a handshake that matches the description below (note that as per
// [RFC2616], the order of the header fields is not important), including
// but not limited to any violations of the ABNF grammar specified for
// the components of the handshake, the server MUST stop processing the
// client's handshake and return an HTTP response with an appropriate
// error code (such as 400 Bad Request).
//
// See https://tools.ietf.org/html/rfc6455#section-4.2.1
// An HTTP/1.1 or higher GET request, including a "Request-URI".
//
// Even if RFC says "1.1 or higher" without mentioning the part of the
// version, we apply it only to minor part.
switch {
case req.major != 1 || req.minor < 1:
// Abort processing the whole request because we do not even know how
// to actually parse it.
err = ErrHandshakeBadProtocol
case btsToString(req.method) != http.MethodGet:
err = ErrHandshakeBadMethod
default:
if onRequest := u.OnRequest; onRequest != nil {
err = onRequest(req.uri)
}
}
// Start headers read/parse loop.
var (
// headerSeen reports which header was seen by setting corresponding
// bit on.
headerSeen byte
nonce = make([]byte, nonceSize)
)
for err == nil {
line, e := readLine(br)
if e != nil {
return hs, e
}
if len(line) == 0 {
// Blank line, no more lines to read.
break
}
k, v, ok := httpParseHeaderLine(line)
if !ok {
err = ErrMalformedRequest
break
}
switch btsToString(k) {
case headerHostCanonical:
headerSeen |= headerSeenHost
if onHost := u.OnHost; onHost != nil {
err = onHost(v)
}
case headerUpgradeCanonical:
headerSeen |= headerSeenUpgrade
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
err = ErrHandshakeBadUpgrade
}
case headerConnectionCanonical:
headerSeen |= headerSeenConnection
if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) {
err = ErrHandshakeBadConnection
}
case headerSecVersionCanonical:
headerSeen |= headerSeenSecVersion
if !bytes.Equal(v, specHeaderValueSecVersion) {
err = ErrHandshakeUpgradeRequired
}
case headerSecKeyCanonical:
headerSeen |= headerSeenSecKey
if len(v) != nonceSize {
err = ErrHandshakeBadSecKey
} else {
copy(nonce[:], v)
}
case headerSecProtocolCanonical:
if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) {
var ok bool
if custom != nil {
hs.Protocol, ok = custom(v)
} else {
hs.Protocol, ok = btsSelectProtocol(v, check)
}
if !ok {
err = ErrMalformedRequest
}
}
case headerSecExtensionsCanonical:
if f := u.Negotiate; err == nil && f != nil {
hs.Extensions, err = negotiateExtensions(v, hs.Extensions, f)
}
// DEPRECATED path.
if custom, check := u.ExtensionCustom, u.Extension; u.Negotiate == nil && (custom != nil || check != nil) {
var ok bool
if custom != nil {
hs.Extensions, ok = custom(v, hs.Extensions)
} else {
hs.Extensions, ok = btsSelectExtensions(v, hs.Extensions, check)
}
if !ok {
err = ErrMalformedRequest
}
}
default:
if onHeader := u.OnHeader; onHeader != nil {
err = onHeader(k, v)
}
}
}
switch {
case err == nil && headerSeen != headerSeenAll:
switch {
case headerSeen&headerSeenHost == 0:
// As RFC2616 says:
// A client MUST include a Host header field in all HTTP/1.1
// request messages. If the requested URI does not include an
// Internet host name for the service being requested, then the
// Host header field MUST be given with an empty value. An
// HTTP/1.1 proxy MUST ensure that any request message it
// forwards does contain an appropriate Host header field that
// identifies the service being requested by the proxy. All
// Internet-based HTTP/1.1 servers MUST respond with a 400 (Bad
// Request) status code to any HTTP/1.1 request message which
// lacks a Host header field.
err = ErrHandshakeBadHost
case headerSeen&headerSeenUpgrade == 0:
err = ErrHandshakeBadUpgrade
case headerSeen&headerSeenConnection == 0:
err = ErrHandshakeBadConnection
case headerSeen&headerSeenSecVersion == 0:
// In case of empty or not present version we do not send 426 status,
// because it does not meet the ABNF rules of RFC6455:
//
// version = DIGIT | (NZDIGIT DIGIT) |
// ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
// ; Limited to 0-255 range, with no leading zeros
//
// That is, if version is really invalid we sent 426 status as above, if it
// not present it is 400.
err = ErrHandshakeBadSecVersion
case headerSeen&headerSeenSecKey == 0:
err = ErrHandshakeBadSecKey
default:
panic("unknown headers state")
}
case err == nil && u.OnBeforeUpgrade != nil:
header[1], err = u.OnBeforeUpgrade()
}
if err != nil {
var code int
if rej, ok := err.(*rejectConnectionError); ok {
code = rej.code
header[1] = rej.header
}
if code == 0 {
code = http.StatusInternalServerError
}
httpWriteResponseError(bw, err, code, header.WriteTo)
// Do not store Flush() error to not override already existing one.
bw.Flush()
return
}
httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo)
err = bw.Flush()
return
}
type handshakeHeader [2]HandshakeHeader
func (hs handshakeHeader) WriteTo(w io.Writer) (n int64, err error) {
for i := 0; i < len(hs) && err == nil; i++ {
if h := hs[i]; h != nil {
var m int64
m, err = h.WriteTo(w)
n += m
}
}
return n, err
}

View file

View file

@ -0,0 +1,214 @@
package ws
import (
"bufio"
"bytes"
"fmt"
"reflect"
"unsafe"
"github.com/gobwas/httphead"
)
// SelectFromSlice creates accept function that could be used as Protocol/Extension
// select during upgrade.
func SelectFromSlice(accept []string) func(string) bool {
if len(accept) > 16 {
mp := make(map[string]struct{}, len(accept))
for _, p := range accept {
mp[p] = struct{}{}
}
return func(p string) bool {
_, ok := mp[p]
return ok
}
}
return func(p string) bool {
for _, ok := range accept {
if p == ok {
return true
}
}
return false
}
}
// SelectEqual creates accept function that could be used as Protocol/Extension
// select during upgrade.
func SelectEqual(v string) func(string) bool {
return func(p string) bool {
return v == p
}
}
func strToBytes(str string) (bts []byte) {
s := (*reflect.StringHeader)(unsafe.Pointer(&str))
b := (*reflect.SliceHeader)(unsafe.Pointer(&bts))
b.Data = s.Data
b.Len = s.Len
b.Cap = s.Len
return
}
func btsToString(bts []byte) (str string) {
return *(*string)(unsafe.Pointer(&bts))
}
// asciiToInt converts bytes to int.
func asciiToInt(bts []byte) (ret int, err error) {
// ASCII numbers all start with the high-order bits 0011.
// If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those
// bits and interpret them directly as an integer.
var n int
if n = len(bts); n < 1 {
return 0, fmt.Errorf("converting empty bytes to int")
}
for i := 0; i < n; i++ {
if bts[i]&0xf0 != 0x30 {
return 0, fmt.Errorf("%s is not a numeric character", string(bts[i]))
}
ret += int(bts[i]&0xf) * pow(10, n-i-1)
}
return ret, nil
}
// pow for integers implementation.
// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3
func pow(a, b int) int {
p := 1
for b > 0 {
if b&1 != 0 {
p *= a
}
b >>= 1
a *= a
}
return p
}
func bsplit3(bts []byte, sep byte) (b1, b2, b3 []byte) {
a := bytes.IndexByte(bts, sep)
b := bytes.IndexByte(bts[a+1:], sep)
if a == -1 || b == -1 {
return bts, nil, nil
}
b += a + 1
return bts[:a], bts[a+1 : b], bts[b+1:]
}
func btrim(bts []byte) []byte {
var i, j int
for i = 0; i < len(bts) && (bts[i] == ' ' || bts[i] == '\t'); {
i++
}
for j = len(bts); j > i && (bts[j-1] == ' ' || bts[j-1] == '\t'); {
j--
}
return bts[i:j]
}
func strHasToken(header, token string) (has bool) {
return btsHasToken(strToBytes(header), strToBytes(token))
}
func btsHasToken(header, token []byte) (has bool) {
httphead.ScanTokens(header, func(v []byte) bool {
has = bytes.EqualFold(v, token)
return !has
})
return
}
const (
toLower = 'a' - 'A' // for use with OR.
toUpper = ^byte(toLower) // for use with AND.
toLower8 = uint64(toLower) |
uint64(toLower)<<8 |
uint64(toLower)<<16 |
uint64(toLower)<<24 |
uint64(toLower)<<32 |
uint64(toLower)<<40 |
uint64(toLower)<<48 |
uint64(toLower)<<56
)
// Algorithm below is like standard textproto/CanonicalMIMEHeaderKey, except
// that it operates with slice of bytes and modifies it inplace without copying.
func canonicalizeHeaderKey(k []byte) {
upper := true
for i, c := range k {
if upper && 'a' <= c && c <= 'z' {
k[i] &= toUpper
} else if !upper && 'A' <= c && c <= 'Z' {
k[i] |= toLower
}
upper = c == '-'
}
}
// readLine reads line from br. It reads until '\n' and returns bytes without
// '\n' or '\r\n' at the end.
// It returns err if and only if line does not end in '\n'. Note that read
// bytes returned in any case of error.
//
// It is much like the textproto/Reader.ReadLine() except the thing that it
// returns raw bytes, instead of string. That is, it avoids copying bytes read
// from br.
//
// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be
// safe with future I/O operations on br.
//
// We could control I/O operations on br and do not need to make additional
// copy for safety.
//
// NOTE: it may return copied flag to notify that returned buffer is safe to
// use.
func readLine(br *bufio.Reader) ([]byte, error) {
var line []byte
for {
bts, err := br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
// Copy bytes because next read will discard them.
line = append(line, bts...)
continue
}
// Avoid copy of single read.
if line == nil {
line = bts
} else {
line = append(line, bts...)
}
if err != nil {
return line, err
}
// Size of line is at least 1.
// In other case bufio.ReadSlice() returns error.
n := len(line)
// Cut '\n' or '\r\n'.
if n > 1 && line[n-2] == '\r' {
line = line[:n-2]
} else {
line = line[:n-1]
}
return line, nil
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func nonZero(a, b int) int {
if a != 0 {
return a
}
return b
}

View file

@ -0,0 +1,104 @@
package ws
import (
"encoding/binary"
"io"
)
// Header size length bounds in bytes.
const (
MaxHeaderSize = 14
MinHeaderSize = 2
)
const (
bit0 = 0x80
bit1 = 0x40
bit2 = 0x20
bit3 = 0x10
bit4 = 0x08
bit5 = 0x04
bit6 = 0x02
bit7 = 0x01
len7 = int64(125)
len16 = int64(^(uint16(0)))
len64 = int64(^(uint64(0)) >> 1)
)
// HeaderSize returns number of bytes that are needed to encode given header.
// It returns -1 if header is malformed.
func HeaderSize(h Header) (n int) {
switch {
case h.Length < 126:
n = 2
case h.Length <= len16:
n = 4
case h.Length <= len64:
n = 10
default:
return -1
}
if h.Masked {
n += len(h.Mask)
}
return n
}
// WriteHeader writes header binary representation into w.
func WriteHeader(w io.Writer, h Header) error {
// Make slice of bytes with capacity 14 that could hold any header.
bts := make([]byte, MaxHeaderSize)
if h.Fin {
bts[0] |= bit0
}
bts[0] |= h.Rsv << 4
bts[0] |= byte(h.OpCode)
var n int
switch {
case h.Length <= len7:
bts[1] = byte(h.Length)
n = 2
case h.Length <= len16:
bts[1] = 126
binary.BigEndian.PutUint16(bts[2:4], uint16(h.Length))
n = 4
case h.Length <= len64:
bts[1] = 127
binary.BigEndian.PutUint64(bts[2:10], uint64(h.Length))
n = 10
default:
return ErrHeaderLengthUnexpected
}
if h.Masked {
bts[1] |= bit0
n += copy(bts[n:], h.Mask[:])
}
_, err := w.Write(bts[:n])
return err
}
// WriteFrame writes frame binary representation into w.
func WriteFrame(w io.Writer, f Frame) error {
err := WriteHeader(w, f.Header)
if err != nil {
return err
}
_, err = w.Write(f.Payload)
return err
}
// MustWriteFrame is like WriteFrame but panics if frame can not be read.
func MustWriteFrame(w io.Writer, f Frame) {
if err := WriteFrame(w, f); err != nil {
panic(err)
}
}

View file

@ -0,0 +1,72 @@
package wsutil
import (
"io"
"github.com/gobwas/pool/pbytes"
"github.com/gobwas/ws"
)
// CipherReader implements io.Reader that applies xor-cipher to the bytes read
// from source.
// It could help to unmask WebSocket frame payload on the fly.
type CipherReader struct {
r io.Reader
mask [4]byte
pos int
}
// NewCipherReader creates xor-cipher reader from r with given mask.
func NewCipherReader(r io.Reader, mask [4]byte) *CipherReader {
return &CipherReader{r, mask, 0}
}
// Reset resets CipherReader to read from r with given mask.
func (c *CipherReader) Reset(r io.Reader, mask [4]byte) {
c.r = r
c.mask = mask
c.pos = 0
}
// Read implements io.Reader interface. It applies mask given during
// initialization to every read byte.
func (c *CipherReader) Read(p []byte) (n int, err error) {
n, err = c.r.Read(p)
ws.Cipher(p[:n], c.mask, c.pos)
c.pos += n
return
}
// CipherWriter implements io.Writer that applies xor-cipher to the bytes
// written to the destination writer. It does not modify the original bytes.
type CipherWriter struct {
w io.Writer
mask [4]byte
pos int
}
// NewCipherWriter creates xor-cipher writer to w with given mask.
func NewCipherWriter(w io.Writer, mask [4]byte) *CipherWriter {
return &CipherWriter{w, mask, 0}
}
// Reset reset CipherWriter to write to w with given mask.
func (c *CipherWriter) Reset(w io.Writer, mask [4]byte) {
c.w = w
c.mask = mask
c.pos = 0
}
// Write implements io.Writer interface. It applies masking during
// initialization to every sent byte. It does not modify original slice.
func (c *CipherWriter) Write(p []byte) (n int, err error) {
cp := pbytes.GetLen(len(p))
defer pbytes.Put(cp)
copy(cp, p)
ws.Cipher(cp, c.mask, c.pos)
n, err = c.w.Write(cp)
c.pos += n
return
}

View file

@ -0,0 +1,146 @@
package wsutil
import (
"bufio"
"bytes"
"context"
"io"
"io/ioutil"
"net"
"net/http"
"github.com/gobwas/ws"
)
// DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
// handshake. That is, it gives ability to receive copied HTTP request and
// response bytes that made inside Dialer.Dial().
//
// Note that it must not be used in production applications that requires
// Dial() to be efficient.
type DebugDialer struct {
// Dialer contains WebSocket connection establishment options.
Dialer ws.Dialer
// OnRequest and OnResponse are the callbacks that will be called with the
// HTTP request and response respectively.
OnRequest, OnResponse func([]byte)
}
// Dial connects to the url host and upgrades connection to WebSocket. It makes
// it by calling d.Dialer.Dial().
func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
// Need to copy Dialer to prevent original object mutation.
dialer := d.Dialer
var (
reqBuf bytes.Buffer
resBuf bytes.Buffer
resContentLength int64
)
userWrap := dialer.WrapConn
dialer.WrapConn = func(c net.Conn) net.Conn {
if userWrap != nil {
c = userWrap(c)
}
// Save the pointer to the raw connection.
conn = c
var (
r io.Reader = conn
w io.Writer = conn
)
if d.OnResponse != nil {
r = &prefetchResponseReader{
source: conn,
buffer: &resBuf,
contentLength: &resContentLength,
}
}
if d.OnRequest != nil {
w = io.MultiWriter(conn, &reqBuf)
}
return rwConn{conn, r, w}
}
_, br, hs, err = dialer.Dial(ctx, urlstr)
if onRequest := d.OnRequest; onRequest != nil {
onRequest(reqBuf.Bytes())
}
if onResponse := d.OnResponse; onResponse != nil {
// We must split response inside buffered bytes from other received
// bytes from server.
p := resBuf.Bytes()
n := bytes.Index(p, headEnd)
h := n + len(headEnd) // Head end index.
n = h + int(resContentLength) // Body end index.
onResponse(p[:n])
if br != nil {
// If br is non-nil, then it mean two things. First is that
// handshake is OK and server has sent additional bytes probably
// immediate sent frames (or weird but possible response body).
// Second, the bad one, is that br buffer's source is now rwConn
// instance from above WrapConn call. It is incorrect, so we must
// fix it.
var r io.Reader = conn
if len(p) > h {
// Buffer contains more than just HTTP headers bytes.
r = io.MultiReader(
bytes.NewReader(p[h:]),
conn,
)
}
br.Reset(r)
// Must make br.Buffered() to be non-zero.
br.Peek(len(p[h:]))
}
}
return conn, br, hs, err
}
type rwConn struct {
net.Conn
r io.Reader
w io.Writer
}
func (rwc rwConn) Read(p []byte) (int, error) {
return rwc.r.Read(p)
}
func (rwc rwConn) Write(p []byte) (int, error) {
return rwc.w.Write(p)
}
var headEnd = []byte("\r\n\r\n")
type prefetchResponseReader struct {
source io.Reader // Original connection source.
reader io.Reader // Wrapped reader used to read from by clients.
buffer *bytes.Buffer
contentLength *int64
}
func (r *prefetchResponseReader) Read(p []byte) (int, error) {
if r.reader == nil {
resp, err := http.ReadResponse(bufio.NewReader(
io.TeeReader(r.source, r.buffer),
), nil)
if err == nil {
*r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
}
bts := r.buffer.Bytes()
r.reader = io.MultiReader(
bytes.NewReader(bts),
r.source,
)
}
return r.reader.Read(p)
}

View file

@ -0,0 +1,29 @@
package wsutil
// RecvExtension is an interface for clearing fragment header RSV bits.
type RecvExtension interface {
BitsRecv(seq int, rsv byte) (byte, error)
}
// RecvExtensionFunc is an adapter to allow the use of ordinary functions as
// RecvExtension.
type RecvExtensionFunc func(int, byte) (byte, error)
// BitsRecv implements RecvExtension.
func (fn RecvExtensionFunc) BitsRecv(seq int, rsv byte) (byte, error) {
return fn(seq, rsv)
}
// SendExtension is an interface for setting fragment header RSV bits.
type SendExtension interface {
BitsSend(seq int, rsv byte) (byte, error)
}
// SendExtensionFunc is an adapter to allow the use of ordinary functions as
// SendExtension.
type SendExtensionFunc func(int, byte) (byte, error)
// BitsSend implements SendExtension.
func (fn SendExtensionFunc) BitsSend(seq int, rsv byte) (byte, error) {
return fn(seq, rsv)
}

View file

@ -0,0 +1,219 @@
package wsutil
import (
"errors"
"io"
"io/ioutil"
"strconv"
"github.com/gobwas/pool/pbytes"
"github.com/gobwas/ws"
)
// ClosedError returned when peer has closed the connection with appropriate
// code and a textual reason.
type ClosedError struct {
Code ws.StatusCode
Reason string
}
// Error implements error interface.
func (err ClosedError) Error() string {
return "ws closed: " + strconv.FormatUint(uint64(err.Code), 10) + " " + err.Reason
}
// ControlHandler contains logic of handling control frames.
//
// The intentional way to use it is to read the next frame header from the
// connection, optionally check its validity via ws.CheckHeader() and if it is
// not a ws.OpText of ws.OpBinary (or ws.OpContinuation) pass it to Handle()
// method.
//
// That is, passed header should be checked to get rid of unexpected errors.
//
// The Handle() method will read out all control frame payload (if any) and
// write necessary bytes as a rfc compatible response.
type ControlHandler struct {
Src io.Reader
Dst io.Writer
State ws.State
// DisableSrcCiphering disables unmasking payload data read from Src.
// It is useful when wsutil.Reader is used or when frame payload already
// pulled and ciphered out from the connection (and introduced by
// bytes.Reader, for example).
DisableSrcCiphering bool
}
// ErrNotControlFrame is returned by ControlHandler to indicate that given
// header could not be handled.
var ErrNotControlFrame = errors.New("not a control frame")
// Handle handles control frames regarding to the c.State and writes responses
// to the c.Dst when needed.
//
// It returns ErrNotControlFrame when given header is not of ws.OpClose,
// ws.OpPing or ws.OpPong operation code.
func (c ControlHandler) Handle(h ws.Header) error {
switch h.OpCode {
case ws.OpPing:
return c.HandlePing(h)
case ws.OpPong:
return c.HandlePong(h)
case ws.OpClose:
return c.HandleClose(h)
}
return ErrNotControlFrame
}
// HandlePing handles ping frame and writes specification compatible response
// to the c.Dst.
func (c ControlHandler) HandlePing(h ws.Header) error {
if h.Length == 0 {
// The most common case when ping is empty.
// Note that when sending masked frame the mask for empty payload is
// just four zero bytes.
return ws.WriteHeader(c.Dst, ws.Header{
Fin: true,
OpCode: ws.OpPong,
Masked: c.State.ClientSide(),
})
}
// In other way reply with Pong frame with copied payload.
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
Length: h.Length,
Masked: c.State.ClientSide(),
}))
defer pbytes.Put(p)
// Deal with ciphering i/o:
// Masking key is used to mask the "Payload data" defined in the same
// section as frame-payload-data, which includes "Extension data" and
// "Application data".
//
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// NOTE: We prefer ControlWriter with preallocated buffer to
// ws.WriteHeader because it performs one syscall instead of two.
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpPong, p)
r := c.Src
if c.State.ServerSide() && !c.DisableSrcCiphering {
r = NewCipherReader(r, h.Mask)
}
_, err := io.Copy(w, r)
if err == nil {
err = w.Flush()
}
return err
}
// HandlePong handles pong frame by discarding it.
func (c ControlHandler) HandlePong(h ws.Header) error {
if h.Length == 0 {
return nil
}
buf := pbytes.GetLen(int(h.Length))
defer pbytes.Put(buf)
// Discard pong message according to the RFC6455:
// A Pong frame MAY be sent unsolicited. This serves as a
// unidirectional heartbeat. A response to an unsolicited Pong frame
// is not expected.
_, err := io.CopyBuffer(ioutil.Discard, c.Src, buf)
return err
}
// HandleClose handles close frame, makes protocol validity checks and writes
// specification compatible response to the c.Dst.
func (c ControlHandler) HandleClose(h ws.Header) error {
if h.Length == 0 {
err := ws.WriteHeader(c.Dst, ws.Header{
Fin: true,
OpCode: ws.OpClose,
Masked: c.State.ClientSide(),
})
if err != nil {
return err
}
// Due to RFC, we should interpret the code as no status code
// received:
// If this Close control frame contains no status code, _The WebSocket
// Connection Close Code_ is considered to be 1005.
//
// See https://tools.ietf.org/html/rfc6455#section-7.1.5
return ClosedError{
Code: ws.StatusNoStatusRcvd,
}
}
// Prepare bytes both for reading reason and sending response.
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
Length: h.Length,
Masked: c.State.ClientSide(),
}))
defer pbytes.Put(p)
// Get the subslice to read the frame payload out.
subp := p[:h.Length]
r := c.Src
if c.State.ServerSide() && !c.DisableSrcCiphering {
r = NewCipherReader(r, h.Mask)
}
if _, err := io.ReadFull(r, subp); err != nil {
return err
}
code, reason := ws.ParseCloseFrameData(subp)
if err := ws.CheckCloseFrameData(code, reason); err != nil {
// Here we could not use the prepared bytes because there is no
// guarantee that it may fit our protocol error closure code and a
// reason.
c.closeWithProtocolError(err)
return err
}
// Deal with ciphering i/o:
// Masking key is used to mask the "Payload data" defined in the same
// section as frame-payload-data, which includes "Extension data" and
// "Application data".
//
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// NOTE: We prefer ControlWriter with preallocated buffer to
// ws.WriteHeader because it performs one syscall instead of two.
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpClose, p)
// RFC6455#5.5.1:
// If an endpoint receives a Close frame and did not previously
// send a Close frame, the endpoint MUST send a Close frame in
// response. (When sending a Close frame in response, the endpoint
// typically echoes the status code it received.)
_, err := w.Write(p[:2])
if err != nil {
return err
}
if err = w.Flush(); err != nil {
return err
}
return ClosedError{
Code: code,
Reason: reason,
}
}
func (c ControlHandler) closeWithProtocolError(reason error) error {
f := ws.NewCloseFrame(ws.NewCloseFrameBody(
ws.StatusProtocolError, reason.Error(),
))
if c.State.ClientSide() {
ws.MaskFrameInPlace(f)
}
return ws.WriteFrame(c.Dst, f)
}

View file

@ -0,0 +1,279 @@
package wsutil
import (
"bytes"
"io"
"io/ioutil"
"github.com/gobwas/ws"
)
// Message represents a message from peer, that could be presented in one or
// more frames. That is, it contains payload of all message fragments and
// operation code of initial frame for this message.
type Message struct {
OpCode ws.OpCode
Payload []byte
}
// ReadMessage is a helper function that reads next message from r. It appends
// received message(s) to the third argument and returns the result of it and
// an error if some failure happened. That is, it probably could receive more
// than one message when peer sending fragmented message in multiple frames and
// want to send some control frame between fragments. Then returned slice will
// contain those control frames at first, and then result of gluing fragments.
//
// TODO(gobwas): add DefaultReader with buffer size options.
func ReadMessage(r io.Reader, s ws.State, m []Message) ([]Message, error) {
rd := Reader{
Source: r,
State: s,
CheckUTF8: true,
OnIntermediate: func(hdr ws.Header, src io.Reader) error {
bts, err := ioutil.ReadAll(src)
if err != nil {
return err
}
m = append(m, Message{hdr.OpCode, bts})
return nil
},
}
h, err := rd.NextFrame()
if err != nil {
return m, err
}
var p []byte
if h.Fin {
// No more frames will be read. Use fixed sized buffer to read payload.
p = make([]byte, h.Length)
// It is not possible to receive io.EOF here because Reader does not
// return EOF if frame payload was successfully fetched.
// Thus we consistent here with io.Reader behavior.
_, err = io.ReadFull(&rd, p)
} else {
// Frame is fragmented, thus use ioutil.ReadAll behavior.
var buf bytes.Buffer
_, err = buf.ReadFrom(&rd)
p = buf.Bytes()
}
if err != nil {
return m, err
}
return append(m, Message{h.OpCode, p}), nil
}
// ReadClientMessage reads next message from r, considering that caller
// represents server side.
// It is a shortcut for ReadMessage(r, ws.StateServerSide, m)
func ReadClientMessage(r io.Reader, m []Message) ([]Message, error) {
return ReadMessage(r, ws.StateServerSide, m)
}
// ReadServerMessage reads next message from r, considering that caller
// represents client side.
// It is a shortcut for ReadMessage(r, ws.StateClientSide, m)
func ReadServerMessage(r io.Reader, m []Message) ([]Message, error) {
return ReadMessage(r, ws.StateClientSide, m)
}
// ReadData is a helper function that reads next data (non-control) message
// from rw.
// It takes care on handling all control frames. It will write response on
// control frames to the write part of rw. It blocks until some data frame
// will be received.
//
// Note this may handle and write control frames into the writer part of a
// given io.ReadWriter.
func ReadData(rw io.ReadWriter, s ws.State) ([]byte, ws.OpCode, error) {
return readData(rw, s, ws.OpText|ws.OpBinary)
}
// ReadClientData reads next data message from rw, considering that caller
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
//
// Note this may handle and write control frames into the writer part of a
// given io.ReadWriter.
func ReadClientData(rw io.ReadWriter) ([]byte, ws.OpCode, error) {
return ReadData(rw, ws.StateServerSide)
}
// ReadClientText reads next text message from rw, considering that caller
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
// It discards received binary messages.
//
// Note this may handle and write control frames into the writer part of a
// given io.ReadWriter.
func ReadClientText(rw io.ReadWriter) ([]byte, error) {
p, _, err := readData(rw, ws.StateServerSide, ws.OpText)
return p, err
}
// ReadClientBinary reads next binary message from rw, considering that caller
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
// It discards received text messages.
//
// Note this may handle and write control frames into the writer part of a given
// io.ReadWriter.
func ReadClientBinary(rw io.ReadWriter) ([]byte, error) {
p, _, err := readData(rw, ws.StateServerSide, ws.OpBinary)
return p, err
}
// ReadServerData reads next data message from rw, considering that caller
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
//
// Note this may handle and write control frames into the writer part of a
// given io.ReadWriter.
func ReadServerData(rw io.ReadWriter) ([]byte, ws.OpCode, error) {
return ReadData(rw, ws.StateClientSide)
}
// ReadServerText reads next text message from rw, considering that caller
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
// It discards received binary messages.
//
// Note this may handle and write control frames into the writer part of a given
// io.ReadWriter.
func ReadServerText(rw io.ReadWriter) ([]byte, error) {
p, _, err := readData(rw, ws.StateClientSide, ws.OpText)
return p, err
}
// ReadServerBinary reads next binary message from rw, considering that caller
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
// It discards received text messages.
//
// Note this may handle and write control frames into the writer part of a
// given io.ReadWriter.
func ReadServerBinary(rw io.ReadWriter) ([]byte, error) {
p, _, err := readData(rw, ws.StateClientSide, ws.OpBinary)
return p, err
}
// WriteMessage is a helper function that writes message to the w. It
// constructs single frame with given operation code and payload.
// It uses given state to prepare side-dependent things, like cipher
// payload bytes from client to server. It will not mutate p bytes if
// cipher must be made.
//
// If you want to write message in fragmented frames, use Writer instead.
func WriteMessage(w io.Writer, s ws.State, op ws.OpCode, p []byte) error {
return writeFrame(w, s, op, true, p)
}
// WriteServerMessage writes message to w, considering that caller
// represents server side.
func WriteServerMessage(w io.Writer, op ws.OpCode, p []byte) error {
return WriteMessage(w, ws.StateServerSide, op, p)
}
// WriteServerText is the same as WriteServerMessage with
// ws.OpText.
func WriteServerText(w io.Writer, p []byte) error {
return WriteServerMessage(w, ws.OpText, p)
}
// WriteServerBinary is the same as WriteServerMessage with
// ws.OpBinary.
func WriteServerBinary(w io.Writer, p []byte) error {
return WriteServerMessage(w, ws.OpBinary, p)
}
// WriteClientMessage writes message to w, considering that caller
// represents client side.
func WriteClientMessage(w io.Writer, op ws.OpCode, p []byte) error {
return WriteMessage(w, ws.StateClientSide, op, p)
}
// WriteClientText is the same as WriteClientMessage with
// ws.OpText.
func WriteClientText(w io.Writer, p []byte) error {
return WriteClientMessage(w, ws.OpText, p)
}
// WriteClientBinary is the same as WriteClientMessage with
// ws.OpBinary.
func WriteClientBinary(w io.Writer, p []byte) error {
return WriteClientMessage(w, ws.OpBinary, p)
}
// HandleClientControlMessage handles control frame from conn and writes
// response when needed.
//
// It considers that caller represents server side.
func HandleClientControlMessage(conn io.Writer, msg Message) error {
return HandleControlMessage(conn, ws.StateServerSide, msg)
}
// HandleServerControlMessage handles control frame from conn and writes
// response when needed.
//
// It considers that caller represents client side.
func HandleServerControlMessage(conn io.Writer, msg Message) error {
return HandleControlMessage(conn, ws.StateClientSide, msg)
}
// HandleControlMessage handles message which was read by ReadMessage()
// functions.
//
// That is, it is expected, that payload is already unmasked and frame header
// were checked by ws.CheckHeader() call.
func HandleControlMessage(conn io.Writer, state ws.State, msg Message) error {
return (ControlHandler{
DisableSrcCiphering: true,
Src: bytes.NewReader(msg.Payload),
Dst: conn,
State: state,
}).Handle(ws.Header{
Length: int64(len(msg.Payload)),
OpCode: msg.OpCode,
Fin: true,
Masked: state.ServerSide(),
})
}
// ControlFrameHandler returns FrameHandlerFunc for handling control frames.
// For more info see ControlHandler docs.
func ControlFrameHandler(w io.Writer, state ws.State) FrameHandlerFunc {
return func(h ws.Header, r io.Reader) error {
return (ControlHandler{
DisableSrcCiphering: true,
Src: r,
Dst: w,
State: state,
}).Handle(h)
}
}
func readData(rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, ws.OpCode, error) {
controlHandler := ControlFrameHandler(rw, s)
rd := Reader{
Source: rw,
State: s,
CheckUTF8: true,
SkipHeaderCheck: false,
OnIntermediate: controlHandler,
}
for {
hdr, err := rd.NextFrame()
if err != nil {
return nil, 0, err
}
if hdr.OpCode.IsControl() {
if err := controlHandler(hdr, &rd); err != nil {
return nil, 0, err
}
continue
}
if hdr.OpCode&want == 0 {
if err := rd.Discard(); err != nil {
return nil, 0, err
}
continue
}
bts, err := ioutil.ReadAll(&rd)
return bts, hdr.OpCode, err
}
}

View file

@ -0,0 +1,280 @@
package wsutil
import (
"errors"
"io"
"io/ioutil"
"github.com/gobwas/ws"
)
// ErrNoFrameAdvance means that Reader's Read() method was called without
// preceding NextFrame() call.
var ErrNoFrameAdvance = errors.New("no frame advance")
// FrameHandlerFunc handles parsed frame header and its body represented by
// io.Reader.
//
// Note that reader represents already unmasked body.
type FrameHandlerFunc func(ws.Header, io.Reader) error
// Reader is a wrapper around source io.Reader which represents WebSocket
// connection. It contains options for reading messages from source.
//
// Reader implements io.Reader, which Read() method reads payload of incoming
// WebSocket frames. It also takes care on fragmented frames and possibly
// intermediate control frames between them.
//
// Note that Reader's methods are not goroutine safe.
type Reader struct {
Source io.Reader
State ws.State
// SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
SkipHeaderCheck bool
// CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
// bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
CheckUTF8 bool
// Extensions is a list of negotiated extensions for reader Source.
// It is used to meet the specs and clear appropriate bits in fragment
// header RSV segment.
Extensions []RecvExtension
// TODO(gobwas): add max frame size limit here.
OnContinuation FrameHandlerFunc
OnIntermediate FrameHandlerFunc
opCode ws.OpCode // Used to store message op code on fragmentation.
frame io.Reader // Used to as frame reader.
raw io.LimitedReader // Used to discard frames without cipher.
utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
fseq int // Fragment sequence in message counter.
}
// NewReader creates new frame reader that reads from r keeping given state to
// make some protocol validity checks when it needed.
func NewReader(r io.Reader, s ws.State) *Reader {
return &Reader{
Source: r,
State: s,
}
}
// NewClientSideReader is a helper function that calls NewReader with r and
// ws.StateClientSide.
func NewClientSideReader(r io.Reader) *Reader {
return NewReader(r, ws.StateClientSide)
}
// NewServerSideReader is a helper function that calls NewReader with r and
// ws.StateServerSide.
func NewServerSideReader(r io.Reader) *Reader {
return NewReader(r, ws.StateServerSide)
}
// Read implements io.Reader. It reads the next message payload into p.
// It takes care on fragmented messages.
//
// The error is io.EOF only if all of message bytes were read.
// If an io.EOF happens during reading some but not all the message bytes
// Read() returns io.ErrUnexpectedEOF.
//
// The error is ErrNoFrameAdvance if no NextFrame() call was made before
// reading next message bytes.
func (r *Reader) Read(p []byte) (n int, err error) {
if r.frame == nil {
if !r.fragmented() {
// Every new Read() must be preceded by NextFrame() call.
return 0, ErrNoFrameAdvance
}
// Read next continuation or intermediate control frame.
_, err := r.NextFrame()
if err != nil {
return 0, err
}
if r.frame == nil {
// We handled intermediate control and now got nothing to read.
return 0, nil
}
}
n, err = r.frame.Read(p)
if err != nil && err != io.EOF {
return
}
if err == nil && r.raw.N != 0 {
return n, nil
}
// EOF condition (either err is io.EOF or r.raw.N is zero).
switch {
case r.raw.N != 0:
err = io.ErrUnexpectedEOF
case r.fragmented():
err = nil
r.resetFragment()
case r.CheckUTF8 && !r.utf8.Valid():
// NOTE: check utf8 only when full message received, since partial
// reads may be invalid.
n = r.utf8.Accepted()
err = ErrInvalidUTF8
default:
r.reset()
err = io.EOF
}
return
}
// Discard discards current message unread bytes.
// It discards all frames of fragmented message.
func (r *Reader) Discard() (err error) {
for {
_, err = io.Copy(ioutil.Discard, &r.raw)
if err != nil {
break
}
if !r.fragmented() {
break
}
if _, err = r.NextFrame(); err != nil {
break
}
}
r.reset()
return err
}
// NextFrame prepares r to read next message. It returns received frame header
// and non-nil error on failure.
//
// Note that next NextFrame() call must be done after receiving or discarding
// all current message bytes.
func (r *Reader) NextFrame() (hdr ws.Header, err error) {
hdr, err = ws.ReadHeader(r.Source)
if err == io.EOF && r.fragmented() {
// If we are in fragmented state EOF means that is was totally
// unexpected.
//
// NOTE: This is necessary to prevent callers such that
// ioutil.ReadAll to receive some amount of bytes without an error.
// ReadAll() ignores an io.EOF error, thus caller may think that
// whole message fetched, but actually only part of it.
err = io.ErrUnexpectedEOF
}
if err == nil && !r.SkipHeaderCheck {
err = ws.CheckHeader(hdr, r.State)
}
if err != nil {
return hdr, err
}
// Save raw reader to use it on discarding frame without ciphering and
// other streaming checks.
r.raw = io.LimitedReader{
R: r.Source,
N: hdr.Length,
}
frame := io.Reader(&r.raw)
if hdr.Masked {
frame = NewCipherReader(frame, hdr.Mask)
}
for _, ext := range r.Extensions {
hdr.Rsv, err = ext.BitsRecv(r.fseq, hdr.Rsv)
if err != nil {
return hdr, err
}
}
if r.fragmented() {
if hdr.OpCode.IsControl() {
if cb := r.OnIntermediate; cb != nil {
err = cb(hdr, frame)
}
if err == nil {
// Ensure that src is empty.
_, err = io.Copy(ioutil.Discard, &r.raw)
}
return
}
} else {
r.opCode = hdr.OpCode
}
if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
r.utf8.Source = frame
frame = &r.utf8
}
// Save reader with ciphering and other streaming checks.
r.frame = frame
if hdr.OpCode == ws.OpContinuation {
if cb := r.OnContinuation; cb != nil {
err = cb(hdr, frame)
}
}
if hdr.Fin {
r.State = r.State.Clear(ws.StateFragmented)
r.fseq = 0
} else {
r.State = r.State.Set(ws.StateFragmented)
r.fseq++
}
return
}
func (r *Reader) fragmented() bool {
return r.State.Fragmented()
}
func (r *Reader) resetFragment() {
r.raw = io.LimitedReader{}
r.frame = nil
// Reset source of the UTF8Reader, but not the state.
r.utf8.Source = nil
}
func (r *Reader) reset() {
r.raw = io.LimitedReader{}
r.frame = nil
r.utf8 = UTF8Reader{}
r.fseq = 0
r.opCode = 0
}
// NextReader prepares next message read from r. It returns header that
// describes the message and io.Reader to read message's payload. It returns
// non-nil error when it is not possible to read message's initial frame.
//
// Note that next NextReader() on the same r should be done after reading all
// bytes from previously returned io.Reader. For more performant way to discard
// message use Reader and its Discard() method.
//
// Note that it will not handle any "intermediate" frames, that possibly could
// be received between text/binary continuation frames. That is, if peer sent
// text/binary frame with fin flag "false", then it could send ping frame, and
// eventually remaining part of text/binary frame with fin "true" with
// NextReader() the ping frame will be dropped without any notice. To handle
// this rare, but possible situation (and if you do not know exactly which
// frames peer could send), you could use Reader with OnIntermediate field set.
func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
rd := &Reader{
Source: r,
State: s,
}
header, err := rd.NextFrame()
if err != nil {
return header, nil, err
}
return header, rd, nil
}

View file

@ -0,0 +1,68 @@
package wsutil
import (
"bufio"
"bytes"
"io"
"io/ioutil"
"net/http"
"github.com/gobwas/ws"
)
// DebugUpgrader is a wrapper around ws.Upgrader. It tracks I/O of a
// WebSocket handshake.
//
// Note that it must not be used in production applications that requires
// Upgrade() to be efficient.
type DebugUpgrader struct {
// Upgrader contains upgrade to WebSocket options.
Upgrader ws.Upgrader
// OnRequest and OnResponse are the callbacks that will be called with the
// HTTP request and response respectively.
OnRequest, OnResponse func([]byte)
}
// Upgrade calls Upgrade() on underlying ws.Upgrader and tracks I/O on conn.
func (d *DebugUpgrader) Upgrade(conn io.ReadWriter) (hs ws.Handshake, err error) {
var (
// Take the Reader and Writer parts from conn to be probably replaced
// below.
r io.Reader = conn
w io.Writer = conn
)
if onRequest := d.OnRequest; onRequest != nil {
var buf bytes.Buffer
// First, we must read the entire request.
req, err := http.ReadRequest(bufio.NewReader(
io.TeeReader(conn, &buf),
))
if err == nil {
// Fulfill the buffer with the response body.
io.Copy(ioutil.Discard, req.Body)
req.Body.Close()
}
onRequest(buf.Bytes())
r = io.MultiReader(
&buf, conn,
)
}
if onResponse := d.OnResponse; onResponse != nil {
var buf bytes.Buffer
// Intercept the response stream written by the Upgrade().
w = io.MultiWriter(
conn, &buf,
)
defer func() {
onResponse(buf.Bytes())
}()
}
return d.Upgrader.Upgrade(struct {
io.Reader
io.Writer
}{r, w})
}

View file

@ -0,0 +1,140 @@
package wsutil
import (
"fmt"
"io"
)
// ErrInvalidUTF8 is returned by UTF8 reader on invalid utf8 sequence.
var ErrInvalidUTF8 = fmt.Errorf("invalid utf8")
// UTF8Reader implements io.Reader that calculates utf8 validity state after
// every read byte from Source.
//
// Note that in some cases client must call r.Valid() after all bytes are read
// to ensure that all of them are valid utf8 sequences. That is, some io helper
// functions such io.ReadAtLeast or io.ReadFull could discard the error
// information returned by the reader when they receive all of requested bytes.
// For example, the last read sequence is invalid and UTF8Reader returns number
// of bytes read and an error. But helper function decides to discard received
// error due to all requested bytes are completely read from the source.
//
// Another possible case is when some valid sequence become split by the read
// bound. Then UTF8Reader can not make decision about validity of the last
// sequence cause it is not fully read yet. And if the read stops, Valid() will
// return false, even if Read() by itself dit not.
type UTF8Reader struct {
Source io.Reader
accepted int
state uint32
codep uint32
}
// NewUTF8Reader creates utf8 reader that reads from r.
func NewUTF8Reader(r io.Reader) *UTF8Reader {
return &UTF8Reader{
Source: r,
}
}
// Reset resets utf8 reader to read from r.
func (u *UTF8Reader) Reset(r io.Reader) {
u.Source = r
u.state = 0
u.codep = 0
}
// Read implements io.Reader.
func (u *UTF8Reader) Read(p []byte) (n int, err error) {
n, err = u.Source.Read(p)
accepted := 0
s, c := u.state, u.codep
for i := 0; i < n; i++ {
c, s = decode(s, c, p[i])
if s == utf8Reject {
u.state = s
return accepted, ErrInvalidUTF8
}
if s == utf8Accept {
accepted = i + 1
}
}
u.state, u.codep = s, c
u.accepted = accepted
return
}
// Valid checks current reader state. It returns true if all read bytes are
// valid UTF-8 sequences, and false if not.
func (u *UTF8Reader) Valid() bool {
return u.state == utf8Accept
}
// Accepted returns number of valid bytes in last Read().
func (u *UTF8Reader) Accepted() int {
return u.accepted
}
// Below is port of UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
//
// Copyright (c) 2008-2009 Bjoern Hoehrmann <bjoern@hoehrmann.de>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to
// deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
// sell copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
// IN THE SOFTWARE.
const (
utf8Accept = 0
utf8Reject = 12
)
var utf8d = [...]byte{
// The first part of the table maps bytes to character classes that
// to reduce the size of the transition table and create bitmasks.
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
// The second part is a transition table that maps a combination
// of a state of the automaton and a character class to a state.
0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12,
12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12,
12, 12, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12,
12, 36, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
}
func decode(state, codep uint32, b byte) (uint32, uint32) {
t := uint32(utf8d[b])
if state != utf8Accept {
codep = (uint32(b) & 0x3f) | (codep << 6)
} else {
codep = (0xff >> t) & uint32(b)
}
return codep, uint32(utf8d[256+state+t])
}

View file

@ -0,0 +1,572 @@
package wsutil
import (
"fmt"
"io"
"github.com/gobwas/pool"
"github.com/gobwas/pool/pbytes"
"github.com/gobwas/ws"
)
// DefaultWriteBuffer contains size of Writer's default buffer. It used by
// Writer constructor functions.
var DefaultWriteBuffer = 4096
var (
// ErrNotEmpty is returned by Writer.WriteThrough() to indicate that buffer is
// not empty and write through could not be done. That is, caller should call
// Writer.FlushFragment() to make buffer empty.
ErrNotEmpty = fmt.Errorf("writer not empty")
// ErrControlOverflow is returned by ControlWriter.Write() to indicate that
// no more data could be written to the underlying io.Writer because
// MaxControlFramePayloadSize limit is reached.
ErrControlOverflow = fmt.Errorf("control frame payload overflow")
)
// Constants which are represent frame length ranges.
const (
len7 = int64(125) // 126 and 127 are reserved values
len16 = int64(^uint16(0))
len64 = int64((^uint64(0)) >> 1)
)
// ControlWriter is a wrapper around Writer that contains some guards for
// buffered writes of control frames.
type ControlWriter struct {
w *Writer
limit int
n int
}
// NewControlWriter contains ControlWriter with Writer inside whose buffer size
// is at most ws.MaxControlFramePayloadSize + ws.MaxHeaderSize.
func NewControlWriter(dest io.Writer, state ws.State, op ws.OpCode) *ControlWriter {
return &ControlWriter{
w: NewWriterSize(dest, state, op, ws.MaxControlFramePayloadSize),
limit: ws.MaxControlFramePayloadSize,
}
}
// NewControlWriterBuffer returns a new ControlWriter with buf as a buffer.
//
// Note that it reserves x bytes of buf for header data, where x could be
// ws.MinHeaderSize or ws.MinHeaderSize+4 (depending on state). At most
// (ws.MaxControlFramePayloadSize + x) bytes of buf will be used.
//
// It panics if len(buf) <= ws.MinHeaderSize + x.
func NewControlWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *ControlWriter {
max := ws.MaxControlFramePayloadSize + headerSize(state, ws.MaxControlFramePayloadSize)
if len(buf) > max {
buf = buf[:max]
}
w := NewWriterBuffer(dest, state, op, buf)
return &ControlWriter{
w: w,
limit: len(w.buf),
}
}
// Write implements io.Writer. It writes to the underlying Writer until it
// returns error or until ControlWriter write limit will be exceeded.
func (c *ControlWriter) Write(p []byte) (n int, err error) {
if c.n+len(p) > c.limit {
return 0, ErrControlOverflow
}
return c.w.Write(p)
}
// Flush flushes all buffered data to the underlying io.Writer.
func (c *ControlWriter) Flush() error {
return c.w.Flush()
}
var writers = pool.New(128, 65536)
// GetWriter tries to reuse Writer getting it from the pool.
//
// This function is intended for memory consumption optimizations, because
// NewWriter*() functions make allocations for inner buffer.
//
// Note the it ceils n to the power of two.
//
// If you have your own bytes buffer pool you could use NewWriterBuffer to use
// pooled bytes in writer.
func GetWriter(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
x, m := writers.Get(n)
if x != nil {
w := x.(*Writer)
w.Reset(dest, state, op)
return w
}
// NOTE: we use m instead of n, because m is an attempt to reuse w of such
// size in the future.
return NewWriterBufferSize(dest, state, op, m)
}
// PutWriter puts w for future reuse by GetWriter().
func PutWriter(w *Writer) {
w.Reset(nil, 0, 0)
writers.Put(w, w.Size())
}
// Writer contains logic of buffering output data into a WebSocket fragments.
// It is much the same as bufio.Writer, except the thing that it works with
// WebSocket frames, not the raw data.
//
// Writer writes frames with specified OpCode.
// It uses ws.State to decide whether the output frames must be masked.
//
// Note that it does not check control frame size or other RFC rules.
// That is, it must be used with special care to write control frames without
// violation of RFC. You could use ControlWriter that wraps Writer and contains
// some guards for writing control frames.
//
// If an error occurs writing to a Writer, no more data will be accepted and
// all subsequent writes will return the error.
//
// After all data has been written, the client should call the Flush() method
// to guarantee all data has been forwarded to the underlying io.Writer.
type Writer struct {
// dest specifies a destination of buffer flushes.
dest io.Writer
// op specifies the WebSocket operation code used in flushed frames.
op ws.OpCode
// state specifies the state of the Writer.
state ws.State
// extensions is a list of negotiated extensions for writer Dest.
// It is used to meet the specs and set appropriate bits in fragment
// header RSV segment.
extensions []SendExtension
// noFlush reports whether buffer must grow instead of being flushed.
noFlush bool
// Raw representation of the buffer, including reserved header bytes.
raw []byte
// Writeable part of buffer, without reserved header bytes.
// Resetting this to nil will not result in reallocation if raw is not nil.
// And vice versa: if buf is not nil, then Writer is assumed as ready and
// initialized.
buf []byte
// Buffered bytes counter.
n int
dirty bool
fseq int
err error
}
// NewWriter returns a new Writer whose buffer has the DefaultWriteBuffer size.
func NewWriter(dest io.Writer, state ws.State, op ws.OpCode) *Writer {
return NewWriterBufferSize(dest, state, op, 0)
}
// NewWriterSize returns a new Writer whose buffer size is at most n + ws.MaxHeaderSize.
// That is, output frames payload length could be up to n, except the case when
// Write() is called on empty Writer with len(p) > n.
//
// If n <= 0 then the default buffer size is used as Writer's buffer size.
func NewWriterSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
if n > 0 {
n += headerSize(state, n)
}
return NewWriterBufferSize(dest, state, op, n)
}
// NewWriterBufferSize returns a new Writer whose buffer size is equal to n.
// If n <= ws.MinHeaderSize then the default buffer size is used.
//
// Note that Writer will reserve x bytes for header data, where x is in range
// [ws.MinHeaderSize,ws.MaxHeaderSize]. That is, frames flushed by Writer
// will not have payload length equal to n, except the case when Write() is
// called on empty Writer with len(p) > n.
func NewWriterBufferSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
if n <= ws.MinHeaderSize {
n = DefaultWriteBuffer
}
return NewWriterBuffer(dest, state, op, make([]byte, n))
}
// NewWriterBuffer returns a new Writer with buf as a buffer.
//
// Note that it reserves x bytes of buf for header data, where x is in range
// [ws.MinHeaderSize,ws.MaxHeaderSize] (depending on state and buf size).
//
// You could use ws.HeaderSize() to calculate number of bytes needed to store
// header data.
//
// It panics if len(buf) is too small to fit header and payload data.
func NewWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *Writer {
w := &Writer{
dest: dest,
state: state,
op: op,
raw: buf,
}
w.initBuf()
return w
}
func (w *Writer) initBuf() {
offset := reserve(w.state, len(w.raw))
if len(w.raw) <= offset {
panic("wsutil: writer buffer is too small")
}
w.buf = w.raw[offset:]
}
// Reset resets Writer as it was created by New() methods.
// Note that Reset does reset extenstions and other options was set after
// Writer initialization.
func (w *Writer) Reset(dest io.Writer, state ws.State, op ws.OpCode) {
w.dest = dest
w.state = state
w.op = op
w.initBuf()
w.n = 0
w.dirty = false
w.fseq = 0
w.extensions = w.extensions[:0]
w.noFlush = false
}
// ResetOp is an quick version of Reset().
// ResetOp does reset unwritten fragments and does not reset results of
// SetExtensions() or DisableFlush() methods.
func (w *Writer) ResetOp(op ws.OpCode) {
w.op = op
w.n = 0
w.dirty = false
w.fseq = 0
}
// SetExtensions adds xs as extenstions to be used during writes.
func (w *Writer) SetExtensions(xs ...SendExtension) {
w.extensions = xs
}
// DisableFlush denies Writer to write fragments.
func (w *Writer) DisableFlush() {
w.noFlush = true
}
// Size returns the size of the underlying buffer in bytes (not including
// WebSocket header bytes).
func (w *Writer) Size() int {
return len(w.buf)
}
// Available returns how many bytes are unused in the buffer.
func (w *Writer) Available() int {
return len(w.buf) - w.n
}
// Buffered returns the number of bytes that have been written into the current
// buffer.
func (w *Writer) Buffered() int {
return w.n
}
// Write implements io.Writer.
//
// Note that even if the Writer was created to have N-sized buffer, Write()
// with payload of N bytes will not fit into that buffer. Writer reserves some
// space to fit WebSocket header data.
func (w *Writer) Write(p []byte) (n int, err error) {
// Even empty p may make a sense.
w.dirty = true
var nn int
for len(p) > w.Available() && w.err == nil {
if w.noFlush {
w.Grow(len(p) - w.Available())
continue
}
if w.Buffered() == 0 {
// Large write, empty buffer. Write directly from p to avoid copy.
// Trade off here is that we make additional Write() to underlying
// io.Writer when writing frame header.
//
// On large buffers additional write is better than copying.
nn, _ = w.WriteThrough(p)
} else {
nn = copy(w.buf[w.n:], p)
w.n += nn
w.FlushFragment()
}
n += nn
p = p[nn:]
}
if w.err != nil {
return n, w.err
}
nn = copy(w.buf[w.n:], p)
w.n += nn
n += nn
// Even if w.Available() == 0 we will not flush buffer preventively because
// this could bring unwanted fragmentation. That is, user could create
// buffer with size that fits exactly all further Write() call, and then
// call Flush(), excepting that single and not fragmented frame will be
// sent. With preemptive flush this case will produce two frames last one
// will be empty and just to set fin = true.
return n, w.err
}
func ceilPowerOfTwo(n int) int {
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n++
return n
}
func (w *Writer) Grow(n int) {
var (
offset = len(w.raw) - len(w.buf)
size = ceilPowerOfTwo(offset + w.n + n)
)
if size <= len(w.raw) {
panic("wsutil: buffer grow leads to its reduce")
}
p := make([]byte, size)
copy(p, w.raw[:offset+w.n])
w.raw = p
w.buf = w.raw[offset:]
}
// WriteThrough writes data bypassing the buffer.
// Note that Writer's buffer must be empty before calling WriteThrough().
func (w *Writer) WriteThrough(p []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
if w.Buffered() != 0 {
return 0, ErrNotEmpty
}
var frame ws.Frame
frame.Header = ws.Header{
OpCode: w.opCode(),
Fin: false,
Length: int64(len(p)),
}
for _, ext := range w.extensions {
frame.Header.Rsv, err = ext.BitsSend(w.fseq, frame.Header.Rsv)
if err != nil {
return 0, err
}
}
if w.state.ClientSide() {
// Should copy bytes to prevent corruption of caller data.
payload := pbytes.GetLen(len(p))
defer pbytes.Put(payload)
copy(payload, p)
frame.Payload = payload
frame = ws.MaskFrameInPlace(frame)
} else {
frame.Payload = p
}
w.err = ws.WriteFrame(w.dest, frame)
if w.err == nil {
n = len(p)
}
w.dirty = true
w.fseq++
return n, w.err
}
// ReadFrom implements io.ReaderFrom.
func (w *Writer) ReadFrom(src io.Reader) (n int64, err error) {
var nn int
for err == nil {
if w.Available() == 0 {
if w.noFlush {
w.Grow(w.Buffered()) // Twice bigger.
} else {
err = w.FlushFragment()
}
continue
}
// We copy the behavior of bufio.Writer here.
// Also, from the docs on io.ReaderFrom:
// ReadFrom reads data from r until EOF or error.
//
// See https://codereview.appspot.com/76400048/#ps1
const maxEmptyReads = 100
var nr int
for nr < maxEmptyReads {
nn, err = src.Read(w.buf[w.n:])
if nn != 0 || err != nil {
break
}
nr++
}
if nr == maxEmptyReads {
return n, io.ErrNoProgress
}
w.n += nn
n += int64(nn)
}
if err == io.EOF {
// NOTE: Do not flush preemptively.
// See the Write() sources for more info.
err = nil
w.dirty = true
}
return n, err
}
// Flush writes any buffered data to the underlying io.Writer.
// It sends the frame with "fin" flag set to true.
//
// If no Write() or ReadFrom() was made, then Flush() does nothing.
func (w *Writer) Flush() error {
if (!w.dirty && w.Buffered() == 0) || w.err != nil {
return w.err
}
w.err = w.flushFragment(true)
w.n = 0
w.dirty = false
w.fseq = 0
return w.err
}
// FlushFragment writes any buffered data to the underlying io.Writer.
// It sends the frame with "fin" flag set to false.
func (w *Writer) FlushFragment() error {
if w.Buffered() == 0 || w.err != nil {
return w.err
}
w.err = w.flushFragment(false)
w.n = 0
w.fseq++
return w.err
}
func (w *Writer) flushFragment(fin bool) (err error) {
var (
payload = w.buf[:w.n]
header = ws.Header{
OpCode: w.opCode(),
Fin: fin,
Length: int64(len(payload)),
}
)
for _, ext := range w.extensions {
header.Rsv, err = ext.BitsSend(w.fseq, header.Rsv)
if err != nil {
return err
}
}
if w.state.ClientSide() {
header.Masked = true
header.Mask = ws.NewMask()
ws.Cipher(payload, header.Mask, 0)
}
// Write header to the header segment of the raw buffer.
var (
offset = len(w.raw) - len(w.buf)
skip = offset - ws.HeaderSize(header)
)
buf := bytesWriter{
buf: w.raw[skip:offset],
}
if err := ws.WriteHeader(&buf, header); err != nil {
// Must never be reached.
panic("dump header error: " + err.Error())
}
_, err = w.dest.Write(w.raw[skip : offset+w.n])
return err
}
func (w *Writer) opCode() ws.OpCode {
if w.fseq > 0 {
return ws.OpContinuation
}
return w.op
}
var errNoSpace = fmt.Errorf("not enough buffer space")
type bytesWriter struct {
buf []byte
pos int
}
func (w *bytesWriter) Write(p []byte) (int, error) {
n := copy(w.buf[w.pos:], p)
w.pos += n
if n != len(p) {
return n, errNoSpace
}
return n, nil
}
func writeFrame(w io.Writer, s ws.State, op ws.OpCode, fin bool, p []byte) error {
var frame ws.Frame
if s.ClientSide() {
// Should copy bytes to prevent corruption of caller data.
payload := pbytes.GetLen(len(p))
defer pbytes.Put(payload)
copy(payload, p)
frame = ws.NewFrame(op, fin, payload)
frame = ws.MaskFrameInPlace(frame)
} else {
frame = ws.NewFrame(op, fin, p)
}
return ws.WriteFrame(w, frame)
}
func reserve(state ws.State, n int) (offset int) {
var mask int
if state.ClientSide() {
mask = 4
}
switch {
case n <= int(len7)+mask+2:
return mask + 2
case n <= int(len16)+mask+4:
return mask + 4
default:
return mask + 10
}
}
// headerSize returns number of bytes needed to encode header of a frame with
// given state and length.
func headerSize(s ws.State, n int) int {
return ws.HeaderSize(ws.Header{
Length: int64(n),
Masked: s.ClientSide(),
})
}

View file

@ -0,0 +1,57 @@
/*
Package wsutil provides utilities for working with WebSocket protocol.
Overview:
// Read masked text message from peer and check utf8 encoding.
header, err := ws.ReadHeader(conn)
if err != nil {
// handle err
}
// Prepare to read payload.
r := io.LimitReader(conn, header.Length)
r = wsutil.NewCipherReader(r, header.Mask)
r = wsutil.NewUTF8Reader(r)
payload, err := ioutil.ReadAll(r)
if err != nil {
// handle err
}
You could get the same behavior using just `wsutil.Reader`:
r := wsutil.Reader{
Source: conn,
CheckUTF8: true,
}
payload, err := ioutil.ReadAll(r)
if err != nil {
// handle err
}
Or even simplest:
payload, err := wsutil.ReadClientText(conn)
if err != nil {
// handle err
}
Package is also exports tools for buffered writing:
// Create buffered writer, that will buffer output bytes and send them as
// 128-length fragments (with exception on large writes, see the doc).
writer := wsutil.NewWriterSize(conn, ws.StateServerSide, ws.OpText, 128)
_, err := io.CopyN(writer, rand.Reader, 100)
if err == nil {
err = writer.Flush()
}
if err != nil {
// handle error
}
For more utils and helpers see the documentation.
*/
package wsutil