1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-02-12 19:31:53 +00:00

Squash: Fix rtc to rtmp sync timestamp using sender report. #2470

This commit is contained in:
winlin 2021-08-17 07:25:03 +08:00
parent 3d58e98d1c
commit 85620a34f5
309 changed files with 14837 additions and 8525 deletions

219
README.md
View file

@ -5,7 +5,7 @@
[![](https://github.com/ossrs/srs/actions/workflows/release.yml/badge.svg)](https://github.com/ossrs/srs/actions/workflows/release.yml?query=workflow%3ARelease)
[![](https://github.com/ossrs/srs/actions/workflows/test.yml/badge.svg?branch=develop)](https://github.com/ossrs/srs/actions?query=workflow%3ATest+branch%3Adevelop)
[![](https://codecov.io/gh/ossrs/srs/branch/develop/graph/badge.svg)](https://codecov.io/gh/ossrs/srs/branch/develop)
[![](https://gitee.com/winlinvip/srs-wiki/raw/master/images/wechat-badge.png)](../../wikis/Contact#wechat)
[![](https://gitee.com/winlinvip/srs-wiki/raw/master/images/wechat-badge2.png)](../../wikis/Contact#wechat)
[![](https://gitee.com/winlinvip/srs-wiki/raw/master/images/bbs2.png)](http://bbs.ossrs.net)
SRS/4.0 [Leo](https://github.com/ossrs/srs/wiki/v4_CN_Product#release40) 是一个简单高效的实时视频服务器支持RTMP/WebRTC/HLS/HTTP-FLV/SRT。
@ -89,87 +89,9 @@ A big `THANK YOU` also goes to:
* All friends of SRS for [big supports](https://github.com/ossrs/srs/wiki/Product).
* [Genes](http://sourceforge.net/users/genes), [Mabbott](http://sourceforge.net/users/mabbott) and [Michael Talyanksy](https://github.com/michaeltalyansky) for [st](https://github.com/ossrs/state-threads/tree/srs).
## Features
- [x] Using coroutine by ST, it's really simple and stupid enough.
- [x] Support cluster which consists of origin ([CN][v4_CN_DeliveryRTMP],[EN][v4_EN_DeliveryRTMP]) and edge([CN][v4_CN_Edge], [EN][v4_EN_Edge]) server and uses RTMP as default transport protocol.
- [x] Origin server supports remuxing RTMP to HTTP-FLV([CN][v4_CN_SampleHttpFlv], [EN][v4_EN_SampleHttpFlv]) and HLS([CN][v4_CN_DeliveryHLS], [EN][v4_EN_DeliveryHLS]).
- [x] Edge server supports remuxing RTMP to HTTP-FLV([CN][v4_CN_SampleHttpFlv], [EN][v4_EN_SampleHttpFlv]). As for HLS([CN][v4_CN_DeliveryHLS], [EN][v4_EN_DeliveryHLS]) edge server, recomment to use HTTP edge server, such as [NGINX](http://nginx.org/).
- [x] Support HLS with audio-only([CN][v4_CN_DeliveryHLS2], [EN][v4_EN_DeliveryHLS2]), which need to build the timestamp from AAC samples, so we enhanced it please read [#547][bug #547].
- [x] Support HLS with mp3(h.264+mp3) audio codec, please read [bug #301][bug #301].
- [x] Support transmux RTMP to HTTP-FLV/MP3/AAC/TS, please read wiki([CN][v4_CN_DeliveryHttpStream], [EN][v4_CN_DeliveryHttpStream]).
- [x] Support ingesting([CN][v4_CN_Ingest], [EN][v4_EN_Ingest]) other protocols to SRS by FFMPEG.
- [x] Support RTMP long time(>4.6hours) publishing/playing, with the timestamp corrected.
- [x] Support native HTTP server([CN][v4_CN_SampleHTTP], [EN][v4_EN_SampleHTTP]) for http api and http live streaming.
- [x] Support HTTP CORS for js in http api and http live streaming.
- [x] Support HTTP API([CN][v4_CN_HTTPApi], [EN][v4_EN_HTTPApi]) for system management.
- [x] Support HTTP callback([CN][v4_CN_HTTPCallback], [EN][v4_EN_HTTPCallback]) for authentication and integration.
- [x] Support DVR([CN][v4_CN_DVR], [EN][v4_EN_DVR]) to record live streaming to FLV file.
- [x] Support DVR control module like NGINX-RTMP, please read [#459][bug #459].
- [x] Support EXEC like NGINX-RTMP, please read [bug #367][bug #367].
- [x] Support security strategy including allow/deny publish/play IP([CN][v4_CN_Security], [EN][v4_EN_Security]).
- [x] Support low latency(0.1s+) transport model, please read [bug #257][bug #257].
- [x] Support gop-cache([CN][v4_CN_LowLatency2], [EN][v4_EN_LowLatency2]) for player fast startup.
- [x] Support Vhost([CN][v4_CN_RtmpUrlVhost], [EN][v4_EN_RtmpUrlVhost]) and \_\_defaultVhost\_\_.
- [x] Support reloading([CN][v4_CN_Reload], [EN][v4_EN_Reload]) to apply changes of config.
- [x] Support listening at multiple ports.
- [x] Support forwarding([CN][v4_CN_Forward], [EN][v4_EN_Forward]) to other RTMP servers.
- [x] Support transcoding([CN][v4_CN_FFMPEG], [EN][v4_EN_FFMPEG]) by FFMPEG.
- [x] All wikis are writen in [Chinese][v4_CN_Home] and [English][v4_EN_Home].
- [x] Enhanced json, replace NXJSON(LGPL) with json-parser(BSD), read [#904][bug #904].
- [x] Support valgrind and latest ARM by patching ST, read [ST#1](https://github.com/ossrs/state-threads/issues/1) and [ST#2](https://github.com/ossrs/state-threads/issues/2).
- [x] Support traceable and session-based log([CN][v4_CN_SrsLog], [EN][v4_EN_SrsLog]).
- [x] High performance([CN][v4_CN_Performance], [EN][v4_EN_Performance]) RTMP/HTTP-FLV, 6000+ connections.
- [x] Enhanced complex error code with description and stack, read [#913][bug #913].
- [x] Enhanced RTMP url which supports vhost in stream, read [#1059][bug #1059].
- [x] Support origin cluster, please read [#464][bug #464], [RTMP 302][bug #92].
- [x] Support listen at IPv4 and IPv6, read [#460][bug #460].
- [x] Improve test coverage for core/kernel/protocol/service.
- [x] Support docker by [srs-docker](https://github.com/ossrs/srs-docker).
- [x] Support multiple processes by ReusePort([CN][v4_CN_REUSEPORT], [EN][v4_EN_REUSEPORT]), [#775][bug #775].
- [x] Support a simple [mgmt console](http://ossrs.net:8080/console), please read [srs-console](https://github.com/ossrs/srs-console).
- [x] [Experimental] Support playing stream by WebRTC, [#307][bug #307].
- [x] [Experimental] Support publishing stream by WebRTC, [#307][bug #307].
- [x] [Experimental] Support mux RTP/RTCP/DTLS/SRTP on one port for WebRTC, [#307][bug #307].
- [x] [Experimental] Support client address changing for WebRTC, [#307][bug #307].
- [x] [Experimental] Support transcode RTMP/AAC to WebRTC/Opus, [#307][bug #307].
- [x] [Experimental] Support AV1 codec for WebRTC, [#2324][bug #2324].
- [x] [Experimental] Enhance HTTP Stream Server for HTTP-FLV, HTTPS, HLS etc. [#1657][bug #1657].
- [x] [Experimental] Support DVR in MP4 format, read [#738][bug #738].
- [x] [Experimental] Support MPEG-DASH, the future live streaming protocol, read [#299][bug #299].
- [x] [Experimental] Support pushing MPEG-TS over UDP, please read [bug #250][bug #250].
- [x] [Experimental] Support pushing FLV over HTTP POST, please read wiki([CN][v4_CN_Streamer2], [EN][v4_EN_Streamer2]).
- [x] [Experimental] Support HTTP RAW API, please read [#459][bug #459], [#470][bug #470], [#319][bug #319].
- [x] [Experimental] Support SRT server, read [#1147][bug #1147].
- [x] [Experimental] Support transmux RTC to RTMP, [#2093][bug #2093].
- [x] [Deprecated] Support pushing RTSP, please read [bug #2304][bug #2304].
- [x] [Deprecated] Support Adobe HDS(f4m), please read wiki([CN][v4_CN_DeliveryHDS], [EN][v4_EN_DeliveryHDS]) and [#1535][bug #1535].
- [x] [Deprecated] Support bandwidth testing, please read [#1535][bug #1535].
- [x] [Deprecated] Support Adobe FMS/AMS token traverse([CN][v4_CN_DRM2], [EN][v4_EN_DRM2]) authentication, please read [#1535][bug #1535].
- [x] [Removed] Support RTMP client library: [srs-librtmp](https://github.com/ossrs/srs-librtmp).
- [ ] Support Windows/Cygwin 64bits, [#2532](https://github.com/ossrs/srs/issues/2532).
- [ ] Support push stream by GB28181, [#1500][bug #1500].
- [ ] Support IETF-QUIC for WebRTC Cluster, [#2091][bug #2091].
- [ ] Enhanced forwarding with vhost and variables, [#1342][bug #1342].
- [ ] Support DVR to Cloud Storage, [#1193][bug #1193].
- [ ] Support H.265 over RTMP and HLS, [#465][bug #465].
- [ ] Improve RTC performance to 5K by multiple threading, [#2188][bug #2188].
- [ ] Support source cleanup for idle streams, [#413][bug #413].
- [ ] Support change user to run SRS, [#1111][bug #1111].
- [ ] Support HLS variant, [#463][bug #463].
> Remark: About the milestone and product plan, please read ([CN][v4_CN_Product], [EN][v4_EN_Product]) wiki.
<a name="history"></a>
<a name="changes"></a>
<a name="change-logs"></a>
## Changelog
Please read [CHANGELOG](CHANGELOG.md#changelog).
## Releases
* 2020-08-15, Release [v4.0.156](https://github.com/ossrs/srs/releases/tag/v4.0.156), 4.0 dev4, v4.0.156, 145490 lines.
* 2020-08-14, Release [v4.0.153](https://github.com/ossrs/srs/releases/tag/v4.0.153), 4.0 dev3, v4.0.153, 145506 lines.
* 2020-08-07, Release [v4.0.150](https://github.com/ossrs/srs/releases/tag/v4.0.150), 4.0 dev2, v4.0.150, 145289 lines.
* 2020-07-25, Release [v4.0.146](https://github.com/ossrs/srs/releases/tag/v4.0.146), 4.0 dev1, v4.0.146, 144026 lines.
@ -186,6 +108,17 @@ Please read [CHANGELOG](CHANGELOG.md#changelog).
* 2013-10-23, [Release v0.1.0][r0.1], rtmp. 8287 lines.
* 2013-10-17, Created.
## Features
Please read [FEATURES](trunk/doc/Features.md#features).
<a name="history"></a>
<a name="change-logs"></a>
## Changelog
Please read [CHANGELOG](trunk/doc/CHANGELOG.md#changelog).
## Compare
Comparing with other media servers, SRS is much better and stronger, for details please
@ -193,139 +126,23 @@ read Product([CN][v4_CN_Compare]/[EN][v4_EN_Compare]).
## Performance
Please read [PERFORMANCE](PERFORMANCE.md#performance).
Please read [PERFORMANCE](trunk/doc/PERFORMANCE.md#performance).
## Architecture
The stream architecture of SRS.
```
+----------+ +----------+
| Upstream | | Deliver |
+---|------+ +----|-----+
+---+------------------+------+---------------------+----------------+
| Input | SRS(Simple RTMP Server) | Output |
+----------------------+----------------------------+----------------+
| | +-> DASH ----------------+-> DASH player |
| Encoder(1) | +-> RTMP/HDS -----------+-> Flash player |
| (FMLE,OBS, --RTMP-+->-+-> HLS/HTTP ------------+-> M3U8 player |
| FFmpeg,XSPLIT, | +-> FLV/MP3/Aac/Ts ------+-> HTTP player |
| ......) | +-> Fowarder ------------+-> RTMP server |
| | +-> Transcoder ----------+-> RTMP server |
| | +-> EXEC(5) -------------+-> External app |
| | +-> DVR -----------------+-> FLV file |
| | +-> BandwidthTest -------+-> Flash |
| | +-> WebRTC --------------+-> Flash |
+----------------------+ | |
| WebRTC Client | +--> RTMP | |
| (H5,Native...) --RTC-+---+---> WebRTC ------------+-> WebRTC Client|
+----------------------+ | |
| MediaSource(2) | | |
| (RTSP,FILE, | | |
| HTTP,HLS, --pull-+->-- Ingester(3) -(rtmp)----+-> SRS |
| Device, | | |
| ......) | | |
+----------------------+ | |
| MediaSource(2) | | |
| (MPEGTSoverUDP | | |
| HTTP-FLV, --push-+->- StreamCaster(4) -(rtmp)-+-> SRS |
| SRT, | | |
| ......) | | |
+----------------------+ | |
| FFMPEG --push(srt)--+->- SRTModule(5) ---(rtmp)-+-> SRS |
+----------------------+----------------------------+----------------+
```
Remark:
1. Encoder: Encoder pushs RTMP stream to SRS.
1. MediaSource: Supports any media source, ingesting by ffmpeg.
1. Ingester: Forks a ffmpeg(or other tools) to ingest as rtmp to SRS, please read [Ingest][v4_CN_Ingest].
1. Streamer: Remuxs other protocols to RTMP, please read [Streamer][v4_CN_Streamer].
1. EXEC: Like NGINX-RTMP, EXEC forks external tools for events, please read [ng-exec][v4_CN_NgExec].
1. SRTModule: A isolate module which run in [hybrid](https://github.com/ossrs/srs/issues/1147#issuecomment-577574883) model.
Please read [ARCHITECTURE](trunk/doc/Architecture.md#architecture).
## Ports
The ports used by SRS, kernel services:
* `tcp://1935`, for RTMP live streaming server([CN][v4_CN_DeliveryRTMP],[EN][v4_EN_DeliveryRTMP]).
* `tcp://1985`, HTTP API server, for HTTP-API([CN][v4_CN_HTTPApi], [EN][v4_EN_HTTPApi]), WebRTC([CN][v4_CN_WebRTC], [EN][v4_EN_WebRTC]), etc.
* `tcp://8080`, HTTP live streaming server, HTTP-FLV([CN][v4_CN_SampleHttpFlv], [EN][v4_EN_SampleHttpFlv]), HLS([CN][v4_CN_SampleHLS], [EN][v4_EN_SampleHLS]) as such.
* `udp://8000`, WebRTC Media([CN][v4_CN_WebRTC], [EN][v4_EN_WebRTC]) server.
For optional HTTPS services, which might be provided by other web servers:
* `tcp://8088`, HTTPS live streaming server.
* `tcp://1990`, HTTPS API server.
For optional stream caster services, to push streams to SRS:
* `udp://8935`, Stream Caster: [Push MPEGTS over UDP](https://github.com/ossrs/srs/wiki/v4_CN_Streamer#push-mpeg-ts-over-udp) server.
* `tcp://554`, Stream Caster: [Push RTSP](https://github.com/ossrs/srs/wiki/v4_CN_Streamer#push-rtsp-to-srs) server.
* `tcp://8936`, Stream Caster: [Push HTTP-FLV](https://github.com/ossrs/srs/wiki/v4_CN_Streamer#push-http-flv-to-srs) server.
* `udp://10080`, Stream Caster: [Push SRT Media](https://github.com/ossrs/srs/issues/1147#issuecomment-577469119) server.
For external services to work with SRS:
* `udp://1989`, [WebRTC Signaling](https://github.com/ossrs/signaling#usage) server.
Please read [PORTS](trunk/doc/Resources.md#ports).
## APIs
The API used by SRS:
* `/api/v1/` The HTTP API path.
* `/rtc/v1/` The HTTP API path for RTC.
* `/sig/v1/` The [demo signaling](https://github.com/ossrs/signaling) API.
Other API used by [ossrs.net](https://ossrs.net):
* `/gif/v1` The statistic API.
* `/service/v1/` The latest available version API.
* `/ws-service/v1/` The latest available version API, by websocket.
* `/im-service/v1/` The latest available version API, by IM.
* `/code-service/v1/` The latest available version API, by Code verification.
Please read [APIS](trunk/doc/Resources.md#apis).
## Mirrors
Gitee: [https://gitee.com/ossrs/srs](https://gitee.com/ossrs/srs), the GIT usage([CN][v4_CN_Git], [EN][v4_EN_Git])
```
git clone https://gitee.com/ossrs/srs.git &&
cd srs && git remote set-url origin https://github.com/ossrs/srs.git && git pull
```
> Remark: For users in China, recomment to use mirror from CSDN or OSChina, because they are much faster.
Gitlab: [https://gitlab.com/winlinvip/srs-gitlab](https://gitlab.com/winlinvip/srs-gitlab), the GIT usage([CN][v4_CN_Git], [EN][v4_EN_Git])
```
git clone https://gitlab.com/winlinvip/srs-gitlab.git srs &&
cd srs && git remote set-url origin https://github.com/ossrs/srs.git && git pull
```
Github: [https://github.com/ossrs/srs](https://github.com/ossrs/srs), the GIT usage([CN][v4_CN_Git], [EN][v4_EN_Git])
```
git clone https://github.com/ossrs/srs.git
```
| Branch | Cost | Size | CMD |
| --- | --- | --- | --- |
| 3.0release | 2m19.931s | 262MB | git clone -b 3.0release https://gitee.com/ossrs/srs.git |
| 3.0release | 0m56.515s | 95MB | git clone -b 3.0release --depth=1 https://gitee.com/ossrs/srs.git |
| develop | 2m22.430s | 234MB | git clone -b develop https://gitee.com/ossrs/srs.git |
| develop | 0m46.421s | 42MB | git clone -b develop --depth=1 https://gitee.com/ossrs/srs.git |
| min | 2m22.865s | 217MB | git clone -b min https://gitee.com/ossrs/srs.git |
| min | 0m36.472s | 11MB | git clone -b min --depth=1 https://gitee.com/ossrs/srs.git |
## System Requirements
Supported operating systems and hardware:
* Linux, with x86, x86-64 or arm.
* Mac, with intel chip.
* Other OS, such as Windows, please use [docker](https://github.com/ossrs/srs-docker/tree/v4#usage).
Please read [MIRRORS](trunk/doc/Resources.md#mirrors).
Beijing, 2013.10<br/>
Winlin

View file

@ -4,9 +4,11 @@ go 1.15
require (
github.com/ossrs/go-oryx-lib v0.0.8
github.com/pion/interceptor v0.0.9
github.com/pion/interceptor v0.0.10
github.com/pion/logging v0.2.2
github.com/pion/rtcp v1.2.6
github.com/pion/rtp v1.6.2
github.com/pion/sdp/v3 v3.0.4
github.com/pion/webrtc/v3 v3.0.4
github.com/pion/transport v0.12.2
github.com/pion/webrtc/v3 v3.0.13
)

View file

@ -14,6 +14,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.1.5 h1:kxhtnfFVi+rYdOALN0B3k9UT86zVJKfBimRaciULW4I=
github.com/google/uuid v1.1.5/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
@ -31,10 +33,16 @@ github.com/pion/datachannel v1.4.21 h1:3ZvhNyfmxsAqltQrApLPQMhSFNA+aT87RqyCq4OXm
github.com/pion/datachannel v1.4.21/go.mod h1:oiNyP4gHx2DIwRzX/MFyH0Rz/Gz05OgBlayAI2hAWjg=
github.com/pion/dtls/v2 v2.0.4 h1:WuUcqi6oYMu/noNTz92QrF1DaFj4eXbhQ6dzaaAwOiI=
github.com/pion/dtls/v2 v2.0.4/go.mod h1:qAkFscX0ZHoI1E07RfYPoRw3manThveu+mlTDdOxoGI=
github.com/pion/dtls/v2 v2.0.8 h1:reGe8rNIMfO/UAeFLqO61tl64t154Qfkr4U3Gzu1tsg=
github.com/pion/dtls/v2 v2.0.8/go.mod h1:QuDII+8FVvk9Dp5t5vYIMTo7hh7uBkra+8QIm7QGm10=
github.com/pion/ice/v2 v2.0.14 h1:FxXxauyykf89SWAtkQCfnHkno6G8+bhRkNguSh9zU+4=
github.com/pion/ice/v2 v2.0.14/go.mod h1:wqaUbOq5ObDNU5ox1hRsEst0rWfsKuH1zXjQFEWiZwM=
github.com/pion/ice/v2 v2.0.15 h1:KZrwa2ciL9od8+TUVJiYTNsCW9J5lktBjGwW1MacEnQ=
github.com/pion/ice/v2 v2.0.15/go.mod h1:ZIiVGevpgAxF/cXiIVmuIUtCb3Xs4gCzCbXB6+nFkSI=
github.com/pion/interceptor v0.0.9 h1:fk5hTdyLO3KURQsf/+RjMpEm4NE3yeTY9Kh97b5BvwA=
github.com/pion/interceptor v0.0.9/go.mod h1:dHgEP5dtxOTf21MObuBAjJeAayPxLUAZjerGH8Xr07c=
github.com/pion/interceptor v0.0.10 h1:dXFyFWRJFwmzQqyn0U8dUAbOJu+JJnMVAqxmvTu30B4=
github.com/pion/interceptor v0.0.10/go.mod h1:qzeuWuD/ZXvPqOnxNcnhWfkCZ2e1kwwslicyyPnhoK4=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY=
@ -58,6 +66,7 @@ github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRh
github.com/pion/transport v0.10.0/go.mod h1:BnHnUipd0rZQyTVB2SBGojFHT9CBt5C5TcsJSQGkvSE=
github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A=
github.com/pion/transport v0.12.0/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
github.com/pion/transport v0.12.1/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
github.com/pion/transport v0.12.2 h1:WYEjhloRHt1R86LhUKjC5y+P52Y11/QqEUalvtzVoys=
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
github.com/pion/turn/v2 v2.0.5 h1:iwMHqDfPEDEOFzwWKT56eFmh6DYC6o/+xnLAEzgISbA=
@ -67,6 +76,8 @@ github.com/pion/udp v0.1.0/go.mod h1:BPELIjbwE9PRbd/zxI/KYBnbo7B6+oA6YuEaNE8lths
github.com/pion/webrtc v1.2.0 h1:3LGGPQEMacwG2hcDfhdvwQPz315gvjZXOfY4vaF4+I4=
github.com/pion/webrtc/v3 v3.0.4 h1:Tiw3H9fpfcwkvaxonB+Gv1DG9tmgYBQaM1vBagDHP40=
github.com/pion/webrtc/v3 v3.0.4/go.mod h1:1TmFSLpPYFTFXFHPtoq9eGP1ASTa9LC6FBh7sUY8cd4=
github.com/pion/webrtc/v3 v3.0.13 h1:iyR3xf4eQLLatfvAOhjf/vHBi8x9y1TGeJqOHq7TjE4=
github.com/pion/webrtc/v3 v3.0.13/go.mod h1:+7cDZgV7jKkm4H+f0ki2wiMSuZtyFlezKLfBR2hntcQ=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@ -80,6 +91,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 h1:pLI5jrR7OSLijeIDcmRxNmw2api+jEfxLoykJVice/E=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -90,17 +103,24 @@ golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7 h1:3uJsdck53FDIpWwLeAXlia9p4C8j0BO2xZrqzKpL0D8=
golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210119194325-5f4716e94777 h1:003p0dJM77cxMSyCPFphvZf/Y5/NXf5fzg6ufd1/Oew=
golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=

View file

@ -24,8 +24,6 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/pion/transport/vnet"
"github.com/pion/webrtc/v3"
"io"
"io/ioutil"
"math/rand"
@ -36,7 +34,11 @@ import (
"testing"
"time"
"github.com/pion/transport/vnet"
"github.com/pion/webrtc/v3"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/flv"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/pion/interceptor"
"github.com/pion/rtcp"
@ -1936,3 +1938,162 @@ func TestRTCServerVersion(t *testing.T) {
return
}
}
func TestRtcPublishFlvPlay(t *testing.T) {
ctx := logger.WithContext(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Duration(*srsTimeout)*time.Millisecond)
var r0, r1, r2, r3 error
defer func(ctx context.Context) {
if err := filterTestError(ctx.Err(), r0, r1, r2, r3); err != nil {
t.Errorf("Fail for err %+v", err)
} else {
logger.Tf(ctx, "test done with err %+v", err)
}
}(ctx)
var resources []io.Closer
defer func() {
for _, resource := range resources {
_ = resource.Close()
}
}()
var wg sync.WaitGroup
defer wg.Wait()
// The event notify.
var thePublisher *testPublisher
mainReady, mainReadyCancel := context.WithCancel(context.Background())
publishReady, publishReadyCancel := context.WithCancel(context.Background())
streamSuffix := fmt.Sprintf("basic-publish-flvplay-%v-%v", os.Getpid(), rand.Int())
// Objects init.
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
doInit := func() (err error) {
// Initialize publisher with private api.
if thePublisher, err = newTestPublisher(registerDefaultCodecs, func(pub *testPublisher) error {
pub.streamSuffix = streamSuffix
pub.iceReadyCancel = publishReadyCancel
resources = append(resources, pub)
return pub.Setup(*srsVnetClientIP)
}); err != nil {
return err
}
// Init done.
mainReadyCancel()
<-ctx.Done()
return nil
}
if err := doInit(); err != nil {
r1 = err
}
}()
// Run publisher.
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
select {
case <-ctx.Done():
case <-mainReady.Done():
r2 = thePublisher.Run(logger.WithContext(ctx), cancel)
logger.Tf(ctx, "pub done")
}
}()
// Run player.
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
select {
case <-ctx.Done():
case <-publishReady.Done():
var url string = "http://127.0.0.1:8080" + *srsStream + "-" + streamSuffix + ".flv"
logger.Tf(ctx, "Run play flv url=%v", url)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
logger.Tf(ctx, "New request for flv %v failed, err=%v", url, err)
return
}
client := http.Client{}
resp, err := client.Do(req)
if err != nil {
logger.Tf(ctx, "Http get flv %v failed, err=%v", url, err)
return
}
var f flv.Demuxer
if f, err = flv.NewDemuxer(resp.Body); err != nil {
logger.Tf(ctx, "Create flv demuxer for %v failed, err=%v", url, err)
return
}
defer f.Close()
var version uint8
var hasVideo, hasAudio bool
if version, hasVideo, hasAudio, err = f.ReadHeader(); err != nil {
logger.Tf(ctx, "Flv demuxer read header failed, err=%v", err)
return
}
// Optional, user can check the header.
_ = version
_ = hasAudio
_ = hasVideo
var nnVideo, nnAudio int
var prevVideoTimestamp, prevAudioTimestamp int64
for {
var tagType flv.TagType
var tagSize, timestamp uint32
if tagType, tagSize, timestamp, err = f.ReadTagHeader(); err != nil {
logger.Tf(ctx, "Flv demuxer read tag header failed, err=%v", err)
return
}
var tag []byte
if tag, err = f.ReadTag(tagSize); err != nil {
logger.Tf(ctx, "Flv demuxer read tag failed, err=%v", err)
return
}
if tagType == flv.TagTypeAudio {
nnAudio++
prevAudioTimestamp = (int64)(timestamp)
} else if tagType == flv.TagTypeVideo {
nnVideo++
prevVideoTimestamp = (int64)(timestamp)
}
if nnAudio >= 10 && nnVideo >= 10 {
avDiff := prevVideoTimestamp - prevAudioTimestamp
// Check timestamp gap between video and audio, make sure audio timestamp align to video timestamp.
if avDiff <= 50 && avDiff >= -50 {
logger.Tf(ctx, "Flv recv %v audio, %v video, timestamp gap=%v", nnAudio, nnVideo, avDiff)
cancel()
break
}
}
_ = tag
}
}
}()
}

View file

@ -26,8 +26,8 @@ var (
// NewMD5 and NewSHA1.
func NewHash(h hash.Hash, space UUID, data []byte, version int) UUID {
h.Reset()
h.Write(space[:])
h.Write(data)
h.Write(space[:]) //nolint:errcheck
h.Write(data) //nolint:errcheck
s := h.Sum(nil)
var uuid UUID
copy(uuid[:], s)

View file

@ -9,7 +9,7 @@ import (
"fmt"
)
// Scan implements sql.Scanner so UUIDs can be read from databases transparently
// Scan implements sql.Scanner so UUIDs can be read from databases transparently.
// Currently, database types that map to string and []byte are supported. Please
// consult database-specific driver documentation for matching types.
func (uuid *UUID) Scan(src interface{}) error {

View file

@ -35,6 +35,12 @@ const (
var rander = rand.Reader // random function
type invalidLengthError struct{ len int }
func (err invalidLengthError) Error() string {
return fmt.Sprintf("invalid UUID length: %d", err.len)
}
// Parse decodes s into a UUID or returns an error. Both the standard UUID
// forms of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded as well as the
@ -68,7 +74,7 @@ func Parse(s string) (UUID, error) {
}
return uuid, nil
default:
return uuid, fmt.Errorf("invalid UUID length: %d", len(s))
return uuid, invalidLengthError{len(s)}
}
// s is now at least 36 bytes long
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
@ -112,7 +118,7 @@ func ParseBytes(b []byte) (UUID, error) {
}
return uuid, nil
default:
return uuid, fmt.Errorf("invalid UUID length: %d", len(b))
return uuid, invalidLengthError{len(b)}
}
// s is now at least 36 bytes long
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx

View file

@ -0,0 +1,516 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// The oryx AAC package includes some utilites.
package aac
import (
"github.com/ossrs/go-oryx-lib/errors"
)
// The ADTS is a format of AAC.
// We can encode the RAW AAC frame in ADTS muxer.
// We can also decode the ADTS data to RAW AAC frame.
type ADTS interface {
// Set the ASC, the codec information.
// Before encoding raw frame, user must set the asc.
SetASC(asc []byte) (err error)
// Encode the raw aac frame to adts data.
// @remark User must set the asc first.
Encode(raw []byte) (adts []byte, err error)
// Decode the adts data to raw frame.
// @remark User can get the asc after decode ok.
// @remark When left if not nil, user must decode it again.
Decode(adts []byte) (raw, left []byte, err error)
// Get the ASC, the codec information.
// When decode a adts data or set the asc, user can use this API to get it.
ASC() *AudioSpecificConfig
}
// The AAC object type in RAW AAC frame.
// Refer to @doc ISO_IEC_14496-3-AAC-2001.pdf, @page 23, @section 1.5.1.1 Audio object type definition
type ObjectType uint8
const (
ObjectTypeForbidden ObjectType = iota
ObjectTypeMain
ObjectTypeLC
ObjectTypeSSR
ObjectTypeHE ObjectType = 5 // HE=LC+SBR
ObjectTypeHEv2 ObjectType = 29 // HEv2=LC+SBR+PS
)
func (v ObjectType) String() string {
switch v {
case ObjectTypeMain:
return "Main"
case ObjectTypeLC:
return "LC"
case ObjectTypeSSR:
return "SSR"
case ObjectTypeHE:
return "HE"
case ObjectTypeHEv2:
return "HEv2"
default:
return "Forbidden"
}
}
func (v ObjectType) ToProfile() Profile {
switch v {
case ObjectTypeMain:
return ProfileMain
case ObjectTypeHE, ObjectTypeHEv2, ObjectTypeLC:
return ProfileLC
case ObjectTypeSSR:
return ProfileSSR
default:
return ProfileForbidden
}
}
// The profile of AAC in ADTS.
// Refer to @doc ISO_IEC_13818-7-AAC-2004.pdf, @page 40, @section 7.1 Profiles
type Profile uint8
const (
ProfileMain Profile = iota
ProfileLC
ProfileSSR
ProfileForbidden
)
func (v Profile) String() string {
switch v {
case ProfileMain:
return "Main"
case ProfileLC:
return "LC"
case ProfileSSR:
return "SSR"
default:
return "Forbidden"
}
}
func (v Profile) ToObjectType() ObjectType {
switch v {
case ProfileMain:
return ObjectTypeMain
case ProfileLC:
return ObjectTypeLC
case ProfileSSR:
return ObjectTypeSSR
default:
return ObjectTypeForbidden
}
}
// The aac sample rate index.
// Refer to @doc ISO_IEC_13818-7-AAC-2004.pdf, @page 46, @section Table 35 Sampling frequency
type SampleRateIndex uint8
const (
SampleRateIndex96kHz SampleRateIndex = iota
SampleRateIndex88kHz
SampleRateIndex64kHz
SampleRateIndex48kHz
SampleRateIndex44kHz
SampleRateIndex32kHz
SampleRateIndex24kHz
SampleRateIndex22kHz
SampleRateIndex16kHz
SampleRateIndex12kHz
SampleRateIndex11kHz
SampleRateIndex8kHz
SampleRateIndex7kHz
SampleRateIndexReserved0
SampleRateIndexReserved1
SampleRateIndexReserved2
SampleRateIndexReserved3
SampleRateIndexForbidden
)
func (v SampleRateIndex) String() string {
switch v {
case SampleRateIndex96kHz:
return "96kHz"
case SampleRateIndex88kHz:
return "88kHz"
case SampleRateIndex64kHz:
return "64kHz"
case SampleRateIndex48kHz:
return "48kHz"
case SampleRateIndex44kHz:
return "44kHz"
case SampleRateIndex32kHz:
return "32kHz"
case SampleRateIndex24kHz:
return "24kHz"
case SampleRateIndex22kHz:
return "22kHz"
case SampleRateIndex16kHz:
return "16kHz"
case SampleRateIndex12kHz:
return "12kHz"
case SampleRateIndex11kHz:
return "11kHz"
case SampleRateIndex8kHz:
return "8kHz"
case SampleRateIndex7kHz:
return "7kHz"
case SampleRateIndexReserved0, SampleRateIndexReserved1, SampleRateIndexReserved2, SampleRateIndexReserved3:
return "Reserved"
default:
return "Forbidden"
}
}
func (v SampleRateIndex) ToHz() int {
aacSR := []int{
96000, 88200, 64000, 48000,
44100, 32000, 24000, 22050,
16000, 12000, 11025, 8000,
7350, 0, 0, 0,
/* To avoid overflow by forbidden */
0,
}
return aacSR[v]
}
// The aac channel.
// Refer to @doc ISO_IEC_13818-7-AAC-2004.pdf, @page 72, @section Table 42 Implicit speaker mapping
type Channels uint8
const (
ChannelForbidden Channels = iota
// center front speaker
// FFMPEG: mono FC
ChannelMono
// left, right front speakers
// FFMPEG: stereo FL+FR
ChannelStereo
// center front speaker, left, right front speakers
// FFMPEG: 2.1 FL+FR+LFE
// FFMPEG: 3.0 FL+FR+FC
// FFMPEG: 3.0(back) FL+FR+BC
Channel3
// center front speaker, left, right center front speakers, rear surround
// FFMPEG: 4.0 FL+FR+FC+BC
// FFMPEG: quad FL+FR+BL+BR
// FFMPEG: quad(side) FL+FR+SL+SR
// FFMPEG: 3.1 FL+FR+FC+LFE
Channel4
// center front speaker, left, right front speakers, left surround, right surround rear speakers
// FFMPEG: 5.0 FL+FR+FC+BL+BR
// FFMPEG: 5.0(side) FL+FR+FC+SL+SR
// FFMPEG: 4.1 FL+FR+FC+LFE+BC
Channel5
// center front speaker, left, right front speakers, left surround, right surround rear speakers,
// front low frequency effects speaker
// FFMPEG: 5.1 FL+FR+FC+LFE+BL+BR
// FFMPEG: 5.1(side) FL+FR+FC+LFE+SL+SR
// FFMPEG: 6.0 FL+FR+FC+BC+SL+SR
// FFMPEG: 6.0(front) FL+FR+FLC+FRC+SL+SR
// FFMPEG: hexagonal FL+FR+FC+BL+BR+BC
Channel5_1 // speakers: 6
// center front speaker, left, right center front speakers, left, right outside front speakers,
// left surround, right surround rear speakers, front low frequency effects speaker
// FFMPEG: 7.1 FL+FR+FC+LFE+BL+BR+SL+SR
// FFMPEG: 7.1(wide) FL+FR+FC+LFE+BL+BR+FLC+FRC
// FFMPEG: 7.1(wide-side) FL+FR+FC+LFE+FLC+FRC+SL+SR
Channel7_1 // speakers: 7
// FFMPEG: 6.1 FL+FR+FC+LFE+BC+SL+SR
// FFMPEG: 6.1(back) FL+FR+FC+LFE+BL+BR+BC
// FFMPEG: 6.1(front) FL+FR+LFE+FLC+FRC+SL+SR
// FFMPEG: 7.0 FL+FR+FC+BL+BR+SL+SR
// FFMPEG: 7.0(front) FL+FR+FC+FLC+FRC+SL+SR
)
func (v Channels) String() string {
switch v {
case ChannelMono:
return "Mono(FC)"
case ChannelStereo:
return "Stereo(FL+FR)"
case Channel3:
return "FL+FR+FC"
case Channel4:
return "FL+FR+FC+BC"
case Channel5:
return "FL+FR+FC+SL+SR"
case Channel5_1:
return "FL+FR+FC+LFE+SL+SR"
case Channel7_1:
return "FL+FR+FC+LFE+BL+BR+SL+SR"
default:
return "Forbidden"
}
}
// Please use NewADTS() and interface ADTS instead.
// It's only exposed for example.
type ADTSImpl struct {
asc AudioSpecificConfig
}
func NewADTS() (ADTS, error) {
return &ADTSImpl{}, nil
}
func (v *ADTSImpl) SetASC(asc []byte) (err error) {
return v.asc.UnmarshalBinary(asc)
}
func (v *ADTSImpl) Encode(raw []byte) (data []byte, err error) {
if err = v.asc.validate(); err != nil {
return nil, errors.WithMessage(err, "adts encode")
}
// write the ADTS header.
// Refer to @doc ISO_IEC_13818-7-AAC-2004.pdf, @page 26, @section 6.2 Audio Data Transport Stream, ADTS
// byte_alignment()
// adts_fixed_header:
// 12bits syncword,
// 16bits left.
// adts_variable_header:
// 28bits
// 12+16+28=56bits
// adts_error_check:
// 16bits if protection_absent
// 56+16=72bits
// if protection_absent:
// require(7bytes)=56bits
// else
// require(9bytes)=72bits
aacFixedHeader := make([]byte, 7)
p := aacFixedHeader
// Syncword 12 bslbf
p[0] = byte(0xff)
// 4bits left.
// Refer to @doc ISO_IEC_13818-7-AAC-2004.pdf, @page 27, @section 6.2.1 Fixed Header of ADTS
// ID 1 bslbf
// Layer 2 uimsbf
// protection_absent 1 bslbf
p[1] = byte(0xf1)
// profile 2 uimsbf
// sampling_frequency_index 4 uimsbf
// private_bit 1 bslbf
// channel_configuration 3 uimsbf
// original/copy 1 bslbf
// home 1 bslbf
profile := v.asc.Object.ToProfile()
p[2] = byte((profile<<6)&0xc0) | byte((v.asc.SampleRate<<2)&0x3c) | byte((v.asc.Channels>>2)&0x01)
// 4bits left.
// Refer to @doc ISO_IEC_13818-7-AAC-2004.pdf, @page 27, @section 6.2.2 Variable Header of ADTS
// copyright_identification_bit 1 bslbf
// copyright_identification_start 1 bslbf
aacFrameLength := uint16(len(raw) + len(aacFixedHeader))
p[3] = byte((v.asc.Channels<<6)&0xc0) | byte((aacFrameLength>>11)&0x03)
// aac_frame_length 13 bslbf: Length of the frame including headers and error_check in bytes.
// use the left 2bits as the 13 and 12 bit,
// the aac_frame_length is 13bits, so we move 13-2=11.
p[4] = byte(aacFrameLength >> 3)
// adts_buffer_fullness 11 bslbf
p[5] = byte(aacFrameLength<<5) & byte(0xe0)
// no_raw_data_blocks_in_frame 2 uimsbf
p[6] = byte(0xfc)
return append(p, raw...), nil
}
func (v *ADTSImpl) Decode(data []byte) (raw, left []byte, err error) {
// write the ADTS header.
// Refer to @doc ISO_IEC_13818-7-AAC-2004.pdf, @page 26, @section 6.2 Audio Data Transport Stream, ADTS
// @see https://github.com/ossrs/srs/issues/212#issuecomment-64145885
// byte_alignment()
p := data
if len(p) <= 7 {
return nil, nil, errors.Errorf("requires 7+ but only %v bytes", len(p))
}
// matched 12bits 0xFFF,
// @remark, we must cast the 0xff to char to compare.
if p[0] != 0xff || p[1]&0xf0 != 0xf0 {
return nil, nil, errors.Errorf("invalid signature %#x", uint8(p[1]&0xf0))
}
// Syncword 12 bslbf
_ = p[0]
// 4bits left.
// Refer to @doc ISO_IEC_13818-7-AAC-2004.pdf, @page 27, @section 6.2.1 Fixed Header of ADTS
// ID 1 bslbf
// layer 2 uimsbf
// protection_absent 1 bslbf
pat := uint8(p[1]) & 0x0f
id := (pat >> 3) & 0x01
//layer := (pat >> 1) & 0x03
protectionAbsent := pat & 0x01
// ID: MPEG identifier, set to '1' if the audio data in the ADTS stream are MPEG-2 AAC (See ISO/IEC 13818-7)
// and set to '0' if the audio data are MPEG-4. See also ISO/IEC 11172-3, subclause 2.4.2.3.
if id != 0x01 {
// well, some system always use 0, but actually is aac format.
// for example, houjian vod ts always set the aac id to 0, actually 1.
// we just ignore it, and alwyas use 1(aac) to demux.
id = 0x01
}
sfiv := uint16(p[2])<<8 | uint16(p[3])
// profile 2 uimsbf
// sampling_frequency_index 4 uimsbf
// private_bit 1 bslbf
// channel_configuration 3 uimsbf
// original/copy 1 bslbf
// home 1 bslbf
profile := Profile(uint8(sfiv>>14) & 0x03)
samplingFrequencyIndex := uint8(sfiv>>10) & 0x0f
//private_bit := (t >> 9) & 0x01
channelConfiguration := uint8(sfiv>>6) & 0x07
//original := uint8(sfiv >> 5) & 0x01
//home := uint8(sfiv >> 4) & 0x01
// 4bits left.
// Refer to @doc ISO_IEC_13818-7-AAC-2004.pdf, @page 27, @section 6.2.2 Variable Header of ADTS
// copyright_identification_bit 1 bslbf
// copyright_identification_start 1 bslbf
//fh_copyright_identification_bit = uint8(sfiv >> 3) & 0x01
//fh_copyright_identification_start = uint8(sfiv >> 2) & 0x01
// frame_length 13 bslbf: Length of the frame including headers and error_check in bytes.
// use the left 2bits as the 13 and 12 bit,
// the frame_length is 13bits, so we move 13-2=11.
frameLength := (sfiv << 11) & 0x1800
abfv := uint32(p[4])<<16 | uint32(p[5])<<8 | uint32(p[6])
p = p[7:]
// frame_length 13 bslbf: consume the first 13-2=11bits
// the fh2 is 24bits, so we move right 24-11=13.
frameLength |= uint16((abfv >> 13) & 0x07ff)
// adts_buffer_fullness 11 bslbf
//fh_adts_buffer_fullness = (abfv >> 2) & 0x7ff
// number_of_raw_data_blocks_in_frame 2 uimsbf
//number_of_raw_data_blocks_in_frame = abfv & 0x03
// adts_error_check(), 1.A.2.2.3 Error detection
if protectionAbsent == 0 {
if len(p) <= 2 {
return nil, nil, errors.Errorf("requires 2+ but only %v bytes", len(p))
}
// crc_check 16 Rpchof
p = p[2:]
}
v.asc.Object = profile.ToObjectType()
v.asc.Channels = Channels(channelConfiguration)
v.asc.SampleRate = SampleRateIndex(samplingFrequencyIndex)
nbRaw := int(frameLength - 7)
if len(p) < nbRaw {
return nil, nil, errors.Errorf("requires %v but only %v bytes", nbRaw, len(p))
}
raw = p[:nbRaw]
left = p[nbRaw:]
if err = v.asc.validate(); err != nil {
return nil, nil, errors.WithMessage(err, "adts decode")
}
return
}
func (v *ADTSImpl) ASC() *AudioSpecificConfig {
return &v.asc
}
// Convert the ASC(Audio Specific Configuration).
// Refer to @doc ISO_IEC_14496-3-AAC-2001.pdf, @page 33, @section 1.6.2.1 AudioSpecificConfig
type AudioSpecificConfig struct {
Object ObjectType // AAC object type.
SampleRate SampleRateIndex // AAC sample rate, not the FLV sampling rate.
Channels Channels // AAC channel configuration.
}
func (v *AudioSpecificConfig) validate() (err error) {
switch v.Object {
case ObjectTypeMain, ObjectTypeLC, ObjectTypeSSR, ObjectTypeHE, ObjectTypeHEv2:
default:
return errors.Errorf("invalid object %#x", uint8(v.Object))
}
if v.SampleRate < SampleRateIndex88kHz || v.SampleRate > SampleRateIndex7kHz {
return errors.Errorf("invalid sample-rate %#x", uint8(v.SampleRate))
}
if v.Channels < ChannelMono || v.Channels > Channel7_1 {
return errors.Errorf("invalid channels %#x", uint8(v.Channels))
}
return
}
func (v *AudioSpecificConfig) UnmarshalBinary(data []byte) (err error) {
// AudioSpecificConfig
// Refer to @doc ISO_IEC_14496-3-AAC-2001.pdf, @page 33, @section 1.6.2.1 AudioSpecificConfig
//
// only need to decode the first 2bytes:
// audioObjectType, 5bits.
// samplingFrequencyIndex, aac_sample_rate, 4bits.
// channelConfiguration, aac_channels, 4bits
//
// @see SrsAacTransmuxer::write_audio
if len(data) < 2 {
return errors.Errorf("requires 2 but only %v bytes", len(data))
}
t0, t1 := uint8(data[0]), uint8(data[1])
v.Object = ObjectType((t0 >> 3) & 0x1f)
v.SampleRate = SampleRateIndex(((t0 << 1) & 0x0e) | ((t1 >> 7) & 0x01))
v.Channels = Channels((t1 >> 3) & 0x0f)
return v.validate()
}
func (v *AudioSpecificConfig) MarshalBinary() (data []byte, err error) {
if err = v.validate(); err != nil {
return
}
// AudioSpecificConfig
// Refer to @doc ISO_IEC_14496-3-AAC-2001.pdf, @page 33, @section 1.6.2.1 AudioSpecificConfig
//
// only need to decode the first 2bytes:
// audioObjectType, 5bits.
// samplingFrequencyIndex, aac_sample_rate, 4bits.
// channelConfiguration, aac_channels, 4bits
return []byte{
byte(byte(v.Object)&0x1f)<<3 | byte(byte(v.SampleRate)&0x0e)>>1,
byte(byte(v.SampleRate)&0x01)<<7 | byte(byte(v.Channels)&0x0f)<<3,
}, nil
}

View file

@ -0,0 +1,747 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// The oryx FLV package support bytes from/to FLV tags.
package flv
import (
"bytes"
"errors"
"github.com/ossrs/go-oryx-lib/aac"
"io"
"strings"
)
// FLV Tag Type is the type of tag,
// refer to @doc video_file_format_spec_v10.pdf, @page 9, @section FLV tags
type TagType uint8
const (
TagTypeForbidden TagType = 0
TagTypeAudio TagType = 8
TagTypeVideo TagType = 9
TagTypeScriptData TagType = 18
)
func (v TagType) String() string {
switch v {
case TagTypeVideo:
return "Video"
case TagTypeAudio:
return "Audio"
case TagTypeScriptData:
return "Data"
default:
return "Forbidden"
}
}
// FLV Demuxer is used to demux FLV file.
// Refer to @doc video_file_format_spec_v10.pdf, @page 74, @section Annex E. The FLV File Format
// A FLV file must consist the bellow parts:
// 1. A FLV header, refer to @doc video_file_format_spec_v10.pdf, @page 8, @section The FLV header
// 2. One or more tags, refer to @doc video_file_format_spec_v10.pdf, @page 9, @section FLV tags
// @remark We always ignore the previous tag size.
type Demuxer interface {
// Read the FLV header, return the version of FLV, whether hasVideo or hasAudio in header.
ReadHeader() (version uint8, hasVideo, hasAudio bool, err error)
// Read the FLV tag header, return the tag information, especially the tag size,
// then user can read the tag payload.
ReadTagHeader() (tagType TagType, tagSize, timestamp uint32, err error)
// Read the FLV tag body, drop the next 4 bytes previous tag size.
ReadTag(tagSize uint32) (tag []byte, err error)
// Close the demuxer.
Close() error
}
// When FLV signature is not "FLV"
var errSignature = errors.New("FLV signatures are illegal")
// Create a demuxer object.
func NewDemuxer(r io.Reader) (Demuxer, error) {
return &demuxer{
r: r,
}, nil
}
type demuxer struct {
r io.Reader
}
func (v *demuxer) ReadHeader() (version uint8, hasVideo, hasAudio bool, err error) {
h := &bytes.Buffer{}
if _, err = io.CopyN(h, v.r, 13); err != nil {
return
}
p := h.Bytes()
if !bytes.Equal([]byte{byte('F'), byte('L'), byte('V')}, p[:3]) {
err = errSignature
return
}
version = uint8(p[3])
hasVideo = (p[4] & 0x01) == 0x01
hasAudio = ((p[4] >> 2) & 0x01) == 0x01
return
}
func (v *demuxer) ReadTagHeader() (tagType TagType, tagSize uint32, timestamp uint32, err error) {
h := &bytes.Buffer{}
if _, err = io.CopyN(h, v.r, 11); err != nil {
return
}
p := h.Bytes()
tagType = TagType(p[0])
tagSize = uint32(p[1])<<16 | uint32(p[2])<<8 | uint32(p[3])
timestamp = uint32(p[7])<<24 | uint32(p[4])<<16 | uint32(p[5])<<8 | uint32(p[6])
return
}
func (v *demuxer) ReadTag(tagSize uint32) (tag []byte, err error) {
h := &bytes.Buffer{}
if _, err = io.CopyN(h, v.r, int64(tagSize+4)); err != nil {
return
}
p := h.Bytes()
tag = p[0 : len(p)-4]
return
}
func (v *demuxer) Close() error {
return nil
}
// The FLV muxer is used to write packet in FLV protocol.
// Refer to @doc video_file_format_spec_v10.pdf, @page 74, @section Annex E. The FLV File Format
type Muxer interface {
// Write the FLV header.
WriteHeader(hasVideo, hasAudio bool) (err error)
// Write A FLV tag.
WriteTag(tagType TagType, timestamp uint32, tag []byte) (err error)
// Close the muxer.
Close() error
}
// Create a muxer object.
func NewMuxer(w io.Writer) (Muxer, error) {
return &muxer{
w: w,
}, nil
}
type muxer struct {
w io.Writer
}
func (v *muxer) WriteHeader(hasVideo, hasAudio bool) (err error) {
var flags byte
if hasVideo {
flags |= 0x01
}
if hasAudio {
flags |= 0x04
}
r := bytes.NewReader([]byte{
byte('F'), byte('L'), byte('V'),
0x01,
flags,
0x00, 0x00, 0x00, 0x09,
0x00, 0x00, 0x00, 0x00,
})
if _, err = io.Copy(v.w, r); err != nil {
return
}
return
}
func (v *muxer) WriteTag(tagType TagType, timestamp uint32, tag []byte) (err error) {
// Tag header.
tagSize := uint32(len(tag))
r := bytes.NewReader([]byte{
byte(tagType),
byte(tagSize >> 16), byte(tagSize >> 8), byte(tagSize),
byte(timestamp >> 16), byte(timestamp >> 8), byte(timestamp),
byte(timestamp >> 24),
0x00, 0x00, 0x00,
})
if _, err = io.Copy(v.w, r); err != nil {
return
}
// TAG
if _, err = io.Copy(v.w, bytes.NewReader(tag)); err != nil {
return
}
// Previous tag size.
pts := uint32(11 + len(tag))
r = bytes.NewReader([]byte{
byte(pts >> 24), byte(pts >> 16), byte(pts >> 8), byte(pts),
})
if _, err = io.Copy(v.w, r); err != nil {
return
}
return
}
func (v *muxer) Close() error {
return nil
}
// The Audio AAC frame trait, whether sequence header(ASC) or raw data.
// Refer to @doc video_file_format_spec_v10.pdf, @page 77, @section E.4.2 Audio Tags
type AudioFrameTrait uint8
const (
// For AAC, the frame trait.
AudioFrameTraitSequenceHeader AudioFrameTrait = 0 // 0 = AAC sequence header
AudioFrameTraitRaw AudioFrameTrait = 1 // 1 = AAC raw
// For Opus, the frame trait, may has more than one traits.
AudioFrameTraitOpusRaw AudioFrameTrait = 0x02 // 2, Has RAW Opus data.
AudioFrameTraitOpusSamplingRate AudioFrameTrait = 0x04 // 4, Has Opus SamplingRate.
AudioFrameTraitOpusAudioLevel AudioFrameTrait = 0x08 // 8, Has audio level data, 16bits.
AudioFrameTraitForbidden AudioFrameTrait = 0xff
)
func (v AudioFrameTrait) String() string {
if v > AudioFrameTraitRaw && v < AudioFrameTraitForbidden {
var s []string
if (v & AudioFrameTraitOpusRaw) == AudioFrameTraitOpusRaw {
s = append(s, "RAW")
}
if (v & AudioFrameTraitOpusSamplingRate) == AudioFrameTraitOpusSamplingRate {
s = append(s, "SR")
}
if (v & AudioFrameTraitOpusAudioLevel) == AudioFrameTraitOpusAudioLevel {
s = append(s, "AL")
}
return strings.Join(s, "|")
}
switch v {
case AudioFrameTraitSequenceHeader:
return "SequenceHeader"
case AudioFrameTraitRaw:
return "Raw"
default:
return "Forbidden"
}
}
// The audio channels, FLV named it the SoundType.
// Refer to @doc video_file_format_spec_v10.pdf, @page 77, @section E.4.2 Audio Tags
type AudioChannels uint8
const (
AudioChannelsMono AudioChannels = iota // 0 = Mono sound
AudioChannelsStereo // 1 = Stereo sound
AudioChannelsForbidden
)
func (v AudioChannels) String() string {
switch v {
case AudioChannelsMono:
return "Mono"
case AudioChannelsStereo:
return "Stereo"
default:
return "Forbidden"
}
}
func (v *AudioChannels) From(a aac.Channels) {
switch a {
case aac.ChannelMono:
*v = AudioChannelsMono
case aac.ChannelStereo:
*v = AudioChannelsStereo
case aac.Channel3, aac.Channel4, aac.Channel5, aac.Channel5_1, aac.Channel7_1:
*v = AudioChannelsStereo
default:
*v = AudioChannelsForbidden
}
}
// The audio sample bits, FLV named it the SoundSize.
// Refer to @doc video_file_format_spec_v10.pdf, @page 76, @section E.4.2 Audio Tags
type AudioSampleBits uint8
const (
AudioSampleBits8bits AudioSampleBits = iota // 0 = 8-bit samples
AudioSampleBits16bits // 1 = 16-bit samples
AudioSampleBitsForbidden
)
func (v AudioSampleBits) String() string {
switch v {
case AudioSampleBits8bits:
return "8-bits"
case AudioSampleBits16bits:
return "16-bits"
default:
return "Forbidden"
}
}
// The audio sampling rate, FLV named it the SoundRate.
// Refer to @doc video_file_format_spec_v10.pdf, @page 76, @section E.4.2 Audio Tags
type AudioSamplingRate uint8
const (
// For FLV, only support 5, 11, 22, 44KHz sampling rate.
AudioSamplingRate5kHz AudioSamplingRate = iota // 0 = 5.5 kHz
AudioSamplingRate11kHz // 1 = 11 kHz
AudioSamplingRate22kHz // 2 = 22 kHz
AudioSamplingRate44kHz // 3 = 44 kHz
// For Opus, support 8, 12, 16, 24, 48KHz
// We will write a UINT8 sampling rate after FLV audio tag header.
// @doc https://tools.ietf.org/html/rfc6716#section-2
AudioSamplingRateNB8kHz = 8 // NB (narrowband)
AudioSamplingRateMB12kHz = 12 // MB (medium-band)
AudioSamplingRateWB16kHz = 16 // WB (wideband)
AudioSamplingRateSWB24kHz = 24 // SWB (super-wideband)
AudioSamplingRateFB48kHz = 48 // FB (fullband)
AudioSamplingRateForbidden
)
func (v AudioSamplingRate) String() string {
switch v {
case AudioSamplingRate5kHz:
return "5.5kHz"
case AudioSamplingRate11kHz:
return "11kHz"
case AudioSamplingRate22kHz:
return "22kHz"
case AudioSamplingRate44kHz:
return "44kHz"
case AudioSamplingRateNB8kHz:
return "NB8kHz"
case AudioSamplingRateMB12kHz:
return "MB12kHz"
case AudioSamplingRateWB16kHz:
return "WB16kHz"
case AudioSamplingRateSWB24kHz:
return "SWB24kHz"
case AudioSamplingRateFB48kHz:
return "FB48kHz"
default:
return "Forbidden"
}
}
// Parse the FLV sampling rate to Hz.
func (v AudioSamplingRate) ToHz() int {
flvSR := []int{5512, 11025, 22050, 44100}
return flvSR[v]
}
// For FLV, convert aac sample rate index to FLV sampling rate.
func (v *AudioSamplingRate) From(a aac.SampleRateIndex) {
switch a {
case aac.SampleRateIndex96kHz, aac.SampleRateIndex88kHz, aac.SampleRateIndex64kHz:
*v = AudioSamplingRate44kHz
case aac.SampleRateIndex48kHz:
*v = AudioSamplingRate44kHz
case aac.SampleRateIndex44kHz, aac.SampleRateIndex32kHz:
*v = AudioSamplingRate44kHz
case aac.SampleRateIndex24kHz, aac.SampleRateIndex22kHz, aac.SampleRateIndex16kHz:
*v = AudioSamplingRate22kHz
case aac.SampleRateIndex12kHz, aac.SampleRateIndex11kHz, aac.SampleRateIndex8kHz:
*v = AudioSamplingRate11kHz
case aac.SampleRateIndex7kHz:
*v = AudioSamplingRate5kHz
default:
*v = AudioSamplingRateForbidden
}
}
// Parse the Opus sampling rate to Hz.
func (v AudioSamplingRate) OpusToHz() int {
opusSR := []int{8000, 12000, 16000, 24000, 48000}
return opusSR[v]
}
// For Opus, convert aac sample rate index to FLV sampling rate.
func (v *AudioSamplingRate) OpusFrom(a aac.SampleRateIndex) {
switch a {
case aac.SampleRateIndex96kHz, aac.SampleRateIndex88kHz, aac.SampleRateIndex64kHz:
*v = AudioSamplingRateFB48kHz
case aac.SampleRateIndex48kHz, aac.SampleRateIndex44kHz, aac.SampleRateIndex32kHz:
*v = AudioSamplingRateFB48kHz
case aac.SampleRateIndex24kHz, aac.SampleRateIndex22kHz:
*v = AudioSamplingRateSWB24kHz
case aac.SampleRateIndex16kHz:
*v = AudioSamplingRateWB16kHz
case aac.SampleRateIndex12kHz, aac.SampleRateIndex11kHz:
*v = AudioSamplingRateMB12kHz
case aac.SampleRateIndex8kHz, aac.SampleRateIndex7kHz:
*v = AudioSamplingRateNB8kHz
default:
*v = AudioSamplingRateForbidden
}
}
// The audio codec id, FLV named it the SoundFormat.
// Refer to @doc video_file_format_spec_v10.pdf, @page 76, @section E.4.2 Audio Tags
// It's 4bits, that is 0-16.
type AudioCodec uint8
const (
AudioCodecLinearPCM AudioCodec = iota // 0 = Linear PCM, platform endian
AudioCodecADPCM // 1 = ADPCM
AudioCodecMP3 // 2 = MP3
AudioCodecLinearPCMle // 3 = Linear PCM, little endian
AudioCodecNellymoser16kHz // 4 = Nellymoser 16 kHz mono
AudioCodecNellymoser8kHz // 5 = Nellymoser 8 kHz mono
AudioCodecNellymoser // 6 = Nellymoser
AudioCodecG711Alaw // 7 = G.711 A-law logarithmic PCM
AudioCodecG711MuLaw // 8 = G.711 mu-law logarithmic PCM
AudioCodecReserved // 9 = reserved
AudioCodecAAC // 10 = AAC
AudioCodecSpeex // 11 = Speex
AudioCodecUndefined12
// For FLV, it's undefined, we define it as Opus for WebRTC.
AudioCodecOpus // 13 = Opus
AudioCodecMP3In8kHz // 14 = MP3 8 kHz
AudioCodecDeviceSpecific // 15 = Device-specific sound
AudioCodecForbidden
)
func (v AudioCodec) String() string {
switch v {
case AudioCodecLinearPCM:
return "LinearPCM(platform-endian)"
case AudioCodecADPCM:
return "ADPCM"
case AudioCodecMP3:
return "MP3"
case AudioCodecLinearPCMle:
return "LinearPCM(little-endian)"
case AudioCodecNellymoser16kHz:
return "Nellymoser(16kHz-mono)"
case AudioCodecNellymoser8kHz:
return "Nellymoser(8kHz-mono)"
case AudioCodecNellymoser:
return "Nellymoser"
case AudioCodecG711Alaw:
return "G.711(A-law)"
case AudioCodecG711MuLaw:
return "G.711(mu-law)"
case AudioCodecAAC:
return "AAC"
case AudioCodecSpeex:
return "Speex"
case AudioCodecOpus:
return "Opus"
case AudioCodecMP3In8kHz:
return "MP3(8kHz)"
case AudioCodecDeviceSpecific:
return "DeviceSpecific"
default:
return "Forbidden"
}
}
type AudioFrame struct {
SoundFormat AudioCodec
SoundRate AudioSamplingRate
SoundSize AudioSampleBits
SoundType AudioChannels
Trait AudioFrameTrait
AudioLevel uint16
Raw []byte
}
// The packager used to codec the FLV audio tag body.
// Refer to @doc video_file_format_spec_v10.pdf, @page 76, @section E.4.2 Audio Tags
type AudioPackager interface {
// Encode the audio frame to FLV audio tag.
Encode(frame *AudioFrame) (tag []byte, err error)
// Decode the FLV audio tag to audio frame.
Decode(tag []byte) (frame *AudioFrame, err error)
}
var errDataNotEnough = errors.New("Data not enough")
type audioPackager struct {
}
func NewAudioPackager() (AudioPackager, error) {
return &audioPackager{}, nil
}
func (v *audioPackager) Encode(frame *AudioFrame) (tag []byte, err error) {
audioTagHeader := []byte{
byte(frame.SoundFormat)<<4 | byte(frame.SoundRate)<<2 | byte(frame.SoundSize)<<1 | byte(frame.SoundType),
}
// For Opus, we put the sampling rate after trait,
// so we set the sound rate in audio tag to 0.
if frame.SoundFormat == AudioCodecOpus {
audioTagHeader[0] &= 0xf3
}
if frame.SoundFormat == AudioCodecAAC {
return append(append(audioTagHeader, byte(frame.Trait)), frame.Raw...), nil
} else if frame.SoundFormat == AudioCodecOpus {
var b bytes.Buffer
b.Write(audioTagHeader)
b.WriteByte(byte(frame.Trait))
if (frame.Trait & AudioFrameTraitOpusSamplingRate) == AudioFrameTraitOpusSamplingRate {
b.WriteByte(byte(frame.SoundRate))
}
if (frame.Trait & AudioFrameTraitOpusAudioLevel) == AudioFrameTraitOpusAudioLevel {
b.WriteByte(byte(frame.AudioLevel >> 8))
b.WriteByte(byte(frame.AudioLevel))
}
b.Write(frame.Raw)
return b.Bytes(), nil
} else {
return append(audioTagHeader, frame.Raw...), nil
}
}
func (v *audioPackager) Decode(tag []byte) (frame *AudioFrame, err error) {
// Refer to @doc video_file_format_spec_v10.pdf, @page 76, @section E.4.2 Audio Tags
// @see SrsFormat::audio_aac_demux
if len(tag) < 2 {
err = errDataNotEnough
return
}
t := uint8(tag[0])
frame = &AudioFrame{}
frame.SoundFormat = AudioCodec(uint8(t>>4) & 0x0f)
frame.SoundRate = AudioSamplingRate(uint8(t>>2) & 0x03)
frame.SoundSize = AudioSampleBits(uint8(t>>1) & 0x01)
frame.SoundType = AudioChannels(t & 0x01)
if frame.SoundFormat == AudioCodecAAC {
frame.Trait = AudioFrameTrait(tag[1])
frame.Raw = tag[2:]
} else if frame.SoundFormat == AudioCodecOpus {
frame.Trait = AudioFrameTrait(tag[1])
p := tag[2:]
// For Opus, we put sampling rate after trait.
if (frame.Trait & AudioFrameTraitOpusSamplingRate) == AudioFrameTraitOpusSamplingRate {
if len(p) < 1 {
return nil, errDataNotEnough
}
frame.SoundRate = AudioSamplingRate(p[0])
p = p[1:]
}
// For Opus, we put audio level after trait.
if (frame.Trait & AudioFrameTraitOpusAudioLevel) == AudioFrameTraitOpusAudioLevel {
if len(p) < 2 {
return nil, errDataNotEnough
}
frame.AudioLevel = uint16(p[0])<<8 | uint16(p[1])
p = p[2:]
}
frame.Raw = p
} else {
frame.Raw = tag[1:]
}
return
}
// The video frame type.
// Refer to @doc video_file_format_spec_v10.pdf, @page 78, @section E.4.3 Video Tags
type VideoFrameType uint8
const (
VideoFrameTypeForbidden VideoFrameType = iota
VideoFrameTypeKeyframe // 1 = key frame (for AVC, a seekable frame)
VideoFrameTypeInterframe // 2 = inter frame (for AVC, a non-seekable frame)
VideoFrameTypeDisposable // 3 = disposable inter frame (H.263 only)
VideoFrameTypeGenerated // 4 = generated key frame (reserved for server use only)
VideoFrameTypeInfo // 5 = video info/command frame
)
func (v VideoFrameType) String() string {
switch v {
case VideoFrameTypeKeyframe:
return "Keyframe"
case VideoFrameTypeInterframe:
return "Interframe"
case VideoFrameTypeDisposable:
return "DisposableInterframe"
case VideoFrameTypeGenerated:
return "GeneratedKeyframe"
case VideoFrameTypeInfo:
return "Info"
default:
return "Forbidden"
}
}
// The video codec id.
// Refer to @doc video_file_format_spec_v10.pdf, @page 78, @section E.4.3 Video Tags
// It's 4bits, that is 0-16.
type VideoCodec uint8
const (
VideoCodecForbidden VideoCodec = iota + 1
VideoCodecH263 // 2 = Sorenson H.263
VideoCodecScreen // 3 = Screen video
VideoCodecOn2VP6 // 4 = On2 VP6
VideoCodecOn2VP6Alpha // 5 = On2 VP6 with alpha channel
VideoCodecScreen2 // 6 = Screen video version 2
VideoCodecAVC // 7 = AVC
// See page 79 at @doc https://github.com/CDN-Union/H265/blob/master/Document/video_file_format_spec_v10_1_ksyun_20170615.doc
VideoCodecHEVC VideoCodec = 12 // 12 = HEVC
)
func (v VideoCodec) String() string {
switch v {
case VideoCodecH263:
return "H.263"
case VideoCodecScreen:
return "Screen"
case VideoCodecOn2VP6:
return "VP6"
case VideoCodecOn2VP6Alpha:
return "On2VP6(alpha)"
case VideoCodecScreen2:
return "Screen2"
case VideoCodecAVC:
return "AVC"
case VideoCodecHEVC:
return "HEVC"
default:
return "Forbidden"
}
}
// The video AVC frame trait, whethere sequence header or not.
// Refer to @doc video_file_format_spec_v10.pdf, @page 78, @section E.4.3 Video Tags
// If AVC or HEVC, it's 8bits.
type VideoFrameTrait uint8
const (
VideoFrameTraitSequenceHeader VideoFrameTrait = iota // 0 = AVC/HEVC sequence header
VideoFrameTraitNALU // 1 = AVC/HEVC NALU
VideoFrameTraitSequenceEOF // 2 = AVC/HEVC end of sequence (lower level NALU sequence ender is
VideoFrameTraitForbidden
)
func (v VideoFrameTrait) String() string {
switch v {
case VideoFrameTraitSequenceHeader:
return "SequenceHeader"
case VideoFrameTraitNALU:
return "NALU"
case VideoFrameTraitSequenceEOF:
return "SequenceEOF"
default:
return "Forbidden"
}
}
type VideoFrame struct {
CodecID VideoCodec
FrameType VideoFrameType
Trait VideoFrameTrait
CTS int32
Raw []byte
}
func NewVideoFrame() *VideoFrame {
return &VideoFrame{}
}
// The packager used to codec the FLV video tag body.
// Refer to @doc video_file_format_spec_v10.pdf, @page 78, @section E.4.3 Video Tags
type VideoPackager interface {
// Decode the FLV video tag to video frame.
// @remark For RTMP/FLV: pts = dts + cts, where dts is timestamp in packet/tag.
Decode(tag []byte) (frame *VideoFrame, err error)
// Encode the video frame to FLV video tag.
Encode(frame *VideoFrame) (tag []byte, err error)
}
type videoPackager struct {
}
func NewVideoPackager() (VideoPackager, error) {
return &videoPackager{}, nil
}
func (v *videoPackager) Decode(tag []byte) (frame *VideoFrame, err error) {
if len(tag) < 5 {
err = errDataNotEnough
return
}
p := tag
frame = &VideoFrame{}
frame.FrameType = VideoFrameType(byte(p[0]>>4) & 0x0f)
frame.CodecID = VideoCodec(byte(p[0]) & 0x0f)
if frame.CodecID == VideoCodecAVC || frame.CodecID == VideoCodecHEVC {
frame.Trait = VideoFrameTrait(p[1])
frame.CTS = int32(uint32(p[2])<<16 | uint32(p[3])<<8 | uint32(p[4]))
frame.Raw = tag[5:]
} else {
frame.Raw = tag[1:]
}
return
}
func (v videoPackager) Encode(frame *VideoFrame) (tag []byte, err error) {
if frame.CodecID == VideoCodecAVC || frame.CodecID == VideoCodecHEVC {
return append([]byte{
byte(frame.FrameType)<<4 | byte(frame.CodecID), byte(frame.Trait),
byte(frame.CTS >> 16), byte(frame.CTS >> 8), byte(frame.CTS),
}, frame.Raw...), nil
} else {
return append([]byte{
byte(frame.FrameType)<<4 | byte(frame.CodecID),
}, frame.Raw...), nil
}
}

View file

@ -1,2 +1,24 @@
vendor
*-fuzz.zip
### JetBrains IDE ###
#####################
.idea/
### Emacs Temporary Files ###
#############################
*~
### Folders ###
###############
bin/
vendor/
node_modules/
### Files ###
#############
*.ivf
*.ogg
tags
cover.out
*.sw[poe]
*.wasm
examples/sfu-ws/cert.pem
examples/sfu-ws/key.pem

View file

@ -53,6 +53,7 @@ We would love contributes that fall under the 'Planned Features' and fixing any
* TLS_PSK_WITH_AES_128_CCM ([RFC 6655][rfc6655])
* TLS_PSK_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655])
* TLS_PSK_WITH_AES_128_GCM_SHA256 ([RFC 5487][rfc5487])
* TLS_PSK_WITH_AES_128_CBC_SHA256 ([RFC 5487][rfc5487])
[rfc5289]: https://tools.ietf.org/html/rfc5289
[rfc8422]: https://tools.ietf.org/html/rfc8422
@ -146,6 +147,9 @@ Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contribu
* [ZHENK](https://github.com/scorpionknifes)
* [Carson Hoffman](https://github.com/CarsonHoffman)
* [Vadim Filimonov](https://github.com/fffilimonov)
* [Jim Wert](https://github.com/bocajim)
* [Alvaro Viebrantz](https://github.com/alvarowolfx)
* [Kegan Dougal](https://github.com/Kegsay)
### License
MIT License - see [LICENSE](LICENSE) for full text

View file

@ -1,145 +0,0 @@
package dtls
import "fmt"
type alertLevel byte
const (
alertLevelWarning alertLevel = 1
alertLevelFatal alertLevel = 2
)
func (a alertLevel) String() string {
switch a {
case alertLevelWarning:
return "LevelWarning"
case alertLevelFatal:
return "LevelFatal"
default:
return "Invalid alert level"
}
}
type alertDescription byte
const (
alertCloseNotify alertDescription = 0
alertUnexpectedMessage alertDescription = 10
alertBadRecordMac alertDescription = 20
alertDecryptionFailed alertDescription = 21
alertRecordOverflow alertDescription = 22
alertDecompressionFailure alertDescription = 30
alertHandshakeFailure alertDescription = 40
alertNoCertificate alertDescription = 41
alertBadCertificate alertDescription = 42
alertUnsupportedCertificate alertDescription = 43
alertCertificateRevoked alertDescription = 44
alertCertificateExpired alertDescription = 45
alertCertificateUnknown alertDescription = 46
alertIllegalParameter alertDescription = 47
alertUnknownCA alertDescription = 48
alertAccessDenied alertDescription = 49
alertDecodeError alertDescription = 50
alertDecryptError alertDescription = 51
alertExportRestriction alertDescription = 60
alertProtocolVersion alertDescription = 70
alertInsufficientSecurity alertDescription = 71
alertInternalError alertDescription = 80
alertUserCanceled alertDescription = 90
alertNoRenegotiation alertDescription = 100
alertUnsupportedExtension alertDescription = 110
)
func (a alertDescription) String() string {
switch a {
case alertCloseNotify:
return "CloseNotify"
case alertUnexpectedMessage:
return "UnexpectedMessage"
case alertBadRecordMac:
return "BadRecordMac"
case alertDecryptionFailed:
return "DecryptionFailed"
case alertRecordOverflow:
return "RecordOverflow"
case alertDecompressionFailure:
return "DecompressionFailure"
case alertHandshakeFailure:
return "HandshakeFailure"
case alertNoCertificate:
return "NoCertificate"
case alertBadCertificate:
return "BadCertificate"
case alertUnsupportedCertificate:
return "UnsupportedCertificate"
case alertCertificateRevoked:
return "CertificateRevoked"
case alertCertificateExpired:
return "CertificateExpired"
case alertCertificateUnknown:
return "CertificateUnknown"
case alertIllegalParameter:
return "IllegalParameter"
case alertUnknownCA:
return "UnknownCA"
case alertAccessDenied:
return "AccessDenied"
case alertDecodeError:
return "DecodeError"
case alertDecryptError:
return "DecryptError"
case alertExportRestriction:
return "ExportRestriction"
case alertProtocolVersion:
return "ProtocolVersion"
case alertInsufficientSecurity:
return "InsufficientSecurity"
case alertInternalError:
return "InternalError"
case alertUserCanceled:
return "UserCanceled"
case alertNoRenegotiation:
return "NoRenegotiation"
case alertUnsupportedExtension:
return "UnsupportedExtension"
default:
return "Invalid alert description"
}
}
// One of the content types supported by the TLS record layer is the
// alert type. Alert messages convey the severity of the message
// (warning or fatal) and a description of the alert. Alert messages
// with a level of fatal result in the immediate termination of the
// connection. In this case, other connections corresponding to the
// session may continue, but the session identifier MUST be invalidated,
// preventing the failed session from being used to establish new
// connections. Like other messages, alert messages are encrypted and
// compressed, as specified by the current connection state.
// https://tools.ietf.org/html/rfc5246#section-7.2
type alert struct {
alertLevel alertLevel
alertDescription alertDescription
}
func (a alert) contentType() contentType {
return contentTypeAlert
}
func (a *alert) Marshal() ([]byte, error) {
return []byte{byte(a.alertLevel), byte(a.alertDescription)}, nil
}
func (a *alert) Unmarshal(data []byte) error {
if len(data) != 2 {
return errBufferTooSmall
}
a.alertLevel = alertLevel(data[0])
a.alertDescription = alertDescription(data[1])
return nil
}
func (a *alert) String() string {
return fmt.Sprintf("Alert %s: %s", a.alertLevel, a.alertDescription)
}

View file

@ -1,23 +0,0 @@
package dtls
// Application data messages are carried by the record layer and are
// fragmented, compressed, and encrypted based on the current connection
// state. The messages are treated as transparent data to the record
// layer.
// https://tools.ietf.org/html/rfc5246#section-10
type applicationData struct {
data []byte
}
func (a applicationData) contentType() contentType {
return contentTypeApplicationData
}
func (a *applicationData) Marshal() ([]byte, error) {
return append([]byte{}, a.data...), nil
}
func (a *applicationData) Unmarshal(data []byte) error {
a.data = append([]byte{}, data...)
return nil
}

View file

@ -1,25 +0,0 @@
package dtls
// The change cipher spec protocol exists to signal transitions in
// ciphering strategies. The protocol consists of a single message,
// which is encrypted and compressed under the current (not the pending)
// connection state. The message consists of a single byte of value 1.
// https://tools.ietf.org/html/rfc5246#section-7.1
type changeCipherSpec struct {
}
func (c changeCipherSpec) contentType() contentType {
return contentTypeChangeCipherSpec
}
func (c *changeCipherSpec) Marshal() ([]byte, error) {
return []byte{0x01}, nil
}
func (c *changeCipherSpec) Unmarshal(data []byte) error {
if len(data) == 1 && data[0] == 0x01 {
return nil
}
return errInvalidCipherSpec
}

View file

@ -1,73 +1,72 @@
package dtls
import (
"encoding/binary"
"fmt"
"hash"
"github.com/pion/dtls/v2/internal/ciphersuite"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
// CipherSuiteID is an ID for our supported CipherSuites
type CipherSuiteID uint16
type CipherSuiteID = ciphersuite.ID
// Supported Cipher Suites
const (
// AES-128-CCM
TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = 0xc0ac //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = 0xc0ae //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 //nolint:golint,stylecheck
// AES-128-GCM-SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = 0xc02b //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = 0xc02f //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
// AES-256-CBC-SHA
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = 0xc00a //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = 0xc014 //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM CipherSuiteID = 0xc0a4 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = 0xc0a8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = 0x00a8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:golint,stylecheck
)
// CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite
type CipherSuiteAuthenticationType = ciphersuite.AuthenticationType
// AuthenticationType Enums
const (
CipherSuiteAuthenticationTypeCertificate CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeCertificate
CipherSuiteAuthenticationTypePreSharedKey CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypePreSharedKey
CipherSuiteAuthenticationTypeAnonymous CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeAnonymous
)
var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14
func (c CipherSuiteID) String() string {
switch c {
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM:
return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM"
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8"
case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
case TLS_PSK_WITH_AES_128_CCM:
return "TLS_PSK_WITH_AES_128_CCM"
case TLS_PSK_WITH_AES_128_CCM_8:
return "TLS_PSK_WITH_AES_128_CCM_8"
case TLS_PSK_WITH_AES_128_GCM_SHA256:
return "TLS_PSK_WITH_AES_128_GCM_SHA256"
default:
return fmt.Sprintf("unknown(%v)", uint16(c))
}
}
type cipherSuite interface {
// CipherSuite is an interface that all DTLS CipherSuites must satisfy
type CipherSuite interface {
// String of CipherSuite, only used for logging
String() string
// ID of CipherSuite.
ID() CipherSuiteID
certificateType() clientCertificateType
hashFunc() func() hash.Hash
isPSK() bool
isInitialized() bool
// Generate the internal encryption state
init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error
// What type of Certificate does this CipherSuite use
CertificateType() clientcertificate.Type
encrypt(pkt *recordLayer, raw []byte) ([]byte, error)
decrypt(in []byte) ([]byte, error)
// What Hash function is used during verification
HashFunc() func() hash.Hash
// AuthenticationType controls what authentication method is using during the handshake
AuthenticationType() CipherSuiteAuthenticationType
// Called when keying material has been generated, should initialize the internal cipher
Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error
IsInitialized() bool
Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error)
Decrypt(in []byte) ([]byte, error)
}
// CipherSuiteName provides the same functionality as tls.CipherSuiteName
@ -76,7 +75,7 @@ type cipherSuite interface {
// Our implementation differs slightly in that it takes in a CiperSuiteID,
// like the rest of our library, instead of a uint16 like crypto/tls.
func CipherSuiteName(id CipherSuiteID) string {
suite := cipherSuiteForID(id)
suite := cipherSuiteForID(id, nil)
if suite != nil {
return suite.String()
}
@ -86,87 +85,78 @@ func CipherSuiteName(id CipherSuiteID) string {
// Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
// A cipherSuite is a specific combination of key agreement, cipher and MAC
// function.
func cipherSuiteForID(id CipherSuiteID) cipherSuite {
switch id {
func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite {
switch id { //nolint:exhaustive
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM:
return newCipherSuiteTLSEcdheEcdsaWithAes128Ccm()
return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm()
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
return newCipherSuiteTLSEcdheEcdsaWithAes128Ccm8()
return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8()
case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return &cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{}
return &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}
case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return &cipherSuiteTLSEcdheRsaWithAes128GcmSha256{}
return &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}
case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
return &cipherSuiteTLSEcdheEcdsaWithAes256CbcSha{}
return &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}
case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
return &cipherSuiteTLSEcdheRsaWithAes256CbcSha{}
return &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}
case TLS_PSK_WITH_AES_128_CCM:
return newCipherSuiteTLSPskWithAes128Ccm()
return ciphersuite.NewTLSPskWithAes128Ccm()
case TLS_PSK_WITH_AES_128_CCM_8:
return newCipherSuiteTLSPskWithAes128Ccm8()
return ciphersuite.NewTLSPskWithAes128Ccm8()
case TLS_PSK_WITH_AES_128_GCM_SHA256:
return &cipherSuiteTLSPskWithAes128GcmSha256{}
return &ciphersuite.TLSPskWithAes128GcmSha256{}
case TLS_PSK_WITH_AES_128_CBC_SHA256:
return &ciphersuite.TLSPskWithAes128CbcSha256{}
}
if customCiphers != nil {
for _, c := range customCiphers() {
if c.ID() == id {
return c
}
}
}
return nil
}
// CipherSuites we support in order of preference
func defaultCipherSuites() []cipherSuite {
return []cipherSuite{
&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheRsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheEcdsaWithAes256CbcSha{},
&cipherSuiteTLSEcdheRsaWithAes256CbcSha{},
func defaultCipherSuites() []CipherSuite {
return []CipherSuite{
&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{},
&ciphersuite.TLSEcdheRsaWithAes128GcmSha256{},
&ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{},
&ciphersuite.TLSEcdheRsaWithAes256CbcSha{},
}
}
func allCipherSuites() []cipherSuite {
return []cipherSuite{
newCipherSuiteTLSEcdheEcdsaWithAes128Ccm(),
newCipherSuiteTLSEcdheEcdsaWithAes128Ccm8(),
&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheRsaWithAes128GcmSha256{},
&cipherSuiteTLSEcdheEcdsaWithAes256CbcSha{},
&cipherSuiteTLSEcdheRsaWithAes256CbcSha{},
newCipherSuiteTLSPskWithAes128Ccm(),
newCipherSuiteTLSPskWithAes128Ccm8(),
&cipherSuiteTLSPskWithAes128GcmSha256{},
func allCipherSuites() []CipherSuite {
return []CipherSuite{
ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm(),
ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8(),
&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{},
&ciphersuite.TLSEcdheRsaWithAes128GcmSha256{},
&ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{},
&ciphersuite.TLSEcdheRsaWithAes256CbcSha{},
ciphersuite.NewTLSPskWithAes128Ccm(),
ciphersuite.NewTLSPskWithAes128Ccm8(),
&ciphersuite.TLSPskWithAes128GcmSha256{},
}
}
func decodeCipherSuites(buf []byte) ([]cipherSuite, error) {
if len(buf) < 2 {
return nil, errDTLSPacketInvalidLength
}
cipherSuitesCount := int(binary.BigEndian.Uint16(buf[0:])) / 2
rtrn := []cipherSuite{}
for i := 0; i < cipherSuitesCount; i++ {
if len(buf) < (i*2 + 4) {
return nil, errBufferTooSmall
}
id := CipherSuiteID(binary.BigEndian.Uint16(buf[(i*2)+2:]))
if c := cipherSuiteForID(id); c != nil {
rtrn = append(rtrn, c)
}
}
return rtrn, nil
}
func encodeCipherSuites(cipherSuites []cipherSuite) []byte {
out := []byte{0x00, 0x00}
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(cipherSuites)*2))
func cipherSuiteIDs(cipherSuites []CipherSuite) []uint16 {
rtrn := []uint16{}
for _, c := range cipherSuites {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(c.ID()))
rtrn = append(rtrn, uint16(c.ID()))
}
return out
return rtrn
}
func parseCipherSuites(userSelectedSuites []CipherSuiteID, excludePSK, excludeNonPSK bool) ([]cipherSuite, error) {
cipherSuitesForIDs := func(ids []CipherSuiteID) ([]cipherSuite, error) {
cipherSuites := []cipherSuite{}
func parseCipherSuites(userSelectedSuites []CipherSuiteID, customCipherSuites func() []CipherSuite, includeCertificateSuites, includePSKSuites bool) ([]CipherSuite, error) {
cipherSuitesForIDs := func(ids []CipherSuiteID) ([]CipherSuite, error) {
cipherSuites := []CipherSuite{}
for _, id := range ids {
c := cipherSuiteForID(id)
c := cipherSuiteForID(id, nil)
if c == nil {
return nil, &invalidCipherSuite{id}
}
@ -176,11 +166,11 @@ func parseCipherSuites(userSelectedSuites []CipherSuiteID, excludePSK, excludeNo
}
var (
cipherSuites []cipherSuite
cipherSuites []CipherSuite
err error
i int
)
if len(userSelectedSuites) != 0 {
if userSelectedSuites != nil {
cipherSuites, err = cipherSuitesForIDs(userSelectedSuites)
if err != nil {
return nil, err
@ -189,18 +179,35 @@ func parseCipherSuites(userSelectedSuites []CipherSuiteID, excludePSK, excludeNo
cipherSuites = defaultCipherSuites()
}
// Put CustomCipherSuites before ID selected suites
if customCipherSuites != nil {
cipherSuites = append(customCipherSuites(), cipherSuites...)
}
var foundCertificateSuite, foundPSKSuite, foundAnonymousSuite bool
for _, c := range cipherSuites {
if excludePSK && c.isPSK() || excludeNonPSK && !c.isPSK() {
switch {
case includeCertificateSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate:
foundCertificateSuite = true
case includePSKSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey:
foundPSKSuite = true
case c.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous:
foundAnonymousSuite = true
default:
continue
}
cipherSuites[i] = c
i++
}
cipherSuites = cipherSuites[:i]
if len(cipherSuites) == 0 {
switch {
case includeCertificateSuites && !foundCertificateSuite && !foundAnonymousSuite:
return nil, errNoAvailableCertificateCipherSuite
case includePSKSuites && !foundPSKSuite:
return nil, errNoAvailablePSKCipherSuite
case i == 0:
return nil, errNoAvailableCipherSuites
}
return cipherSuites, nil
return cipherSuites[:i], nil
}

View file

@ -1,93 +0,0 @@
package dtls
import (
"crypto/sha256"
"errors"
"fmt"
"hash"
"sync/atomic"
)
type cipherSuiteAes128Ccm struct {
ccm atomic.Value // *cryptoCCM
clientCertificateType clientCertificateType
id CipherSuiteID
psk bool
cryptoCCMTagLen cryptoCCMTagLen
}
func newCipherSuiteAes128Ccm(clientCertificateType clientCertificateType, id CipherSuiteID, psk bool, cryptoCCMTagLen cryptoCCMTagLen) *cipherSuiteAes128Ccm {
return &cipherSuiteAes128Ccm{
clientCertificateType: clientCertificateType,
id: id,
psk: psk,
cryptoCCMTagLen: cryptoCCMTagLen,
}
}
func (c *cipherSuiteAes128Ccm) certificateType() clientCertificateType {
return c.clientCertificateType
}
func (c *cipherSuiteAes128Ccm) ID() CipherSuiteID {
return c.id
}
func (c *cipherSuiteAes128Ccm) String() string {
return c.id.String()
}
func (c *cipherSuiteAes128Ccm) hashFunc() func() hash.Hash {
return sha256.New
}
func (c *cipherSuiteAes128Ccm) isPSK() bool {
return c.psk
}
func (c *cipherSuiteAes128Ccm) isInitialized() bool {
return c.ccm.Load() != nil
}
func (c *cipherSuiteAes128Ccm) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 0
prfKeyLen = 16
prfIvLen = 4
)
keys, err := prfEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.hashFunc())
if err != nil {
return err
}
var ccm *cryptoCCM
if isClient {
ccm, err = newCryptoCCM(c.cryptoCCMTagLen, keys.clientWriteKey, keys.clientWriteIV, keys.serverWriteKey, keys.serverWriteIV)
} else {
ccm, err = newCryptoCCM(c.cryptoCCMTagLen, keys.serverWriteKey, keys.serverWriteIV, keys.clientWriteKey, keys.clientWriteIV)
}
c.ccm.Store(ccm)
return err
}
var errCipherSuiteNotInit = errors.New("CipherSuite has not been initialized")
func (c *cipherSuiteAes128Ccm) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
ccm := c.ccm.Load()
if ccm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return ccm.(*cryptoCCM).encrypt(pkt, raw)
}
func (c *cipherSuiteAes128Ccm) decrypt(raw []byte) ([]byte, error) {
ccm := c.ccm.Load()
if ccm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return ccm.(*cryptoCCM).decrypt(raw)
}

View file

@ -6,8 +6,12 @@ import (
"crypto/tls"
)
// VersionDTLS12 is the DTLS version in the same style as
// VersionTLSXX from crypto/tls
const VersionDTLS12 = 0xfefd
// Convert from our cipherSuite interface to a tls.CipherSuite struct
func toTLSCipherSuite(c cipherSuite) *tls.CipherSuite {
func toTLSCipherSuite(c CipherSuite) *tls.CipherSuite {
return &tls.CipherSuite{
ID: uint16(c.ID()),
Name: c.String(),

View file

@ -1,5 +0,0 @@
package dtls
func newCipherSuiteTLSEcdheEcdsaWithAes128Ccm() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateTypeECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM, false, cryptoCCMTagLength)
}

View file

@ -1,5 +0,0 @@
package dtls
func newCipherSuiteTLSEcdheEcdsaWithAes128Ccm8() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateTypeECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, false, cryptoCCM8TagLength)
}

View file

@ -1,77 +0,0 @@
package dtls
import (
"crypto/sha256"
"fmt"
"hash"
"sync/atomic"
)
type cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256 struct {
gcm atomic.Value // *cryptoGCM
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) certificateType() clientCertificateType {
return clientCertificateTypeECDSASign
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) ID() CipherSuiteID {
return TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) String() string {
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) hashFunc() func() hash.Hash {
return sha256.New
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) isPSK() bool {
return false
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) isInitialized() bool {
return c.gcm.Load() != nil
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 0
prfKeyLen = 16
prfIvLen = 4
)
keys, err := prfEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.hashFunc())
if err != nil {
return err
}
var gcm *cryptoGCM
if isClient {
gcm, err = newCryptoGCM(keys.clientWriteKey, keys.clientWriteIV, keys.serverWriteKey, keys.serverWriteIV)
} else {
gcm, err = newCryptoGCM(keys.serverWriteKey, keys.serverWriteIV, keys.clientWriteKey, keys.clientWriteIV)
}
c.gcm.Store(gcm)
return err
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
gcm := c.gcm.Load()
if gcm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return gcm.(*cryptoGCM).encrypt(pkt, raw)
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) decrypt(raw []byte) ([]byte, error) {
gcm := c.gcm.Load()
if gcm == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return gcm.(*cryptoGCM).decrypt(raw)
}

View file

@ -1,83 +0,0 @@
package dtls
import (
"crypto/sha256"
"fmt"
"hash"
"sync/atomic"
)
type cipherSuiteTLSEcdheEcdsaWithAes256CbcSha struct {
cbc atomic.Value // *cryptoCBC
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) certificateType() clientCertificateType {
return clientCertificateTypeECDSASign
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) ID() CipherSuiteID {
return TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) String() string {
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) hashFunc() func() hash.Hash {
return sha256.New
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) isPSK() bool {
return false
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) isInitialized() bool {
return c.cbc.Load() != nil
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 20
prfKeyLen = 32
prfIvLen = 16
)
keys, err := prfEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.hashFunc())
if err != nil {
return err
}
var cbc *cryptoCBC
if isClient {
cbc, err = newCryptoCBC(
keys.clientWriteKey, keys.clientWriteIV, keys.clientMACKey,
keys.serverWriteKey, keys.serverWriteIV, keys.serverMACKey,
)
} else {
cbc, err = newCryptoCBC(
keys.serverWriteKey, keys.serverWriteIV, keys.serverMACKey,
keys.clientWriteKey, keys.clientWriteIV, keys.clientMACKey,
)
}
c.cbc.Store(cbc)
return err
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cbc.(*cryptoCBC).encrypt(pkt, raw)
}
func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) decrypt(raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cbc.(*cryptoCBC).decrypt(raw)
}

View file

@ -1,17 +0,0 @@
package dtls
type cipherSuiteTLSEcdheRsaWithAes128GcmSha256 struct {
cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256
}
func (c *cipherSuiteTLSEcdheRsaWithAes128GcmSha256) certificateType() clientCertificateType {
return clientCertificateTypeRSASign
}
func (c *cipherSuiteTLSEcdheRsaWithAes128GcmSha256) ID() CipherSuiteID {
return TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
}
func (c *cipherSuiteTLSEcdheRsaWithAes128GcmSha256) String() string {
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
}

View file

@ -1,17 +0,0 @@
package dtls
type cipherSuiteTLSEcdheRsaWithAes256CbcSha struct {
cipherSuiteTLSEcdheEcdsaWithAes256CbcSha
}
func (c *cipherSuiteTLSEcdheRsaWithAes256CbcSha) certificateType() clientCertificateType {
return clientCertificateTypeRSASign
}
func (c *cipherSuiteTLSEcdheRsaWithAes256CbcSha) ID() CipherSuiteID {
return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
}
func (c *cipherSuiteTLSEcdheRsaWithAes256CbcSha) String() string {
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
}

View file

@ -1,5 +0,0 @@
package dtls
func newCipherSuiteTLSPskWithAes128Ccm() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateType(0), TLS_PSK_WITH_AES_128_CCM, true, cryptoCCMTagLength)
}

View file

@ -1,5 +0,0 @@
package dtls
func newCipherSuiteTLSPskWithAes128Ccm8() *cipherSuiteAes128Ccm {
return newCipherSuiteAes128Ccm(clientCertificateType(0), TLS_PSK_WITH_AES_128_CCM_8, true, cryptoCCM8TagLength)
}

View file

@ -1,21 +0,0 @@
package dtls
type cipherSuiteTLSPskWithAes128GcmSha256 struct {
cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) certificateType() clientCertificateType {
return clientCertificateType(0)
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) ID() CipherSuiteID {
return TLS_PSK_WITH_AES_128_GCM_SHA256
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) String() string {
return "TLS_PSK_WITH_AES_128_GCM_SHA256"
}
func (c *cipherSuiteTLSPskWithAes128GcmSha256) isPSK() bool {
return true
}

View file

@ -1,16 +0,0 @@
package dtls
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10
type clientCertificateType byte
const (
clientCertificateTypeRSASign clientCertificateType = 1
clientCertificateTypeECDSASign clientCertificateType = 64
)
func clientCertificateTypes() map[clientCertificateType]bool {
return map[clientCertificateType]bool{
clientCertificateTypeRSASign: true,
clientCertificateTypeECDSASign: true,
}
}

View file

@ -1,49 +1,9 @@
package dtls
type compressionMethodID byte
import "github.com/pion/dtls/v2/pkg/protocol"
const (
compressionMethodNull compressionMethodID = 0
)
type compressionMethod struct {
id compressionMethodID
}
func compressionMethods() map[compressionMethodID]*compressionMethod {
return map[compressionMethodID]*compressionMethod{
compressionMethodNull: {id: compressionMethodNull},
func defaultCompressionMethods() []*protocol.CompressionMethod {
return []*protocol.CompressionMethod{
{},
}
}
func defaultCompressionMethods() []*compressionMethod {
return []*compressionMethod{
compressionMethods()[compressionMethodNull],
}
}
func decodeCompressionMethods(buf []byte) ([]*compressionMethod, error) {
if len(buf) < 1 {
return nil, errDTLSPacketInvalidLength
}
compressionMethodsCount := int(buf[0])
c := []*compressionMethod{}
for i := 0; i < compressionMethodsCount; i++ {
if len(buf) <= i+1 {
return nil, errBufferTooSmall
}
id := compressionMethodID(buf[i+1])
if compressionMethod, ok := compressionMethods()[id]; ok {
c = append(c, compressionMethod)
}
}
return c, nil
}
func encodeCompressionMethods(c []*compressionMethod) []byte {
out := []byte{byte(len(c))}
for i := len(c); i > 0; i-- {
out = append(out, byte(c[i-1].id))
}
return out
}

View file

@ -6,11 +6,14 @@ import (
"crypto/ed25519"
"crypto/tls"
"crypto/x509"
"io"
"time"
"github.com/pion/logging"
)
const keyLogLabelTLS12 = "CLIENT_RANDOM"
// Config is used to configure a DTLS client or server.
// After a Config is passed to a DTLS function it must not be modified.
type Config struct {
@ -23,6 +26,11 @@ type Config struct {
// If CipherSuites is nil, a default list is used
CipherSuites []CipherSuiteID
// CustomCipherSuites is a list of CipherSuites that can be
// provided by the user. This allow users to user Ciphers that are reserved
// for private usage.
CustomCipherSuites func() []CipherSuite
// SignatureSchemes contains the signature and hash schemes that the peer requests to verify.
SignatureSchemes []tls.SignatureScheme
@ -107,6 +115,14 @@ type Config struct {
// Packet with sequence number older than this value compared to the latest
// accepted packet will be discarded. (default is 64)
ReplayProtectionWindow int
// KeyLogWriter optionally specifies a destination for TLS master secrets
// in NSS key log format that can be used to allow external programs
// such as Wireshark to decrypt TLS connections.
// See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format.
// Use of KeyLogWriter compromises security and should only be
// used for debugging.
KeyLogWriter io.Writer
}
func defaultConnectContextMaker() (context.Context, func()) {
@ -154,8 +170,6 @@ func validateConfig(config *Config) error {
switch {
case config == nil:
return errNoConfigProvided
case len(config.Certificates) > 0 && config.PSK != nil:
return errPSKAndCertificate
case config.PSKIdentityHint != nil && config.PSK == nil:
return errIdentityNoPSK
}
@ -174,6 +188,6 @@ func validateConfig(config *Config) error {
}
}
_, err := parseCipherSuites(config.CipherSuites, config.PSK == nil, config.PSK != nil)
_, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.PSK == nil || len(config.Certificates) > 0, config.PSK != nil)
return err
}

View file

@ -11,8 +11,14 @@ import (
"time"
"github.com/pion/dtls/v2/internal/closer"
"github.com/pion/dtls/v2/internal/net/connctx"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
"github.com/pion/logging"
"github.com/pion/transport/connctx"
"github.com/pion/transport/deadline"
"github.com/pion/transport/replaydetector"
)
@ -20,17 +26,12 @@ import (
const (
initialTickerInterval = time.Second
cookieLength = 20
defaultNamedCurve = namedCurveX25519
defaultNamedCurve = elliptic.X25519
inboundBufferSize = 8192
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
defaultReplayProtectionWindow = 64
)
var (
errApplicationDataEpochZero = errors.New("ApplicationData with epoch of 0")
errUnhandledContextType = errors.New("unhandled contentType")
)
func invalidKeyingLabels() map[string]bool {
return map[string]bool{
"client finished": true,
@ -86,12 +87,12 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
return nil, errNilNextConn
}
cipherSuites, err := parseCipherSuites(config.CipherSuites, config.PSK == nil, config.PSK != nil)
cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.PSK == nil || len(config.Certificates) > 0, config.PSK != nil)
if err != nil {
return nil, err
}
signatureSchemes, err := parseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err
}
@ -172,9 +173,11 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
verifyPeerCertificate: config.VerifyPeerCertificate,
rootCAs: config.RootCAs,
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
}
var initialFlight flightVal
@ -260,11 +263,8 @@ func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Con
// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
switch {
case config == nil:
if config == nil {
return nil, errNoConfigProvided
case config.PSK == nil && len(config.Certificates) == 0:
return nil, errServerMustHaveCertificate
}
return createConn(ctx, conn, config, false, nil)
@ -322,13 +322,13 @@ func (c *Conn) Write(p []byte) (int, error) {
return len(p), c.writePackets(c.writeDeadline, []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
epoch: c.getLocalEpoch(),
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Epoch: c.getLocalEpoch(),
Version: protocol.Version1_2,
},
content: &applicationData{
data: p,
Content: &protocol.ApplicationData{
Data: p,
},
},
shouldEncrypt: true,
@ -370,16 +370,16 @@ func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
var rawPackets [][]byte
for _, p := range pkts {
if h, ok := p.record.content.(*handshake); ok {
if h, ok := p.record.Content.(*handshake.Handshake); ok {
handshakeRaw, err := p.record.Marshal()
if err != nil {
return err
}
c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
srvCliStr(c.state.isClient), h.handshakeHeader.handshakeType.String(),
p.record.recordLayerHeader.epoch, h.handshakeHeader.messageSequence)
c.handshakeCache.push(handshakeRaw[recordLayerHeaderSize:], p.record.recordLayerHeader.epoch, h.handshakeHeader.messageSequence, h.handshakeHeader.handshakeType, c.state.isClient)
srvCliStr(c.state.isClient), h.Header.Type.String(),
p.record.Header.Epoch, h.Header.MessageSequence)
c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
rawHandshakePackets, err := c.processHandshakePacket(p, h)
if err != nil {
@ -400,7 +400,7 @@ func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
compactedRawPackets := c.compactRawPackets(rawPackets)
for _, compactedRawPackets := range compactedRawPackets {
if _, err := c.nextConn.Write(ctx, compactedRawPackets); err != nil {
if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
return netError(err)
}
}
@ -426,18 +426,18 @@ func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
}
func (c *Conn) processPacket(p *packet) ([]byte, error) {
epoch := p.record.recordLayerHeader.epoch
epoch := p.record.Header.Epoch
for len(c.state.localSequenceNumber) <= int(epoch) {
c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
}
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
if seq > maxSequenceNumber {
if seq > recordlayer.MaxSequenceNumber {
// RFC 6347 Section 4.1.0
// The implementation must either abandon an association or rehandshake
// prior to allowing the sequence number to wrap.
return nil, errSequenceNumberOverflow
}
p.record.recordLayerHeader.sequenceNumber = seq
p.record.Header.SequenceNumber = seq
rawPacket, err := p.record.Marshal()
if err != nil {
@ -446,7 +446,7 @@ func (c *Conn) processPacket(p *packet) ([]byte, error) {
if p.shouldEncrypt {
var err error
rawPacket, err = c.state.cipherSuite.encrypt(p.record, rawPacket)
rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
if err != nil {
return nil, err
}
@ -455,43 +455,43 @@ func (c *Conn) processPacket(p *packet) ([]byte, error) {
return rawPacket, nil
}
func (c *Conn) processHandshakePacket(p *packet, h *handshake) ([][]byte, error) {
func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
rawPackets := make([][]byte, 0)
handshakeFragments, err := c.fragmentHandshake(h)
if err != nil {
return nil, err
}
epoch := p.record.recordLayerHeader.epoch
epoch := p.record.Header.Epoch
for len(c.state.localSequenceNumber) <= int(epoch) {
c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
}
for _, handshakeFragment := range handshakeFragments {
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
if seq > maxSequenceNumber {
if seq > recordlayer.MaxSequenceNumber {
return nil, errSequenceNumberOverflow
}
recordLayerHeader := &recordLayerHeader{
protocolVersion: p.record.recordLayerHeader.protocolVersion,
contentType: p.record.recordLayerHeader.contentType,
contentLen: uint16(len(handshakeFragment)),
epoch: p.record.recordLayerHeader.epoch,
sequenceNumber: seq,
recordlayerHeader := &recordlayer.Header{
Version: p.record.Header.Version,
ContentType: p.record.Header.ContentType,
ContentLen: uint16(len(handshakeFragment)),
Epoch: p.record.Header.Epoch,
SequenceNumber: seq,
}
recordLayerHeaderBytes, err := recordLayerHeader.Marshal()
recordlayerHeaderBytes, err := recordlayerHeader.Marshal()
if err != nil {
return nil, err
}
p.record.recordLayerHeader = *recordLayerHeader
p.record.Header = *recordlayerHeader
rawPacket := append(recordLayerHeaderBytes, handshakeFragment...)
rawPacket := append(recordlayerHeaderBytes, handshakeFragment...)
if p.shouldEncrypt {
var err error
rawPacket, err = c.state.cipherSuite.encrypt(p.record, rawPacket)
rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
if err != nil {
return nil, err
}
@ -503,8 +503,8 @@ func (c *Conn) processHandshakePacket(p *packet, h *handshake) ([][]byte, error)
return rawPackets, nil
}
func (c *Conn) fragmentHandshake(h *handshake) ([][]byte, error) {
content, err := h.handshakeMessage.Marshal()
func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
content, err := h.Message.Marshal()
if err != nil {
return nil, err
}
@ -522,22 +522,22 @@ func (c *Conn) fragmentHandshake(h *handshake) ([][]byte, error) {
for _, contentFragment := range contentFragments {
contentFragmentLen := len(contentFragment)
handshakeHeaderFragment := &handshakeHeader{
handshakeType: h.handshakeHeader.handshakeType,
length: h.handshakeHeader.length,
messageSequence: h.handshakeHeader.messageSequence,
fragmentOffset: uint32(offset),
fragmentLength: uint32(contentFragmentLen),
headerFragment := &handshake.Header{
Type: h.Header.Type,
Length: h.Header.Length,
MessageSequence: h.Header.MessageSequence,
FragmentOffset: uint32(offset),
FragmentLength: uint32(contentFragmentLen),
}
offset += contentFragmentLen
handshakeHeaderFragmentRaw, err := handshakeHeaderFragment.Marshal()
headerFragmentRaw, err := headerFragment.Marshal()
if err != nil {
return nil, err
}
fragmentedHandshake := append(handshakeHeaderFragmentRaw, contentFragment...)
fragmentedHandshake := append(headerFragmentRaw, contentFragment...)
fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
}
@ -556,12 +556,12 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
defer poolReadBuffer.Put(bufptr)
b := *bufptr
i, err := c.nextConn.Read(ctx, b)
i, err := c.nextConn.ReadContext(ctx, b)
if err != nil {
return netError(err)
}
pkts, err := unpackDatagram(b[:i])
pkts, err := recordlayer.UnpackDatagram(b[:i])
if err != nil {
return err
}
@ -570,7 +570,7 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
for _, p := range pkts {
hs, alert, err := c.handleIncomingPacket(p, true)
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
if err == nil {
err = alertErr
}
@ -609,7 +609,7 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
for _, p := range pkts {
_, alert, err := c.handleIncomingPacket(p, false) // don't re-enqueue
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
if err == nil {
err = alertErr
}
@ -628,8 +628,8 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
return nil
}
func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, error) { //nolint:gocognit
h := &recordLayerHeader{}
func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
h := &recordlayer.Header{}
if err := h.Unmarshal(buf); err != nil {
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
@ -639,10 +639,10 @@ func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, err
// Validate epoch
remoteEpoch := c.getRemoteEpoch()
if h.epoch > remoteEpoch {
if h.epoch > remoteEpoch+1 {
if h.Epoch > remoteEpoch {
if h.Epoch > remoteEpoch+1 {
c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
h.epoch, h.sequenceNumber,
h.Epoch, h.SequenceNumber,
)
return false, nil, nil
}
@ -654,22 +654,22 @@ func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, err
}
// Anti-replay protection
for len(c.state.replayDetector) <= int(h.epoch) {
for len(c.state.replayDetector) <= int(h.Epoch) {
c.state.replayDetector = append(c.state.replayDetector,
replaydetector.New(c.replayProtectionWindow, maxSequenceNumber),
replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
)
}
markPacketAsValid, ok := c.state.replayDetector[int(h.epoch)].Check(h.sequenceNumber)
markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
if !ok {
c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
h.epoch, h.sequenceNumber,
h.Epoch, h.SequenceNumber,
)
return false, nil, nil
}
// Decrypt
if h.epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() {
if h.Epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handshake not finished, queuing packet")
@ -678,7 +678,7 @@ func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, err
}
var err error
buf, err = c.state.cipherSuite.decrypt(buf)
buf, err = c.state.cipherSuite.Decrypt(buf)
if err != nil {
c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
return false, nil, nil
@ -694,35 +694,35 @@ func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, err
} else if isHandshake {
markPacketAsValid()
for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
rawHandshake := &handshake{}
rawHandshake := &handshake.Handshake{}
if err := rawHandshake.Unmarshal(out); err != nil {
c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
continue
}
_ = c.handshakeCache.push(out, epoch, rawHandshake.handshakeHeader.messageSequence, rawHandshake.handshakeHeader.handshakeType, !c.state.isClient)
_ = c.handshakeCache.push(out, epoch, rawHandshake.Header.MessageSequence, rawHandshake.Header.Type, !c.state.isClient)
}
return true, nil, nil
}
r := &recordLayer{}
r := &recordlayer.RecordLayer{}
if err := r.Unmarshal(buf); err != nil {
return false, &alert{alertLevelFatal, alertDecodeError}, err
return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
}
switch content := r.content.(type) {
case *alert:
switch content := r.Content.(type) {
case *alert.Alert:
c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
var a *alert
if content.alertDescription == alertCloseNotify {
var a *alert.Alert
if content.Description == alert.CloseNotify {
// Respond with a close_notify [RFC5246 Section 7.2.1]
a = &alert{alertLevelWarning, alertCloseNotify}
a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
}
markPacketAsValid()
return false, a, &errAlert{content}
case *changeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() {
case *protocol.ChangeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debugf("CipherSuite not initialized, queuing packet")
@ -730,27 +730,27 @@ func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, err
return false, nil, nil
}
newRemoteEpoch := h.epoch + 1
newRemoteEpoch := h.Epoch + 1
c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
if c.getRemoteEpoch()+1 == newRemoteEpoch {
c.setRemoteEpoch(newRemoteEpoch)
markPacketAsValid()
}
case *applicationData:
if h.epoch == 0 {
return false, &alert{alertLevelFatal, alertUnexpectedMessage}, errApplicationDataEpochZero
case *protocol.ApplicationData:
if h.Epoch == 0 {
return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
}
markPacketAsValid()
select {
case c.decrypted <- content.data:
case c.decrypted <- content.Data:
case <-c.closed.Done():
}
default:
return false, &alert{alertLevelFatal, alertUnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.contentType())
return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
}
return false, nil, nil
}
@ -759,17 +759,17 @@ func (c *Conn) recvHandshake() <-chan chan struct{} {
return c.handshakeRecv
}
func (c *Conn) notify(ctx context.Context, level alertLevel, desc alertDescription) error {
func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
return c.writePackets(ctx, []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
epoch: c.getLocalEpoch(),
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Epoch: c.getLocalEpoch(),
Version: protocol.Version1_2,
},
content: &alert{
alertLevel: level,
alertDescription: desc,
Content: &alert.Alert{
Level: level,
Description: desc,
},
},
shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
@ -892,7 +892,7 @@ func (c *Conn) translateHandshakeCtxError(err error) error {
if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
return nil
}
return &HandshakeError{err}
return &HandshakeError{Err: err}
}
func (c *Conn) close(byUser bool) error {
@ -902,7 +902,7 @@ func (c *Conn) close(byUser bool) error {
if c.isHandshakeCompletedSuccessfully() && byUser {
// Discard error from notify() to return non-error on the first user call of Close()
// even if the underlying connection is already closed.
_ = c.notify(context.Background(), alertLevelWarning, alertCloseNotify)
_ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
}
c.closeLock.Lock()

View file

@ -1,17 +0,0 @@
package dtls
// https://tools.ietf.org/html/rfc4346#section-6.2.1
type contentType uint8
const (
contentTypeChangeCipherSpec contentType = 20
contentTypeAlert contentType = 21
contentTypeHandshake contentType = 22
contentTypeApplicationData contentType = 23
)
type content interface {
contentType() contentType
Marshal() ([]byte, error)
Unmarshal(data []byte) error
}

View file

@ -12,13 +12,16 @@ import (
"encoding/binary"
"math/big"
"time"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/crypto/hash"
)
type ecdsaSignature struct {
R, S *big.Int
}
func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve namedCurve) []byte {
func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve) []byte {
serverECDHParams := make([]byte, 4)
serverECDHParams[0] = 3 // named curve
binary.BigEndian.PutUint16(serverECDHParams[1:], uint16(namedCurve))
@ -38,24 +41,24 @@ func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve na
// hash/signature algorithm pair that appears in that extension
//
// https://tools.ietf.org/html/rfc5246#section-7.4.2
func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCurve namedCurve, privateKey crypto.PrivateKey, hashAlgorithm hashAlgorithm) ([]byte, error) {
func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) {
msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve)
switch p := privateKey.(type) {
case ed25519.PrivateKey:
// https://crypto.stackexchange.com/a/55483
return p.Sign(rand.Reader, msg, crypto.Hash(0))
case *ecdsa.PrivateKey:
hashed := hashAlgorithm.digest(msg)
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
hashed := hashAlgorithm.Digest(msg)
return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash())
case *rsa.PrivateKey:
hashed := hashAlgorithm.digest(msg)
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
hashed := hashAlgorithm.Digest(msg)
return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash())
}
return nil, errKeySignatureGenerateUnimplemented
}
func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hashAlgorithm, rawCertificates [][]byte) error { //nolint:dupl
func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.Algorithm, rawCertificates [][]byte) error { //nolint:dupl
if len(rawCertificates) == 0 {
return errLengthMismatch
}
@ -78,7 +81,7 @@ func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hashAl
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
return errInvalidECDSASignature
}
hashed := hashAlgorithm.digest(message)
hashed := hashAlgorithm.Digest(message)
if !ecdsa.Verify(p, hashed, ecdsaSig.R, ecdsaSig.S) {
return errKeySignatureMismatch
}
@ -86,8 +89,8 @@ func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hashAl
case *rsa.PublicKey:
switch certificate.SignatureAlgorithm {
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
hashed := hashAlgorithm.digest(message)
return rsa.VerifyPKCS1v15(p, hashAlgorithm.cryptoHash(), hashed, remoteKeySignature)
hashed := hashAlgorithm.Digest(message)
return rsa.VerifyPKCS1v15(p, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature)
default:
return errKeySignatureVerifyUnimplemented
}
@ -104,7 +107,7 @@ func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hashAl
// CertificateVerify message is sent to explicitly verify possession of
// the private key in the certificate.
// https://tools.ietf.org/html/rfc5246#section-7.3
func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.PrivateKey, hashAlgorithm hashAlgorithm) ([]byte, error) {
func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) {
h := sha256.New()
if _, err := h.Write(handshakeBodies); err != nil {
return nil, err
@ -116,15 +119,15 @@ func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.Private
// https://crypto.stackexchange.com/a/55483
return p.Sign(rand.Reader, hashed, crypto.Hash(0))
case *ecdsa.PrivateKey:
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash())
case *rsa.PrivateKey:
return p.Sign(rand.Reader, hashed, hashAlgorithm.cryptoHash())
return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash())
}
return nil, errInvalidSignatureAlgorithm
}
func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hashAlgorithm, remoteKeySignature []byte, rawCertificates [][]byte) error { //nolint:dupl
func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hash.Algorithm, remoteKeySignature []byte, rawCertificates [][]byte) error { //nolint:dupl
if len(rawCertificates) == 0 {
return errLengthMismatch
}
@ -147,7 +150,7 @@ func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hashAlgorithm
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
return errInvalidECDSASignature
}
hash := hashAlgorithm.digest(handshakeBodies)
hash := hashAlgorithm.Digest(handshakeBodies)
if !ecdsa.Verify(p, hash, ecdsaSig.R, ecdsaSig.S) {
return errKeySignatureMismatch
}
@ -155,8 +158,8 @@ func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hashAlgorithm
case *rsa.PublicKey:
switch certificate.SignatureAlgorithm {
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
hash := hashAlgorithm.digest(handshakeBodies)
return rsa.VerifyPKCS1v15(p, hashAlgorithm.cryptoHash(), hash, remoteKeySignature)
hash := hashAlgorithm.Digest(handshakeBodies)
return rsa.VerifyPKCS1v15(p, hashAlgorithm.CryptoHash(), hash, remoteKeySignature)
default:
return errKeySignatureVerifyUnimplemented
}
@ -216,17 +219,3 @@ func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName
}
return certificate[0].Verify(opts)
}
func generateAEADAdditionalData(h *recordLayerHeader, payloadLen int) []byte {
var additionalData [13]byte
// SequenceNumber MUST be set first
// we only want uint48, clobbering an extra 2 (using uint64, Golang doesn't have uint48)
binary.BigEndian.PutUint64(additionalData[:], h.sequenceNumber)
binary.BigEndian.PutUint16(additionalData[:], h.epoch)
additionalData[8] = byte(h.contentType)
additionalData[9] = h.protocolVersion.major
additionalData[10] = h.protocolVersion.minor
binary.BigEndian.PutUint16(additionalData[len(additionalData)-2:], uint16(payloadLen))
return additionalData[:]
}

View file

@ -1,133 +0,0 @@
package dtls
import ( //nolint:gci
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha1" //nolint:gosec
"encoding/binary"
)
// block ciphers using cipher block chaining.
type cbcMode interface {
cipher.BlockMode
SetIV([]byte)
}
// State needed to handle encrypted input/output
type cryptoCBC struct {
writeCBC, readCBC cbcMode
writeMac, readMac []byte
}
// Currently hardcoded to be SHA1 only
var cryptoCBCMacFunc = sha1.New //nolint:gochecknoglobals
func newCryptoCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte) (*cryptoCBC, error) {
writeBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
readBlock, err := aes.NewCipher(remoteKey)
if err != nil {
return nil, err
}
return &cryptoCBC{
writeCBC: cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode),
writeMac: localMac,
readCBC: cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode),
readMac: remoteMac,
}, nil
}
func (c *cryptoCBC) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
payload := raw[recordLayerHeaderSize:]
raw = raw[:recordLayerHeaderSize]
blockSize := c.writeCBC.BlockSize()
// Generate + Append MAC
h := pkt.recordLayerHeader
MAC, err := prfMac(h.epoch, h.sequenceNumber, h.contentType, h.protocolVersion, payload, c.writeMac)
if err != nil {
return nil, err
}
payload = append(payload, MAC...)
// Generate + Append padding
padding := make([]byte, blockSize-len(payload)%blockSize)
paddingLen := len(padding)
for i := 0; i < paddingLen; i++ {
padding[i] = byte(paddingLen - 1)
}
payload = append(payload, padding...)
// Generate IV
iv := make([]byte, blockSize)
if _, err := rand.Read(iv); err != nil {
return nil, err
}
// Set IV + Encrypt + Prepend IV
c.writeCBC.SetIV(iv)
c.writeCBC.CryptBlocks(payload, payload)
payload = append(iv, payload...)
// Prepend unencrypte header with encrypted payload
raw = append(raw, payload...)
// Update recordLayer size to include IV+MAC+Padding
binary.BigEndian.PutUint16(raw[recordLayerHeaderSize-2:], uint16(len(raw)-recordLayerHeaderSize))
return raw, nil
}
func (c *cryptoCBC) decrypt(in []byte) ([]byte, error) {
body := in[recordLayerHeaderSize:]
blockSize := c.readCBC.BlockSize()
mac := cryptoCBCMacFunc()
var h recordLayerHeader
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.contentType == contentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(body)%blockSize != 0 || len(body) < blockSize+max(mac.Size()+1, blockSize):
return nil, errNotEnoughRoomForNonce
}
// Set + remove per record IV
c.readCBC.SetIV(body[:blockSize])
body = body[blockSize:]
// Decrypt
c.readCBC.CryptBlocks(body, body)
// Padding+MAC needs to be checked in constant time
// Otherwise we reveal information about the level of correctness
paddingLen, paddingGood := examinePadding(body)
macSize := mac.Size()
if len(body) < macSize {
return nil, errInvalidMAC
}
dataEnd := len(body) - macSize - paddingLen
expectedMAC := body[dataEnd : dataEnd+macSize]
actualMAC, err := prfMac(h.epoch, h.sequenceNumber, h.contentType, h.protocolVersion, body[:dataEnd], c.readMac)
// Compute Local MAC and compare
if paddingGood != 255 || err != nil || !hmac.Equal(actualMAC, expectedMAC) {
return nil, errInvalidMAC
}
return append(in[:recordLayerHeaderSize], body[:dataEnd]...), nil
}

View file

@ -1,94 +0,0 @@
package dtls
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"fmt"
)
const (
cryptoGCMTagLength = 16
cryptoGCMNonceLength = 12
)
// State needed to handle encrypted input/output
type cryptoGCM struct {
localGCM, remoteGCM cipher.AEAD
localWriteIV, remoteWriteIV []byte
}
func newCryptoGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*cryptoGCM, error) {
localBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
localGCM, err := cipher.NewGCM(localBlock)
if err != nil {
return nil, err
}
remoteBlock, err := aes.NewCipher(remoteKey)
if err != nil {
return nil, err
}
remoteGCM, err := cipher.NewGCM(remoteBlock)
if err != nil {
return nil, err
}
return &cryptoGCM{
localGCM: localGCM,
localWriteIV: localWriteIV,
remoteGCM: remoteGCM,
remoteWriteIV: remoteWriteIV,
}, nil
}
func (c *cryptoGCM) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
payload := raw[recordLayerHeaderSize:]
raw = raw[:recordLayerHeaderSize]
nonce := make([]byte, cryptoGCMNonceLength)
copy(nonce, c.localWriteIV[:4])
if _, err := rand.Read(nonce[4:]); err != nil {
return nil, err
}
additionalData := generateAEADAdditionalData(&pkt.recordLayerHeader, len(payload))
encryptedPayload := c.localGCM.Seal(nil, nonce, payload, additionalData)
r := make([]byte, len(raw)+len(nonce[4:])+len(encryptedPayload))
copy(r, raw)
copy(r[len(raw):], nonce[4:])
copy(r[len(raw)+len(nonce[4:]):], encryptedPayload)
// Update recordLayer size to include explicit nonce
binary.BigEndian.PutUint16(r[recordLayerHeaderSize-2:], uint16(len(r)-recordLayerHeaderSize))
return r, nil
}
func (c *cryptoGCM) decrypt(in []byte) ([]byte, error) {
var h recordLayerHeader
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.contentType == contentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(in) <= (8 + recordLayerHeaderSize):
return nil, errNotEnoughRoomForNonce
}
nonce := make([]byte, 0, cryptoGCMNonceLength)
nonce = append(append(nonce, c.remoteWriteIV[:4]...), in[recordLayerHeaderSize:recordLayerHeaderSize+8]...)
out := in[recordLayerHeaderSize+8:]
additionalData := generateAEADAdditionalData(&h, len(out)-cryptoGCMTagLength)
out, err = c.remoteGCM.Open(out[:0], nonce, out, additionalData)
if err != nil {
return nil, fmt.Errorf("%w: %v", errDecryptPacket, err)
}
return append(in[:recordLayerHeaderSize], out...), nil
}

View file

@ -1,14 +0,0 @@
package dtls
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10
type ellipticCurveType byte
const (
ellipticCurveTypeNamedCurve ellipticCurveType = 0x03
)
func ellipticCurveTypes() map[ellipticCurveType]bool {
return map[ellipticCurveType]bool{
ellipticCurveTypeNamedCurve: true,
}
}

View file

@ -8,101 +8,78 @@ import (
"net"
"os"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"golang.org/x/xerrors"
)
// Typed errors
var (
ErrConnClosed = &FatalError{errors.New("conn is closed")} //nolint:goerr113
ErrConnClosed = &FatalError{Err: errors.New("conn is closed")} //nolint:goerr113
errDeadlineExceeded = &TimeoutError{xerrors.Errorf("read/write timeout: %w", context.DeadlineExceeded)}
errDeadlineExceeded = &TimeoutError{Err: xerrors.Errorf("read/write timeout: %w", context.DeadlineExceeded)}
errInvalidContentType = &TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113
errBufferTooSmall = &TemporaryError{errors.New("buffer is too small")} //nolint:goerr113
errContextUnsupported = &TemporaryError{errors.New("context is not supported for ExportKeyingMaterial")} //nolint:goerr113
errDTLSPacketInvalidLength = &TemporaryError{errors.New("packet is too short")} //nolint:goerr113
errHandshakeInProgress = &TemporaryError{errors.New("handshake is in progress")} //nolint:goerr113
errInvalidContentType = &TemporaryError{errors.New("invalid content type")} //nolint:goerr113
errInvalidMAC = &TemporaryError{errors.New("invalid mac")} //nolint:goerr113
errInvalidPacketLength = &TemporaryError{errors.New("packet length and declared length do not match")} //nolint:goerr113
errReservedExportKeyingMaterial = &TemporaryError{errors.New("ExportKeyingMaterial can not be used with a reserved label")} //nolint:goerr113
errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
errContextUnsupported = &TemporaryError{Err: errors.New("context is not supported for ExportKeyingMaterial")} //nolint:goerr113
errHandshakeInProgress = &TemporaryError{Err: errors.New("handshake is in progress")} //nolint:goerr113
errReservedExportKeyingMaterial = &TemporaryError{Err: errors.New("ExportKeyingMaterial can not be used with a reserved label")} //nolint:goerr113
errApplicationDataEpochZero = &TemporaryError{Err: errors.New("ApplicationData with epoch of 0")} //nolint:goerr113
errUnhandledContextType = &TemporaryError{Err: errors.New("unhandled contentType")} //nolint:goerr113
errCertificateVerifyNoCertificate = &FatalError{errors.New("client sent certificate verify but we have no certificate to verify")} //nolint:goerr113
errCipherSuiteNoIntersection = &FatalError{errors.New("client+server do not support any shared cipher suites")} //nolint:goerr113
errCipherSuiteUnset = &FatalError{errors.New("server hello can not be created without a cipher suite")} //nolint:goerr113
errClientCertificateNotVerified = &FatalError{errors.New("client sent certificate but did not verify it")} //nolint:goerr113
errClientCertificateRequired = &FatalError{errors.New("server required client verification, but got none")} //nolint:goerr113
errClientNoMatchingSRTPProfile = &FatalError{errors.New("server responded with SRTP Profile we do not support")} //nolint:goerr113
errClientRequiredButNoServerEMS = &FatalError{errors.New("client required Extended Master Secret extension, but server does not support it")} //nolint:goerr113
errCompressionMethodUnset = &FatalError{errors.New("server hello can not be created without a compression method")} //nolint:goerr113
errCookieMismatch = &FatalError{errors.New("client+server cookie does not match")} //nolint:goerr113
errCookieTooLong = &FatalError{errors.New("cookie must not be longer then 255 bytes")} //nolint:goerr113
errIdentityNoPSK = &FatalError{errors.New("PSK Identity Hint provided but PSK is nil")} //nolint:goerr113
errInvalidCertificate = &FatalError{errors.New("no certificate provided")} //nolint:goerr113
errInvalidCipherSpec = &FatalError{errors.New("cipher spec invalid")} //nolint:goerr113
errInvalidCipherSuite = &FatalError{errors.New("invalid or unknown cipher suite")} //nolint:goerr113
errInvalidClientKeyExchange = &FatalError{errors.New("unable to determine if ClientKeyExchange is a public key or PSK Identity")} //nolint:goerr113
errInvalidCompressionMethod = &FatalError{errors.New("invalid or unknown compression method")} //nolint:goerr113
errInvalidECDSASignature = &FatalError{errors.New("ECDSA signature contained zero or negative values")} //nolint:goerr113
errInvalidEllipticCurveType = &FatalError{errors.New("invalid or unknown elliptic curve type")} //nolint:goerr113
errInvalidExtensionType = &FatalError{errors.New("invalid extension type")} //nolint:goerr113
errInvalidHashAlgorithm = &FatalError{errors.New("invalid hash algorithm")} //nolint:goerr113
errInvalidNamedCurve = &FatalError{errors.New("invalid named curve")} //nolint:goerr113
errInvalidPrivateKey = &FatalError{errors.New("invalid private key type")} //nolint:goerr113
errInvalidSNIFormat = &FatalError{errors.New("invalid server name format")} //nolint:goerr113
errInvalidSignatureAlgorithm = &FatalError{errors.New("invalid signature algorithm")} //nolint:goerr113
errKeySignatureMismatch = &FatalError{errors.New("expected and actual key signature do not match")} //nolint:goerr113
errNilNextConn = &FatalError{errors.New("Conn can not be created with a nil nextConn")} //nolint:goerr113
errNoAvailableCipherSuites = &FatalError{errors.New("connection can not be created, no CipherSuites satisfy this Config")} //nolint:goerr113
errNoAvailableSignatureSchemes = &FatalError{errors.New("connection can not be created, no SignatureScheme satisfy this Config")} //nolint:goerr113
errNoCertificates = &FatalError{errors.New("no certificates configured")} //nolint:goerr113
errNoConfigProvided = &FatalError{errors.New("no config provided")} //nolint:goerr113
errNoSupportedEllipticCurves = &FatalError{errors.New("client requested zero or more elliptic curves that are not supported by the server")} //nolint:goerr113
errUnsupportedProtocolVersion = &FatalError{errors.New("unsupported protocol version")} //nolint:goerr113
errPSKAndCertificate = &FatalError{errors.New("Certificate and PSK provided")} //nolint:stylecheck
errPSKAndIdentityMustBeSetForClient = &FatalError{errors.New("PSK and PSK Identity Hint must both be set for client")} //nolint:goerr113
errRequestedButNoSRTPExtension = &FatalError{errors.New("SRTP support was requested but server did not respond with use_srtp extension")} //nolint:goerr113
errServerMustHaveCertificate = &FatalError{errors.New("Certificate is mandatory for server")} //nolint:stylecheck
errServerNoMatchingSRTPProfile = &FatalError{errors.New("client requested SRTP but we have no matching profiles")} //nolint:goerr113
errServerRequiredButNoClientEMS = &FatalError{errors.New("server requires the Extended Master Secret extension, but the client does not support it")} //nolint:goerr113
errVerifyDataMismatch = &FatalError{errors.New("expected and actual verify data does not match")} //nolint:goerr113
errCertificateVerifyNoCertificate = &FatalError{Err: errors.New("client sent certificate verify but we have no certificate to verify")} //nolint:goerr113
errCipherSuiteNoIntersection = &FatalError{Err: errors.New("client+server do not support any shared cipher suites")} //nolint:goerr113
errClientCertificateNotVerified = &FatalError{Err: errors.New("client sent certificate but did not verify it")} //nolint:goerr113
errClientCertificateRequired = &FatalError{Err: errors.New("server required client verification, but got none")} //nolint:goerr113
errClientNoMatchingSRTPProfile = &FatalError{Err: errors.New("server responded with SRTP Profile we do not support")} //nolint:goerr113
errClientRequiredButNoServerEMS = &FatalError{Err: errors.New("client required Extended Master Secret extension, but server does not support it")} //nolint:goerr113
errCookieMismatch = &FatalError{Err: errors.New("client+server cookie does not match")} //nolint:goerr113
errIdentityNoPSK = &FatalError{Err: errors.New("PSK Identity Hint provided but PSK is nil")} //nolint:goerr113
errInvalidCertificate = &FatalError{Err: errors.New("no certificate provided")} //nolint:goerr113
errInvalidCipherSuite = &FatalError{Err: errors.New("invalid or unknown cipher suite")} //nolint:goerr113
errInvalidECDSASignature = &FatalError{Err: errors.New("ECDSA signature contained zero or negative values")} //nolint:goerr113
errInvalidPrivateKey = &FatalError{Err: errors.New("invalid private key type")} //nolint:goerr113
errInvalidSignatureAlgorithm = &FatalError{Err: errors.New("invalid signature algorithm")} //nolint:goerr113
errKeySignatureMismatch = &FatalError{Err: errors.New("expected and actual key signature do not match")} //nolint:goerr113
errNilNextConn = &FatalError{Err: errors.New("Conn can not be created with a nil nextConn")} //nolint:goerr113
errNoAvailableCipherSuites = &FatalError{Err: errors.New("connection can not be created, no CipherSuites satisfy this Config")} //nolint:goerr113
errNoAvailablePSKCipherSuite = &FatalError{Err: errors.New("connection can not be created, pre-shared key present but no compatible CipherSuite")} //nolint:goerr113
errNoAvailableCertificateCipherSuite = &FatalError{Err: errors.New("connection can not be created, certificate present but no compatible CipherSuite")} //nolint:goerr113
errNoAvailableSignatureSchemes = &FatalError{Err: errors.New("connection can not be created, no SignatureScheme satisfy this Config")} //nolint:goerr113
errNoCertificates = &FatalError{Err: errors.New("no certificates configured")} //nolint:goerr113
errNoConfigProvided = &FatalError{Err: errors.New("no config provided")} //nolint:goerr113
errNoSupportedEllipticCurves = &FatalError{Err: errors.New("client requested zero or more elliptic curves that are not supported by the server")} //nolint:goerr113
errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113
errPSKAndIdentityMustBeSetForClient = &FatalError{Err: errors.New("PSK and PSK Identity Hint must both be set for client")} //nolint:goerr113
errRequestedButNoSRTPExtension = &FatalError{Err: errors.New("SRTP support was requested but server did not respond with use_srtp extension")} //nolint:goerr113
errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} //nolint:goerr113
errServerRequiredButNoClientEMS = &FatalError{Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it")} //nolint:goerr113
errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} //nolint:goerr113
errHandshakeMessageUnset = &InternalError{errors.New("handshake message unset, unable to marshal")} //nolint:goerr113
errInvalidFlight = &InternalError{errors.New("invalid flight number")} //nolint:goerr113
errKeySignatureGenerateUnimplemented = &InternalError{errors.New("unable to generate key signature, unimplemented")} //nolint:goerr113
errKeySignatureVerifyUnimplemented = &InternalError{errors.New("unable to verify key signature, unimplemented")} //nolint:goerr113
errLengthMismatch = &InternalError{errors.New("data length and declared length do not match")} //nolint:goerr113
errNotEnoughRoomForNonce = &InternalError{errors.New("buffer not long enough to contain nonce")} //nolint:goerr113
errNotImplemented = &InternalError{errors.New("feature has not been implemented yet")} //nolint:goerr113
errSequenceNumberOverflow = &InternalError{errors.New("sequence number overflow")} //nolint:goerr113
errUnableToMarshalFragmented = &InternalError{errors.New("unable to marshal fragmented handshakes")} //nolint:goerr113
errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} //nolint:goerr113
errKeySignatureGenerateUnimplemented = &InternalError{Err: errors.New("unable to generate key signature, unimplemented")} //nolint:goerr113
errKeySignatureVerifyUnimplemented = &InternalError{Err: errors.New("unable to verify key signature, unimplemented")} //nolint:goerr113
errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113
errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113
errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} //nolint:goerr113
)
// FatalError indicates that the DTLS connection is no longer available.
// It is mainly caused by wrong configuration of server or client.
type FatalError struct {
Err error
}
type FatalError = protocol.FatalError
// InternalError indicates and internal error caused by the implementation, and the DTLS connection is no longer available.
// It is mainly caused by bugs or tried to use unimplemented features.
type InternalError struct {
Err error
}
type InternalError = protocol.InternalError
// TemporaryError indicates that the DTLS connection is still available, but the request was failed temporary.
type TemporaryError struct {
Err error
}
type TemporaryError = protocol.TemporaryError
// TimeoutError indicates that the request was timed out.
type TimeoutError struct {
Err error
}
type TimeoutError = protocol.TimeoutError
// HandshakeError indicates that the handshake failed.
type HandshakeError struct {
Err error
}
type HandshakeError = protocol.HandshakeError
// invalidCipherSuite indicates an attempt at using an unsupported cipher suite.
type invalidCipherSuite struct {
@ -120,87 +97,22 @@ func (e *invalidCipherSuite) Is(err error) bool {
return false
}
// Timeout implements net.Error.Timeout()
func (*FatalError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*FatalError) Temporary() bool { return false }
// Unwrap implements Go1.13 error unwrapper.
func (e *FatalError) Unwrap() error { return e.Err }
func (e *FatalError) Error() string { return fmt.Sprintf("dtls fatal: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*InternalError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*InternalError) Temporary() bool { return false }
// Unwrap implements Go1.13 error unwrapper.
func (e *InternalError) Unwrap() error { return e.Err }
func (e *InternalError) Error() string { return fmt.Sprintf("dtls internal: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*TemporaryError) Timeout() bool { return false }
// Temporary implements net.Error.Temporary()
func (*TemporaryError) Temporary() bool { return true }
// Unwrap implements Go1.13 error unwrapper.
func (e *TemporaryError) Unwrap() error { return e.Err }
func (e *TemporaryError) Error() string { return fmt.Sprintf("dtls temporary: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (*TimeoutError) Timeout() bool { return true }
// Temporary implements net.Error.Temporary()
func (*TimeoutError) Temporary() bool { return true }
// Unwrap implements Go1.13 error unwrapper.
func (e *TimeoutError) Unwrap() error { return e.Err }
func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (e *HandshakeError) Timeout() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Timeout()
}
return false
}
// Temporary implements net.Error.Temporary()
func (e *HandshakeError) Temporary() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Temporary()
}
return false
}
// Unwrap implements Go1.13 error unwrapper.
func (e *HandshakeError) Unwrap() error { return e.Err }
func (e *HandshakeError) Error() string { return fmt.Sprintf("handshake error: %v", e.Err) }
// errAlert wraps DTLS alert notification as an error
type errAlert struct {
*alert
*alert.Alert
}
func (e *errAlert) Error() string {
return fmt.Sprintf("alert: %s", e.alert.String())
return fmt.Sprintf("alert: %s", e.Alert.String())
}
func (e *errAlert) IsFatalOrCloseNotify() bool {
return e.alertLevel == alertLevelFatal || e.alertDescription == alertCloseNotify
return e.Level == alert.Fatal || e.Description == alert.CloseNotify
}
func (e *errAlert) Is(err error) bool {
if other, ok := err.(*errAlert); ok {
return e.alertLevel == other.alertLevel && e.alertDescription == other.alertDescription
return e.Level == other.Level && e.Description == other.Description
}
return false
}
@ -216,14 +128,14 @@ func netError(err error) error {
case (*net.OpError):
if se, ok := e.Err.(*os.SyscallError); ok {
if se.Timeout() {
return &TimeoutError{err}
return &TimeoutError{Err: err}
}
if isOpErrorTemporary(se) {
return &TemporaryError{err}
return &TemporaryError{Err: err}
}
}
case (net.Error):
return err
}
return &FatalError{err}
return &FatalError{Err: err}
}

View file

@ -1,88 +0,0 @@
package dtls
import (
"encoding/binary"
)
// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml
type extensionValue uint16
const (
extensionServerNameValue extensionValue = 0
extensionSupportedEllipticCurvesValue extensionValue = 10
extensionSupportedPointFormatsValue extensionValue = 11
extensionSupportedSignatureAlgorithmsValue extensionValue = 13
extensionUseSRTPValue extensionValue = 14
extensionUseExtendedMasterSecretValue extensionValue = 23
extensionRenegotiationInfoValue extensionValue = 65281
)
type extension interface {
Marshal() ([]byte, error)
Unmarshal(data []byte) error
extensionValue() extensionValue
}
func decodeExtensions(buf []byte) ([]extension, error) {
if len(buf) < 2 {
return nil, errBufferTooSmall
}
declaredLen := binary.BigEndian.Uint16(buf)
if len(buf)-2 != int(declaredLen) {
return nil, errLengthMismatch
}
extensions := []extension{}
unmarshalAndAppend := func(data []byte, e extension) error {
err := e.Unmarshal(data)
if err != nil {
return err
}
extensions = append(extensions, e)
return nil
}
for offset := 2; offset < len(buf); {
if len(buf) < (offset + 2) {
return nil, errBufferTooSmall
}
var err error
switch extensionValue(binary.BigEndian.Uint16(buf[offset:])) {
case extensionServerNameValue:
err = unmarshalAndAppend(buf[offset:], &extensionServerName{})
case extensionSupportedEllipticCurvesValue:
err = unmarshalAndAppend(buf[offset:], &extensionSupportedEllipticCurves{})
case extensionUseSRTPValue:
err = unmarshalAndAppend(buf[offset:], &extensionUseSRTP{})
case extensionUseExtendedMasterSecretValue:
err = unmarshalAndAppend(buf[offset:], &extensionUseExtendedMasterSecret{})
case extensionRenegotiationInfoValue:
err = unmarshalAndAppend(buf[offset:], &extensionRenegotiationInfo{})
default:
}
if err != nil {
return nil, err
}
if len(buf) < (offset + 4) {
return nil, errBufferTooSmall
}
extensionLength := binary.BigEndian.Uint16(buf[offset+2:])
offset += (4 + int(extensionLength))
}
return extensions, nil
}
func encodeExtensions(e []extension) ([]byte, error) {
extensions := []byte{}
for _, e := range e {
raw, err := e.Marshal()
if err != nil {
return nil, err
}
extensions = append(extensions, raw...)
}
out := []byte{0x00, 0x00}
binary.BigEndian.PutUint16(out, uint16(len(extensions)))
return append(out, extensions...), nil
}

View file

@ -1,37 +0,0 @@
package dtls
import "encoding/binary"
const (
extensionRenegotiationInfoHeaderSize = 5
)
// https://tools.ietf.org/html/rfc5746
type extensionRenegotiationInfo struct {
renegotiatedConnection uint8
}
func (e extensionRenegotiationInfo) extensionValue() extensionValue {
return extensionRenegotiationInfoValue
}
func (e *extensionRenegotiationInfo) Marshal() ([]byte, error) {
out := make([]byte, extensionRenegotiationInfoHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(1)) // length
out[4] = e.renegotiatedConnection
return out, nil
}
func (e *extensionRenegotiationInfo) Unmarshal(data []byte) error {
if len(data) < extensionRenegotiationInfoHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
e.renegotiatedConnection = data[4]
return nil
}

View file

@ -1,70 +0,0 @@
package dtls
import (
"strings"
"golang.org/x/crypto/cryptobyte"
)
const extensionServerNameTypeDNSHostName = 0
type extensionServerName struct {
serverName string
}
func (e extensionServerName) extensionValue() extensionValue {
return extensionServerNameValue
}
func (e *extensionServerName) Marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(uint16(e.extensionValue()))
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(extensionServerNameTypeDNSHostName)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(e.serverName))
})
})
})
return b.Bytes()
}
func (e *extensionServerName) Unmarshal(data []byte) error {
s := cryptobyte.String(data)
var extension uint16
s.ReadUint16(&extension)
if extensionValue(extension) != e.extensionValue() {
return errInvalidExtensionType
}
var extData cryptobyte.String
s.ReadUint16LengthPrefixed(&extData)
var nameList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
return errInvalidSNIFormat
}
for !nameList.Empty() {
var nameType uint8
var serverName cryptobyte.String
if !nameList.ReadUint8(&nameType) ||
!nameList.ReadUint16LengthPrefixed(&serverName) ||
serverName.Empty() {
return errInvalidSNIFormat
}
if nameType != extensionServerNameTypeDNSHostName {
continue
}
if len(e.serverName) != 0 {
// Multiple names of the same name_type are prohibited.
return errInvalidSNIFormat
}
e.serverName = string(serverName)
// An SNI value may not include a trailing dot.
if strings.HasSuffix(e.serverName, ".") {
return errInvalidSNIFormat
}
}
return nil
}

View file

@ -1,54 +0,0 @@
package dtls
import (
"encoding/binary"
)
const (
extensionSupportedGroupsHeaderSize = 6
)
// https://tools.ietf.org/html/rfc8422#section-5.1.1
type extensionSupportedEllipticCurves struct {
ellipticCurves []namedCurve
}
func (e extensionSupportedEllipticCurves) extensionValue() extensionValue {
return extensionSupportedEllipticCurvesValue
}
func (e *extensionSupportedEllipticCurves) Marshal() ([]byte, error) {
out := make([]byte, extensionSupportedGroupsHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(e.ellipticCurves)*2)))
binary.BigEndian.PutUint16(out[4:], uint16(len(e.ellipticCurves)*2))
for _, v := range e.ellipticCurves {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v))
}
return out, nil
}
func (e *extensionSupportedEllipticCurves) Unmarshal(data []byte) error {
if len(data) <= extensionSupportedGroupsHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
groupCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if extensionSupportedGroupsHeaderSize+(groupCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < groupCount; i++ {
supportedGroupID := namedCurve(binary.BigEndian.Uint16(data[(extensionSupportedGroupsHeaderSize + (i * 2)):]))
if _, ok := namedCurves()[supportedGroupID]; ok {
e.ellipticCurves = append(e.ellipticCurves, supportedGroupID)
}
}
return nil
}

View file

@ -1,56 +0,0 @@
package dtls
import "encoding/binary"
const (
extensionSupportedPointFormatsSize = 5
)
type ellipticCurvePointFormat byte
const ellipticCurvePointFormatUncompressed ellipticCurvePointFormat = 0
// https://tools.ietf.org/html/rfc4492#section-5.1.2
type extensionSupportedPointFormats struct {
pointFormats []ellipticCurvePointFormat
}
func (e extensionSupportedPointFormats) extensionValue() extensionValue {
return extensionSupportedPointFormatsValue
}
func (e *extensionSupportedPointFormats) Marshal() ([]byte, error) {
out := make([]byte, extensionSupportedPointFormatsSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(1+(len(e.pointFormats))))
out[4] = byte(len(e.pointFormats))
for _, v := range e.pointFormats {
out = append(out, byte(v))
}
return out, nil
}
func (e *extensionSupportedPointFormats) Unmarshal(data []byte) error {
if len(data) <= extensionSupportedPointFormatsSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
pointFormatCount := int(binary.BigEndian.Uint16(data[4:]))
if extensionSupportedGroupsHeaderSize+(pointFormatCount) > len(data) {
return errLengthMismatch
}
for i := 0; i < pointFormatCount; i++ {
p := ellipticCurvePointFormat(data[extensionSupportedPointFormatsSize+i])
switch p {
case ellipticCurvePointFormatUncompressed:
e.pointFormats = append(e.pointFormats, p)
default:
}
}
return nil
}

View file

@ -1,60 +0,0 @@
package dtls
import (
"encoding/binary"
)
const (
extensionSupportedSignatureAlgorithmsHeaderSize = 6
)
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
type extensionSupportedSignatureAlgorithms struct {
signatureHashAlgorithms []signatureHashAlgorithm
}
func (e extensionSupportedSignatureAlgorithms) extensionValue() extensionValue {
return extensionSupportedSignatureAlgorithmsValue
}
func (e *extensionSupportedSignatureAlgorithms) Marshal() ([]byte, error) {
out := make([]byte, extensionSupportedSignatureAlgorithmsHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(e.signatureHashAlgorithms)*2)))
binary.BigEndian.PutUint16(out[4:], uint16(len(e.signatureHashAlgorithms)*2))
for _, v := range e.signatureHashAlgorithms {
out = append(out, []byte{0x00, 0x00}...)
out[len(out)-2] = byte(v.hash)
out[len(out)-1] = byte(v.signature)
}
return out, nil
}
func (e *extensionSupportedSignatureAlgorithms) Unmarshal(data []byte) error {
if len(data) <= extensionSupportedSignatureAlgorithmsHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
algorithmCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if extensionSupportedSignatureAlgorithmsHeaderSize+(algorithmCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < algorithmCount; i++ {
supportedHashAlgorithm := hashAlgorithm(data[extensionSupportedSignatureAlgorithmsHeaderSize+(i*2)])
supportedSignatureAlgorithm := signatureAlgorithm(data[extensionSupportedSignatureAlgorithmsHeaderSize+(i*2)+1])
if _, ok := hashAlgorithms()[supportedHashAlgorithm]; ok {
if _, ok := signatureAlgorithms()[supportedSignatureAlgorithm]; ok {
e.signatureHashAlgorithms = append(e.signatureHashAlgorithms, signatureHashAlgorithm{
supportedHashAlgorithm,
supportedSignatureAlgorithm,
})
}
}
}
return nil
}

View file

@ -1,40 +0,0 @@
package dtls
import "encoding/binary"
const (
extensionUseExtendedMasterSecretHeaderSize = 4
)
// https://tools.ietf.org/html/rfc8422
type extensionUseExtendedMasterSecret struct {
supported bool
}
func (e extensionUseExtendedMasterSecret) extensionValue() extensionValue {
return extensionUseExtendedMasterSecretValue
}
func (e *extensionUseExtendedMasterSecret) Marshal() ([]byte, error) {
if !e.supported {
return []byte{}, nil
}
out := make([]byte, extensionUseExtendedMasterSecretHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(0)) // length
return out, nil
}
func (e *extensionUseExtendedMasterSecret) Unmarshal(data []byte) error {
if len(data) < extensionUseExtendedMasterSecretHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
e.supported = true
return nil
}

View file

@ -1,53 +0,0 @@
package dtls
import "encoding/binary"
const (
extensionUseSRTPHeaderSize = 6
)
// https://tools.ietf.org/html/rfc8422
type extensionUseSRTP struct {
protectionProfiles []SRTPProtectionProfile
}
func (e extensionUseSRTP) extensionValue() extensionValue {
return extensionUseSRTPValue
}
func (e *extensionUseSRTP) Marshal() ([]byte, error) {
out := make([]byte, extensionUseSRTPHeaderSize)
binary.BigEndian.PutUint16(out, uint16(e.extensionValue()))
binary.BigEndian.PutUint16(out[2:], uint16(2+(len(e.protectionProfiles)*2)+ /* MKI Length */ 1))
binary.BigEndian.PutUint16(out[4:], uint16(len(e.protectionProfiles)*2))
for _, v := range e.protectionProfiles {
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v))
}
out = append(out, 0x00) /* MKI Length */
return out, nil
}
func (e *extensionUseSRTP) Unmarshal(data []byte) error {
if len(data) <= extensionUseSRTPHeaderSize {
return errBufferTooSmall
} else if extensionValue(binary.BigEndian.Uint16(data)) != e.extensionValue() {
return errInvalidExtensionType
}
profileCount := int(binary.BigEndian.Uint16(data[4:]) / 2)
if extensionSupportedGroupsHeaderSize+(profileCount*2) > len(data) {
return errLengthMismatch
}
for i := 0; i < profileCount; i++ {
supportedProfile := SRTPProtectionProfile(binary.BigEndian.Uint16(data[(extensionUseSRTPHeaderSize + (i * 2)):]))
if _, ok := srtpProtectionProfiles()[supportedProfile]; ok {
e.protectionProfiles = append(e.protectionProfiles, supportedProfile)
}
}
return nil
}

View file

@ -3,11 +3,17 @@ package dtls
import (
"context"
"crypto/rand"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
)
func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
seq, msgs, ok := cache.fullPullMap(0,
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
)
if !ok {
// No valid message received. Keep reading
@ -15,61 +21,68 @@ func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handsh
}
state.handshakeRecvSequence = seq
var clientHello *handshakeMessageClientHello
var clientHello *handshake.MessageClientHello
// Validate type
if clientHello, ok = msgs[handshakeTypeClientHello].(*handshakeMessageClientHello); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
}
if !clientHello.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
if !clientHello.Version.Equal(protocol.Version1_2) {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
}
state.remoteRandom = clientHello.random
state.remoteRandom = clientHello.Random
if state.cipherSuite, ok = findMatchingCipherSuite(clientHello.cipherSuites, cfg.localCipherSuites); !ok {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errCipherSuiteNoIntersection
cipherSuites := []CipherSuite{}
for _, id := range clientHello.CipherSuiteIDs {
if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil {
cipherSuites = append(cipherSuites, c)
}
}
for _, extension := range clientHello.extensions {
switch e := extension.(type) {
case *extensionSupportedEllipticCurves:
if len(e.ellipticCurves) == 0 {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errNoSupportedEllipticCurves
if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
}
for _, val := range clientHello.Extensions {
switch e := val.(type) {
case *extension.SupportedEllipticCurves:
if len(e.EllipticCurves) == 0 {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves
}
state.namedCurve = e.ellipticCurves[0]
case *extensionUseSRTP:
profile, ok := findMatchingSRTPProfile(e.protectionProfiles, cfg.localSRTPProtectionProfiles)
state.namedCurve = e.EllipticCurves[0]
case *extension.UseSRTP:
profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
if !ok {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errServerNoMatchingSRTPProfile
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile
}
state.srtpProtectionProfile = profile
case *extensionUseExtendedMasterSecret:
case *extension.UseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
}
case *extensionServerName:
state.serverName = e.serverName // remote server name
case *extension.ServerName:
state.serverName = e.ServerName // remote server name
}
}
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errServerRequiredButNoClientEMS
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS
}
if state.localKeypair == nil {
var err error
state.localKeypair, err = generateKeypair(state.namedCurve)
state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve)
if err != nil {
return 0, &alert{alertLevelFatal, alertIllegalParameter}, err
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
}
}
return flight2, nil, nil
}
func flight0Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
func flight0Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
// Initialize
state.cookie = make([]byte, cookieLength)
if _, err := rand.Read(state.cookie); err != nil {
@ -81,7 +94,7 @@ func flight0Generate(c flightConn, state *State, cache *handshakeCache, cfg *han
state.remoteEpoch.Store(zeroEpoch)
state.namedCurve = defaultNamedCurve
if err := state.localRandom.populate(); err != nil {
if err := state.localRandom.Populate(); err != nil {
return nil, nil, err
}

View file

@ -2,101 +2,108 @@ package dtls
import (
"context"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
// HelloVerifyRequest can be skipped by the server,
// so allow ServerHello during flight1 also
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeHelloVerifyRequest, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
if _, ok := msgs[handshakeTypeServerHello]; ok {
if _, ok := msgs[handshake.TypeServerHello]; ok {
// Flight1 and flight2 were skipped.
// Parse as flight3.
return flight3Parse(ctx, c, state, cache, cfg)
}
if h, ok := msgs[handshakeTypeHelloVerifyRequest].(*handshakeMessageHelloVerifyRequest); ok {
if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok {
// DTLS 1.2 clients must not assume that the server will use the protocol version
// specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
if !h.version.Equal(protocolVersion1_0) && !h.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
}
state.cookie = append([]byte{}, h.cookie...)
state.cookie = append([]byte{}, h.Cookie...)
state.handshakeRecvSequence = seq
return flight3, nil, nil
}
return 0, &alert{alertLevelFatal, alertInternalError}, nil
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
}
func flight1Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
func flight1Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
var zeroEpoch uint16
state.localEpoch.Store(zeroEpoch)
state.remoteEpoch.Store(zeroEpoch)
state.namedCurve = defaultNamedCurve
state.cookie = nil
if err := state.localRandom.populate(); err != nil {
if err := state.localRandom.Populate(); err != nil {
return nil, nil, err
}
extensions := []extension{
&extensionSupportedSignatureAlgorithms{
signatureHashAlgorithms: cfg.localSignatureSchemes,
extensions := []extension.Extension{
&extension.SupportedSignatureAlgorithms{
SignatureHashAlgorithms: cfg.localSignatureSchemes,
},
&extensionRenegotiationInfo{
renegotiatedConnection: 0,
&extension.RenegotiationInfo{
RenegotiatedConnection: 0,
},
}
if cfg.localPSKCallback == nil {
extensions = append(extensions, []extension{
&extensionSupportedEllipticCurves{
ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256, namedCurveP384},
extensions = append(extensions, []extension.Extension{
&extension.SupportedEllipticCurves{
EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
},
&extensionSupportedPointFormats{
pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
&extension.SupportedPointFormats{
PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
},
}...)
}
if len(cfg.localSRTPProtectionProfiles) > 0 {
extensions = append(extensions, &extensionUseSRTP{
protectionProfiles: cfg.localSRTPProtectionProfiles,
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: cfg.localSRTPProtectionProfiles,
})
}
if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
cfg.extendedMasterSecret == RequireExtendedMasterSecret {
extensions = append(extensions, &extensionUseExtendedMasterSecret{
supported: true,
extensions = append(extensions, &extension.UseExtendedMasterSecret{
Supported: true,
})
}
if len(cfg.serverName) > 0 {
extensions = append(extensions, &extensionServerName{serverName: cfg.serverName})
extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
}
return []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageClientHello{
version: protocolVersion1_2,
cookie: state.cookie,
random: state.localRandom,
cipherSuites: cfg.localCipherSuites,
compressionMethods: defaultCompressionMethods(),
extensions: extensions,
Content: &handshake.Handshake{
Message: &handshake.MessageClientHello{
Version: protocol.Version1_2,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
},
},
},

View file

@ -3,11 +3,16 @@ package dtls
import (
"bytes"
"context"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
)
if !ok {
// Client may retransmit the first ClientHello when HelloVerifyRequest is dropped.
@ -16,38 +21,38 @@ func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handsh
}
state.handshakeRecvSequence = seq
var clientHello *handshakeMessageClientHello
var clientHello *handshake.MessageClientHello
// Validate type
if clientHello, ok = msgs[handshakeTypeClientHello].(*handshakeMessageClientHello); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
}
if !clientHello.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
if !clientHello.Version.Equal(protocol.Version1_2) {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
}
if len(clientHello.cookie) == 0 {
if len(clientHello.Cookie) == 0 {
return 0, nil, nil
}
if !bytes.Equal(state.cookie, clientHello.cookie) {
return 0, &alert{alertLevelFatal, alertAccessDenied}, errCookieMismatch
if !bytes.Equal(state.cookie, clientHello.Cookie) {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.AccessDenied}, errCookieMismatch
}
return flight4, nil, nil
}
func flight2Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
func flight2Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
state.handshakeSendSequence = 0
return []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageHelloVerifyRequest{
version: protocolVersion1_2,
cookie: state.cookie,
Content: &handshake.Handshake{
Message: &handshake.MessageHelloVerifyRequest{
Version: protocol.Version1_2,
Cookie: state.cookie,
},
},
},

View file

@ -2,23 +2,31 @@ package dtls
import (
"context"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) { //nolint:gocognit
func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
// Clients may receive multiple HelloVerifyRequest messages with different cookies.
// Clients SHOULD handle this by sending a new ClientHello with a cookie in response
// to the new HelloVerifyRequest. RFC 6347 Section 4.2.1
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeHelloVerifyRequest, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
)
if ok {
if h, msgOk := msgs[handshakeTypeHelloVerifyRequest].(*handshakeMessageHelloVerifyRequest); msgOk {
if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk {
// DTLS 1.2 clients must not assume that the server will use the protocol version
// specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
if !h.version.Equal(protocolVersion1_0) && !h.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
}
state.cookie = append([]byte{}, h.cookie...)
state.cookie = append([]byte{}, h.Cookie...)
state.handshakeRecvSequence = seq
return flight3, nil, nil
}
@ -26,17 +34,17 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
if cfg.localPSKCallback != nil {
seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
)
} else {
seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
)
}
if !ok {
@ -45,130 +53,139 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
}
state.handshakeRecvSequence = seq
if h, ok := msgs[handshakeTypeServerHello].(*handshakeMessageServerHello); ok {
if !h.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
if h, ok := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); ok {
if !h.Version.Equal(protocol.Version1_2) {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
}
for _, extension := range h.extensions {
switch e := extension.(type) {
case *extensionUseSRTP:
profile, ok := findMatchingSRTPProfile(e.protectionProfiles, cfg.localSRTPProtectionProfiles)
for _, v := range h.Extensions {
switch e := v.(type) {
case *extension.UseSRTP:
profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
if !ok {
return 0, &alert{alertLevelFatal, alertIllegalParameter}, errClientNoMatchingSRTPProfile
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile
}
state.srtpProtectionProfile = profile
case *extensionUseExtendedMasterSecret:
case *extension.UseExtendedMasterSecret:
if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
state.extendedMasterSecret = true
}
}
}
if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errClientRequiredButNoServerEMS
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS
}
if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errRequestedButNoSRTPExtension
}
if _, ok := findMatchingCipherSuite([]cipherSuite{h.cipherSuite}, cfg.localCipherSuites); !ok {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errCipherSuiteNoIntersection
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension
}
state.cipherSuite = h.cipherSuite
state.remoteRandom = h.random
cfg.log.Tracef("[handshake] use cipher suite: %s", h.cipherSuite.String())
remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites)
if remoteCipherSuite == nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
}
selectedCipherSuite, ok := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites)
if !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
}
state.cipherSuite = selectedCipherSuite
state.remoteRandom = h.Random
cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String())
}
if h, ok := msgs[handshakeTypeCertificate].(*handshakeMessageCertificate); ok {
state.PeerCertificates = h.certificate
if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok {
state.PeerCertificates = h.Certificate
} else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate
}
if h, ok := msgs[handshakeTypeServerKeyExchange].(*handshakeMessageServerKeyExchange); ok {
if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok {
alertPtr, err := handleServerKeyExchange(c, state, cfg, h)
if err != nil {
return 0, alertPtr, err
}
}
if _, ok := msgs[handshakeTypeCertificateRequest].(*handshakeMessageCertificateRequest); ok {
if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
state.remoteRequestedCertificate = true
}
return flight5, nil, nil
}
func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshakeMessageServerKeyExchange) (*alert, error) {
func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
var err error
if cfg.localPSKCallback != nil {
var psk []byte
if psk, err = cfg.localPSKCallback(h.identityHint); err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
state.preMasterSecret = prfPSKPreMasterSecret(psk)
state.IdentityHint = h.IdentityHint
state.preMasterSecret = prf.PSKPreMasterSecret(psk)
} else {
if state.localKeypair, err = generateKeypair(h.namedCurve); err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
if state.preMasterSecret, err = prfPreMasterSecret(h.publicKey, state.localKeypair.privateKey, state.localKeypair.curve); err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
}
return nil, nil
}
func flight3Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
extensions := []extension{
&extensionSupportedSignatureAlgorithms{
signatureHashAlgorithms: cfg.localSignatureSchemes,
func flight3Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
extensions := []extension.Extension{
&extension.SupportedSignatureAlgorithms{
SignatureHashAlgorithms: cfg.localSignatureSchemes,
},
&extensionRenegotiationInfo{
renegotiatedConnection: 0,
&extension.RenegotiationInfo{
RenegotiatedConnection: 0,
},
}
if cfg.localPSKCallback == nil {
extensions = append(extensions, []extension{
&extensionSupportedEllipticCurves{
ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256, namedCurveP384},
extensions = append(extensions, []extension.Extension{
&extension.SupportedEllipticCurves{
EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
},
&extensionSupportedPointFormats{
pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
&extension.SupportedPointFormats{
PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
},
}...)
}
if len(cfg.localSRTPProtectionProfiles) > 0 {
extensions = append(extensions, &extensionUseSRTP{
protectionProfiles: cfg.localSRTPProtectionProfiles,
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: cfg.localSRTPProtectionProfiles,
})
}
if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
cfg.extendedMasterSecret == RequireExtendedMasterSecret {
extensions = append(extensions, &extensionUseExtendedMasterSecret{
supported: true,
extensions = append(extensions, &extension.UseExtendedMasterSecret{
Supported: true,
})
}
if len(cfg.serverName) > 0 {
extensions = append(extensions, &extensionServerName{serverName: cfg.serverName})
extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
}
return []*packet{
{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageClientHello{
version: protocolVersion1_2,
cookie: state.cookie,
random: state.localRandom,
cipherSuites: cfg.localCipherSuites,
compressionMethods: defaultCompressionMethods(),
extensions: extensions,
Content: &handshake.Handshake{
Message: &handshake.MessageClientHello{
Version: protocol.Version1_2,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
},
},
},

View file

@ -3,13 +3,23 @@ package dtls
import (
"context"
"crypto/x509"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) { //nolint:gocognit
func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, true},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeCertificateVerify, cfg.initialEpoch, true, true},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true},
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, true},
)
if !ok {
// No valid message received. Keep reading
@ -17,113 +27,114 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh
}
// Validate type
var clientKeyExchange *handshakeMessageClientKeyExchange
if clientKeyExchange, ok = msgs[handshakeTypeClientKeyExchange].(*handshakeMessageClientKeyExchange); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
var clientKeyExchange *handshake.MessageClientKeyExchange
if clientKeyExchange, ok = msgs[handshake.TypeClientKeyExchange].(*handshake.MessageClientKeyExchange); !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
}
if h, hasCert := msgs[handshakeTypeCertificate].(*handshakeMessageCertificate); hasCert {
state.PeerCertificates = h.certificate
if h, hasCert := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); hasCert {
state.PeerCertificates = h.Certificate
}
if h, hasCertVerify := msgs[handshakeTypeCertificateVerify].(*handshakeMessageCertificateVerify); hasCertVerify {
if h, hasCertVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasCertVerify {
if state.PeerCertificates == nil {
return 0, &alert{alertLevelFatal, alertNoCertificate}, errCertificateVerifyNoCertificate
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate
}
plainText := cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
)
// Verify that the pair of hash algorithm and signiture is listed.
var validSignatureScheme bool
for _, ss := range cfg.localSignatureSchemes {
if ss.hash == h.hashAlgorithm && ss.signature == h.signatureAlgorithm {
if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
validSignatureScheme = true
break
}
}
if !validSignatureScheme {
return 0, &alert{alertLevelFatal, alertInsufficientSecurity}, errNoAvailableSignatureSchemes
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
}
if err := verifyCertificateVerify(plainText, h.hashAlgorithm, h.signature, state.PeerCertificates); err != nil {
return 0, &alert{alertLevelFatal, alertBadCertificate}, err
if err := verifyCertificateVerify(plainText, h.HashAlgorithm, h.Signature, state.PeerCertificates); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
}
var chains [][]*x509.Certificate
var err error
var verified bool
if cfg.clientAuth >= VerifyClientCertIfGiven {
if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil {
return 0, &alert{alertLevelFatal, alertBadCertificate}, err
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
}
verified = true
}
if cfg.verifyPeerCertificate != nil {
if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
return 0, &alert{alertLevelFatal, alertBadCertificate}, err
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
}
}
state.peerCertificatesVerified = verified
}
if !state.cipherSuite.isInitialized() {
serverRandom := state.localRandom.marshalFixed()
clientRandom := state.remoteRandom.marshalFixed()
if !state.cipherSuite.IsInitialized() {
serverRandom := state.localRandom.MarshalFixed()
clientRandom := state.remoteRandom.MarshalFixed()
var err error
var preMasterSecret []byte
if cfg.localPSKCallback != nil {
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey {
var psk []byte
if psk, err = cfg.localPSKCallback(clientKeyExchange.identityHint); err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
preMasterSecret = prfPSKPreMasterSecret(psk)
state.IdentityHint = clientKeyExchange.IdentityHint
preMasterSecret = prf.PSKPreMasterSecret(psk)
} else {
preMasterSecret, err = prfPreMasterSecret(clientKeyExchange.publicKey, state.localKeypair.privateKey, state.localKeypair.curve)
preMasterSecret, err = prf.PreMasterSecret(clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve)
if err != nil {
return 0, &alert{alertLevelFatal, alertIllegalParameter}, err
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
}
}
if state.extendedMasterSecret {
var sessionHash []byte
sessionHash, err = cache.sessionHash(state.cipherSuite.hashFunc(), cfg.initialEpoch)
sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch)
if err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
state.masterSecret, err = prfExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.hashFunc())
state.masterSecret, err = prf.ExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
if err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
} else {
state.masterSecret, err = prfMasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.hashFunc())
state.masterSecret, err = prf.MasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
if err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
}
if err := state.cipherSuite.init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
}
// Now, encrypted packets can be handled
if err := c.handleQueuedPackets(ctx); err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
seq, msgs, ok = cache.fullPullMap(seq,
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, true, false},
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
)
if !ok {
// No valid message received. Keep reading
@ -131,25 +142,29 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh
}
state.handshakeRecvSequence = seq
if _, ok = msgs[handshakeTypeFinished].(*handshakeMessageFinished); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
}
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous {
return flight6, nil, nil
}
switch cfg.clientAuth {
case RequireAnyClientCert:
if state.PeerCertificates == nil {
return 0, &alert{alertLevelFatal, alertNoCertificate}, errClientCertificateRequired
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired
}
case VerifyClientCertIfGiven:
if state.PeerCertificates != nil && !state.peerCertificatesVerified {
return 0, &alert{alertLevelFatal, alertBadCertificate}, errClientCertificateNotVerified
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
}
case RequireAndVerifyClientCert:
if state.PeerCertificates == nil {
return 0, &alert{alertLevelFatal, alertNoCertificate}, errClientCertificateRequired
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired
}
if !state.peerCertificatesVerified {
return 0, &alert{alertLevelFatal, alertBadCertificate}, errClientCertificateNotVerified
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
}
case NoClientCert, RequestClientCert:
return flight6, nil, nil
@ -158,96 +173,100 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh
return flight6, nil, nil
}
func flight4Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
extensions := []extension{}
func flight4Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
extensions := []extension.Extension{&extension.RenegotiationInfo{
RenegotiatedConnection: 0,
}}
if (cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret {
extensions = append(extensions, &extensionUseExtendedMasterSecret{
supported: true,
extensions = append(extensions, &extension.UseExtendedMasterSecret{
Supported: true,
})
}
if state.srtpProtectionProfile != 0 {
extensions = append(extensions, &extensionUseSRTP{
protectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile},
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile},
})
}
if cfg.localPSKCallback == nil {
extensions = append(extensions, []extension{
&extensionSupportedEllipticCurves{
ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256, namedCurveP384},
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
extensions = append(extensions, []extension.Extension{
&extension.SupportedEllipticCurves{
EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
},
&extensionSupportedPointFormats{
pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
&extension.SupportedPointFormats{
PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
},
}...)
}
var pkts []*packet
cipherSuiteID := uint16(state.cipherSuite.ID())
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageServerHello{
version: protocolVersion1_2,
random: state.localRandom,
cipherSuite: state.cipherSuite,
compressionMethod: defaultCompressionMethods()[0],
extensions: extensions,
Content: &handshake.Handshake{
Message: &handshake.MessageServerHello{
Version: protocol.Version1_2,
Random: state.localRandom,
CipherSuiteID: &cipherSuiteID,
CompressionMethod: defaultCompressionMethods()[0],
Extensions: extensions,
},
},
},
})
if cfg.localPSKCallback == nil {
switch {
case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate:
certificate, err := cfg.getCertificate(cfg.serverName)
if err != nil {
return nil, &alert{alertLevelFatal, alertHandshakeFailure}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
}
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageCertificate{
certificate: certificate.Certificate,
Content: &handshake.Handshake{
Message: &handshake.MessageCertificate{
Certificate: certificate.Certificate,
},
},
},
})
serverRandom := state.localRandom.marshalFixed()
clientRandom := state.remoteRandom.marshalFixed()
serverRandom := state.localRandom.MarshalFixed()
clientRandom := state.remoteRandom.MarshalFixed()
// Find compatible signature scheme
signatureHashAlgo, err := selectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey)
signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey)
if err != nil {
return nil, &alert{alertLevelFatal, alertInsufficientSecurity}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
}
signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.publicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.hash)
signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.Hash)
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
state.localKeySignature = signature
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageServerKeyExchange{
ellipticCurveType: ellipticCurveTypeNamedCurve,
namedCurve: state.namedCurve,
publicKey: state.localKeypair.publicKey,
hashAlgorithm: signatureHashAlgo.hash,
signatureAlgorithm: signatureHashAlgo.signature,
signature: state.localKeySignature,
Content: &handshake.Handshake{
Message: &handshake.MessageServerKeyExchange{
EllipticCurveType: elliptic.CurveTypeNamedCurve,
NamedCurve: state.namedCurve,
PublicKey: state.localKeypair.PublicKey,
HashAlgorithm: signatureHashAlgo.Hash,
SignatureAlgorithm: signatureHashAlgo.Signature,
Signature: state.localKeySignature,
},
},
},
@ -255,33 +274,48 @@ func flight4Generate(c flightConn, state *State, cache *handshakeCache, cfg *han
if cfg.clientAuth > NoClientCert {
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageCertificateRequest{
certificateTypes: []clientCertificateType{clientCertificateTypeRSASign, clientCertificateTypeECDSASign},
signatureHashAlgorithms: cfg.localSignatureSchemes,
Content: &handshake.Handshake{
Message: &handshake.MessageCertificateRequest{
CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign},
SignatureHashAlgorithms: cfg.localSignatureSchemes,
},
},
},
})
}
} else if cfg.localPSKIdentityHint != nil {
case cfg.localPSKIdentityHint != nil:
// To help the client in selecting which identity to use, the server
// can provide a "PSK identity hint" in the ServerKeyExchange message.
// If no hint is provided, the ServerKeyExchange message is omitted.
//
// https://tools.ietf.org/html/rfc4279#section-2
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageServerKeyExchange{
identityHint: cfg.localPSKIdentityHint,
Content: &handshake.Handshake{
Message: &handshake.MessageServerKeyExchange{
IdentityHint: cfg.localPSKIdentityHint,
},
},
},
})
case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous:
pkts = append(pkts, &packet{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
Message: &handshake.MessageServerKeyExchange{
EllipticCurveType: elliptic.CurveTypeNamedCurve,
NamedCurve: state.namedCurve,
PublicKey: state.localKeypair.PublicKey,
},
},
},
@ -289,12 +323,12 @@ func flight4Generate(c flightConn, state *State, cache *handshakeCache, cfg *han
}
pkts = append(pkts, &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageServerHelloDone{},
Content: &handshake.Handshake{
Message: &handshake.MessageServerHelloDone{},
},
},
})

View file

@ -5,52 +5,59 @@ import (
"context"
"crypto"
"crypto/x509"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
func flight5Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
func flight5Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, false, false},
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
var finished *handshakeMessageFinished
if finished, ok = msgs[handshakeTypeFinished].(*handshakeMessageFinished); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
var finished *handshake.MessageFinished
if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
}
plainText := cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeCertificateVerify, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, true, false},
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
)
expectedVerifyData, err := prfVerifyDataServer(state.masterSecret, plainText, state.cipherSuite.hashFunc())
expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
if err != nil {
return 0, &alert{alertLevelFatal, alertInternalError}, err
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
if !bytes.Equal(expectedVerifyData, finished.verifyData) {
return 0, &alert{alertLevelFatal, alertHandshakeFailure}, errVerifyDataMismatch
if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
}
return flight5, nil, nil
}
func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) { //nolint:gocognit
func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit
var certBytes [][]byte
var privateKey crypto.PrivateKey
if len(cfg.localCertificates) > 0 {
certificate, err := cfg.getCertificate(cfg.serverName)
if err != nil {
return nil, &alert{alertLevelFatal, alertHandshakeFailure}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
}
certBytes = certificate.Certificate
privateKey = certificate.PrivateKey
@ -61,62 +68,62 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han
if state.remoteRequestedCertificate {
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageCertificate{
certificate: certBytes,
Content: &handshake.Handshake{
Message: &handshake.MessageCertificate{
Certificate: certBytes,
},
},
},
})
}
clientKeyExchange := &handshakeMessageClientKeyExchange{}
clientKeyExchange := &handshake.MessageClientKeyExchange{}
if cfg.localPSKCallback == nil {
clientKeyExchange.publicKey = state.localKeypair.publicKey
clientKeyExchange.PublicKey = state.localKeypair.PublicKey
} else {
clientKeyExchange.identityHint = cfg.localPSKIdentityHint
clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint
}
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: clientKeyExchange,
Content: &handshake.Handshake{
Message: clientKeyExchange,
},
},
})
serverKeyExchangeData := cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
)
serverKeyExchange := &handshakeMessageServerKeyExchange{}
serverKeyExchange := &handshake.MessageServerKeyExchange{}
// handshakeMessageServerKeyExchange is optional for PSK
if len(serverKeyExchangeData) == 0 {
alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshakeMessageServerKeyExchange{})
alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshake.MessageServerKeyExchange{})
if err != nil {
return nil, alertPtr, err
}
} else {
rawHandshake := &handshake{}
rawHandshake := &handshake.Handshake{}
err := rawHandshake.Unmarshal(serverKeyExchangeData)
if err != nil {
return nil, &alert{alertLevelFatal, alertUnexpectedMessage}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err
}
switch h := rawHandshake.handshakeMessage.(type) {
case *handshakeMessageServerKeyExchange:
switch h := rawHandshake.Message.(type) {
case *handshake.MessageServerKeyExchange:
serverKeyExchange = h
default:
return nil, &alert{alertLevelFatal, alertUnexpectedMessage}, errInvalidContentType
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType
}
}
@ -124,15 +131,15 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han
merged := []byte{}
seqPred := uint16(state.handshakeSendSequence)
for _, p := range pkts {
h, ok := p.record.content.(*handshake)
h, ok := p.record.Content.(*handshake.Handshake)
if !ok {
return nil, &alert{alertLevelFatal, alertInternalError}, errInvalidContentType
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
}
h.handshakeHeader.messageSequence = seqPred
h.Header.MessageSequence = seqPred
seqPred++
raw, err := h.Marshal()
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
merged = append(merged, raw...)
}
@ -146,98 +153,98 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han
// private key in the certificate.
if state.remoteRequestedCertificate && len(cfg.localCertificates) > 0 {
plainText := append(cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
), merged...)
// Find compatible signature scheme
signatureHashAlgo, err := selectSignatureScheme(cfg.localSignatureSchemes, privateKey)
signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, privateKey)
if err != nil {
return nil, &alert{alertLevelFatal, alertInsufficientSecurity}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
}
certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.hash)
certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.Hash)
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
state.localCertificatesVerify = certVerify
p := &packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageCertificateVerify{
hashAlgorithm: signatureHashAlgo.hash,
signatureAlgorithm: signatureHashAlgo.signature,
signature: state.localCertificatesVerify,
Content: &handshake.Handshake{
Message: &handshake.MessageCertificateVerify{
HashAlgorithm: signatureHashAlgo.Hash,
SignatureAlgorithm: signatureHashAlgo.Signature,
Signature: state.localCertificatesVerify,
},
},
},
}
pkts = append(pkts, p)
h, ok := p.record.content.(*handshake)
h, ok := p.record.Content.(*handshake.Handshake)
if !ok {
return nil, &alert{alertLevelFatal, alertInternalError}, errInvalidContentType
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
}
h.handshakeHeader.messageSequence = seqPred
h.Header.MessageSequence = seqPred
// seqPred++ // this is the last use of seqPred
raw, err := h.Marshal()
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
merged = append(merged, raw...)
}
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &changeCipherSpec{},
Content: &protocol.ChangeCipherSpec{},
},
})
if len(state.localVerifyData) == 0 {
plainText := cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeCertificateVerify, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, true, false},
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
)
var err error
state.localVerifyData, err = prfVerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.hashFunc())
state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc())
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
}
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
epoch: 1,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
Epoch: 1,
},
content: &handshake{
handshakeMessage: &handshakeMessageFinished{
verifyData: state.localVerifyData,
Content: &handshake.Handshake{
Message: &handshake.MessageFinished{
VerifyData: state.localVerifyData,
},
},
},
@ -248,66 +255,69 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han
return pkts, nil, nil
}
func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshakeMessageServerKeyExchange, sendingPlainText []byte) (*alert, error) { //nolint:gocognit
if state.cipherSuite.isInitialized() {
func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit
if state.cipherSuite.IsInitialized() {
return nil, nil
}
clientRandom := state.localRandom.marshalFixed()
serverRandom := state.remoteRandom.marshalFixed()
clientRandom := state.localRandom.MarshalFixed()
serverRandom := state.remoteRandom.MarshalFixed()
var err error
if state.extendedMasterSecret {
var sessionHash []byte
sessionHash, err = cache.sessionHash(state.cipherSuite.hashFunc(), cfg.initialEpoch, sendingPlainText)
sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText)
if err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
state.masterSecret, err = prfExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.hashFunc())
state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
if err != nil {
return &alert{alertLevelFatal, alertIllegalParameter}, err
return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
}
} else {
state.masterSecret, err = prfMasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.hashFunc())
state.masterSecret, err = prf.MasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
if err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
}
if cfg.localPSKCallback == nil {
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
// Verify that the pair of hash algorithm and signiture is listed.
var validSignatureScheme bool
for _, ss := range cfg.localSignatureSchemes {
if ss.hash == h.hashAlgorithm && ss.signature == h.signatureAlgorithm {
if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
validSignatureScheme = true
break
}
}
if !validSignatureScheme {
return &alert{alertLevelFatal, alertInsufficientSecurity}, errNoAvailableSignatureSchemes
return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
}
expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.publicKey, h.namedCurve)
if err = verifyKeySignature(expectedMsg, h.signature, h.hashAlgorithm, state.PeerCertificates); err != nil {
return &alert{alertLevelFatal, alertBadCertificate}, err
expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.PublicKey, h.NamedCurve)
if err = verifyKeySignature(expectedMsg, h.Signature, h.HashAlgorithm, state.PeerCertificates); err != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
}
var chains [][]*x509.Certificate
if !cfg.insecureSkipVerify {
if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil {
return &alert{alertLevelFatal, alertBadCertificate}, err
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
}
}
if cfg.verifyPeerCertificate != nil {
if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
return &alert{alertLevelFatal, alertBadCertificate}, err
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
}
}
}
if err = state.cipherSuite.init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil {
return &alert{alertLevelFatal, alertInternalError}, err
if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
return nil, nil
}

View file

@ -2,69 +2,75 @@ package dtls
import (
"context"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
func flight6Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert, error) {
func flight6Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1,
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, true, false},
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
if _, ok = msgs[handshakeTypeFinished].(*handshakeMessageFinished); !ok {
return 0, &alert{alertLevelFatal, alertInternalError}, nil
if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
}
// Other party retransmitted the last flight.
return flight6, nil, nil
}
func flight6Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert, error) {
func flight6Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
var pkts []*packet
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
content: &changeCipherSpec{},
Content: &protocol.ChangeCipherSpec{},
},
})
if len(state.localVerifyData) == 0 {
plainText := cache.pullAndMerge(
handshakeCachePullRule{handshakeTypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeCertificateVerify, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshakeTypeFinished, cfg.initialEpoch + 1, true, false},
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
)
var err error
state.localVerifyData, err = prfVerifyDataServer(state.masterSecret, plainText, state.cipherSuite.hashFunc())
state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
if err != nil {
return nil, &alert{alertLevelFatal, alertInternalError}, err
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
}
pkts = append(pkts,
&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
epoch: 1,
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
Epoch: 1,
},
content: &handshake{
handshakeMessage: &handshakeMessageFinished{
verifyData: state.localVerifyData,
Content: &handshake.Handshake{
Message: &handshake.MessageFinished{
VerifyData: state.localVerifyData,
},
},
},

View file

@ -2,13 +2,15 @@ package dtls
import (
"context"
"github.com/pion/dtls/v2/pkg/protocol/alert"
)
// Parse received handshakes and return next flightVal
type flightParser func(context.Context, flightConn, *State, *handshakeCache, *handshakeConfig) (flightVal, *alert, error)
type flightParser func(context.Context, flightConn, *State, *handshakeCache, *handshakeConfig) (flightVal, *alert.Alert, error)
// Generate flights
type flightGenerator func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert, error)
type flightGenerator func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert.Alert, error)
func (f flightVal) getFlightParser() (flightParser, error) {
switch f {

View file

@ -1,8 +1,14 @@
package dtls
import (
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
type fragment struct {
recordLayerHeader recordLayerHeader
handshakeHeader handshakeHeader
recordLayerHeader recordlayer.Header
handshakeHeader handshake.Header
data []byte
}
@ -27,29 +33,29 @@ func (f *fragmentBuffer) push(buf []byte) (bool, error) {
}
// fragment isn't a handshake, we don't need to handle it
if frag.recordLayerHeader.contentType != contentTypeHandshake {
if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake {
return false, nil
}
for buf = buf[recordLayerHeaderSize:]; len(buf) != 0; frag = new(fragment) {
for buf = buf[recordlayer.HeaderSize:]; len(buf) != 0; frag = new(fragment) {
if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
return false, err
}
if _, ok := f.cache[frag.handshakeHeader.messageSequence]; !ok {
f.cache[frag.handshakeHeader.messageSequence] = []*fragment{}
if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok {
f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{}
}
// end index should be the length of handshake header but if the handshake
// was fragmented, we should keep them all
end := int(handshakeHeaderLength + frag.handshakeHeader.length)
end := int(handshake.HeaderLength + frag.handshakeHeader.Length)
if size := len(buf); end > size {
end = size
}
// Discard all headers, when rebuilding the packet we will re-build
frag.data = append([]byte{}, buf[handshakeHeaderLength:end]...)
f.cache[frag.handshakeHeader.messageSequence] = append(f.cache[frag.handshakeHeader.messageSequence], frag)
frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...)
f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag)
buf = buf[end:]
}
@ -68,9 +74,9 @@ func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
rawMessage := []byte{}
appendMessage = func(targetOffset uint32) bool {
for _, f := range frags {
if f.handshakeHeader.fragmentOffset == targetOffset {
fragmentEnd := (f.handshakeHeader.fragmentOffset + f.handshakeHeader.fragmentLength)
if fragmentEnd != f.handshakeHeader.length {
if f.handshakeHeader.FragmentOffset == targetOffset {
fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength)
if fragmentEnd != f.handshakeHeader.Length {
if !appendMessage(fragmentEnd) {
return false
}
@ -89,15 +95,15 @@ func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
}
firstHeader := frags[0].handshakeHeader
firstHeader.fragmentOffset = 0
firstHeader.fragmentLength = firstHeader.length
firstHeader.FragmentOffset = 0
firstHeader.FragmentLength = firstHeader.Length
rawHeader, err := firstHeader.Marshal()
if err != nil {
return nil, 0
}
messageEpoch := frags[0].recordLayerHeader.epoch
messageEpoch := frags[0].recordLayerHeader.Epoch
delete(f.cache, f.currentMessageSequenceNumber)
f.currentMessageSequenceNumber++

View file

@ -4,7 +4,7 @@ package dtls
import "fmt"
func partialHeaderMismatch(a, b recordLayerHeader) bool {
func partialHeaderMismatch(a, b recordlayer.Header) bool {
// Ignoring content length for now.
a.contentLen = b.contentLen
return a != b
@ -26,10 +26,10 @@ func FuzzRecordLayer(data []byte) int {
if err = nr.Unmarshal(data); err != nil {
panic(err) // nolint
}
if partialHeaderMismatch(nr.recordLayerHeader, r.recordLayerHeader) {
if partialHeaderMismatch(nr.recordlayer.Header, r.recordlayer.Header) {
panic( // nolint
fmt.Sprintf("header mismatch: %+v != %+v",
nr.recordLayerHeader, r.recordLayerHeader,
nr.recordlayer.Header, r.recordlayer.Header,
),
)
}

View file

@ -2,10 +2,10 @@ module github.com/pion/dtls/v2
require (
github.com/pion/logging v0.2.2
github.com/pion/transport v0.10.1
github.com/pion/transport v0.12.2
github.com/pion/udp v0.1.0
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad
golang.org/x/net v0.0.0-20210119194325-5f4716e94777
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1
)

View file

@ -4,8 +4,8 @@ github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/transport v0.10.0 h1:9M12BSneJm6ggGhJyWpDveFOstJsTiQjkLf4M44rm80=
github.com/pion/transport v0.10.0/go.mod h1:BnHnUipd0rZQyTVB2SBGojFHT9CBt5C5TcsJSQGkvSE=
github.com/pion/transport v0.10.1 h1:2W+yJT+0mOQ160ThZYUx5Zp2skzshiNgxrNE9GUfhJM=
github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A=
github.com/pion/transport v0.12.2 h1:WYEjhloRHt1R86LhUKjC5y+P52Y11/QqEUalvtzVoys=
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
github.com/pion/udp v0.1.0 h1:uGxQsNyrqG3GLINv36Ff60covYmfrLoxzwnCsIYspXI=
github.com/pion/udp v0.1.0/go.mod h1:BPELIjbwE9PRbd/zxI/KYBnbo7B6+oA6YuEaNE8lths=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@ -18,20 +18,23 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 h1:pLI5jrR7OSLijeIDcmRxNmw2api+jEfxLoykJVice/E=
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102 h1:42cLlJJdEh+ySyeUUbEQ5bsTiq8voBeTuweGVkY6Puw=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7 h1:3uJsdck53FDIpWwLeAXlia9p4C8j0BO2xZrqzKpL0D8=
golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210119194325-5f4716e94777 h1:003p0dJM77cxMSyCPFphvZf/Y5/NXf5fzg6ufd1/Oew=
golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View file

@ -1,136 +0,0 @@
package dtls
// https://tools.ietf.org/html/rfc5246#section-7.4
type handshakeType uint8
const (
handshakeTypeHelloRequest handshakeType = 0
handshakeTypeClientHello handshakeType = 1
handshakeTypeServerHello handshakeType = 2
handshakeTypeHelloVerifyRequest handshakeType = 3
handshakeTypeCertificate handshakeType = 11
handshakeTypeServerKeyExchange handshakeType = 12
handshakeTypeCertificateRequest handshakeType = 13
handshakeTypeServerHelloDone handshakeType = 14
handshakeTypeCertificateVerify handshakeType = 15
handshakeTypeClientKeyExchange handshakeType = 16
handshakeTypeFinished handshakeType = 20
// msg_len for Handshake messages assumes an extra 12 bytes for
// sequence, fragment and version information
handshakeMessageHeaderLength = 12
)
type handshakeMessage interface {
Marshal() ([]byte, error)
Unmarshal(data []byte) error
handshakeType() handshakeType
}
func (h handshakeType) String() string {
switch h {
case handshakeTypeHelloRequest:
return "HelloRequest"
case handshakeTypeClientHello:
return "ClientHello"
case handshakeTypeServerHello:
return "ServerHello"
case handshakeTypeHelloVerifyRequest:
return "HelloVerifyRequest"
case handshakeTypeCertificate:
return "TypeCertificate"
case handshakeTypeServerKeyExchange:
return "ServerKeyExchange"
case handshakeTypeCertificateRequest:
return "CertificateRequest"
case handshakeTypeServerHelloDone:
return "ServerHelloDone"
case handshakeTypeCertificateVerify:
return "CertificateVerify"
case handshakeTypeClientKeyExchange:
return "ClientKeyExchange"
case handshakeTypeFinished:
return "Finished"
}
return ""
}
// The handshake protocol is responsible for selecting a cipher spec and
// generating a master secret, which together comprise the primary
// cryptographic parameters associated with a secure session. The
// handshake protocol can also optionally authenticate parties who have
// certificates signed by a trusted certificate authority.
// https://tools.ietf.org/html/rfc5246#section-7.3
type handshake struct {
handshakeHeader handshakeHeader
handshakeMessage handshakeMessage
}
func (h handshake) contentType() contentType {
return contentTypeHandshake
}
func (h *handshake) Marshal() ([]byte, error) {
if h.handshakeMessage == nil {
return nil, errHandshakeMessageUnset
} else if h.handshakeHeader.fragmentOffset != 0 {
return nil, errUnableToMarshalFragmented
}
msg, err := h.handshakeMessage.Marshal()
if err != nil {
return nil, err
}
h.handshakeHeader.length = uint32(len(msg))
h.handshakeHeader.fragmentLength = h.handshakeHeader.length
h.handshakeHeader.handshakeType = h.handshakeMessage.handshakeType()
header, err := h.handshakeHeader.Marshal()
if err != nil {
return nil, err
}
return append(header, msg...), nil
}
func (h *handshake) Unmarshal(data []byte) error {
if err := h.handshakeHeader.Unmarshal(data); err != nil {
return err
}
reportedLen := bigEndianUint24(data[1:])
if uint32(len(data)-handshakeMessageHeaderLength) != reportedLen {
return errLengthMismatch
} else if reportedLen != h.handshakeHeader.fragmentLength {
return errLengthMismatch
}
switch handshakeType(data[0]) {
case handshakeTypeHelloRequest:
return errNotImplemented
case handshakeTypeClientHello:
h.handshakeMessage = &handshakeMessageClientHello{}
case handshakeTypeHelloVerifyRequest:
h.handshakeMessage = &handshakeMessageHelloVerifyRequest{}
case handshakeTypeServerHello:
h.handshakeMessage = &handshakeMessageServerHello{}
case handshakeTypeCertificate:
h.handshakeMessage = &handshakeMessageCertificate{}
case handshakeTypeServerKeyExchange:
h.handshakeMessage = &handshakeMessageServerKeyExchange{}
case handshakeTypeCertificateRequest:
h.handshakeMessage = &handshakeMessageCertificateRequest{}
case handshakeTypeServerHelloDone:
h.handshakeMessage = &handshakeMessageServerHelloDone{}
case handshakeTypeClientKeyExchange:
h.handshakeMessage = &handshakeMessageClientKeyExchange{}
case handshakeTypeFinished:
h.handshakeMessage = &handshakeMessageFinished{}
case handshakeTypeCertificateVerify:
h.handshakeMessage = &handshakeMessageCertificateVerify{}
default:
return errNotImplemented
}
return h.handshakeMessage.Unmarshal(data[handshakeMessageHeaderLength:])
}

View file

@ -2,10 +2,13 @@ package dtls
import (
"sync"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
)
type handshakeCacheItem struct {
typ handshakeType
typ handshake.Type
isClient bool
epoch uint16
messageSequence uint16
@ -13,7 +16,7 @@ type handshakeCacheItem struct {
}
type handshakeCachePullRule struct {
typ handshakeType
typ handshake.Type
epoch uint16
isClient bool
optional bool
@ -28,7 +31,7 @@ func newHandshakeCache() *handshakeCache {
return &handshakeCache{}
}
func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshakeType, isClient bool) bool { //nolint
func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) bool { //nolint
h.mu.Lock()
defer h.mu.Unlock()
@ -74,11 +77,11 @@ func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCache
}
// fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map.
func (h *handshakeCache) fullPullMap(startSeq int, rules ...handshakeCachePullRule) (int, map[handshakeType]handshakeMessage, bool) {
func (h *handshakeCache) fullPullMap(startSeq int, rules ...handshakeCachePullRule) (int, map[handshake.Type]handshake.Message, bool) {
h.mu.Lock()
defer h.mu.Unlock()
ci := make(map[handshakeType]*handshakeCacheItem)
ci := make(map[handshake.Type]*handshakeCacheItem)
for _, r := range rules {
var item *handshakeCacheItem
for _, c := range h.cache {
@ -97,7 +100,7 @@ func (h *handshakeCache) fullPullMap(startSeq int, rules ...handshakeCachePullRu
}
ci[r.typ] = item
}
out := make(map[handshakeType]handshakeMessage)
out := make(map[handshake.Type]handshake.Message)
seq := startSeq
for _, r := range rules {
t := r.typ
@ -105,16 +108,16 @@ func (h *handshakeCache) fullPullMap(startSeq int, rules ...handshakeCachePullRu
if i == nil {
continue
}
rawHandshake := &handshake{}
rawHandshake := &handshake.Handshake{}
if err := rawHandshake.Unmarshal(i.data); err != nil {
return startSeq, nil, false
}
if uint16(seq) != rawHandshake.handshakeHeader.messageSequence {
if uint16(seq) != rawHandshake.Header.MessageSequence {
// There is a gap. Some messages are not arrived.
return startSeq, nil, false
}
seq++
out[t] = rawHandshake.handshakeMessage
out[t] = rawHandshake.Message
}
return seq, out, true
}
@ -133,19 +136,19 @@ func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte {
// sessionHash returns the session hash for Extended Master Secret support
// https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4
func (h *handshakeCache) sessionHash(hf hashFunc, epoch uint16, additional ...[]byte) ([]byte, error) {
func (h *handshakeCache) sessionHash(hf prf.HashFunc, epoch uint16, additional ...[]byte) ([]byte, error) {
merged := []byte{}
// Order defined by https://tools.ietf.org/html/rfc5246#section-7.3
handshakeBuffer := h.pull(
handshakeCachePullRule{handshakeTypeClientHello, epoch, true, false},
handshakeCachePullRule{handshakeTypeServerHello, epoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, epoch, false, false},
handshakeCachePullRule{handshakeTypeServerKeyExchange, epoch, false, false},
handshakeCachePullRule{handshakeTypeCertificateRequest, epoch, false, false},
handshakeCachePullRule{handshakeTypeServerHelloDone, epoch, false, false},
handshakeCachePullRule{handshakeTypeCertificate, epoch, true, false},
handshakeCachePullRule{handshakeTypeClientKeyExchange, epoch, true, false},
handshakeCachePullRule{handshake.TypeClientHello, epoch, true, false},
handshakeCachePullRule{handshake.TypeServerHello, epoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, epoch, false, false},
handshakeCachePullRule{handshake.TypeServerKeyExchange, epoch, false, false},
handshakeCachePullRule{handshake.TypeCertificateRequest, epoch, false, false},
handshakeCachePullRule{handshake.TypeServerHelloDone, epoch, false, false},
handshakeCachePullRule{handshake.TypeCertificate, epoch, true, false},
handshakeCachePullRule{handshake.TypeClientKeyExchange, epoch, true, false},
)
for _, p := range handshakeBuffer {

View file

@ -1,41 +0,0 @@
package dtls
import (
"encoding/binary"
)
// msg_len for Handshake messages assumes an extra 12 bytes for
// sequence, fragment and version information
const handshakeHeaderLength = 12
type handshakeHeader struct {
handshakeType handshakeType
length uint32 // uint24 in spec
messageSequence uint16
fragmentOffset uint32 // uint24 in spec
fragmentLength uint32 // uint24 in spec
}
func (h *handshakeHeader) Marshal() ([]byte, error) {
out := make([]byte, handshakeMessageHeaderLength)
out[0] = byte(h.handshakeType)
putBigEndianUint24(out[1:], h.length)
binary.BigEndian.PutUint16(out[4:], h.messageSequence)
putBigEndianUint24(out[6:], h.fragmentOffset)
putBigEndianUint24(out[9:], h.fragmentLength)
return out, nil
}
func (h *handshakeHeader) Unmarshal(data []byte) error {
if len(data) < handshakeHeaderLength {
return errBufferTooSmall
}
h.handshakeType = handshakeType(data[0])
h.length = bigEndianUint24(data[1:])
h.messageSequence = binary.BigEndian.Uint16(data[4:])
h.fragmentOffset = bigEndianUint24(data[6:])
h.fragmentLength = bigEndianUint24(data[9:])
return nil
}

View file

@ -1,55 +0,0 @@
package dtls
type handshakeMessageCertificate struct {
certificate [][]byte
}
func (h handshakeMessageCertificate) handshakeType() handshakeType {
return handshakeTypeCertificate
}
const (
handshakeMessageCertificateLengthFieldSize = 3
)
func (h *handshakeMessageCertificate) Marshal() ([]byte, error) {
out := make([]byte, handshakeMessageCertificateLengthFieldSize)
for _, r := range h.certificate {
// Certificate Length
out = append(out, make([]byte, handshakeMessageCertificateLengthFieldSize)...)
putBigEndianUint24(out[len(out)-handshakeMessageCertificateLengthFieldSize:], uint32(len(r)))
// Certificate body
out = append(out, append([]byte{}, r...)...)
}
// Total Payload Size
putBigEndianUint24(out[0:], uint32(len(out[handshakeMessageCertificateLengthFieldSize:])))
return out, nil
}
func (h *handshakeMessageCertificate) Unmarshal(data []byte) error {
if len(data) < handshakeMessageCertificateLengthFieldSize {
return errBufferTooSmall
}
if certificateBodyLen := int(bigEndianUint24(data)); certificateBodyLen+handshakeMessageCertificateLengthFieldSize != len(data) {
return errLengthMismatch
}
offset := handshakeMessageCertificateLengthFieldSize
for offset < len(data) {
certificateLen := int(bigEndianUint24(data[offset:]))
offset += handshakeMessageCertificateLengthFieldSize
if offset+certificateLen > len(data) {
return errLengthMismatch
}
h.certificate = append(h.certificate, append([]byte{}, data[offset:offset+certificateLen]...))
offset += certificateLen
}
return nil
}

View file

@ -1,91 +0,0 @@
package dtls
import (
"encoding/binary"
)
/*
A non-anonymous server can optionally request a certificate from
the client, if appropriate for the selected cipher suite. This
message, if sent, will immediately follow the ServerKeyExchange
message (if it is sent; otherwise, this message follows the
server's Certificate message).
*/
type handshakeMessageCertificateRequest struct {
certificateTypes []clientCertificateType
signatureHashAlgorithms []signatureHashAlgorithm
}
const (
handshakeMessageCertificateRequestMinLength = 5
)
func (h handshakeMessageCertificateRequest) handshakeType() handshakeType {
return handshakeTypeCertificateRequest
}
func (h *handshakeMessageCertificateRequest) Marshal() ([]byte, error) {
out := []byte{byte(len(h.certificateTypes))}
for _, v := range h.certificateTypes {
out = append(out, byte(v))
}
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(h.signatureHashAlgorithms)*2))
for _, v := range h.signatureHashAlgorithms {
out = append(out, byte(v.hash))
out = append(out, byte(v.signature))
}
out = append(out, []byte{0x00, 0x00}...) // Distinguished Names Length
return out, nil
}
func (h *handshakeMessageCertificateRequest) Unmarshal(data []byte) error {
if len(data) < handshakeMessageCertificateRequestMinLength {
return errBufferTooSmall
}
offset := 0
certificateTypesLength := int(data[0])
offset++
if (offset + certificateTypesLength) > len(data) {
return errBufferTooSmall
}
for i := 0; i < certificateTypesLength; i++ {
certType := clientCertificateType(data[offset+i])
if _, ok := clientCertificateTypes()[certType]; ok {
h.certificateTypes = append(h.certificateTypes, certType)
}
}
offset += certificateTypesLength
if len(data) < offset+2 {
return errBufferTooSmall
}
signatureHashAlgorithmsLength := int(binary.BigEndian.Uint16(data[offset:]))
offset += 2
if (offset + signatureHashAlgorithmsLength) > len(data) {
return errBufferTooSmall
}
for i := 0; i < signatureHashAlgorithmsLength; i += 2 {
if len(data) < (offset + i + 2) {
return errBufferTooSmall
}
hash := hashAlgorithm(data[offset+i])
signature := signatureAlgorithm(data[offset+i+1])
if _, ok := hashAlgorithms()[hash]; !ok {
continue
} else if _, ok := signatureAlgorithms()[signature]; !ok {
continue
}
h.signatureHashAlgorithms = append(h.signatureHashAlgorithms, signatureHashAlgorithm{signature: signature, hash: hash})
}
return nil
}

View file

@ -1,51 +0,0 @@
package dtls
import (
"encoding/binary"
)
type handshakeMessageCertificateVerify struct {
hashAlgorithm hashAlgorithm
signatureAlgorithm signatureAlgorithm
signature []byte
}
const handshakeMessageCertificateVerifyMinLength = 4
func (h handshakeMessageCertificateVerify) handshakeType() handshakeType {
return handshakeTypeCertificateVerify
}
func (h *handshakeMessageCertificateVerify) Marshal() ([]byte, error) {
out := make([]byte, 1+1+2+len(h.signature))
out[0] = byte(h.hashAlgorithm)
out[1] = byte(h.signatureAlgorithm)
binary.BigEndian.PutUint16(out[2:], uint16(len(h.signature)))
copy(out[4:], h.signature)
return out, nil
}
func (h *handshakeMessageCertificateVerify) Unmarshal(data []byte) error {
if len(data) < handshakeMessageCertificateVerifyMinLength {
return errBufferTooSmall
}
h.hashAlgorithm = hashAlgorithm(data[0])
if _, ok := hashAlgorithms()[h.hashAlgorithm]; !ok {
return errInvalidHashAlgorithm
}
h.signatureAlgorithm = signatureAlgorithm(data[1])
if _, ok := signatureAlgorithms()[h.signatureAlgorithm]; !ok {
return errInvalidSignatureAlgorithm
}
signatureLength := int(binary.BigEndian.Uint16(data[2:]))
if (signatureLength + 4) != len(data) {
return errBufferTooSmall
}
h.signature = append([]byte{}, data[4:]...)
return nil
}

View file

@ -1,119 +0,0 @@
package dtls
import (
"encoding/binary"
)
/*
When a client first connects to a server it is required to send
the client hello as its first message. The client can also send a
client hello in response to a hello request or on its own
initiative in order to renegotiate the security parameters in an
existing connection.
*/
type handshakeMessageClientHello struct {
version protocolVersion
random handshakeRandom
cookie []byte
cipherSuites []cipherSuite
compressionMethods []*compressionMethod
extensions []extension
}
const handshakeMessageClientHelloVariableWidthStart = 34
func (h handshakeMessageClientHello) handshakeType() handshakeType {
return handshakeTypeClientHello
}
func (h *handshakeMessageClientHello) Marshal() ([]byte, error) {
if len(h.cookie) > 255 {
return nil, errCookieTooLong
}
out := make([]byte, handshakeMessageClientHelloVariableWidthStart)
out[0] = h.version.major
out[1] = h.version.minor
rand := h.random.marshalFixed()
copy(out[2:], rand[:])
out = append(out, 0x00) // SessionID
out = append(out, byte(len(h.cookie)))
out = append(out, h.cookie...)
out = append(out, encodeCipherSuites(h.cipherSuites)...)
out = append(out, encodeCompressionMethods(h.compressionMethods)...)
extensions, err := encodeExtensions(h.extensions)
if err != nil {
return nil, err
}
return append(out, extensions...), nil
}
func (h *handshakeMessageClientHello) Unmarshal(data []byte) error {
if len(data) < 2+handshakeRandomLength {
return errBufferTooSmall
}
h.version.major = data[0]
h.version.minor = data[1]
var random [handshakeRandomLength]byte
copy(random[:], data[2:])
h.random.unmarshalFixed(random)
// rest of packet has variable width sections
currOffset := handshakeMessageClientHelloVariableWidthStart
currOffset += int(data[currOffset]) + 1 // SessionID
currOffset++
if len(data) <= currOffset {
return errBufferTooSmall
}
n := int(data[currOffset-1])
if len(data) <= currOffset+n {
return errBufferTooSmall
}
h.cookie = append([]byte{}, data[currOffset:currOffset+n]...)
currOffset += len(h.cookie)
// Cipher Suites
if len(data) < currOffset {
return errBufferTooSmall
}
cipherSuites, err := decodeCipherSuites(data[currOffset:])
if err != nil {
return err
}
h.cipherSuites = cipherSuites
if len(data) < currOffset+2 {
return errBufferTooSmall
}
currOffset += int(binary.BigEndian.Uint16(data[currOffset:])) + 2
// Compression Methods
if len(data) < currOffset {
return errBufferTooSmall
}
compressionMethods, err := decodeCompressionMethods(data[currOffset:])
if err != nil {
return err
}
h.compressionMethods = compressionMethods
if len(data) < currOffset {
return errBufferTooSmall
}
currOffset += int(data[currOffset]) + 1
// Extensions
extensions, err := decodeExtensions(data[currOffset:])
if err != nil {
return err
}
h.extensions = extensions
return nil
}

View file

@ -1,46 +0,0 @@
package dtls
import (
"encoding/binary"
)
type handshakeMessageClientKeyExchange struct {
identityHint []byte
publicKey []byte
}
func (h handshakeMessageClientKeyExchange) handshakeType() handshakeType {
return handshakeTypeClientKeyExchange
}
func (h *handshakeMessageClientKeyExchange) Marshal() ([]byte, error) {
switch {
case (h.identityHint != nil && h.publicKey != nil) || (h.identityHint == nil && h.publicKey == nil):
return nil, errInvalidClientKeyExchange
case h.publicKey != nil:
return append([]byte{byte(len(h.publicKey))}, h.publicKey...), nil
default:
out := append([]byte{0x00, 0x00}, h.identityHint...)
binary.BigEndian.PutUint16(out, uint16(len(out)-2))
return out, nil
}
}
func (h *handshakeMessageClientKeyExchange) Unmarshal(data []byte) error {
if len(data) < 2 {
return errBufferTooSmall
}
// If parsed as PSK return early and only populate PSK Identity Hint
if pskLength := binary.BigEndian.Uint16(data); len(data) == int(pskLength+2) {
h.identityHint = append([]byte{}, data[2:]...)
return nil
}
if publicKeyLength := int(data[0]); len(data) != publicKeyLength+1 {
return errBufferTooSmall
}
h.publicKey = append([]byte{}, data[1:]...)
return nil
}

View file

@ -1,18 +0,0 @@
package dtls
type handshakeMessageFinished struct {
verifyData []byte
}
func (h handshakeMessageFinished) handshakeType() handshakeType {
return handshakeTypeFinished
}
func (h *handshakeMessageFinished) Marshal() ([]byte, error) {
return append([]byte{}, h.verifyData...), nil
}
func (h *handshakeMessageFinished) Unmarshal(data []byte) error {
h.verifyData = append([]byte{}, data...)
return nil
}

View file

@ -1,57 +0,0 @@
package dtls
/*
The definition of HelloVerifyRequest is as follows:
struct {
ProtocolVersion server_version;
opaque cookie<0..2^8-1>;
} HelloVerifyRequest;
The HelloVerifyRequest message type is hello_verify_request(3).
When the client sends its ClientHello message to the server, the server
MAY respond with a HelloVerifyRequest message. This message contains
a stateless cookie generated using the technique of [PHOTURIS]. The
client MUST retransmit the ClientHello with the cookie added.
https://tools.ietf.org/html/rfc6347#section-4.2.1
*/
type handshakeMessageHelloVerifyRequest struct {
version protocolVersion
cookie []byte
}
func (h handshakeMessageHelloVerifyRequest) handshakeType() handshakeType {
return handshakeTypeHelloVerifyRequest
}
func (h *handshakeMessageHelloVerifyRequest) Marshal() ([]byte, error) {
if len(h.cookie) > 255 {
return nil, errCookieTooLong
}
out := make([]byte, 3+len(h.cookie))
out[0] = h.version.major
out[1] = h.version.minor
out[2] = byte(len(h.cookie))
copy(out[3:], h.cookie)
return out, nil
}
func (h *handshakeMessageHelloVerifyRequest) Unmarshal(data []byte) error {
if len(data) < 3 {
return errBufferTooSmall
}
h.version.major = data[0]
h.version.minor = data[1]
cookieLength := data[2]
if len(data) < (int(cookieLength) + 3) {
return errBufferTooSmall
}
h.cookie = make([]byte, cookieLength)
copy(h.cookie, data[3:3+cookieLength])
return nil
}

View file

@ -1,102 +0,0 @@
package dtls
import (
"encoding/binary"
)
/*
The server will send this message in response to a ClientHello
message when it was able to find an acceptable set of algorithms.
If it cannot find such a match, it will respond with a handshake
failure alert.
https://tools.ietf.org/html/rfc5246#section-7.4.1.3
*/
type handshakeMessageServerHello struct {
version protocolVersion
random handshakeRandom
cipherSuite cipherSuite
compressionMethod *compressionMethod
extensions []extension
}
const handshakeMessageServerHelloVariableWidthStart = 2 + handshakeRandomLength
func (h handshakeMessageServerHello) handshakeType() handshakeType {
return handshakeTypeServerHello
}
func (h *handshakeMessageServerHello) Marshal() ([]byte, error) {
if h.cipherSuite == nil {
return nil, errCipherSuiteUnset
} else if h.compressionMethod == nil {
return nil, errCompressionMethodUnset
}
out := make([]byte, handshakeMessageServerHelloVariableWidthStart)
out[0] = h.version.major
out[1] = h.version.minor
rand := h.random.marshalFixed()
copy(out[2:], rand[:])
out = append(out, 0x00) // SessionID
out = append(out, []byte{0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(h.cipherSuite.ID()))
out = append(out, byte(h.compressionMethod.id))
extensions, err := encodeExtensions(h.extensions)
if err != nil {
return nil, err
}
return append(out, extensions...), nil
}
func (h *handshakeMessageServerHello) Unmarshal(data []byte) error {
if len(data) < 2+handshakeRandomLength {
return errBufferTooSmall
}
h.version.major = data[0]
h.version.minor = data[1]
var random [handshakeRandomLength]byte
copy(random[:], data[2:])
h.random.unmarshalFixed(random)
currOffset := handshakeMessageServerHelloVariableWidthStart
currOffset += int(data[currOffset]) + 1 // SessionID
if len(data) < (currOffset + 2) {
return errBufferTooSmall
}
if c := cipherSuiteForID(CipherSuiteID(binary.BigEndian.Uint16(data[currOffset:]))); c != nil {
h.cipherSuite = c
currOffset += 2
} else {
return errInvalidCipherSuite
}
if len(data) < currOffset {
return errBufferTooSmall
}
if compressionMethod, ok := compressionMethods()[compressionMethodID(data[currOffset])]; ok {
h.compressionMethod = compressionMethod
currOffset++
} else {
return errInvalidCompressionMethod
}
if len(data) <= currOffset {
h.extensions = []extension{}
return nil
}
extensions, err := decodeExtensions(data[currOffset:])
if err != nil {
return err
}
h.extensions = extensions
return nil
}

View file

@ -1,16 +0,0 @@
package dtls
type handshakeMessageServerHelloDone struct {
}
func (h handshakeMessageServerHelloDone) handshakeType() handshakeType {
return handshakeTypeServerHelloDone
}
func (h *handshakeMessageServerHelloDone) Marshal() ([]byte, error) {
return []byte{}, nil
}
func (h *handshakeMessageServerHelloDone) Unmarshal(data []byte) error {
return nil
}

View file

@ -1,104 +0,0 @@
package dtls
import (
"encoding/binary"
)
// Structure supports ECDH and PSK
type handshakeMessageServerKeyExchange struct {
identityHint []byte
ellipticCurveType ellipticCurveType
namedCurve namedCurve
publicKey []byte
hashAlgorithm hashAlgorithm
signatureAlgorithm signatureAlgorithm
signature []byte
}
func (h handshakeMessageServerKeyExchange) handshakeType() handshakeType {
return handshakeTypeServerKeyExchange
}
func (h *handshakeMessageServerKeyExchange) Marshal() ([]byte, error) {
if h.identityHint != nil {
out := append([]byte{0x00, 0x00}, h.identityHint...)
binary.BigEndian.PutUint16(out, uint16(len(out)-2))
return out, nil
}
out := []byte{byte(h.ellipticCurveType), 0x00, 0x00}
binary.BigEndian.PutUint16(out[1:], uint16(h.namedCurve))
out = append(out, byte(len(h.publicKey)))
out = append(out, h.publicKey...)
out = append(out, []byte{byte(h.hashAlgorithm), byte(h.signatureAlgorithm), 0x00, 0x00}...)
binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(h.signature)))
out = append(out, h.signature...)
return out, nil
}
func (h *handshakeMessageServerKeyExchange) Unmarshal(data []byte) error {
if len(data) < 2 {
return errBufferTooSmall
}
// If parsed as PSK return early and only populate PSK Identity Hint
if pskLength := binary.BigEndian.Uint16(data); len(data) == int(pskLength+2) {
h.identityHint = append([]byte{}, data[2:]...)
return nil
}
if _, ok := ellipticCurveTypes()[ellipticCurveType(data[0])]; ok {
h.ellipticCurveType = ellipticCurveType(data[0])
} else {
return errInvalidEllipticCurveType
}
if len(data[1:]) < 2 {
return errBufferTooSmall
}
h.namedCurve = namedCurve(binary.BigEndian.Uint16(data[1:3]))
if _, ok := namedCurves()[h.namedCurve]; !ok {
return errInvalidNamedCurve
}
if len(data) < 4 {
return errBufferTooSmall
}
publicKeyLength := int(data[3])
offset := 4 + publicKeyLength
if len(data) < offset {
return errBufferTooSmall
}
h.publicKey = append([]byte{}, data[4:offset]...)
if len(data) <= offset {
return errBufferTooSmall
}
h.hashAlgorithm = hashAlgorithm(data[offset])
if _, ok := hashAlgorithms()[h.hashAlgorithm]; !ok {
return errInvalidHashAlgorithm
}
offset++
if len(data) <= offset {
return errBufferTooSmall
}
h.signatureAlgorithm = signatureAlgorithm(data[offset])
if _, ok := signatureAlgorithms()[h.signatureAlgorithm]; !ok {
return errInvalidSignatureAlgorithm
}
offset++
if len(data) < offset+2 {
return errBufferTooSmall
}
signatureLength := int(binary.BigEndian.Uint16(data[offset:]))
offset += 2
if len(data) < offset+signatureLength {
return errBufferTooSmall
}
h.signature = append([]byte{}, data[offset:offset+signatureLength]...)
return nil
}

View file

@ -1,44 +0,0 @@
package dtls
import (
"crypto/rand"
"encoding/binary"
"time"
)
const (
randomBytesLength = 28
handshakeRandomLength = randomBytesLength + 4
)
// https://tools.ietf.org/html/rfc4346#section-7.4.1.2
type handshakeRandom struct {
gmtUnixTime time.Time
randomBytes [randomBytesLength]byte
}
func (h *handshakeRandom) marshalFixed() [handshakeRandomLength]byte {
var out [handshakeRandomLength]byte
binary.BigEndian.PutUint32(out[0:], uint32(h.gmtUnixTime.Unix()))
copy(out[4:], h.randomBytes[:])
return out
}
func (h *handshakeRandom) unmarshalFixed(data [handshakeRandomLength]byte) {
h.gmtUnixTime = time.Unix(int64(binary.BigEndian.Uint32(data[0:])), 0)
copy(h.randomBytes[:], data[4:])
}
// populate fills the handshakeRandom with random values
// may be called multiple times
func (h *handshakeRandom) populate() error {
h.gmtUnixTime = time.Now()
tmp := make([]byte, randomBytesLength)
_, err := rand.Read(tmp)
copy(h.randomBytes[:], tmp)
return err
}

View file

@ -4,10 +4,14 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/logging"
)
@ -46,8 +50,6 @@ import (
// Read retransmit
// Retransmit last flight
var errInvalidFSMTransition = errors.New("invalid state machine transition")
type handshakeState uint8
const (
@ -88,10 +90,10 @@ type handshakeFSM struct {
type handshakeConfig struct {
localPSKCallback PSKCallback
localPSKIdentityHint []byte
localCipherSuites []cipherSuite // Available CipherSuites
localSignatureSchemes []signatureHashAlgorithm // Available signature schemes
extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension
localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support
localCipherSuites []CipherSuite // Available CipherSuites
localSignatureSchemes []signaturehash.Algorithm // Available signature schemes
extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension
localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support
serverName string
clientAuth ClientAuthType // If we are a client should we request a client certificate
localCertificates []tls.Certificate
@ -101,9 +103,11 @@ type handshakeConfig struct {
rootCAs *x509.CertPool
clientCAs *x509.CertPool
retransmitInterval time.Duration
customCipherSuites func() []CipherSuite
onFlightState func(flightVal, handshakeState)
log logging.LeveledLogger
keyLogWriter io.Writer
initialEpoch uint16
@ -111,13 +115,25 @@ type handshakeConfig struct {
}
type flightConn interface {
notify(ctx context.Context, level alertLevel, desc alertDescription) error
notify(ctx context.Context, level alert.Level, desc alert.Description) error
writePackets(context.Context, []*packet) error
recvHandshake() <-chan chan struct{}
setLocalEpoch(epoch uint16)
handleQueuedPackets(context.Context) error
}
func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) {
if c.keyLogWriter == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
_, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)))
if err != nil {
c.log.Debugf("failed to write key log file: %s", err)
}
}
func srvCliStr(isClient bool) string {
if isClient {
return "client"
@ -175,20 +191,20 @@ func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeStat
s.flights = nil
// Prepare flights
var (
a *alert
a *alert.Alert
err error
pkts []*packet
)
gen, retransmit, errFlight := s.currentFlight.getFlightGenerator()
if errFlight != nil {
err = errFlight
a = &alert{alertLevelFatal, alertInternalError}
a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}
} else {
pkts, a, err = gen(c, s.state, s.cache, s.cfg)
s.retransmit = retransmit
}
if a != nil {
if alertErr := c.notify(ctx, a.alertLevel, a.alertDescription); alertErr != nil {
if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil {
if err != nil {
err = alertErr
}
@ -202,12 +218,12 @@ func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeStat
epoch := s.cfg.initialEpoch
nextEpoch := epoch
for _, p := range s.flights {
p.record.recordLayerHeader.epoch += epoch
if p.record.recordLayerHeader.epoch > nextEpoch {
nextEpoch = p.record.recordLayerHeader.epoch
p.record.Header.Epoch += epoch
if p.record.Header.Epoch > nextEpoch {
nextEpoch = p.record.Header.Epoch
}
if h, ok := p.record.content.(*handshake); ok {
h.handshakeHeader.messageSequence = uint16(s.state.handshakeSendSequence)
if h, ok := p.record.Content.(*handshake.Handshake); ok {
h.Header.MessageSequence = uint16(s.state.handshakeSendSequence)
s.state.handshakeSendSequence++
}
}
@ -233,7 +249,7 @@ func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState,
func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit
parse, errFlight := s.currentFlight.getFlightParser()
if errFlight != nil {
if alertErr := c.notify(ctx, alertLevelFatal, alertInternalError); alertErr != nil {
if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
if errFlight != nil {
return handshakeErrored, alertErr
}
@ -248,7 +264,7 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState,
nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
close(done)
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
if err != nil {
err = alertErr
}
@ -281,7 +297,7 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState,
func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) {
parse, errFlight := s.currentFlight.getFlightParser()
if errFlight != nil {
if alertErr := c.notify(ctx, alertLevelFatal, alertInternalError); alertErr != nil {
if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
if errFlight != nil {
return handshakeErrored, alertErr
}
@ -295,7 +311,7 @@ func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState
nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
close(done)
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
if err != nil {
err = alertErr
}

View file

@ -1,116 +0,0 @@
package dtls
import ( //nolint:gci
"crypto"
"crypto/md5" //nolint:gosec
"crypto/sha1" //nolint:gosec
"crypto/sha256"
"crypto/sha512"
)
// hashAlgorithm is used to indicate the hash algorithm used
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-18
type hashAlgorithm uint16
// Supported hash hash algorithms
const (
hashAlgorithmMD2 hashAlgorithm = 0 // Blacklisted
hashAlgorithmMD5 hashAlgorithm = 1 // Blacklisted
hashAlgorithmSHA1 hashAlgorithm = 2 // Blacklisted
hashAlgorithmSHA224 hashAlgorithm = 3
hashAlgorithmSHA256 hashAlgorithm = 4
hashAlgorithmSHA384 hashAlgorithm = 5
hashAlgorithmSHA512 hashAlgorithm = 6
hashAlgorithmEd25519 hashAlgorithm = 8
)
// String makes hashAlgorithm printable
func (h hashAlgorithm) String() string {
switch h {
case hashAlgorithmMD2:
return "md2"
case hashAlgorithmMD5:
return "md5" // [RFC3279]
case hashAlgorithmSHA1:
return "sha-1" // [RFC3279]
case hashAlgorithmSHA224:
return "sha-224" // [RFC4055]
case hashAlgorithmSHA256:
return "sha-256" // [RFC4055]
case hashAlgorithmSHA384:
return "sha-384" // [RFC4055]
case hashAlgorithmSHA512:
return "sha-512" // [RFC4055]
case hashAlgorithmEd25519:
return "null"
default:
return "unknown or unsupported hash algorithm"
}
}
func (h hashAlgorithm) digest(b []byte) []byte {
switch h {
case hashAlgorithmMD5:
hash := md5.Sum(b) // #nosec
return hash[:]
case hashAlgorithmSHA1:
hash := sha1.Sum(b) // #nosec
return hash[:]
case hashAlgorithmSHA224:
hash := sha256.Sum224(b)
return hash[:]
case hashAlgorithmSHA256:
hash := sha256.Sum256(b)
return hash[:]
case hashAlgorithmSHA384:
hash := sha512.Sum384(b)
return hash[:]
case hashAlgorithmSHA512:
hash := sha512.Sum512(b)
return hash[:]
default:
return nil
}
}
func (h hashAlgorithm) insecure() bool {
switch h {
case hashAlgorithmMD2, hashAlgorithmMD5, hashAlgorithmSHA1:
return true
default:
return false
}
}
func (h hashAlgorithm) cryptoHash() crypto.Hash {
switch h {
case hashAlgorithmMD5:
return crypto.MD5
case hashAlgorithmSHA1:
return crypto.SHA1
case hashAlgorithmSHA224:
return crypto.SHA224
case hashAlgorithmSHA256:
return crypto.SHA256
case hashAlgorithmSHA384:
return crypto.SHA384
case hashAlgorithmSHA512:
return crypto.SHA512
case hashAlgorithmEd25519:
return crypto.Hash(0)
default:
return crypto.Hash(0)
}
}
func hashAlgorithms() map[hashAlgorithm]struct{} {
return map[hashAlgorithm]struct{}{
hashAlgorithmMD5: {},
hashAlgorithmSHA1: {},
hashAlgorithmSHA224: {},
hashAlgorithmSHA256: {},
hashAlgorithmSHA384: {},
hashAlgorithmSHA512: {},
hashAlgorithmEd25519: {},
}
}

View file

@ -0,0 +1,108 @@
package ciphersuite
import (
"crypto/sha256"
"fmt"
"hash"
"sync/atomic"
"github.com/pion/dtls/v2/pkg/crypto/ciphersuite"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
// Aes128Ccm is a base class used by multiple AES-CCM Ciphers
type Aes128Ccm struct {
ccm atomic.Value // *cryptoCCM
clientCertificateType clientcertificate.Type
id ID
psk bool
cryptoCCMTagLen ciphersuite.CCMTagLen
}
func newAes128Ccm(clientCertificateType clientcertificate.Type, id ID, psk bool, cryptoCCMTagLen ciphersuite.CCMTagLen) *Aes128Ccm {
return &Aes128Ccm{
clientCertificateType: clientCertificateType,
id: id,
psk: psk,
cryptoCCMTagLen: cryptoCCMTagLen,
}
}
// CertificateType returns what type of certificate this CipherSuite exchanges
func (c *Aes128Ccm) CertificateType() clientcertificate.Type {
return c.clientCertificateType
}
// ID returns the ID of the CipherSuite
func (c *Aes128Ccm) ID() ID {
return c.id
}
func (c *Aes128Ccm) String() string {
return c.id.String()
}
// HashFunc returns the hashing func for this CipherSuite
func (c *Aes128Ccm) HashFunc() func() hash.Hash {
return sha256.New
}
// AuthenticationType controls what authentication method is using during the handshake
func (c *Aes128Ccm) AuthenticationType() AuthenticationType {
if c.psk {
return AuthenticationTypePreSharedKey
}
return AuthenticationTypeCertificate
}
// IsInitialized returns if the CipherSuite has keying material and can
// encrypt/decrypt packets
func (c *Aes128Ccm) IsInitialized() bool {
return c.ccm.Load() != nil
}
// Init initializes the internal Cipher with keying material
func (c *Aes128Ccm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 0
prfKeyLen = 16
prfIvLen = 4
)
keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc())
if err != nil {
return err
}
var ccm *ciphersuite.CCM
if isClient {
ccm, err = ciphersuite.NewCCM(c.cryptoCCMTagLen, keys.ClientWriteKey, keys.ClientWriteIV, keys.ServerWriteKey, keys.ServerWriteIV)
} else {
ccm, err = ciphersuite.NewCCM(c.cryptoCCMTagLen, keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV)
}
c.ccm.Store(ccm)
return err
}
// Encrypt encrypts a single TLS RecordLayer
func (c *Aes128Ccm) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
ccm := c.ccm.Load()
if ccm == nil {
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return ccm.(*ciphersuite.CCM).Encrypt(pkt, raw)
}
// Decrypt decrypts a single TLS RecordLayer
func (c *Aes128Ccm) Decrypt(raw []byte) ([]byte, error) {
ccm := c.ccm.Load()
if ccm == nil {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return ccm.(*ciphersuite.CCM).Decrypt(raw)
}

View file

@ -0,0 +1,71 @@
// Package ciphersuite provides TLS Ciphers as registered with the IANA https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-4
package ciphersuite
import (
"errors"
"fmt"
"github.com/pion/dtls/v2/pkg/protocol"
)
var errCipherSuiteNotInit = &protocol.TemporaryError{Err: errors.New("CipherSuite has not been initialized")} //nolint:goerr113
// ID is an ID for our supported CipherSuites
type ID uint16
func (i ID) String() string {
switch i {
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM:
return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM"
case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8"
case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
case TLS_PSK_WITH_AES_128_CCM:
return "TLS_PSK_WITH_AES_128_CCM"
case TLS_PSK_WITH_AES_128_CCM_8:
return "TLS_PSK_WITH_AES_128_CCM_8"
case TLS_PSK_WITH_AES_128_GCM_SHA256:
return "TLS_PSK_WITH_AES_128_GCM_SHA256"
case TLS_PSK_WITH_AES_128_CBC_SHA256:
return "TLS_PSK_WITH_AES_128_CBC_SHA256"
default:
return fmt.Sprintf("unknown(%v)", uint16(i))
}
}
// Supported Cipher Suites
const (
// AES-128-CCM
TLS_ECDHE_ECDSA_WITH_AES_128_CCM ID = 0xc0ac //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ID = 0xc0ae //nolint:golint,stylecheck
// AES-128-GCM-SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ID = 0xc02b //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ID = 0xc02f //nolint:golint,stylecheck
// AES-256-CBC-SHA
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ID = 0xc00a //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ID = 0xc014 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM ID = 0xc0a4 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM_8 ID = 0xc0a8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_GCM_SHA256 ID = 0x00a8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CBC_SHA256 ID = 0x00ae //nolint:golint,stylecheck
)
// AuthenticationType controls what authentication method is using during the handshake
type AuthenticationType int
// AuthenticationType Enums
const (
AuthenticationTypeCertificate AuthenticationType = iota + 1
AuthenticationTypePreSharedKey
AuthenticationTypeAnonymous
)

View file

@ -0,0 +1,11 @@
package ciphersuite
import (
"github.com/pion/dtls/v2/pkg/crypto/ciphersuite"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
)
// NewTLSEcdheEcdsaWithAes128Ccm constructs a TLS_ECDHE_ECDSA_WITH_AES_128_CCM Cipher
func NewTLSEcdheEcdsaWithAes128Ccm() *Aes128Ccm {
return newAes128Ccm(clientcertificate.ECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM, false, ciphersuite.CCMTagLength)
}

View file

@ -0,0 +1,11 @@
package ciphersuite
import (
"github.com/pion/dtls/v2/pkg/crypto/ciphersuite"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
)
// NewTLSEcdheEcdsaWithAes128Ccm8 creates a new TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuite
func NewTLSEcdheEcdsaWithAes128Ccm8() *Aes128Ccm {
return newAes128Ccm(clientcertificate.ECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, false, ciphersuite.CCMTagLength8)
}

View file

@ -0,0 +1,92 @@
package ciphersuite
import (
"crypto/sha256"
"fmt"
"hash"
"sync/atomic"
"github.com/pion/dtls/v2/pkg/crypto/ciphersuite"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
// TLSEcdheEcdsaWithAes128GcmSha256 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite
type TLSEcdheEcdsaWithAes128GcmSha256 struct {
gcm atomic.Value // *cryptoGCM
}
// CertificateType returns what type of certficate this CipherSuite exchanges
func (c *TLSEcdheEcdsaWithAes128GcmSha256) CertificateType() clientcertificate.Type {
return clientcertificate.ECDSASign
}
// ID returns the ID of the CipherSuite
func (c *TLSEcdheEcdsaWithAes128GcmSha256) ID() ID {
return TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
}
func (c *TLSEcdheEcdsaWithAes128GcmSha256) String() string {
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
}
// HashFunc returns the hashing func for this CipherSuite
func (c *TLSEcdheEcdsaWithAes128GcmSha256) HashFunc() func() hash.Hash {
return sha256.New
}
// AuthenticationType controls what authentication method is using during the handshake
func (c *TLSEcdheEcdsaWithAes128GcmSha256) AuthenticationType() AuthenticationType {
return AuthenticationTypeCertificate
}
// IsInitialized returns if the CipherSuite has keying material and can
// encrypt/decrypt packets
func (c *TLSEcdheEcdsaWithAes128GcmSha256) IsInitialized() bool {
return c.gcm.Load() != nil
}
// Init initializes the internal Cipher with keying material
func (c *TLSEcdheEcdsaWithAes128GcmSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 0
prfKeyLen = 16
prfIvLen = 4
)
keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc())
if err != nil {
return err
}
var gcm *ciphersuite.GCM
if isClient {
gcm, err = ciphersuite.NewGCM(keys.ClientWriteKey, keys.ClientWriteIV, keys.ServerWriteKey, keys.ServerWriteIV)
} else {
gcm, err = ciphersuite.NewGCM(keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV)
}
c.gcm.Store(gcm)
return err
}
// Encrypt encrypts a single TLS RecordLayer
func (c *TLSEcdheEcdsaWithAes128GcmSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
gcm := c.gcm.Load()
if gcm == nil {
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return gcm.(*ciphersuite.GCM).Encrypt(pkt, raw)
}
// Decrypt decrypts a single TLS RecordLayer
func (c *TLSEcdheEcdsaWithAes128GcmSha256) Decrypt(raw []byte) ([]byte, error) {
gcm := c.gcm.Load()
if gcm == nil {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return gcm.(*ciphersuite.GCM).Decrypt(raw)
}

View file

@ -0,0 +1,101 @@
package ciphersuite
import (
"crypto/sha1" //nolint: gosec,gci
"crypto/sha256"
"fmt"
"hash"
"sync/atomic"
"github.com/pion/dtls/v2/pkg/crypto/ciphersuite"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
// TLSEcdheEcdsaWithAes256CbcSha represents a TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuite
type TLSEcdheEcdsaWithAes256CbcSha struct {
cbc atomic.Value // *cryptoCBC
}
// CertificateType returns what type of certficate this CipherSuite exchanges
func (c *TLSEcdheEcdsaWithAes256CbcSha) CertificateType() clientcertificate.Type {
return clientcertificate.ECDSASign
}
// ID returns the ID of the CipherSuite
func (c *TLSEcdheEcdsaWithAes256CbcSha) ID() ID {
return TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
}
func (c *TLSEcdheEcdsaWithAes256CbcSha) String() string {
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
}
// HashFunc returns the hashing func for this CipherSuite
func (c *TLSEcdheEcdsaWithAes256CbcSha) HashFunc() func() hash.Hash {
return sha256.New
}
// AuthenticationType controls what authentication method is using during the handshake
func (c *TLSEcdheEcdsaWithAes256CbcSha) AuthenticationType() AuthenticationType {
return AuthenticationTypeCertificate
}
// IsInitialized returns if the CipherSuite has keying material and can
// encrypt/decrypt packets
func (c *TLSEcdheEcdsaWithAes256CbcSha) IsInitialized() bool {
return c.cbc.Load() != nil
}
// Init initializes the internal Cipher with keying material
func (c *TLSEcdheEcdsaWithAes256CbcSha) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 20
prfKeyLen = 32
prfIvLen = 16
)
keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc())
if err != nil {
return err
}
var cbc *ciphersuite.CBC
if isClient {
cbc, err = ciphersuite.NewCBC(
keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey,
keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey,
sha1.New,
)
} else {
cbc, err = ciphersuite.NewCBC(
keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey,
keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey,
sha1.New,
)
}
c.cbc.Store(cbc)
return err
}
// Encrypt encrypts a single TLS RecordLayer
func (c *TLSEcdheEcdsaWithAes256CbcSha) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cbc.(*ciphersuite.CBC).Encrypt(pkt, raw)
}
// Decrypt decrypts a single TLS RecordLayer
func (c *TLSEcdheEcdsaWithAes256CbcSha) Decrypt(raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cbc.(*ciphersuite.CBC).Decrypt(raw)
}

View file

@ -0,0 +1,22 @@
package ciphersuite
import "github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
// TLSEcdheRsaWithAes128GcmSha256 implements the TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuite
type TLSEcdheRsaWithAes128GcmSha256 struct {
TLSEcdheEcdsaWithAes128GcmSha256
}
// CertificateType returns what type of certificate this CipherSuite exchanges
func (c *TLSEcdheRsaWithAes128GcmSha256) CertificateType() clientcertificate.Type {
return clientcertificate.RSASign
}
// ID returns the ID of the CipherSuite
func (c *TLSEcdheRsaWithAes128GcmSha256) ID() ID {
return TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
}
func (c *TLSEcdheRsaWithAes128GcmSha256) String() string {
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
}

View file

@ -0,0 +1,22 @@
package ciphersuite
import "github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
// TLSEcdheRsaWithAes256CbcSha implements the TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuite
type TLSEcdheRsaWithAes256CbcSha struct {
TLSEcdheEcdsaWithAes256CbcSha
}
// CertificateType returns what type of certificate this CipherSuite exchanges
func (c *TLSEcdheRsaWithAes256CbcSha) CertificateType() clientcertificate.Type {
return clientcertificate.RSASign
}
// ID returns the ID of the CipherSuite
func (c *TLSEcdheRsaWithAes256CbcSha) ID() ID {
return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
}
func (c *TLSEcdheRsaWithAes256CbcSha) String() string {
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
}

View file

@ -0,0 +1,100 @@
package ciphersuite
import (
"crypto/sha256"
"fmt"
"hash"
"sync/atomic"
"github.com/pion/dtls/v2/pkg/crypto/ciphersuite"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
// TLSPskWithAes128CbcSha256 implements the TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuite
type TLSPskWithAes128CbcSha256 struct {
cbc atomic.Value // *cryptoCBC
}
// CertificateType returns what type of certificate this CipherSuite exchanges
func (c *TLSPskWithAes128CbcSha256) CertificateType() clientcertificate.Type {
return clientcertificate.Type(0)
}
// ID returns the ID of the CipherSuite
func (c *TLSPskWithAes128CbcSha256) ID() ID {
return TLS_PSK_WITH_AES_128_CBC_SHA256
}
func (c *TLSPskWithAes128CbcSha256) String() string {
return "TLS_PSK_WITH_AES_128_CBC_SHA256"
}
// HashFunc returns the hashing func for this CipherSuite
func (c *TLSPskWithAes128CbcSha256) HashFunc() func() hash.Hash {
return sha256.New
}
// AuthenticationType controls what authentication method is using during the handshake
func (c *TLSPskWithAes128CbcSha256) AuthenticationType() AuthenticationType {
return AuthenticationTypePreSharedKey
}
// IsInitialized returns if the CipherSuite has keying material and can
// encrypt/decrypt packets
func (c *TLSPskWithAes128CbcSha256) IsInitialized() bool {
return c.cbc.Load() != nil
}
// Init initializes the internal Cipher with keying material
func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 32
prfKeyLen = 16
prfIvLen = 16
)
keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc())
if err != nil {
return err
}
var cbc *ciphersuite.CBC
if isClient {
cbc, err = ciphersuite.NewCBC(
keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey,
keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey,
c.HashFunc(),
)
} else {
cbc, err = ciphersuite.NewCBC(
keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey,
keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey,
c.HashFunc(),
)
}
c.cbc.Store(cbc)
return err
}
// Encrypt encrypts a single TLS RecordLayer
func (c *TLSPskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cbc.(*ciphersuite.CBC).Encrypt(pkt, raw)
}
// Decrypt decrypts a single TLS RecordLayer
func (c *TLSPskWithAes128CbcSha256) Decrypt(raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cbc.(*ciphersuite.CBC).Decrypt(raw)
}

View file

@ -0,0 +1,11 @@
package ciphersuite
import (
"github.com/pion/dtls/v2/pkg/crypto/ciphersuite"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
)
// NewTLSPskWithAes128Ccm returns the TLS_PSK_WITH_AES_128_CCM CipherSuite
func NewTLSPskWithAes128Ccm() *Aes128Ccm {
return newAes128Ccm(clientcertificate.Type(0), TLS_PSK_WITH_AES_128_CCM, true, ciphersuite.CCMTagLength)
}

View file

@ -0,0 +1,11 @@
package ciphersuite
import (
"github.com/pion/dtls/v2/pkg/crypto/ciphersuite"
"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
)
// NewTLSPskWithAes128Ccm8 returns the TLS_PSK_WITH_AES_128_CCM_8 CipherSuite
func NewTLSPskWithAes128Ccm8() *Aes128Ccm {
return newAes128Ccm(clientcertificate.Type(0), TLS_PSK_WITH_AES_128_CCM_8, true, ciphersuite.CCMTagLength8)
}

View file

@ -0,0 +1,27 @@
package ciphersuite
import "github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
// TLSPskWithAes128GcmSha256 implements the TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuite
type TLSPskWithAes128GcmSha256 struct {
TLSEcdheEcdsaWithAes128GcmSha256
}
// CertificateType returns what type of certificate this CipherSuite exchanges
func (c *TLSPskWithAes128GcmSha256) CertificateType() clientcertificate.Type {
return clientcertificate.Type(0)
}
// ID returns the ID of the CipherSuite
func (c *TLSPskWithAes128GcmSha256) ID() ID {
return TLS_PSK_WITH_AES_128_GCM_SHA256
}
func (c *TLSPskWithAes128GcmSha256) String() string {
return "TLS_PSK_WITH_AES_128_GCM_SHA256"
}
// AuthenticationType controls what authentication method is using during the handshake
func (c *TLSPskWithAes128GcmSha256) AuthenticationType() AuthenticationType {
return AuthenticationTypePreSharedKey
}

View file

@ -0,0 +1,39 @@
// Package util contains small helpers used across the repo
package util
import (
"encoding/binary"
)
// BigEndianUint24 returns the value of a big endian uint24
func BigEndianUint24(raw []byte) uint32 {
if len(raw) < 3 {
return 0
}
rawCopy := make([]byte, 4)
copy(rawCopy[1:], raw)
return binary.BigEndian.Uint32(rawCopy)
}
// PutBigEndianUint24 encodes a uint24 and places into out
func PutBigEndianUint24(out []byte, in uint32) {
tmp := make([]byte, 4)
binary.BigEndian.PutUint32(tmp, in)
copy(out, tmp[1:])
}
// PutBigEndianUint48 encodes a uint64 and places into out
func PutBigEndianUint48(out []byte, in uint64) {
tmp := make([]byte, 8)
binary.BigEndian.PutUint64(tmp, in)
copy(out, tmp[2:])
}
// Max returns the larger value
func Max(a, b int) int {
if a > b {
return a
}
return b
}

View file

@ -3,6 +3,8 @@ package dtls
import (
"net"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
"github.com/pion/udp"
)
@ -14,15 +16,15 @@ func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, e
lc := udp.ListenConfig{
AcceptFilter: func(packet []byte) bool {
pkts, err := unpackDatagram(packet)
pkts, err := recordlayer.UnpackDatagram(packet)
if err != nil || len(pkts) < 1 {
return false
}
h := &recordLayerHeader{}
h := &recordlayer.Header{}
if err := h.Unmarshal(pkts[0]); err != nil {
return false
}
return h.contentType == contentTypeHandshake
return h.ContentType == protocol.ContentTypeHandshake
},
}
parent, err := lc.Listen(network, laddr)

View file

@ -1,62 +0,0 @@
package dtls
import (
"crypto/elliptic"
"crypto/rand"
"golang.org/x/crypto/curve25519"
)
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8
type namedCurve uint16
type namedCurveKeypair struct {
curve namedCurve
publicKey []byte
privateKey []byte
}
const (
namedCurveP256 namedCurve = 0x0017
namedCurveP384 namedCurve = 0x0018
namedCurveX25519 namedCurve = 0x001d
)
func namedCurves() map[namedCurve]bool {
return map[namedCurve]bool{
namedCurveX25519: true,
namedCurveP256: true,
namedCurveP384: true,
}
}
func generateKeypair(c namedCurve) (*namedCurveKeypair, error) {
switch c { //nolint:golint
case namedCurveX25519:
tmp := make([]byte, 32)
if _, err := rand.Read(tmp); err != nil {
return nil, err
}
var public, private [32]byte
copy(private[:], tmp)
curve25519.ScalarBaseMult(&public, &private)
return &namedCurveKeypair{namedCurveX25519, public[:], private[:]}, nil
case namedCurveP256:
return ellipticCurveKeypair(namedCurveP256, elliptic.P256(), elliptic.P256())
case namedCurveP384:
return ellipticCurveKeypair(namedCurveP384, elliptic.P384(), elliptic.P384())
default:
return nil, errInvalidNamedCurve
}
}
func ellipticCurveKeypair(nc namedCurve, c1, c2 elliptic.Curve) (*namedCurveKeypair, error) {
privateKey, x, y, err := elliptic.GenerateKey(c1, rand.Reader)
if err != nil {
return nil, err
}
return &namedCurveKeypair{nc, elliptic.Marshal(c2, x, y), privateKey}, nil
}

View file

@ -1,7 +1,9 @@
package dtls
import "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
type packet struct {
record *recordLayer
record *recordlayer.RecordLayer
shouldEncrypt bool
resetLocalSequenceNumber bool
}

View file

@ -0,0 +1,164 @@
package ciphersuite
import ( //nolint:gci
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"encoding/binary"
"hash"
"github.com/pion/dtls/v2/internal/util"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
// block ciphers using cipher block chaining.
type cbcMode interface {
cipher.BlockMode
SetIV([]byte)
}
// CBC Provides an API to Encrypt/Decrypt DTLS 1.2 Packets
type CBC struct {
writeCBC, readCBC cbcMode
writeMac, readMac []byte
h prf.HashFunc
}
// NewCBC creates a DTLS CBC Cipher
func NewCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte, h prf.HashFunc) (*CBC, error) {
writeBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
readBlock, err := aes.NewCipher(remoteKey)
if err != nil {
return nil, err
}
return &CBC{
writeCBC: cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode),
writeMac: localMac,
readCBC: cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode),
readMac: remoteMac,
h: h,
}, nil
}
// Encrypt encrypt a DTLS RecordLayer message
func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
payload := raw[recordlayer.HeaderSize:]
raw = raw[:recordlayer.HeaderSize]
blockSize := c.writeCBC.BlockSize()
// Generate + Append MAC
h := pkt.Header
MAC, err := c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, payload, c.writeMac, c.h)
if err != nil {
return nil, err
}
payload = append(payload, MAC...)
// Generate + Append padding
padding := make([]byte, blockSize-len(payload)%blockSize)
paddingLen := len(padding)
for i := 0; i < paddingLen; i++ {
padding[i] = byte(paddingLen - 1)
}
payload = append(payload, padding...)
// Generate IV
iv := make([]byte, blockSize)
if _, err := rand.Read(iv); err != nil {
return nil, err
}
// Set IV + Encrypt + Prepend IV
c.writeCBC.SetIV(iv)
c.writeCBC.CryptBlocks(payload, payload)
payload = append(iv, payload...)
// Prepend unencrypte header with encrypted payload
raw = append(raw, payload...)
// Update recordLayer size to include IV+MAC+Padding
binary.BigEndian.PutUint16(raw[recordlayer.HeaderSize-2:], uint16(len(raw)-recordlayer.HeaderSize))
return raw, nil
}
// Decrypt decrypts a DTLS RecordLayer message
func (c *CBC) Decrypt(in []byte) ([]byte, error) {
body := in[recordlayer.HeaderSize:]
blockSize := c.readCBC.BlockSize()
mac := c.h()
var h recordlayer.Header
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.ContentType == protocol.ContentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(body)%blockSize != 0 || len(body) < blockSize+util.Max(mac.Size()+1, blockSize):
return nil, errNotEnoughRoomForNonce
}
// Set + remove per record IV
c.readCBC.SetIV(body[:blockSize])
body = body[blockSize:]
// Decrypt
c.readCBC.CryptBlocks(body, body)
// Padding+MAC needs to be checked in constant time
// Otherwise we reveal information about the level of correctness
paddingLen, paddingGood := examinePadding(body)
if paddingGood != 255 {
return nil, errInvalidMAC
}
macSize := mac.Size()
if len(body) < macSize {
return nil, errInvalidMAC
}
dataEnd := len(body) - macSize - paddingLen
expectedMAC := body[dataEnd : dataEnd+macSize]
actualMAC, err := c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, body[:dataEnd], c.readMac, c.h)
// Compute Local MAC and compare
if err != nil || !hmac.Equal(actualMAC, expectedMAC) {
return nil, errInvalidMAC
}
return append(in[:recordlayer.HeaderSize], body[:dataEnd]...), nil
}
func (c *CBC) hmac(epoch uint16, sequenceNumber uint64, contentType protocol.ContentType, protocolVersion protocol.Version, payload []byte, key []byte, hf func() hash.Hash) ([]byte, error) {
h := hmac.New(hf, key)
msg := make([]byte, 13)
binary.BigEndian.PutUint16(msg, epoch)
util.PutBigEndianUint48(msg[2:], sequenceNumber)
msg[8] = byte(contentType)
msg[9] = protocolVersion.Major
msg[10] = protocolVersion.Minor
binary.BigEndian.PutUint16(msg[11:], uint16(len(payload)))
if _, err := h.Write(msg); err != nil {
return nil, err
} else if _, err := h.Write(payload); err != nil {
return nil, err
}
return h.Sum(nil), nil
}

View file

@ -1,38 +1,40 @@
package dtls
package ciphersuite
import (
"crypto/aes"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"github.com/pion/dtls/v2/pkg/crypto/ccm"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
var errDecryptPacket = errors.New("decryptPacket")
type cryptoCCMTagLen int
// CCMTagLen is the length of Authentication Tag
type CCMTagLen int
// CCM Enums
const (
cryptoCCM8TagLength cryptoCCMTagLen = 8
cryptoCCMTagLength cryptoCCMTagLen = 16
cryptoCCMNonceLength = 12
CCMTagLength8 CCMTagLen = 8
CCMTagLength CCMTagLen = 16
ccmNonceLength = 12
)
// State needed to handle encrypted input/output
type cryptoCCM struct {
// CCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets
type CCM struct {
localCCM, remoteCCM ccm.CCM
localWriteIV, remoteWriteIV []byte
tagLen cryptoCCMTagLen
tagLen CCMTagLen
}
func newCryptoCCM(tagLen cryptoCCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*cryptoCCM, error) {
// NewCCM creates a DTLS GCM Cipher
func NewCCM(tagLen CCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*CCM, error) {
localBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
localCCM, err := ccm.NewCCM(localBlock, int(tagLen), cryptoCCMNonceLength)
localCCM, err := ccm.NewCCM(localBlock, int(tagLen), ccmNonceLength)
if err != nil {
return nil, err
}
@ -41,12 +43,12 @@ func newCryptoCCM(tagLen cryptoCCMTagLen, localKey, localWriteIV, remoteKey, rem
if err != nil {
return nil, err
}
remoteCCM, err := ccm.NewCCM(remoteBlock, int(tagLen), cryptoCCMNonceLength)
remoteCCM, err := ccm.NewCCM(remoteBlock, int(tagLen), ccmNonceLength)
if err != nil {
return nil, err
}
return &cryptoCCM{
return &CCM{
localCCM: localCCM,
localWriteIV: localWriteIV,
remoteCCM: remoteCCM,
@ -55,46 +57,48 @@ func newCryptoCCM(tagLen cryptoCCMTagLen, localKey, localWriteIV, remoteKey, rem
}, nil
}
func (c *cryptoCCM) encrypt(pkt *recordLayer, raw []byte) ([]byte, error) {
payload := raw[recordLayerHeaderSize:]
raw = raw[:recordLayerHeaderSize]
// Encrypt encrypt a DTLS RecordLayer message
func (c *CCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
payload := raw[recordlayer.HeaderSize:]
raw = raw[:recordlayer.HeaderSize]
nonce := append(append([]byte{}, c.localWriteIV[:4]...), make([]byte, 8)...)
if _, err := rand.Read(nonce[4:]); err != nil {
return nil, err
}
additionalData := generateAEADAdditionalData(&pkt.recordLayerHeader, len(payload))
additionalData := generateAEADAdditionalData(&pkt.Header, len(payload))
encryptedPayload := c.localCCM.Seal(nil, nonce, payload, additionalData)
encryptedPayload = append(nonce[4:], encryptedPayload...)
raw = append(raw, encryptedPayload...)
// Update recordLayer size to include explicit nonce
binary.BigEndian.PutUint16(raw[recordLayerHeaderSize-2:], uint16(len(raw)-recordLayerHeaderSize))
binary.BigEndian.PutUint16(raw[recordlayer.HeaderSize-2:], uint16(len(raw)-recordlayer.HeaderSize))
return raw, nil
}
func (c *cryptoCCM) decrypt(in []byte) ([]byte, error) {
var h recordLayerHeader
// Decrypt decrypts a DTLS RecordLayer message
func (c *CCM) Decrypt(in []byte) ([]byte, error) {
var h recordlayer.Header
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.contentType == contentTypeChangeCipherSpec:
case h.ContentType == protocol.ContentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(in) <= (8 + recordLayerHeaderSize):
case len(in) <= (8 + recordlayer.HeaderSize):
return nil, errNotEnoughRoomForNonce
}
nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[recordLayerHeaderSize:recordLayerHeaderSize+8]...)
out := in[recordLayerHeaderSize+8:]
nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[recordlayer.HeaderSize:recordlayer.HeaderSize+8]...)
out := in[recordlayer.HeaderSize+8:]
additionalData := generateAEADAdditionalData(&h, len(out)-int(c.tagLen))
out, err = c.remoteCCM.Open(out[:0], nonce, out, additionalData)
if err != nil {
return nil, fmt.Errorf("%w: %v", errDecryptPacket, err)
}
return append(in[:recordLayerHeaderSize], out...), nil
return append(in[:recordlayer.HeaderSize], out...), nil
}

View file

@ -0,0 +1,72 @@
// Package ciphersuite provides the crypto operations needed for a DTLS CipherSuite
package ciphersuite
import (
"encoding/binary"
"errors"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
var (
errNotEnoughRoomForNonce = &protocol.InternalError{Err: errors.New("buffer not long enough to contain nonce")} //nolint:goerr113
errDecryptPacket = &protocol.TemporaryError{Err: errors.New("failed to decrypt packet")} //nolint:goerr113
errInvalidMAC = &protocol.TemporaryError{Err: errors.New("invalid mac")} //nolint:goerr113
)
func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte {
var additionalData [13]byte
// SequenceNumber MUST be set first
// we only want uint48, clobbering an extra 2 (using uint64, Golang doesn't have uint48)
binary.BigEndian.PutUint64(additionalData[:], h.SequenceNumber)
binary.BigEndian.PutUint16(additionalData[:], h.Epoch)
additionalData[8] = byte(h.ContentType)
additionalData[9] = h.Version.Major
additionalData[10] = h.Version.Minor
binary.BigEndian.PutUint16(additionalData[len(additionalData)-2:], uint16(payloadLen))
return additionalData[:]
}
// examinePadding returns, in constant time, the length of the padding to remove
// from the end of payload. It also returns a byte which is equal to 255 if the
// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2.
//
// https://github.com/golang/go/blob/039c2081d1178f90a8fa2f4e6958693129f8de33/src/crypto/tls/conn.go#L245
func examinePadding(payload []byte) (toRemove int, good byte) {
if len(payload) < 1 {
return 0, 0
}
paddingLen := payload[len(payload)-1]
t := uint(len(payload)-1) - uint(paddingLen)
// if len(payload) >= (paddingLen - 1) then the MSB of t is zero
good = byte(int32(^t) >> 31)
// The maximum possible padding length plus the actual length field
toCheck := 256
// The length of the padded data is public, so we can use an if here
if toCheck > len(payload) {
toCheck = len(payload)
}
for i := 0; i < toCheck; i++ {
t := uint(paddingLen) - uint(i)
// if i <= paddingLen then the MSB of t is zero
mask := byte(int32(^t) >> 31)
b := payload[len(payload)-1-i]
good &^= mask&paddingLen ^ mask&b
}
// We AND together the bits of good and replicate the result across
// all the bits.
good &= good << 4
good &= good << 2
good &= good << 1
good = uint8(int8(good) >> 7)
toRemove = int(paddingLen) + 1
return toRemove, good
}

View file

@ -0,0 +1,100 @@
package ciphersuite
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"fmt"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
const (
gcmTagLength = 16
gcmNonceLength = 12
)
// GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets
type GCM struct {
localGCM, remoteGCM cipher.AEAD
localWriteIV, remoteWriteIV []byte
}
// NewGCM creates a DTLS GCM Cipher
func NewGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*GCM, error) {
localBlock, err := aes.NewCipher(localKey)
if err != nil {
return nil, err
}
localGCM, err := cipher.NewGCM(localBlock)
if err != nil {
return nil, err
}
remoteBlock, err := aes.NewCipher(remoteKey)
if err != nil {
return nil, err
}
remoteGCM, err := cipher.NewGCM(remoteBlock)
if err != nil {
return nil, err
}
return &GCM{
localGCM: localGCM,
localWriteIV: localWriteIV,
remoteGCM: remoteGCM,
remoteWriteIV: remoteWriteIV,
}, nil
}
// Encrypt encrypt a DTLS RecordLayer message
func (g *GCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
payload := raw[recordlayer.HeaderSize:]
raw = raw[:recordlayer.HeaderSize]
nonce := make([]byte, gcmNonceLength)
copy(nonce, g.localWriteIV[:4])
if _, err := rand.Read(nonce[4:]); err != nil {
return nil, err
}
additionalData := generateAEADAdditionalData(&pkt.Header, len(payload))
encryptedPayload := g.localGCM.Seal(nil, nonce, payload, additionalData)
r := make([]byte, len(raw)+len(nonce[4:])+len(encryptedPayload))
copy(r, raw)
copy(r[len(raw):], nonce[4:])
copy(r[len(raw)+len(nonce[4:]):], encryptedPayload)
// Update recordLayer size to include explicit nonce
binary.BigEndian.PutUint16(r[recordlayer.HeaderSize-2:], uint16(len(r)-recordlayer.HeaderSize))
return r, nil
}
// Decrypt decrypts a DTLS RecordLayer message
func (g *GCM) Decrypt(in []byte) ([]byte, error) {
var h recordlayer.Header
err := h.Unmarshal(in)
switch {
case err != nil:
return nil, err
case h.ContentType == protocol.ContentTypeChangeCipherSpec:
// Nothing to encrypt with ChangeCipherSpec
return in, nil
case len(in) <= (8 + recordlayer.HeaderSize):
return nil, errNotEnoughRoomForNonce
}
nonce := make([]byte, 0, gcmNonceLength)
nonce = append(append(nonce, g.remoteWriteIV[:4]...), in[recordlayer.HeaderSize:recordlayer.HeaderSize+8]...)
out := in[recordlayer.HeaderSize+8:]
additionalData := generateAEADAdditionalData(&h, len(out)-gcmTagLength)
out, err = g.remoteGCM.Open(out[:0], nonce, out, additionalData)
if err != nil {
return nil, fmt.Errorf("%w: %v", errDecryptPacket, err)
}
return append(in[:recordlayer.HeaderSize], out...), nil
}

View file

@ -0,0 +1,22 @@
// Package clientcertificate provides all the support Client Certificate types
package clientcertificate
// Type is used to communicate what
// type of certificate is being transported
//
//https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-2
type Type byte
// ClientCertificateType enums
const (
RSASign Type = 1
ECDSASign Type = 64
)
// Types returns all valid ClientCertificate Types
func Types() map[Type]bool {
return map[Type]bool{
RSASign: true,
ECDSASign: true,
}
}

View file

@ -0,0 +1,99 @@
// Package elliptic provides elliptic curve cryptography for DTLS
package elliptic
import (
"crypto/elliptic"
"crypto/rand"
"errors"
"golang.org/x/crypto/curve25519"
)
var errInvalidNamedCurve = errors.New("invalid named curve")
// CurvePointFormat is used to represent the IANA registered curve points
//
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9
type CurvePointFormat byte
// CurvePointFormat enums
const (
CurvePointFormatUncompressed CurvePointFormat = 0
)
// Keypair is a Curve with a Private/Public Keypair
type Keypair struct {
Curve Curve
PublicKey []byte
PrivateKey []byte
}
// CurveType is used to represent the IANA registered curve types for TLS
//
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10
type CurveType byte
// CurveType enums
const (
CurveTypeNamedCurve CurveType = 0x03
)
// CurveTypes returns all known curves
func CurveTypes() map[CurveType]struct{} {
return map[CurveType]struct{}{
CurveTypeNamedCurve: {},
}
}
// Curve is used to represent the IANA registered curves for TLS
//
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8
type Curve uint16
// Curve enums
const (
P256 Curve = 0x0017
P384 Curve = 0x0018
X25519 Curve = 0x001d
)
// Curves returns all curves we implement
func Curves() map[Curve]bool {
return map[Curve]bool{
X25519: true,
P256: true,
P384: true,
}
}
// GenerateKeypair generates a keypair for the given Curve
func GenerateKeypair(c Curve) (*Keypair, error) {
switch c { //nolint:golint
case X25519:
tmp := make([]byte, 32)
if _, err := rand.Read(tmp); err != nil {
return nil, err
}
var public, private [32]byte
copy(private[:], tmp)
curve25519.ScalarBaseMult(&public, &private)
return &Keypair{X25519, public[:], private[:]}, nil
case P256:
return ellipticCurveKeypair(P256, elliptic.P256(), elliptic.P256())
case P384:
return ellipticCurveKeypair(P384, elliptic.P384(), elliptic.P384())
default:
return nil, errInvalidNamedCurve
}
}
func ellipticCurveKeypair(nc Curve, c1, c2 elliptic.Curve) (*Keypair, error) {
privateKey, x, y, err := elliptic.GenerateKey(c1, rand.Reader)
if err != nil {
return nil, err
}
return &Keypair{nc, elliptic.Marshal(c2, x, y), privateKey}, nil
}

View file

@ -0,0 +1,126 @@
// Package hash provides TLS HashAlgorithm as defined in TLS 1.2
package hash
import ( //nolint:gci
"crypto"
"crypto/md5" //nolint:gosec
"crypto/sha1" //nolint:gosec
"crypto/sha256"
"crypto/sha512"
)
// Algorithm is used to indicate the hash algorithm used
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-18
type Algorithm uint16
// Supported hash algorithms
const (
None Algorithm = 0 // Blacklisted
MD5 Algorithm = 1 // Blacklisted
SHA1 Algorithm = 2 // Blacklisted
SHA224 Algorithm = 3
SHA256 Algorithm = 4
SHA384 Algorithm = 5
SHA512 Algorithm = 6
Ed25519 Algorithm = 8
)
// String makes hashAlgorithm printable
func (a Algorithm) String() string {
switch a {
case None:
return "none"
case MD5:
return "md5" // [RFC3279]
case SHA1:
return "sha-1" // [RFC3279]
case SHA224:
return "sha-224" // [RFC4055]
case SHA256:
return "sha-256" // [RFC4055]
case SHA384:
return "sha-384" // [RFC4055]
case SHA512:
return "sha-512" // [RFC4055]
case Ed25519:
return "null"
default:
return "unknown or unsupported hash algorithm"
}
}
// Digest performs a digest on the passed value
func (a Algorithm) Digest(b []byte) []byte {
switch a {
case None:
return nil
case MD5:
hash := md5.Sum(b) // #nosec
return hash[:]
case SHA1:
hash := sha1.Sum(b) // #nosec
return hash[:]
case SHA224:
hash := sha256.Sum224(b)
return hash[:]
case SHA256:
hash := sha256.Sum256(b)
return hash[:]
case SHA384:
hash := sha512.Sum384(b)
return hash[:]
case SHA512:
hash := sha512.Sum512(b)
return hash[:]
default:
return nil
}
}
// Insecure returns if the given HashAlgorithm is considered secure in DTLS 1.2
func (a Algorithm) Insecure() bool {
switch a {
case None, MD5, SHA1:
return true
default:
return false
}
}
// CryptoHash returns the crypto.Hash implementation for the given HashAlgorithm
func (a Algorithm) CryptoHash() crypto.Hash {
switch a {
case None:
return crypto.Hash(0)
case MD5:
return crypto.MD5
case SHA1:
return crypto.SHA1
case SHA224:
return crypto.SHA224
case SHA256:
return crypto.SHA256
case SHA384:
return crypto.SHA384
case SHA512:
return crypto.SHA512
case Ed25519:
return crypto.Hash(0)
default:
return crypto.Hash(0)
}
}
// Algorithms returns all the supported Hash Algorithms
func Algorithms() map[Algorithm]struct{} {
return map[Algorithm]struct{}{
None: {},
MD5: {},
SHA1: {},
SHA224: {},
SHA256: {},
SHA384: {},
SHA512: {},
Ed25519: {},
}
}

View file

@ -0,0 +1,224 @@
// Package prf implements TLS 1.2 Pseudorandom functions
package prf
import ( //nolint:gci
ellipticStdlib "crypto/elliptic"
"crypto/hmac"
"encoding/binary"
"errors"
"fmt"
"hash"
"math"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/protocol"
"golang.org/x/crypto/curve25519"
)
const (
masterSecretLabel = "master secret"
extendedMasterSecretLabel = "extended master secret"
keyExpansionLabel = "key expansion"
verifyDataClientLabel = "client finished"
verifyDataServerLabel = "server finished"
)
// HashFunc allows callers to decide what hash is used in PRF
type HashFunc func() hash.Hash
// EncryptionKeys is all the state needed for a TLS CipherSuite
type EncryptionKeys struct {
MasterSecret []byte
ClientMACKey []byte
ServerMACKey []byte
ClientWriteKey []byte
ServerWriteKey []byte
ClientWriteIV []byte
ServerWriteIV []byte
}
var errInvalidNamedCurve = &protocol.FatalError{Err: errors.New("invalid named curve")} //nolint:goerr113
func (e *EncryptionKeys) String() string {
return fmt.Sprintf(`encryptionKeys:
- masterSecret: %#v
- clientMACKey: %#v
- serverMACKey: %#v
- clientWriteKey: %#v
- serverWriteKey: %#v
- clientWriteIV: %#v
- serverWriteIV: %#v
`,
e.MasterSecret,
e.ClientMACKey,
e.ServerMACKey,
e.ClientWriteKey,
e.ServerWriteKey,
e.ClientWriteIV,
e.ServerWriteIV)
}
// PSKPreMasterSecret generates the PSK Premaster Secret
// The premaster secret is formed as follows: if the PSK is N octets
// long, concatenate a uint16 with the value N, N zero octets, a second
// uint16 with the value N, and the PSK itself.
//
// https://tools.ietf.org/html/rfc4279#section-2
func PSKPreMasterSecret(psk []byte) []byte {
pskLen := uint16(len(psk))
out := append(make([]byte, 2+pskLen+2), psk...)
binary.BigEndian.PutUint16(out, pskLen)
binary.BigEndian.PutUint16(out[2+pskLen:], pskLen)
return out
}
// PreMasterSecret implements TLS 1.2 Premaster Secret generation given a keypair and a curve
func PreMasterSecret(publicKey, privateKey []byte, curve elliptic.Curve) ([]byte, error) {
switch curve {
case elliptic.X25519:
return curve25519.X25519(privateKey, publicKey)
case elliptic.P256:
return ellipticCurvePreMasterSecret(publicKey, privateKey, ellipticStdlib.P256(), ellipticStdlib.P256())
case elliptic.P384:
return ellipticCurvePreMasterSecret(publicKey, privateKey, ellipticStdlib.P384(), ellipticStdlib.P384())
default:
return nil, errInvalidNamedCurve
}
}
func ellipticCurvePreMasterSecret(publicKey, privateKey []byte, c1, c2 ellipticStdlib.Curve) ([]byte, error) {
x, y := ellipticStdlib.Unmarshal(c1, publicKey)
if x == nil || y == nil {
return nil, errInvalidNamedCurve
}
result, _ := c2.ScalarMult(x, y, privateKey)
preMasterSecret := make([]byte, (c2.Params().BitSize+7)>>3)
resultBytes := result.Bytes()
copy(preMasterSecret[len(preMasterSecret)-len(resultBytes):], resultBytes)
return preMasterSecret, nil
}
// PHash is PRF is the SHA-256 hash function is used for all cipher suites
// defined in this TLS 1.2 document and in TLS documents published prior to this
// document when TLS 1.2 is negotiated. New cipher suites MUST explicitly
// specify a PRF and, in general, SHOULD use the TLS PRF with SHA-256 or a
// stronger standard hash function.
//
// P_hash(secret, seed) = HMAC_hash(secret, A(1) + seed) +
// HMAC_hash(secret, A(2) + seed) +
// HMAC_hash(secret, A(3) + seed) + ...
//
// A() is defined as:
//
// A(0) = seed
// A(i) = HMAC_hash(secret, A(i-1))
//
// P_hash can be iterated as many times as necessary to produce the
// required quantity of data. For example, if P_SHA256 is being used to
// create 80 bytes of data, it will have to be iterated three times
// (through A(3)), creating 96 bytes of output data; the last 16 bytes
// of the final iteration will then be discarded, leaving 80 bytes of
// output data.
//
// https://tools.ietf.org/html/rfc4346w
func PHash(secret, seed []byte, requestedLength int, h HashFunc) ([]byte, error) {
hmacSHA256 := func(key, data []byte) ([]byte, error) {
mac := hmac.New(h, key)
if _, err := mac.Write(data); err != nil {
return nil, err
}
return mac.Sum(nil), nil
}
var err error
lastRound := seed
out := []byte{}
iterations := int(math.Ceil(float64(requestedLength) / float64(h().Size())))
for i := 0; i < iterations; i++ {
lastRound, err = hmacSHA256(secret, lastRound)
if err != nil {
return nil, err
}
withSecret, err := hmacSHA256(secret, append(lastRound, seed...))
if err != nil {
return nil, err
}
out = append(out, withSecret...)
}
return out[:requestedLength], nil
}
// ExtendedMasterSecret generates a Extended MasterSecret as defined in
// https://tools.ietf.org/html/rfc7627
func ExtendedMasterSecret(preMasterSecret, sessionHash []byte, h HashFunc) ([]byte, error) {
seed := append([]byte(extendedMasterSecretLabel), sessionHash...)
return PHash(preMasterSecret, seed, 48, h)
}
// MasterSecret generates a TLS 1.2 MasterSecret
func MasterSecret(preMasterSecret, clientRandom, serverRandom []byte, h HashFunc) ([]byte, error) {
seed := append(append([]byte(masterSecretLabel), clientRandom...), serverRandom...)
return PHash(preMasterSecret, seed, 48, h)
}
// GenerateEncryptionKeys is the final step TLS 1.2 PRF. Given all state generated so far generates
// the final keys need for encryption
func GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int, h HashFunc) (*EncryptionKeys, error) {
seed := append(append([]byte(keyExpansionLabel), serverRandom...), clientRandom...)
keyMaterial, err := PHash(masterSecret, seed, (2*macLen)+(2*keyLen)+(2*ivLen), h)
if err != nil {
return nil, err
}
clientMACKey := keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
serverMACKey := keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
clientWriteKey := keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
serverWriteKey := keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
clientWriteIV := keyMaterial[:ivLen]
keyMaterial = keyMaterial[ivLen:]
serverWriteIV := keyMaterial[:ivLen]
return &EncryptionKeys{
MasterSecret: masterSecret,
ClientMACKey: clientMACKey,
ServerMACKey: serverMACKey,
ClientWriteKey: clientWriteKey,
ServerWriteKey: serverWriteKey,
ClientWriteIV: clientWriteIV,
ServerWriteIV: serverWriteIV,
}, nil
}
func prfVerifyData(masterSecret, handshakeBodies []byte, label string, hashFunc HashFunc) ([]byte, error) {
h := hashFunc()
if _, err := h.Write(handshakeBodies); err != nil {
return nil, err
}
seed := append([]byte(label), h.Sum(nil)...)
return PHash(masterSecret, seed, 12, hashFunc)
}
// VerifyDataClient is caled on the Client Side to either verify or generate the VerifyData message
func VerifyDataClient(masterSecret, handshakeBodies []byte, h HashFunc) ([]byte, error) {
return prfVerifyData(masterSecret, handshakeBodies, verifyDataClientLabel, h)
}
// VerifyDataServer is caled on the Server Side to either verify or generate the VerifyData message
func VerifyDataServer(masterSecret, handshakeBodies []byte, h HashFunc) ([]byte, error) {
return prfVerifyData(masterSecret, handshakeBodies, verifyDataServerLabel, h)
}

View file

@ -0,0 +1,24 @@
// Package signature provides our implemented Signature Algorithms
package signature
// Algorithm as defined in TLS 1.2
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-16
type Algorithm uint16
// SignatureAlgorithm enums
const (
Anonymous Algorithm = 0
RSA Algorithm = 1
ECDSA Algorithm = 3
Ed25519 Algorithm = 7
)
// Algorithms returns all implemented Signature Algorithms
func Algorithms() map[Algorithm]struct{} {
return map[Algorithm]struct{}{
Anonymous: {},
RSA: {},
ECDSA: {},
Ed25519: {},
}
}

View file

@ -0,0 +1,9 @@
package signaturehash
import "errors"
var (
errNoAvailableSignatureSchemes = errors.New("connection can not be created, no SignatureScheme satisfy this Config")
errInvalidSignatureAlgorithm = errors.New("invalid signature algorithm")
errInvalidHashAlgorithm = errors.New("invalid hash algorithm")
)

View file

@ -0,0 +1,93 @@
// Package signaturehash provides the SignatureHashAlgorithm as defined in TLS 1.2
package signaturehash
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/tls"
"github.com/pion/dtls/v2/pkg/crypto/hash"
"github.com/pion/dtls/v2/pkg/crypto/signature"
"golang.org/x/xerrors"
)
// Algorithm is a signature/hash algorithm pairs which may be used in
// digital signatures.
//
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
type Algorithm struct {
Hash hash.Algorithm
Signature signature.Algorithm
}
// Algorithms are all the know SignatureHash Algorithms
func Algorithms() []Algorithm {
return []Algorithm{
{hash.SHA256, signature.ECDSA},
{hash.SHA384, signature.ECDSA},
{hash.SHA512, signature.ECDSA},
{hash.SHA256, signature.RSA},
{hash.SHA384, signature.RSA},
{hash.SHA512, signature.RSA},
{hash.Ed25519, signature.Ed25519},
}
}
// SelectSignatureScheme returns most preferred and compatible scheme.
func SelectSignatureScheme(sigs []Algorithm, privateKey crypto.PrivateKey) (Algorithm, error) {
for _, ss := range sigs {
if ss.isCompatible(privateKey) {
return ss, nil
}
}
return Algorithm{}, errNoAvailableSignatureSchemes
}
// isCompatible checks that given private key is compatible with the signature scheme.
func (a *Algorithm) isCompatible(privateKey crypto.PrivateKey) bool {
switch privateKey.(type) {
case ed25519.PrivateKey:
return a.Signature == signature.Ed25519
case *ecdsa.PrivateKey:
return a.Signature == signature.ECDSA
case *rsa.PrivateKey:
return a.Signature == signature.RSA
default:
return false
}
}
// ParseSignatureSchemes translates []tls.SignatureScheme to []signatureHashAlgorithm.
// It returns default signature scheme list if no SignatureScheme is passed.
func ParseSignatureSchemes(sigs []tls.SignatureScheme, insecureHashes bool) ([]Algorithm, error) {
if len(sigs) == 0 {
return Algorithms(), nil
}
out := []Algorithm{}
for _, ss := range sigs {
sig := signature.Algorithm(ss & 0xFF)
if _, ok := signature.Algorithms()[sig]; !ok {
return nil,
xerrors.Errorf("SignatureScheme %04x: %w", ss, errInvalidSignatureAlgorithm)
}
h := hash.Algorithm(ss >> 8)
if _, ok := hash.Algorithms()[h]; !ok || (ok && h == hash.None) {
return nil, xerrors.Errorf("SignatureScheme %04x: %w", ss, errInvalidHashAlgorithm)
}
if h.Insecure() && !insecureHashes {
continue
}
out = append(out, Algorithm{
Hash: h,
Signature: sig,
})
}
if len(out) == 0 {
return nil, errNoAvailableSignatureSchemes
}
return out, nil
}

View file

@ -0,0 +1,160 @@
// Package alert implements TLS alert protocol https://tools.ietf.org/html/rfc5246#section-7.2
package alert
import (
"errors"
"fmt"
"github.com/pion/dtls/v2/pkg/protocol"
)
var errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113
// Level is the level of the TLS Alert
type Level byte
// Level enums
const (
Warning Level = 1
Fatal Level = 2
)
func (l Level) String() string {
switch l {
case Warning:
return "Warning"
case Fatal:
return "Fatal"
default:
return "Invalid alert level"
}
}
// Description is the extended info of the TLS Alert
type Description byte
// Description enums
const (
CloseNotify Description = 0
UnexpectedMessage Description = 10
BadRecordMac Description = 20
DecryptionFailed Description = 21
RecordOverflow Description = 22
DecompressionFailure Description = 30
HandshakeFailure Description = 40
NoCertificate Description = 41
BadCertificate Description = 42
UnsupportedCertificate Description = 43
CertificateRevoked Description = 44
CertificateExpired Description = 45
CertificateUnknown Description = 46
IllegalParameter Description = 47
UnknownCA Description = 48
AccessDenied Description = 49
DecodeError Description = 50
DecryptError Description = 51
ExportRestriction Description = 60
ProtocolVersion Description = 70
InsufficientSecurity Description = 71
InternalError Description = 80
UserCanceled Description = 90
NoRenegotiation Description = 100
UnsupportedExtension Description = 110
)
func (d Description) String() string {
switch d {
case CloseNotify:
return "CloseNotify"
case UnexpectedMessage:
return "UnexpectedMessage"
case BadRecordMac:
return "BadRecordMac"
case DecryptionFailed:
return "DecryptionFailed"
case RecordOverflow:
return "RecordOverflow"
case DecompressionFailure:
return "DecompressionFailure"
case HandshakeFailure:
return "HandshakeFailure"
case NoCertificate:
return "NoCertificate"
case BadCertificate:
return "BadCertificate"
case UnsupportedCertificate:
return "UnsupportedCertificate"
case CertificateRevoked:
return "CertificateRevoked"
case CertificateExpired:
return "CertificateExpired"
case CertificateUnknown:
return "CertificateUnknown"
case IllegalParameter:
return "IllegalParameter"
case UnknownCA:
return "UnknownCA"
case AccessDenied:
return "AccessDenied"
case DecodeError:
return "DecodeError"
case DecryptError:
return "DecryptError"
case ExportRestriction:
return "ExportRestriction"
case ProtocolVersion:
return "ProtocolVersion"
case InsufficientSecurity:
return "InsufficientSecurity"
case InternalError:
return "InternalError"
case UserCanceled:
return "UserCanceled"
case NoRenegotiation:
return "NoRenegotiation"
case UnsupportedExtension:
return "UnsupportedExtension"
default:
return "Invalid alert description"
}
}
// Alert is one of the content types supported by the TLS record layer.
// Alert messages convey the severity of the message
// (warning or fatal) and a description of the alert. Alert messages
// with a level of fatal result in the immediate termination of the
// connection. In this case, other connections corresponding to the
// session may continue, but the session identifier MUST be invalidated,
// preventing the failed session from being used to establish new
// connections. Like other messages, alert messages are encrypted and
// compressed, as specified by the current connection state.
// https://tools.ietf.org/html/rfc5246#section-7.2
type Alert struct {
Level Level
Description Description
}
// ContentType returns the ContentType of this Content
func (a Alert) ContentType() protocol.ContentType {
return protocol.ContentTypeAlert
}
// Marshal returns the encoded alert
func (a *Alert) Marshal() ([]byte, error) {
return []byte{byte(a.Level), byte(a.Description)}, nil
}
// Unmarshal populates the alert from binary data
func (a *Alert) Unmarshal(data []byte) error {
if len(data) != 2 {
return errBufferTooSmall
}
a.Level = Level(data[0])
a.Description = Description(data[1])
return nil
}
func (a *Alert) String() string {
return fmt.Sprintf("Alert %s: %s", a.Level, a.Description)
}

Some files were not shown because too many files have changed in this diff Show more