1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

For regression test, add srs-bench to 3rdparty

This commit is contained in:
winlin 2021-03-04 13:23:01 +08:00
parent de87dd427d
commit 876210f6c9
1158 changed files with 256967 additions and 3 deletions

View file

@ -0,0 +1,2 @@
# vim temporary files
*.sw[poe]

View file

@ -0,0 +1,8 @@
linters-settings:
govet:
check-shadowing: true
misspell:
locale: US
run:
skip-dirs-use-default: false

View file

@ -0,0 +1,20 @@
<h1 align="center">
Design
</h1>
### Portable
Pion Data Channels is written in Go and extremely portable. Anywhere Golang runs, Pion Data Channels should work as well! Instead of dealing with complicated
cross-compiling of multiple libraries, you now can run anywhere with one `go build`
### Simple API
The API is based on an io.ReadWriteCloser.
### Readable
If code comes from an RFC we try to make sure everything is commented with a link to the spec.
This makes learning and debugging easier, this library was written to also serve as a guide for others.
### Tested
Every commit is tested via travis-ci Go provides fantastic facilities for testing, and more will be added as time goes on.
### Shared libraries
Every pion product is built using shared libraries, allowing others to review and reuse our libraries.

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2018
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,45 @@
<h1 align="center">
<br>
Pion Data Channels
<br>
</h1>
<h4 align="center">A Go implementation of WebRTC Data Channels</h4>
<p align="center">
<a href="https://pion.ly"><img src="https://img.shields.io/badge/pion-datachannel-gray.svg?longCache=true&colorB=brightgreen" alt="Pion Data Channels"></a>
<!--<a href="https://sourcegraph.com/github.com/pion/webrtc?badge"><img src="https://sourcegraph.com/github.com/pion/webrtc/-/badge.svg" alt="Sourcegraph Widget"></a>-->
<a href="https://pion.ly/slack"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a>
<br>
<a href="https://travis-ci.org/pion/datachannel"><img src="https://travis-ci.org/pion/datachannel.svg?branch=master" alt="Build Status"></a>
<a href="https://pkg.go.dev/github.com/pion/datachannel"><img src="https://godoc.org/github.com/pion/datachannel?status.svg" alt="GoDoc"></a>
<a href="https://codecov.io/gh/pion/datachannel"><img src="https://codecov.io/gh/pion/datachannel/branch/master/graph/badge.svg" alt="Coverage Status"></a>
<a href="https://goreportcard.com/report/github.com/pion/datachannel"><img src="https://goreportcard.com/badge/github.com/pion/datachannel" alt="Go Report Card"></a>
<!--<a href="https://www.codacy.com/app/Sean-Der/webrtc"><img src="https://api.codacy.com/project/badge/Grade/18f4aec384894e6aac0b94effe51961d" alt="Codacy Badge"></a>-->
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License: MIT"></a>
</p>
<br>
See [DESIGN.md](DESIGN.md) for an overview of features and future goals.
### Roadmap
The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones.
### Community
Pion has an active community on the [Golang Slack](https://invite.slack.golangbridge.org/). Sign up and join the **#pion** channel for discussions and support. You can also use [Pion mailing list](https://groups.google.com/forum/#!forum/pion).
We are always looking to support **your projects**. Please reach out if you have something to build!
If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly)
### Contributing
Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contributing)** to join the group of amazing people making this project possible:
* [John Bradley](https://github.com/kc5nra) - *Original Author*
* [Sean DuBois](https://github.com/Sean-Der) - *Original Author*
* [Michiel De Backker](https://github.com/backkem) - *Public API*
* [Yutaka Takeda](https://github.com/enobufs) - *PR-SCTP*
* [Hugo Arregui](https://github.com/hugoArregui)
* [Atsushi Watanabe](https://github.com/at-wat)
* [Norman Rasmussen](https://github.com/normanr) - *Fix Empty DataChannel messages*
### License
MIT License - see [LICENSE](LICENSE) for full text

View file

@ -0,0 +1,20 @@
#
# DO NOT EDIT THIS FILE
#
# It is automatically copied from https://github.com/pion/.goassets repository.
#
coverage:
status:
project:
default:
# Allow decreasing 2% of total coverage to avoid noise.
threshold: 2%
patch:
default:
target: 70%
only_pulls: true
ignore:
- "examples/*"
- "examples/**/*"

View file

@ -0,0 +1,378 @@
// Package datachannel implements WebRTC Data Channels
package datachannel
import (
"fmt"
"io"
"sync/atomic"
"github.com/pion/logging"
"github.com/pion/sctp"
"github.com/pkg/errors"
)
const receiveMTU = 8192
// Reader is an extended io.Reader
// that also returns if the message is text.
type Reader interface {
ReadDataChannel([]byte) (int, bool, error)
}
// Writer is an extended io.Writer
// that also allows indicating if a message is text.
type Writer interface {
WriteDataChannel([]byte, bool) (int, error)
}
// ReadWriteCloser is an extended io.ReadWriteCloser
// that also implements our Reader and Writer.
type ReadWriteCloser interface {
io.Reader
io.Writer
Reader
Writer
io.Closer
}
// DataChannel represents a data channel
type DataChannel struct {
Config
// stats
messagesSent uint32
messagesReceived uint32
bytesSent uint64
bytesReceived uint64
stream *sctp.Stream
log logging.LeveledLogger
}
// Config is used to configure the data channel.
type Config struct {
ChannelType ChannelType
Negotiated bool
Priority uint16
ReliabilityParameter uint32
Label string
Protocol string
LoggerFactory logging.LoggerFactory
}
func newDataChannel(stream *sctp.Stream, config *Config) (*DataChannel, error) {
return &DataChannel{
Config: *config,
stream: stream,
log: config.LoggerFactory.NewLogger("datachannel"),
}, nil
}
// Dial opens a data channels over SCTP
func Dial(a *sctp.Association, id uint16, config *Config) (*DataChannel, error) {
stream, err := a.OpenStream(id, sctp.PayloadTypeWebRTCBinary)
if err != nil {
return nil, err
}
dc, err := Client(stream, config)
if err != nil {
return nil, err
}
return dc, nil
}
// Client opens a data channel over an SCTP stream
func Client(stream *sctp.Stream, config *Config) (*DataChannel, error) {
msg := &channelOpen{
ChannelType: config.ChannelType,
Priority: config.Priority,
ReliabilityParameter: config.ReliabilityParameter,
Label: []byte(config.Label),
Protocol: []byte(config.Protocol),
}
if !config.Negotiated {
rawMsg, err := msg.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal ChannelOpen %v", err)
}
if _, err = stream.WriteSCTP(rawMsg, sctp.PayloadTypeWebRTCDCEP); err != nil {
return nil, fmt.Errorf("failed to send ChannelOpen %v", err)
}
}
return newDataChannel(stream, config)
}
// Accept is used to accept incoming data channels over SCTP
func Accept(a *sctp.Association, config *Config) (*DataChannel, error) {
stream, err := a.AcceptStream()
if err != nil {
return nil, err
}
stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary)
dc, err := Server(stream, config)
if err != nil {
return nil, err
}
return dc, nil
}
// Server accepts a data channel over an SCTP stream
func Server(stream *sctp.Stream, config *Config) (*DataChannel, error) {
buffer := make([]byte, receiveMTU) // TODO: Can probably be smaller
n, ppi, err := stream.ReadSCTP(buffer)
if err != nil {
return nil, err
}
if ppi != sctp.PayloadTypeWebRTCDCEP {
return nil, fmt.Errorf("unexpected packet type: %s", ppi)
}
openMsg, err := parseExpectDataChannelOpen(buffer[:n])
if err != nil {
return nil, errors.Wrap(err, "failed to parse DataChannelOpen packet")
}
config.ChannelType = openMsg.ChannelType
config.Priority = openMsg.Priority
config.ReliabilityParameter = openMsg.ReliabilityParameter
config.Label = string(openMsg.Label)
config.Protocol = string(openMsg.Protocol)
dataChannel, err := newDataChannel(stream, config)
if err != nil {
return nil, err
}
err = dataChannel.writeDataChannelAck()
if err != nil {
return nil, err
}
err = dataChannel.commitReliabilityParams()
if err != nil {
return nil, err
}
return dataChannel, nil
}
// Read reads a packet of len(p) bytes as binary data
func (c *DataChannel) Read(p []byte) (int, error) {
n, _, err := c.ReadDataChannel(p)
return n, err
}
// ReadDataChannel reads a packet of len(p) bytes
func (c *DataChannel) ReadDataChannel(p []byte) (int, bool, error) {
for {
n, ppi, err := c.stream.ReadSCTP(p)
if err == io.EOF {
// When the peer sees that an incoming stream was
// reset, it also resets its corresponding outgoing stream.
closeErr := c.stream.Close()
if closeErr != nil {
return 0, false, closeErr
}
}
if err != nil {
return 0, false, err
}
var isString bool
switch ppi {
case sctp.PayloadTypeWebRTCDCEP:
err = c.handleDCEP(p[:n])
if err != nil {
c.log.Errorf("Failed to handle DCEP: %s", err.Error())
continue
}
continue
case sctp.PayloadTypeWebRTCString, sctp.PayloadTypeWebRTCStringEmpty:
isString = true
}
switch ppi {
case sctp.PayloadTypeWebRTCBinaryEmpty, sctp.PayloadTypeWebRTCStringEmpty:
n = 0
}
atomic.AddUint32(&c.messagesReceived, 1)
atomic.AddUint64(&c.bytesReceived, uint64(n))
return n, isString, err
}
}
// MessagesSent returns the number of messages sent
func (c *DataChannel) MessagesSent() uint32 {
return atomic.LoadUint32(&c.messagesSent)
}
// MessagesReceived returns the number of messages received
func (c *DataChannel) MessagesReceived() uint32 {
return atomic.LoadUint32(&c.messagesReceived)
}
// BytesSent returns the number of bytes sent
func (c *DataChannel) BytesSent() uint64 {
return atomic.LoadUint64(&c.bytesSent)
}
// BytesReceived returns the number of bytes received
func (c *DataChannel) BytesReceived() uint64 {
return atomic.LoadUint64(&c.bytesReceived)
}
// StreamIdentifier returns the Stream identifier associated to the stream.
func (c *DataChannel) StreamIdentifier() uint16 {
return c.stream.StreamIdentifier()
}
func (c *DataChannel) handleDCEP(data []byte) error {
msg, err := parse(data)
if err != nil {
return errors.Wrap(err, "Failed to parse DataChannel packet")
}
switch msg := msg.(type) {
case *channelOpen:
c.log.Debug("Received DATA_CHANNEL_OPEN")
err = c.writeDataChannelAck()
if err != nil {
return fmt.Errorf("failed to ACK channel open: %v", err)
}
// Note: DATA_CHANNEL_OPEN message is handled inside Server() method.
// Therefore, the message will not reach here.
case *channelAck:
c.log.Debug("Received DATA_CHANNEL_ACK")
err = c.commitReliabilityParams()
if err != nil {
return err
}
// TODO: handle ChannelAck (https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-5.2)
default:
return fmt.Errorf("unhandled DataChannel message %v", msg)
}
return nil
}
// Write writes len(p) bytes from p as binary data
func (c *DataChannel) Write(p []byte) (n int, err error) {
return c.WriteDataChannel(p, false)
}
// WriteDataChannel writes len(p) bytes from p
func (c *DataChannel) WriteDataChannel(p []byte, isString bool) (n int, err error) {
// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6
// SCTP does not support the sending of empty user messages. Therefore,
// if an empty message has to be sent, the appropriate PPID (WebRTC
// String Empty or WebRTC Binary Empty) is used and the SCTP user
// message of one zero byte is sent. When receiving an SCTP user
// message with one of these PPIDs, the receiver MUST ignore the SCTP
// user message and process it as an empty message.
var ppi sctp.PayloadProtocolIdentifier
switch {
case !isString && len(p) > 0:
ppi = sctp.PayloadTypeWebRTCBinary
case !isString && len(p) == 0:
ppi = sctp.PayloadTypeWebRTCBinaryEmpty
case isString && len(p) > 0:
ppi = sctp.PayloadTypeWebRTCString
case isString && len(p) == 0:
ppi = sctp.PayloadTypeWebRTCStringEmpty
}
atomic.AddUint32(&c.messagesSent, 1)
atomic.AddUint64(&c.bytesSent, uint64(len(p)))
if len(p) == 0 {
_, err := c.stream.WriteSCTP([]byte{0}, ppi)
return 0, err
}
return c.stream.WriteSCTP(p, ppi)
}
func (c *DataChannel) writeDataChannelAck() error {
ack := channelAck{}
ackMsg, err := ack.Marshal()
if err != nil {
return fmt.Errorf("failed to marshal ChannelOpen ACK: %v", err)
}
_, err = c.stream.WriteSCTP(ackMsg, sctp.PayloadTypeWebRTCDCEP)
if err != nil {
return fmt.Errorf("failed to send ChannelOpen ACK: %v", err)
}
return err
}
// Close closes the DataChannel and the underlying SCTP stream.
func (c *DataChannel) Close() error {
// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
// Closing of a data channel MUST be signaled by resetting the
// corresponding outgoing streams [RFC6525]. This means that if one
// side decides to close the data channel, it resets the corresponding
// outgoing stream. When the peer sees that an incoming stream was
// reset, it also resets its corresponding outgoing stream. Once this
// is completed, the data channel is closed. Resetting a stream sets
// the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with
// a corresponding notification to the application layer that the reset
// has been performed. Streams are available for reuse after a reset
// has been performed.
return c.stream.Close()
}
// BufferedAmount returns the number of bytes of data currently queued to be
// sent over this stream.
func (c *DataChannel) BufferedAmount() uint64 {
return c.stream.BufferedAmount()
}
// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
// data that is considered "low." Defaults to 0.
func (c *DataChannel) BufferedAmountLowThreshold() uint64 {
return c.stream.BufferedAmountLowThreshold()
}
// SetBufferedAmountLowThreshold is used to update the threshold.
// See BufferedAmountLowThreshold().
func (c *DataChannel) SetBufferedAmountLowThreshold(th uint64) {
c.stream.SetBufferedAmountLowThreshold(th)
}
// OnBufferedAmountLow sets the callback handler which would be called when the
// number of bytes of outgoing data buffered is lower than the threshold.
func (c *DataChannel) OnBufferedAmountLow(f func()) {
c.stream.OnBufferedAmountLow(f)
}
func (c *DataChannel) commitReliabilityParams() error {
switch c.Config.ChannelType {
case ChannelTypeReliable:
c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
case ChannelTypeReliableUnordered:
c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
case ChannelTypePartialReliableRexmit:
c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
case ChannelTypePartialReliableRexmitUnordered:
c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
case ChannelTypePartialReliableTimed:
c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
case ChannelTypePartialReliableTimedUnordered:
c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
default:
return fmt.Errorf("invalid ChannelType: %v ", c.Config.ChannelType)
}
return nil
}

View file

@ -0,0 +1,11 @@
module github.com/pion/datachannel
require (
github.com/pion/logging v0.2.2
github.com/pion/sctp v1.7.10
github.com/pion/transport v0.10.1
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.6.1
)
go 1.13

View file

@ -0,0 +1,38 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/sctp v1.7.10 h1:o3p3/hZB5Cx12RMGyWmItevJtZ6o2cpuxaw6GOS4x+8=
github.com/pion/sctp v1.7.10/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0=
github.com/pion/transport v0.10.1 h1:2W+yJT+0mOQ160ThZYUx5Zp2skzshiNgxrNE9GUfhJM=
github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -0,0 +1,94 @@
package datachannel
import (
"fmt"
"github.com/pkg/errors"
)
// message is a parsed DataChannel message
type message interface {
Marshal() ([]byte, error)
Unmarshal([]byte) error
}
// messageType is the first byte in a DataChannel message that specifies type
type messageType byte
// DataChannel Message Types
const (
dataChannelAck messageType = 0x02
dataChannelOpen messageType = 0x03
)
func (t messageType) String() string {
switch t {
case dataChannelAck:
return "DataChannelAck"
case dataChannelOpen:
return "DataChannelOpen"
default:
return fmt.Sprintf("Unknown MessageType: %d", t)
}
}
// parse accepts raw input and returns a DataChannel message
func parse(raw []byte) (message, error) {
if len(raw) == 0 {
return nil, errors.Errorf("DataChannel message is not long enough to determine type ")
}
var msg message
switch messageType(raw[0]) {
case dataChannelOpen:
msg = &channelOpen{}
case dataChannelAck:
msg = &channelAck{}
default:
return nil, errors.Errorf("Unknown MessageType %v", messageType(raw[0]))
}
if err := msg.Unmarshal(raw); err != nil {
return nil, err
}
return msg, nil
}
// parseExpectDataChannelOpen parses a DataChannelOpen message
// or throws an error
func parseExpectDataChannelOpen(raw []byte) (*channelOpen, error) {
if len(raw) == 0 {
return nil, errors.Errorf("the DataChannel message is not long enough to determine type")
}
if actualTyp := messageType(raw[0]); actualTyp != dataChannelOpen {
return nil, errors.Errorf("expected DataChannelOpen but got %s", actualTyp)
}
msg := &channelOpen{}
if err := msg.Unmarshal(raw); err != nil {
return nil, err
}
return msg, nil
}
// parseExpectDataChannelAck parses a DataChannelAck message
// or throws an error
// func parseExpectDataChannelAck(raw []byte) (*channelAck, error) {
// if len(raw) == 0 {
// return nil, errors.Errorf("the DataChannel message is not long enough to determine type")
// }
//
// if actualTyp := messageType(raw[0]); actualTyp != dataChannelAck {
// return nil, errors.Errorf("expected DataChannelAck but got %s", actualTyp)
// }
//
// msg := &channelAck{}
// if err := msg.Unmarshal(raw); err != nil {
// return nil, err
// }
//
// return msg, nil
// }

View file

@ -0,0 +1,22 @@
package datachannel
// channelAck is used to ACK a DataChannel open
type channelAck struct{}
const (
channelOpenAckLength = 4
)
// Marshal returns raw bytes for the given message
func (c *channelAck) Marshal() ([]byte, error) {
raw := make([]byte, channelOpenAckLength)
raw[0] = uint8(dataChannelAck)
return raw, nil
}
// Unmarshal populates the struct with the given raw data
func (c *channelAck) Unmarshal(raw []byte) error {
// Message type already checked in Parse and there is no further data
return nil
}

View file

@ -0,0 +1,123 @@
package datachannel
import (
"encoding/binary"
"github.com/pkg/errors"
)
/*
channelOpen represents a DATA_CHANNEL_OPEN Message
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Message Type | Channel Type | Priority |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reliability Parameter |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Label Length | Protocol Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Label |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Protocol |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
type channelOpen struct {
ChannelType ChannelType
Priority uint16
ReliabilityParameter uint32
Label []byte
Protocol []byte
}
const (
channelOpenHeaderLength = 12
)
// ChannelType determines the reliability of the WebRTC DataChannel
type ChannelType byte
// ChannelType enums
const (
// ChannelTypeReliable determines the Data Channel provides a
// reliable in-order bi-directional communication.
ChannelTypeReliable ChannelType = 0x00
// ChannelTypeReliableUnordered determines the Data Channel
// provides a reliable unordered bi-directional communication.
ChannelTypeReliableUnordered ChannelType = 0x80
// ChannelTypePartialReliableRexmit determines the Data Channel
// provides a partially-reliable in-order bi-directional communication.
// User messages will not be retransmitted more times than specified in the Reliability Parameter.
ChannelTypePartialReliableRexmit ChannelType = 0x01
// ChannelTypePartialReliableRexmitUnordered determines
// the Data Channel provides a partial reliable unordered bi-directional communication.
// User messages will not be retransmitted more times than specified in the Reliability Parameter.
ChannelTypePartialReliableRexmitUnordered ChannelType = 0x81
// ChannelTypePartialReliableTimed determines the Data Channel
// provides a partial reliable in-order bi-directional communication.
// User messages might not be transmitted or retransmitted after
// a specified life-time given in milli- seconds in the Reliability Parameter.
// This life-time starts when providing the user message to the protocol stack.
ChannelTypePartialReliableTimed ChannelType = 0x02
// The Data Channel provides a partial reliable unordered bi-directional
// communication. User messages might not be transmitted or retransmitted
// after a specified life-time given in milli- seconds in the Reliability Parameter.
// This life-time starts when providing the user message to the protocol stack.
ChannelTypePartialReliableTimedUnordered ChannelType = 0x82
)
// ChannelPriority enums
const (
ChannelPriorityBelowNormal uint16 = 128
ChannelPriorityNormal uint16 = 256
ChannelPriorityHigh uint16 = 512
ChannelPriorityExtraHigh uint16 = 1024
)
// Marshal returns raw bytes for the given message
func (c *channelOpen) Marshal() ([]byte, error) {
labelLength := len(c.Label)
protocolLength := len(c.Protocol)
totalLen := channelOpenHeaderLength + labelLength + protocolLength
raw := make([]byte, totalLen)
raw[0] = uint8(dataChannelOpen)
raw[1] = byte(c.ChannelType)
binary.BigEndian.PutUint16(raw[2:], c.Priority)
binary.BigEndian.PutUint32(raw[4:], c.ReliabilityParameter)
binary.BigEndian.PutUint16(raw[8:], uint16(labelLength))
binary.BigEndian.PutUint16(raw[10:], uint16(protocolLength))
endLabel := channelOpenHeaderLength + labelLength
copy(raw[channelOpenHeaderLength:endLabel], c.Label)
copy(raw[endLabel:endLabel+protocolLength], c.Protocol)
return raw, nil
}
// Unmarshal populates the struct with the given raw data
func (c *channelOpen) Unmarshal(raw []byte) error {
if len(raw) < channelOpenHeaderLength {
return errors.Errorf("Length of input is not long enough to satisfy header %d", len(raw))
}
c.ChannelType = ChannelType(raw[1])
c.Priority = binary.BigEndian.Uint16(raw[2:])
c.ReliabilityParameter = binary.BigEndian.Uint32(raw[4:])
labelLength := binary.BigEndian.Uint16(raw[8:])
protocolLength := binary.BigEndian.Uint16(raw[10:])
if len(raw) != int(channelOpenHeaderLength+labelLength+protocolLength) {
return errors.Errorf("Label + Protocol length don't match full packet length")
}
c.Label = raw[channelOpenHeaderLength : channelOpenHeaderLength+labelLength]
c.Protocol = raw[channelOpenHeaderLength+labelLength : channelOpenHeaderLength+labelLength+protocolLength]
return nil
}

View file

@ -0,0 +1,15 @@
{
"extends": [
"config:base"
],
"postUpdateOptions": [
"gomodTidy"
],
"commitBody": "Generated by renovateBot",
"packageRules": [
{
"packagePatterns": ["^golang.org/x/"],
"schedule": ["on the first day of the month"]
}
]
}

View file

@ -0,0 +1,21 @@
# http://editorconfig.org/
root = true
[*]
charset = utf-8
insert_final_newline = true
trim_trailing_whitespace = true
end_of_line = lf
[*.go]
indent_style = tab
indent_size = 4
[{*.yml,*.yaml}]
indent_style = space
indent_size = 2
# Makefiles always use tabs for indentation
[Makefile]
indent_style = tab

View file

@ -0,0 +1,2 @@
vendor
*-fuzz.zip

View file

@ -0,0 +1,89 @@
linters-settings:
govet:
check-shadowing: true
misspell:
locale: US
exhaustive:
default-signifies-exhaustive: true
gomodguard:
blocked:
modules:
- github.com/pkg/errors:
recommendations:
- errors
linters:
enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bodyclose # checks whether HTTP response body is closed successfully
- deadcode # Finds unused code
- depguard # Go linter that checks if package imports are in a list of acceptable packages
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- dupl # Tool for code clone detection
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- exhaustive # check exhaustiveness of enum switch statements
- exportloopref # checks for pointers to enclosing loop variables
- gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # The most opinionated Go source code linter
- godox # Tool for detection of FIXME, TODO and other comment keywords
- goerr113 # Golang linter to check the errors handling expressions
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
- goheader # Checks is file header matches to pattern
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
- golint # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
- gosec # Inspects source code for security problems
- gosimple # Linter for Go source code that specializes in simplifying a code
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # Detects when assignments to existing variables are not used
- misspell # Finds commonly misspelled English words in comments
- nakedret # Finds naked returns in functions greater than a specified function length
- noctx # noctx finds sending http request without context.Context
- scopelint # Scopelint checks for unpinned variables in go programs
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
- structcheck # Finds unused struct fields
- stylecheck # Stylecheck is a replacement for golint
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- unused # Checks Go code for unused constants, variables, functions and types
- varcheck # Finds unused global variables and constants
- whitespace # Tool for detection of leading and trailing whitespace
disable:
- funlen # Tool for detection of long functions
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- gomnd # An analyzer to detect magic numbers.
- lll # Reports long lines
- maligned # Tool to detect Go structs that would take less memory if their fields were sorted
- nestif # Reports deeply nested if statements
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- nolintlint # Reports ill-formed or insufficient nolint directives
- prealloc # Finds slice declarations that could potentially be preallocated
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- testpackage # linter that makes you use a separate _test package
- wsl # Whitespace Linter - Forces you to use empty lines!
issues:
exclude-use-default: false
exclude-rules:
# Allow complex tests, better to be self contained
- path: _test\.go
linters:
- gocognit
# Allow complex main function in examples
- path: examples
text: "of func `main` is high"
linters:
- gocognit
run:
skip-dirs-use-default: false

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2018
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,6 @@
fuzz-build-record-layer: fuzz-prepare
go-fuzz-build -tags gofuzz -func FuzzRecordLayer
fuzz-run-record-layer:
go-fuzz -bin dtls-fuzz.zip -workdir fuzz
fuzz-prepare:
@GO111MODULE=on go mod vendor

View file

@ -0,0 +1,151 @@
<h1 align="center">
<br>
Pion DTLS
<br>
</h1>
<h4 align="center">A Go implementation of DTLS</h4>
<p align="center">
<a href="https://pion.ly"><img src="https://img.shields.io/badge/pion-dtls-gray.svg?longCache=true&colorB=brightgreen" alt="Pion DTLS"></a>
<a href="https://sourcegraph.com/github.com/pion/dtls"><img src="https://sourcegraph.com/github.com/pion/dtls/-/badge.svg" alt="Sourcegraph Widget"></a>
<a href="https://pion.ly/slack"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a>
<br>
<a href="https://travis-ci.org/pion/dtls"><img src="https://travis-ci.org/pion/dtls.svg?branch=master" alt="Build Status"></a>
<a href="https://pkg.go.dev/github.com/pion/dtls"><img src="https://godoc.org/github.com/pion/dtls?status.svg" alt="GoDoc"></a>
<a href="https://codecov.io/gh/pion/dtls"><img src="https://codecov.io/gh/pion/dtls/branch/master/graph/badge.svg" alt="Coverage Status"></a>
<a href="https://goreportcard.com/report/github.com/pion/dtls"><img src="https://goreportcard.com/badge/github.com/pion/dtls" alt="Go Report Card"></a>
<a href="https://www.codacy.com/app/Sean-Der/dtls"><img src="https://api.codacy.com/project/badge/Grade/18f4aec384894e6aac0b94effe51961d" alt="Codacy Badge"></a>
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License: MIT"></a>
</p>
<br>
Native [DTLS 1.2][rfc6347] implementation in the Go programming language.
A long term goal is a professional security review, and maye inclusion in stdlib.
[rfc6347]: https://tools.ietf.org/html/rfc6347
### Goals/Progress
This will only be targeting DTLS 1.2, and the most modern/common cipher suites.
We would love contributes that fall under the 'Planned Features' and fixing any bugs!
#### Current features
* DTLS 1.2 Client/Server
* Key Exchange via ECDHE(curve25519, nistp256, nistp384) and PSK
* Packet loss and re-ordering is handled during handshaking
* Key export ([RFC 5705][rfc5705])
* Serialization and Resumption of sessions
* Extended Master Secret extension ([RFC 7627][rfc7627])
[rfc5705]: https://tools.ietf.org/html/rfc5705
[rfc7627]: https://tools.ietf.org/html/rfc7627
#### Supported ciphers
##### ECDHE
* TLS_ECDHE_ECDSA_WITH_AES_128_CCM ([RFC 6655][rfc6655])
* TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655])
* TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289])
* TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289])
* TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422])
* TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422])
##### PSK
* TLS_PSK_WITH_AES_128_CCM ([RFC 6655][rfc6655])
* TLS_PSK_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655])
* TLS_PSK_WITH_AES_128_GCM_SHA256 ([RFC 5487][rfc5487])
[rfc5289]: https://tools.ietf.org/html/rfc5289
[rfc8422]: https://tools.ietf.org/html/rfc8422
[rfc6655]: https://tools.ietf.org/html/rfc6655
[rfc5487]: https://tools.ietf.org/html/rfc5487
#### Planned Features
* Chacha20Poly1305
#### Excluded Features
* DTLS 1.0
* Renegotiation
* Compression
### Using
This library needs at least Go 1.13, and you should have [Go modules
enabled](https://github.com/golang/go/wiki/Modules).
#### Pion DTLS
For a DTLS 1.2 Server that listens on 127.0.0.1:4444
```sh
go run examples/listen/selfsign/main.go
```
For a DTLS 1.2 Client that connects to 127.0.0.1:4444
```sh
go run examples/dial/selfsign/main.go
```
#### OpenSSL
Pion DTLS can connect to itself and OpenSSL.
```
// Generate a certificate
openssl ecparam -out key.pem -name prime256v1 -genkey
openssl req -new -sha256 -key key.pem -out server.csr
openssl x509 -req -sha256 -days 365 -in server.csr -signkey key.pem -out cert.pem
// Use with examples/dial/selfsign/main.go
openssl s_server -dtls1_2 -cert cert.pem -key key.pem -accept 4444
// Use with examples/listen/selfsign/main.go
openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -debug -cert cert.pem -key key.pem
```
### Using with PSK
Pion DTLS also comes with examples that do key exchange via PSK
#### Pion DTLS
```sh
go run examples/listen/psk/main.go
```
```sh
go run examples/dial/psk/main.go
```
#### OpenSSL
```
// Use with examples/dial/psk/main.go
openssl s_server -dtls1_2 -accept 4444 -nocert -psk abc123 -cipher PSK-AES128-CCM8
// Use with examples/listen/psk/main.go
openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -psk abc123 -cipher PSK-AES128-CCM8
```
### Contributing
Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contributing)** to join the group of amazing people making this project possible:
* [Sean DuBois](https://github.com/Sean-Der) - *Original Author*
* [Michiel De Backker](https://github.com/backkem) - *Public API*
* [Chris Hiszpanski](https://github.com/thinkski) - *Support Signature Algorithms Extension*
* [Iñigo Garcia Olaizola](https://github.com/igolaizola) - *Serialization & resumption, cert verification, E2E*
* [Daniele Sluijters](https://github.com/daenney) - *AES-CCM support*
* [Jin Lei](https://github.com/jinleileiking) - *Logging*
* [Hugo Arregui](https://github.com/hugoArregui)
* [Lander Noterman](https://github.com/LanderN)
* [Aleksandr Razumov](https://github.com/ernado) - *Fuzzing*
* [Ryan Gordon](https://github.com/ryangordon)
* [Stefan Tatschner](https://rumpelsepp.org/contact.html)
* [Hayden James](https://github.com/hjames9)
* [Jozef Kralik](https://github.com/jkralik)
* [Robert Eperjesi](https://github.com/epes)
* [Atsushi Watanabe](https://github.com/at-wat)
* [Julien Salleyron](https://github.com/juliens) - *Server Name Indication*
* [Jeroen de Bruijn](https://github.com/vidavidorra)
* [bjdgyc](https://github.com/bjdgyc)
* [Jeffrey Stoke (Jeff Ctor)](https://github.com/jeffreystoke) - *Fragmentbuffer Fix*
* [Frank Olbricht](https://github.com/folbricht)
* [ZHENK](https://github.com/scorpionknifes)
* [Carson Hoffman](https://github.com/CarsonHoffman)
* [Vadim Filimonov](https://github.com/fffilimonov)
### License
MIT License - see [LICENSE](LICENSE) for full text

View file

@ -0,0 +1,145 @@
package dtls
import "fmt"
type alertLevel byte
const (
alertLevelWarning alertLevel = 1
alertLevelFatal alertLevel = 2
)
func (a alertLevel) String() string {
switch a {
case alertLevelWarning:
return "LevelWarning"
case alertLevelFatal:
return "LevelFatal"
default:
return "Invalid alert level"
}
}
type alertDescription byte
const (
alertCloseNotify alertDescription = 0
alertUnexpectedMessage alertDescription = 10
alertBadRecordMac alertDescription = 20
alertDecryptionFailed alertDescription = 21
alertRecordOverflow alertDescription = 22
alertDecompressionFailure alertDescription = 30
alertHandshakeFailure alertDescription = 40
alertNoCertificate alertDescription = 41
alertBadCertificate alertDescription = 42
alertUnsupportedCertificate alertDescription = 43
alertCertificateRevoked alertDescription = 44
alertCertificateExpired alertDescription = 45
alertCertificateUnknown alertDescription = 46
alertIllegalParameter alertDescription = 47
alertUnknownCA alertDescription = 48
alertAccessDenied alertDescription = 49
alertDecodeError alertDescription = 50
alertDecryptError alertDescription = 51
alertExportRestriction alertDescription = 60
alertProtocolVersion alertDescription = 70
alertInsufficientSecurity alertDescription = 71
alertInternalError alertDescription = 80
alertUserCanceled alertDescription = 90
alertNoRenegotiation alertDescription = 100
alertUnsupportedExtension alertDescription = 110
)
func (a alertDescription) String() string {
switch a {
case alertCloseNotify:
return "CloseNotify"
case alertUnexpectedMessage:
return "UnexpectedMessage"
case alertBadRecordMac:
return "BadRecordMac"
case alertDecryptionFailed:
return "DecryptionFailed"
case alertRecordOverflow:
return "RecordOverflow"
case alertDecompressionFailure:
return "DecompressionFailure"
case alertHandshakeFailure:
return "HandshakeFailure"
case alertNoCertificate:
return "NoCertificate"
case alertBadCertificate:
return "BadCertificate"
case alertUnsupportedCertificate:
return "UnsupportedCertificate"
case alertCertificateRevoked:
return "CertificateRevoked"
case alertCertificateExpired:
return "CertificateExpired"
case alertCertificateUnknown:
return "CertificateUnknown"
case alertIllegalParameter:
return "IllegalParameter"
case alertUnknownCA:
return "UnknownCA"
case alertAccessDenied:
return "AccessDenied"
case alertDecodeError:
return "DecodeError"
case alertDecryptError:
return "DecryptError"
case alertExportRestriction:
return "ExportRestriction"
case alertProtocolVersion:
return "ProtocolVersion"
case alertInsufficientSecurity:
return "InsufficientSecurity"
case alertInternalError:
return "InternalError"
case alertUserCanceled:
return "UserCanceled"
case alertNoRenegotiation:
return "NoRenegotiation"
case alertUnsupportedExtension:
return "UnsupportedExtension"
default:
return "Invalid alert description"
}
}
// One of the content types supported by the TLS record layer is the
// alert type. Alert messages convey the severity of the message
// (warning or fatal) and a description of the alert. Alert messages
// with a level of fatal result in the immediate termination of the
// connection. In this case, other connections corresponding to the
// session may continue, but the session identifier MUST be invalidated,
// preventing the failed session from being used to establish new
// connections. Like other messages, alert messages are encrypted and
// compressed, as specified by the current connection state.
// https://tools.ietf.org/html/rfc5246#section-7.2
type alert struct {
alertLevel alertLevel
alertDescription alertDescription
}
func (a alert) contentType() contentType {
return contentTypeAlert
}
func (a *alert) Marshal() ([]byte, error) {
return []byte{byte(a.alertLevel), byte(a.alertDescription)}, nil
}
func (a *alert) Unmarshal(data []byte) error {
if len(data) != 2 {
return errBufferTooSmall
}
a.alertLevel = alertLevel(data[0])
a.alertDescription = alertDescription(data[1])
return nil
}
func (a *alert) String() string {
return fmt.Sprintf("Alert %s: %s", a.alertLevel, a.alertDescription)
}

View file

@ -0,0 +1,23 @@
package dtls
// Application data messages are carried by the record layer and are
// fragmented, compressed, and encrypted based on the current connection
// state. The messages are treated as transparent data to the record
// layer.
// https://tools.ietf.org/html/rfc5246#section-10
type applicationData struct {
data []byte
}
func (a applicationData) contentType() contentType {
return contentTypeApplicationData
}
func (a *applicationData) Marshal() ([]byte, error) {
return append([]byte{}, a.data...), nil
}
func (a *applicationData) Unmarshal(data []byte) error {
a.data = append([]byte{}, data...)
return nil
}

View file

@ -0,0 +1,67 @@
package dtls
import (
"crypto/tls"
"crypto/x509"
"strings"
)
func (c *handshakeConfig) getCertificate(serverName string) (*tls.Certificate, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.nameToCertificate == nil {
nameToCertificate := make(map[string]*tls.Certificate)
for i := range c.localCertificates {
cert := &c.localCertificates[i]
x509Cert := cert.Leaf
if x509Cert == nil {
var parseErr error
x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0])
if parseErr != nil {
continue
}
}
if len(x509Cert.Subject.CommonName) > 0 {
nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert
}
for _, san := range x509Cert.DNSNames {
nameToCertificate[strings.ToLower(san)] = cert
}
}
c.nameToCertificate = nameToCertificate
}
if len(c.localCertificates) == 0 {
return nil, errNoCertificates
}
if len(c.localCertificates) == 1 {
// There's only one choice, so no point doing any work.
return &c.localCertificates[0], nil
}
if len(serverName) == 0 {
return &c.localCertificates[0], nil
}
name := strings.TrimRight(strings.ToLower(serverName), ".")
if cert, ok := c.nameToCertificate[name]; ok {
return cert, nil
}
// try replacing labels in the name with wildcards until we get a
// match.
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := c.nameToCertificate[candidate]; ok {
return cert, nil
}
}
// If nothing matches, return the first certificate.
return &c.localCertificates[0], nil
}

View file

@ -0,0 +1,25 @@
package dtls
// The change cipher spec protocol exists to signal transitions in
// ciphering strategies. The protocol consists of a single message,
// which is encrypted and compressed under the current (not the pending)
// connection state. The message consists of a single byte of value 1.
// https://tools.ietf.org/html/rfc5246#section-7.1
type changeCipherSpec struct {
}
func (c changeCipherSpec) contentType() contentType {
return contentTypeChangeCipherSpec
}
func (c *changeCipherSpec) Marshal() ([]byte, error) {
return []byte{0x01}, nil
}
func (c *changeCipherSpec) Unmarshal(data []byte) error {
if len(data) == 1 && data[0] == 0x01 {
return nil
}
return errInvalidCipherSpec
}

View file

@ -0,0 +1,206 @@
package dtls
import (
"encoding/binary"
"fmt"
"hash"
)
// CipherSuiteID is an ID for our supported CipherSuites
type CipherSuiteID uint16
// Supported Cipher Suites
const (
// AES-128-CCM
TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = 0xc0ac //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = 0xc0ae //nolint:golint,stylecheck
// AES-128-GCM-SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = 0xc02b //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = 0xc02f //nolint:golint,stylecheck
// AES-256-CBC-SHA
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = 0xc00a //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = 0xc014 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM CipherSuiteID = 0xc0a4 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = 0xc0a8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = 0x00a8 //nolint:golint,stylecheck
)
var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14
func (c CipherSuiteID) String() string {
switch c {
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM:
return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM"
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8"
case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
case TLS_PSK_WITH_AES_128_CCM:
return "TLS_PSK_WITH_AES_128_CCM"
case TLS_PSK_WITH_AES_128_CCM_8:
return "TLS_PSK_WITH_AES_128_CCM_8"
case TLS_PSK_WITH_AES_128_GCM_SHA256:
return "TLS_PSK_WITH_AES_128_GCM_SHA256"
default:
return fmt.Sprintf("unknown(%v)", uint16(c))
}
}
type cipherSuite interface {
String() string
ID() CipherSuiteID
certificateType() clientCertificateType
hashFunc() func() hash.Hash
isPSK() bool
isInitialized() bool
// Generate the internal encryption state
init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error
encrypt(pkt *recordLayer, raw []byte) ([]byte, error)
decrypt(in []byte) ([]byte, error)
}
// CipherSuiteName provides the same functionality as tls.CipherSuiteName
// that appeared first in Go 1.14.
//
// Our implementation differs slightly in that it takes in a CiperSuiteID,
// like the rest of our library, instead of a uint16 like crypto/tls.
func CipherSuiteName(id CipherSuiteID) string {
suite := cipherSuiteForID(id)
if suite != nil {
return suite.String()
}
return fmt.Sprintf("0x%04X", uint16(id))
}
// Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
// A cipherSuite is a specific combination of key agreement, cipher and MAC
// function.
func cipherSuiteForID(id CipherSuiteID) cipherSuite {
switch id {
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM:
return newCipherSuiteTLSEcdheEcdsaWithAes128Ccm()
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
return newCipherSuiteTLSEcdheEcdsaWithAes128Ccm8()
case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return &cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{}
case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return &cipherSuiteTLSEcdheRsaWithAes128GcmSha256{}
case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
return &cipherSuiteTLSEcdheEcdsaWithAes256CbcSha{}
case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
return &cipherSuiteTLSEcdheRsaWithAes256CbcSha{}
case TLS_PSK_WITH_AES_128_CCM:
return newCipherSuiteTLSPskWithAes128Ccm()
case TLS_PSK_WITH_AES_128_CCM_8:
return newCipherSuiteTLSPskWithAes128Ccm8()
case TLS_PSK_WITH_AES_128_GCM_SHA256:
return &cipherSuiteTLSPskWithAes128GcmSha256{}
}
return nil
}
// CipherSuites we support in order of preference
func defaultCipherSuites() []cipherSuite {
return []cipherSuite{
&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheRsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheEcdsaWithAes256CbcSha{},
&cipherSuiteTLSEcdheRsaWithAes256CbcSha{},
}
}
func allCipherSuites() []cipherSuite {
return []cipherSuite{
newCipherSuiteTLSEcdheEcdsaWithAes128Ccm(),
newCipherSuiteTLSEcdheEcdsaWithAes128Ccm8(),
&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheRsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheEcdsaWithAes256CbcSha{},
&cipherSuiteTLSEcdheRsaWithAes256CbcSha{},
newCipherSuiteTLSPskWithAes128Ccm(),
newCipherSuiteTLSPskWithAes128Ccm8(),
&cipherSuiteTLSPskWithAes128GcmSha256{},
}
}
func decodeCipherSuites(buf []byte) ([]cipherSuite, error) {
if len(buf) < 2 {
return nil, errDTLSPacketInvalidLength
}
cipherSuitesCount := int(binary.BigEndian.Uint16(buf[0:])) / 2
rtrn := []cipherSuite{}
for i := 0; i < cipherSuitesCount; i++ {
if len(buf) < (i*2 + 4) {
return nil, errBufferTooSmall
}
id := CipherSuiteID(binary.BigEndian.Uint16(buf[(i*2)+2:]))
if c := cipherSuiteForID(id); c != nil {
rtrn = append(rtrn, c)
}
}
return rtrn, nil
}
func encodeCipherSuites(cipherSuites []cipherSuite) []byte {
out := []byte{0x00, 0x00}
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(cipherSuites)*2))
for _, c := range cipherSuites {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(c.ID()))
}
return out
}
func parseCipherSuites(userSelectedSuites []CipherSuiteID, excludePSK, excludeNonPSK bool) ([]cipherSuite, error) {
cipherSuitesForIDs := func(ids []CipherSuiteID) ([]cipherSuite, error) {
cipherSuites := []cipherSuite{}
for _, id := range ids {
c := cipherSuiteForID(id)
if c == nil {
return nil, &invalidCipherSuite{id}
}
cipherSuites = append(cipherSuites, c)
}
return cipherSuites, nil
}
var (
cipherSuites []cipherSuite
err error
i int
)
if len(userSelectedSuites) != 0 {
cipherSuites, err = cipherSuitesForIDs(userSelectedSuites)
if err != nil {
return nil, err
}
} else {
cipherSuites = defaultCipherSuites()
}
for _, c := range cipherSuites {
if excludePSK && c.isPSK() || excludeNonPSK && !c.isPSK() {
continue
}
cipherSuites[i] = c
i++
}
cipherSuites = cipherSuites[:i]
if len(cipherSuites) == 0 {
return nil, errNoAvailableCipherSuites
}
return cipherSuites, nil
}

View file

@ -0,0 +1,93 @@
package dtls
import (
"crypto/sha256"
"errors"
"fmt"
"hash"
"sync/atomic"
)
type cipherSuiteAes128Ccm struct {
ccm atomic.Value // *cryptoCCM
clientCertificateType clientCertificateType
id CipherSuiteID
psk bool
cryptoCCMTagLen cryptoCCMTagLen
}
func newCipherSuiteAes128Ccm(clientCertificateType clientCertificateType, id CipherSuiteID, psk bool, cryptoCCMTagLen cryptoCCMTagLen) *cipherSuiteAes128Ccm {
return &cipherSuiteAes128Ccm{
clientCertificateType: clientCertificateType,
id: id,
psk: psk,
cryptoCCMTagLen: cryptoCCMTagLen,
}
}
func (c *cipherSuiteAes128Ccm) certificateType() clientCertificateType {
return c.clientCertificateType
}
func (c *cipherSuiteAes128Ccm) ID() CipherSuiteID {
return c.id
}
func (c *cipherSuiteAes128Ccm) String() string {
return c.id.String()
}
func (c *cipherSuiteAes128Ccm) hashFunc() func() hash.Hash {
return sha256.New
}
func (c *cipherSuiteAes128Ccm) isPSK() bool {
return c.psk
}
func (c *cipherSuiteAes128Ccm) isInitialized() bool {
return c.ccm.Load() != nil
}
func (c *cipherSuiteAes128Ccm) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 0
prfKeyLen = 16
prfIvLen = 4
)
keys, err := prfEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.hashFunc())
if err != nil {
return err
}
var ccm *cryptoCCM
if isClient {
ccm, err = newCryptoCCM(c.cryptoCCMTagLen, keys.clientWriteKey, keys.clientWriteIV, keys.serverWriteKey, keys.serverWriteIV)
} else {
ccm, err = newCryptoCCM(c.cryptoCCMTagLen, keys.serverWriteKey, keys.serverWriteIV, keys.clientWriteKey, keys.clientWriteIV)
}
c.ccm.Store(ccm)
return err
}
var errCipherSuiteNotInit = errors.New("CipherSuite has not been initialized")
func (c *cipherSuiteAes128Ccm) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
ccm := c.ccm.Load()
if ccm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return ccm.(*cryptoCCM).encrypt(pkt, raw)
}
func (c *cipherSuiteAes128Ccm) decrypt(raw []byte) ([]byte, error) {
ccm := c.ccm.Load()
if ccm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return ccm.(*cryptoCCM).decrypt(raw)
}

View file

@ -0,0 +1,36 @@
// +build go1.14
package dtls
import (
"crypto/tls"
)
// Convert from our cipherSuite interface to a tls.CipherSuite struct
func toTLSCipherSuite(c cipherSuite) *tls.CipherSuite {
return &tls.CipherSuite{
ID: uint16(c.ID()),
Name: c.String(),
SupportedVersions: []uint16{VersionDTLS12},
Insecure: false,
}
}
// CipherSuites returns a list of cipher suites currently implemented by this
// package, excluding those with security issues, which are returned by
// InsecureCipherSuites.
func CipherSuites() []*tls.CipherSuite {
suites := allCipherSuites()
res := make([]*tls.CipherSuite, len(suites))
for i, c := range suites {
res[i] = toTLSCipherSuite(c)
}
return res
}
// InsecureCipherSuites returns a list of cipher suites currently implemented by
// this package and which have security issues.
func InsecureCipherSuites() []*tls.CipherSuite {
var res []*tls.CipherSuite
return res
}

View file

@ -0,0 +1,5 @@
package dtls
func newCipherSuiteTLSEcdheEcdsaWithAes128Ccm() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateTypeECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM, false, cryptoCCMTagLength)
}

View file

@ -0,0 +1,5 @@
package dtls
func newCipherSuiteTLSEcdheEcdsaWithAes128Ccm8() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateTypeECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, false, cryptoCCM8TagLength)
}

View file

@ -0,0 +1,77 @@
package dtls
import (
"crypto/sha256"
"fmt"
"hash"
"sync/atomic"
)
type cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256 struct {
gcm atomic.Value // *cryptoGCM
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) certificateType() clientCertificateType {
return clientCertificateTypeECDSASign
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) ID() CipherSuiteID {
return TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) String() string {
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) hashFunc() func() hash.Hash {
return sha256.New
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) isPSK() bool {
return false
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) isInitialized() bool {
return c.gcm.Load() != nil
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 0
prfKeyLen = 16
prfIvLen = 4
)
keys, err := prfEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.hashFunc())
if err != nil {
return err
}
var gcm *cryptoGCM
if isClient {
gcm, err = newCryptoGCM(keys.clientWriteKey, keys.clientWriteIV, keys.serverWriteKey, keys.serverWriteIV)
} else {
gcm, err = newCryptoGCM(keys.serverWriteKey, keys.serverWriteIV, keys.clientWriteKey, keys.clientWriteIV)
}
c.gcm.Store(gcm)
return err
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
gcm := c.gcm.Load()
if gcm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return gcm.(*cryptoGCM).encrypt(pkt, raw)
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) decrypt(raw []byte) ([]byte, error) {
gcm := c.gcm.Load()
if gcm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return gcm.(*cryptoGCM).decrypt(raw)
}

View file

@ -0,0 +1,83 @@
package dtls
import (
"crypto/sha256"
"fmt"
"hash"
"sync/atomic"
)
type cipherSuiteTLSEcdheEcdsaWithAes256CbcSha struct {
cbc atomic.Value // *cryptoCBC
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) certificateType() clientCertificateType {
return clientCertificateTypeECDSASign
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) ID() CipherSuiteID {
return TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) String() string {
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) hashFunc() func() hash.Hash {
return sha256.New
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) isPSK() bool {
return false
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) isInitialized() bool {
return c.cbc.Load() != nil
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 20
prfKeyLen = 32
prfIvLen = 16
)
keys, err := prfEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.hashFunc())
if err != nil {
return err
}
var cbc *cryptoCBC
if isClient {
cbc, err = newCryptoCBC(
keys.clientWriteKey, keys.clientWriteIV, keys.clientMACKey,
keys.serverWriteKey, keys.serverWriteIV, keys.serverMACKey,
)
} else {
cbc, err = newCryptoCBC(
keys.serverWriteKey, keys.serverWriteIV, keys.serverMACKey,
keys.clientWriteKey, keys.clientWriteIV, keys.clientMACKey,
)
}
c.cbc.Store(cbc)
return err
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cbc.(*cryptoCBC).encrypt(pkt, raw)
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) decrypt(raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cbc.(*cryptoCBC).decrypt(raw)
}

View file

@ -0,0 +1,17 @@
package dtls
type cipherSuiteTLSEcdheRsaWithAes128GcmSha256 struct {
cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256
}
func (c *cipherSuiteTLSEcdheRsaWithAes128GcmSha256) certificateType() clientCertificateType {
return clientCertificateTypeRSASign
}
func (c *cipherSuiteTLSEcdheRsaWithAes128GcmSha256) ID() CipherSuiteID {
return TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
}
func (c *cipherSuiteTLSEcdheRsaWithAes128GcmSha256) String() string {
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
}

View file

@ -0,0 +1,17 @@
package dtls
type cipherSuiteTLSEcdheRsaWithAes256CbcSha struct {
cipherSuiteTLSEcdheEcdsaWithAes256CbcSha
}
func (c *cipherSuiteTLSEcdheRsaWithAes256CbcSha) certificateType() clientCertificateType {
return clientCertificateTypeRSASign
}
func (c *cipherSuiteTLSEcdheRsaWithAes256CbcSha) ID() CipherSuiteID {
return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
}
func (c *cipherSuiteTLSEcdheRsaWithAes256CbcSha) String() string {
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
}

View file

@ -0,0 +1,5 @@
package dtls
func newCipherSuiteTLSPskWithAes128Ccm() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateType(0), TLS_PSK_WITH_AES_128_CCM, true, cryptoCCMTagLength)
}

View file

@ -0,0 +1,5 @@
package dtls
func newCipherSuiteTLSPskWithAes128Ccm8() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateType(0), TLS_PSK_WITH_AES_128_CCM_8, true, cryptoCCM8TagLength)
}

View file

@ -0,0 +1,21 @@
package dtls
type cipherSuiteTLSPskWithAes128GcmSha256 struct {
cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) certificateType() clientCertificateType {
return clientCertificateType(0)
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) ID() CipherSuiteID {
return TLS_PSK_WITH_AES_128_GCM_SHA256
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) String() string {
return "TLS_PSK_WITH_AES_128_GCM_SHA256"
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) isPSK() bool {
return true
}

View file

@ -0,0 +1,16 @@
package dtls
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10
type clientCertificateType byte
const (
clientCertificateTypeRSASign clientCertificateType = 1
clientCertificateTypeECDSASign clientCertificateType = 64
)
func clientCertificateTypes() map[clientCertificateType]bool {
return map[clientCertificateType]bool{
clientCertificateTypeRSASign: true,
clientCertificateTypeECDSASign: true,
}
}

View file

@ -0,0 +1,20 @@
#
# DO NOT EDIT THIS FILE
#
# It is automatically copied from https://github.com/pion/.goassets repository.
#
coverage:
status:
project:
default:
# Allow decreasing 2% of total coverage to avoid noise.
threshold: 2%
patch:
default:
target: 70%
only_pulls: true
ignore:
- "examples/*"
- "examples/**/*"

View file

@ -0,0 +1,49 @@
package dtls
type compressionMethodID byte
const (
compressionMethodNull compressionMethodID = 0
)
type compressionMethod struct {
id compressionMethodID
}
func compressionMethods() map[compressionMethodID]*compressionMethod {
return map[compressionMethodID]*compressionMethod{
compressionMethodNull: {id: compressionMethodNull},
}
}
func defaultCompressionMethods() []*compressionMethod {
return []*compressionMethod{
compressionMethods()[compressionMethodNull],
}
}
func decodeCompressionMethods(buf []byte) ([]*compressionMethod, error) {
if len(buf) < 1 {
return nil, errDTLSPacketInvalidLength
}
compressionMethodsCount := int(buf[0])
c := []*compressionMethod{}
for i := 0; i < compressionMethodsCount; i++ {
if len(buf) <= i+1 {
return nil, errBufferTooSmall
}
id := compressionMethodID(buf[i+1])
if compressionMethod, ok := compressionMethods()[id]; ok {
c = append(c, compressionMethod)
}
}
return c, nil
}
func encodeCompressionMethods(c []*compressionMethod) []byte {
out := []byte{byte(len(c))}
for i := len(c); i > 0; i-- {
out = append(out, byte(c[i-1].id))
}
return out
}

View file

@ -0,0 +1,179 @@
package dtls
import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/tls"
"crypto/x509"
"time"
"github.com/pion/logging"
)
// Config is used to configure a DTLS client or server.
// After a Config is passed to a DTLS function it must not be modified.
type Config struct {
// Certificates contains certificate chain to present to the other side of the connection.
// Server MUST set this if PSK is non-nil
// client SHOULD sets this so CertificateRequests can be handled if PSK is non-nil
Certificates []tls.Certificate
// CipherSuites is a list of supported cipher suites.
// If CipherSuites is nil, a default list is used
CipherSuites []CipherSuiteID
// SignatureSchemes contains the signature and hash schemes that the peer requests to verify.
SignatureSchemes []tls.SignatureScheme
// SRTPProtectionProfiles are the supported protection profiles
// Clients will send this via use_srtp and assert that the server properly responds
// Servers will assert that clients send one of these profiles and will respond as needed
SRTPProtectionProfiles []SRTPProtectionProfile
// ClientAuth determines the server's policy for
// TLS Client Authentication. The default is NoClientCert.
ClientAuth ClientAuthType
// RequireExtendedMasterSecret determines if the "Extended Master Secret" extension
// should be disabled, requested, or required (default requested).
ExtendedMasterSecret ExtendedMasterSecretType
// FlightInterval controls how often we send outbound handshake messages
// defaults to time.Second
FlightInterval time.Duration
// PSK sets the pre-shared key used by this DTLS connection
// If PSK is non-nil only PSK CipherSuites will be used
PSK PSKCallback
PSKIdentityHint []byte
// InsecureSkipVerify controls whether a client verifies the
// server's certificate chain and host name.
// If InsecureSkipVerify is true, TLS accepts any certificate
// presented by the server and any host name in that certificate.
// In this mode, TLS is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
InsecureSkipVerify bool
// InsecureHashes allows the use of hashing algorithms that are known
// to be vulnerable.
InsecureHashes bool
// VerifyPeerCertificate, if not nil, is called after normal
// certificate verification by either a client or server. It
// receives the certificate provided by the peer and also a flag
// that tells if normal verification has succeedded. If it returns a
// non-nil error, the handshake is aborted and that error results.
//
// If normal verification fails then the handshake will abort before
// considering this callback. If normal verification is disabled by
// setting InsecureSkipVerify, or (for a server) when ClientAuth is
// RequestClientCert or RequireAnyClientCert, then this callback will
// be considered but the verifiedChains will always be nil.
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
// RootCAs defines the set of root certificate authorities
// that one peer uses when verifying the other peer's certificates.
// If RootCAs is nil, TLS uses the host's root CA set.
RootCAs *x509.CertPool
// ClientCAs defines the set of root certificate authorities
// that servers use if required to verify a client certificate
// by the policy in ClientAuth.
ClientCAs *x509.CertPool
// ServerName is used to verify the hostname on the returned
// certificates unless InsecureSkipVerify is given.
ServerName string
LoggerFactory logging.LoggerFactory
// ConnectContextMaker is a function to make a context used in Dial(),
// Client(), Server(), and Accept(). If nil, the default ConnectContextMaker
// is used. It can be implemented as following.
//
// func ConnectContextMaker() (context.Context, func()) {
// return context.WithTimeout(context.Background(), 30*time.Second)
// }
ConnectContextMaker func() (context.Context, func())
// MTU is the length at which handshake messages will be fragmented to
// fit within the maximum transmission unit (default is 1200 bytes)
MTU int
// ReplayProtectionWindow is the size of the replay attack protection window.
// Duplication of the sequence number is checked in this window size.
// Packet with sequence number older than this value compared to the latest
// accepted packet will be discarded. (default is 64)
ReplayProtectionWindow int
}
func defaultConnectContextMaker() (context.Context, func()) {
return context.WithTimeout(context.Background(), 30*time.Second)
}
func (c *Config) connectContextMaker() (context.Context, func()) {
if c.ConnectContextMaker == nil {
return defaultConnectContextMaker()
}
return c.ConnectContextMaker()
}
const defaultMTU = 1200 // bytes
// PSKCallback is called once we have the remote's PSKIdentityHint.
// If the remote provided none it will be nil
type PSKCallback func([]byte) ([]byte, error)
// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
type ClientAuthType int
// ClientAuthType enums
const (
NoClientCert ClientAuthType = iota
RequestClientCert
RequireAnyClientCert
VerifyClientCertIfGiven
RequireAndVerifyClientCert
)
// ExtendedMasterSecretType declares the policy the client and server
// will follow for the Extended Master Secret extension
type ExtendedMasterSecretType int
// ExtendedMasterSecretType enums
const (
RequestExtendedMasterSecret ExtendedMasterSecretType = iota
RequireExtendedMasterSecret
DisableExtendedMasterSecret
)
func validateConfig(config *Config) error {
switch {
case config == nil:
return errNoConfigProvided
case len(config.Certificates) > 0 && config.PSK != nil:
return errPSKAndCertificate
case config.PSKIdentityHint != nil && config.PSK == nil:
return errIdentityNoPSK
}
for _, cert := range config.Certificates {
if cert.Certificate == nil {
return errInvalidCertificate
}
if cert.PrivateKey != nil {
switch cert.PrivateKey.(type) {
case ed25519.PrivateKey:
case *ecdsa.PrivateKey:
default:
return errInvalidPrivateKey
}
}
}
_, err := parseCipherSuites(config.CipherSuites, config.PSK == nil, config.PSK != nil)
return err
}

View file

@ -0,0 +1,978 @@
package dtls
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pion/dtls/v2/internal/closer"
"github.com/pion/dtls/v2/internal/net/connctx"
"github.com/pion/logging"
"github.com/pion/transport/deadline"
"github.com/pion/transport/replaydetector"
)
const (
initialTickerInterval = time.Second
cookieLength = 20
defaultNamedCurve = namedCurveX25519
inboundBufferSize = 8192
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
defaultReplayProtectionWindow = 64
)
var (
errApplicationDataEpochZero = errors.New("ApplicationData with epoch of 0")
errUnhandledContextType = errors.New("unhandled contentType")
)
func invalidKeyingLabels() map[string]bool {
return map[string]bool{
"client finished": true,
"server finished": true,
"master secret": true,
"key expansion": true,
}
}
// Conn represents a DTLS connection
type Conn struct {
lock sync.RWMutex // Internal lock (must not be public)
nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from
fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling
handshakeCache *handshakeCache // caching of handshake messages for verifyData generation
decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read`
state State // Internal state
maximumTransmissionUnit int
handshakeCompletedSuccessfully atomic.Value
encryptedPackets [][]byte
connectionClosedByUser bool
closeLock sync.Mutex
closed *closer.Closer
handshakeLoopsFinished sync.WaitGroup
readDeadline *deadline.Deadline
writeDeadline *deadline.Deadline
log logging.LeveledLogger
reading chan struct{}
handshakeRecv chan chan struct{}
cancelHandshaker func()
cancelHandshakeReader func()
fsm *handshakeFSM
replayProtectionWindow uint
}
func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
err := validateConfig(config)
if err != nil {
return nil, err
}
if nextConn == nil {
return nil, errNilNextConn
}
cipherSuites, err := parseCipherSuites(config.CipherSuites, config.PSK == nil, config.PSK != nil)
if err != nil {
return nil, err
}
signatureSchemes, err := parseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err
}
workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}
loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
}
logger := loggerFactory.NewLogger("dtls")
mtu := config.MTU
if mtu <= 0 {
mtu = defaultMTU
}
replayProtectionWindow := config.ReplayProtectionWindow
if replayProtectionWindow <= 0 {
replayProtectionWindow = defaultReplayProtectionWindow
}
c := &Conn{
nextConn: connctx.New(nextConn),
fragmentBuffer: newFragmentBuffer(),
handshakeCache: newHandshakeCache(),
maximumTransmissionUnit: mtu,
decrypted: make(chan interface{}, 1),
log: logger,
readDeadline: deadline.New(),
writeDeadline: deadline.New(),
reading: make(chan struct{}, 1),
handshakeRecv: make(chan chan struct{}),
closed: closer.NewCloser(),
cancelHandshaker: func() {},
replayProtectionWindow: uint(replayProtectionWindow),
state: State{
isClient: isClient,
},
}
c.setRemoteEpoch(0)
c.setLocalEpoch(0)
serverName := config.ServerName
// Use host from conn address when serverName is not provided
if isClient && serverName == "" && nextConn.RemoteAddr() != nil {
remoteAddr := nextConn.RemoteAddr().String()
var host string
host, _, err = net.SplitHostPort(remoteAddr)
if err != nil {
serverName = remoteAddr
} else {
serverName = host
}
}
hsCfg := &handshakeConfig{
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
localSignatureSchemes: signatureSchemes,
extendedMasterSecret: config.ExtendedMasterSecret,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
serverName: serverName,
clientAuth: config.ClientAuth,
localCertificates: config.Certificates,
insecureSkipVerify: config.InsecureSkipVerify,
verifyPeerCertificate: config.VerifyPeerCertificate,
rootCAs: config.RootCAs,
clientCAs: config.ClientCAs,
retransmitInterval: workerInterval,
log: logger,
initialEpoch: 0,
}
var initialFlight flightVal
var initialFSMState handshakeState
if initialState != nil {
if c.state.isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished
c.state = *initialState
} else {
if c.state.isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
// Do handshake
if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
return nil, err
}
c.log.Trace("Handshake Completed")
return c, nil
}
// Dial connects to the given network address and establishes a DTLS connection on top.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use DialWithContext() instead.
func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()
return DialWithContext(ctx, network, raddr, config)
}
// Client establishes a DTLS connection over an existing connection.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ClientWithContext() instead.
func Client(conn net.Conn, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()
return ClientWithContext(ctx, conn, config)
}
// Server listens for incoming DTLS connections.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ServerWithContext() instead.
func Server(conn net.Conn, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()
return ServerWithContext(ctx, conn, config)
}
// DialWithContext connects to the given network address and establishes a DTLS connection on top.
func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
pConn, err := net.DialUDP(network, nil, raddr)
if err != nil {
return nil, err
}
return ClientWithContext(ctx, pConn, config)
}
// ClientWithContext establishes a DTLS connection over an existing connection.
func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
switch {
case config == nil:
return nil, errNoConfigProvided
case config.PSK != nil && config.PSKIdentityHint == nil:
return nil, errPSKAndIdentityMustBeSetForClient
}
return createConn(ctx, conn, config, true, nil)
}
// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
switch {
case config == nil:
return nil, errNoConfigProvided
case config.PSK == nil && len(config.Certificates) == 0:
return nil, errServerMustHaveCertificate
}
return createConn(ctx, conn, config, false, nil)
}
// Read reads data from the connection.
func (c *Conn) Read(p []byte) (n int, err error) {
if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
}
select {
case <-c.readDeadline.Done():
return 0, errDeadlineExceeded
default:
}
for {
select {
case <-c.readDeadline.Done():
return 0, errDeadlineExceeded
case out, ok := <-c.decrypted:
if !ok {
return 0, io.EOF
}
switch val := out.(type) {
case ([]byte):
if len(p) < len(val) {
return 0, errBufferTooSmall
}
copy(p, val)
return len(val), nil
case (error):
return 0, val
}
}
}
}
// Write writes len(p) bytes from p to the DTLS connection
func (c *Conn) Write(p []byte) (int, error) {
if c.isConnectionClosed() {
return 0, ErrConnClosed
}
select {
case <-c.writeDeadline.Done():
return 0, errDeadlineExceeded
default:
}
if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
}
return len(p), c.writePackets(c.writeDeadline, []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
epoch: c.getLocalEpoch(),
protocolVersion: protocolVersion1_2,
},
content: &applicationData{
data: p,
},
},
shouldEncrypt: true,
},
})
}
// Close closes the connection.
func (c *Conn) Close() error {
err := c.close(true)
c.handshakeLoopsFinished.Wait()
return err
}
// ConnectionState returns basic DTLS details about the connection.
// Note that this replaced the `Export` function of v1.
func (c *Conn) ConnectionState() State {
c.lock.RLock()
defer c.lock.RUnlock()
return *c.state.clone()
}
// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
if c.state.srtpProtectionProfile == 0 {
return 0, false
}
return c.state.srtpProtectionProfile, true
}
func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
c.lock.Lock()
defer c.lock.Unlock()
var rawPackets [][]byte
for _, p := range pkts {
if h, ok := p.record.content.(*handshake); ok {
handshakeRaw, err := p.record.Marshal()
if err != nil {
return err
}
c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
srvCliStr(c.state.isClient), h.handshakeHeader.handshakeType.String(),
p.record.recordLayerHeader.epoch, h.handshakeHeader.messageSequence)
c.handshakeCache.push(handshakeRaw[recordLayerHeaderSize:], p.record.recordLayerHeader.epoch, h.handshakeHeader.messageSequence, h.handshakeHeader.handshakeType, c.state.isClient)
rawHandshakePackets, err := c.processHandshakePacket(p, h)
if err != nil {
return err
}
rawPackets = append(rawPackets, rawHandshakePackets...)
} else {
rawPacket, err := c.processPacket(p)
if err != nil {
return err
}
rawPackets = append(rawPackets, rawPacket)
}
}
if len(rawPackets) == 0 {
return nil
}
compactedRawPackets := c.compactRawPackets(rawPackets)
for _, compactedRawPackets := range compactedRawPackets {
if _, err := c.nextConn.Write(ctx, compactedRawPackets); err != nil {
return netError(err)
}
}
return nil
}
func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
combinedRawPackets := make([][]byte, 0)
currentCombinedRawPacket := make([]byte, 0)
for _, rawPacket := range rawPackets {
if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
currentCombinedRawPacket = []byte{}
}
currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
}
combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
return combinedRawPackets
}
func (c *Conn) processPacket(p *packet) ([]byte, error) {
epoch := p.record.recordLayerHeader.epoch
for len(c.state.localSequenceNumber) <= int(epoch) {
c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
}
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
if seq > maxSequenceNumber {
// RFC 6347 Section 4.1.0
// The implementation must either abandon an association or rehandshake
// prior to allowing the sequence number to wrap.
return nil, errSequenceNumberOverflow
}
p.record.recordLayerHeader.sequenceNumber = seq
rawPacket, err := p.record.Marshal()
if err != nil {
return nil, err
}
if p.shouldEncrypt {
var err error
rawPacket, err = c.state.cipherSuite.encrypt(p.record, rawPacket)
if err != nil {
return nil, err
}
}
return rawPacket, nil
}
func (c *Conn) processHandshakePacket(p *packet, h *handshake) ([][]byte, error) {
rawPackets := make([][]byte, 0)
handshakeFragments, err := c.fragmentHandshake(h)
if err != nil {
return nil, err
}
epoch := p.record.recordLayerHeader.epoch
for len(c.state.localSequenceNumber) <= int(epoch) {
c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
}
for _, handshakeFragment := range handshakeFragments {
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
if seq > maxSequenceNumber {
return nil, errSequenceNumberOverflow
}
recordLayerHeader := &recordLayerHeader{
protocolVersion: p.record.recordLayerHeader.protocolVersion,
contentType: p.record.recordLayerHeader.contentType,
contentLen: uint16(len(handshakeFragment)),
epoch: p.record.recordLayerHeader.epoch,
sequenceNumber: seq,
}
recordLayerHeaderBytes, err := recordLayerHeader.Marshal()
if err != nil {
return nil, err
}
p.record.recordLayerHeader = *recordLayerHeader
rawPacket := append(recordLayerHeaderBytes, handshakeFragment...)
if p.shouldEncrypt {
var err error
rawPacket, err = c.state.cipherSuite.encrypt(p.record, rawPacket)
if err != nil {
return nil, err
}
}
rawPackets = append(rawPackets, rawPacket)
}
return rawPackets, nil
}
func (c *Conn) fragmentHandshake(h *handshake) ([][]byte, error) {
content, err := h.handshakeMessage.Marshal()
if err != nil {
return nil, err
}
fragmentedHandshakes := make([][]byte, 0)
contentFragments := splitBytes(content, c.maximumTransmissionUnit)
if len(contentFragments) == 0 {
contentFragments = [][]byte{
{},
}
}
offset := 0
for _, contentFragment := range contentFragments {
contentFragmentLen := len(contentFragment)
handshakeHeaderFragment := &handshakeHeader{
handshakeType: h.handshakeHeader.handshakeType,
length: h.handshakeHeader.length,
messageSequence: h.handshakeHeader.messageSequence,
fragmentOffset: uint32(offset),
fragmentLength: uint32(contentFragmentLen),
}
offset += contentFragmentLen
handshakeHeaderFragmentRaw, err := handshakeHeaderFragment.Marshal()
if err != nil {
return nil, err
}
fragmentedHandshake := append(handshakeHeaderFragmentRaw, contentFragment...)
fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
}
return fragmentedHandshakes, nil
}
var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
New: func() interface{} {
b := make([]byte, inboundBufferSize)
return &b
},
}
func (c *Conn) readAndBuffer(ctx context.Context) error {
bufptr := poolReadBuffer.Get().(*[]byte)
defer poolReadBuffer.Put(bufptr)
b := *bufptr
i, err := c.nextConn.Read(ctx, b)
if err != nil {
return netError(err)
}
pkts, err := unpackDatagram(b[:i])
if err != nil {
return err
}
var hasHandshake bool
for _, p := range pkts {
hs, alert, err := c.handleIncomingPacket(p, true)
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if err == nil {
err = alertErr
}
}
}
if hs {
hasHandshake = true
}
switch e := err.(type) {
case nil:
case *errAlert:
if e.IsFatalOrCloseNotify() {
return e
}
default:
return e
}
}
if hasHandshake {
done := make(chan struct{})
select {
case c.handshakeRecv <- done:
// If the other party may retransmit the flight,
// we should respond even if it not a new message.
<-done
case <-c.fsm.Done():
}
}
return nil
}
func (c *Conn) handleQueuedPackets(ctx context.Context) error {
pkts := c.encryptedPackets
c.encryptedPackets = nil
for _, p := range pkts {
_, alert, err := c.handleIncomingPacket(p, false) // don't re-enqueue
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if err == nil {
err = alertErr
}
}
}
switch e := err.(type) {
case nil:
case *errAlert:
if e.IsFatalOrCloseNotify() {
return e
}
default:
return e
}
}
return nil
}
func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, error) { //nolint:gocognit
h := &recordLayerHeader{}
if err := h.Unmarshal(buf); err != nil {
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
c.log.Debugf("discarded broken packet: %v", err)
return false, nil, nil
}
// Validate epoch
remoteEpoch := c.getRemoteEpoch()
if h.epoch > remoteEpoch {
if h.epoch > remoteEpoch+1 {
c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
h.epoch, h.sequenceNumber,
)
return false, nil, nil
}
if enqueue {
c.log.Debug("received packet of next epoch, queuing packet")
c.encryptedPackets = append(c.encryptedPackets, buf)
}
return false, nil, nil
}
// Anti-replay protection
for len(c.state.replayDetector) <= int(h.epoch) {
c.state.replayDetector = append(c.state.replayDetector,
replaydetector.New(c.replayProtectionWindow, maxSequenceNumber),
)
}
markPacketAsValid, ok := c.state.replayDetector[int(h.epoch)].Check(h.sequenceNumber)
if !ok {
c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
h.epoch, h.sequenceNumber,
)
return false, nil, nil
}
// Decrypt
if h.epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handshake not finished, queuing packet")
}
return false, nil, nil
}
var err error
buf, err = c.state.cipherSuite.decrypt(buf)
if err != nil {
c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
return false, nil, nil
}
}
isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
if err != nil {
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
c.log.Debugf("defragment failed: %s", err)
return false, nil, nil
} else if isHandshake {
markPacketAsValid()
for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
rawHandshake := &handshake{}
if err := rawHandshake.Unmarshal(out); err != nil {
c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
continue
}
_ = c.handshakeCache.push(out, epoch, rawHandshake.handshakeHeader.messageSequence, rawHandshake.handshakeHeader.handshakeType, !c.state.isClient)
}
return true, nil, nil
}
r := &recordLayer{}
if err := r.Unmarshal(buf); err != nil {
return false, &alert{alertLevelFatal, alertDecodeError}, err
}
switch content := r.content.(type) {
case *alert:
c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
var a *alert
if content.alertDescription == alertCloseNotify {
// Respond with a close_notify [RFC5246 Section 7.2.1]
a = &alert{alertLevelWarning, alertCloseNotify}
}
markPacketAsValid()
return false, a, &errAlert{content}
case *changeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debugf("CipherSuite not initialized, queuing packet")
}
return false, nil, nil
}
newRemoteEpoch := h.epoch + 1
c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
if c.getRemoteEpoch()+1 == newRemoteEpoch {
c.setRemoteEpoch(newRemoteEpoch)
markPacketAsValid()
}
case *applicationData:
if h.epoch == 0 {
return false, &alert{alertLevelFatal, alertUnexpectedMessage}, errApplicationDataEpochZero
}
markPacketAsValid()
select {
case c.decrypted <- content.data:
case <-c.closed.Done():
}
default:
return false, &alert{alertLevelFatal, alertUnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.contentType())
}
return false, nil, nil
}
func (c *Conn) recvHandshake() <-chan chan struct{} {
return c.handshakeRecv
}
func (c *Conn) notify(ctx context.Context, level alertLevel, desc alertDescription) error {
return c.writePackets(ctx, []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
epoch: c.getLocalEpoch(),
protocolVersion: protocolVersion1_2,
},
content: &alert{
alertLevel: level,
alertDescription: desc,
},
},
shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
},
})
}
func (c *Conn) setHandshakeCompletedSuccessfully() {
c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
}
func (c *Conn) isHandshakeCompletedSuccessfully() bool {
boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
return boolean.bool
}
func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
done := make(chan struct{})
ctxRead, cancelRead := context.WithCancel(context.Background())
c.cancelHandshakeReader = cancelRead
cfg.onFlightState = func(f flightVal, s handshakeState) {
if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
c.setHandshakeCompletedSuccessfully()
close(done)
}
}
ctxHs, cancel := context.WithCancel(context.Background())
c.cancelHandshaker = cancel
firstErr := make(chan error, 1)
c.handshakeLoopsFinished.Add(2)
// Handshake routine should be live until close.
// The other party may request retransmission of the last flight to cope with packet drop.
go func() {
defer c.handshakeLoopsFinished.Done()
err := c.fsm.Run(ctxHs, c, initialState)
if !errors.Is(err, context.Canceled) {
select {
case firstErr <- err:
default:
}
}
}()
go func() {
defer func() {
// Escaping read loop.
// It's safe to close decrypted channnel now.
close(c.decrypted)
// Force stop handshaker when the underlying connection is closed.
cancel()
}()
defer c.handshakeLoopsFinished.Done()
for {
if err := c.readAndBuffer(ctxRead); err != nil {
switch e := err.(type) {
case *errAlert:
if !e.IsFatalOrCloseNotify() {
if c.isHandshakeCompletedSuccessfully() {
// Pass the error to Read()
select {
case c.decrypted <- err:
case <-c.closed.Done():
}
}
continue // non-fatal alert must not stop read loop
}
case error:
switch err {
case context.DeadlineExceeded, context.Canceled, io.EOF:
default:
if c.isHandshakeCompletedSuccessfully() {
// Keep read loop and pass the read error to Read()
select {
case c.decrypted <- err:
case <-c.closed.Done():
}
continue // non-fatal alert must not stop read loop
}
}
}
select {
case firstErr <- err:
default:
}
if e, ok := err.(*errAlert); ok {
if e.IsFatalOrCloseNotify() {
_ = c.close(false)
}
}
return
}
}
}()
select {
case err := <-firstErr:
cancelRead()
cancel()
return c.translateHandshakeCtxError(err)
case <-ctx.Done():
cancelRead()
cancel()
return c.translateHandshakeCtxError(ctx.Err())
case <-done:
return nil
}
}
func (c *Conn) translateHandshakeCtxError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
return nil
}
return &HandshakeError{err}
}
func (c *Conn) close(byUser bool) error {
c.cancelHandshaker()
c.cancelHandshakeReader()
if c.isHandshakeCompletedSuccessfully() && byUser {
// Discard error from notify() to return non-error on the first user call of Close()
// even if the underlying connection is already closed.
_ = c.notify(context.Background(), alertLevelWarning, alertCloseNotify)
}
c.closeLock.Lock()
// Don't return ErrConnClosed at the first time of the call from user.
closedByUser := c.connectionClosedByUser
if byUser {
c.connectionClosedByUser = true
}
c.closed.Close()
c.closeLock.Unlock()
if closedByUser {
return ErrConnClosed
}
return c.nextConn.Close()
}
func (c *Conn) isConnectionClosed() bool {
select {
case <-c.closed.Done():
return true
default:
return false
}
}
func (c *Conn) setLocalEpoch(epoch uint16) {
c.state.localEpoch.Store(epoch)
}
func (c *Conn) getLocalEpoch() uint16 {
return c.state.localEpoch.Load().(uint16)
}
func (c *Conn) setRemoteEpoch(epoch uint16) {
c.state.remoteEpoch.Store(epoch)
}
func (c *Conn) getRemoteEpoch() uint16 {
return c.state.remoteEpoch.Load().(uint16)
}
// LocalAddr implements net.Conn.LocalAddr
func (c *Conn) LocalAddr() net.Addr {
return c.nextConn.LocalAddr()
}
// RemoteAddr implements net.Conn.RemoteAddr
func (c *Conn) RemoteAddr() net.Addr {
return c.nextConn.RemoteAddr()
}
// SetDeadline implements net.Conn.SetDeadline
func (c *Conn) SetDeadline(t time.Time) error {
c.readDeadline.Set(t)
return c.SetWriteDeadline(t)
}
// SetReadDeadline implements net.Conn.SetReadDeadline
func (c *Conn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t)
// Read deadline is fully managed by this layer.
// Don't set read deadline to underlying connection.
return nil
}
// SetWriteDeadline implements net.Conn.SetWriteDeadline
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.writeDeadline.Set(t)
// Write deadline is also fully managed by this layer.
return nil
}

View file

@ -0,0 +1,17 @@
package dtls
// https://tools.ietf.org/html/rfc4346#section-6.2.1
type contentType uint8
const (
contentTypeChangeCipherSpec contentType = 20
contentTypeAlert contentType = 21
contentTypeHandshake contentType = 22
contentTypeApplicationData contentType = 23
)
type content interface {
contentType() contentType
Marshal() ([]byte, error)
Unmarshal(data []byte) error
}

View file

@ -0,0 +1,232 @@
package dtls
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/binary"
"math/big"
"time"
)
type ecdsaSignature struct {
R, S *big.Int
}
func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve namedCurve) []byte {
serverECDHParams := make([]byte, 4)
serverECDHParams[0] = 3 // named curve
binary.BigEndian.PutUint16(serverECDHParams[1:], uint16(namedCurve))
serverECDHParams[3] = byte(len(publicKey))
plaintext := []byte{}
plaintext = append(plaintext, clientRandom...)
plaintext = append(plaintext, serverRandom...)
plaintext = append(plaintext, serverECDHParams...)
plaintext = append(plaintext, publicKey...)
return plaintext
}
// If the client provided a "signature_algorithms" extension, then all
// certificates provided by the server MUST be signed by a
// hash/signature algorithm pair that appears in that extension
//
// https://tools.ietf.org/html/rfc5246#section-7.4.2
func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCurve namedCurve, privateKey crypto.PrivateKey, hashAlgorithm hashAlgorithm) ([]byte, error) {
msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve)
switch p := privateKey.(type) {
case ed25519.PrivateKey:
// https://crypto.stackexchange.com/a/55483
return p.Sign(rand.Reader, msg, crypto.Hash(0))
case *ecdsa.PrivateKey:
hashed := hashAlgorithm.digest(msg)
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
case *rsa.PrivateKey:
hashed := hashAlgorithm.digest(msg)
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
}
return nil, errKeySignatureGenerateUnimplemented
}
func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hashAlgorithm, rawCertificates [][]byte) error { //nolint:dupl
if len(rawCertificates) == 0 {
return errLengthMismatch
}
certificate, err := x509.ParseCertificate(rawCertificates[0])
if err != nil {
return err
}
switch p := certificate.PublicKey.(type) {
case ed25519.PublicKey:
if ok := ed25519.Verify(p, message, remoteKeySignature); !ok {
return errKeySignatureMismatch
}
return nil
case *ecdsa.PublicKey:
ecdsaSig := &ecdsaSignature{}
if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil {
return err
}
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
return errInvalidECDSASignature
}
hashed := hashAlgorithm.digest(message)
if !ecdsa.Verify(p, hashed, ecdsaSig.R, ecdsaSig.S) {
return errKeySignatureMismatch
}
return nil
case *rsa.PublicKey:
switch certificate.SignatureAlgorithm {
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
hashed := hashAlgorithm.digest(message)
return rsa.VerifyPKCS1v15(p, hashAlgorithm.cryptoHash(), hashed, remoteKeySignature)
default:
return errKeySignatureVerifyUnimplemented
}
}
return errKeySignatureVerifyUnimplemented
}
// If the server has sent a CertificateRequest message, the client MUST send the Certificate
// message. The ClientKeyExchange message is now sent, and the content
// of that message will depend on the public key algorithm selected
// between the ClientHello and the ServerHello. If the client has sent
// a certificate with signing ability, a digitally-signed
// CertificateVerify message is sent to explicitly verify possession of
// the private key in the certificate.
// https://tools.ietf.org/html/rfc5246#section-7.3
func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.PrivateKey, hashAlgorithm hashAlgorithm) ([]byte, error) {
h := sha256.New()
if _, err := h.Write(handshakeBodies); err != nil {
return nil, err
}
hashed := h.Sum(nil)
switch p := privateKey.(type) {
case ed25519.PrivateKey:
// https://crypto.stackexchange.com/a/55483
return p.Sign(rand.Reader, hashed, crypto.Hash(0))
case *ecdsa.PrivateKey:
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
case *rsa.PrivateKey:
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
}
return nil, errInvalidSignatureAlgorithm
}
func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hashAlgorithm, remoteKeySignature []byte, rawCertificates [][]byte) error { //nolint:dupl
if len(rawCertificates) == 0 {
return errLengthMismatch
}
certificate, err := x509.ParseCertificate(rawCertificates[0])
if err != nil {
return err
}
switch p := certificate.PublicKey.(type) {
case ed25519.PublicKey:
if ok := ed25519.Verify(p, handshakeBodies, remoteKeySignature); !ok {
return errKeySignatureMismatch
}
return nil
case *ecdsa.PublicKey:
ecdsaSig := &ecdsaSignature{}
if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil {
return err
}
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
return errInvalidECDSASignature
}
hash := hashAlgorithm.digest(handshakeBodies)
if !ecdsa.Verify(p, hash, ecdsaSig.R, ecdsaSig.S) {
return errKeySignatureMismatch
}
return nil
case *rsa.PublicKey:
switch certificate.SignatureAlgorithm {
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
hash := hashAlgorithm.digest(handshakeBodies)
return rsa.VerifyPKCS1v15(p, hashAlgorithm.cryptoHash(), hash, remoteKeySignature)
default:
return errKeySignatureVerifyUnimplemented
}
}
return errKeySignatureVerifyUnimplemented
}
func loadCerts(rawCertificates [][]byte) ([]*x509.Certificate, error) {
if len(rawCertificates) == 0 {
return nil, errLengthMismatch
}
certs := make([]*x509.Certificate, 0, len(rawCertificates))
for _, rawCert := range rawCertificates {
cert, err := x509.ParseCertificate(rawCert)
if err != nil {
return nil, err
}
certs = append(certs, cert)
}
return certs, nil
}
func verifyClientCert(rawCertificates [][]byte, roots *x509.CertPool) (chains [][]*x509.Certificate, err error) {
certificate, err := loadCerts(rawCertificates)
if err != nil {
return nil, err
}
intermediateCAPool := x509.NewCertPool()
for _, cert := range certificate[1:] {
intermediateCAPool.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
Intermediates: intermediateCAPool,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
return certificate[0].Verify(opts)
}
func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName string) (chains [][]*x509.Certificate, err error) {
certificate, err := loadCerts(rawCertificates)
if err != nil {
return nil, err
}
intermediateCAPool := x509.NewCertPool()
for _, cert := range certificate[1:] {
intermediateCAPool.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
DNSName: serverName,
Intermediates: intermediateCAPool,
}
return certificate[0].Verify(opts)
}
func generateAEADAdditionalData(h *recordLayerHeader, payloadLen int) []byte {
var additionalData [13]byte
// SequenceNumber MUST be set first
// we only want uint48, clobbering an extra 2 (using uint64, Golang doesn't have uint48)
binary.BigEndian.PutUint64(additionalData[:], h.sequenceNumber)
binary.BigEndian.PutUint16(additionalData[:], h.epoch)
additionalData[8] = byte(h.contentType)
additionalData[9] = h.protocolVersion.major
additionalData[10] = h.protocolVersion.minor
binary.BigEndian.PutUint16(additionalData[len(additionalData)-2:], uint16(payloadLen))
return additionalData[:]
}

View file

@ -0,0 +1,133 @@
package dtls
import ( //nolint:gci
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha1" //nolint:gosec
"encoding/binary"
)
// block ciphers using cipher block chaining.
type cbcMode interface {
cipher.BlockMode
SetIV([]byte)
}
// State needed to handle encrypted input/output
type cryptoCBC struct {
writeCBC, readCBC cbcMode
writeMac, readMac []byte
}
// Currently hardcoded to be SHA1 only
var cryptoCBCMacFunc = sha1.New //nolint:gochecknoglobals
func newCryptoCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte) (*cryptoCBC, error) {
writeBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
readBlock, err := aes.NewCipher(remoteKey)
if err != nil {
return nil, err
}
return &cryptoCBC{
writeCBC: cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode),
writeMac: localMac,
readCBC: cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode),
readMac: remoteMac,
}, nil
}
func (c *cryptoCBC) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
payload := raw[recordLayerHeaderSize:]
raw = raw[:recordLayerHeaderSize]
blockSize := c.writeCBC.BlockSize()
// Generate + Append MAC
h := pkt.recordLayerHeader
MAC, err := prfMac(h.epoch, h.sequenceNumber, h.contentType, h.protocolVersion, payload, c.writeMac)
if err != nil {
return nil, err
}
payload = append(payload, MAC...)
// Generate + Append padding
padding := make([]byte, blockSize-len(payload)%blockSize)
paddingLen := len(padding)
for i := 0; i < paddingLen; i++ {
padding[i] = byte(paddingLen - 1)
}
payload = append(payload, padding...)
// Generate IV
iv := make([]byte, blockSize)
if _, err := rand.Read(iv); err != nil {
return nil, err
}
// Set IV + Encrypt + Prepend IV
c.writeCBC.SetIV(iv)
c.writeCBC.CryptBlocks(payload, payload)
payload = append(iv, payload...)
// Prepend unencrypte header with encrypted payload
raw = append(raw, payload...)
// Update recordLayer size to include IV+MAC+Padding
binary.BigEndian.PutUint16(raw[recordLayerHeaderSize-2:], uint16(len(raw)-recordLayerHeaderSize))
return raw, nil
}
func (c *cryptoCBC) decrypt(in []byte) ([]byte, error) {
body := in[recordLayerHeaderSize:]
blockSize := c.readCBC.BlockSize()
mac := cryptoCBCMacFunc()
var h recordLayerHeader
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.contentType == contentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(body)%blockSize != 0 || len(body) < blockSize+max(mac.Size()+1, blockSize):
return nil, errNotEnoughRoomForNonce
}
// Set + remove per record IV
c.readCBC.SetIV(body[:blockSize])
body = body[blockSize:]
// Decrypt
c.readCBC.CryptBlocks(body, body)
// Padding+MAC needs to be checked in constant time
// Otherwise we reveal information about the level of correctness
paddingLen, paddingGood := examinePadding(body)
macSize := mac.Size()
if len(body) < macSize {
return nil, errInvalidMAC
}
dataEnd := len(body) - macSize - paddingLen
expectedMAC := body[dataEnd : dataEnd+macSize]
actualMAC, err := prfMac(h.epoch, h.sequenceNumber, h.contentType, h.protocolVersion, body[:dataEnd], c.readMac)
// Compute Local MAC and compare
if paddingGood != 255 || err != nil || !hmac.Equal(actualMAC, expectedMAC) {
return nil, errInvalidMAC
}
return append(in[:recordLayerHeaderSize], body[:dataEnd]...), nil
}

View file

@ -0,0 +1,100 @@
package dtls
import (
"crypto/aes"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"github.com/pion/dtls/v2/pkg/crypto/ccm"
)
var errDecryptPacket = errors.New("decryptPacket")
type cryptoCCMTagLen int
const (
cryptoCCM8TagLength cryptoCCMTagLen = 8
cryptoCCMTagLength cryptoCCMTagLen = 16
cryptoCCMNonceLength = 12
)
// State needed to handle encrypted input/output
type cryptoCCM struct {
localCCM, remoteCCM ccm.CCM
localWriteIV, remoteWriteIV []byte
tagLen cryptoCCMTagLen
}
func newCryptoCCM(tagLen cryptoCCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*cryptoCCM, error) {
localBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
localCCM, err := ccm.NewCCM(localBlock, int(tagLen), cryptoCCMNonceLength)
if err != nil {
return nil, err
}
remoteBlock, err := aes.NewCipher(remoteKey)
if err != nil {
return nil, err
}
remoteCCM, err := ccm.NewCCM(remoteBlock, int(tagLen), cryptoCCMNonceLength)
if err != nil {
return nil, err
}
return &cryptoCCM{
localCCM: localCCM,
localWriteIV: localWriteIV,
remoteCCM: remoteCCM,
remoteWriteIV: remoteWriteIV,
tagLen: tagLen,
}, nil
}
func (c *cryptoCCM) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
payload := raw[recordLayerHeaderSize:]
raw = raw[:recordLayerHeaderSize]
nonce := append(append([]byte{}, c.localWriteIV[:4]...), make([]byte, 8)...)
if _, err := rand.Read(nonce[4:]); err != nil {
return nil, err
}
additionalData := generateAEADAdditionalData(&pkt.recordLayerHeader, len(payload))
encryptedPayload := c.localCCM.Seal(nil, nonce, payload, additionalData)
encryptedPayload = append(nonce[4:], encryptedPayload...)
raw = append(raw, encryptedPayload...)
// Update recordLayer size to include explicit nonce
binary.BigEndian.PutUint16(raw[recordLayerHeaderSize-2:], uint16(len(raw)-recordLayerHeaderSize))
return raw, nil
}
func (c *cryptoCCM) decrypt(in []byte) ([]byte, error) {
var h recordLayerHeader
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.contentType == contentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(in) <= (8 + recordLayerHeaderSize):
return nil, errNotEnoughRoomForNonce
}
nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[recordLayerHeaderSize:recordLayerHeaderSize+8]...)
out := in[recordLayerHeaderSize+8:]
additionalData := generateAEADAdditionalData(&h, len(out)-int(c.tagLen))
out, err = c.remoteCCM.Open(out[:0], nonce, out, additionalData)
if err != nil {
return nil, fmt.Errorf("%w: %v", errDecryptPacket, err)
}
return append(in[:recordLayerHeaderSize], out...), nil
}

View file

@ -0,0 +1,94 @@
package dtls
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"fmt"
)
const (
cryptoGCMTagLength = 16
cryptoGCMNonceLength = 12
)
// State needed to handle encrypted input/output
type cryptoGCM struct {
localGCM, remoteGCM cipher.AEAD
localWriteIV, remoteWriteIV []byte
}
func newCryptoGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*cryptoGCM, error) {
localBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
localGCM, err := cipher.NewGCM(localBlock)
if err != nil {
return nil, err
}
remoteBlock, err := aes.NewCipher(remoteKey)
if err != nil {
return nil, err
}
remoteGCM, err := cipher.NewGCM(remoteBlock)
if err != nil {
return nil, err
}
return &cryptoGCM{
localGCM: localGCM,
localWriteIV: localWriteIV,
remoteGCM: remoteGCM,
remoteWriteIV: remoteWriteIV,
}, nil
}
func (c *cryptoGCM) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
payload := raw[recordLayerHeaderSize:]
raw = raw[:recordLayerHeaderSize]
nonce := make([]byte, cryptoGCMNonceLength)
copy(nonce, c.localWriteIV[:4])
if _, err := rand.Read(nonce[4:]); err != nil {
return nil, err
}
additionalData := generateAEADAdditionalData(&pkt.recordLayerHeader, len(payload))
encryptedPayload := c.localGCM.Seal(nil, nonce, payload, additionalData)
r := make([]byte, len(raw)+len(nonce[4:])+len(encryptedPayload))
copy(r, raw)
copy(r[len(raw):], nonce[4:])
copy(r[len(raw)+len(nonce[4:]):], encryptedPayload)
// Update recordLayer size to include explicit nonce
binary.BigEndian.PutUint16(r[recordLayerHeaderSize-2:], uint16(len(r)-recordLayerHeaderSize))
return r, nil
}
func (c *cryptoGCM) decrypt(in []byte) ([]byte, error) {
var h recordLayerHeader
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.contentType == contentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(in) <= (8 + recordLayerHeaderSize):
return nil, errNotEnoughRoomForNonce
}
nonce := make([]byte, 0, cryptoGCMNonceLength)
nonce = append(append(nonce, c.remoteWriteIV[:4]...), in[recordLayerHeaderSize:recordLayerHeaderSize+8]...)
out := in[recordLayerHeaderSize+8:]
additionalData := generateAEADAdditionalData(&h, len(out)-cryptoGCMTagLength)
out, err = c.remoteGCM.Open(out[:0], nonce, out, additionalData)
if err != nil {
return nil, fmt.Errorf("%w: %v", errDecryptPacket, err)
}
return append(in[:recordLayerHeaderSize], out...), nil
}

View file

@ -0,0 +1,14 @@
package dtls
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10
type ellipticCurveType byte
const (
ellipticCurveTypeNamedCurve ellipticCurveType = 0x03
)
func ellipticCurveTypes() map[ellipticCurveType]bool {
return map[ellipticCurveType]bool{
ellipticCurveTypeNamedCurve: true,
}
}

View file

@ -0,0 +1,2 @@
// Package dtls implements Datagram Transport Layer Security (DTLS) 1.2
package dtls

View file

@ -0,0 +1,229 @@
package dtls
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"golang.org/x/xerrors"
)
// Typed errors
var (
ErrConnClosed = &FatalError{errors.New("conn is closed")} //nolint:goerr113
errDeadlineExceeded = &TimeoutError{xerrors.Errorf("read/write timeout: %w", context.DeadlineExceeded)}
errBufferTooSmall = &TemporaryError{errors.New("buffer is too small")} //nolint:goerr113
errContextUnsupported = &TemporaryError{errors.New("context is not supported for ExportKeyingMaterial")} //nolint:goerr113
errDTLSPacketInvalidLength = &TemporaryError{errors.New("packet is too short")} //nolint:goerr113
errHandshakeInProgress = &TemporaryError{errors.New("handshake is in progress")} //nolint:goerr113
errInvalidContentType = &TemporaryError{errors.New("invalid content type")} //nolint:goerr113
errInvalidMAC = &TemporaryError{errors.New("invalid mac")} //nolint:goerr113
errInvalidPacketLength = &TemporaryError{errors.New("packet length and declared length do not match")} //nolint:goerr113
errReservedExportKeyingMaterial = &TemporaryError{errors.New("ExportKeyingMaterial can not be used with a reserved label")} //nolint:goerr113
errCertificateVerifyNoCertificate = &FatalError{errors.New("client sent certificate verify but we have no certificate to verify")} //nolint:goerr113
errCipherSuiteNoIntersection = &FatalError{errors.New("client+server do not support any shared cipher suites")} //nolint:goerr113
errCipherSuiteUnset = &FatalError{errors.New("server hello can not be created without a cipher suite")} //nolint:goerr113
errClientCertificateNotVerified = &FatalError{errors.New("client sent certificate but did not verify it")} //nolint:goerr113
errClientCertificateRequired = &FatalError{errors.New("server required client verification, but got none")} //nolint:goerr113
errClientNoMatchingSRTPProfile = &FatalError{errors.New("server responded with SRTP Profile we do not support")} //nolint:goerr113
errClientRequiredButNoServerEMS = &FatalError{errors.New("client required Extended Master Secret extension, but server does not support it")} //nolint:goerr113
errCompressionMethodUnset = &FatalError{errors.New("server hello can not be created without a compression method")} //nolint:goerr113
errCookieMismatch = &FatalError{errors.New("client+server cookie does not match")} //nolint:goerr113
errCookieTooLong = &FatalError{errors.New("cookie must not be longer then 255 bytes")} //nolint:goerr113
errIdentityNoPSK = &FatalError{errors.New("PSK Identity Hint provided but PSK is nil")} //nolint:goerr113
errInvalidCertificate = &FatalError{errors.New("no certificate provided")} //nolint:goerr113
errInvalidCipherSpec = &FatalError{errors.New("cipher spec invalid")} //nolint:goerr113
errInvalidCipherSuite = &FatalError{errors.New("invalid or unknown cipher suite")} //nolint:goerr113
errInvalidClientKeyExchange = &FatalError{errors.New("unable to determine if ClientKeyExchange is a public key or PSK Identity")} //nolint:goerr113
errInvalidCompressionMethod = &FatalError{errors.New("invalid or unknown compression method")} //nolint:goerr113
errInvalidECDSASignature = &FatalError{errors.New("ECDSA signature contained zero or negative values")} //nolint:goerr113
errInvalidEllipticCurveType = &FatalError{errors.New("invalid or unknown elliptic curve type")} //nolint:goerr113
errInvalidExtensionType = &FatalError{errors.New("invalid extension type")} //nolint:goerr113
errInvalidHashAlgorithm = &FatalError{errors.New("invalid hash algorithm")} //nolint:goerr113
errInvalidNamedCurve = &FatalError{errors.New("invalid named curve")} //nolint:goerr113
errInvalidPrivateKey = &FatalError{errors.New("invalid private key type")} //nolint:goerr113
errInvalidSNIFormat = &FatalError{errors.New("invalid server name format")} //nolint:goerr113
errInvalidSignatureAlgorithm = &FatalError{errors.New("invalid signature algorithm")} //nolint:goerr113
errKeySignatureMismatch = &FatalError{errors.New("expected and actual key signature do not match")} //nolint:goerr113
errNilNextConn = &FatalError{errors.New("Conn can not be created with a nil nextConn")} //nolint:goerr113
errNoAvailableCipherSuites = &FatalError{errors.New("connection can not be created, no CipherSuites satisfy this Config")} //nolint:goerr113
errNoAvailableSignatureSchemes = &FatalError{errors.New("connection can not be created, no SignatureScheme satisfy this Config")} //nolint:goerr113
errNoCertificates = &FatalError{errors.New("no certificates configured")} //nolint:goerr113
errNoConfigProvided = &FatalError{errors.New("no config provided")} //nolint:goerr113
errNoSupportedEllipticCurves = &FatalError{errors.New("client requested zero or more elliptic curves that are not supported by the server")} //nolint:goerr113
errUnsupportedProtocolVersion = &FatalError{errors.New("unsupported protocol version")} //nolint:goerr113
errPSKAndCertificate = &FatalError{errors.New("Certificate and PSK provided")} //nolint:stylecheck
errPSKAndIdentityMustBeSetForClient = &FatalError{errors.New("PSK and PSK Identity Hint must both be set for client")} //nolint:goerr113
errRequestedButNoSRTPExtension = &FatalError{errors.New("SRTP support was requested but server did not respond with use_srtp extension")} //nolint:goerr113
errServerMustHaveCertificate = &FatalError{errors.New("Certificate is mandatory for server")} //nolint:stylecheck
errServerNoMatchingSRTPProfile = &FatalError{errors.New("client requested SRTP but we have no matching profiles")} //nolint:goerr113
errServerRequiredButNoClientEMS = &FatalError{errors.New("server requires the Extended Master Secret extension, but the client does not support it")} //nolint:goerr113
errVerifyDataMismatch = &FatalError{errors.New("expected and actual verify data does not match")} //nolint:goerr113
errHandshakeMessageUnset = &InternalError{errors.New("handshake message unset, unable to marshal")} //nolint:goerr113
errInvalidFlight = &InternalError{errors.New("invalid flight number")} //nolint:goerr113
errKeySignatureGenerateUnimplemented = &InternalError{errors.New("unable to generate key signature, unimplemented")} //nolint:goerr113
errKeySignatureVerifyUnimplemented = &InternalError{errors.New("unable to verify key signature, unimplemented")} //nolint:goerr113
errLengthMismatch = &InternalError{errors.New("data length and declared length do not match")} //nolint:goerr113
errNotEnoughRoomForNonce = &InternalError{errors.New("buffer not long enough to contain nonce")} //nolint:goerr113
errNotImplemented = &InternalError{errors.New("feature has not been implemented yet")} //nolint:goerr113
errSequenceNumberOverflow = &InternalError{errors.New("sequence number overflow")} //nolint:goerr113
errUnableToMarshalFragmented = &InternalError{errors.New("unable to marshal fragmented handshakes")} //nolint:goerr113
)
// FatalError indicates that the DTLS connection is no longer available.
// It is mainly caused by wrong configuration of server or client.
type FatalError struct {
Err error
}
// InternalError indicates and internal error caused by the implementation, and the DTLS connection is no longer available.
// It is mainly caused by bugs or tried to use unimplemented features.
type InternalError struct {
Err error
}
// TemporaryError indicates that the DTLS connection is still available, but the request was failed temporary.
type TemporaryError struct {
Err error
}
// TimeoutError indicates that the request was timed out.
type TimeoutError struct {
Err error
}
// HandshakeError indicates that the handshake failed.
type HandshakeError struct {
Err error
}
// invalidCipherSuite indicates an attempt at using an unsupported cipher suite.
type invalidCipherSuite struct {
id CipherSuiteID
}
func (e *invalidCipherSuite) Error() string {
return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id)
}
func (e *invalidCipherSuite) Is(err error) bool {
if other, ok := err.(*invalidCipherSuite); ok {
return e.id == other.id
}
return false
}
// Timeout implements net.Error.Timeout()
func (*FatalError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*FatalError) Temporary() bool { return false }
// Unwrap implements Go1.13 error unwrapper.
func (e *FatalError) Unwrap() error { return e.Err }
func (e *FatalError) Error() string { return fmt.Sprintf("dtls fatal: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*InternalError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*InternalError) Temporary() bool { return false }
// Unwrap implements Go1.13 error unwrapper.
func (e *InternalError) Unwrap() error { return e.Err }
func (e *InternalError) Error() string { return fmt.Sprintf("dtls internal: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*TemporaryError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*TemporaryError) Temporary() bool { return true }
// Unwrap implements Go1.13 error unwrapper.
func (e *TemporaryError) Unwrap() error { return e.Err }
func (e *TemporaryError) Error() string { return fmt.Sprintf("dtls temporary: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*TimeoutError) Timeout() bool { return true }
// Temporary implements net.Error.Temporary()
func (*TimeoutError) Temporary() bool { return true }
// Unwrap implements Go1.13 error unwrapper.
func (e *TimeoutError) Unwrap() error { return e.Err }
func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (e *HandshakeError) Timeout() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Timeout()
}
return false
}
// Temporary implements net.Error.Temporary()
func (e *HandshakeError) Temporary() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Temporary()
}
return false
}
// Unwrap implements Go1.13 error unwrapper.
func (e *HandshakeError) Unwrap() error { return e.Err }
func (e *HandshakeError) Error() string { return fmt.Sprintf("handshake error: %v", e.Err) }
// errAlert wraps DTLS alert notification as an error
type errAlert struct {
*alert
}
func (e *errAlert) Error() string {
return fmt.Sprintf("alert: %s", e.alert.String())
}
func (e *errAlert) IsFatalOrCloseNotify() bool {
return e.alertLevel == alertLevelFatal || e.alertDescription == alertCloseNotify
}
func (e *errAlert) Is(err error) bool {
if other, ok := err.(*errAlert); ok {
return e.alertLevel == other.alertLevel && e.alertDescription == other.alertDescription
}
return false
}
// netError translates an error from underlying Conn to corresponding net.Error.
func netError(err error) error {
switch err {
case io.EOF, context.Canceled, context.DeadlineExceeded:
// Return io.EOF and context errors as is.
return err
}
switch e := err.(type) {
case (*net.OpError):
if se, ok := e.Err.(*os.SyscallError); ok {
if se.Timeout() {
return &TimeoutError{err}
}
if isOpErrorTemporary(se) {
return &TemporaryError{err}
}
}
case (net.Error):
return err
}
return &FatalError{err}
}

View file

@ -0,0 +1,25 @@
// +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows
// For systems having syscall.Errno.
// Update build targets by following command:
// $ grep -R ECONN $(go env GOROOT)/src/syscall/zerrors_*.go \
// | tr "." "_" | cut -d"_" -f"2" | sort | uniq
package dtls
import (
"os"
"syscall"
)
func isOpErrorTemporary(err *os.SyscallError) bool {
if ne, ok := err.Err.(syscall.Errno); ok {
switch ne {
case syscall.ECONNREFUSED:
return true
default:
return false
}
}
return false
}

View file

@ -0,0 +1,14 @@
// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!nacl,!nacljs,!netbsd,!openbsd,!solaris,!windows
// For systems without syscall.Errno.
// Build targets must be inverse of errors_errno.go
package dtls
import (
"os"
)
func isOpErrorTemporary(err *os.SyscallError) bool {
return false
}

View file

@ -0,0 +1,88 @@
package dtls
import (
"encoding/binary"
)
// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml
type extensionValue uint16
const (
extensionServerNameValue extensionValue = 0
extensionSupportedEllipticCurvesValue extensionValue = 10
extensionSupportedPointFormatsValue extensionValue = 11
extensionSupportedSignatureAlgorithmsValue extensionValue = 13
extensionUseSRTPValue extensionValue = 14
extensionUseExtendedMasterSecretValue extensionValue = 23
extensionRenegotiationInfoValue extensionValue = 65281
)
type extension interface {
Marshal() ([]byte, error)
Unmarshal(data []byte) error
extensionValue() extensionValue
}
func decodeExtensions(buf []byte) ([]extension, error) {
if len(buf) < 2 {
return nil, errBufferTooSmall
}
declaredLen := binary.BigEndian.Uint16(buf)
if len(buf)-2 != int(declaredLen) {
return nil, errLengthMismatch
}
extensions := []extension{}
unmarshalAndAppend := func(data []byte, e extension) error {
err := e.Unmarshal(data)
if err != nil {
return err
}
extensions = append(extensions, e)
return nil
}
for offset := 2; offset < len(buf); {
if len(buf) < (offset + 2) {
return nil, errBufferTooSmall
}
var err error
switch extensionValue(binary.BigEndian.Uint16(buf[offset:])) {
case extensionServerNameValue:
err = unmarshalAndAppend(buf[offset:], &extensionServerName{})
case extensionSupportedEllipticCurvesValue:
err = unmarshalAndAppend(buf[offset:], &extensionSupportedEllipticCurves{})
case extensionUseSRTPValue:
err = unmarshalAndAppend(buf[offset:], &extensionUseSRTP{})
case extensionUseExtendedMasterSecretValue:
err = unmarshalAndAppend(buf[offset:], &extensionUseExtendedMasterSecret{})
case extensionRenegotiationInfoValue:
err = unmarshalAndAppend(buf[offset:], &extensionRenegotiationInfo{})
default:
}
if err != nil {
return nil, err
}
if len(buf) < (offset + 4) {
return nil, errBufferTooSmall
}
extensionLength := binary.BigEndian.Uint16(buf[offset+2:])
offset += (4 + int(extensionLength))
}
return extensions, nil
}
func encodeExtensions(e []extension) ([]byte, error) {
extensions := []byte{}
for _, e := range e {
raw, err := e.Marshal()
if err != nil {
return nil, err
}
extensions = append(extensions, raw...)
}
out := []byte{0x00, 0x00}
binary.BigEndian.PutUint16(out, uint16(len(extensions)))
return append(out, extensions...), nil
}

View file

@ -0,0 +1,37 @@
package dtls
import "encoding/binary"
const (
extensionRenegotiationInfoHeaderSize = 5
)
// https://tools.ietf.org/html/rfc5746
type extensionRenegotiationInfo struct {
renegotiatedConnection uint8
}
func (e extensionRenegotiationInfo) extensionValue() extensionValue {
return extensionRenegotiationInfoValue
}
func (e *extensionRenegotiationInfo) Marshal() ([]byte, error) {
out := make([]byte, extensionRenegotiationInfoHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(1)) // length
out[4] = e.renegotiatedConnection
return out, nil
}
func (e *extensionRenegotiationInfo) Unmarshal(data []byte) error {
if len(data) < extensionRenegotiationInfoHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
e.renegotiatedConnection = data[4]
return nil
}

View file

@ -0,0 +1,70 @@
package dtls
import (
"strings"
"golang.org/x/crypto/cryptobyte"
)
const extensionServerNameTypeDNSHostName = 0
type extensionServerName struct {
serverName string
}
func (e extensionServerName) extensionValue() extensionValue {
return extensionServerNameValue
}
func (e *extensionServerName) Marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(uint16(e.extensionValue()))
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(extensionServerNameTypeDNSHostName)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(e.serverName))
})
})
})
return b.Bytes()
}
func (e *extensionServerName) Unmarshal(data []byte) error {
s := cryptobyte.String(data)
var extension uint16
s.ReadUint16(&extension)
if extensionValue(extension) != e.extensionValue() {
return errInvalidExtensionType
}
var extData cryptobyte.String
s.ReadUint16LengthPrefixed(&extData)
var nameList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
return errInvalidSNIFormat
}
for !nameList.Empty() {
var nameType uint8
var serverName cryptobyte.String
if !nameList.ReadUint8(&nameType) ||
!nameList.ReadUint16LengthPrefixed(&serverName) ||
serverName.Empty() {
return errInvalidSNIFormat
}
if nameType != extensionServerNameTypeDNSHostName {
continue
}
if len(e.serverName) != 0 {
// Multiple names of the same name_type are prohibited.
return errInvalidSNIFormat
}
e.serverName = string(serverName)
// An SNI value may not include a trailing dot.
if strings.HasSuffix(e.serverName, ".") {
return errInvalidSNIFormat
}
}
return nil
}

View file

@ -0,0 +1,54 @@
package dtls
import (
"encoding/binary"
)
const (
extensionSupportedGroupsHeaderSize = 6
)
// https://tools.ietf.org/html/rfc8422#section-5.1.1
type extensionSupportedEllipticCurves struct {
ellipticCurves []namedCurve
}
func (e extensionSupportedEllipticCurves) extensionValue() extensionValue {
return extensionSupportedEllipticCurvesValue
}
func (e *extensionSupportedEllipticCurves) Marshal() ([]byte, error) {
out := make([]byte, extensionSupportedGroupsHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(e.ellipticCurves)*2)))
binary.BigEndian.PutUint16(out[4:], uint16(len(e.ellipticCurves)*2))
for _, v := range e.ellipticCurves {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v))
}
return out, nil
}
func (e *extensionSupportedEllipticCurves) Unmarshal(data []byte) error {
if len(data) <= extensionSupportedGroupsHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
groupCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if extensionSupportedGroupsHeaderSize+(groupCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < groupCount; i++ {
supportedGroupID := namedCurve(binary.BigEndian.Uint16(data[(extensionSupportedGroupsHeaderSize + (i * 2)):]))
if _, ok := namedCurves()[supportedGroupID]; ok {
e.ellipticCurves = append(e.ellipticCurves, supportedGroupID)
}
}
return nil
}

View file

@ -0,0 +1,56 @@
package dtls
import "encoding/binary"
const (
extensionSupportedPointFormatsSize = 5
)
type ellipticCurvePointFormat byte
const ellipticCurvePointFormatUncompressed ellipticCurvePointFormat = 0
// https://tools.ietf.org/html/rfc4492#section-5.1.2
type extensionSupportedPointFormats struct {
pointFormats []ellipticCurvePointFormat
}
func (e extensionSupportedPointFormats) extensionValue() extensionValue {
return extensionSupportedPointFormatsValue
}
func (e *extensionSupportedPointFormats) Marshal() ([]byte, error) {
out := make([]byte, extensionSupportedPointFormatsSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(1+(len(e.pointFormats))))
out[4] = byte(len(e.pointFormats))
for _, v := range e.pointFormats {
out = append(out, byte(v))
}
return out, nil
}
func (e *extensionSupportedPointFormats) Unmarshal(data []byte) error {
if len(data) <= extensionSupportedPointFormatsSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
pointFormatCount := int(binary.BigEndian.Uint16(data[4:]))
if extensionSupportedGroupsHeaderSize+(pointFormatCount) > len(data) {
return errLengthMismatch
}
for i := 0; i < pointFormatCount; i++ {
p := ellipticCurvePointFormat(data[extensionSupportedPointFormatsSize+i])
switch p {
case ellipticCurvePointFormatUncompressed:
e.pointFormats = append(e.pointFormats, p)
default:
}
}
return nil
}

View file

@ -0,0 +1,60 @@
package dtls
import (
"encoding/binary"
)
const (
extensionSupportedSignatureAlgorithmsHeaderSize = 6
)
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
type extensionSupportedSignatureAlgorithms struct {
signatureHashAlgorithms []signatureHashAlgorithm
}
func (e extensionSupportedSignatureAlgorithms) extensionValue() extensionValue {
return extensionSupportedSignatureAlgorithmsValue
}
func (e *extensionSupportedSignatureAlgorithms) Marshal() ([]byte, error) {
out := make([]byte, extensionSupportedSignatureAlgorithmsHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(e.signatureHashAlgorithms)*2)))
binary.BigEndian.PutUint16(out[4:], uint16(len(e.signatureHashAlgorithms)*2))
for _, v := range e.signatureHashAlgorithms {
out = append(out, []byte{0x00, 0x00}...)
out[len(out)-2] = byte(v.hash)
out[len(out)-1] = byte(v.signature)
}
return out, nil
}
func (e *extensionSupportedSignatureAlgorithms) Unmarshal(data []byte) error {
if len(data) <= extensionSupportedSignatureAlgorithmsHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
algorithmCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if extensionSupportedSignatureAlgorithmsHeaderSize+(algorithmCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < algorithmCount; i++ {
supportedHashAlgorithm := hashAlgorithm(data[extensionSupportedSignatureAlgorithmsHeaderSize+(i*2)])
supportedSignatureAlgorithm := signatureAlgorithm(data[extensionSupportedSignatureAlgorithmsHeaderSize+(i*2)+1])
if _, ok := hashAlgorithms()[supportedHashAlgorithm]; ok {
if _, ok := signatureAlgorithms()[supportedSignatureAlgorithm]; ok {
e.signatureHashAlgorithms = append(e.signatureHashAlgorithms, signatureHashAlgorithm{
supportedHashAlgorithm,
supportedSignatureAlgorithm,
})
}
}
}
return nil
}

View file

@ -0,0 +1,40 @@
package dtls
import "encoding/binary"
const (
extensionUseExtendedMasterSecretHeaderSize = 4
)
// https://tools.ietf.org/html/rfc8422
type extensionUseExtendedMasterSecret struct {
supported bool
}
func (e extensionUseExtendedMasterSecret) extensionValue() extensionValue {
return extensionUseExtendedMasterSecretValue
}
func (e *extensionUseExtendedMasterSecret) Marshal() ([]byte, error) {
if !e.supported {
return []byte{}, nil
}
out := make([]byte, extensionUseExtendedMasterSecretHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(0)) // length
return out, nil
}
func (e *extensionUseExtendedMasterSecret) Unmarshal(data []byte) error {
if len(data) < extensionUseExtendedMasterSecretHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
e.supported = true
return nil
}

View file

@ -0,0 +1,53 @@
package dtls
import "encoding/binary"
const (
extensionUseSRTPHeaderSize = 6
)
// https://tools.ietf.org/html/rfc8422
type extensionUseSRTP struct {
protectionProfiles []SRTPProtectionProfile
}
func (e extensionUseSRTP) extensionValue() extensionValue {
return extensionUseSRTPValue
}
func (e *extensionUseSRTP) Marshal() ([]byte, error) {
out := make([]byte, extensionUseSRTPHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(e.protectionProfiles)*2)+ /* MKI Length */ 1))
binary.BigEndian.PutUint16(out[4:], uint16(len(e.protectionProfiles)*2))
for _, v := range e.protectionProfiles {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v))
}
out = append(out, 0x00) /* MKI Length */
return out, nil
}
func (e *extensionUseSRTP) Unmarshal(data []byte) error {
if len(data) <= extensionUseSRTPHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
profileCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if extensionSupportedGroupsHeaderSize+(profileCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < profileCount; i++ {
supportedProfile := SRTPProtectionProfile(binary.BigEndian.Uint16(data[(extensionUseSRTPHeaderSize + (i * 2)):]))
if _, ok := srtpProtectionProfiles()[supportedProfile]; ok {
e.protectionProfiles = append(e.protectionProfiles, supportedProfile)
}
}
return nil
}

View file

@ -0,0 +1,75 @@
package dtls
/*
DTLS messages are grouped into a series of message flights, according
to the diagrams below. Although each flight of messages may consist
of a number of messages, they should be viewed as monolithic for the
purpose of timeout and retransmission.
https://tools.ietf.org/html/rfc4347#section-4.2.4
Client Server
------ ------
Waiting Flight 0
ClientHello --------> Flight 1
<------- HelloVerifyRequest Flight 2
ClientHello --------> Flight 3
ServerHello \
Certificate* \
ServerKeyExchange* Flight 4
CertificateRequest* /
<-------- ServerHelloDone /
Certificate* \
ClientKeyExchange \
CertificateVerify* Flight 5
[ChangeCipherSpec] /
Finished --------> /
[ChangeCipherSpec] \ Flight 6
<-------- Finished /
*/
type flightVal uint8
const (
flight0 flightVal = iota + 1
flight1
flight2
flight3
flight4
flight5
flight6
)
func (f flightVal) String() string {
switch f {
case flight0:
return "Flight 0"
case flight1:
return "Flight 1"
case flight2:
return "Flight 2"
case flight3:
return "Flight 3"
case flight4:
return "Flight 4"
case flight5:
return "Flight 5"
case flight6:
return "Flight 6"
default:
return "Invalid Flight"
}
}
func (f flightVal) isLastSendFlight() bool {
return f == flight6
}
func (f flightVal) isLastRecvFlight() bool {
return f == flight5
}

View file

@ -0,0 +1,89 @@
package dtls
import (
"context"
"crypto/rand"
)
func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
seq, msgs, ok := cache.fullPullMap(0,
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
state.handshakeRecvSequence = seq
var clientHello *handshakeMessageClientHello
// Validate type
if clientHello, ok = msgs[handshakeTypeClientHello].(*handshakeMessageClientHello); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
}
if !clientHello.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
}
state.remoteRandom = clientHello.random
if state.cipherSuite, ok = findMatchingCipherSuite(clientHello.cipherSuites, cfg.localCipherSuites); !ok {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errCipherSuiteNoIntersection
}
for _, extension := range clientHello.extensions {
switch e := extension.(type) {
case *extensionSupportedEllipticCurves:
if len(e.ellipticCurves) == 0 {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errNoSupportedEllipticCurves
}
state.namedCurve = e.ellipticCurves[0]
case *extensionUseSRTP:
profile, ok := findMatchingSRTPProfile(e.protectionProfiles, cfg.localSRTPProtectionProfiles)
if !ok {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errServerNoMatchingSRTPProfile
}
state.srtpProtectionProfile = profile
case *extensionUseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
}
case *extensionServerName:
state.serverName = e.serverName // remote server name
}
}
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errServerRequiredButNoClientEMS
}
if state.localKeypair == nil {
var err error
state.localKeypair, err = generateKeypair(state.namedCurve)
if err != nil {
return 0, &alert{alertLevelFatal, alertIllegalParameter}, err
}
}
return flight2, nil, nil
}
func flight0Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
// Initialize
state.cookie = make([]byte, cookieLength)
if _, err := rand.Read(state.cookie); err != nil {
return nil, nil, err
}
var zeroEpoch uint16
state.localEpoch.Store(zeroEpoch)
state.remoteEpoch.Store(zeroEpoch)
state.namedCurve = defaultNamedCurve
if err := state.localRandom.populate(); err != nil {
return nil, nil, err
}
return nil, nil, nil
}

View file

@ -0,0 +1,105 @@
package dtls
import (
"context"
)
func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
// HelloVerifyRequest can be skipped by the server,
// so allow ServerHello during flight1 also
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeHelloVerifyRequest, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, true},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
if _, ok := msgs[handshakeTypeServerHello]; ok {
// Flight1 and flight2 were skipped.
// Parse as flight3.
return flight3Parse(ctx, c, state, cache, cfg)
}
if h, ok := msgs[handshakeTypeHelloVerifyRequest].(*handshakeMessageHelloVerifyRequest); ok {
// DTLS 1.2 clients must not assume that the server will use the protocol version
// specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
if !h.version.Equal(protocolVersion1_0) && !h.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
}
state.cookie = append([]byte{}, h.cookie...)
state.handshakeRecvSequence = seq
return flight3, nil, nil
}
return 0, &alert{alertLevelFatal, alertInternalError}, nil
}
func flight1Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
var zeroEpoch uint16
state.localEpoch.Store(zeroEpoch)
state.remoteEpoch.Store(zeroEpoch)
state.namedCurve = defaultNamedCurve
state.cookie = nil
if err := state.localRandom.populate(); err != nil {
return nil, nil, err
}
extensions := []extension{
&extensionSupportedSignatureAlgorithms{
signatureHashAlgorithms: cfg.localSignatureSchemes,
},
&extensionRenegotiationInfo{
renegotiatedConnection: 0,
},
}
if cfg.localPSKCallback == nil {
extensions = append(extensions, []extension{
&extensionSupportedEllipticCurves{
ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256, namedCurveP384},
},
&extensionSupportedPointFormats{
pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
},
}...)
}
if len(cfg.localSRTPProtectionProfiles) > 0 {
extensions = append(extensions, &extensionUseSRTP{
protectionProfiles: cfg.localSRTPProtectionProfiles,
})
}
if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
cfg.extendedMasterSecret == RequireExtendedMasterSecret {
extensions = append(extensions, &extensionUseExtendedMasterSecret{
supported: true,
})
}
if len(cfg.serverName) > 0 {
extensions = append(extensions, &extensionServerName{serverName: cfg.serverName})
}
return []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageClientHello{
version: protocolVersion1_2,
cookie: state.cookie,
random: state.localRandom,
cipherSuites: cfg.localCipherSuites,
compressionMethods: defaultCompressionMethods(),
extensions: extensions,
},
},
},
},
}, nil, nil
}

View file

@ -0,0 +1,56 @@
package dtls
import (
"bytes"
"context"
)
func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
)
if !ok {
// Client may retransmit the first ClientHello when HelloVerifyRequest is dropped.
// Parse as flight 0 in this case.
return flight0Parse(ctx, c, state, cache, cfg)
}
state.handshakeRecvSequence = seq
var clientHello *handshakeMessageClientHello
// Validate type
if clientHello, ok = msgs[handshakeTypeClientHello].(*handshakeMessageClientHello); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
}
if !clientHello.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
}
if len(clientHello.cookie) == 0 {
return 0, nil, nil
}
if !bytes.Equal(state.cookie, clientHello.cookie) {
return 0, &alert{alertLevelFatal, alertAccessDenied}, errCookieMismatch
}
return flight4, nil, nil
}
func flight2Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
state.handshakeSendSequence = 0
return []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageHelloVerifyRequest{
version: protocolVersion1_2,
cookie: state.cookie,
},
},
},
},
}, nil, nil
}

View file

@ -0,0 +1,177 @@
package dtls
import (
"context"
)
func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) { //nolint:gocognit
// Clients may receive multiple HelloVerifyRequest messages with different cookies.
// Clients SHOULD handle this by sending a new ClientHello with a cookie in response
// to the new HelloVerifyRequest. RFC 6347 Section 4.2.1
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeHelloVerifyRequest, cfg.initialEpoch, false, true},
)
if ok {
if h, msgOk := msgs[handshakeTypeHelloVerifyRequest].(*handshakeMessageHelloVerifyRequest); msgOk {
// DTLS 1.2 clients must not assume that the server will use the protocol version
// specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
if !h.version.Equal(protocolVersion1_0) && !h.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
}
state.cookie = append([]byte{}, h.cookie...)
state.handshakeRecvSequence = seq
return flight3, nil, nil
}
}
if cfg.localPSKCallback != nil {
seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
)
} else {
seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
)
}
if !ok {
// Don't have enough messages. Keep reading
return 0, nil, nil
}
state.handshakeRecvSequence = seq
if h, ok := msgs[handshakeTypeServerHello].(*handshakeMessageServerHello); ok {
if !h.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
}
for _, extension := range h.extensions {
switch e := extension.(type) {
case *extensionUseSRTP:
profile, ok := findMatchingSRTPProfile(e.protectionProfiles, cfg.localSRTPProtectionProfiles)
if !ok {
return 0, &alert{alertLevelFatal, alertIllegalParameter}, errClientNoMatchingSRTPProfile
}
state.srtpProtectionProfile = profile
case *extensionUseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
}
}
}
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errClientRequiredButNoServerEMS
}
if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errRequestedButNoSRTPExtension
}
if _, ok := findMatchingCipherSuite([]cipherSuite{h.cipherSuite}, cfg.localCipherSuites); !ok {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errCipherSuiteNoIntersection
}
state.cipherSuite = h.cipherSuite
state.remoteRandom = h.random
cfg.log.Tracef("[handshake] use cipher suite: %s", h.cipherSuite.String())
}
if h, ok := msgs[handshakeTypeCertificate].(*handshakeMessageCertificate); ok {
state.PeerCertificates = h.certificate
}
if h, ok := msgs[handshakeTypeServerKeyExchange].(*handshakeMessageServerKeyExchange); ok {
alertPtr, err := handleServerKeyExchange(c, state, cfg, h)
if err != nil {
return 0, alertPtr, err
}
}
if _, ok := msgs[handshakeTypeCertificateRequest].(*handshakeMessageCertificateRequest); ok {
state.remoteRequestedCertificate = true
}
return flight5, nil, nil
}
func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshakeMessageServerKeyExchange) (*alert, error) {
var err error
if cfg.localPSKCallback != nil {
var psk []byte
if psk, err = cfg.localPSKCallback(h.identityHint); err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
}
state.preMasterSecret = prfPSKPreMasterSecret(psk)
} else {
if state.localKeypair, err = generateKeypair(h.namedCurve); err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
}
if state.preMasterSecret, err = prfPreMasterSecret(h.publicKey, state.localKeypair.privateKey, state.localKeypair.curve); err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
}
}
return nil, nil
}
func flight3Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
extensions := []extension{
&extensionSupportedSignatureAlgorithms{
signatureHashAlgorithms: cfg.localSignatureSchemes,
},
&extensionRenegotiationInfo{
renegotiatedConnection: 0,
},
}
if cfg.localPSKCallback == nil {
extensions = append(extensions, []extension{
&extensionSupportedEllipticCurves{
ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256, namedCurveP384},
},
&extensionSupportedPointFormats{
pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
},
}...)
}
if len(cfg.localSRTPProtectionProfiles) > 0 {
extensions = append(extensions, &extensionUseSRTP{
protectionProfiles: cfg.localSRTPProtectionProfiles,
})
}
if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
cfg.extendedMasterSecret == RequireExtendedMasterSecret {
extensions = append(extensions, &extensionUseExtendedMasterSecret{
supported: true,
})
}
if len(cfg.serverName) > 0 {
extensions = append(extensions, &extensionServerName{serverName: cfg.serverName})
}
return []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageClientHello{
version: protocolVersion1_2,
cookie: state.cookie,
random: state.localRandom,
cipherSuites: cfg.localCipherSuites,
compressionMethods: defaultCompressionMethods(),
extensions: extensions,
},
},
},
},
}, nil, nil
}

View file

@ -0,0 +1,303 @@
package dtls
import (
"context"
"crypto/x509"
)
func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) { //nolint:gocognit
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, true},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeCertificateVerify, cfg.initialEpoch, true, true},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
// Validate type
var clientKeyExchange *handshakeMessageClientKeyExchange
if clientKeyExchange, ok = msgs[handshakeTypeClientKeyExchange].(*handshakeMessageClientKeyExchange); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
}
if h, hasCert := msgs[handshakeTypeCertificate].(*handshakeMessageCertificate); hasCert {
state.PeerCertificates = h.certificate
}
if h, hasCertVerify := msgs[handshakeTypeCertificateVerify].(*handshakeMessageCertificateVerify); hasCertVerify {
if state.PeerCertificates == nil {
return 0, &alert{alertLevelFatal, alertNoCertificate}, errCertificateVerifyNoCertificate
}
plainText := cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
)
// Verify that the pair of hash algorithm and signiture is listed.
var validSignatureScheme bool
for _, ss := range cfg.localSignatureSchemes {
if ss.hash == h.hashAlgorithm && ss.signature == h.signatureAlgorithm {
validSignatureScheme = true
break
}
}
if !validSignatureScheme {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errNoAvailableSignatureSchemes
}
if err := verifyCertificateVerify(plainText, h.hashAlgorithm, h.signature, state.PeerCertificates); err != nil {
return 0, &alert{alertLevelFatal, alertBadCertificate}, err
}
var chains [][]*x509.Certificate
var err error
var verified bool
if cfg.clientAuth >= VerifyClientCertIfGiven {
if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil {
return 0, &alert{alertLevelFatal, alertBadCertificate}, err
}
verified = true
}
if cfg.verifyPeerCertificate != nil {
if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
return 0, &alert{alertLevelFatal, alertBadCertificate}, err
}
}
state.peerCertificatesVerified = verified
}
if !state.cipherSuite.isInitialized() {
serverRandom := state.localRandom.marshalFixed()
clientRandom := state.remoteRandom.marshalFixed()
var err error
var preMasterSecret []byte
if cfg.localPSKCallback != nil {
var psk []byte
if psk, err = cfg.localPSKCallback(clientKeyExchange.identityHint); err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
}
preMasterSecret = prfPSKPreMasterSecret(psk)
} else {
preMasterSecret, err = prfPreMasterSecret(clientKeyExchange.publicKey, state.localKeypair.privateKey, state.localKeypair.curve)
if err != nil {
return 0, &alert{alertLevelFatal, alertIllegalParameter}, err
}
}
if state.extendedMasterSecret {
var sessionHash []byte
sessionHash, err = cache.sessionHash(state.cipherSuite.hashFunc(), cfg.initialEpoch)
if err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
}
state.masterSecret, err = prfExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.hashFunc())
if err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
}
} else {
state.masterSecret, err = prfMasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.hashFunc())
if err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
}
}
if err := state.cipherSuite.init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
}
}
// Now, encrypted packets can be handled
if err := c.handleQueuedPackets(ctx); err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
}
seq, msgs, ok = cache.fullPullMap(seq,
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, true, false},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
state.handshakeRecvSequence = seq
if _, ok = msgs[handshakeTypeFinished].(*handshakeMessageFinished); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
}
switch cfg.clientAuth {
case RequireAnyClientCert:
if state.PeerCertificates == nil {
return 0, &alert{alertLevelFatal, alertNoCertificate}, errClientCertificateRequired
}
case VerifyClientCertIfGiven:
if state.PeerCertificates != nil && !state.peerCertificatesVerified {
return 0, &alert{alertLevelFatal, alertBadCertificate}, errClientCertificateNotVerified
}
case RequireAndVerifyClientCert:
if state.PeerCertificates == nil {
return 0, &alert{alertLevelFatal, alertNoCertificate}, errClientCertificateRequired
}
if !state.peerCertificatesVerified {
return 0, &alert{alertLevelFatal, alertBadCertificate}, errClientCertificateNotVerified
}
case NoClientCert, RequestClientCert:
return flight6, nil, nil
}
return flight6, nil, nil
}
func flight4Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
extensions := []extension{}
if (cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret {
extensions = append(extensions, &extensionUseExtendedMasterSecret{
supported: true,
})
}
if state.srtpProtectionProfile != 0 {
extensions = append(extensions, &extensionUseSRTP{
protectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile},
})
}
if cfg.localPSKCallback == nil {
extensions = append(extensions, []extension{
&extensionSupportedEllipticCurves{
ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256, namedCurveP384},
},
&extensionSupportedPointFormats{
pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
},
}...)
}
var pkts []*packet
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageServerHello{
version: protocolVersion1_2,
random: state.localRandom,
cipherSuite: state.cipherSuite,
compressionMethod: defaultCompressionMethods()[0],
extensions: extensions,
},
},
},
})
if cfg.localPSKCallback == nil {
certificate, err := cfg.getCertificate(cfg.serverName)
if err != nil {
return nil, &alert{alertLevelFatal, alertHandshakeFailure}, err
}
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageCertificate{
certificate: certificate.Certificate,
},
},
},
})
serverRandom := state.localRandom.marshalFixed()
clientRandom := state.remoteRandom.marshalFixed()
// Find compatible signature scheme
signatureHashAlgo, err := selectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey)
if err != nil {
return nil, &alert{alertLevelFatal, alertInsufficientSecurity}, err
}
signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.publicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.hash)
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
}
state.localKeySignature = signature
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageServerKeyExchange{
ellipticCurveType: ellipticCurveTypeNamedCurve,
namedCurve: state.namedCurve,
publicKey: state.localKeypair.publicKey,
hashAlgorithm: signatureHashAlgo.hash,
signatureAlgorithm: signatureHashAlgo.signature,
signature: state.localKeySignature,
},
},
},
})
if cfg.clientAuth > NoClientCert {
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageCertificateRequest{
certificateTypes: []clientCertificateType{clientCertificateTypeRSASign, clientCertificateTypeECDSASign},
signatureHashAlgorithms: cfg.localSignatureSchemes,
},
},
},
})
}
} else if cfg.localPSKIdentityHint != nil {
// To help the client in selecting which identity to use, the server
// can provide a "PSK identity hint" in the ServerKeyExchange message.
// If no hint is provided, the ServerKeyExchange message is omitted.
//
// https://tools.ietf.org/html/rfc4279#section-2
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageServerKeyExchange{
identityHint: cfg.localPSKIdentityHint,
},
},
},
})
}
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageServerHelloDone{},
},
},
})
return pkts, nil, nil
}

View file

@ -0,0 +1,313 @@
package dtls
import (
"bytes"
"context"
"crypto"
"crypto/x509"
)
func flight5Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, false, false},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
var finished *handshakeMessageFinished
if finished, ok = msgs[handshakeTypeFinished].(*handshakeMessageFinished); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
}
plainText := cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeCertificateVerify, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, true, false},
)
expectedVerifyData, err := prfVerifyDataServer(state.masterSecret, plainText, state.cipherSuite.hashFunc())
if err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
}
if !bytes.Equal(expectedVerifyData, finished.verifyData) {
return 0, &alert{alertLevelFatal, alertHandshakeFailure}, errVerifyDataMismatch
}
return flight5, nil, nil
}
func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) { //nolint:gocognit
var certBytes [][]byte
var privateKey crypto.PrivateKey
if len(cfg.localCertificates) > 0 {
certificate, err := cfg.getCertificate(cfg.serverName)
if err != nil {
return nil, &alert{alertLevelFatal, alertHandshakeFailure}, err
}
certBytes = certificate.Certificate
privateKey = certificate.PrivateKey
}
var pkts []*packet
if state.remoteRequestedCertificate {
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageCertificate{
certificate: certBytes,
},
},
},
})
}
clientKeyExchange := &handshakeMessageClientKeyExchange{}
if cfg.localPSKCallback == nil {
clientKeyExchange.publicKey = state.localKeypair.publicKey
} else {
clientKeyExchange.identityHint = cfg.localPSKIdentityHint
}
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: clientKeyExchange,
},
},
})
serverKeyExchangeData := cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
)
serverKeyExchange := &handshakeMessageServerKeyExchange{}
// handshakeMessageServerKeyExchange is optional for PSK
if len(serverKeyExchangeData) == 0 {
alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshakeMessageServerKeyExchange{})
if err != nil {
return nil, alertPtr, err
}
} else {
rawHandshake := &handshake{}
err := rawHandshake.Unmarshal(serverKeyExchangeData)
if err != nil {
return nil, &alert{alertLevelFatal, alertUnexpectedMessage}, err
}
switch h := rawHandshake.handshakeMessage.(type) {
case *handshakeMessageServerKeyExchange:
serverKeyExchange = h
default:
return nil, &alert{alertLevelFatal, alertUnexpectedMessage}, errInvalidContentType
}
}
// Append not-yet-sent packets
merged := []byte{}
seqPred := uint16(state.handshakeSendSequence)
for _, p := range pkts {
h, ok := p.record.content.(*handshake)
if !ok {
return nil, &alert{alertLevelFatal, alertInternalError}, errInvalidContentType
}
h.handshakeHeader.messageSequence = seqPred
seqPred++
raw, err := h.Marshal()
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
}
merged = append(merged, raw...)
}
if alertPtr, err := initalizeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil {
return nil, alertPtr, err
}
// If the client has sent a certificate with signing ability, a digitally-signed
// CertificateVerify message is sent to explicitly verify possession of the
// private key in the certificate.
if state.remoteRequestedCertificate && len(cfg.localCertificates) > 0 {
plainText := append(cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
), merged...)
// Find compatible signature scheme
signatureHashAlgo, err := selectSignatureScheme(cfg.localSignatureSchemes, privateKey)
if err != nil {
return nil, &alert{alertLevelFatal, alertInsufficientSecurity}, err
}
certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.hash)
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
}
state.localCertificatesVerify = certVerify
p := &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageCertificateVerify{
hashAlgorithm: signatureHashAlgo.hash,
signatureAlgorithm: signatureHashAlgo.signature,
signature: state.localCertificatesVerify,
},
},
},
}
pkts = append(pkts, p)
h, ok := p.record.content.(*handshake)
if !ok {
return nil, &alert{alertLevelFatal, alertInternalError}, errInvalidContentType
}
h.handshakeHeader.messageSequence = seqPred
// seqPred++ // this is the last use of seqPred
raw, err := h.Marshal()
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
}
merged = append(merged, raw...)
}
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &changeCipherSpec{},
},
})
if len(state.localVerifyData) == 0 {
plainText := cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeCertificateVerify, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, true, false},
)
var err error
state.localVerifyData, err = prfVerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.hashFunc())
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
}
}
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
epoch: 1,
},
content: &handshake{
handshakeMessage: &handshakeMessageFinished{
verifyData: state.localVerifyData,
},
},
},
shouldEncrypt: true,
resetLocalSequenceNumber: true,
})
return pkts, nil, nil
}
func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshakeMessageServerKeyExchange, sendingPlainText []byte) (*alert, error) { //nolint:gocognit
if state.cipherSuite.isInitialized() {
return nil, nil
}
clientRandom := state.localRandom.marshalFixed()
serverRandom := state.remoteRandom.marshalFixed()
var err error
if state.extendedMasterSecret {
var sessionHash []byte
sessionHash, err = cache.sessionHash(state.cipherSuite.hashFunc(), cfg.initialEpoch, sendingPlainText)
if err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
}
state.masterSecret, err = prfExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.hashFunc())
if err != nil {
return &alert{alertLevelFatal, alertIllegalParameter}, err
}
} else {
state.masterSecret, err = prfMasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.hashFunc())
if err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
}
}
if cfg.localPSKCallback == nil {
// Verify that the pair of hash algorithm and signiture is listed.
var validSignatureScheme bool
for _, ss := range cfg.localSignatureSchemes {
if ss.hash == h.hashAlgorithm && ss.signature == h.signatureAlgorithm {
validSignatureScheme = true
break
}
}
if !validSignatureScheme {
return &alert{alertLevelFatal, alertInsufficientSecurity}, errNoAvailableSignatureSchemes
}
expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.publicKey, h.namedCurve)
if err = verifyKeySignature(expectedMsg, h.signature, h.hashAlgorithm, state.PeerCertificates); err != nil {
return &alert{alertLevelFatal, alertBadCertificate}, err
}
var chains [][]*x509.Certificate
if !cfg.insecureSkipVerify {
if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil {
return &alert{alertLevelFatal, alertBadCertificate}, err
}
}
if cfg.verifyPeerCertificate != nil {
if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
return &alert{alertLevelFatal, alertBadCertificate}, err
}
}
}
if err = state.cipherSuite.init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
}
return nil, nil
}

View file

@ -0,0 +1,76 @@
package dtls
import (
"context"
)
func flight6Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1,
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, true, false},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
if _, ok = msgs[handshakeTypeFinished].(*handshakeMessageFinished); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
}
// Other party retransmitted the last flight.
return flight6, nil, nil
}
func flight6Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
var pkts []*packet
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &changeCipherSpec{},
},
})
if len(state.localVerifyData) == 0 {
plainText := cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeCertificateVerify, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, true, false},
)
var err error
state.localVerifyData, err = prfVerifyDataServer(state.masterSecret, plainText, state.cipherSuite.hashFunc())
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
}
}
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
epoch: 1,
},
content: &handshake{
handshakeMessage: &handshakeMessageFinished{
verifyData: state.localVerifyData,
},
},
},
shouldEncrypt: true,
resetLocalSequenceNumber: true,
},
)
return pkts, nil, nil
}

View file

@ -0,0 +1,55 @@
package dtls
import (
"context"
)
// Parse received handshakes and return next flightVal
type flightParser func(context.Context, flightConn, *State, *handshakeCache, *handshakeConfig) (flightVal, *alert, error)
// Generate flights
type flightGenerator func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert, error)
func (f flightVal) getFlightParser() (flightParser, error) {
switch f {
case flight0:
return flight0Parse, nil
case flight1:
return flight1Parse, nil
case flight2:
return flight2Parse, nil
case flight3:
return flight3Parse, nil
case flight4:
return flight4Parse, nil
case flight5:
return flight5Parse, nil
case flight6:
return flight6Parse, nil
default:
return nil, errInvalidFlight
}
}
func (f flightVal) getFlightGenerator() (gen flightGenerator, retransmit bool, err error) {
switch f {
case flight0:
return flight0Generate, true, nil
case flight1:
return flight1Generate, true, nil
case flight2:
// https://tools.ietf.org/html/rfc6347#section-3.2.1
// HelloVerifyRequests must not be retransmitted.
return flight2Generate, false, nil
case flight3:
return flight3Generate, true, nil
case flight4:
return flight4Generate, true, nil
case flight5:
return flight5Generate, true, nil
case flight6:
return flight6Generate, true, nil
default:
return nil, false, errInvalidFlight
}
}

View file

@ -0,0 +1,105 @@
package dtls
type fragment struct {
recordLayerHeader recordLayerHeader
handshakeHeader handshakeHeader
data []byte
}
type fragmentBuffer struct {
// map of MessageSequenceNumbers that hold slices of fragments
cache map[uint16][]*fragment
currentMessageSequenceNumber uint16
}
func newFragmentBuffer() *fragmentBuffer {
return &fragmentBuffer{cache: map[uint16][]*fragment{}}
}
// Attempts to push a DTLS packet to the fragmentBuffer
// when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled
// when an error returns it is fatal, and the DTLS connection should be stopped
func (f *fragmentBuffer) push(buf []byte) (bool, error) {
frag := new(fragment)
if err := frag.recordLayerHeader.Unmarshal(buf); err != nil {
return false, err
}
// fragment isn't a handshake, we don't need to handle it
if frag.recordLayerHeader.contentType != contentTypeHandshake {
return false, nil
}
for buf = buf[recordLayerHeaderSize:]; len(buf) != 0; frag = new(fragment) {
if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
return false, err
}
if _, ok := f.cache[frag.handshakeHeader.messageSequence]; !ok {
f.cache[frag.handshakeHeader.messageSequence] = []*fragment{}
}
// end index should be the length of handshake header but if the handshake
// was fragmented, we should keep them all
end := int(handshakeHeaderLength + frag.handshakeHeader.length)
if size := len(buf); end > size {
end = size
}
// Discard all headers, when rebuilding the packet we will re-build
frag.data = append([]byte{}, buf[handshakeHeaderLength:end]...)
f.cache[frag.handshakeHeader.messageSequence] = append(f.cache[frag.handshakeHeader.messageSequence], frag)
buf = buf[end:]
}
return true, nil
}
func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
frags, ok := f.cache[f.currentMessageSequenceNumber]
if !ok {
return nil, 0
}
// Go doesn't support recursive lambdas
var appendMessage func(targetOffset uint32) bool
rawMessage := []byte{}
appendMessage = func(targetOffset uint32) bool {
for _, f := range frags {
if f.handshakeHeader.fragmentOffset == targetOffset {
fragmentEnd := (f.handshakeHeader.fragmentOffset + f.handshakeHeader.fragmentLength)
if fragmentEnd != f.handshakeHeader.length {
if !appendMessage(fragmentEnd) {
return false
}
}
rawMessage = append(f.data, rawMessage...)
return true
}
}
return false
}
// Recursively collect up
if !appendMessage(0) {
return nil, 0
}
firstHeader := frags[0].handshakeHeader
firstHeader.fragmentOffset = 0
firstHeader.fragmentLength = firstHeader.length
rawHeader, err := firstHeader.Marshal()
if err != nil {
return nil, 0
}
messageEpoch := frags[0].recordLayerHeader.epoch
delete(f.cache, f.currentMessageSequenceNumber)
f.currentMessageSequenceNumber++
return append(rawHeader, rawMessage...), messageEpoch
}

View file

@ -0,0 +1,38 @@
// +build gofuzz
package dtls
import "fmt"
func partialHeaderMismatch(a, b recordLayerHeader) bool {
// Ignoring content length for now.
a.contentLen = b.contentLen
return a != b
}
func FuzzRecordLayer(data []byte) int {
var r recordLayer
if err := r.Unmarshal(data); err != nil {
return 0
}
buf, err := r.Marshal()
if err != nil {
return 1
}
if len(buf) == 0 {
panic("zero buff") // nolint
}
var nr recordLayer
if err = nr.Unmarshal(data); err != nil {
panic(err) // nolint
}
if partialHeaderMismatch(nr.recordLayerHeader, r.recordLayerHeader) {
panic( // nolint
fmt.Sprintf("header mismatch: %+v != %+v",
nr.recordLayerHeader, r.recordLayerHeader,
),
)
}
return 1
}

View file

@ -0,0 +1,12 @@
module github.com/pion/dtls/v2
require (
github.com/pion/logging v0.2.2
github.com/pion/transport v0.10.1
github.com/pion/udp v0.1.0
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1
)
go 1.13

View file

@ -0,0 +1,45 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/transport v0.10.0 h1:9M12BSneJm6ggGhJyWpDveFOstJsTiQjkLf4M44rm80=
github.com/pion/transport v0.10.0/go.mod h1:BnHnUipd0rZQyTVB2SBGojFHT9CBt5C5TcsJSQGkvSE=
github.com/pion/transport v0.10.1 h1:2W+yJT+0mOQ160ThZYUx5Zp2skzshiNgxrNE9GUfhJM=
github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A=
github.com/pion/udp v0.1.0 h1:uGxQsNyrqG3GLINv36Ff60covYmfrLoxzwnCsIYspXI=
github.com/pion/udp v0.1.0/go.mod h1:BPELIjbwE9PRbd/zxI/KYBnbo7B6+oA6YuEaNE8lths=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 h1:pLI5jrR7OSLijeIDcmRxNmw2api+jEfxLoykJVice/E=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102 h1:42cLlJJdEh+ySyeUUbEQ5bsTiq8voBeTuweGVkY6Puw=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -0,0 +1,136 @@
package dtls
// https://tools.ietf.org/html/rfc5246#section-7.4
type handshakeType uint8
const (
handshakeTypeHelloRequest handshakeType = 0
handshakeTypeClientHello handshakeType = 1
handshakeTypeServerHello handshakeType = 2
handshakeTypeHelloVerifyRequest handshakeType = 3
handshakeTypeCertificate handshakeType = 11
handshakeTypeServerKeyExchange handshakeType = 12
handshakeTypeCertificateRequest handshakeType = 13
handshakeTypeServerHelloDone handshakeType = 14
handshakeTypeCertificateVerify handshakeType = 15
handshakeTypeClientKeyExchange handshakeType = 16
handshakeTypeFinished handshakeType = 20
// msg_len for Handshake messages assumes an extra 12 bytes for
// sequence, fragment and version information
handshakeMessageHeaderLength = 12
)
type handshakeMessage interface {
Marshal() ([]byte, error)
Unmarshal(data []byte) error
handshakeType() handshakeType
}
func (h handshakeType) String() string {
switch h {
case handshakeTypeHelloRequest:
return "HelloRequest"
case handshakeTypeClientHello:
return "ClientHello"
case handshakeTypeServerHello:
return "ServerHello"
case handshakeTypeHelloVerifyRequest:
return "HelloVerifyRequest"
case handshakeTypeCertificate:
return "TypeCertificate"
case handshakeTypeServerKeyExchange:
return "ServerKeyExchange"
case handshakeTypeCertificateRequest:
return "CertificateRequest"
case handshakeTypeServerHelloDone:
return "ServerHelloDone"
case handshakeTypeCertificateVerify:
return "CertificateVerify"
case handshakeTypeClientKeyExchange:
return "ClientKeyExchange"
case handshakeTypeFinished:
return "Finished"
}
return ""
}
// The handshake protocol is responsible for selecting a cipher spec and
// generating a master secret, which together comprise the primary
// cryptographic parameters associated with a secure session. The
// handshake protocol can also optionally authenticate parties who have
// certificates signed by a trusted certificate authority.
// https://tools.ietf.org/html/rfc5246#section-7.3
type handshake struct {
handshakeHeader handshakeHeader
handshakeMessage handshakeMessage
}
func (h handshake) contentType() contentType {
return contentTypeHandshake
}
func (h *handshake) Marshal() ([]byte, error) {
if h.handshakeMessage == nil {
return nil, errHandshakeMessageUnset
} else if h.handshakeHeader.fragmentOffset != 0 {
return nil, errUnableToMarshalFragmented
}
msg, err := h.handshakeMessage.Marshal()
if err != nil {
return nil, err
}
h.handshakeHeader.length = uint32(len(msg))
h.handshakeHeader.fragmentLength = h.handshakeHeader.length
h.handshakeHeader.handshakeType = h.handshakeMessage.handshakeType()
header, err := h.handshakeHeader.Marshal()
if err != nil {
return nil, err
}
return append(header, msg...), nil
}
func (h *handshake) Unmarshal(data []byte) error {
if err := h.handshakeHeader.Unmarshal(data); err != nil {
return err
}
reportedLen := bigEndianUint24(data[1:])
if uint32(len(data)-handshakeMessageHeaderLength) != reportedLen {
return errLengthMismatch
} else if reportedLen != h.handshakeHeader.fragmentLength {
return errLengthMismatch
}
switch handshakeType(data[0]) {
case handshakeTypeHelloRequest:
return errNotImplemented
case handshakeTypeClientHello:
h.handshakeMessage = &handshakeMessageClientHello{}
case handshakeTypeHelloVerifyRequest:
h.handshakeMessage = &handshakeMessageHelloVerifyRequest{}
case handshakeTypeServerHello:
h.handshakeMessage = &handshakeMessageServerHello{}
case handshakeTypeCertificate:
h.handshakeMessage = &handshakeMessageCertificate{}
case handshakeTypeServerKeyExchange:
h.handshakeMessage = &handshakeMessageServerKeyExchange{}
case handshakeTypeCertificateRequest:
h.handshakeMessage = &handshakeMessageCertificateRequest{}
case handshakeTypeServerHelloDone:
h.handshakeMessage = &handshakeMessageServerHelloDone{}
case handshakeTypeClientKeyExchange:
h.handshakeMessage = &handshakeMessageClientKeyExchange{}
case handshakeTypeFinished:
h.handshakeMessage = &handshakeMessageFinished{}
case handshakeTypeCertificateVerify:
h.handshakeMessage = &handshakeMessageCertificateVerify{}
default:
return errNotImplemented
}
return h.handshakeMessage.Unmarshal(data[handshakeMessageHeaderLength:])
}

View file

@ -0,0 +1,168 @@
package dtls
import (
"sync"
)
type handshakeCacheItem struct {
typ handshakeType
isClient bool
epoch uint16
messageSequence uint16
data []byte
}
type handshakeCachePullRule struct {
typ handshakeType
epoch uint16
isClient bool
optional bool
}
type handshakeCache struct {
cache []*handshakeCacheItem
mu sync.Mutex
}
func newHandshakeCache() *handshakeCache {
return &handshakeCache{}
}
func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshakeType, isClient bool) bool { //nolint
h.mu.Lock()
defer h.mu.Unlock()
for _, i := range h.cache {
if i.messageSequence == messageSequence &&
i.isClient == isClient {
return false
}
}
h.cache = append(h.cache, &handshakeCacheItem{
data: append([]byte{}, data...),
epoch: epoch,
messageSequence: messageSequence,
typ: typ,
isClient: isClient,
})
return true
}
// returns a list handshakes that match the requested rules
// the list will contain null entries for rules that can't be satisfied
// multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies)
func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem {
h.mu.Lock()
defer h.mu.Unlock()
out := make([]*handshakeCacheItem, len(rules))
for i, r := range rules {
for _, c := range h.cache {
if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch {
switch {
case out[i] == nil:
out[i] = c
case out[i].messageSequence < c.messageSequence:
out[i] = c
}
}
}
}
return out
}
// fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map.
func (h *handshakeCache) fullPullMap(startSeq int, rules ...handshakeCachePullRule) (int, map[handshakeType]handshakeMessage, bool) {
h.mu.Lock()
defer h.mu.Unlock()
ci := make(map[handshakeType]*handshakeCacheItem)
for _, r := range rules {
var item *handshakeCacheItem
for _, c := range h.cache {
if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch {
switch {
case item == nil:
item = c
case item.messageSequence < c.messageSequence:
item = c
}
}
}
if !r.optional && item == nil {
// Missing mandatory message.
return startSeq, nil, false
}
ci[r.typ] = item
}
out := make(map[handshakeType]handshakeMessage)
seq := startSeq
for _, r := range rules {
t := r.typ
i := ci[t]
if i == nil {
continue
}
rawHandshake := &handshake{}
if err := rawHandshake.Unmarshal(i.data); err != nil {
return startSeq, nil, false
}
if uint16(seq) != rawHandshake.handshakeHeader.messageSequence {
// There is a gap. Some messages are not arrived.
return startSeq, nil, false
}
seq++
out[t] = rawHandshake.handshakeMessage
}
return seq, out, true
}
// pullAndMerge calls pull and then merges the results, ignoring any null entries
func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte {
merged := []byte{}
for _, p := range h.pull(rules...) {
if p != nil {
merged = append(merged, p.data...)
}
}
return merged
}
// sessionHash returns the session hash for Extended Master Secret support
// https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4
func (h *handshakeCache) sessionHash(hf hashFunc, epoch uint16, additional ...[]byte) ([]byte, error) {
merged := []byte{}
// Order defined by https://tools.ietf.org/html/rfc5246#section-7.3
handshakeBuffer := h.pull(
handshakeCachePullRule{handshakeTypeClientHello, epoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, epoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, epoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, epoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, epoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, epoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, epoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, epoch, true, false},
)
for _, p := range handshakeBuffer {
if p == nil {
continue
}
merged = append(merged, p.data...)
}
for _, a := range additional {
merged = append(merged, a...)
}
hash := hf()
if _, err := hash.Write(merged); err != nil {
return []byte{}, err
}
return hash.Sum(nil), nil
}

View file

@ -0,0 +1,41 @@
package dtls
import (
"encoding/binary"
)
// msg_len for Handshake messages assumes an extra 12 bytes for
// sequence, fragment and version information
const handshakeHeaderLength = 12
type handshakeHeader struct {
handshakeType handshakeType
length uint32 // uint24 in spec
messageSequence uint16
fragmentOffset uint32 // uint24 in spec
fragmentLength uint32 // uint24 in spec
}
func (h *handshakeHeader) Marshal() ([]byte, error) {
out := make([]byte, handshakeMessageHeaderLength)
out[0] = byte(h.handshakeType)
putBigEndianUint24(out[1:], h.length)
binary.BigEndian.PutUint16(out[4:], h.messageSequence)
putBigEndianUint24(out[6:], h.fragmentOffset)
putBigEndianUint24(out[9:], h.fragmentLength)
return out, nil
}
func (h *handshakeHeader) Unmarshal(data []byte) error {
if len(data) < handshakeHeaderLength {
return errBufferTooSmall
}
h.handshakeType = handshakeType(data[0])
h.length = bigEndianUint24(data[1:])
h.messageSequence = binary.BigEndian.Uint16(data[4:])
h.fragmentOffset = bigEndianUint24(data[6:])
h.fragmentLength = bigEndianUint24(data[9:])
return nil
}

View file

@ -0,0 +1,55 @@
package dtls
type handshakeMessageCertificate struct {
certificate [][]byte
}
func (h handshakeMessageCertificate) handshakeType() handshakeType {
return handshakeTypeCertificate
}
const (
handshakeMessageCertificateLengthFieldSize = 3
)
func (h *handshakeMessageCertificate) Marshal() ([]byte, error) {
out := make([]byte, handshakeMessageCertificateLengthFieldSize)
for _, r := range h.certificate {
// Certificate Length
out = append(out, make([]byte, handshakeMessageCertificateLengthFieldSize)...)
putBigEndianUint24(out[len(out)-handshakeMessageCertificateLengthFieldSize:], uint32(len(r)))
// Certificate body
out = append(out, append([]byte{}, r...)...)
}
// Total Payload Size
putBigEndianUint24(out[0:], uint32(len(out[handshakeMessageCertificateLengthFieldSize:])))
return out, nil
}
func (h *handshakeMessageCertificate) Unmarshal(data []byte) error {
if len(data) < handshakeMessageCertificateLengthFieldSize {
return errBufferTooSmall
}
if certificateBodyLen := int(bigEndianUint24(data)); certificateBodyLen+handshakeMessageCertificateLengthFieldSize != len(data) {
return errLengthMismatch
}
offset := handshakeMessageCertificateLengthFieldSize
for offset < len(data) {
certificateLen := int(bigEndianUint24(data[offset:]))
offset += handshakeMessageCertificateLengthFieldSize
if offset+certificateLen > len(data) {
return errLengthMismatch
}
h.certificate = append(h.certificate, append([]byte{}, data[offset:offset+certificateLen]...))
offset += certificateLen
}
return nil
}

View file

@ -0,0 +1,91 @@
package dtls
import (
"encoding/binary"
)
/*
A non-anonymous server can optionally request a certificate from
the client, if appropriate for the selected cipher suite. This
message, if sent, will immediately follow the ServerKeyExchange
message (if it is sent; otherwise, this message follows the
server's Certificate message).
*/
type handshakeMessageCertificateRequest struct {
certificateTypes []clientCertificateType
signatureHashAlgorithms []signatureHashAlgorithm
}
const (
handshakeMessageCertificateRequestMinLength = 5
)
func (h handshakeMessageCertificateRequest) handshakeType() handshakeType {
return handshakeTypeCertificateRequest
}
func (h *handshakeMessageCertificateRequest) Marshal() ([]byte, error) {
out := []byte{byte(len(h.certificateTypes))}
for _, v := range h.certificateTypes {
out = append(out, byte(v))
}
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(h.signatureHashAlgorithms)*2))
for _, v := range h.signatureHashAlgorithms {
out = append(out, byte(v.hash))
out = append(out, byte(v.signature))
}
out = append(out, []byte{0x00, 0x00}...) // Distinguished Names Length
return out, nil
}
func (h *handshakeMessageCertificateRequest) Unmarshal(data []byte) error {
if len(data) < handshakeMessageCertificateRequestMinLength {
return errBufferTooSmall
}
offset := 0
certificateTypesLength := int(data[0])
offset++
if (offset + certificateTypesLength) > len(data) {
return errBufferTooSmall
}
for i := 0; i < certificateTypesLength; i++ {
certType := clientCertificateType(data[offset+i])
if _, ok := clientCertificateTypes()[certType]; ok {
h.certificateTypes = append(h.certificateTypes, certType)
}
}
offset += certificateTypesLength
if len(data) < offset+2 {
return errBufferTooSmall
}
signatureHashAlgorithmsLength := int(binary.BigEndian.Uint16(data[offset:]))
offset += 2
if (offset + signatureHashAlgorithmsLength) > len(data) {
return errBufferTooSmall
}
for i := 0; i < signatureHashAlgorithmsLength; i += 2 {
if len(data) < (offset + i + 2) {
return errBufferTooSmall
}
hash := hashAlgorithm(data[offset+i])
signature := signatureAlgorithm(data[offset+i+1])
if _, ok := hashAlgorithms()[hash]; !ok {
continue
} else if _, ok := signatureAlgorithms()[signature]; !ok {
continue
}
h.signatureHashAlgorithms = append(h.signatureHashAlgorithms, signatureHashAlgorithm{signature: signature, hash: hash})
}
return nil
}

View file

@ -0,0 +1,51 @@
package dtls
import (
"encoding/binary"
)
type handshakeMessageCertificateVerify struct {
hashAlgorithm hashAlgorithm
signatureAlgorithm signatureAlgorithm
signature []byte
}
const handshakeMessageCertificateVerifyMinLength = 4
func (h handshakeMessageCertificateVerify) handshakeType() handshakeType {
return handshakeTypeCertificateVerify
}
func (h *handshakeMessageCertificateVerify) Marshal() ([]byte, error) {
out := make([]byte, 1+1+2+len(h.signature))
out[0] = byte(h.hashAlgorithm)
out[1] = byte(h.signatureAlgorithm)
binary.BigEndian.PutUint16(out[2:], uint16(len(h.signature)))
copy(out[4:], h.signature)
return out, nil
}
func (h *handshakeMessageCertificateVerify) Unmarshal(data []byte) error {
if len(data) < handshakeMessageCertificateVerifyMinLength {
return errBufferTooSmall
}
h.hashAlgorithm = hashAlgorithm(data[0])
if _, ok := hashAlgorithms()[h.hashAlgorithm]; !ok {
return errInvalidHashAlgorithm
}
h.signatureAlgorithm = signatureAlgorithm(data[1])
if _, ok := signatureAlgorithms()[h.signatureAlgorithm]; !ok {
return errInvalidSignatureAlgorithm
}
signatureLength := int(binary.BigEndian.Uint16(data[2:]))
if (signatureLength + 4) != len(data) {
return errBufferTooSmall
}
h.signature = append([]byte{}, data[4:]...)
return nil
}

View file

@ -0,0 +1,119 @@
package dtls
import (
"encoding/binary"
)
/*
When a client first connects to a server it is required to send
the client hello as its first message. The client can also send a
client hello in response to a hello request or on its own
initiative in order to renegotiate the security parameters in an
existing connection.
*/
type handshakeMessageClientHello struct {
version protocolVersion
random handshakeRandom
cookie []byte
cipherSuites []cipherSuite
compressionMethods []*compressionMethod
extensions []extension
}
const handshakeMessageClientHelloVariableWidthStart = 34
func (h handshakeMessageClientHello) handshakeType() handshakeType {
return handshakeTypeClientHello
}
func (h *handshakeMessageClientHello) Marshal() ([]byte, error) {
if len(h.cookie) > 255 {
return nil, errCookieTooLong
}
out := make([]byte, handshakeMessageClientHelloVariableWidthStart)
out[0] = h.version.major
out[1] = h.version.minor
rand := h.random.marshalFixed()
copy(out[2:], rand[:])
out = append(out, 0x00) // SessionID
out = append(out, byte(len(h.cookie)))
out = append(out, h.cookie...)
out = append(out, encodeCipherSuites(h.cipherSuites)...)
out = append(out, encodeCompressionMethods(h.compressionMethods)...)
extensions, err := encodeExtensions(h.extensions)
if err != nil {
return nil, err
}
return append(out, extensions...), nil
}
func (h *handshakeMessageClientHello) Unmarshal(data []byte) error {
if len(data) < 2+handshakeRandomLength {
return errBufferTooSmall
}
h.version.major = data[0]
h.version.minor = data[1]
var random [handshakeRandomLength]byte
copy(random[:], data[2:])
h.random.unmarshalFixed(random)
// rest of packet has variable width sections
currOffset := handshakeMessageClientHelloVariableWidthStart
currOffset += int(data[currOffset]) + 1 // SessionID
currOffset++
if len(data) <= currOffset {
return errBufferTooSmall
}
n := int(data[currOffset-1])
if len(data) <= currOffset+n {
return errBufferTooSmall
}
h.cookie = append([]byte{}, data[currOffset:currOffset+n]...)
currOffset += len(h.cookie)
// Cipher Suites
if len(data) < currOffset {
return errBufferTooSmall
}
cipherSuites, err := decodeCipherSuites(data[currOffset:])
if err != nil {
return err
}
h.cipherSuites = cipherSuites
if len(data) < currOffset+2 {
return errBufferTooSmall
}
currOffset += int(binary.BigEndian.Uint16(data[currOffset:])) + 2
// Compression Methods
if len(data) < currOffset {
return errBufferTooSmall
}
compressionMethods, err := decodeCompressionMethods(data[currOffset:])
if err != nil {
return err
}
h.compressionMethods = compressionMethods
if len(data) < currOffset {
return errBufferTooSmall
}
currOffset += int(data[currOffset]) + 1
// Extensions
extensions, err := decodeExtensions(data[currOffset:])
if err != nil {
return err
}
h.extensions = extensions
return nil
}

View file

@ -0,0 +1,46 @@
package dtls
import (
"encoding/binary"
)
type handshakeMessageClientKeyExchange struct {
identityHint []byte
publicKey []byte
}
func (h handshakeMessageClientKeyExchange) handshakeType() handshakeType {
return handshakeTypeClientKeyExchange
}
func (h *handshakeMessageClientKeyExchange) Marshal() ([]byte, error) {
switch {
case (h.identityHint != nil && h.publicKey != nil) || (h.identityHint == nil && h.publicKey == nil):
return nil, errInvalidClientKeyExchange
case h.publicKey != nil:
return append([]byte{byte(len(h.publicKey))}, h.publicKey...), nil
default:
out := append([]byte{0x00, 0x00}, h.identityHint...)
binary.BigEndian.PutUint16(out, uint16(len(out)-2))
return out, nil
}
}
func (h *handshakeMessageClientKeyExchange) Unmarshal(data []byte) error {
if len(data) < 2 {
return errBufferTooSmall
}
// If parsed as PSK return early and only populate PSK Identity Hint
if pskLength := binary.BigEndian.Uint16(data); len(data) == int(pskLength+2) {
h.identityHint = append([]byte{}, data[2:]...)
return nil
}
if publicKeyLength := int(data[0]); len(data) != publicKeyLength+1 {
return errBufferTooSmall
}
h.publicKey = append([]byte{}, data[1:]...)
return nil
}

View file

@ -0,0 +1,18 @@
package dtls
type handshakeMessageFinished struct {
verifyData []byte
}
func (h handshakeMessageFinished) handshakeType() handshakeType {
return handshakeTypeFinished
}
func (h *handshakeMessageFinished) Marshal() ([]byte, error) {
return append([]byte{}, h.verifyData...), nil
}
func (h *handshakeMessageFinished) Unmarshal(data []byte) error {
h.verifyData = append([]byte{}, data...)
return nil
}

View file

@ -0,0 +1,57 @@
package dtls
/*
The definition of HelloVerifyRequest is as follows:
struct {
ProtocolVersion server_version;
opaque cookie<0..2^8-1>;
} HelloVerifyRequest;
The HelloVerifyRequest message type is hello_verify_request(3).
When the client sends its ClientHello message to the server, the server
MAY respond with a HelloVerifyRequest message. This message contains
a stateless cookie generated using the technique of [PHOTURIS]. The
client MUST retransmit the ClientHello with the cookie added.
https://tools.ietf.org/html/rfc6347#section-4.2.1
*/
type handshakeMessageHelloVerifyRequest struct {
version protocolVersion
cookie []byte
}
func (h handshakeMessageHelloVerifyRequest) handshakeType() handshakeType {
return handshakeTypeHelloVerifyRequest
}
func (h *handshakeMessageHelloVerifyRequest) Marshal() ([]byte, error) {
if len(h.cookie) > 255 {
return nil, errCookieTooLong
}
out := make([]byte, 3+len(h.cookie))
out[0] = h.version.major
out[1] = h.version.minor
out[2] = byte(len(h.cookie))
copy(out[3:], h.cookie)
return out, nil
}
func (h *handshakeMessageHelloVerifyRequest) Unmarshal(data []byte) error {
if len(data) < 3 {
return errBufferTooSmall
}
h.version.major = data[0]
h.version.minor = data[1]
cookieLength := data[2]
if len(data) < (int(cookieLength) + 3) {
return errBufferTooSmall
}
h.cookie = make([]byte, cookieLength)
copy(h.cookie, data[3:3+cookieLength])
return nil
}

View file

@ -0,0 +1,102 @@
package dtls
import (
"encoding/binary"
)
/*
The server will send this message in response to a ClientHello
message when it was able to find an acceptable set of algorithms.
If it cannot find such a match, it will respond with a handshake
failure alert.
https://tools.ietf.org/html/rfc5246#section-7.4.1.3
*/
type handshakeMessageServerHello struct {
version protocolVersion
random handshakeRandom
cipherSuite cipherSuite
compressionMethod *compressionMethod
extensions []extension
}
const handshakeMessageServerHelloVariableWidthStart = 2 + handshakeRandomLength
func (h handshakeMessageServerHello) handshakeType() handshakeType {
return handshakeTypeServerHello
}
func (h *handshakeMessageServerHello) Marshal() ([]byte, error) {
if h.cipherSuite == nil {
return nil, errCipherSuiteUnset
} else if h.compressionMethod == nil {
return nil, errCompressionMethodUnset
}
out := make([]byte, handshakeMessageServerHelloVariableWidthStart)
out[0] = h.version.major
out[1] = h.version.minor
rand := h.random.marshalFixed()
copy(out[2:], rand[:])
out = append(out, 0x00) // SessionID
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(h.cipherSuite.ID()))
out = append(out, byte(h.compressionMethod.id))
extensions, err := encodeExtensions(h.extensions)
if err != nil {
return nil, err
}
return append(out, extensions...), nil
}
func (h *handshakeMessageServerHello) Unmarshal(data []byte) error {
if len(data) < 2+handshakeRandomLength {
return errBufferTooSmall
}
h.version.major = data[0]
h.version.minor = data[1]
var random [handshakeRandomLength]byte
copy(random[:], data[2:])
h.random.unmarshalFixed(random)
currOffset := handshakeMessageServerHelloVariableWidthStart
currOffset += int(data[currOffset]) + 1 // SessionID
if len(data) < (currOffset + 2) {
return errBufferTooSmall
}
if c := cipherSuiteForID(CipherSuiteID(binary.BigEndian.Uint16(data[currOffset:]))); c != nil {
h.cipherSuite = c
currOffset += 2
} else {
return errInvalidCipherSuite
}
if len(data) < currOffset {
return errBufferTooSmall
}
if compressionMethod, ok := compressionMethods()[compressionMethodID(data[currOffset])]; ok {
h.compressionMethod = compressionMethod
currOffset++
} else {
return errInvalidCompressionMethod
}
if len(data) <= currOffset {
h.extensions = []extension{}
return nil
}
extensions, err := decodeExtensions(data[currOffset:])
if err != nil {
return err
}
h.extensions = extensions
return nil
}

View file

@ -0,0 +1,16 @@
package dtls
type handshakeMessageServerHelloDone struct {
}
func (h handshakeMessageServerHelloDone) handshakeType() handshakeType {
return handshakeTypeServerHelloDone
}
func (h *handshakeMessageServerHelloDone) Marshal() ([]byte, error) {
return []byte{}, nil
}
func (h *handshakeMessageServerHelloDone) Unmarshal(data []byte) error {
return nil
}

View file

@ -0,0 +1,104 @@
package dtls
import (
"encoding/binary"
)
// Structure supports ECDH and PSK
type handshakeMessageServerKeyExchange struct {
identityHint []byte
ellipticCurveType ellipticCurveType
namedCurve namedCurve
publicKey []byte
hashAlgorithm hashAlgorithm
signatureAlgorithm signatureAlgorithm
signature []byte
}
func (h handshakeMessageServerKeyExchange) handshakeType() handshakeType {
return handshakeTypeServerKeyExchange
}
func (h *handshakeMessageServerKeyExchange) Marshal() ([]byte, error) {
if h.identityHint != nil {
out := append([]byte{0x00, 0x00}, h.identityHint...)
binary.BigEndian.PutUint16(out, uint16(len(out)-2))
return out, nil
}
out := []byte{byte(h.ellipticCurveType), 0x00, 0x00}
binary.BigEndian.PutUint16(out[1:], uint16(h.namedCurve))
out = append(out, byte(len(h.publicKey)))
out = append(out, h.publicKey...)
out = append(out, []byte{byte(h.hashAlgorithm), byte(h.signatureAlgorithm), 0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(h.signature)))
out = append(out, h.signature...)
return out, nil
}
func (h *handshakeMessageServerKeyExchange) Unmarshal(data []byte) error {
if len(data) < 2 {
return errBufferTooSmall
}
// If parsed as PSK return early and only populate PSK Identity Hint
if pskLength := binary.BigEndian.Uint16(data); len(data) == int(pskLength+2) {
h.identityHint = append([]byte{}, data[2:]...)
return nil
}
if _, ok := ellipticCurveTypes()[ellipticCurveType(data[0])]; ok {
h.ellipticCurveType = ellipticCurveType(data[0])
} else {
return errInvalidEllipticCurveType
}
if len(data[1:]) < 2 {
return errBufferTooSmall
}
h.namedCurve = namedCurve(binary.BigEndian.Uint16(data[1:3]))
if _, ok := namedCurves()[h.namedCurve]; !ok {
return errInvalidNamedCurve
}
if len(data) < 4 {
return errBufferTooSmall
}
publicKeyLength := int(data[3])
offset := 4 + publicKeyLength
if len(data) < offset {
return errBufferTooSmall
}
h.publicKey = append([]byte{}, data[4:offset]...)
if len(data) <= offset {
return errBufferTooSmall
}
h.hashAlgorithm = hashAlgorithm(data[offset])
if _, ok := hashAlgorithms()[h.hashAlgorithm]; !ok {
return errInvalidHashAlgorithm
}
offset++
if len(data) <= offset {
return errBufferTooSmall
}
h.signatureAlgorithm = signatureAlgorithm(data[offset])
if _, ok := signatureAlgorithms()[h.signatureAlgorithm]; !ok {
return errInvalidSignatureAlgorithm
}
offset++
if len(data) < offset+2 {
return errBufferTooSmall
}
signatureLength := int(binary.BigEndian.Uint16(data[offset:]))
offset += 2
if len(data) < offset+signatureLength {
return errBufferTooSmall
}
h.signature = append([]byte{}, data[offset:offset+signatureLength]...)
return nil
}

View file

@ -0,0 +1,44 @@
package dtls
import (
"crypto/rand"
"encoding/binary"
"time"
)
const (
randomBytesLength = 28
handshakeRandomLength = randomBytesLength + 4
)
// https://tools.ietf.org/html/rfc4346#section-7.4.1.2
type handshakeRandom struct {
gmtUnixTime time.Time
randomBytes [randomBytesLength]byte
}
func (h *handshakeRandom) marshalFixed() [handshakeRandomLength]byte {
var out [handshakeRandomLength]byte
binary.BigEndian.PutUint32(out[0:], uint32(h.gmtUnixTime.Unix()))
copy(out[4:], h.randomBytes[:])
return out
}
func (h *handshakeRandom) unmarshalFixed(data [handshakeRandomLength]byte) {
h.gmtUnixTime = time.Unix(int64(binary.BigEndian.Uint32(data[0:])), 0)
copy(h.randomBytes[:], data[4:])
}
// populate fills the handshakeRandom with random values
// may be called multiple times
func (h *handshakeRandom) populate() error {
h.gmtUnixTime = time.Now()
tmp := make([]byte, randomBytesLength)
_, err := rand.Read(tmp)
copy(h.randomBytes[:], tmp)
return err
}

View file

@ -0,0 +1,318 @@
package dtls
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"sync"
"time"
"github.com/pion/logging"
)
// [RFC6347 Section-4.2.4]
// +-----------+
// +---> | PREPARING | <--------------------+
// | +-----------+ |
// | | |
// | | Buffer next flight |
// | | |
// | \|/ |
// | +-----------+ |
// | | SENDING |<------------------+ | Send
// | +-----------+ | | HelloRequest
// Receive | | | |
// next | | Send flight | | or
// flight | +--------+ | |
// | | | Set retransmit timer | | Receive
// | | \|/ | | HelloRequest
// | | +-----------+ | | Send
// +--)--| WAITING |-------------------+ | ClientHello
// | | +-----------+ Timer expires | |
// | | | | |
// | | +------------------------+ |
// Receive | | Send Read retransmit |
// last | | last |
// flight | | flight |
// | | |
// \|/\|/ |
// +-----------+ |
// | FINISHED | -------------------------------+
// +-----------+
// | /|\
// | |
// +---+
// Read retransmit
// Retransmit last flight
var errInvalidFSMTransition = errors.New("invalid state machine transition")
type handshakeState uint8
const (
handshakeErrored handshakeState = iota
handshakePreparing
handshakeSending
handshakeWaiting
handshakeFinished
)
func (s handshakeState) String() string {
switch s {
case handshakeErrored:
return "Errored"
case handshakePreparing:
return "Preparing"
case handshakeSending:
return "Sending"
case handshakeWaiting:
return "Waiting"
case handshakeFinished:
return "Finished"
default:
return "Unknown"
}
}
type handshakeFSM struct {
currentFlight flightVal
flights []*packet
retransmit bool
state *State
cache *handshakeCache
cfg *handshakeConfig
closed chan struct{}
}
type handshakeConfig struct {
localPSKCallback PSKCallback
localPSKIdentityHint []byte
localCipherSuites []cipherSuite // Available CipherSuites
localSignatureSchemes []signatureHashAlgorithm // Available signature schemes
extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension
localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support
serverName string
clientAuth ClientAuthType // If we are a client should we request a client certificate
localCertificates []tls.Certificate
nameToCertificate map[string]*tls.Certificate
insecureSkipVerify bool
verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
rootCAs *x509.CertPool
clientCAs *x509.CertPool
retransmitInterval time.Duration
onFlightState func(flightVal, handshakeState)
log logging.LeveledLogger
initialEpoch uint16
mu sync.Mutex
}
type flightConn interface {
notify(ctx context.Context, level alertLevel, desc alertDescription) error
writePackets(context.Context, []*packet) error
recvHandshake() <-chan chan struct{}
setLocalEpoch(epoch uint16)
handleQueuedPackets(context.Context) error
}
func srvCliStr(isClient bool) string {
if isClient {
return "client"
}
return "server"
}
func newHandshakeFSM(
s *State, cache *handshakeCache, cfg *handshakeConfig,
initialFlight flightVal,
) *handshakeFSM {
return &handshakeFSM{
currentFlight: initialFlight,
state: s,
cache: cache,
cfg: cfg,
closed: make(chan struct{}),
}
}
func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error {
state := initialState
defer func() {
close(s.closed)
}()
for {
s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
if s.cfg.onFlightState != nil {
s.cfg.onFlightState(s.currentFlight, state)
}
var err error
switch state {
case handshakePreparing:
state, err = s.prepare(ctx, c)
case handshakeSending:
state, err = s.send(ctx, c)
case handshakeWaiting:
state, err = s.wait(ctx, c)
case handshakeFinished:
state, err = s.finish(ctx, c)
default:
return errInvalidFSMTransition
}
if err != nil {
return err
}
}
}
func (s *handshakeFSM) Done() <-chan struct{} {
return s.closed
}
func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) {
s.flights = nil
// Prepare flights
var (
a *alert
err error
pkts []*packet
)
gen, retransmit, errFlight := s.currentFlight.getFlightGenerator()
if errFlight != nil {
err = errFlight
a = &alert{alertLevelFatal, alertInternalError}
} else {
pkts, a, err = gen(c, s.state, s.cache, s.cfg)
s.retransmit = retransmit
}
if a != nil {
if alertErr := c.notify(ctx, a.alertLevel, a.alertDescription); alertErr != nil {
if err != nil {
err = alertErr
}
}
}
if err != nil {
return handshakeErrored, err
}
s.flights = pkts
epoch := s.cfg.initialEpoch
nextEpoch := epoch
for _, p := range s.flights {
p.record.recordLayerHeader.epoch += epoch
if p.record.recordLayerHeader.epoch > nextEpoch {
nextEpoch = p.record.recordLayerHeader.epoch
}
if h, ok := p.record.content.(*handshake); ok {
h.handshakeHeader.messageSequence = uint16(s.state.handshakeSendSequence)
s.state.handshakeSendSequence++
}
}
if epoch != nextEpoch {
s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch)
c.setLocalEpoch(nextEpoch)
}
return handshakeSending, nil
}
func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) {
// Send flights
if err := c.writePackets(ctx, s.flights); err != nil {
return handshakeErrored, err
}
if s.currentFlight.isLastSendFlight() {
return handshakeFinished, nil
}
return handshakeWaiting, nil
}
func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit
parse, errFlight := s.currentFlight.getFlightParser()
if errFlight != nil {
if alertErr := c.notify(ctx, alertLevelFatal, alertInternalError); alertErr != nil {
if errFlight != nil {
return handshakeErrored, alertErr
}
}
return handshakeErrored, errFlight
}
retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
for {
select {
case done := <-c.recvHandshake():
nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
close(done)
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if err != nil {
err = alertErr
}
}
}
if err != nil {
return handshakeErrored, err
}
if nextFlight == 0 {
break
}
s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String())
if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
return handshakeFinished, nil
}
s.currentFlight = nextFlight
return handshakePreparing, nil
case <-retransmitTimer.C:
if !s.retransmit {
return handshakeWaiting, nil
}
return handshakeSending, nil
case <-ctx.Done():
return handshakeErrored, ctx.Err()
}
}
}
func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) {
parse, errFlight := s.currentFlight.getFlightParser()
if errFlight != nil {
if alertErr := c.notify(ctx, alertLevelFatal, alertInternalError); alertErr != nil {
if errFlight != nil {
return handshakeErrored, alertErr
}
}
return handshakeErrored, errFlight
}
retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
select {
case done := <-c.recvHandshake():
nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
close(done)
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if err != nil {
err = alertErr
}
}
}
if err != nil {
return handshakeErrored, err
}
if nextFlight == 0 {
break
}
<-retransmitTimer.C
// Retransmit last flight
return handshakeSending, nil
case <-ctx.Done():
return handshakeErrored, ctx.Err()
}
return handshakeFinished, nil
}

View file

@ -0,0 +1,116 @@
package dtls
import ( //nolint:gci
"crypto"
"crypto/md5" //nolint:gosec
"crypto/sha1" //nolint:gosec
"crypto/sha256"
"crypto/sha512"
)
// hashAlgorithm is used to indicate the hash algorithm used
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-18
type hashAlgorithm uint16
// Supported hash hash algorithms
const (
hashAlgorithmMD2 hashAlgorithm = 0 // Blacklisted
hashAlgorithmMD5 hashAlgorithm = 1 // Blacklisted
hashAlgorithmSHA1 hashAlgorithm = 2 // Blacklisted
hashAlgorithmSHA224 hashAlgorithm = 3
hashAlgorithmSHA256 hashAlgorithm = 4
hashAlgorithmSHA384 hashAlgorithm = 5
hashAlgorithmSHA512 hashAlgorithm = 6
hashAlgorithmEd25519 hashAlgorithm = 8
)
// String makes hashAlgorithm printable
func (h hashAlgorithm) String() string {
switch h {
case hashAlgorithmMD2:
return "md2"
case hashAlgorithmMD5:
return "md5" // [RFC3279]
case hashAlgorithmSHA1:
return "sha-1" // [RFC3279]
case hashAlgorithmSHA224:
return "sha-224" // [RFC4055]
case hashAlgorithmSHA256:
return "sha-256" // [RFC4055]
case hashAlgorithmSHA384:
return "sha-384" // [RFC4055]
case hashAlgorithmSHA512:
return "sha-512" // [RFC4055]
case hashAlgorithmEd25519:
return "null"
default:
return "unknown or unsupported hash algorithm"
}
}
func (h hashAlgorithm) digest(b []byte) []byte {
switch h {
case hashAlgorithmMD5:
hash := md5.Sum(b) // #nosec
return hash[:]
case hashAlgorithmSHA1:
hash := sha1.Sum(b) // #nosec
return hash[:]
case hashAlgorithmSHA224:
hash := sha256.Sum224(b)
return hash[:]
case hashAlgorithmSHA256:
hash := sha256.Sum256(b)
return hash[:]
case hashAlgorithmSHA384:
hash := sha512.Sum384(b)
return hash[:]
case hashAlgorithmSHA512:
hash := sha512.Sum512(b)
return hash[:]
default:
return nil
}
}
func (h hashAlgorithm) insecure() bool {
switch h {
case hashAlgorithmMD2, hashAlgorithmMD5, hashAlgorithmSHA1:
return true
default:
return false
}
}
func (h hashAlgorithm) cryptoHash() crypto.Hash {
switch h {
case hashAlgorithmMD5:
return crypto.MD5
case hashAlgorithmSHA1:
return crypto.SHA1
case hashAlgorithmSHA224:
return crypto.SHA224
case hashAlgorithmSHA256:
return crypto.SHA256
case hashAlgorithmSHA384:
return crypto.SHA384
case hashAlgorithmSHA512:
return crypto.SHA512
case hashAlgorithmEd25519:
return crypto.Hash(0)
default:
return crypto.Hash(0)
}
}
func hashAlgorithms() map[hashAlgorithm]struct{} {
return map[hashAlgorithm]struct{}{
hashAlgorithmMD5: {},
hashAlgorithmSHA1: {},
hashAlgorithmSHA224: {},
hashAlgorithmSHA256: {},
hashAlgorithmSHA384: {},
hashAlgorithmSHA512: {},
hashAlgorithmEd25519: {},
}
}

View file

@ -0,0 +1,45 @@
// Package closer provides signaling channel for shutdown
package closer
import (
"context"
)
// Closer allows for each signaling a channel for shutdown
type Closer struct {
ctx context.Context
closeFunc func()
}
// NewCloser creates a new instance of Closer
func NewCloser() *Closer {
ctx, closeFunc := context.WithCancel(context.Background())
return &Closer{
ctx: ctx,
closeFunc: closeFunc,
}
}
// NewCloserWithParent creates a new instance of Closer with a parent context
func NewCloserWithParent(ctx context.Context) *Closer {
ctx, closeFunc := context.WithCancel(ctx)
return &Closer{
ctx: ctx,
closeFunc: closeFunc,
}
}
// Done returns a channel signaling when it is done
func (c *Closer) Done() <-chan struct{} {
return c.ctx.Done()
}
// Err returns an error of the context
func (c *Closer) Err() error {
return c.ctx.Err()
}
// Close sends a signal to trigger the ctx done channel
func (c *Closer) Close() {
c.closeFunc()
}

View file

@ -0,0 +1,156 @@
// Package connctx wraps net.Conn using context.Context.
package connctx
import (
"context"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
// ErrClosing is returned on Write to closed connection.
var ErrClosing = errors.New("use of closed network connection")
// ConnCtx is a wrapper of net.Conn using context.Context.
type ConnCtx interface {
Read(context.Context, []byte) (int, error)
Write(context.Context, []byte) (int, error)
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
Conn() net.Conn
}
type connCtx struct {
nextConn net.Conn
closed chan struct{}
closeOnce sync.Once
readMu sync.Mutex
writeMu sync.Mutex
}
var veryOld = time.Unix(0, 1) //nolint:gochecknoglobals
// New creates a new ConnCtx wrapping given net.Conn.
func New(conn net.Conn) ConnCtx {
c := &connCtx{
nextConn: conn,
closed: make(chan struct{}),
}
return c
}
func (c *connCtx) Read(ctx context.Context, b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
select {
case <-c.closed:
return 0, io.EOF
default:
}
done := make(chan struct{})
var wg sync.WaitGroup
var errSetDeadline atomic.Value
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-ctx.Done():
// context canceled
if err := c.nextConn.SetReadDeadline(veryOld); err != nil {
errSetDeadline.Store(err)
return
}
<-done
if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil {
errSetDeadline.Store(err)
}
case <-done:
}
}()
n, err := c.nextConn.Read(b)
close(done)
wg.Wait()
if e := ctx.Err(); e != nil && n == 0 {
err = e
}
if err2 := errSetDeadline.Load(); err == nil && err2 != nil {
err = err2.(error)
}
return n, err
}
func (c *connCtx) Write(ctx context.Context, b []byte) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()
select {
case <-c.closed:
return 0, ErrClosing
default:
}
done := make(chan struct{})
var wg sync.WaitGroup
var errSetDeadline atomic.Value
wg.Add(1)
go func() {
select {
case <-ctx.Done():
// context canceled
if err := c.nextConn.SetWriteDeadline(veryOld); err != nil {
errSetDeadline.Store(err)
return
}
<-done
if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil {
errSetDeadline.Store(err)
}
case <-done:
}
wg.Done()
}()
n, err := c.nextConn.Write(b)
close(done)
wg.Wait()
if e := ctx.Err(); e != nil && n == 0 {
err = e
}
if err2 := errSetDeadline.Load(); err == nil && err2 != nil {
err = err2.(error)
}
return n, err
}
func (c *connCtx) Close() error {
err := c.nextConn.Close()
c.closeOnce.Do(func() {
c.writeMu.Lock()
c.readMu.Lock()
close(c.closed)
c.readMu.Unlock()
c.writeMu.Unlock()
})
return err
}
func (c *connCtx) LocalAddr() net.Addr {
return c.nextConn.LocalAddr()
}
func (c *connCtx) RemoteAddr() net.Addr {
return c.nextConn.RemoteAddr()
}
func (c *connCtx) Conn() net.Conn {
return c.nextConn
}

View file

@ -0,0 +1,78 @@
package dtls
import (
"net"
"github.com/pion/udp"
)
// Listen creates a DTLS listener
func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, error) {
if err := validateConfig(config); err != nil {
return nil, err
}
lc := udp.ListenConfig{
AcceptFilter: func(packet []byte) bool {
pkts, err := unpackDatagram(packet)
if err != nil || len(pkts) < 1 {
return false
}
h := &recordLayerHeader{}
if err := h.Unmarshal(pkts[0]); err != nil {
return false
}
return h.contentType == contentTypeHandshake
},
}
parent, err := lc.Listen(network, laddr)
if err != nil {
return nil, err
}
return &listener{
config: config,
parent: parent,
}, nil
}
// NewListener creates a DTLS listener which accepts connections from an inner Listener.
func NewListener(inner net.Listener, config *Config) (net.Listener, error) {
if err := validateConfig(config); err != nil {
return nil, err
}
return &listener{
config: config,
parent: inner,
}, nil
}
// listener represents a DTLS listener
type listener struct {
config *Config
parent net.Listener
}
// Accept waits for and returns the next connection to the listener.
// You have to either close or read on all connection that are created.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, set ConnectContextMaker.
func (l *listener) Accept() (net.Conn, error) {
c, err := l.parent.Accept()
if err != nil {
return nil, err
}
return Server(c, l.config)
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
// Already Accepted connections are not closed.
func (l *listener) Close() error {
return l.parent.Close()
}
// Addr returns the listener's network address.
func (l *listener) Addr() net.Addr {
return l.parent.Addr()
}

View file

@ -0,0 +1,62 @@
package dtls
import (
"crypto/elliptic"
"crypto/rand"
"golang.org/x/crypto/curve25519"
)
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8
type namedCurve uint16
type namedCurveKeypair struct {
curve namedCurve
publicKey []byte
privateKey []byte
}
const (
namedCurveP256 namedCurve = 0x0017
namedCurveP384 namedCurve = 0x0018
namedCurveX25519 namedCurve = 0x001d
)
func namedCurves() map[namedCurve]bool {
return map[namedCurve]bool{
namedCurveX25519: true,
namedCurveP256: true,
namedCurveP384: true,
}
}
func generateKeypair(c namedCurve) (*namedCurveKeypair, error) {
switch c { //nolint:golint
case namedCurveX25519:
tmp := make([]byte, 32)
if _, err := rand.Read(tmp); err != nil {
return nil, err
}
var public, private [32]byte
copy(private[:], tmp)
curve25519.ScalarBaseMult(&public, &private)
return &namedCurveKeypair{namedCurveX25519, public[:], private[:]}, nil
case namedCurveP256:
return ellipticCurveKeypair(namedCurveP256, elliptic.P256(), elliptic.P256())
case namedCurveP384:
return ellipticCurveKeypair(namedCurveP384, elliptic.P384(), elliptic.P384())
default:
return nil, errInvalidNamedCurve
}
}
func ellipticCurveKeypair(nc namedCurve, c1, c2 elliptic.Curve) (*namedCurveKeypair, error) {
privateKey, x, y, err := elliptic.GenerateKey(c1, rand.Reader)
if err != nil {
return nil, err
}
return &namedCurveKeypair{nc, elliptic.Marshal(c2, x, y), privateKey}, nil
}

View file

@ -0,0 +1,7 @@
package dtls
type packet struct {
record *recordLayer
shouldEncrypt bool
resetLocalSequenceNumber bool
}

View file

@ -0,0 +1,251 @@
// Package ccm implements a CCM, Counter with CBC-MAC
// as per RFC 3610.
//
// See https://tools.ietf.org/html/rfc3610
//
// This code was lifted from https://github.com/bocajim/dtls/blob/a3300364a283fcb490d28a93d7fcfa7ba437fbbe/ccm/ccm.go
// and as such was not written by the Pions authors. Like Pions this
// code is licensed under MIT.
//
// A request for including CCM into the Go standard library
// can be found as issue #27484 on the https://github.com/golang/go/
// repository.
package ccm
import (
"crypto/cipher"
"crypto/subtle"
"encoding/binary"
"errors"
"math"
)
// ccm represents a Counter with CBC-MAC with a specific key.
type ccm struct {
b cipher.Block
M uint8
L uint8
}
const ccmBlockSize = 16
// CCM is a block cipher in Counter with CBC-MAC mode.
// Providing authenticated encryption with associated data via the cipher.AEAD interface.
type CCM interface {
cipher.AEAD
// MaxLength returns the maxium length of plaintext in calls to Seal.
// The maximum length of ciphertext in calls to Open is MaxLength()+Overhead().
// The maximum length is related to CCM's `L` parameter (15-noncesize) and
// is 1<<(8*L) - 1 (but also limited by the maxium size of an int).
MaxLength() int
}
var (
errInvalidBlockSize = errors.New("ccm: NewCCM requires 128-bit block cipher")
errInvalidTagSize = errors.New("ccm: tagsize must be 4, 6, 8, 10, 12, 14, or 16")
errInvalidNonceSize = errors.New("ccm: invalid nonce size")
)
// NewCCM returns the given 128-bit block cipher wrapped in CCM.
// The tagsize must be an even integer between 4 and 16 inclusive
// and is used as CCM's `M` parameter.
// The noncesize must be an integer between 7 and 13 inclusive,
// 15-noncesize is used as CCM's `L` parameter.
func NewCCM(b cipher.Block, tagsize, noncesize int) (CCM, error) {
if b.BlockSize() != ccmBlockSize {
return nil, errInvalidBlockSize
}
if tagsize < 4 || tagsize > 16 || tagsize&1 != 0 {
return nil, errInvalidTagSize
}
lensize := 15 - noncesize
if lensize < 2 || lensize > 8 {
return nil, errInvalidNonceSize
}
c := &ccm{b: b, M: uint8(tagsize), L: uint8(lensize)}
return c, nil
}
func (c *ccm) NonceSize() int { return 15 - int(c.L) }
func (c *ccm) Overhead() int { return int(c.M) }
func (c *ccm) MaxLength() int { return maxlen(c.L, c.Overhead()) }
func maxlen(l uint8, tagsize int) int {
max := (uint64(1) << (8 * l)) - 1
if m64 := uint64(math.MaxInt64) - uint64(tagsize); l > 8 || max > m64 {
max = m64 // The maximum lentgh on a 64bit arch
}
if max != uint64(int(max)) {
return math.MaxInt32 - tagsize // We have only 32bit int's
}
return int(max)
}
// MaxNonceLength returns the maximum nonce length for a given plaintext length.
// A return value <= 0 indicates that plaintext length is too large for
// any nonce length.
func MaxNonceLength(pdatalen int) int {
const tagsize = 16
for L := 2; L <= 8; L++ {
if maxlen(uint8(L), tagsize) >= pdatalen {
return 15 - L
}
}
return 0
}
func (c *ccm) cbcRound(mac, data []byte) {
for i := 0; i < ccmBlockSize; i++ {
mac[i] ^= data[i]
}
c.b.Encrypt(mac, mac)
}
func (c *ccm) cbcData(mac, data []byte) {
for len(data) >= ccmBlockSize {
c.cbcRound(mac, data[:ccmBlockSize])
data = data[ccmBlockSize:]
}
if len(data) > 0 {
var block [ccmBlockSize]byte
copy(block[:], data)
c.cbcRound(mac, block[:])
}
}
var errPlaintextTooLong = errors.New("ccm: plaintext too large")
func (c *ccm) tag(nonce, plaintext, adata []byte) ([]byte, error) {
var mac [ccmBlockSize]byte
if len(adata) > 0 {
mac[0] |= 1 << 6
}
mac[0] |= (c.M - 2) << 2
mac[0] |= c.L - 1
if len(nonce) != c.NonceSize() {
return nil, errInvalidNonceSize
}
if len(plaintext) > c.MaxLength() {
return nil, errPlaintextTooLong
}
binary.BigEndian.PutUint64(mac[ccmBlockSize-8:], uint64(len(plaintext)))
copy(mac[1:ccmBlockSize-c.L], nonce)
c.b.Encrypt(mac[:], mac[:])
var block [ccmBlockSize]byte
if n := uint64(len(adata)); n > 0 {
// First adata block includes adata length
i := 2
if n <= 0xfeff {
binary.BigEndian.PutUint16(block[:i], uint16(n))
} else {
block[0] = 0xfe
block[1] = 0xff
if n < uint64(1<<32) {
i = 2 + 4
binary.BigEndian.PutUint32(block[2:i], uint32(n))
} else {
i = 2 + 8
binary.BigEndian.PutUint64(block[2:i], n)
}
}
i = copy(block[i:], adata)
c.cbcRound(mac[:], block[:])
c.cbcData(mac[:], adata[i:])
}
if len(plaintext) > 0 {
c.cbcData(mac[:], plaintext)
}
return mac[:c.M], nil
}
// sliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
// From crypto/cipher/gcm.go
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
// Seal encrypts and authenticates plaintext, authenticates the
// additional data and appends the result to dst, returning the updated
// slice. The nonce must be NonceSize() bytes long and unique for all
// time, for a given key.
// The plaintext must be no longer than MaxLength() bytes long.
//
// The plaintext and dst may alias exactly or not at all.
func (c *ccm) Seal(dst, nonce, plaintext, adata []byte) []byte {
tag, err := c.tag(nonce, plaintext, adata)
if err != nil {
// The cipher.AEAD interface doesn't allow for an error return.
panic(err) // nolint
}
var iv, s0 [ccmBlockSize]byte
iv[0] = c.L - 1
copy(iv[1:ccmBlockSize-c.L], nonce)
c.b.Encrypt(s0[:], iv[:])
for i := 0; i < int(c.M); i++ {
tag[i] ^= s0[i]
}
iv[len(iv)-1] |= 1
stream := cipher.NewCTR(c.b, iv[:])
ret, out := sliceForAppend(dst, len(plaintext)+int(c.M))
stream.XORKeyStream(out, plaintext)
copy(out[len(plaintext):], tag)
return ret
}
var (
errOpen = errors.New("ccm: message authentication failed")
errCiphertextTooShort = errors.New("ccm: ciphertext too short")
errCiphertextTooLong = errors.New("ccm: ciphertext too long")
)
func (c *ccm) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) {
if len(ciphertext) < int(c.M) {
return nil, errCiphertextTooShort
}
if len(ciphertext) > c.MaxLength()+c.Overhead() {
return nil, errCiphertextTooLong
}
tag := make([]byte, int(c.M))
copy(tag, ciphertext[len(ciphertext)-int(c.M):])
ciphertextWithoutTag := ciphertext[:len(ciphertext)-int(c.M)]
var iv, s0 [ccmBlockSize]byte
iv[0] = c.L - 1
copy(iv[1:ccmBlockSize-c.L], nonce)
c.b.Encrypt(s0[:], iv[:])
for i := 0; i < int(c.M); i++ {
tag[i] ^= s0[i]
}
iv[len(iv)-1] |= 1
stream := cipher.NewCTR(c.b, iv[:])
// Cannot decrypt directly to dst since we're not supposed to
// reveal the plaintext to the caller if authentication fails.
plaintext := make([]byte, len(ciphertextWithoutTag))
stream.XORKeyStream(plaintext, ciphertextWithoutTag)
expectedTag, err := c.tag(nonce, plaintext, adata)
if err != nil {
return nil, err
}
if subtle.ConstantTimeCompare(tag, expectedTag) != 1 {
return nil, errOpen
}
return append(dst, plaintext...), nil
}

View file

@ -0,0 +1,50 @@
// Package fingerprint provides a helper to create fingerprint string from certificate
package fingerprint
import (
"crypto"
"crypto/x509"
"errors"
"fmt"
)
var (
errHashUnavailable = errors.New("fingerprint: hash algorithm is not linked into the binary")
errInvalidFingerprintLength = errors.New("fingerprint: invalid fingerprint length")
)
// Fingerprint creates a fingerprint for a certificate using the specified hash algorithm
func Fingerprint(cert *x509.Certificate, algo crypto.Hash) (string, error) {
if !algo.Available() {
return "", errHashUnavailable
}
h := algo.New()
for i := 0; i < len(cert.Raw); {
n, _ := h.Write(cert.Raw[i:])
// Hash.Writer is specified to be never returning an error.
// https://golang.org/pkg/hash/#Hash
i += n
}
digest := []byte(fmt.Sprintf("%x", h.Sum(nil)))
digestlen := len(digest)
if digestlen == 0 {
return "", nil
}
if digestlen%2 != 0 {
return "", errInvalidFingerprintLength
}
res := make([]byte, digestlen>>1+digestlen-1)
pos := 0
for i, c := range digest {
res[pos] = c
pos++
if (i)%2 != 0 && i < digestlen-1 {
res[pos] = byte(':')
pos++
}
}
return string(res), nil
}

View file

@ -0,0 +1,37 @@
package fingerprint
import (
"crypto"
"errors"
)
var errInvalidHashAlgorithm = errors.New("fingerprint: invalid hash algorithm")
func nameToHash() map[string]crypto.Hash {
return map[string]crypto.Hash{
"md5": crypto.MD5, // [RFC3279]
"sha-1": crypto.SHA1, // [RFC3279]
"sha-224": crypto.SHA224, // [RFC4055]
"sha-256": crypto.SHA256, // [RFC4055]
"sha-384": crypto.SHA384, // [RFC4055]
"sha-512": crypto.SHA512, // [RFC4055]
}
}
// HashFromString allows looking up a hash algorithm by it's string representation
func HashFromString(s string) (crypto.Hash, error) {
if h, ok := nameToHash()[s]; ok {
return h, nil
}
return 0, errInvalidHashAlgorithm
}
// StringFromHash allows looking up a string representation of the crypto.Hash.
func StringFromHash(hash crypto.Hash) (string, error) {
for s, h := range nameToHash() {
if h == hash {
return s, nil
}
}
return "", errInvalidHashAlgorithm
}

View file

@ -0,0 +1,230 @@
package dtls
import ( //nolint:gci
"crypto/elliptic"
"crypto/hmac"
"crypto/sha1" //nolint:gosec
"encoding/binary"
"fmt"
"hash"
"math"
"golang.org/x/crypto/curve25519"
)
const (
prfMasterSecretLabel = "master secret"
prfExtendedMasterSecretLabel = "extended master secret"
prfKeyExpansionLabel = "key expansion"
prfVerifyDataClientLabel = "client finished"
prfVerifyDataServerLabel = "server finished"
)
type hashFunc func() hash.Hash
type encryptionKeys struct {
masterSecret []byte
clientMACKey []byte
serverMACKey []byte
clientWriteKey []byte
serverWriteKey []byte
clientWriteIV []byte
serverWriteIV []byte
}
func (e *encryptionKeys) String() string {
return fmt.Sprintf(`encryptionKeys:
- masterSecret: %#v
- clientMACKey: %#v
- serverMACKey: %#v
- clientWriteKey: %#v
- serverWriteKey: %#v
- clientWriteIV: %#v
- serverWriteIV: %#v
`,
e.masterSecret,
e.clientMACKey,
e.serverMACKey,
e.clientWriteKey,
e.serverWriteKey,
e.clientWriteIV,
e.serverWriteIV)
}
// The premaster secret is formed as follows: if the PSK is N octets
// long, concatenate a uint16 with the value N, N zero octets, a second
// uint16 with the value N, and the PSK itself.
//
// https://tools.ietf.org/html/rfc4279#section-2
func prfPSKPreMasterSecret(psk []byte) []byte {
pskLen := uint16(len(psk))
out := append(make([]byte, 2+pskLen+2), psk...)
binary.BigEndian.PutUint16(out, pskLen)
binary.BigEndian.PutUint16(out[2+pskLen:], pskLen)
return out
}
func prfPreMasterSecret(publicKey, privateKey []byte, curve namedCurve) ([]byte, error) {
switch curve {
case namedCurveX25519:
return curve25519.X25519(privateKey, publicKey)
case namedCurveP256:
return ellipticCurvePreMasterSecret(publicKey, privateKey, elliptic.P256(), elliptic.P256())
case namedCurveP384:
return ellipticCurvePreMasterSecret(publicKey, privateKey, elliptic.P384(), elliptic.P384())
default:
return nil, errInvalidNamedCurve
}
}
func ellipticCurvePreMasterSecret(publicKey, privateKey []byte, c1, c2 elliptic.Curve) ([]byte, error) {
x, y := elliptic.Unmarshal(c1, publicKey)
if x == nil || y == nil {
return nil, errInvalidNamedCurve
}
result, _ := c2.ScalarMult(x, y, privateKey)
preMasterSecret := make([]byte, (c2.Params().BitSize+7)>>3)
resultBytes := result.Bytes()
copy(preMasterSecret[len(preMasterSecret)-len(resultBytes):], resultBytes)
return preMasterSecret, nil
}
// This PRF with the SHA-256 hash function is used for all cipher suites
// defined in this document and in TLS documents published prior to this
// document when TLS 1.2 is negotiated. New cipher suites MUST explicitly
// specify a PRF and, in general, SHOULD use the TLS PRF with SHA-256 or a
// stronger standard hash function.
//
// P_hash(secret, seed) = HMAC_hash(secret, A(1) + seed) +
// HMAC_hash(secret, A(2) + seed) +
// HMAC_hash(secret, A(3) + seed) + ...
//
// A() is defined as:
//
// A(0) = seed
// A(i) = HMAC_hash(secret, A(i-1))
//
// P_hash can be iterated as many times as necessary to produce the
// required quantity of data. For example, if P_SHA256 is being used to
// create 80 bytes of data, it will have to be iterated three times
// (through A(3)), creating 96 bytes of output data; the last 16 bytes
// of the final iteration will then be discarded, leaving 80 bytes of
// output data.
//
// https://tools.ietf.org/html/rfc4346w
func prfPHash(secret, seed []byte, requestedLength int, h hashFunc) ([]byte, error) {
hmacSHA256 := func(key, data []byte) ([]byte, error) {
mac := hmac.New(h, key)
if _, err := mac.Write(data); err != nil {
return nil, err
}
return mac.Sum(nil), nil
}
var err error
lastRound := seed
out := []byte{}
iterations := int(math.Ceil(float64(requestedLength) / float64(h().Size())))
for i := 0; i < iterations; i++ {
lastRound, err = hmacSHA256(secret, lastRound)
if err != nil {
return nil, err
}
withSecret, err := hmacSHA256(secret, append(lastRound, seed...))
if err != nil {
return nil, err
}
out = append(out, withSecret...)
}
return out[:requestedLength], nil
}
func prfExtendedMasterSecret(preMasterSecret, sessionHash []byte, h hashFunc) ([]byte, error) {
seed := append([]byte(prfExtendedMasterSecretLabel), sessionHash...)
return prfPHash(preMasterSecret, seed, 48, h)
}
func prfMasterSecret(preMasterSecret, clientRandom, serverRandom []byte, h hashFunc) ([]byte, error) {
seed := append(append([]byte(prfMasterSecretLabel), clientRandom...), serverRandom...)
return prfPHash(preMasterSecret, seed, 48, h)
}
func prfEncryptionKeys(masterSecret, clientRandom, serverRandom []byte, prfMacLen, prfKeyLen, prfIvLen int, h hashFunc) (*encryptionKeys, error) {
seed := append(append([]byte(prfKeyExpansionLabel), serverRandom...), clientRandom...)
keyMaterial, err := prfPHash(masterSecret, seed, (2*prfMacLen)+(2*prfKeyLen)+(2*prfIvLen), h)
if err != nil {
return nil, err
}
clientMACKey := keyMaterial[:prfMacLen]
keyMaterial = keyMaterial[prfMacLen:]
serverMACKey := keyMaterial[:prfMacLen]
keyMaterial = keyMaterial[prfMacLen:]
clientWriteKey := keyMaterial[:prfKeyLen]
keyMaterial = keyMaterial[prfKeyLen:]
serverWriteKey := keyMaterial[:prfKeyLen]
keyMaterial = keyMaterial[prfKeyLen:]
clientWriteIV := keyMaterial[:prfIvLen]
keyMaterial = keyMaterial[prfIvLen:]
serverWriteIV := keyMaterial[:prfIvLen]
return &encryptionKeys{
masterSecret: masterSecret,
clientMACKey: clientMACKey,
serverMACKey: serverMACKey,
clientWriteKey: clientWriteKey,
serverWriteKey: serverWriteKey,
clientWriteIV: clientWriteIV,
serverWriteIV: serverWriteIV,
}, nil
}
func prfVerifyData(masterSecret, handshakeBodies []byte, label string, hashFunc hashFunc) ([]byte, error) {
h := hashFunc()
if _, err := h.Write(handshakeBodies); err != nil {
return nil, err
}
seed := append([]byte(label), h.Sum(nil)...)
return prfPHash(masterSecret, seed, 12, hashFunc)
}
func prfVerifyDataClient(masterSecret, handshakeBodies []byte, h hashFunc) ([]byte, error) {
return prfVerifyData(masterSecret, handshakeBodies, prfVerifyDataClientLabel, h)
}
func prfVerifyDataServer(masterSecret, handshakeBodies []byte, h hashFunc) ([]byte, error) {
return prfVerifyData(masterSecret, handshakeBodies, prfVerifyDataServerLabel, h)
}
// compute the MAC using HMAC-SHA1
func prfMac(epoch uint16, sequenceNumber uint64, contentType contentType, protocolVersion protocolVersion, payload []byte, key []byte) ([]byte, error) {
h := hmac.New(sha1.New, key)
msg := make([]byte, 13)
binary.BigEndian.PutUint16(msg, epoch)
putBigEndianUint48(msg[2:], sequenceNumber)
msg[8] = byte(contentType)
msg[9] = protocolVersion.major
msg[10] = protocolVersion.minor
binary.BigEndian.PutUint16(msg[11:], uint16(len(payload)))
if _, err := h.Write(msg); err != nil {
return nil, err
} else if _, err := h.Write(payload); err != nil {
return nil, err
}
return h.Sum(nil), nil
}

View file

@ -0,0 +1,93 @@
package dtls
import (
"encoding/binary"
)
/*
The TLS Record Layer which handles all data transport.
The record layer is assumed to sit directly on top of some
reliable transport such as TCP. The record layer can carry four types of content:
1. Handshake messagesused for algorithm negotiation and key establishment.
2. ChangeCipherSpec messagesreally part of the handshake but technically a separate kind of message.
3. Alert messagesused to signal that errors have occurred
4. Application layer data
The DTLS record layer is extremely similar to that of TLS 1.1. The
only change is the inclusion of an explicit sequence number in the
record. This sequence number allows the recipient to correctly
verify the TLS MAC.
https://tools.ietf.org/html/rfc4347#section-4.1
*/
type recordLayer struct {
recordLayerHeader recordLayerHeader
content content
}
func (r *recordLayer) Marshal() ([]byte, error) {
contentRaw, err := r.content.Marshal()
if err != nil {
return nil, err
}
r.recordLayerHeader.contentLen = uint16(len(contentRaw))
r.recordLayerHeader.contentType = r.content.contentType()
headerRaw, err := r.recordLayerHeader.Marshal()
if err != nil {
return nil, err
}
return append(headerRaw, contentRaw...), nil
}
func (r *recordLayer) Unmarshal(data []byte) error {
if len(data) < recordLayerHeaderSize {
return errBufferTooSmall
}
if err := r.recordLayerHeader.Unmarshal(data); err != nil {
return err
}
switch contentType(data[0]) {
case contentTypeChangeCipherSpec:
r.content = &changeCipherSpec{}
case contentTypeAlert:
r.content = &alert{}
case contentTypeHandshake:
r.content = &handshake{}
case contentTypeApplicationData:
r.content = &applicationData{}
default:
return errInvalidContentType
}
return r.content.Unmarshal(data[recordLayerHeaderSize:])
}
// Note that as with TLS, multiple handshake messages may be placed in
// the same DTLS record, provided that there is room and that they are
// part of the same flight. Thus, there are two acceptable ways to pack
// two DTLS messages into the same datagram: in the same record or in
// separate records.
// https://tools.ietf.org/html/rfc6347#section-4.2.3
func unpackDatagram(buf []byte) ([][]byte, error) {
out := [][]byte{}
for offset := 0; len(buf) != offset; {
if len(buf)-offset <= recordLayerHeaderSize {
return nil, errInvalidPacketLength
}
pktLen := (recordLayerHeaderSize + int(binary.BigEndian.Uint16(buf[offset+11:])))
if offset+pktLen > len(buf) {
return nil, errInvalidPacketLength
}
out = append(out, buf[offset:offset+pktLen])
offset += pktLen
}
return out, nil
}

View file

@ -0,0 +1,76 @@
package dtls
import "encoding/binary"
type recordLayerHeader struct {
contentType contentType
contentLen uint16
protocolVersion protocolVersion
epoch uint16
sequenceNumber uint64 // uint48 in spec
}
const (
recordLayerHeaderSize = 13
maxSequenceNumber = 0x0000FFFFFFFFFFFF
dtls1_2Major = 0xfe
dtls1_2Minor = 0xfd
dtls1_0Major = 0xfe
dtls1_0Minor = 0xff
// VersionDTLS12 is the DTLS version in the same style as
// VersionTLSXX from crypto/tls
VersionDTLS12 = 0xfefd
)
var (
protocolVersion1_0 = protocolVersion{dtls1_0Major, dtls1_0Minor} //nolint:gochecknoglobals
protocolVersion1_2 = protocolVersion{dtls1_2Major, dtls1_2Minor} //nolint:gochecknoglobals
)
// https://tools.ietf.org/html/rfc4346#section-6.2.1
type protocolVersion struct {
major, minor uint8
}
func (v protocolVersion) Equal(x protocolVersion) bool {
return v.major == x.major && v.minor == x.minor
}
func (r *recordLayerHeader) Marshal() ([]byte, error) {
if r.sequenceNumber > maxSequenceNumber {
return nil, errSequenceNumberOverflow
}
out := make([]byte, recordLayerHeaderSize)
out[0] = byte(r.contentType)
out[1] = r.protocolVersion.major
out[2] = r.protocolVersion.minor
binary.BigEndian.PutUint16(out[3:], r.epoch)
putBigEndianUint48(out[5:], r.sequenceNumber)
binary.BigEndian.PutUint16(out[recordLayerHeaderSize-2:], r.contentLen)
return out, nil
}
func (r *recordLayerHeader) Unmarshal(data []byte) error {
if len(data) < recordLayerHeaderSize {
return errBufferTooSmall
}
r.contentType = contentType(data[0])
r.protocolVersion.major = data[1]
r.protocolVersion.minor = data[2]
r.epoch = binary.BigEndian.Uint16(data[3:])
// SequenceNumber is stored as uint48, make into uint64
seqCopy := make([]byte, 8)
copy(seqCopy[2:], data[5:11])
r.sequenceNumber = binary.BigEndian.Uint64(seqCopy)
if !r.protocolVersion.Equal(protocolVersion1_0) && !r.protocolVersion.Equal(protocolVersion1_2) {
return errUnsupportedProtocolVersion
}
return nil
}

View file

@ -0,0 +1,15 @@
{
"extends": [
"config:base"
],
"postUpdateOptions": [
"gomodTidy"
],
"commitBody": "Generated by renovateBot",
"packageRules": [
{
"packagePatterns": ["^golang.org/x/"],
"schedule": ["on the first day of the month"]
}
]
}

View file

@ -0,0 +1,19 @@
package dtls
import (
"context"
"net"
)
// Resume imports an already established dtls connection using a specific dtls state
func Resume(state *State, conn net.Conn, config *Config) (*Conn, error) {
if err := state.initCipherSuite(); err != nil {
return nil, err
}
c, err := createConn(context.Background(), conn, config, state.isClient, state)
if err != nil {
return nil, err
}
return c, nil
}

Some files were not shown because too many files have changed in this diff Show more