1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-02-15 04:42:04 +00:00

Merge branch '4.0release' into merge/develop

This commit is contained in:
winlin 2021-05-02 21:46:06 +08:00
commit 64243ddb8a
101 changed files with 18902 additions and 9 deletions

View file

@ -80,6 +80,25 @@ Other important wiki:
* Usage: How to improve edge performance for multiple CPUs? ([CN][v4_CN_REUSEPORT], [EN][v4_EN_REUSEPORT])
* Usage: How to file a bug or contact us? ([CN][v4_CN_Contact], [EN][v4_EN_Contact])
## Ports
The ports used by SRS:
* tcp://1935, for RTMP live streaming server.
* tcp://1985, HTTP API server.
* tcp://1990, HTTPS API server.
* tcp://8080, HTTP live streaming server.
* tcp://8088, HTTPS live streaming server.
* udp://8000, [WebRTC Media](https://github.com/ossrs/srs/wiki/v4_CN_WebRTC) server.
* udp://1980, [WebRTC Signaling](https://github.com/ossrs/signaling#usage) server.
* udp://8935, Stream Caster: [Push MPEGTS over UDP](https://github.com/ossrs/srs/wiki/v4_CN_Streamer#push-mpeg-ts-over-udp) server.
* tcp://554, Stream Caster: [Push RTSP](https://github.com/ossrs/srs/wiki/v4_CN_Streamer#push-rtsp-to-srs) server.
* tcp://8936, Stream Caster: [Push HTTP-FLV](https://github.com/ossrs/srs/wiki/v4_CN_Streamer#push-http-flv-to-srs) server.
* tcp://5060, Stream Caster: [Push GB28181 SIP](https://github.com/ossrs/srs/issues/1500#issuecomment-606695679) server.
* udp://9000, Stream Caster: [Push GB28181 Media(bundle)](https://github.com/ossrs/srs/issues/1500#issuecomment-606695679) server.
* udp://58200-58300, Stream Caster: [Push GB28181 Media(no-bundle)](https://github.com/ossrs/srs/issues/1500#issuecomment-606695679) server.
* udp://10080, Stream Caster: [Push SRT Media](https://github.com/ossrs/srs/issues/1147#issuecomment-577469119) server.
## Features
- [x] Using coroutine by ST, it's really simple and stupid enough.
@ -163,8 +182,9 @@ Other important wiki:
## V4 changes
* v5.0, 2021-04-20, Support RTC2RTMP bridger and shared FastTimer. 4.0.95
* v5.0, 2021-04-20, Refine transcoder to support aac2opus and opus2aac. 4.0.94
* v4.0, 2021-05-02, Add one to one demo. 4.0.96
* v4.0, 2021-04-20, Support RTC2RTMP bridger and shared FastTimer. 4.0.95
* v4.0, 2021-04-20, Refine transcoder to support aac2opus and opus2aac. 4.0.94
* v4.0, 2021-05-01, Timer: Extract and apply shared FastTimer. 4.0.93
* v4.0, 2021-04-29, RTC: Support AV1 for Chrome M90. 4.0.91
* v4.0, 2021-04-24, Change push-RTSP as deprecated feature.

View file

@ -0,0 +1,7 @@
httpx-static
letsencrypt.cache
*.crt
*.pem
*.key
objs
.format.txt

19
trunk/3rdparty/httpx-static/Makefile vendored Normal file
View file

@ -0,0 +1,19 @@
.PHONY: help default clean httpx-static
default: httpx-static
clean:
rm -f ./objs/httpx-static
.format.txt: *.go
gofmt -w .
echo "done" > .format.txt
httpx-static: ./objs/httpx-static
./objs/httpx-static: .format.txt *.go Makefile
go build -mod=vendor -o objs/httpx-static .
help:
@echo "Usage: make [httpx-static]"
@echo " httpx-static Make the httpx-static to ./objs/httpx-static"

51
trunk/3rdparty/httpx-static/README.md vendored Normal file
View file

@ -0,0 +1,51 @@
# HTTPX
A HTTP/HTTPS Server, support letsencrypt or self-sign HTTPS and proxying HTTP as HTTPS.
> Remark: Requires GO1.8+
## Usage
*HTTP*: Start a HTTP static server
```
go get github.com/ossrs/go-oryx/httpx-static &&
cd $GOPATH/src/github.com/ossrs/go-oryx/httpx-static &&
$GOPATH/bin/httpx-static -http 8080 -root `pwd`/html
```
Open http://localhost:8080/ in browser.
*HTTPS self-sign*: Start a HTTPS static server
```
go get github.com/ossrs/go-oryx/httpx-static &&
cd $GOPATH/src/github.com/ossrs/go-oryx/httpx-static &&
openssl genrsa -out server.key 2048 &&
subj="/C=CN/ST=Beijing/L=Beijing/O=Me/OU=Me/CN=me.org" &&
openssl req -new -x509 -key server.key -out server.crt -days 365 -subj $subj &&
$GOPATH/bin/httpx-static -https 8443 -root `pwd`/html
```
Open https://localhost:8443/ in browser.
> Remark: Click `ADVANCED` => `Proceed to localhost (unsafe)`.
*HTTPS proxy*: Proxy http as https
```
go get github.com/ossrs/go-oryx/httpx-static &&
cd $GOPATH/src/github.com/ossrs/go-oryx/httpx-static &&
openssl genrsa -out server.key 2048 &&
subj="/C=CN/ST=Beijing/L=Beijing/O=Me/OU=Me/CN=me.org" &&
openssl req -new -x509 -key server.key -out server.crt -days 365 -subj $subj &&
$GOPATH/bin/httpx-static -https 8443 -root `pwd`/html -proxy http://ossrs.net:1985/api/v1
```
Open https://localhost:8443/api/v1/summaries in browser.
## History
* v0.0.3, 2017-11-03, Support multiple proxy HTTP to HTTPS.
Winlin 2017

5
trunk/3rdparty/httpx-static/go.mod vendored Normal file
View file

@ -0,0 +1,5 @@
module github.com/ossrs/go-oryx/httpx-static
go 1.16
require github.com/ossrs/go-oryx-lib v0.0.8

2
trunk/3rdparty/httpx-static/go.sum vendored Normal file
View file

@ -0,0 +1,2 @@
github.com/ossrs/go-oryx-lib v0.0.8 h1:k8ml3ZLsjIMoQEdZdWuy8zkU0w/fbJSyHvT/s9NyeCc=
github.com/ossrs/go-oryx-lib v0.0.8/go.mod h1:i2tH4TZBzAw5h+HwGrNOKvP/nmZgSQz0OEnLLdzcT/8=

View file

@ -0,0 +1 @@
HTTP/HTTPS static server with API proxy.

461
trunk/3rdparty/httpx-static/main.go vendored Normal file
View file

@ -0,0 +1,461 @@
/*
The MIT License (MIT)
Copyright (c) 2019 winlin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
/*
This the main entrance of https-proxy, proxy to api or other http server.
*/
package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
oe "github.com/ossrs/go-oryx-lib/errors"
oh "github.com/ossrs/go-oryx-lib/http"
"github.com/ossrs/go-oryx-lib/https"
ol "github.com/ossrs/go-oryx-lib/logger"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path"
"strconv"
"strings"
"sync"
)
type Strings []string
func (v *Strings) String() string {
return fmt.Sprintf("strings [%v]", strings.Join(*v, ","))
}
func (v *Strings) Set(value string) error {
*v = append(*v, value)
return nil
}
func shouldProxyURL(srcPath, proxyPath string) bool {
if !strings.HasSuffix(srcPath, "/") {
// /api to /api/
// /api.js to /api.js/
// /api/100 to /api/100/
srcPath += "/"
}
if !strings.HasSuffix(proxyPath, "/") {
// /api/ to /api/
// to match /api/ or /api/100
// and not match /api.js/
proxyPath += "/"
}
return strings.HasPrefix(srcPath, proxyPath)
}
func NewComplexProxy(ctx context.Context, proxyUrl *url.URL, originalRequest *http.Request) http.Handler {
proxy := &httputil.ReverseProxy{}
// Create a proxy which attach a isolate logger.
elogger := log.New(os.Stderr, fmt.Sprintf("%v ", originalRequest.RemoteAddr), log.LstdFlags)
proxy.ErrorLog = elogger
proxy.Director = func(r *http.Request) {
// about the x-real-schema, we proxy to backend to identify the client schema.
if rschema := r.Header.Get("X-Real-Schema"); rschema == "" {
if r.TLS == nil {
r.Header.Set("X-Real-Schema", "http")
} else {
r.Header.Set("X-Real-Schema", "https")
}
}
// about x-real-ip and x-forwarded-for or
// about X-Real-IP and X-Forwarded-For or
// https://segmentfault.com/q/1010000002409659
// https://distinctplace.com/2014/04/23/story-behind-x-forwarded-for-and-x-real-ip-headers/
// @remark http proxy will set the X-Forwarded-For.
if rip := r.Header.Get("X-Real-IP"); rip == "" {
if rip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
r.Header.Set("X-Real-IP", rip)
}
}
r.URL.Scheme = proxyUrl.Scheme
r.URL.Host = proxyUrl.Host
ra, url, rip := r.RemoteAddr, r.URL.String(), r.Header.Get("X-Real-Ip")
ol.Tf(ctx, "proxy http rip=%v, addr=%v %v %v with headers %v", rip, ra, r.Method, url, r.Header)
}
proxy.ModifyResponse = func(w *http.Response) error {
// we already added this header, it will cause chrome failed when duplicated.
if w.Header.Get("Access-Control-Allow-Origin") == "*" {
w.Header.Del("Access-Control-Allow-Origin")
}
return nil
}
return proxy
}
func run(ctx context.Context) error {
oh.Server = fmt.Sprintf("%v/%v", Signature(), Version())
fmt.Println(oh.Server, "HTTP/HTTPS static server with API proxy.")
var httpPorts Strings
flag.Var(&httpPorts, "t", "http listen")
flag.Var(&httpPorts, "http", "http listen at. 0 to disable http.")
var httpsPorts Strings
flag.Var(&httpsPorts, "s", "https listen")
flag.Var(&httpsPorts, "https", "https listen at. 0 to disable https. 443 to serve. ")
var httpsDomains string
flag.StringVar(&httpsDomains, "d", "", "https the allow domains")
flag.StringVar(&httpsDomains, "domains", "", "https the allow domains, empty to allow all. for example: ossrs.net,www.ossrs.net")
var html string
flag.StringVar(&html, "r", "./html", "the www web root")
flag.StringVar(&html, "root", "./html", "the www web root. support relative dir to argv[0].")
var cacheFile string
flag.StringVar(&cacheFile, "e", "./letsencrypt.cache", "https the cache for letsencrypt")
flag.StringVar(&cacheFile, "cache", "./letsencrypt.cache", "https the cache for letsencrypt. support relative dir to argv[0].")
var useLetsEncrypt bool
flag.BoolVar(&useLetsEncrypt, "l", false, "whether use letsencrypt CA")
flag.BoolVar(&useLetsEncrypt, "lets", false, "whether use letsencrypt CA. self sign if not.")
var ssKey string
flag.StringVar(&ssKey, "k", "", "https self-sign key")
flag.StringVar(&ssKey, "ssk", "", "https self-sign key")
var ssCert string
flag.StringVar(&ssCert, "c", "", `https self-sign cert`)
flag.StringVar(&ssCert, "ssc", "", `https self-sign cert`)
var oproxies Strings
flag.Var(&oproxies, "p", "proxy ruler")
flag.Var(&oproxies, "proxy", "one or more proxy the matched path to backend, for example, -proxy http://127.0.0.1:8888/api/webrtc")
var sdomains, skeys, scerts Strings
flag.Var(&sdomains, "sdomain", "the SSL hostname")
flag.Var(&skeys, "skey", "the SSL key for domain")
flag.Var(&scerts, "scert", "the SSL cert for domain")
flag.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v -t http -s https -d domains -r root -e cache -l lets -k ssk -c ssc -p proxy", os.Args[0]))
fmt.Println(fmt.Sprintf(" "))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -t, -http string"))
fmt.Println(fmt.Sprintf(" Listen at port for HTTP server. Default: 0, disable HTTP."))
fmt.Println(fmt.Sprintf(" -s, -https string"))
fmt.Println(fmt.Sprintf(" Listen at port for HTTPS server. Default: 0, disable HTTPS."))
fmt.Println(fmt.Sprintf(" -r, -root string"))
fmt.Println(fmt.Sprintf(" The www root path. Supports relative to argv[0]=%v. Default: ./html", path.Dir(os.Args[0])))
fmt.Println(fmt.Sprintf(" -p, -proxy string"))
fmt.Println(fmt.Sprintf(" Proxy path to backend. For example: http://127.0.0.1:8888/api/webrtc"))
fmt.Println(fmt.Sprintf("Options for HTTPS(letsencrypt cert):"))
fmt.Println(fmt.Sprintf(" -l, -lets=bool"))
fmt.Println(fmt.Sprintf(" Whether use letsencrypt CA. Default: false"))
fmt.Println(fmt.Sprintf(" -e, -cache string"))
fmt.Println(fmt.Sprintf(" The letsencrypt cache. Default: ./letsencrypt.cache"))
fmt.Println(fmt.Sprintf(" -d, -domains string"))
fmt.Println(fmt.Sprintf(" Set the validate HTTPS domain. For example: ossrs.net,www.ossrs.net"))
fmt.Println(fmt.Sprintf("Options for HTTPS(file-based cert):"))
fmt.Println(fmt.Sprintf(" -k, -ssk string"))
fmt.Println(fmt.Sprintf(" The self-sign or validate file-based key file."))
fmt.Println(fmt.Sprintf(" -c, -ssc string"))
fmt.Println(fmt.Sprintf(" The self-sign or validate file-based cert file."))
fmt.Println(fmt.Sprintf(" -sdomain string"))
fmt.Println(fmt.Sprintf(" For multiple HTTPS site, the domain name. For example: ossrs.net"))
fmt.Println(fmt.Sprintf(" -skey string"))
fmt.Println(fmt.Sprintf(" For multiple HTTPS site, the key file."))
fmt.Println(fmt.Sprintf(" -scert string"))
fmt.Println(fmt.Sprintf(" For multiple HTTPS site, the cert file."))
fmt.Println(fmt.Sprintf("For example:"))
fmt.Println(fmt.Sprintf(" %v -t 8080 -s 9443 -r ./html", os.Args[0]))
fmt.Println(fmt.Sprintf(" %v -t 8080 -s 9443 -r ./html -p http://ossrs.net:1985/api/v1/versions", os.Args[0]))
fmt.Println(fmt.Sprintf("Generate cert for self-sign HTTPS:"))
fmt.Println(fmt.Sprintf(" openssl genrsa -out server.key 2048"))
fmt.Println(fmt.Sprintf(` openssl req -new -x509 -key server.key -out server.crt -days 365 -subj "/C=CN/ST=Beijing/L=Beijing/O=Me/OU=Me/CN=me.org"`))
fmt.Println(fmt.Sprintf("For example:"))
fmt.Println(fmt.Sprintf(" %v -s 9443 -r ./html -sdomain ossrs.net -skey ossrs.net.key -scert ossrs.net.pem", os.Args[0]))
}
flag.Parse()
if useLetsEncrypt && len(httpsPorts) == 0 {
return oe.Errorf("for letsencrypt, https=%v must be 0(disabled) or 443(enabled)", httpsPorts)
}
if len(httpPorts) == 0 && len(httpsPorts) == 0 {
flag.Usage()
os.Exit(-1)
}
var proxyUrls []*url.URL
proxies := make(map[string]*url.URL)
for _, oproxy := range []string(oproxies) {
if oproxy == "" {
return oe.Errorf("empty proxy in %v", oproxies)
}
proxyUrl, err := url.Parse(oproxy)
if err != nil {
return oe.Wrapf(err, "parse proxy %v", oproxy)
}
if _, ok := proxies[proxyUrl.Path]; ok {
return oe.Errorf("proxy %v duplicated", proxyUrl.Path)
}
proxyUrls = append(proxyUrls, proxyUrl)
proxies[proxyUrl.Path] = proxyUrl
ol.Tf(ctx, "Proxy %v to %v", proxyUrl.Path, oproxy)
}
if !path.IsAbs(cacheFile) && path.IsAbs(os.Args[0]) {
cacheFile = path.Join(path.Dir(os.Args[0]), cacheFile)
}
if !path.IsAbs(html) && path.IsAbs(os.Args[0]) {
html = path.Join(path.Dir(os.Args[0]), html)
}
fs := http.FileServer(http.Dir(html))
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
oh.SetHeader(w)
if o := r.Header.Get("Origin"); len(o) > 0 {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, HEAD, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Expose-Headers", "Server,range,Content-Length,Content-Range")
w.Header().Set("Access-Control-Allow-Headers", "origin,range,accept-encoding,referer,Cache-Control,X-Proxy-Authorization,X-Requested-With,Content-Type")
}
// For matched OPTIONS, directly return without response.
if r.Method == "OPTIONS" {
return
}
if proxyUrls == nil {
if r.URL.Path == "/httpx/v1/versions" {
oh.WriteVersion(w, r, Version())
return
}
fs.ServeHTTP(w, r)
return
}
for _, proxyUrl := range proxyUrls {
if !shouldProxyURL(r.URL.Path, proxyUrl.Path) {
continue
}
if proxy, ok := proxies[proxyUrl.Path]; ok {
p := NewComplexProxy(ctx, proxy, r)
p.ServeHTTP(w, r)
return
}
}
fs.ServeHTTP(w, r)
})
var protos []string
if len(httpPorts) > 0 {
protos = append(protos, fmt.Sprintf("http(:%v)", strings.Join(httpPorts, ",")))
}
for _, httpsPort := range httpsPorts {
if httpsPort == "0" {
continue
}
s := httpsDomains
if httpsDomains == "" {
s = "all domains"
}
if useLetsEncrypt {
protos = append(protos, fmt.Sprintf("https(:%v, %v, %v)", httpsPort, s, cacheFile))
} else {
protos = append(protos, fmt.Sprintf("https(:%v)", httpsPort))
}
if useLetsEncrypt {
protos = append(protos, "letsencrypt")
} else if ssKey != "" {
protos = append(protos, fmt.Sprintf("self-sign(%v, %v)", ssKey, ssCert))
} else if len(sdomains) == 0 {
return oe.New("no ssl config")
}
for i := 0; i < len(sdomains); i++ {
sdomain, skey, scert := sdomains[i], skeys[i], scerts[i]
if f, err := os.Open(scert); err != nil {
return oe.Wrapf(err, "open cert %v for %v err %+v", scert, sdomain, err)
} else {
f.Close()
}
if f, err := os.Open(skey); err != nil {
return oe.Wrapf(err, "open key %v for %v err %+v", skey, sdomain, err)
} else {
f.Close()
}
protos = append(protos, fmt.Sprintf("ssl(%v,%v,%v)", sdomain, skey, scert))
}
}
ol.Tf(ctx, "%v html root at %v", strings.Join(protos, ", "), string(html))
if len(httpsPorts) > 0 && !useLetsEncrypt && ssKey != "" {
if f, err := os.Open(ssCert); err != nil {
return oe.Wrapf(err, "open cert %v err %+v", ssCert, err)
} else {
f.Close()
}
if f, err := os.Open(ssKey); err != nil {
return oe.Wrapf(err, "open key %v err %+v", ssKey, err)
} else {
f.Close()
}
}
wg := sync.WaitGroup{}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var httpServers []*http.Server
for _, v := range httpPorts {
httpPort, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return oe.Wrapf(err, "parse %v", v)
}
wg.Add(1)
go func(httpPort int) {
defer wg.Done()
ctx = ol.WithContext(ctx)
if httpPort == 0 {
ol.W(ctx, "http server disabled")
return
}
defer cancel()
hs := &http.Server{Addr: fmt.Sprintf(":%v", httpPort), Handler: nil}
httpServers = append(httpServers, hs)
ol.Tf(ctx, "http serve at %v", httpPort)
if err := hs.ListenAndServe(); err != nil {
ol.Ef(ctx, "http serve err %+v", err)
return
}
ol.T("http server ok")
}(int(httpPort))
}
for _, v := range httpsPorts {
httpsPort, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return oe.Wrapf(err, "parse %v", v)
}
wg.Add(1)
go func() {
defer wg.Done()
ctx = ol.WithContext(ctx)
if httpsPort == 0 {
ol.W(ctx, "https server disabled")
return
}
defer cancel()
var err error
var m https.Manager
if useLetsEncrypt {
var domains []string
if httpsDomains != "" {
domains = strings.Split(httpsDomains, ",")
}
if m, err = https.NewLetsencryptManager("", domains, cacheFile); err != nil {
ol.Ef(ctx, "create letsencrypt manager err %+v", err)
return
}
} else if ssKey != "" {
if m, err = https.NewSelfSignManager(ssCert, ssKey); err != nil {
ol.Ef(ctx, "create self-sign manager err %+v", err)
return
}
} else if len(sdomains) > 0 {
if m, err = NewCertsManager(sdomains, skeys, scerts); err != nil {
ol.Ef(ctx, "create ssl managers err %+v", err)
return
}
}
hss := &http.Server{
Addr: fmt.Sprintf(":%v", httpsPort),
TLSConfig: &tls.Config{
GetCertificate: m.GetCertificate,
},
}
httpServers = append(httpServers, hss)
ol.Tf(ctx, "https serve at %v", httpsPort)
if err = hss.ListenAndServeTLS("", ""); err != nil {
ol.Ef(ctx, "https serve err %+v", err)
return
}
ol.T("https serve ok")
}()
}
select {
case <-ctx.Done():
for _, server := range httpServers {
server.Close()
}
}
wg.Wait()
return nil
}
func main() {
ctx := ol.WithContext(context.Background())
if err := run(ctx); err != nil {
ol.Ef(ctx, "run err %+v", err)
os.Exit(-1)
}
}

View file

@ -0,0 +1,49 @@
/*
The MIT License (MIT)
Copyright (c) 2019 winlin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
package main
import "testing"
func TestShouldProxyURL(t *testing.T) {
vvs := []struct {
source string
proxy string
expect bool
}{
{"talks/v1", "talks/v1", true},
{"talks/v1/iceconfig", "talks/v1", true},
{"talks/v1/iceconfig.js", "talks/v1", true},
{"talks/v1.js", "talks/v1", false},
{"talks/iceconfig", "talks/v1", false},
{"talks/v1", "api/v1", false},
{"talks/v1/iceconfig", "api/v1", false},
}
for _, vv := range vvs {
if v := shouldProxyURL(vv.source, vv.proxy); v != vv.expect {
t.Errorf("source=%v, proxy=%v, expect=%v", vv.source, vv.proxy, vv.expect)
}
}
}

63
trunk/3rdparty/httpx-static/mcerts.go vendored Normal file
View file

@ -0,0 +1,63 @@
/*
The MIT License (MIT)
Copyright (c) 2019 winlin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
package main
import (
"crypto/tls"
"fmt"
oe "github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/https"
)
type certsManager struct {
// Key is hostname.
certs map[string]https.Manager
}
func NewCertsManager(domains, keys, certs []string) (m https.Manager, err error) {
v := &certsManager{
certs: make(map[string]https.Manager),
}
for i := 0; i < len(domains); i++ {
domain, key, cert := domains[i], keys[i], certs[i]
if m, err = https.NewSelfSignManager(cert, key); err != nil {
return nil, oe.Wrapf(err, "create cert for %v by %v, %v", domain, cert, key)
} else {
v.certs[domain] = m
}
}
return v, nil
}
func (v *certsManager) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if cert, ok := v.certs[clientHello.ServerName]; ok {
return cert.GetCertificate(clientHello)
}
return nil, fmt.Errorf("no cert for %v", clientHello.ServerName)
}

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2013-2017 winlin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,23 @@
Copyright (c) 2015, Dave Cheney <dave@cheney.net>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,52 @@
# errors [![Travis-CI](https://travis-ci.org/pkg/errors.svg)](https://travis-ci.org/pkg/errors) [![AppVeyor](https://ci.appveyor.com/api/projects/status/b98mptawhudj53ep/branch/master?svg=true)](https://ci.appveyor.com/project/davecheney/errors/branch/master) [![GoDoc](https://godoc.org/github.com/pkg/errors?status.svg)](http://godoc.org/github.com/pkg/errors) [![Report card](https://goreportcard.com/badge/github.com/pkg/errors)](https://goreportcard.com/report/github.com/pkg/errors)
Package errors provides simple error handling primitives.
`go get github.com/pkg/errors`
The traditional error handling idiom in Go is roughly akin to
```go
if err != nil {
return err
}
```
which applied recursively up the call stack results in error reports without context or debugging information. The errors package allows programmers to add context to the failure path in their code in a way that does not destroy the original value of the error.
## Adding context to an error
The errors.Wrap function returns a new error that adds context to the original error. For example
```go
_, err := ioutil.ReadAll(r)
if err != nil {
return errors.Wrap(err, "read failed")
}
```
## Retrieving the cause of an error
Using `errors.Wrap` constructs a stack of errors, adding context to the preceding error. Depending on the nature of the error it may be necessary to reverse the operation of errors.Wrap to retrieve the original error for inspection. Any error value which implements this interface can be inspected by `errors.Cause`.
```go
type causer interface {
Cause() error
}
```
`errors.Cause` will recursively retrieve the topmost error which does not implement `causer`, which is assumed to be the original cause. For example:
```go
switch err := errors.Cause(err).(type) {
case *MyError:
// handle specifically
default:
// unknown error
}
```
[Read the package documentation for more information](https://godoc.org/github.com/pkg/errors).
## Contributing
We welcome pull requests, bug fixes and issue reports. With that said, the bar for adding new symbols to this package is intentionally set high.
Before proposing a change, please discuss your change by raising an issue.
## Licence
BSD-2-Clause

View file

@ -0,0 +1,270 @@
// Package errors provides simple error handling primitives.
//
// The traditional error handling idiom in Go is roughly akin to
//
// if err != nil {
// return err
// }
//
// which applied recursively up the call stack results in error reports
// without context or debugging information. The errors package allows
// programmers to add context to the failure path in their code in a way
// that does not destroy the original value of the error.
//
// Adding context to an error
//
// The errors.Wrap function returns a new error that adds context to the
// original error by recording a stack trace at the point Wrap is called,
// and the supplied message. For example
//
// _, err := ioutil.ReadAll(r)
// if err != nil {
// return errors.Wrap(err, "read failed")
// }
//
// If additional control is required the errors.WithStack and errors.WithMessage
// functions destructure errors.Wrap into its component operations of annotating
// an error with a stack trace and an a message, respectively.
//
// Retrieving the cause of an error
//
// Using errors.Wrap constructs a stack of errors, adding context to the
// preceding error. Depending on the nature of the error it may be necessary
// to reverse the operation of errors.Wrap to retrieve the original error
// for inspection. Any error value which implements this interface
//
// type causer interface {
// Cause() error
// }
//
// can be inspected by errors.Cause. errors.Cause will recursively retrieve
// the topmost error which does not implement causer, which is assumed to be
// the original cause. For example:
//
// switch err := errors.Cause(err).(type) {
// case *MyError:
// // handle specifically
// default:
// // unknown error
// }
//
// causer interface is not exported by this package, but is considered a part
// of stable public API.
//
// Formatted printing of errors
//
// All error values returned from this package implement fmt.Formatter and can
// be formatted by the fmt package. The following verbs are supported
//
// %s print the error. If the error has a Cause it will be
// printed recursively
// %v see %s
// %+v extended format. Each Frame of the error's StackTrace will
// be printed in detail.
//
// Retrieving the stack trace of an error or wrapper
//
// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are
// invoked. This information can be retrieved with the following interface.
//
// type stackTracer interface {
// StackTrace() errors.StackTrace
// }
//
// Where errors.StackTrace is defined as
//
// type StackTrace []Frame
//
// The Frame type represents a call site in the stack trace. Frame supports
// the fmt.Formatter interface that can be used for printing information about
// the stack trace of this error. For example:
//
// if err, ok := err.(stackTracer); ok {
// for _, f := range err.StackTrace() {
// fmt.Printf("%+s:%d", f)
// }
// }
//
// stackTracer interface is not exported by this package, but is considered a part
// of stable public API.
//
// See the documentation for Frame.Format for more details.
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
)
// New returns an error with the supplied message.
// New also records the stack trace at the point it was called.
func New(message string) error {
return &fundamental{
msg: message,
stack: callers(),
}
}
// Errorf formats according to a format specifier and returns the string
// as a value that satisfies error.
// Errorf also records the stack trace at the point it was called.
func Errorf(format string, args ...interface{}) error {
return &fundamental{
msg: fmt.Sprintf(format, args...),
stack: callers(),
}
}
// fundamental is an error that has a message and a stack, but no caller.
type fundamental struct {
msg string
*stack
}
func (f *fundamental) Error() string { return f.msg }
func (f *fundamental) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
io.WriteString(s, f.msg)
f.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, f.msg)
case 'q':
fmt.Fprintf(s, "%q", f.msg)
}
}
// WithStack annotates err with a stack trace at the point WithStack was called.
// If err is nil, WithStack returns nil.
func WithStack(err error) error {
if err == nil {
return nil
}
return &withStack{
err,
callers(),
}
}
type withStack struct {
error
*stack
}
func (w *withStack) Cause() error { return w.error }
func (w *withStack) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v", w.Cause())
w.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, w.Error())
case 'q':
fmt.Fprintf(s, "%q", w.Error())
}
}
// Wrap returns an error annotating err with a stack trace
// at the point Wrap is called, and the supplied message.
// If err is nil, Wrap returns nil.
func Wrap(err error, message string) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: message,
}
return &withStack{
err,
callers(),
}
}
// Wrapf returns an error annotating err with a stack trace
// at the point Wrapf is call, and the format specifier.
// If err is nil, Wrapf returns nil.
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: fmt.Sprintf(format, args...),
}
return &withStack{
err,
callers(),
}
}
// WithMessage annotates err with a new message.
// If err is nil, WithMessage returns nil.
func WithMessage(err error, message string) error {
if err == nil {
return nil
}
return &withMessage{
cause: err,
msg: message,
}
}
type withMessage struct {
cause error
msg string
}
func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() }
func (w *withMessage) Cause() error { return w.cause }
func (w *withMessage) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v\n", w.Cause())
io.WriteString(s, w.msg)
return
}
fallthrough
case 's', 'q':
io.WriteString(s, w.Error())
}
}
// Cause returns the underlying cause of the error, if possible.
// An error value has a cause if it implements the following
// interface:
//
// type causer interface {
// Cause() error
// }
//
// If the error does not implement Cause, the original error will
// be returned. If the error is nil, nil will be returned without further
// investigation.
func Cause(err error) error {
type causer interface {
Cause() error
}
for err != nil {
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return err
}

View file

@ -0,0 +1,187 @@
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
"path"
"runtime"
"strings"
)
// Frame represents a program counter inside a stack frame.
type Frame uintptr
// pc returns the program counter for this frame;
// multiple frames may have the same PC value.
func (f Frame) pc() uintptr { return uintptr(f) - 1 }
// file returns the full path to the file that contains the
// function for this Frame's pc.
func (f Frame) file() string {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return "unknown"
}
file, _ := fn.FileLine(f.pc())
return file
}
// line returns the line number of source code of the
// function for this Frame's pc.
func (f Frame) line() int {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return 0
}
_, line := fn.FileLine(f.pc())
return line
}
// Format formats the frame according to the fmt.Formatter interface.
//
// %s source file
// %d source line
// %n function name
// %v equivalent to %s:%d
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+s path of source file relative to the compile time GOPATH
// %+v equivalent to %+s:%d
func (f Frame) Format(s fmt.State, verb rune) {
switch verb {
case 's':
switch {
case s.Flag('+'):
pc := f.pc()
fn := runtime.FuncForPC(pc)
if fn == nil {
io.WriteString(s, "unknown")
} else {
file, _ := fn.FileLine(pc)
fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file)
}
default:
io.WriteString(s, path.Base(f.file()))
}
case 'd':
fmt.Fprintf(s, "%d", f.line())
case 'n':
name := runtime.FuncForPC(f.pc()).Name()
io.WriteString(s, funcname(name))
case 'v':
f.Format(s, 's')
io.WriteString(s, ":")
f.Format(s, 'd')
}
}
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
type StackTrace []Frame
// Format formats the stack of Frames according to the fmt.Formatter interface.
//
// %s lists source files for each Frame in the stack
// %v lists the source file and line number for each Frame in the stack
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+v Prints filename, function, and line number for each Frame in the stack.
func (st StackTrace) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case s.Flag('+'):
for _, f := range st {
fmt.Fprintf(s, "\n%+v", f)
}
case s.Flag('#'):
fmt.Fprintf(s, "%#v", []Frame(st))
default:
fmt.Fprintf(s, "%v", []Frame(st))
}
case 's':
fmt.Fprintf(s, "%s", []Frame(st))
}
}
// stack represents a stack of program counters.
type stack []uintptr
func (s *stack) Format(st fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case st.Flag('+'):
for _, pc := range *s {
f := Frame(pc)
fmt.Fprintf(st, "\n%+v", f)
}
}
}
}
func (s *stack) StackTrace() StackTrace {
f := make([]Frame, len(*s))
for i := 0; i < len(f); i++ {
f[i] = Frame((*s)[i])
}
return f
}
func callers() *stack {
const depth = 32
var pcs [depth]uintptr
n := runtime.Callers(3, pcs[:])
var st stack = pcs[0:n]
return &st
}
// funcname removes the path prefix component of a function's name reported by func.Name().
func funcname(name string) string {
i := strings.LastIndex(name, "/")
name = name[i+1:]
i = strings.Index(name, ".")
return name[i+1:]
}
func trimGOPATH(name, file string) string {
// Here we want to get the source file path relative to the compile time
// GOPATH. As of Go 1.6.x there is no direct way to know the compiled
// GOPATH at runtime, but we can infer the number of path segments in the
// GOPATH. We note that fn.Name() returns the function name qualified by
// the import path, which does not include the GOPATH. Thus we can trim
// segments from the beginning of the file path until the number of path
// separators remaining is one more than the number of path separators in
// the function name. For example, given:
//
// GOPATH /home/user
// file /home/user/src/pkg/sub/file.go
// fn.Name() pkg/sub.Type.Method
//
// We want to produce:
//
// pkg/sub/file.go
//
// From this we can easily see that fn.Name() has one less path separator
// than our desired output. We count separators from the end of the file
// path until it finds two more than in the function name and then move
// one character forward to preserve the initial path segment without a
// leading separator.
const sep = "/"
goal := strings.Count(name, sep) + 2
i := len(file)
for n := 0; n < goal; n++ {
i = strings.LastIndex(file[:i], sep)
if i == -1 {
// not enough separators found, set i so that the slice expression
// below leaves file unmodified
i = -len(sep)
break
}
}
// get back to 0 or trim the leading separator
file = file[i+len(sep):]
return file
}

View file

@ -0,0 +1,87 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// The oryx http package, the response parse service.
package http
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
)
// Read http api by HTTP GET and parse the code/data.
func ApiRequest(url string) (code int, body []byte, err error) {
if body, err = apiGet(url); err != nil {
return
}
if code, _, err = apiParse(url, body); err != nil {
return
}
return
}
// Read http api by HTTP GET.
func apiGet(url string) (body []byte, err error) {
var resp *http.Response
if resp, err = http.Get(url); err != nil {
err = fmt.Errorf("api get failed, url=%v, err is %v", url, err)
return
}
defer resp.Body.Close()
if body, err = ioutil.ReadAll(resp.Body); err != nil {
err = fmt.Errorf("api read failed, url=%v, err is %v", url, err)
return
}
return
}
// Parse the standard response {code:int,data:object}.
func apiParse(url string, body []byte) (code int, data interface{}, err error) {
obj := make(map[string]interface{})
if err = json.Unmarshal(body, &obj); err != nil {
err = fmt.Errorf("api parse failed, url=%v, body=%v, err is %v", url, string(body), err)
return
}
if value, ok := obj["code"]; !ok {
err = fmt.Errorf("api no code, url=%v, body=%v", url, string(body))
return
} else if value, ok := value.(float64); !ok {
err = fmt.Errorf("api code not number, code=%v, url=%v, body=%v", value, url, string(body))
return
} else {
code = int(value)
}
data, _ = obj["data"]
if code != 0 {
err = fmt.Errorf("api error, code=%v, url=%v, body=%v, data=%v", code, url, string(body), data)
return
}
return
}

View file

@ -0,0 +1,274 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// The oryx http package provides standard request and response in json.
// Error, when error, use this handler.
// CplxError, for complex error, use this handler.
// Data, when no error, use this handler.
// SystemError, application level error code.
// SetHeader, for direclty response the raw stream.
// The standard server response:
// code, an int error code.
// data, specifies the data.
// The api for simple api:
// WriteVersion, to directly response the version.
// WriteData, to directly write the data in json.
// WriteError, to directly write the error.
// WriteCplxError, to directly write the complex error.
// The global variables:
// oh.Server, to set the response header["Server"].
package http
import (
"encoding/json"
"fmt"
ol "github.com/ossrs/go-oryx-lib/logger"
"net/http"
"os"
"strconv"
"strings"
)
// header["Content-Type"] in response.
const (
HttpJson = "application/json"
HttpJavaScript = "application/javascript"
)
// header["Server"] in response.
var Server = "Oryx"
// system int error.
type SystemError int
func (v SystemError) Error() string {
return fmt.Sprintf("System error=%d", int(v))
}
// system conplex error.
type SystemComplexError struct {
// the system error code.
Code SystemError `json:"code"`
// the description for this error.
Message string `json:"data"`
}
func (v SystemComplexError) Error() string {
return fmt.Sprintf("%v, %v", v.Code.Error(), v.Message)
}
// application level, with code.
type AppError interface {
Code() int
error
}
// HTTP Status Code
type HTTPStatus interface {
Status() int
}
// http standard error response.
// @remark for not SystemError, we will use logger.E to print it.
// @remark user can use WriteError() for simple api.
func Error(ctx ol.Context, err error) http.Handler {
// for complex error, use code instead.
if v, ok := err.(SystemComplexError); ok {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonHandler(ctx, FilterCplxSystemError(ctx, w, r, v)).ServeHTTP(w, r)
})
}
// for int error, use code instead.
if v, ok := err.(SystemError); ok {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonHandler(ctx, FilterSystemError(ctx, w, r, v)).ServeHTTP(w, r)
})
}
// for application error, use code instead.
if v, ok := err.(AppError); ok {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonHandler(ctx, FilterAppError(ctx, w, r, v)).ServeHTTP(w, r)
})
}
// unknown error, log and response detail
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
SetHeader(w)
w.Header().Set("Content-Type", HttpJson)
status := http.StatusInternalServerError
if v, ok := err.(HTTPStatus); ok {
status = v.Status()
}
http.Error(w, FilterError(ctx, w, r, err), status)
})
}
// Wrapper for complex error use Error(ctx, SystemComplexError{})
// @remark user can use WriteCplxError() for simple api.
func CplxError(ctx ol.Context, code SystemError, message string) http.Handler {
return Error(ctx, SystemComplexError{code, message})
}
// http normal response.
// @remark user can use nil v to response success, which data is null.
// @remark user can use WriteData() for simple api.
func Data(ctx ol.Context, v interface{}) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonHandler(ctx, FilterData(ctx, w, r, v)).ServeHTTP(w, r)
})
}
// set http header, for directly use the w,
// for example, user want to directly write raw text.
func SetHeader(w http.ResponseWriter) {
w.Header().Set("Server", Server)
}
// response json directly.
func jsonHandler(ctx ol.Context, rv interface{}) http.Handler {
var err error
var b []byte
if b, err = json.Marshal(rv); err != nil {
return Error(ctx, err)
}
status := http.StatusOK
if v, ok := rv.(HTTPStatus); ok {
status = v.Status()
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
SetHeader(w)
q := r.URL.Query()
if cb := q.Get("callback"); cb != "" {
w.Header().Set("Content-Type", HttpJavaScript)
if status != http.StatusOK {
w.WriteHeader(status)
}
// TODO: Handle error.
fmt.Fprintf(w, "%s(%s)", cb, string(b))
} else {
w.Header().Set("Content-Type", HttpJson)
if status != http.StatusOK {
w.WriteHeader(status)
}
// TODO: Handle error.
w.Write(b)
}
})
}
// response the standard version info:
// {code, server, data} where server is the server pid, and data is below object:
// {major, minor, revision, extra, version, signature}
// @param version in {major.minor.revision-extra}, where -extra is optional,
// for example: 1.0.0 or 1.0.0-0 or 1.0.0-1
func WriteVersion(w http.ResponseWriter, r *http.Request, version string) {
var major, minor, revision, extra int
versions := strings.Split(version, "-")
if len(versions) > 1 {
extra, _ = strconv.Atoi(versions[1])
}
versions = strings.Split(versions[0], ".")
if len(versions) > 0 {
major, _ = strconv.Atoi(versions[0])
}
if len(versions) > 1 {
minor, _ = strconv.Atoi(versions[1])
}
if len(versions) > 2 {
revision, _ = strconv.Atoi(versions[2])
}
Data(nil, map[string]interface{}{
"major": major,
"minor": minor,
"revision": revision,
"extra": extra,
"version": version,
"signature": Server,
}).ServeHTTP(w, r)
}
// Directly write json data, a wrapper for Data().
// @remark user can use Data() for group of complex apis.
func WriteData(ctx ol.Context, w http.ResponseWriter, r *http.Request, v interface{}) {
Data(ctx, v).ServeHTTP(w, r)
}
// Directly write success json response, same to WriteData(ctx, w, r, nil).
func Success(ctx ol.Context, w http.ResponseWriter, r *http.Request) {
WriteData(ctx, w, r, nil)
}
// Directly write error, a wrapper for Error().
// @remark user can use Error() for group of complex apis.
func WriteError(ctx ol.Context, w http.ResponseWriter, r *http.Request, err error) {
Error(ctx, err).ServeHTTP(w, r)
}
// Directly write complex error, a wrappter for CplxError().
// @remark user can use CplxError() for group of complex apis.
func WriteCplxError(ctx ol.Context, w http.ResponseWriter, r *http.Request, code SystemError, message string) {
CplxError(ctx, code, message).ServeHTTP(w, r)
}
// for hijack to define the response structure.
// user can redefine these functions for special response.
var FilterCplxSystemError = func(ctx ol.Context, w http.ResponseWriter, r *http.Request, o SystemComplexError) interface{} {
ol.Ef(ctx, "Serve %v failed, err is %+v", r.URL, o)
return o
}
var FilterSystemError = func(ctx ol.Context, w http.ResponseWriter, r *http.Request, o SystemError) interface{} {
ol.Ef(ctx, "Serve %v failed, err is %+v", r.URL, o)
return map[string]int{"code": int(o)}
}
var FilterAppError = func(ctx ol.Context, w http.ResponseWriter, r *http.Request, err AppError) interface{} {
ol.Ef(ctx, "Serve %v failed, err is %+v", r.URL, err)
return map[string]interface{}{"code": err.Code(), "data": err.Error()}
}
var FilterError = func(ctx ol.Context, w http.ResponseWriter, r *http.Request, err error) string {
ol.Ef(ctx, "Serve %v failed, err is %+v", r.URL, err)
return err.Error()
}
var FilterData = func(ctx ol.Context, w http.ResponseWriter, r *http.Request, o interface{}) interface{} {
rv := map[string]interface{}{
"code": 0,
"server": os.Getpid(),
"data": o,
}
// for string, directly use it without convert,
// for the type covert by golang maybe modify the content.
if v, ok := o.(string); ok {
rv["data"] = v
}
return rv
}

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2015 Sebastian Erhart
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,18 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
// Challenge is a string that identifies a particular type and version of ACME challenge.
type Challenge string
const (
// HTTP01 is the "http-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http
// Note: HTTP01ChallengePath returns the URL path to fulfill this challenge
HTTP01 = Challenge("http-01")
// TLSSNI01 is the "tls-sni-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#tls-with-server-name-indication-tls-sni
// Note: TLSSNI01ChallengeCert returns a certificate to fulfill this challenge
TLSSNI01 = Challenge("tls-sni-01")
// DNS01 is the "dns-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#dns
// Note: DNS01Record returns a DNS record which will fulfill this challenge
DNS01 = Challenge("dns-01")
)

View file

@ -0,0 +1,640 @@
// Package acme implements the ACME protocol for Let's Encrypt and other conforming providers.
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"crypto"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"net"
"regexp"
"strconv"
"strings"
"time"
)
var (
// Logger is an optional custom logger.
Logger *log.Logger
)
// logf writes a log entry. It uses Logger if not
// nil, otherwise it uses the default log.Logger.
func logf(format string, args ...interface{}) {
if Logger != nil {
Logger.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
// User interface is to be implemented by users of this library.
// It is used by the client type to get user specific information.
type User interface {
GetEmail() string
GetRegistration() *RegistrationResource
GetPrivateKey() crypto.PrivateKey
}
// Interface for all challenge solvers to implement.
type solver interface {
Solve(challenge challenge, domain string) error
}
type validateFunc func(j *jws, domain, uri string, chlng challenge) error
// Client is the user-friendy way to ACME
type Client struct {
directory directory
user User
jws *jws
keyType KeyType
issuerCert []byte
solvers map[Challenge]solver
}
// NewClient creates a new ACME client on behalf of the user. The client will depend on
// the ACME directory located at caDirURL for the rest of its actions. It will
// generate private keys for certificates of size keyBits.
func NewClient(caDirURL string, user User, keyType KeyType) (*Client, error) {
privKey := user.GetPrivateKey()
if privKey == nil {
return nil, errors.New("private key was nil")
}
var dir directory
if _, err := getJSON(caDirURL, &dir); err != nil {
return nil, fmt.Errorf("get directory at '%s': %v", caDirURL, err)
}
if dir.NewRegURL == "" {
return nil, errors.New("directory missing new registration URL")
}
if dir.NewAuthzURL == "" {
return nil, errors.New("directory missing new authz URL")
}
if dir.NewCertURL == "" {
return nil, errors.New("directory missing new certificate URL")
}
if dir.RevokeCertURL == "" {
return nil, errors.New("directory missing revoke certificate URL")
}
jws := &jws{privKey: privKey, directoryURL: caDirURL}
// REVIEW: best possibility?
// Add all available solvers with the right index as per ACME
// spec to this map. Otherwise they won`t be found.
solvers := make(map[Challenge]solver)
solvers[HTTP01] = &httpChallenge{jws: jws, validate: validate, provider: &HTTPProviderServer{}}
solvers[TLSSNI01] = &tlsSNIChallenge{jws: jws, validate: validate, provider: &TLSProviderServer{}}
return &Client{directory: dir, user: user, jws: jws, keyType: keyType, solvers: solvers}, nil
}
// SetChallengeProvider specifies a custom provider that will make the solution available
func (c *Client) SetChallengeProvider(challenge Challenge, p ChallengeProvider) error {
switch challenge {
case HTTP01:
c.solvers[challenge] = &httpChallenge{jws: c.jws, validate: validate, provider: p}
case TLSSNI01:
c.solvers[challenge] = &tlsSNIChallenge{jws: c.jws, validate: validate, provider: p}
default:
return fmt.Errorf("Unknown challenge %v", challenge)
}
return nil
}
// SetHTTPAddress specifies a custom interface:port to be used for HTTP based challenges.
// If this option is not used, the default port 80 and all interfaces will be used.
// To only specify a port and no interface use the ":port" notation.
func (c *Client) SetHTTPAddress(iface string) error {
host, port, err := net.SplitHostPort(iface)
if err != nil {
return err
}
if chlng, ok := c.solvers[HTTP01]; ok {
chlng.(*httpChallenge).provider = NewHTTPProviderServer(host, port)
}
return nil
}
// SetTLSAddress specifies a custom interface:port to be used for TLS based challenges.
// If this option is not used, the default port 443 and all interfaces will be used.
// To only specify a port and no interface use the ":port" notation.
func (c *Client) SetTLSAddress(iface string) error {
host, port, err := net.SplitHostPort(iface)
if err != nil {
return err
}
if chlng, ok := c.solvers[TLSSNI01]; ok {
chlng.(*tlsSNIChallenge).provider = NewTLSProviderServer(host, port)
}
return nil
}
// ExcludeChallenges explicitly removes challenges from the pool for solving.
func (c *Client) ExcludeChallenges(challenges []Challenge) {
// Loop through all challenges and delete the requested one if found.
for _, challenge := range challenges {
delete(c.solvers, challenge)
}
}
// Register the current account to the ACME server.
func (c *Client) Register() (*RegistrationResource, error) {
if c == nil || c.user == nil {
return nil, errors.New("acme: cannot register a nil client or user")
}
logf("[INFO] acme: Registering account for %s", c.user.GetEmail())
regMsg := registrationMessage{
Resource: "new-reg",
}
if c.user.GetEmail() != "" {
regMsg.Contact = []string{"mailto:" + c.user.GetEmail()}
} else {
regMsg.Contact = []string{}
}
var serverReg Registration
hdr, err := postJSON(c.jws, c.directory.NewRegURL, regMsg, &serverReg)
if err != nil {
return nil, err
}
reg := &RegistrationResource{Body: serverReg}
links := parseLinks(hdr["Link"])
reg.URI = hdr.Get("Location")
if links["terms-of-service"] != "" {
reg.TosURL = links["terms-of-service"]
}
if links["next"] != "" {
reg.NewAuthzURL = links["next"]
} else {
return nil, errors.New("acme: The server did not return 'next' link to proceed")
}
return reg, nil
}
// AgreeToTOS updates the Client registration and sends the agreement to
// the server.
func (c *Client) AgreeToTOS() error {
reg := c.user.GetRegistration()
reg.Body.Agreement = c.user.GetRegistration().TosURL
reg.Body.Resource = "reg"
_, err := postJSON(c.jws, c.user.GetRegistration().URI, c.user.GetRegistration().Body, nil)
return err
}
// ObtainCertificate tries to obtain a single certificate using all domains passed into it.
// The first domain in domains is used for the CommonName field of the certificate, all other
// domains are added using the Subject Alternate Names extension. A new private key is generated
// for every invocation of this function. If you do not want that you can supply your own private key
// in the privKey parameter. If this parameter is non-nil it will be used instead of generating a new one.
// If bundle is true, the []byte contains both the issuer certificate and
// your issued certificate as a bundle.
// This function will never return a partial certificate. If one domain in the list fails,
// the whole certificate will fail.
func (c *Client) ObtainCertificate(domains []string, bundle bool, privKey crypto.PrivateKey) (CertificateResource, map[string]error) {
if bundle {
logf("[INFO][%s] acme: Obtaining bundled SAN certificate", strings.Join(domains, ", "))
} else {
logf("[INFO][%s] acme: Obtaining SAN certificate", strings.Join(domains, ", "))
}
challenges, failures := c.getChallenges(domains)
// If any challenge fails - return. Do not generate partial SAN certificates.
if len(failures) > 0 {
return CertificateResource{}, failures
}
errs := c.solveChallenges(challenges)
// If any challenge fails - return. Do not generate partial SAN certificates.
if len(errs) > 0 {
return CertificateResource{}, errs
}
logf("[INFO][%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
cert, err := c.requestCertificate(challenges, bundle, privKey)
if err != nil {
for _, chln := range challenges {
failures[chln.Domain] = err
}
}
return cert, failures
}
// RevokeCertificate takes a PEM encoded certificate or bundle and tries to revoke it at the CA.
func (c *Client) RevokeCertificate(certificate []byte) error {
certificates, err := parsePEMBundle(certificate)
if err != nil {
return err
}
x509Cert := certificates[0]
if x509Cert.IsCA {
return fmt.Errorf("Certificate bundle starts with a CA certificate")
}
encodedCert := base64.URLEncoding.EncodeToString(x509Cert.Raw)
_, err = postJSON(c.jws, c.directory.RevokeCertURL, revokeCertMessage{Resource: "revoke-cert", Certificate: encodedCert}, nil)
return err
}
// RenewCertificate takes a CertificateResource and tries to renew the certificate.
// If the renewal process succeeds, the new certificate will ge returned in a new CertResource.
// Please be aware that this function will return a new certificate in ANY case that is not an error.
// If the server does not provide us with a new cert on a GET request to the CertURL
// this function will start a new-cert flow where a new certificate gets generated.
// If bundle is true, the []byte contains both the issuer certificate and
// your issued certificate as a bundle.
// For private key reuse the PrivateKey property of the passed in CertificateResource should be non-nil.
func (c *Client) RenewCertificate(cert CertificateResource, bundle bool) (CertificateResource, error) {
// Input certificate is PEM encoded. Decode it here as we may need the decoded
// cert later on in the renewal process. The input may be a bundle or a single certificate.
certificates, err := parsePEMBundle(cert.Certificate)
if err != nil {
return CertificateResource{}, err
}
x509Cert := certificates[0]
if x509Cert.IsCA {
return CertificateResource{}, fmt.Errorf("[%s] Certificate bundle starts with a CA certificate", cert.Domain)
}
// This is just meant to be informal for the user.
timeLeft := x509Cert.NotAfter.Sub(time.Now().UTC())
logf("[INFO][%s] acme: Trying renewal with %d hours remaining", cert.Domain, int(timeLeft.Hours()))
// The first step of renewal is to check if we get a renewed cert
// directly from the cert URL.
resp, err := httpGet(cert.CertURL)
if err != nil {
return CertificateResource{}, err
}
defer resp.Body.Close()
serverCertBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return CertificateResource{}, err
}
serverCert, err := x509.ParseCertificate(serverCertBytes)
if err != nil {
return CertificateResource{}, err
}
// If the server responds with a different certificate we are effectively renewed.
// TODO: Further test if we can actually use the new certificate (Our private key works)
if !x509Cert.Equal(serverCert) {
logf("[INFO][%s] acme: Server responded with renewed certificate", cert.Domain)
issuedCert := pemEncode(derCertificateBytes(serverCertBytes))
// If bundle is true, we want to return a certificate bundle.
// To do this, we need the issuer certificate.
if bundle {
// The issuer certificate link is always supplied via an "up" link
// in the response headers of a new certificate.
links := parseLinks(resp.Header["Link"])
issuerCert, err := c.getIssuerCertificate(links["up"])
if err != nil {
// If we fail to acquire the issuer cert, return the issued certificate - do not fail.
logf("[ERROR][%s] acme: Could not bundle issuer certificate: %v", cert.Domain, err)
} else {
// Success - append the issuer cert to the issued cert.
issuerCert = pemEncode(derCertificateBytes(issuerCert))
issuedCert = append(issuedCert, issuerCert...)
}
}
cert.Certificate = issuedCert
return cert, nil
}
var privKey crypto.PrivateKey
if cert.PrivateKey != nil {
privKey, err = parsePEMPrivateKey(cert.PrivateKey)
if err != nil {
return CertificateResource{}, err
}
}
var domains []string
var failures map[string]error
// check for SAN certificate
if len(x509Cert.DNSNames) > 1 {
domains = append(domains, x509Cert.Subject.CommonName)
for _, sanDomain := range x509Cert.DNSNames {
if sanDomain == x509Cert.Subject.CommonName {
continue
}
domains = append(domains, sanDomain)
}
} else {
domains = append(domains, x509Cert.Subject.CommonName)
}
newCert, failures := c.ObtainCertificate(domains, bundle, privKey)
return newCert, failures[cert.Domain]
}
// Looks through the challenge combinations to find a solvable match.
// Then solves the challenges in series and returns.
func (c *Client) solveChallenges(challenges []authorizationResource) map[string]error {
// loop through the resources, basically through the domains.
failures := make(map[string]error)
for _, authz := range challenges {
// no solvers - no solving
if solvers := c.chooseSolvers(authz.Body, authz.Domain); solvers != nil {
for i, solver := range solvers {
// TODO: do not immediately fail if one domain fails to validate.
err := solver.Solve(authz.Body.Challenges[i], authz.Domain)
if err != nil {
failures[authz.Domain] = err
}
}
} else {
failures[authz.Domain] = fmt.Errorf("[%s] acme: Could not determine solvers", authz.Domain)
}
}
return failures
}
// Checks all combinations from the server and returns an array of
// solvers which should get executed in series.
func (c *Client) chooseSolvers(auth authorization, domain string) map[int]solver {
for _, combination := range auth.Combinations {
solvers := make(map[int]solver)
for _, idx := range combination {
if solver, ok := c.solvers[auth.Challenges[idx].Type]; ok {
solvers[idx] = solver
} else {
logf("[INFO][%s] acme: Could not find solver for: %s", domain, auth.Challenges[idx].Type)
}
}
// If we can solve the whole combination, return the solvers
if len(solvers) == len(combination) {
return solvers
}
}
return nil
}
// Get the challenges needed to proof our identifier to the ACME server.
func (c *Client) getChallenges(domains []string) ([]authorizationResource, map[string]error) {
resc, errc := make(chan authorizationResource), make(chan domainError)
for _, domain := range domains {
go func(domain string) {
authMsg := authorization{Resource: "new-authz", Identifier: identifier{Type: "dns", Value: domain}}
var authz authorization
hdr, err := postJSON(c.jws, c.user.GetRegistration().NewAuthzURL, authMsg, &authz)
if err != nil {
errc <- domainError{Domain: domain, Error: err}
return
}
links := parseLinks(hdr["Link"])
if links["next"] == "" {
logf("[ERROR][%s] acme: Server did not provide next link to proceed", domain)
return
}
resc <- authorizationResource{Body: authz, NewCertURL: links["next"], AuthURL: hdr.Get("Location"), Domain: domain}
}(domain)
}
responses := make(map[string]authorizationResource)
failures := make(map[string]error)
for i := 0; i < len(domains); i++ {
select {
case res := <-resc:
responses[res.Domain] = res
case err := <-errc:
failures[err.Domain] = err.Error
}
}
challenges := make([]authorizationResource, 0, len(responses))
for _, domain := range domains {
if challenge, ok := responses[domain]; ok {
challenges = append(challenges, challenge)
}
}
close(resc)
close(errc)
return challenges, failures
}
func (c *Client) requestCertificate(authz []authorizationResource, bundle bool, privKey crypto.PrivateKey) (CertificateResource, error) {
if len(authz) == 0 {
return CertificateResource{}, errors.New("Passed no authorizations to requestCertificate!")
}
commonName := authz[0]
var err error
if privKey == nil {
privKey, err = generatePrivateKey(c.keyType)
if err != nil {
return CertificateResource{}, err
}
}
var san []string
var authURLs []string
for _, auth := range authz[1:] {
san = append(san, auth.Domain)
authURLs = append(authURLs, auth.AuthURL)
}
// TODO: should the CSR be customizable?
csr, err := generateCsr(privKey, commonName.Domain, san)
if err != nil {
return CertificateResource{}, err
}
csrString := base64.URLEncoding.EncodeToString(csr)
jsonBytes, err := json.Marshal(csrMessage{Resource: "new-cert", Csr: csrString, Authorizations: authURLs})
if err != nil {
return CertificateResource{}, err
}
resp, err := c.jws.post(commonName.NewCertURL, jsonBytes)
if err != nil {
return CertificateResource{}, err
}
privateKeyPem := pemEncode(privKey)
cerRes := CertificateResource{
Domain: commonName.Domain,
CertURL: resp.Header.Get("Location"),
PrivateKey: privateKeyPem}
for {
switch resp.StatusCode {
case 201, 202:
cert, err := ioutil.ReadAll(limitReader(resp.Body, 1024*1024))
resp.Body.Close()
if err != nil {
return CertificateResource{}, err
}
// The server returns a body with a length of zero if the
// certificate was not ready at the time this request completed.
// Otherwise the body is the certificate.
if len(cert) > 0 {
cerRes.CertStableURL = resp.Header.Get("Content-Location")
cerRes.AccountRef = c.user.GetRegistration().URI
issuedCert := pemEncode(derCertificateBytes(cert))
// If bundle is true, we want to return a certificate bundle.
// To do this, we need the issuer certificate.
if bundle {
// The issuer certificate link is always supplied via an "up" link
// in the response headers of a new certificate.
links := parseLinks(resp.Header["Link"])
issuerCert, err := c.getIssuerCertificate(links["up"])
if err != nil {
// If we fail to acquire the issuer cert, return the issued certificate - do not fail.
logf("[WARNING][%s] acme: Could not bundle issuer certificate: %v", commonName.Domain, err)
} else {
// Success - append the issuer cert to the issued cert.
issuerCert = pemEncode(derCertificateBytes(issuerCert))
issuedCert = append(issuedCert, issuerCert...)
}
}
cerRes.Certificate = issuedCert
logf("[INFO][%s] Server responded with a certificate.", commonName.Domain)
return cerRes, nil
}
// The certificate was granted but is not yet issued.
// Check retry-after and loop.
ra := resp.Header.Get("Retry-After")
retryAfter, err := strconv.Atoi(ra)
if err != nil {
return CertificateResource{}, err
}
logf("[INFO][%s] acme: Server responded with status 202; retrying after %ds", commonName.Domain, retryAfter)
time.Sleep(time.Duration(retryAfter) * time.Second)
break
default:
return CertificateResource{}, handleHTTPError(resp)
}
resp, err = httpGet(cerRes.CertURL)
if err != nil {
return CertificateResource{}, err
}
}
}
// getIssuerCertificate requests the issuer certificate and caches it for
// subsequent requests.
func (c *Client) getIssuerCertificate(url string) ([]byte, error) {
logf("[INFO] acme: Requesting issuer cert from %s", url)
if c.issuerCert != nil {
return c.issuerCert, nil
}
resp, err := httpGet(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
issuerBytes, err := ioutil.ReadAll(limitReader(resp.Body, 1024*1024))
if err != nil {
return nil, err
}
_, err = x509.ParseCertificate(issuerBytes)
if err != nil {
return nil, err
}
c.issuerCert = issuerBytes
return issuerBytes, err
}
func parseLinks(links []string) map[string]string {
aBrkt := regexp.MustCompile("[<>]")
slver := regexp.MustCompile("(.+) *= *\"(.+)\"")
linkMap := make(map[string]string)
for _, link := range links {
link = aBrkt.ReplaceAllString(link, "")
parts := strings.Split(link, ";")
matches := slver.FindStringSubmatch(parts[1])
if len(matches) > 0 {
linkMap[matches[2]] = parts[0]
}
}
return linkMap
}
// validate makes the ACME server start validating a
// challenge response, only returning once it is done.
func validate(j *jws, domain, uri string, chlng challenge) error {
var challengeResponse challenge
hdr, err := postJSON(j, uri, chlng, &challengeResponse)
if err != nil {
return err
}
// After the path is sent, the ACME server will access our server.
// Repeatedly check the server for an updated status on our request.
for {
switch challengeResponse.Status {
case "valid":
logf("[INFO][%s] The server validated our request", domain)
return nil
case "pending":
break
case "invalid":
return handleChallengeError(challengeResponse)
default:
return errors.New("The server returned an unexpected state.")
}
ra, err := strconv.Atoi(hdr.Get("Retry-After"))
if err != nil {
// The ACME server MUST return a Retry-After.
// If it doesn't, we'll just poll hard.
ra = 1
}
time.Sleep(time.Duration(ra) * time.Second)
hdr, err = getJSON(uri, &challengeResponse)
if err != nil {
return err
}
}
}

View file

@ -0,0 +1,325 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io"
"io/ioutil"
"math/big"
"net/http"
"strings"
"time"
"github.com/ossrs/go-oryx-lib/https/crypto/ocsp"
)
// KeyType represents the key algo as well as the key size or curve to use.
type KeyType string
type derCertificateBytes []byte
// Constants for all key types we support.
const (
EC256 = KeyType("P256")
EC384 = KeyType("P384")
RSA2048 = KeyType("2048")
RSA4096 = KeyType("4096")
RSA8192 = KeyType("8192")
)
const (
// OCSPGood means that the certificate is valid.
OCSPGood = ocsp.Good
// OCSPRevoked means that the certificate has been deliberately revoked.
OCSPRevoked = ocsp.Revoked
// OCSPUnknown means that the OCSP responder doesn't know about the certificate.
OCSPUnknown = ocsp.Unknown
// OCSPServerFailed means that the OCSP responder failed to process the request.
OCSPServerFailed = ocsp.ServerFailed
)
// GetOCSPForCert takes a PEM encoded cert or cert bundle returning the raw OCSP response,
// the parsed response, and an error, if any. The returned []byte can be passed directly
// into the OCSPStaple property of a tls.Certificate. If the bundle only contains the
// issued certificate, this function will try to get the issuer certificate from the
// IssuingCertificateURL in the certificate. If the []byte and/or ocsp.Response return
// values are nil, the OCSP status may be assumed OCSPUnknown.
func GetOCSPForCert(bundle []byte) ([]byte, *ocsp.Response, error) {
certificates, err := parsePEMBundle(bundle)
if err != nil {
return nil, nil, err
}
// We expect the certificate slice to be ordered downwards the chain.
// SRV CRT -> CA. We need to pull the leaf and issuer certs out of it,
// which should always be the first two certificates. If there's no
// OCSP server listed in the leaf cert, there's nothing to do. And if
// we have only one certificate so far, we need to get the issuer cert.
issuedCert := certificates[0]
if len(issuedCert.OCSPServer) == 0 {
return nil, nil, errors.New("no OCSP server specified in cert")
}
if len(certificates) == 1 {
// TODO: build fallback. If this fails, check the remaining array entries.
if len(issuedCert.IssuingCertificateURL) == 0 {
return nil, nil, errors.New("no issuing certificate URL")
}
resp, err := httpGet(issuedCert.IssuingCertificateURL[0])
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
issuerBytes, err := ioutil.ReadAll(limitReader(resp.Body, 1024*1024))
if err != nil {
return nil, nil, err
}
issuerCert, err := x509.ParseCertificate(issuerBytes)
if err != nil {
return nil, nil, err
}
// Insert it into the slice on position 0
// We want it ordered right SRV CRT -> CA
certificates = append(certificates, issuerCert)
}
issuerCert := certificates[1]
// Finally kick off the OCSP request.
ocspReq, err := ocsp.CreateRequest(issuedCert, issuerCert, nil)
if err != nil {
return nil, nil, err
}
reader := bytes.NewReader(ocspReq)
req, err := httpPost(issuedCert.OCSPServer[0], "application/ocsp-request", reader)
if err != nil {
return nil, nil, err
}
defer req.Body.Close()
ocspResBytes, err := ioutil.ReadAll(limitReader(req.Body, 1024*1024))
ocspRes, err := ocsp.ParseResponse(ocspResBytes, issuerCert)
if err != nil {
return nil, nil, err
}
if ocspRes.Certificate == nil {
err = ocspRes.CheckSignatureFrom(issuerCert)
if err != nil {
return nil, nil, err
}
}
return ocspResBytes, ocspRes, nil
}
func getKeyAuthorization(token string, key interface{}) (string, error) {
var publicKey crypto.PublicKey
switch k := key.(type) {
case *ecdsa.PrivateKey:
publicKey = k.Public()
case *rsa.PrivateKey:
publicKey = k.Public()
}
// Generate the Key Authorization for the challenge
jwk := keyAsJWK(publicKey)
if jwk == nil {
return "", errors.New("Could not generate JWK from key.")
}
thumbBytes, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return "", err
}
// unpad the base64URL
keyThumb := base64.URLEncoding.EncodeToString(thumbBytes)
index := strings.Index(keyThumb, "=")
if index != -1 {
keyThumb = keyThumb[:index]
}
return token + "." + keyThumb, nil
}
// parsePEMBundle parses a certificate bundle from top to bottom and returns
// a slice of x509 certificates. This function will error if no certificates are found.
func parsePEMBundle(bundle []byte) ([]*x509.Certificate, error) {
var certificates []*x509.Certificate
var certDERBlock *pem.Block
for {
certDERBlock, bundle = pem.Decode(bundle)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert, err := x509.ParseCertificate(certDERBlock.Bytes)
if err != nil {
return nil, err
}
certificates = append(certificates, cert)
}
}
if len(certificates) == 0 {
return nil, errors.New("No certificates were found while parsing the bundle.")
}
return certificates, nil
}
func parsePEMPrivateKey(key []byte) (crypto.PrivateKey, error) {
keyBlock, _ := pem.Decode(key)
switch keyBlock.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(keyBlock.Bytes)
default:
return nil, errors.New("Unknown PEM header value")
}
}
func generatePrivateKey(keyType KeyType) (crypto.PrivateKey, error) {
switch keyType {
case EC256:
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case EC384:
return ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case RSA2048:
return rsa.GenerateKey(rand.Reader, 2048)
case RSA4096:
return rsa.GenerateKey(rand.Reader, 4096)
case RSA8192:
return rsa.GenerateKey(rand.Reader, 8192)
}
return nil, fmt.Errorf("Invalid KeyType: %s", keyType)
}
func generateCsr(privateKey crypto.PrivateKey, domain string, san []string) ([]byte, error) {
template := x509.CertificateRequest{
Subject: pkix.Name{
CommonName: domain,
},
}
if len(san) > 0 {
template.DNSNames = san
}
return x509.CreateCertificateRequest(rand.Reader, &template, privateKey)
}
func pemEncode(data interface{}) []byte {
var pemBlock *pem.Block
switch key := data.(type) {
case *ecdsa.PrivateKey:
keyBytes, _ := x509.MarshalECPrivateKey(key)
pemBlock = &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes}
case *rsa.PrivateKey:
pemBlock = &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}
break
case derCertificateBytes:
pemBlock = &pem.Block{Type: "CERTIFICATE", Bytes: []byte(data.(derCertificateBytes))}
}
return pem.EncodeToMemory(pemBlock)
}
func pemDecode(data []byte) (*pem.Block, error) {
pemBlock, _ := pem.Decode(data)
if pemBlock == nil {
return nil, fmt.Errorf("Pem decode did not yield a valid block. Is the certificate in the right format?")
}
return pemBlock, nil
}
func pemDecodeTox509(pem []byte) (*x509.Certificate, error) {
pemBlock, err := pemDecode(pem)
if pemBlock == nil {
return nil, err
}
return x509.ParseCertificate(pemBlock.Bytes)
}
// GetPEMCertExpiration returns the "NotAfter" date of a PEM encoded certificate.
// The certificate has to be PEM encoded. Any other encodings like DER will fail.
func GetPEMCertExpiration(cert []byte) (time.Time, error) {
pemBlock, err := pemDecode(cert)
if pemBlock == nil {
return time.Time{}, err
}
return getCertExpiration(pemBlock.Bytes)
}
// getCertExpiration returns the "NotAfter" date of a DER encoded certificate.
func getCertExpiration(cert []byte) (time.Time, error) {
pCert, err := x509.ParseCertificate(cert)
if err != nil {
return time.Time{}, err
}
return pCert.NotAfter, nil
}
func generatePemCert(privKey *rsa.PrivateKey, domain string) ([]byte, error) {
derBytes, err := generateDerCert(privKey, time.Time{}, domain)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), nil
}
func generateDerCert(privKey *rsa.PrivateKey, expiration time.Time, domain string) ([]byte, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, err
}
if expiration.IsZero() {
expiration = time.Now().Add(365)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "ACME Challenge TEMP",
},
NotBefore: time.Now(),
NotAfter: expiration,
KeyUsage: x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
DNSNames: []string{domain},
}
return x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)
}
func limitReader(rd io.ReadCloser, numBytes int64) io.ReadCloser {
return http.MaxBytesReader(nil, rd, numBytes)
}

View file

@ -0,0 +1,75 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"encoding/json"
"fmt"
"net/http"
"strings"
)
const (
tosAgreementError = "Must agree to subscriber agreement before any further actions"
)
// RemoteError is the base type for all errors specific to the ACME protocol.
type RemoteError struct {
StatusCode int `json:"status,omitempty"`
Type string `json:"type"`
Detail string `json:"detail"`
}
func (e RemoteError) Error() string {
return fmt.Sprintf("acme: Error %d - %s - %s", e.StatusCode, e.Type, e.Detail)
}
// TOSError represents the error which is returned if the user needs to
// accept the TOS.
// TODO: include the new TOS url if we can somehow obtain it.
type TOSError struct {
RemoteError
}
type domainError struct {
Domain string
Error error
}
type challengeError struct {
RemoteError
records []validationRecord
}
func (c challengeError) Error() string {
var errStr string
for _, validation := range c.records {
errStr = errStr + fmt.Sprintf("\tValidation for %s:%s\n\tResolved to:\n\t\t%s\n\tUsed: %s\n\n",
validation.Hostname, validation.Port, strings.Join(validation.ResolvedAddresses, "\n\t\t"), validation.UsedAddress)
}
return fmt.Sprintf("%s\nError Detail:\n%s", c.RemoteError.Error(), errStr)
}
func handleHTTPError(resp *http.Response) error {
var errorDetail RemoteError
decoder := json.NewDecoder(resp.Body)
err := decoder.Decode(&errorDetail)
if err != nil {
return err
}
errorDetail.StatusCode = resp.StatusCode
// Check for errors we handle specifically
if errorDetail.StatusCode == http.StatusForbidden && errorDetail.Detail == tosAgreementError {
return TOSError{errorDetail}
}
return errorDetail
}
func handleChallengeError(chlng challenge) error {
return challengeError{chlng.Error, chlng.ValidationRecords}
}

View file

@ -0,0 +1,119 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"runtime"
"strings"
"time"
)
// UserAgent (if non-empty) will be tacked onto the User-Agent string in requests.
var UserAgent string
// defaultClient is an HTTP client with a reasonable timeout value.
var defaultClient = http.Client{Timeout: 10 * time.Second}
const (
// defaultGoUserAgent is the Go HTTP package user agent string. Too
// bad it isn't exported. If it changes, we should update it here, too.
defaultGoUserAgent = "Go-http-client/1.1"
// ourUserAgent is the User-Agent of this underlying library package.
ourUserAgent = "xenolf-acme"
)
// httpHead performs a HEAD request with a proper User-Agent string.
// The response body (resp.Body) is already closed when this function returns.
func httpHead(url string) (resp *http.Response, err error) {
req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", userAgent())
resp, err = defaultClient.Do(req)
if err != nil {
return resp, err
}
resp.Body.Close()
return resp, err
}
// httpPost performs a POST request with a proper User-Agent string.
// Callers should close resp.Body when done reading from it.
func httpPost(url string, bodyType string, body io.Reader) (resp *http.Response, err error) {
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", bodyType)
req.Header.Set("User-Agent", userAgent())
return defaultClient.Do(req)
}
// httpGet performs a GET request with a proper User-Agent string.
// Callers should close resp.Body when done reading from it.
func httpGet(url string) (resp *http.Response, err error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", userAgent())
return defaultClient.Do(req)
}
// getJSON performs an HTTP GET request and parses the response body
// as JSON, into the provided respBody object.
func getJSON(uri string, respBody interface{}) (http.Header, error) {
resp, err := httpGet(uri)
if err != nil {
return nil, fmt.Errorf("failed to get %q: %v", uri, err)
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
return resp.Header, handleHTTPError(resp)
}
return resp.Header, json.NewDecoder(resp.Body).Decode(respBody)
}
// postJSON performs an HTTP POST request and parses the response body
// as JSON, into the provided respBody object.
func postJSON(j *jws, uri string, reqBody, respBody interface{}) (http.Header, error) {
jsonBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, errors.New("Failed to marshal network message...")
}
resp, err := j.post(uri, jsonBytes)
if err != nil {
return nil, fmt.Errorf("Failed to post JWS message. -> %v", err)
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
return resp.Header, handleHTTPError(resp)
}
if respBody == nil {
return resp.Header, nil
}
return resp.Header, json.NewDecoder(resp.Body).Decode(respBody)
}
// userAgent builds and returns the User-Agent string to use in requests.
func userAgent() string {
ua := fmt.Sprintf("%s (%s; %s) %s %s", defaultGoUserAgent, runtime.GOOS, runtime.GOARCH, ourUserAgent, UserAgent)
return strings.TrimSpace(ua)
}

View file

@ -0,0 +1,43 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"fmt"
"log"
)
type httpChallenge struct {
jws *jws
validate validateFunc
provider ChallengeProvider
}
// HTTP01ChallengePath returns the URL path for the `http-01` challenge
func HTTP01ChallengePath(token string) string {
return "/.well-known/acme-challenge/" + token
}
func (s *httpChallenge) Solve(chlng challenge, domain string) error {
logf("[INFO][%s] acme: Trying to solve HTTP-01", domain)
// Generate the Key Authorization for the challenge
keyAuth, err := getKeyAuthorization(chlng.Token, s.jws.privKey)
if err != nil {
return err
}
err = s.provider.Present(domain, chlng.Token, keyAuth)
if err != nil {
return fmt.Errorf("[%s] error presenting token: %v", domain, err)
}
defer func() {
err := s.provider.CleanUp(domain, chlng.Token, keyAuth)
if err != nil {
log.Printf("[%s] error cleaning up: %v", domain, err)
}
}()
return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
}

View file

@ -0,0 +1,81 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"fmt"
"net"
"net/http"
"strings"
)
// HTTPProviderServer implements ChallengeProvider for `http-01` challenge
// It may be instantiated without using the NewHTTPProviderServer function if
// you want only to use the default values.
type HTTPProviderServer struct {
iface string
port string
done chan bool
listener net.Listener
}
// NewHTTPProviderServer creates a new HTTPProviderServer on the selected interface and port.
// Setting iface and / or port to an empty string will make the server fall back to
// the "any" interface and port 80 respectively.
func NewHTTPProviderServer(iface, port string) *HTTPProviderServer {
return &HTTPProviderServer{iface: iface, port: port}
}
// Present starts a web server and makes the token available at `HTTP01ChallengePath(token)` for web requests.
func (s *HTTPProviderServer) Present(domain, token, keyAuth string) error {
if s.port == "" {
s.port = "80"
}
var err error
s.listener, err = net.Listen("tcp", net.JoinHostPort(s.iface, s.port))
if err != nil {
return fmt.Errorf("Could not start HTTP server for challenge -> %v", err)
}
s.done = make(chan bool)
go s.serve(domain, token, keyAuth)
return nil
}
// CleanUp closes the HTTP server and removes the token from `HTTP01ChallengePath(token)`
func (s *HTTPProviderServer) CleanUp(domain, token, keyAuth string) error {
if s.listener == nil {
return nil
}
s.listener.Close()
<-s.done
return nil
}
func (s *HTTPProviderServer) serve(domain, token, keyAuth string) {
path := HTTP01ChallengePath(token)
// The handler validates the HOST header and request type.
// For validation it then writes the token the server returned with the challenge
mux := http.NewServeMux()
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.Host, domain) && r.Method == "GET" {
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte(keyAuth))
logf("[INFO][%s] Served key authentication", domain)
} else {
logf("[INFO] Received request for domain %s with method %s", r.Host, r.Method)
w.Write([]byte("TEST"))
}
})
httpServer := &http.Server{
Handler: mux,
}
// Once httpServer is shut down we don't want any lingering
// connections, so disable KeepAlives.
httpServer.SetKeepAlivesEnabled(false)
httpServer.Serve(s.listener)
s.done <- true
}

View file

@ -0,0 +1,109 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"fmt"
"net/http"
"github.com/ossrs/go-oryx-lib/https/jose"
)
type jws struct {
directoryURL string
privKey crypto.PrivateKey
nonces []string
}
func keyAsJWK(key interface{}) *jose.JsonWebKey {
switch k := key.(type) {
case *ecdsa.PublicKey:
return &jose.JsonWebKey{Key: k, Algorithm: "EC"}
case *rsa.PublicKey:
return &jose.JsonWebKey{Key: k, Algorithm: "RSA"}
default:
return nil
}
}
// Posts a JWS signed message to the specified URL
func (j *jws) post(url string, content []byte) (*http.Response, error) {
signedContent, err := j.signContent(content)
if err != nil {
return nil, err
}
resp, err := httpPost(url, "application/jose+json", bytes.NewBuffer([]byte(signedContent.FullSerialize())))
if err != nil {
return nil, err
}
j.getNonceFromResponse(resp)
return resp, err
}
func (j *jws) signContent(content []byte) (*jose.JsonWebSignature, error) {
var alg jose.SignatureAlgorithm
switch k := j.privKey.(type) {
case *rsa.PrivateKey:
alg = jose.RS256
case *ecdsa.PrivateKey:
if k.Curve == elliptic.P256() {
alg = jose.ES256
} else if k.Curve == elliptic.P384() {
alg = jose.ES384
}
}
signer, err := jose.NewSigner(alg, j.privKey)
if err != nil {
return nil, err
}
signer.SetNonceSource(j)
signed, err := signer.Sign(content)
if err != nil {
return nil, err
}
return signed, nil
}
func (j *jws) getNonceFromResponse(resp *http.Response) error {
nonce := resp.Header.Get("Replay-Nonce")
if nonce == "" {
return fmt.Errorf("Server did not respond with a proper nonce header.")
}
j.nonces = append(j.nonces, nonce)
return nil
}
func (j *jws) getNonce() error {
resp, err := httpHead(j.directoryURL)
if err != nil {
return err
}
return j.getNonceFromResponse(resp)
}
func (j *jws) Nonce() (string, error) {
nonce := ""
if len(j.nonces) == 0 {
err := j.getNonce()
if err != nil {
return nonce, err
}
}
nonce, j.nonces = j.nonces[len(j.nonces)-1], j.nonces[:len(j.nonces)-1]
return nonce, nil
}

View file

@ -0,0 +1,117 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"time"
"github.com/ossrs/go-oryx-lib/https/jose"
)
type directory struct {
NewAuthzURL string `json:"new-authz"`
NewCertURL string `json:"new-cert"`
NewRegURL string `json:"new-reg"`
RevokeCertURL string `json:"revoke-cert"`
}
type recoveryKeyMessage struct {
Length int `json:"length,omitempty"`
Client jose.JsonWebKey `json:"client,omitempty"`
Server jose.JsonWebKey `json:"client,omitempty"`
}
type registrationMessage struct {
Resource string `json:"resource"`
Contact []string `json:"contact"`
// RecoveryKey recoveryKeyMessage `json:"recoveryKey,omitempty"`
}
// Registration is returned by the ACME server after the registration
// The client implementation should save this registration somewhere.
type Registration struct {
Resource string `json:"resource,omitempty"`
ID int `json:"id"`
Key jose.JsonWebKey `json:"key"`
Contact []string `json:"contact"`
Agreement string `json:"agreement,omitempty"`
Authorizations string `json:"authorizations,omitempty"`
Certificates string `json:"certificates,omitempty"`
// RecoveryKey recoveryKeyMessage `json:"recoveryKey,omitempty"`
}
// RegistrationResource represents all important informations about a registration
// of which the client needs to keep track itself.
type RegistrationResource struct {
Body Registration `json:"body,omitempty"`
URI string `json:"uri,omitempty"`
NewAuthzURL string `json:"new_authzr_uri,omitempty"`
TosURL string `json:"terms_of_service,omitempty"`
}
type authorizationResource struct {
Body authorization
Domain string
NewCertURL string
AuthURL string
}
type authorization struct {
Resource string `json:"resource,omitempty"`
Identifier identifier `json:"identifier"`
Status string `json:"status,omitempty"`
Expires time.Time `json:"expires,omitempty"`
Challenges []challenge `json:"challenges,omitempty"`
Combinations [][]int `json:"combinations,omitempty"`
}
type identifier struct {
Type string `json:"type"`
Value string `json:"value"`
}
type validationRecord struct {
URI string `json:"url,omitempty"`
Hostname string `json:"hostname,omitempty"`
Port string `json:"port,omitempty"`
ResolvedAddresses []string `json:"addressesResolved,omitempty"`
UsedAddress string `json:"addressUsed,omitempty"`
}
type challenge struct {
Resource string `json:"resource,omitempty"`
Type Challenge `json:"type,omitempty"`
Status string `json:"status,omitempty"`
URI string `json:"uri,omitempty"`
Token string `json:"token,omitempty"`
KeyAuthorization string `json:"keyAuthorization,omitempty"`
TLS bool `json:"tls,omitempty"`
Iterations int `json:"n,omitempty"`
Error RemoteError `json:"error,omitempty"`
ValidationRecords []validationRecord `json:"validationRecord,omitempty"`
}
type csrMessage struct {
Resource string `json:"resource,omitempty"`
Csr string `json:"csr"`
Authorizations []string `json:"authorizations"`
}
type revokeCertMessage struct {
Resource string `json:"resource"`
Certificate string `json:"certificate"`
}
// CertificateResource represents a CA issued certificate.
// PrivateKey and Certificate are both already PEM encoded
// and can be directly written to disk. Certificate may
// be a certificate bundle, depending on the options supplied
// to create it.
type CertificateResource struct {
Domain string `json:"domain"`
CertURL string `json:"certUrl"`
CertStableURL string `json:"certStableUrl"`
AccountRef string `json:"accountRef,omitempty"`
PrivateKey []byte `json:"-"`
Certificate []byte `json:"-"`
}

View file

@ -0,0 +1,30 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import "time"
// ChallengeProvider enables implementing a custom challenge
// provider. Present presents the solution to a challenge available to
// be solved. CleanUp will be called by the challenge if Present ends
// in a non-error state.
type ChallengeProvider interface {
Present(domain, token, keyAuth string) error
CleanUp(domain, token, keyAuth string) error
}
// ChallengeProviderTimeout allows for implementing a
// ChallengeProvider where an unusually long timeout is required when
// waiting for an ACME challenge to be satisfied, such as when
// checking for DNS record progagation. If an implementor of a
// ChallengeProvider provides a Timeout method, then the return values
// of the Timeout method will be used when appropriate by the acme
// package. The interval value is the time between checks.
//
// The default values used for timeout and interval are 60 seconds and
// 2 seconds respectively. These are used when no Timeout method is
// defined for the ChallengeProvider.
type ChallengeProviderTimeout interface {
ChallengeProvider
Timeout() (timeout, interval time.Duration)
}

View file

@ -0,0 +1,75 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"fmt"
"log"
)
type tlsSNIChallenge struct {
jws *jws
validate validateFunc
provider ChallengeProvider
}
func (t *tlsSNIChallenge) Solve(chlng challenge, domain string) error {
// FIXME: https://github.com/ietf-wg-acme/acme/pull/22
// Currently we implement this challenge to track boulder, not the current spec!
logf("[INFO][%s] acme: Trying to solve TLS-SNI-01", domain)
// Generate the Key Authorization for the challenge
keyAuth, err := getKeyAuthorization(chlng.Token, t.jws.privKey)
if err != nil {
return err
}
err = t.provider.Present(domain, chlng.Token, keyAuth)
if err != nil {
return fmt.Errorf("[%s] error presenting token: %v", domain, err)
}
defer func() {
err := t.provider.CleanUp(domain, chlng.Token, keyAuth)
if err != nil {
log.Printf("[%s] error cleaning up: %v", domain, err)
}
}()
return t.validate(t.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
}
// TLSSNI01ChallengeCert returns a certificate and target domain for the `tls-sni-01` challenge
func TLSSNI01ChallengeCertDomain(keyAuth string) (tls.Certificate, string, error) {
// generate a new RSA key for the certificates
tempPrivKey, err := generatePrivateKey(RSA2048)
if err != nil {
return tls.Certificate{}, "", err
}
rsaPrivKey := tempPrivKey.(*rsa.PrivateKey)
rsaPrivPEM := pemEncode(rsaPrivKey)
zBytes := sha256.Sum256([]byte(keyAuth))
z := hex.EncodeToString(zBytes[:sha256.Size])
domain := fmt.Sprintf("%s.%s.acme.invalid", z[:32], z[32:])
tempCertPEM, err := generatePemCert(rsaPrivKey, domain)
if err != nil {
return tls.Certificate{}, "", err
}
certificate, err := tls.X509KeyPair(tempCertPEM, rsaPrivPEM)
if err != nil {
return tls.Certificate{}, "", err
}
return certificate, domain, nil
}
// TLSSNI01ChallengeCert returns a certificate for the `tls-sni-01` challenge
func TLSSNI01ChallengeCert(keyAuth string) (tls.Certificate, error) {
cert, _, err := TLSSNI01ChallengeCertDomain(keyAuth)
return cert, err
}

View file

@ -0,0 +1,64 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"crypto/tls"
"fmt"
"net"
"net/http"
)
// TLSProviderServer implements ChallengeProvider for `TLS-SNI-01` challenge
// It may be instantiated without using the NewTLSProviderServer function if
// you want only to use the default values.
type TLSProviderServer struct {
iface string
port string
done chan bool
listener net.Listener
}
// NewTLSProviderServer creates a new TLSProviderServer on the selected interface and port.
// Setting iface and / or port to an empty string will make the server fall back to
// the "any" interface and port 443 respectively.
func NewTLSProviderServer(iface, port string) *TLSProviderServer {
return &TLSProviderServer{iface: iface, port: port}
}
// Present makes the keyAuth available as a cert
func (s *TLSProviderServer) Present(domain, token, keyAuth string) error {
if s.port == "" {
s.port = "443"
}
cert, err := TLSSNI01ChallengeCert(keyAuth)
if err != nil {
return err
}
tlsConf := new(tls.Config)
tlsConf.Certificates = []tls.Certificate{cert}
s.listener, err = tls.Listen("tcp", net.JoinHostPort(s.iface, s.port), tlsConf)
if err != nil {
return fmt.Errorf("Could not start HTTPS server for challenge -> %v", err)
}
s.done = make(chan bool)
go func() {
http.Serve(s.listener, nil)
s.done <- true
}()
return nil
}
// CleanUp closes the HTTP server.
func (s *TLSProviderServer) CleanUp(domain, token, keyAuth string) error {
if s.listener == nil {
return nil
}
s.listener.Close()
<-s.done
return nil
}

View file

@ -0,0 +1,31 @@
// fork from https://github.com/rsc/letsencrypt/tree/master/vendor/github.com/xenolf/lego/acme
// fork from https://github.com/xenolf/lego/tree/master/acme
package acme
import (
"fmt"
"time"
)
// WaitFor polls the given function 'f', once every 'interval', up to 'timeout'.
func WaitFor(timeout, interval time.Duration, f func() (bool, error)) error {
var lastErr string
timeup := time.After(timeout)
for {
select {
case <-timeup:
return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr)
default:
}
stop, err := f()
if stop {
return nil
}
if err != nil {
lastErr = err.Error()
}
time.Sleep(interval)
}
}

View file

@ -0,0 +1,644 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// fork from golang.org/x/crypto/ocsp
// Package ocsp parses OCSP responses as specified in RFC 2560. OCSP responses
// are signed messages attesting to the validity of a certificate for a small
// period of time. This is used to manage revocation for X.509 certificates.
package ocsp
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"math/big"
"strconv"
"time"
)
var idPKIXOCSPBasic = asn1.ObjectIdentifier([]int{1, 3, 6, 1, 5, 5, 7, 48, 1, 1})
// ResponseStatus contains the result of an OCSP request. See
// https://tools.ietf.org/html/rfc6960#section-2.3
type ResponseStatus int
const (
Success ResponseStatus = 0
Malformed ResponseStatus = 1
InternalError ResponseStatus = 2
TryLater ResponseStatus = 3
// Status code four is unused in OCSP. See
// https://tools.ietf.org/html/rfc6960#section-4.2.1
SignatureRequired ResponseStatus = 5
Unauthorized ResponseStatus = 6
)
func (r ResponseStatus) String() string {
switch r {
case Success:
return "success"
case Malformed:
return "malformed"
case InternalError:
return "internal error"
case TryLater:
return "try later"
case SignatureRequired:
return "signature required"
case Unauthorized:
return "unauthorized"
default:
return "unknown OCSP status: " + strconv.Itoa(int(r))
}
}
// ResponseError is an error that may be returned by ParseResponse to indicate
// that the response itself is an error, not just that its indicating that a
// certificate is revoked, unknown, etc.
type ResponseError struct {
Status ResponseStatus
}
func (r ResponseError) Error() string {
return "ocsp: error from server: " + r.Status.String()
}
// These are internal structures that reflect the ASN.1 structure of an OCSP
// response. See RFC 2560, section 4.2.
type certID struct {
HashAlgorithm pkix.AlgorithmIdentifier
NameHash []byte
IssuerKeyHash []byte
SerialNumber *big.Int
}
// https://tools.ietf.org/html/rfc2560#section-4.1.1
type ocspRequest struct {
TBSRequest tbsRequest
}
type tbsRequest struct {
Version int `asn1:"explicit,tag:0,default:0,optional"`
RequestorName pkix.RDNSequence `asn1:"explicit,tag:1,optional"`
RequestList []request
}
type request struct {
Cert certID
}
type responseASN1 struct {
Status asn1.Enumerated
Response responseBytes `asn1:"explicit,tag:0,optional"`
}
type responseBytes struct {
ResponseType asn1.ObjectIdentifier
Response []byte
}
type basicResponse struct {
TBSResponseData responseData
SignatureAlgorithm pkix.AlgorithmIdentifier
Signature asn1.BitString
Certificates []asn1.RawValue `asn1:"explicit,tag:0,optional"`
}
type responseData struct {
Raw asn1.RawContent
Version int `asn1:"optional,default:0,explicit,tag:0"`
RawResponderID asn1.RawValue
ProducedAt time.Time `asn1:"generalized"`
Responses []singleResponse
}
type singleResponse struct {
CertID certID
Good asn1.Flag `asn1:"tag:0,optional"`
Revoked revokedInfo `asn1:"tag:1,optional"`
Unknown asn1.Flag `asn1:"tag:2,optional"`
ThisUpdate time.Time `asn1:"generalized"`
NextUpdate time.Time `asn1:"generalized,explicit,tag:0,optional"`
SingleExtensions []pkix.Extension `asn1:"explicit,tag:1,optional"`
}
type revokedInfo struct {
RevocationTime time.Time `asn1:"generalized"`
Reason asn1.Enumerated `asn1:"explicit,tag:0,optional"`
}
var (
oidSignatureMD2WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 2}
oidSignatureMD5WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 4}
oidSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 5}
oidSignatureSHA256WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 11}
oidSignatureSHA384WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 12}
oidSignatureSHA512WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 13}
oidSignatureDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 3}
oidSignatureDSAWithSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 3, 2}
oidSignatureECDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 1}
oidSignatureECDSAWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 2}
oidSignatureECDSAWithSHA384 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 3}
oidSignatureECDSAWithSHA512 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 4}
)
var hashOIDs = map[crypto.Hash]asn1.ObjectIdentifier{
crypto.SHA1: asn1.ObjectIdentifier([]int{1, 3, 14, 3, 2, 26}),
crypto.SHA256: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 1}),
crypto.SHA384: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 2}),
crypto.SHA512: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 3}),
}
// TODO(rlb): This is also from crypto/x509, so same comment as AGL's below
var signatureAlgorithmDetails = []struct {
algo x509.SignatureAlgorithm
oid asn1.ObjectIdentifier
pubKeyAlgo x509.PublicKeyAlgorithm
hash crypto.Hash
}{
{x509.MD2WithRSA, oidSignatureMD2WithRSA, x509.RSA, crypto.Hash(0) /* no value for MD2 */},
{x509.MD5WithRSA, oidSignatureMD5WithRSA, x509.RSA, crypto.MD5},
{x509.SHA1WithRSA, oidSignatureSHA1WithRSA, x509.RSA, crypto.SHA1},
{x509.SHA256WithRSA, oidSignatureSHA256WithRSA, x509.RSA, crypto.SHA256},
{x509.SHA384WithRSA, oidSignatureSHA384WithRSA, x509.RSA, crypto.SHA384},
{x509.SHA512WithRSA, oidSignatureSHA512WithRSA, x509.RSA, crypto.SHA512},
{x509.DSAWithSHA1, oidSignatureDSAWithSHA1, x509.DSA, crypto.SHA1},
{x509.DSAWithSHA256, oidSignatureDSAWithSHA256, x509.DSA, crypto.SHA256},
{x509.ECDSAWithSHA1, oidSignatureECDSAWithSHA1, x509.ECDSA, crypto.SHA1},
{x509.ECDSAWithSHA256, oidSignatureECDSAWithSHA256, x509.ECDSA, crypto.SHA256},
{x509.ECDSAWithSHA384, oidSignatureECDSAWithSHA384, x509.ECDSA, crypto.SHA384},
{x509.ECDSAWithSHA512, oidSignatureECDSAWithSHA512, x509.ECDSA, crypto.SHA512},
}
// TODO(rlb): This is also from crypto/x509, so same comment as AGL's below
func signingParamsForPublicKey(pub interface{}, requestedSigAlgo x509.SignatureAlgorithm) (hashFunc crypto.Hash, sigAlgo pkix.AlgorithmIdentifier, err error) {
var pubType x509.PublicKeyAlgorithm
switch pub := pub.(type) {
case *rsa.PublicKey:
pubType = x509.RSA
hashFunc = crypto.SHA256
sigAlgo.Algorithm = oidSignatureSHA256WithRSA
sigAlgo.Parameters = asn1.RawValue{
Tag: 5,
}
case *ecdsa.PublicKey:
pubType = x509.ECDSA
switch pub.Curve {
case elliptic.P224(), elliptic.P256():
hashFunc = crypto.SHA256
sigAlgo.Algorithm = oidSignatureECDSAWithSHA256
case elliptic.P384():
hashFunc = crypto.SHA384
sigAlgo.Algorithm = oidSignatureECDSAWithSHA384
case elliptic.P521():
hashFunc = crypto.SHA512
sigAlgo.Algorithm = oidSignatureECDSAWithSHA512
default:
err = errors.New("x509: unknown elliptic curve")
}
default:
err = errors.New("x509: only RSA and ECDSA keys supported")
}
if err != nil {
return
}
if requestedSigAlgo == 0 {
return
}
found := false
for _, details := range signatureAlgorithmDetails {
if details.algo == requestedSigAlgo {
if details.pubKeyAlgo != pubType {
err = errors.New("x509: requested SignatureAlgorithm does not match private key type")
return
}
sigAlgo.Algorithm, hashFunc = details.oid, details.hash
if hashFunc == 0 {
err = errors.New("x509: cannot sign with hash function requested")
return
}
found = true
break
}
}
if !found {
err = errors.New("x509: unknown SignatureAlgorithm")
}
return
}
// TODO(agl): this is taken from crypto/x509 and so should probably be exported
// from crypto/x509 or crypto/x509/pkix.
func getSignatureAlgorithmFromOID(oid asn1.ObjectIdentifier) x509.SignatureAlgorithm {
for _, details := range signatureAlgorithmDetails {
if oid.Equal(details.oid) {
return details.algo
}
}
return x509.UnknownSignatureAlgorithm
}
// TODO(rlb): This is not taken from crypto/x509, but it's of the same general form.
func getHashAlgorithmFromOID(target asn1.ObjectIdentifier) crypto.Hash {
for hash, oid := range hashOIDs {
if oid.Equal(target) {
return hash
}
}
return crypto.Hash(0)
}
func getOIDFromHashAlgorithm(target crypto.Hash) asn1.ObjectIdentifier {
for hash, oid := range hashOIDs {
if hash == target {
return oid
}
}
return nil
}
// This is the exposed reflection of the internal OCSP structures.
// The status values that can be expressed in OCSP. See RFC 6960.
const (
// Good means that the certificate is valid.
Good = iota
// Revoked means that the certificate has been deliberately revoked.
Revoked
// Unknown means that the OCSP responder doesn't know about the certificate.
Unknown
// ServerFailed is unused and was never used (see
// https://go-review.googlesource.com/#/c/18944). ParseResponse will
// return a ResponseError when an error response is parsed.
ServerFailed
)
// The enumerated reasons for revoking a certificate. See RFC 5280.
const (
Unspecified = iota
KeyCompromise = iota
CACompromise = iota
AffiliationChanged = iota
Superseded = iota
CessationOfOperation = iota
CertificateHold = iota
_ = iota
RemoveFromCRL = iota
PrivilegeWithdrawn = iota
AACompromise = iota
)
// Request represents an OCSP request. See RFC 6960.
type Request struct {
HashAlgorithm crypto.Hash
IssuerNameHash []byte
IssuerKeyHash []byte
SerialNumber *big.Int
}
// Marshal marshals the OCSP request to ASN.1 DER encoded form.
func (req *Request) Marshal() ([]byte, error) {
hashAlg := getOIDFromHashAlgorithm(req.HashAlgorithm)
if hashAlg == nil {
return nil, errors.New("Unknown hash algorithm")
}
return asn1.Marshal(ocspRequest{
tbsRequest{
Version: 0,
RequestList: []request{
{
Cert: certID{
pkix.AlgorithmIdentifier{
Algorithm: hashAlg,
Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */},
},
req.IssuerNameHash,
req.IssuerKeyHash,
req.SerialNumber,
},
},
},
},
})
}
// Response represents an OCSP response containing a single SingleResponse. See
// RFC 6960.
type Response struct {
// Status is one of {Good, Revoked, Unknown}
Status int
SerialNumber *big.Int
ProducedAt, ThisUpdate, NextUpdate, RevokedAt time.Time
RevocationReason int
Certificate *x509.Certificate
// TBSResponseData contains the raw bytes of the signed response. If
// Certificate is nil then this can be used to verify Signature.
TBSResponseData []byte
Signature []byte
SignatureAlgorithm x509.SignatureAlgorithm
// IssuerHash is the hash used to compute the IssuerNameHash and IssuerKeyHash.
// Valid values are crypto.SHA1, crypto.SHA256, crypto.SHA384, and crypto.SHA512.
// If zero, the default is crypto.SHA1.
IssuerHash crypto.Hash
// RawResponderName optionally contains the DER-encoded subject of the
// responder certificate. Exactly one of RawResponderName and
// ResponderKeyHash is set.
RawResponderName []byte
// ResponderKeyHash optionally contains the SHA-1 hash of the
// responder's public key. Exactly one of RawResponderName and
// ResponderKeyHash is set.
ResponderKeyHash []byte
// Extensions contains raw X.509 extensions from the singleExtensions field
// of the OCSP response. When parsing certificates, this can be used to
// extract non-critical extensions that are not parsed by this package. When
// marshaling OCSP responses, the Extensions field is ignored, see
// ExtraExtensions.
Extensions []pkix.Extension
// ExtraExtensions contains extensions to be copied, raw, into any marshaled
// OCSP response (in the singleExtensions field). Values override any
// extensions that would otherwise be produced based on the other fields. The
// ExtraExtensions field is not populated when parsing certificates, see
// Extensions.
ExtraExtensions []pkix.Extension
}
// These are pre-serialized error responses for the various non-success codes
// defined by OCSP. The Unauthorized code in particular can be used by an OCSP
// responder that supports only pre-signed responses as a response to requests
// for certificates with unknown status. See RFC 5019.
var (
MalformedRequestErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x01}
InternalErrorErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x02}
TryLaterErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x03}
SigRequredErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x05}
UnauthorizedErrorResponse = []byte{0x30, 0x03, 0x0A, 0x01, 0x06}
)
// CheckSignatureFrom checks that the signature in resp is a valid signature
// from issuer. This should only be used if resp.Certificate is nil. Otherwise,
// the OCSP response contained an intermediate certificate that created the
// signature. That signature is checked by ParseResponse and only
// resp.Certificate remains to be validated.
func (resp *Response) CheckSignatureFrom(issuer *x509.Certificate) error {
return issuer.CheckSignature(resp.SignatureAlgorithm, resp.TBSResponseData, resp.Signature)
}
// ParseError results from an invalid OCSP response.
type ParseError string
func (p ParseError) Error() string {
return string(p)
}
// ParseRequest parses an OCSP request in DER form. It only supports
// requests for a single certificate. Signed requests are not supported.
// If a request includes a signature, it will result in a ParseError.
func ParseRequest(bytes []byte) (*Request, error) {
var req ocspRequest
rest, err := asn1.Unmarshal(bytes, &req)
if err != nil {
return nil, err
}
if len(rest) > 0 {
return nil, ParseError("trailing data in OCSP request")
}
if len(req.TBSRequest.RequestList) == 0 {
return nil, ParseError("OCSP request contains no request body")
}
innerRequest := req.TBSRequest.RequestList[0]
hashFunc := getHashAlgorithmFromOID(innerRequest.Cert.HashAlgorithm.Algorithm)
if hashFunc == crypto.Hash(0) {
return nil, ParseError("OCSP request uses unknown hash function")
}
return &Request{
HashAlgorithm: hashFunc,
IssuerNameHash: innerRequest.Cert.NameHash,
IssuerKeyHash: innerRequest.Cert.IssuerKeyHash,
SerialNumber: innerRequest.Cert.SerialNumber,
}, nil
}
// ParseResponse parses an OCSP response in DER form. It only supports
// responses for a single certificate. If the response contains a certificate
// then the signature over the response is checked. If issuer is not nil then
// it will be used to validate the signature or embedded certificate.
//
// Invalid signatures or parse failures will result in a ParseError. Error
// responses will result in a ResponseError.
func ParseResponse(bytes []byte, issuer *x509.Certificate) (*Response, error) {
return ParseResponseForCert(bytes, nil, issuer)
}
// ParseResponseForCert parses an OCSP response in DER form and searches for a
// Response relating to cert. If such a Response is found and the OCSP response
// contains a certificate then the signature over the response is checked. If
// issuer is not nil then it will be used to validate the signature or embedded
// certificate.
//
// Invalid signatures or parse failures will result in a ParseError. Error
// responses will result in a ResponseError.
func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Response, error) {
var resp responseASN1
rest, err := asn1.Unmarshal(bytes, &resp)
if err != nil {
return nil, err
}
if len(rest) > 0 {
return nil, ParseError("trailing data in OCSP response")
}
if status := ResponseStatus(resp.Status); status != Success {
return nil, ResponseError{status}
}
if !resp.Response.ResponseType.Equal(idPKIXOCSPBasic) {
return nil, ParseError("bad OCSP response type")
}
var basicResp basicResponse
rest, err = asn1.Unmarshal(resp.Response.Response, &basicResp)
if err != nil {
return nil, err
}
if len(basicResp.Certificates) > 1 {
return nil, ParseError("OCSP response contains bad number of certificates")
}
if n := len(basicResp.TBSResponseData.Responses); n == 0 || cert == nil && n > 1 {
return nil, ParseError("OCSP response contains bad number of responses")
}
ret := &Response{
TBSResponseData: basicResp.TBSResponseData.Raw,
Signature: basicResp.Signature.RightAlign(),
SignatureAlgorithm: getSignatureAlgorithmFromOID(basicResp.SignatureAlgorithm.Algorithm),
}
// Handle the ResponderID CHOICE tag. ResponderID can be flattened into
// TBSResponseData once https://go-review.googlesource.com/34503 has been
// released.
rawResponderID := basicResp.TBSResponseData.RawResponderID
switch rawResponderID.Tag {
case 1: // Name
var rdn pkix.RDNSequence
if rest, err := asn1.Unmarshal(rawResponderID.Bytes, &rdn); err != nil || len(rest) != 0 {
return nil, ParseError("invalid responder name")
}
ret.RawResponderName = rawResponderID.Bytes
case 2: // KeyHash
if rest, err := asn1.Unmarshal(rawResponderID.Bytes, &ret.ResponderKeyHash); err != nil || len(rest) != 0 {
return nil, ParseError("invalid responder key hash")
}
default:
return nil, ParseError("invalid responder id tag")
}
if len(basicResp.Certificates) > 0 {
ret.Certificate, err = x509.ParseCertificate(basicResp.Certificates[0].FullBytes)
if err != nil {
return nil, err
}
if err := ret.CheckSignatureFrom(ret.Certificate); err != nil {
return nil, ParseError("bad signature on embedded certificate: " + err.Error())
}
if issuer != nil {
if err := issuer.CheckSignature(ret.Certificate.SignatureAlgorithm, ret.Certificate.RawTBSCertificate, ret.Certificate.Signature); err != nil {
return nil, ParseError("bad OCSP signature: " + err.Error())
}
}
} else if issuer != nil {
if err := ret.CheckSignatureFrom(issuer); err != nil {
return nil, ParseError("bad OCSP signature: " + err.Error())
}
}
var r singleResponse
for _, resp := range basicResp.TBSResponseData.Responses {
if cert == nil || cert.SerialNumber.Cmp(resp.CertID.SerialNumber) == 0 {
r = resp
break
}
}
for _, ext := range r.SingleExtensions {
if ext.Critical {
return nil, ParseError("unsupported critical extension")
}
}
ret.Extensions = r.SingleExtensions
ret.SerialNumber = r.CertID.SerialNumber
for h, oid := range hashOIDs {
if r.CertID.HashAlgorithm.Algorithm.Equal(oid) {
ret.IssuerHash = h
break
}
}
if ret.IssuerHash == 0 {
return nil, ParseError("unsupported issuer hash algorithm")
}
switch {
case bool(r.Good):
ret.Status = Good
case bool(r.Unknown):
ret.Status = Unknown
default:
ret.Status = Revoked
ret.RevokedAt = r.Revoked.RevocationTime
ret.RevocationReason = int(r.Revoked.Reason)
}
ret.ProducedAt = basicResp.TBSResponseData.ProducedAt
ret.ThisUpdate = r.ThisUpdate
ret.NextUpdate = r.NextUpdate
return ret, nil
}
// RequestOptions contains options for constructing OCSP requests.
type RequestOptions struct {
// Hash contains the hash function that should be used when
// constructing the OCSP request. If zero, SHA-1 will be used.
Hash crypto.Hash
}
func (opts *RequestOptions) hash() crypto.Hash {
if opts == nil || opts.Hash == 0 {
// SHA-1 is nearly universally used in OCSP.
return crypto.SHA1
}
return opts.Hash
}
// CreateRequest returns a DER-encoded, OCSP request for the status of cert. If
// opts is nil then sensible defaults are used.
func CreateRequest(cert, issuer *x509.Certificate, opts *RequestOptions) ([]byte, error) {
hashFunc := opts.hash()
// OCSP seems to be the only place where these raw hash identifiers are
// used. I took the following from
// http://msdn.microsoft.com/en-us/library/ff635603.aspx
_, ok := hashOIDs[hashFunc]
if !ok {
return nil, x509.ErrUnsupportedAlgorithm
}
if !hashFunc.Available() {
return nil, x509.ErrUnsupportedAlgorithm
}
h := opts.hash().New()
var publicKeyInfo struct {
Algorithm pkix.AlgorithmIdentifier
PublicKey asn1.BitString
}
if _, err := asn1.Unmarshal(issuer.RawSubjectPublicKeyInfo, &publicKeyInfo); err != nil {
return nil, err
}
h.Write(publicKeyInfo.PublicKey.RightAlign())
issuerKeyHash := h.Sum(nil)
h.Reset()
h.Write(issuer.RawSubject)
issuerNameHash := h.Sum(nil)
req := &Request{
HashAlgorithm: hashFunc,
IssuerNameHash: issuerNameHash,
IssuerKeyHash: issuerKeyHash,
SerialNumber: cert.SerialNumber,
}
return req.Marshal()
}

View file

@ -0,0 +1,124 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package https
import (
"crypto/tls"
"fmt"
"github.com/ossrs/go-oryx-lib/https/letsencrypt"
"runtime"
"strconv"
"strings"
)
// Requires golang 1.6+, because there's bug in http.Server
// to set the GetCertificate of TLSConfig.
func checkRuntime() (err error) {
version := strings.Trim(runtime.Version(), "go")
if versions := strings.Split(version, "."); len(versions) < 1 {
return fmt.Errorf("invalid version=%v", version)
} else if major, err := strconv.Atoi(versions[0]); err != nil {
return fmt.Errorf("invalid version=%v, err=%v", version, err)
} else if minor, err := strconv.Atoi(versions[1]); err != nil {
return fmt.Errorf("invalid version=%v, err=%v", version, err)
} else if major == 1 && minor < 6 {
return fmt.Errorf("requires golang 1.6+, version=%v(%v.%v)", version, major, minor)
}
return
}
// The https manager which provides the certificate.
type Manager interface {
GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
}
// The cert is sign by ourself.
type selfSignManager struct {
cert *tls.Certificate
certFile string
keyFile string
}
func NewSelfSignManager(certFile, keyFile string) (m Manager, err error) {
if err = checkRuntime(); err != nil {
return
}
return &selfSignManager{certFile: certFile, keyFile: keyFile}, nil
}
func (v *selfSignManager) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if v.cert != nil {
return v.cert, nil
}
cert, err := tls.LoadX509KeyPair(v.certFile, v.keyFile)
if err != nil {
return nil, err
}
// cache the cert.
v.cert = &cert
return &cert, err
}
// The cert is sign by letsencrypt
type letsencryptManager struct {
lets letsencrypt.Manager
}
// Register the email to letsencrypt, cache the certs in cacheFile, set allow hosts.
// @remark set hosts to empty string when allow all request hosts, but maybe attack.
// @remark set email to nil to not regiester, use empty email to request cert from letsencrypt.
// @remark set cacheFile to empty string to not cache the info and certs.
// @remark we only use tls validate, https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#tls-with-server-name-indication-tls-sni
// so the https port must be 443, we cannot serve at other ports.
func NewLetsencryptManager(email string, hosts []string, cacheFile string) (m Manager, err error) {
v := &letsencryptManager{}
if err = checkRuntime(); err != nil {
return
}
if cacheFile != "" {
if err = v.lets.CacheFile(cacheFile); err != nil {
return
}
}
if len(hosts) > 0 {
v.lets.SetHosts(hosts)
}
if email != "" {
if err = v.lets.Register(email, nil); err != nil {
return
}
}
return v, nil
}
func (v *letsencryptManager) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return v.lets.GetCertificate(clientHello)
}

View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,521 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1
package jose
import (
"crypto"
"crypto/aes"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"errors"
"fmt"
"math/big"
"github.com/ossrs/go-oryx-lib/https/jose/cipher"
)
// A generic RSA-based encrypter/verifier
type rsaEncrypterVerifier struct {
publicKey *rsa.PublicKey
}
// A generic RSA-based decrypter/signer
type rsaDecrypterSigner struct {
privateKey *rsa.PrivateKey
}
// A generic EC-based encrypter/verifier
type ecEncrypterVerifier struct {
publicKey *ecdsa.PublicKey
}
// A key generator for ECDH-ES
type ecKeyGenerator struct {
size int
algID string
publicKey *ecdsa.PublicKey
}
// A generic EC-based decrypter/signer
type ecDecrypterSigner struct {
privateKey *ecdsa.PrivateKey
}
// newRSARecipient creates recipientKeyInfo based on the given key.
func newRSARecipient(keyAlg KeyAlgorithm, publicKey *rsa.PublicKey) (recipientKeyInfo, error) {
// Verify that key management algorithm is supported by this encrypter
switch keyAlg {
case RSA1_5, RSA_OAEP, RSA_OAEP_256:
default:
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
}
if publicKey == nil {
return recipientKeyInfo{}, errors.New("invalid public key")
}
return recipientKeyInfo{
keyAlg: keyAlg,
keyEncrypter: &rsaEncrypterVerifier{
publicKey: publicKey,
},
}, nil
}
// newRSASigner creates a recipientSigInfo based on the given key.
func newRSASigner(sigAlg SignatureAlgorithm, privateKey *rsa.PrivateKey) (recipientSigInfo, error) {
// Verify that key management algorithm is supported by this encrypter
switch sigAlg {
case RS256, RS384, RS512, PS256, PS384, PS512:
default:
return recipientSigInfo{}, ErrUnsupportedAlgorithm
}
if privateKey == nil {
return recipientSigInfo{}, errors.New("invalid private key")
}
return recipientSigInfo{
sigAlg: sigAlg,
publicKey: &JsonWebKey{
Key: &privateKey.PublicKey,
},
signer: &rsaDecrypterSigner{
privateKey: privateKey,
},
}, nil
}
// newECDHRecipient creates recipientKeyInfo based on the given key.
func newECDHRecipient(keyAlg KeyAlgorithm, publicKey *ecdsa.PublicKey) (recipientKeyInfo, error) {
// Verify that key management algorithm is supported by this encrypter
switch keyAlg {
case ECDH_ES, ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW:
default:
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
}
if publicKey == nil || !publicKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) {
return recipientKeyInfo{}, errors.New("invalid public key")
}
return recipientKeyInfo{
keyAlg: keyAlg,
keyEncrypter: &ecEncrypterVerifier{
publicKey: publicKey,
},
}, nil
}
// newECDSASigner creates a recipientSigInfo based on the given key.
func newECDSASigner(sigAlg SignatureAlgorithm, privateKey *ecdsa.PrivateKey) (recipientSigInfo, error) {
// Verify that key management algorithm is supported by this encrypter
switch sigAlg {
case ES256, ES384, ES512:
default:
return recipientSigInfo{}, ErrUnsupportedAlgorithm
}
if privateKey == nil {
return recipientSigInfo{}, errors.New("invalid private key")
}
return recipientSigInfo{
sigAlg: sigAlg,
publicKey: &JsonWebKey{
Key: &privateKey.PublicKey,
},
signer: &ecDecrypterSigner{
privateKey: privateKey,
},
}, nil
}
// Encrypt the given payload and update the object.
func (ctx rsaEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
encryptedKey, err := ctx.encrypt(cek, alg)
if err != nil {
return recipientInfo{}, err
}
return recipientInfo{
encryptedKey: encryptedKey,
header: &rawHeader{},
}, nil
}
// Encrypt the given payload. Based on the key encryption algorithm,
// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256).
func (ctx rsaEncrypterVerifier) encrypt(cek []byte, alg KeyAlgorithm) ([]byte, error) {
switch alg {
case RSA1_5:
return rsa.EncryptPKCS1v15(randReader, ctx.publicKey, cek)
case RSA_OAEP:
return rsa.EncryptOAEP(sha1.New(), randReader, ctx.publicKey, cek, []byte{})
case RSA_OAEP_256:
return rsa.EncryptOAEP(sha256.New(), randReader, ctx.publicKey, cek, []byte{})
}
return nil, ErrUnsupportedAlgorithm
}
// Decrypt the given payload and return the content encryption key.
func (ctx rsaDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
return ctx.decrypt(recipient.encryptedKey, KeyAlgorithm(headers.Alg), generator)
}
// Decrypt the given payload. Based on the key encryption algorithm,
// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256).
func (ctx rsaDecrypterSigner) decrypt(jek []byte, alg KeyAlgorithm, generator keyGenerator) ([]byte, error) {
// Note: The random reader on decrypt operations is only used for blinding,
// so stubbing is meanlingless (hence the direct use of rand.Reader).
switch alg {
case RSA1_5:
defer func() {
// DecryptPKCS1v15SessionKey sometimes panics on an invalid payload
// because of an index out of bounds error, which we want to ignore.
// This has been fixed in Go 1.3.1 (released 2014/08/13), the recover()
// only exists for preventing crashes with unpatched versions.
// See: https://groups.google.com/forum/#!topic/golang-dev/7ihX6Y6kx9k
// See: https://code.google.com/p/go/source/detail?r=58ee390ff31602edb66af41ed10901ec95904d33
_ = recover()
}()
// Perform some input validation.
keyBytes := ctx.privateKey.PublicKey.N.BitLen() / 8
if keyBytes != len(jek) {
// Input size is incorrect, the encrypted payload should always match
// the size of the public modulus (e.g. using a 2048 bit key will
// produce 256 bytes of output). Reject this since it's invalid input.
return nil, ErrCryptoFailure
}
cek, _, err := generator.genKey()
if err != nil {
return nil, ErrCryptoFailure
}
// When decrypting an RSA-PKCS1v1.5 payload, we must take precautions to
// prevent chosen-ciphertext attacks as described in RFC 3218, "Preventing
// the Million Message Attack on Cryptographic Message Syntax". We are
// therefore deliberately ignoring errors here.
_ = rsa.DecryptPKCS1v15SessionKey(rand.Reader, ctx.privateKey, jek, cek)
return cek, nil
case RSA_OAEP:
// Use rand.Reader for RSA blinding
return rsa.DecryptOAEP(sha1.New(), rand.Reader, ctx.privateKey, jek, []byte{})
case RSA_OAEP_256:
// Use rand.Reader for RSA blinding
return rsa.DecryptOAEP(sha256.New(), rand.Reader, ctx.privateKey, jek, []byte{})
}
return nil, ErrUnsupportedAlgorithm
}
// Sign the given payload
func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
var hash crypto.Hash
switch alg {
case RS256, PS256:
hash = crypto.SHA256
case RS384, PS384:
hash = crypto.SHA384
case RS512, PS512:
hash = crypto.SHA512
default:
return Signature{}, ErrUnsupportedAlgorithm
}
hasher := hash.New()
// According to documentation, Write() on hash never fails
_, _ = hasher.Write(payload)
hashed := hasher.Sum(nil)
var out []byte
var err error
switch alg {
case RS256, RS384, RS512:
out, err = rsa.SignPKCS1v15(randReader, ctx.privateKey, hash, hashed)
case PS256, PS384, PS512:
out, err = rsa.SignPSS(randReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
})
}
if err != nil {
return Signature{}, err
}
return Signature{
Signature: out,
protected: &rawHeader{},
}, nil
}
// Verify the given payload
func (ctx rsaEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
var hash crypto.Hash
switch alg {
case RS256, PS256:
hash = crypto.SHA256
case RS384, PS384:
hash = crypto.SHA384
case RS512, PS512:
hash = crypto.SHA512
default:
return ErrUnsupportedAlgorithm
}
hasher := hash.New()
// According to documentation, Write() on hash never fails
_, _ = hasher.Write(payload)
hashed := hasher.Sum(nil)
switch alg {
case RS256, RS384, RS512:
return rsa.VerifyPKCS1v15(ctx.publicKey, hash, hashed, signature)
case PS256, PS384, PS512:
return rsa.VerifyPSS(ctx.publicKey, hash, hashed, signature, nil)
}
return ErrUnsupportedAlgorithm
}
// Encrypt the given payload and update the object.
func (ctx ecEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
switch alg {
case ECDH_ES:
// ECDH-ES mode doesn't wrap a key, the shared secret is used directly as the key.
return recipientInfo{
header: &rawHeader{},
}, nil
case ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW:
default:
return recipientInfo{}, ErrUnsupportedAlgorithm
}
generator := ecKeyGenerator{
algID: string(alg),
publicKey: ctx.publicKey,
}
switch alg {
case ECDH_ES_A128KW:
generator.size = 16
case ECDH_ES_A192KW:
generator.size = 24
case ECDH_ES_A256KW:
generator.size = 32
}
kek, header, err := generator.genKey()
if err != nil {
return recipientInfo{}, err
}
block, err := aes.NewCipher(kek)
if err != nil {
return recipientInfo{}, err
}
jek, err := josecipher.KeyWrap(block, cek)
if err != nil {
return recipientInfo{}, err
}
return recipientInfo{
encryptedKey: jek,
header: &header,
}, nil
}
// Get key size for EC key generator
func (ctx ecKeyGenerator) keySize() int {
return ctx.size
}
// Get a content encryption key for ECDH-ES
func (ctx ecKeyGenerator) genKey() ([]byte, rawHeader, error) {
priv, err := ecdsa.GenerateKey(ctx.publicKey.Curve, randReader)
if err != nil {
return nil, rawHeader{}, err
}
out := josecipher.DeriveECDHES(ctx.algID, []byte{}, []byte{}, priv, ctx.publicKey, ctx.size)
headers := rawHeader{
Epk: &JsonWebKey{
Key: &priv.PublicKey,
},
}
return out, headers, nil
}
// Decrypt the given payload and return the content encryption key.
func (ctx ecDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
if headers.Epk == nil {
return nil, errors.New("square/go-jose: missing epk header")
}
publicKey, ok := headers.Epk.Key.(*ecdsa.PublicKey)
if publicKey == nil || !ok {
return nil, errors.New("square/go-jose: invalid epk header")
}
if !ctx.privateKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) {
return nil, errors.New("square/go-jose: invalid public key in epk header")
}
apuData := headers.Apu.bytes()
apvData := headers.Apv.bytes()
deriveKey := func(algID string, size int) []byte {
return josecipher.DeriveECDHES(algID, apuData, apvData, ctx.privateKey, publicKey, size)
}
var keySize int
switch KeyAlgorithm(headers.Alg) {
case ECDH_ES:
// ECDH-ES uses direct key agreement, no key unwrapping necessary.
return deriveKey(string(headers.Enc), generator.keySize()), nil
case ECDH_ES_A128KW:
keySize = 16
case ECDH_ES_A192KW:
keySize = 24
case ECDH_ES_A256KW:
keySize = 32
default:
return nil, ErrUnsupportedAlgorithm
}
key := deriveKey(headers.Alg, keySize)
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return josecipher.KeyUnwrap(block, recipient.encryptedKey)
}
// Sign the given payload
func (ctx ecDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
var expectedBitSize int
var hash crypto.Hash
switch alg {
case ES256:
expectedBitSize = 256
hash = crypto.SHA256
case ES384:
expectedBitSize = 384
hash = crypto.SHA384
case ES512:
expectedBitSize = 521
hash = crypto.SHA512
}
curveBits := ctx.privateKey.Curve.Params().BitSize
if expectedBitSize != curveBits {
return Signature{}, fmt.Errorf("square/go-jose: expected %d bit key, got %d bits instead", expectedBitSize, curveBits)
}
hasher := hash.New()
// According to documentation, Write() on hash never fails
_, _ = hasher.Write(payload)
hashed := hasher.Sum(nil)
r, s, err := ecdsa.Sign(randReader, ctx.privateKey, hashed)
if err != nil {
return Signature{}, err
}
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes += 1
}
// We serialize the outpus (r and s) into big-endian byte arrays and pad
// them with zeros on the left to make sure the sizes work out. Both arrays
// must be keyBytes long, and the output must be 2*keyBytes long.
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
out := append(rBytesPadded, sBytesPadded...)
return Signature{
Signature: out,
protected: &rawHeader{},
}, nil
}
// Verify the given payload
func (ctx ecEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
var keySize int
var hash crypto.Hash
switch alg {
case ES256:
keySize = 32
hash = crypto.SHA256
case ES384:
keySize = 48
hash = crypto.SHA384
case ES512:
keySize = 66
hash = crypto.SHA512
default:
return ErrUnsupportedAlgorithm
}
if len(signature) != 2*keySize {
return fmt.Errorf("square/go-jose: invalid signature size, have %d bytes, wanted %d", len(signature), 2*keySize)
}
hasher := hash.New()
// According to documentation, Write() on hash never fails
_, _ = hasher.Write(payload)
hashed := hasher.Sum(nil)
r := big.NewInt(0).SetBytes(signature[:keySize])
s := big.NewInt(0).SetBytes(signature[keySize:])
match := ecdsa.Verify(ctx.publicKey, hashed, r, s)
if !match {
return errors.New("square/go-jose: ecdsa signature failed to verify")
}
return nil
}

View file

@ -0,0 +1,197 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1/cipher
package josecipher
import (
"bytes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"crypto/sha512"
"crypto/subtle"
"encoding/binary"
"errors"
"hash"
)
const (
nonceBytes = 16
)
// NewCBCHMAC instantiates a new AEAD based on CBC+HMAC.
func NewCBCHMAC(key []byte, newBlockCipher func([]byte) (cipher.Block, error)) (cipher.AEAD, error) {
keySize := len(key) / 2
integrityKey := key[:keySize]
encryptionKey := key[keySize:]
blockCipher, err := newBlockCipher(encryptionKey)
if err != nil {
return nil, err
}
var hash func() hash.Hash
switch keySize {
case 16:
hash = sha256.New
case 24:
hash = sha512.New384
case 32:
hash = sha512.New
}
return &cbcAEAD{
hash: hash,
blockCipher: blockCipher,
authtagBytes: keySize,
integrityKey: integrityKey,
}, nil
}
// An AEAD based on CBC+HMAC
type cbcAEAD struct {
hash func() hash.Hash
authtagBytes int
integrityKey []byte
blockCipher cipher.Block
}
func (ctx *cbcAEAD) NonceSize() int {
return nonceBytes
}
func (ctx *cbcAEAD) Overhead() int {
// Maximum overhead is block size (for padding) plus auth tag length, where
// the length of the auth tag is equivalent to the key size.
return ctx.blockCipher.BlockSize() + ctx.authtagBytes
}
// Seal encrypts and authenticates the plaintext.
func (ctx *cbcAEAD) Seal(dst, nonce, plaintext, data []byte) []byte {
// Output buffer -- must take care not to mangle plaintext input.
ciphertext := make([]byte, uint64(len(plaintext))+uint64(ctx.Overhead()))[:len(plaintext)]
copy(ciphertext, plaintext)
ciphertext = padBuffer(ciphertext, ctx.blockCipher.BlockSize())
cbc := cipher.NewCBCEncrypter(ctx.blockCipher, nonce)
cbc.CryptBlocks(ciphertext, ciphertext)
authtag := ctx.computeAuthTag(data, nonce, ciphertext)
ret, out := resize(dst, uint64(len(dst))+uint64(len(ciphertext))+uint64(len(authtag)))
copy(out, ciphertext)
copy(out[len(ciphertext):], authtag)
return ret
}
// Open decrypts and authenticates the ciphertext.
func (ctx *cbcAEAD) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
if len(ciphertext) < ctx.authtagBytes {
return nil, errors.New("square/go-jose: invalid ciphertext (too short)")
}
offset := len(ciphertext) - ctx.authtagBytes
expectedTag := ctx.computeAuthTag(data, nonce, ciphertext[:offset])
match := subtle.ConstantTimeCompare(expectedTag, ciphertext[offset:])
if match != 1 {
return nil, errors.New("square/go-jose: invalid ciphertext (auth tag mismatch)")
}
cbc := cipher.NewCBCDecrypter(ctx.blockCipher, nonce)
// Make copy of ciphertext buffer, don't want to modify in place
buffer := append([]byte{}, []byte(ciphertext[:offset])...)
if len(buffer)%ctx.blockCipher.BlockSize() > 0 {
return nil, errors.New("square/go-jose: invalid ciphertext (invalid length)")
}
cbc.CryptBlocks(buffer, buffer)
// Remove padding
plaintext, err := unpadBuffer(buffer, ctx.blockCipher.BlockSize())
if err != nil {
return nil, err
}
ret, out := resize(dst, uint64(len(dst))+uint64(len(plaintext)))
copy(out, plaintext)
return ret, nil
}
// Compute an authentication tag
func (ctx *cbcAEAD) computeAuthTag(aad, nonce, ciphertext []byte) []byte {
buffer := make([]byte, uint64(len(aad))+uint64(len(nonce))+uint64(len(ciphertext))+8)
n := 0
n += copy(buffer, aad)
n += copy(buffer[n:], nonce)
n += copy(buffer[n:], ciphertext)
binary.BigEndian.PutUint64(buffer[n:], uint64(len(aad))*8)
// According to documentation, Write() on hash.Hash never fails.
hmac := hmac.New(ctx.hash, ctx.integrityKey)
_, _ = hmac.Write(buffer)
return hmac.Sum(nil)[:ctx.authtagBytes]
}
// resize ensures the the given slice has a capacity of at least n bytes.
// If the capacity of the slice is less than n, a new slice is allocated
// and the existing data will be copied.
func resize(in []byte, n uint64) (head, tail []byte) {
if uint64(cap(in)) >= n {
head = in[:n]
} else {
head = make([]byte, n)
copy(head, in)
}
tail = head[len(in):]
return
}
// Apply padding
func padBuffer(buffer []byte, blockSize int) []byte {
missing := blockSize - (len(buffer) % blockSize)
ret, out := resize(buffer, uint64(len(buffer))+uint64(missing))
padding := bytes.Repeat([]byte{byte(missing)}, missing)
copy(out, padding)
return ret
}
// Remove padding
func unpadBuffer(buffer []byte, blockSize int) ([]byte, error) {
if len(buffer)%blockSize != 0 {
return nil, errors.New("square/go-jose: invalid padding")
}
last := buffer[len(buffer)-1]
count := int(last)
if count == 0 || count > blockSize || count > len(buffer) {
return nil, errors.New("square/go-jose: invalid padding")
}
padding := bytes.Repeat([]byte{last}, count)
if !bytes.HasSuffix(buffer, padding) {
return nil, errors.New("square/go-jose: invalid padding")
}
return buffer[:len(buffer)-count], nil
}

View file

@ -0,0 +1,76 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1/cipher
package josecipher
import (
"crypto"
"encoding/binary"
"hash"
"io"
)
type concatKDF struct {
z, info []byte
i uint32
cache []byte
hasher hash.Hash
}
// NewConcatKDF builds a KDF reader based on the given inputs.
func NewConcatKDF(hash crypto.Hash, z, algID, ptyUInfo, ptyVInfo, supPubInfo, supPrivInfo []byte) io.Reader {
buffer := make([]byte, uint64(len(algID))+uint64(len(ptyUInfo))+uint64(len(ptyVInfo))+uint64(len(supPubInfo))+uint64(len(supPrivInfo)))
n := 0
n += copy(buffer, algID)
n += copy(buffer[n:], ptyUInfo)
n += copy(buffer[n:], ptyVInfo)
n += copy(buffer[n:], supPubInfo)
copy(buffer[n:], supPrivInfo)
hasher := hash.New()
return &concatKDF{
z: z,
info: buffer,
hasher: hasher,
cache: []byte{},
i: 1,
}
}
func (ctx *concatKDF) Read(out []byte) (int, error) {
copied := copy(out, ctx.cache)
ctx.cache = ctx.cache[copied:]
for copied < len(out) {
ctx.hasher.Reset()
// Write on a hash.Hash never fails
_ = binary.Write(ctx.hasher, binary.BigEndian, ctx.i)
_, _ = ctx.hasher.Write(ctx.z)
_, _ = ctx.hasher.Write(ctx.info)
hash := ctx.hasher.Sum(nil)
chunkCopied := copy(out[copied:], hash)
copied += chunkCopied
ctx.cache = hash[chunkCopied:]
ctx.i++
}
return copied, nil
}

View file

@ -0,0 +1,63 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1/cipher
package josecipher
import (
"crypto"
"crypto/ecdsa"
"encoding/binary"
)
// DeriveECDHES derives a shared encryption key using ECDH/ConcatKDF as described in JWE/JWA.
// It is an error to call this function with a private/public key that are not on the same
// curve. Callers must ensure that the keys are valid before calling this function. Output
// size may be at most 1<<16 bytes (64 KiB).
func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, size int) []byte {
if size > 1<<16 {
panic("ECDH-ES output size too large, must be less than 1<<16")
}
// algId, partyUInfo, partyVInfo inputs must be prefixed with the length
algID := lengthPrefixed([]byte(alg))
ptyUInfo := lengthPrefixed(apuData)
ptyVInfo := lengthPrefixed(apvData)
// suppPubInfo is the encoded length of the output size in bits
supPubInfo := make([]byte, 4)
binary.BigEndian.PutUint32(supPubInfo, uint32(size)*8)
if !priv.PublicKey.Curve.IsOnCurve(pub.X, pub.Y) {
panic("public key not on same curve as private key")
}
z, _ := priv.PublicKey.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
reader := NewConcatKDF(crypto.SHA256, z.Bytes(), algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{})
key := make([]byte, size)
// Read on the KDF will never fail
_, _ = reader.Read(key)
return key
}
func lengthPrefixed(data []byte) []byte {
out := make([]byte, len(data)+4)
binary.BigEndian.PutUint32(out, uint32(len(data)))
copy(out[4:], data)
return out
}

View file

@ -0,0 +1,110 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1/cipher
package josecipher
import (
"crypto/cipher"
"crypto/subtle"
"encoding/binary"
"errors"
)
var defaultIV = []byte{0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6}
// KeyWrap implements NIST key wrapping; it wraps a content encryption key (cek) with the given block cipher.
func KeyWrap(block cipher.Block, cek []byte) ([]byte, error) {
if len(cek)%8 != 0 {
return nil, errors.New("square/go-jose: key wrap input must be 8 byte blocks")
}
n := len(cek) / 8
r := make([][]byte, n)
for i := range r {
r[i] = make([]byte, 8)
copy(r[i], cek[i*8:])
}
buffer := make([]byte, 16)
tBytes := make([]byte, 8)
copy(buffer, defaultIV)
for t := 0; t < 6*n; t++ {
copy(buffer[8:], r[t%n])
block.Encrypt(buffer, buffer)
binary.BigEndian.PutUint64(tBytes, uint64(t+1))
for i := 0; i < 8; i++ {
buffer[i] = buffer[i] ^ tBytes[i]
}
copy(r[t%n], buffer[8:])
}
out := make([]byte, (n+1)*8)
copy(out, buffer[:8])
for i := range r {
copy(out[(i+1)*8:], r[i])
}
return out, nil
}
// KeyUnwrap implements NIST key unwrapping; it unwraps a content encryption key (cek) with the given block cipher.
func KeyUnwrap(block cipher.Block, ciphertext []byte) ([]byte, error) {
if len(ciphertext)%8 != 0 {
return nil, errors.New("square/go-jose: key wrap input must be 8 byte blocks")
}
n := (len(ciphertext) / 8) - 1
r := make([][]byte, n)
for i := range r {
r[i] = make([]byte, 8)
copy(r[i], ciphertext[(i+1)*8:])
}
buffer := make([]byte, 16)
tBytes := make([]byte, 8)
copy(buffer[:8], ciphertext[:8])
for t := 6*n - 1; t >= 0; t-- {
binary.BigEndian.PutUint64(tBytes, uint64(t+1))
for i := 0; i < 8; i++ {
buffer[i] = buffer[i] ^ tBytes[i]
}
copy(buffer[8:], r[t%n])
block.Decrypt(buffer, buffer)
copy(r[t%n], buffer[8:])
}
if subtle.ConstantTimeCompare(buffer[:8], defaultIV) == 0 {
return nil, errors.New("square/go-jose: failed to unwrap key")
}
out := make([]byte, n*8)
for i := range r {
copy(out[i*8:], r[i])
}
return out, nil
}

View file

@ -0,0 +1,350 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1
package jose
import (
"crypto/ecdsa"
"crypto/rsa"
"fmt"
"reflect"
)
// Encrypter represents an encrypter which produces an encrypted JWE object.
type Encrypter interface {
Encrypt(plaintext []byte) (*JsonWebEncryption, error)
EncryptWithAuthData(plaintext []byte, aad []byte) (*JsonWebEncryption, error)
SetCompression(alg CompressionAlgorithm)
}
// MultiEncrypter represents an encrypter which supports multiple recipients.
type MultiEncrypter interface {
Encrypt(plaintext []byte) (*JsonWebEncryption, error)
EncryptWithAuthData(plaintext []byte, aad []byte) (*JsonWebEncryption, error)
SetCompression(alg CompressionAlgorithm)
AddRecipient(alg KeyAlgorithm, encryptionKey interface{}) error
}
// A generic content cipher
type contentCipher interface {
keySize() int
encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error)
decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error)
}
// A key generator (for generating/getting a CEK)
type keyGenerator interface {
keySize() int
genKey() ([]byte, rawHeader, error)
}
// A generic key encrypter
type keyEncrypter interface {
encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key
}
// A generic key decrypter
type keyDecrypter interface {
decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key
}
// A generic encrypter based on the given key encrypter and content cipher.
type genericEncrypter struct {
contentAlg ContentEncryption
compressionAlg CompressionAlgorithm
cipher contentCipher
recipients []recipientKeyInfo
keyGenerator keyGenerator
}
type recipientKeyInfo struct {
keyID string
keyAlg KeyAlgorithm
keyEncrypter keyEncrypter
}
// SetCompression sets a compression algorithm to be applied before encryption.
func (ctx *genericEncrypter) SetCompression(compressionAlg CompressionAlgorithm) {
ctx.compressionAlg = compressionAlg
}
// NewEncrypter creates an appropriate encrypter based on the key type
func NewEncrypter(alg KeyAlgorithm, enc ContentEncryption, encryptionKey interface{}) (Encrypter, error) {
encrypter := &genericEncrypter{
contentAlg: enc,
compressionAlg: NONE,
recipients: []recipientKeyInfo{},
cipher: getContentCipher(enc),
}
if encrypter.cipher == nil {
return nil, ErrUnsupportedAlgorithm
}
var keyID string
var rawKey interface{}
switch encryptionKey := encryptionKey.(type) {
case *JsonWebKey:
keyID = encryptionKey.KeyID
rawKey = encryptionKey.Key
default:
rawKey = encryptionKey
}
switch alg {
case DIRECT:
// Direct encryption mode must be treated differently
if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) {
return nil, ErrUnsupportedKeyType
}
encrypter.keyGenerator = staticKeyGenerator{
key: rawKey.([]byte),
}
recipient, _ := newSymmetricRecipient(alg, rawKey.([]byte))
if keyID != "" {
recipient.keyID = keyID
}
encrypter.recipients = []recipientKeyInfo{recipient}
return encrypter, nil
case ECDH_ES:
// ECDH-ES (w/o key wrapping) is similar to DIRECT mode
typeOf := reflect.TypeOf(rawKey)
if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) {
return nil, ErrUnsupportedKeyType
}
encrypter.keyGenerator = ecKeyGenerator{
size: encrypter.cipher.keySize(),
algID: string(enc),
publicKey: rawKey.(*ecdsa.PublicKey),
}
recipient, _ := newECDHRecipient(alg, rawKey.(*ecdsa.PublicKey))
if keyID != "" {
recipient.keyID = keyID
}
encrypter.recipients = []recipientKeyInfo{recipient}
return encrypter, nil
default:
// Can just add a standard recipient
encrypter.keyGenerator = randomKeyGenerator{
size: encrypter.cipher.keySize(),
}
err := encrypter.AddRecipient(alg, encryptionKey)
return encrypter, err
}
}
// NewMultiEncrypter creates a multi-encrypter based on the given parameters
func NewMultiEncrypter(enc ContentEncryption) (MultiEncrypter, error) {
cipher := getContentCipher(enc)
if cipher == nil {
return nil, ErrUnsupportedAlgorithm
}
encrypter := &genericEncrypter{
contentAlg: enc,
compressionAlg: NONE,
recipients: []recipientKeyInfo{},
cipher: cipher,
keyGenerator: randomKeyGenerator{
size: cipher.keySize(),
},
}
return encrypter, nil
}
func (ctx *genericEncrypter) AddRecipient(alg KeyAlgorithm, encryptionKey interface{}) (err error) {
var recipient recipientKeyInfo
switch alg {
case DIRECT, ECDH_ES:
return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", alg)
}
recipient, err = makeJWERecipient(alg, encryptionKey)
if err == nil {
ctx.recipients = append(ctx.recipients, recipient)
}
return err
}
func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) {
switch encryptionKey := encryptionKey.(type) {
case *rsa.PublicKey:
return newRSARecipient(alg, encryptionKey)
case *ecdsa.PublicKey:
return newECDHRecipient(alg, encryptionKey)
case []byte:
return newSymmetricRecipient(alg, encryptionKey)
case *JsonWebKey:
recipient, err := makeJWERecipient(alg, encryptionKey.Key)
if err == nil && encryptionKey.KeyID != "" {
recipient.keyID = encryptionKey.KeyID
}
return recipient, err
default:
return recipientKeyInfo{}, ErrUnsupportedKeyType
}
}
// newDecrypter creates an appropriate decrypter based on the key type
func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
switch decryptionKey := decryptionKey.(type) {
case *rsa.PrivateKey:
return &rsaDecrypterSigner{
privateKey: decryptionKey,
}, nil
case *ecdsa.PrivateKey:
return &ecDecrypterSigner{
privateKey: decryptionKey,
}, nil
case []byte:
return &symmetricKeyCipher{
key: decryptionKey,
}, nil
case *JsonWebKey:
return newDecrypter(decryptionKey.Key)
default:
return nil, ErrUnsupportedKeyType
}
}
// Implementation of encrypt method producing a JWE object.
func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JsonWebEncryption, error) {
return ctx.EncryptWithAuthData(plaintext, nil)
}
// Implementation of encrypt method producing a JWE object.
func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JsonWebEncryption, error) {
obj := &JsonWebEncryption{}
obj.aad = aad
obj.protected = &rawHeader{
Enc: ctx.contentAlg,
}
obj.recipients = make([]recipientInfo, len(ctx.recipients))
if len(ctx.recipients) == 0 {
return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to")
}
cek, headers, err := ctx.keyGenerator.genKey()
if err != nil {
return nil, err
}
obj.protected.merge(&headers)
for i, info := range ctx.recipients {
recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg)
if err != nil {
return nil, err
}
recipient.header.Alg = string(info.keyAlg)
if info.keyID != "" {
recipient.header.Kid = info.keyID
}
obj.recipients[i] = recipient
}
if len(ctx.recipients) == 1 {
// Move per-recipient headers into main protected header if there's
// only a single recipient.
obj.protected.merge(obj.recipients[0].header)
obj.recipients[0].header = nil
}
if ctx.compressionAlg != NONE {
plaintext, err = compress(ctx.compressionAlg, plaintext)
if err != nil {
return nil, err
}
obj.protected.Zip = ctx.compressionAlg
}
authData := obj.computeAuthData()
parts, err := ctx.cipher.encrypt(cek, authData, plaintext)
if err != nil {
return nil, err
}
obj.iv = parts.iv
obj.ciphertext = parts.ciphertext
obj.tag = parts.tag
return obj, nil
}
// Decrypt and validate the object and return the plaintext.
func (obj JsonWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
headers := obj.mergedHeaders(nil)
if len(headers.Crit) > 0 {
return nil, fmt.Errorf("square/go-jose: unsupported crit header")
}
decrypter, err := newDecrypter(decryptionKey)
if err != nil {
return nil, err
}
cipher := getContentCipher(headers.Enc)
if cipher == nil {
return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.Enc))
}
generator := randomKeyGenerator{
size: cipher.keySize(),
}
parts := &aeadParts{
iv: obj.iv,
ciphertext: obj.ciphertext,
tag: obj.tag,
}
authData := obj.computeAuthData()
var plaintext []byte
for _, recipient := range obj.recipients {
recipientHeaders := obj.mergedHeaders(&recipient)
cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
if err == nil {
// Found a valid CEK -- let's try to decrypt.
plaintext, err = cipher.decrypt(cek, authData, parts)
if err == nil {
break
}
}
}
if plaintext == nil {
return nil, ErrCryptoFailure
}
// The "zip" header parameter may only be present in the protected header.
if obj.protected.Zip != "" {
plaintext, err = decompress(obj.protected.Zip, plaintext)
}
return plaintext, err
}

View file

@ -0,0 +1,27 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
Package jose aims to provide an implementation of the Javascript Object Signing
and Encryption set of standards. For the moment, it mainly focuses on
encryption and signing based on the JSON Web Encryption and JSON Web Signature
standards. The library supports both the compact and full serialization
formats, and has optional support for multiple recipients.
*/
// from gopkg.in/square/go-jose.v1
package jose // import "github.com/ossrs/go-oryx-lib/https/jose"

View file

@ -0,0 +1,194 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1
package jose
import (
"bytes"
"compress/flate"
"encoding/base64"
"encoding/binary"
"io"
"math/big"
"regexp"
"strings"
"encoding/json"
)
var stripWhitespaceRegex = regexp.MustCompile("\\s")
// Url-safe base64 encode that strips padding
func base64URLEncode(data []byte) string {
var result = base64.URLEncoding.EncodeToString(data)
return strings.TrimRight(result, "=")
}
// Url-safe base64 decoder that adds padding
func base64URLDecode(data string) ([]byte, error) {
var missing = (4 - len(data)%4) % 4
data += strings.Repeat("=", missing)
return base64.URLEncoding.DecodeString(data)
}
// Helper function to serialize known-good objects.
// Precondition: value is not a nil pointer.
func mustSerializeJSON(value interface{}) []byte {
out, err := json.Marshal(value)
if err != nil {
panic(err)
}
// We never want to serialize the top-level value "null," since it's not a
// valid JOSE message. But if a caller passes in a nil pointer to this method,
// MarshalJSON will happily serialize it as the top-level value "null". If
// that value is then embedded in another operation, for instance by being
// base64-encoded and fed as input to a signing algorithm
// (https://github.com/square/go-jose/issues/22), the result will be
// incorrect. Because this method is intended for known-good objects, and a nil
// pointer is not a known-good object, we are free to panic in this case.
// Note: It's not possible to directly check whether the data pointed at by an
// interface is a nil pointer, so we do this hacky workaround.
// https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I
if string(out) == "null" {
panic("Tried to serialize a nil pointer.")
}
return out
}
// Strip all newlines and whitespace
func stripWhitespace(data string) string {
return stripWhitespaceRegex.ReplaceAllString(data, "")
}
// Perform compression based on algorithm
func compress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
switch algorithm {
case DEFLATE:
return deflate(input)
default:
return nil, ErrUnsupportedAlgorithm
}
}
// Perform decompression based on algorithm
func decompress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
switch algorithm {
case DEFLATE:
return inflate(input)
default:
return nil, ErrUnsupportedAlgorithm
}
}
// Compress with DEFLATE
func deflate(input []byte) ([]byte, error) {
output := new(bytes.Buffer)
// Writing to byte buffer, err is always nil
writer, _ := flate.NewWriter(output, 1)
_, _ = io.Copy(writer, bytes.NewBuffer(input))
err := writer.Close()
return output.Bytes(), err
}
// Decompress with DEFLATE
func inflate(input []byte) ([]byte, error) {
output := new(bytes.Buffer)
reader := flate.NewReader(bytes.NewBuffer(input))
_, err := io.Copy(output, reader)
if err != nil {
return nil, err
}
err = reader.Close()
return output.Bytes(), err
}
// byteBuffer represents a slice of bytes that can be serialized to url-safe base64.
type byteBuffer struct {
data []byte
}
func newBuffer(data []byte) *byteBuffer {
if data == nil {
return nil
}
return &byteBuffer{
data: data,
}
}
func newFixedSizeBuffer(data []byte, length int) *byteBuffer {
if len(data) > length {
panic("square/go-jose: invalid call to newFixedSizeBuffer (len(data) > length)")
}
pad := make([]byte, length-len(data))
return newBuffer(append(pad, data...))
}
func newBufferFromInt(num uint64) *byteBuffer {
data := make([]byte, 8)
binary.BigEndian.PutUint64(data, num)
return newBuffer(bytes.TrimLeft(data, "\x00"))
}
func (b *byteBuffer) MarshalJSON() ([]byte, error) {
return json.Marshal(b.base64())
}
func (b *byteBuffer) UnmarshalJSON(data []byte) error {
var encoded string
err := json.Unmarshal(data, &encoded)
if err != nil {
return err
}
if encoded == "" {
return nil
}
decoded, err := base64URLDecode(encoded)
if err != nil {
return err
}
*b = *newBuffer(decoded)
return nil
}
func (b *byteBuffer) base64() string {
return base64URLEncode(b.data)
}
func (b *byteBuffer) bytes() []byte {
// Handling nil here allows us to transparently handle nil slices when serializing.
if b == nil {
return nil
}
return b.data
}
func (b byteBuffer) bigInt() *big.Int {
return new(big.Int).SetBytes(b.data)
}
func (b byteBuffer) toInt() int {
return int(b.bigInt().Int64())
}

View file

@ -0,0 +1,281 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1
package jose
import (
"fmt"
"strings"
"encoding/json"
)
// rawJsonWebEncryption represents a raw JWE JSON object. Used for parsing/serializing.
type rawJsonWebEncryption struct {
Protected *byteBuffer `json:"protected,omitempty"`
Unprotected *rawHeader `json:"unprotected,omitempty"`
Header *rawHeader `json:"header,omitempty"`
Recipients []rawRecipientInfo `json:"recipients,omitempty"`
Aad *byteBuffer `json:"aad,omitempty"`
EncryptedKey *byteBuffer `json:"encrypted_key,omitempty"`
Iv *byteBuffer `json:"iv,omitempty"`
Ciphertext *byteBuffer `json:"ciphertext,omitempty"`
Tag *byteBuffer `json:"tag,omitempty"`
}
// rawRecipientInfo represents a raw JWE Per-Recipient header JSON object. Used for parsing/serializing.
type rawRecipientInfo struct {
Header *rawHeader `json:"header,omitempty"`
EncryptedKey string `json:"encrypted_key,omitempty"`
}
// JsonWebEncryption represents an encrypted JWE object after parsing.
type JsonWebEncryption struct {
Header JoseHeader
protected, unprotected *rawHeader
recipients []recipientInfo
aad, iv, ciphertext, tag []byte
original *rawJsonWebEncryption
}
// recipientInfo represents a raw JWE Per-Recipient header JSON object after parsing.
type recipientInfo struct {
header *rawHeader
encryptedKey []byte
}
// GetAuthData retrieves the (optional) authenticated data attached to the object.
func (obj JsonWebEncryption) GetAuthData() []byte {
if obj.aad != nil {
out := make([]byte, len(obj.aad))
copy(out, obj.aad)
return out
}
return nil
}
// Get the merged header values
func (obj JsonWebEncryption) mergedHeaders(recipient *recipientInfo) rawHeader {
out := rawHeader{}
out.merge(obj.protected)
out.merge(obj.unprotected)
if recipient != nil {
out.merge(recipient.header)
}
return out
}
// Get the additional authenticated data from a JWE object.
func (obj JsonWebEncryption) computeAuthData() []byte {
var protected string
if obj.original != nil {
protected = obj.original.Protected.base64()
} else {
protected = base64URLEncode(mustSerializeJSON((obj.protected)))
}
output := []byte(protected)
if obj.aad != nil {
output = append(output, '.')
output = append(output, []byte(base64URLEncode(obj.aad))...)
}
return output
}
// ParseEncrypted parses an encrypted message in compact or full serialization format.
func ParseEncrypted(input string) (*JsonWebEncryption, error) {
input = stripWhitespace(input)
if strings.HasPrefix(input, "{") {
return parseEncryptedFull(input)
}
return parseEncryptedCompact(input)
}
// parseEncryptedFull parses a message in compact format.
func parseEncryptedFull(input string) (*JsonWebEncryption, error) {
var parsed rawJsonWebEncryption
err := json.Unmarshal([]byte(input), &parsed)
if err != nil {
return nil, err
}
return parsed.sanitized()
}
// sanitized produces a cleaned-up JWE object from the raw JSON.
func (parsed *rawJsonWebEncryption) sanitized() (*JsonWebEncryption, error) {
obj := &JsonWebEncryption{
original: parsed,
unprotected: parsed.Unprotected,
}
// Check that there is not a nonce in the unprotected headers
if (parsed.Unprotected != nil && parsed.Unprotected.Nonce != "") ||
(parsed.Header != nil && parsed.Header.Nonce != "") {
return nil, ErrUnprotectedNonce
}
if parsed.Protected != nil && len(parsed.Protected.bytes()) > 0 {
err := json.Unmarshal(parsed.Protected.bytes(), &obj.protected)
if err != nil {
return nil, fmt.Errorf("square/go-jose: invalid protected header: %s, %s", err, parsed.Protected.base64())
}
}
// Note: this must be called _after_ we parse the protected header,
// otherwise fields from the protected header will not get picked up.
obj.Header = obj.mergedHeaders(nil).sanitized()
if len(parsed.Recipients) == 0 {
obj.recipients = []recipientInfo{
recipientInfo{
header: parsed.Header,
encryptedKey: parsed.EncryptedKey.bytes(),
},
}
} else {
obj.recipients = make([]recipientInfo, len(parsed.Recipients))
for r := range parsed.Recipients {
encryptedKey, err := base64URLDecode(parsed.Recipients[r].EncryptedKey)
if err != nil {
return nil, err
}
// Check that there is not a nonce in the unprotected header
if parsed.Recipients[r].Header != nil && parsed.Recipients[r].Header.Nonce != "" {
return nil, ErrUnprotectedNonce
}
obj.recipients[r].header = parsed.Recipients[r].Header
obj.recipients[r].encryptedKey = encryptedKey
}
}
for _, recipient := range obj.recipients {
headers := obj.mergedHeaders(&recipient)
if headers.Alg == "" || headers.Enc == "" {
return nil, fmt.Errorf("square/go-jose: message is missing alg/enc headers")
}
}
obj.iv = parsed.Iv.bytes()
obj.ciphertext = parsed.Ciphertext.bytes()
obj.tag = parsed.Tag.bytes()
obj.aad = parsed.Aad.bytes()
return obj, nil
}
// parseEncryptedCompact parses a message in compact format.
func parseEncryptedCompact(input string) (*JsonWebEncryption, error) {
parts := strings.Split(input, ".")
if len(parts) != 5 {
return nil, fmt.Errorf("square/go-jose: compact JWE format must have five parts")
}
rawProtected, err := base64URLDecode(parts[0])
if err != nil {
return nil, err
}
encryptedKey, err := base64URLDecode(parts[1])
if err != nil {
return nil, err
}
iv, err := base64URLDecode(parts[2])
if err != nil {
return nil, err
}
ciphertext, err := base64URLDecode(parts[3])
if err != nil {
return nil, err
}
tag, err := base64URLDecode(parts[4])
if err != nil {
return nil, err
}
raw := &rawJsonWebEncryption{
Protected: newBuffer(rawProtected),
EncryptedKey: newBuffer(encryptedKey),
Iv: newBuffer(iv),
Ciphertext: newBuffer(ciphertext),
Tag: newBuffer(tag),
}
return raw.sanitized()
}
// CompactSerialize serializes an object using the compact serialization format.
func (obj JsonWebEncryption) CompactSerialize() (string, error) {
if len(obj.recipients) != 1 || obj.unprotected != nil ||
obj.protected == nil || obj.recipients[0].header != nil {
return "", ErrNotSupported
}
serializedProtected := mustSerializeJSON(obj.protected)
return fmt.Sprintf(
"%s.%s.%s.%s.%s",
base64URLEncode(serializedProtected),
base64URLEncode(obj.recipients[0].encryptedKey),
base64URLEncode(obj.iv),
base64URLEncode(obj.ciphertext),
base64URLEncode(obj.tag)), nil
}
// FullSerialize serializes an object using the full JSON serialization format.
func (obj JsonWebEncryption) FullSerialize() string {
raw := rawJsonWebEncryption{
Unprotected: obj.unprotected,
Iv: newBuffer(obj.iv),
Ciphertext: newBuffer(obj.ciphertext),
EncryptedKey: newBuffer(obj.recipients[0].encryptedKey),
Tag: newBuffer(obj.tag),
Aad: newBuffer(obj.aad),
Recipients: []rawRecipientInfo{},
}
if len(obj.recipients) > 1 {
for _, recipient := range obj.recipients {
info := rawRecipientInfo{
Header: recipient.header,
EncryptedKey: base64URLEncode(recipient.encryptedKey),
}
raw.Recipients = append(raw.Recipients, info)
}
} else {
// Use flattened serialization
raw.Header = obj.recipients[0].header
raw.EncryptedKey = newBuffer(obj.recipients[0].encryptedKey)
}
if obj.protected != nil {
raw.Protected = newBuffer(mustSerializeJSON(obj.protected))
}
return string(mustSerializeJSON(raw))
}

View file

@ -0,0 +1,448 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1
package jose
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"math/big"
"reflect"
"strings"
"encoding/json"
)
// rawJsonWebKey represents a public or private key in JWK format, used for parsing/serializing.
type rawJsonWebKey struct {
Use string `json:"use,omitempty"`
Kty string `json:"kty,omitempty"`
Kid string `json:"kid,omitempty"`
Crv string `json:"crv,omitempty"`
Alg string `json:"alg,omitempty"`
K *byteBuffer `json:"k,omitempty"`
X *byteBuffer `json:"x,omitempty"`
Y *byteBuffer `json:"y,omitempty"`
N *byteBuffer `json:"n,omitempty"`
E *byteBuffer `json:"e,omitempty"`
// -- Following fields are only used for private keys --
// RSA uses D, P and Q, while ECDSA uses only D. Fields Dp, Dq, and Qi are
// completely optional. Therefore for RSA/ECDSA, D != nil is a contract that
// we have a private key whereas D == nil means we have only a public key.
D *byteBuffer `json:"d,omitempty"`
P *byteBuffer `json:"p,omitempty"`
Q *byteBuffer `json:"q,omitempty"`
Dp *byteBuffer `json:"dp,omitempty"`
Dq *byteBuffer `json:"dq,omitempty"`
Qi *byteBuffer `json:"qi,omitempty"`
// Certificates
X5c []string `json:"x5c,omitempty"`
}
// JsonWebKey represents a public or private key in JWK format.
type JsonWebKey struct {
Key interface{}
Certificates []*x509.Certificate
KeyID string
Algorithm string
Use string
}
// MarshalJSON serializes the given key to its JSON representation.
func (k JsonWebKey) MarshalJSON() ([]byte, error) {
var raw *rawJsonWebKey
var err error
switch key := k.Key.(type) {
case *ecdsa.PublicKey:
raw, err = fromEcPublicKey(key)
case *rsa.PublicKey:
raw = fromRsaPublicKey(key)
case *ecdsa.PrivateKey:
raw, err = fromEcPrivateKey(key)
case *rsa.PrivateKey:
raw, err = fromRsaPrivateKey(key)
case []byte:
raw, err = fromSymmetricKey(key)
default:
return nil, fmt.Errorf("square/go-jose: unknown key type '%s'", reflect.TypeOf(key))
}
if err != nil {
return nil, err
}
raw.Kid = k.KeyID
raw.Alg = k.Algorithm
raw.Use = k.Use
for _, cert := range k.Certificates {
raw.X5c = append(raw.X5c, base64.StdEncoding.EncodeToString(cert.Raw))
}
return json.Marshal(raw)
}
// UnmarshalJSON reads a key from its JSON representation.
func (k *JsonWebKey) UnmarshalJSON(data []byte) (err error) {
var raw rawJsonWebKey
err = json.Unmarshal(data, &raw)
if err != nil {
return err
}
var key interface{}
switch raw.Kty {
case "EC":
if raw.D != nil {
key, err = raw.ecPrivateKey()
} else {
key, err = raw.ecPublicKey()
}
case "RSA":
if raw.D != nil {
key, err = raw.rsaPrivateKey()
} else {
key, err = raw.rsaPublicKey()
}
case "oct":
key, err = raw.symmetricKey()
default:
err = fmt.Errorf("square/go-jose: unknown json web key type '%s'", raw.Kty)
}
if err == nil {
*k = JsonWebKey{Key: key, KeyID: raw.Kid, Algorithm: raw.Alg, Use: raw.Use}
}
k.Certificates = make([]*x509.Certificate, len(raw.X5c))
for i, cert := range raw.X5c {
raw, err := base64.StdEncoding.DecodeString(cert)
if err != nil {
return err
}
k.Certificates[i], err = x509.ParseCertificate(raw)
if err != nil {
return err
}
}
return
}
// JsonWebKeySet represents a JWK Set object.
type JsonWebKeySet struct {
Keys []JsonWebKey `json:"keys"`
}
// Key convenience method returns keys by key ID. Specification states
// that a JWK Set "SHOULD" use distinct key IDs, but allows for some
// cases where they are not distinct. Hence method returns a slice
// of JsonWebKeys.
func (s *JsonWebKeySet) Key(kid string) []JsonWebKey {
var keys []JsonWebKey
for _, key := range s.Keys {
if key.KeyID == kid {
keys = append(keys, key)
}
}
return keys
}
const rsaThumbprintTemplate = `{"e":"%s","kty":"RSA","n":"%s"}`
const ecThumbprintTemplate = `{"crv":"%s","kty":"EC","x":"%s","y":"%s"}`
func ecThumbprintInput(curve elliptic.Curve, x, y *big.Int) (string, error) {
coordLength := curveSize(curve)
crv, err := curveName(curve)
if err != nil {
return "", err
}
return fmt.Sprintf(ecThumbprintTemplate, crv,
newFixedSizeBuffer(x.Bytes(), coordLength).base64(),
newFixedSizeBuffer(y.Bytes(), coordLength).base64()), nil
}
func rsaThumbprintInput(n *big.Int, e int) (string, error) {
return fmt.Sprintf(rsaThumbprintTemplate,
newBufferFromInt(uint64(e)).base64(),
newBuffer(n.Bytes()).base64()), nil
}
// Thumbprint computes the JWK Thumbprint of a key using the
// indicated hash algorithm.
func (k *JsonWebKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
var input string
var err error
switch key := k.Key.(type) {
case *ecdsa.PublicKey:
input, err = ecThumbprintInput(key.Curve, key.X, key.Y)
case *ecdsa.PrivateKey:
input, err = ecThumbprintInput(key.Curve, key.X, key.Y)
case *rsa.PublicKey:
input, err = rsaThumbprintInput(key.N, key.E)
case *rsa.PrivateKey:
input, err = rsaThumbprintInput(key.N, key.E)
default:
return nil, fmt.Errorf("square/go-jose: unknown key type '%s'", reflect.TypeOf(key))
}
if err != nil {
return nil, err
}
h := hash.New()
h.Write([]byte(input))
return h.Sum(nil), nil
}
// Valid checks that the key contains the expected parameters
func (k *JsonWebKey) Valid() bool {
if k.Key == nil {
return false
}
switch key := k.Key.(type) {
case *ecdsa.PublicKey:
if key.Curve == nil || key.X == nil || key.Y == nil {
return false
}
case *ecdsa.PrivateKey:
if key.Curve == nil || key.X == nil || key.Y == nil || key.D == nil {
return false
}
case *rsa.PublicKey:
if key.N == nil || key.E == 0 {
return false
}
case *rsa.PrivateKey:
if key.N == nil || key.E == 0 || key.D == nil || len(key.Primes) < 2 {
return false
}
default:
return false
}
return true
}
func (key rawJsonWebKey) rsaPublicKey() (*rsa.PublicKey, error) {
if key.N == nil || key.E == nil {
return nil, fmt.Errorf("square/go-jose: invalid RSA key, missing n/e values")
}
return &rsa.PublicKey{
N: key.N.bigInt(),
E: key.E.toInt(),
}, nil
}
func fromRsaPublicKey(pub *rsa.PublicKey) *rawJsonWebKey {
return &rawJsonWebKey{
Kty: "RSA",
N: newBuffer(pub.N.Bytes()),
E: newBufferFromInt(uint64(pub.E)),
}
}
func (key rawJsonWebKey) ecPublicKey() (*ecdsa.PublicKey, error) {
var curve elliptic.Curve
switch key.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("square/go-jose: unsupported elliptic curve '%s'", key.Crv)
}
if key.X == nil || key.Y == nil {
return nil, errors.New("square/go-jose: invalid EC key, missing x/y values")
}
x := key.X.bigInt()
y := key.Y.bigInt()
if !curve.IsOnCurve(x, y) {
return nil, errors.New("square/go-jose: invalid EC key, X/Y are not on declared curve")
}
return &ecdsa.PublicKey{
Curve: curve,
X: x,
Y: y,
}, nil
}
func fromEcPublicKey(pub *ecdsa.PublicKey) (*rawJsonWebKey, error) {
if pub == nil || pub.X == nil || pub.Y == nil {
return nil, fmt.Errorf("square/go-jose: invalid EC key (nil, or X/Y missing)")
}
name, err := curveName(pub.Curve)
if err != nil {
return nil, err
}
size := curveSize(pub.Curve)
xBytes := pub.X.Bytes()
yBytes := pub.Y.Bytes()
if len(xBytes) > size || len(yBytes) > size {
return nil, fmt.Errorf("square/go-jose: invalid EC key (X/Y too large)")
}
key := &rawJsonWebKey{
Kty: "EC",
Crv: name,
X: newFixedSizeBuffer(xBytes, size),
Y: newFixedSizeBuffer(yBytes, size),
}
return key, nil
}
func (key rawJsonWebKey) rsaPrivateKey() (*rsa.PrivateKey, error) {
var missing []string
switch {
case key.N == nil:
missing = append(missing, "N")
case key.E == nil:
missing = append(missing, "E")
case key.D == nil:
missing = append(missing, "D")
case key.P == nil:
missing = append(missing, "P")
case key.Q == nil:
missing = append(missing, "Q")
}
if len(missing) > 0 {
return nil, fmt.Errorf("square/go-jose: invalid RSA private key, missing %s value(s)", strings.Join(missing, ", "))
}
rv := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: key.N.bigInt(),
E: key.E.toInt(),
},
D: key.D.bigInt(),
Primes: []*big.Int{
key.P.bigInt(),
key.Q.bigInt(),
},
}
if key.Dp != nil {
rv.Precomputed.Dp = key.Dp.bigInt()
}
if key.Dq != nil {
rv.Precomputed.Dq = key.Dq.bigInt()
}
if key.Qi != nil {
rv.Precomputed.Qinv = key.Qi.bigInt()
}
err := rv.Validate()
return rv, err
}
func fromRsaPrivateKey(rsa *rsa.PrivateKey) (*rawJsonWebKey, error) {
if len(rsa.Primes) != 2 {
return nil, ErrUnsupportedKeyType
}
raw := fromRsaPublicKey(&rsa.PublicKey)
raw.D = newBuffer(rsa.D.Bytes())
raw.P = newBuffer(rsa.Primes[0].Bytes())
raw.Q = newBuffer(rsa.Primes[1].Bytes())
return raw, nil
}
func (key rawJsonWebKey) ecPrivateKey() (*ecdsa.PrivateKey, error) {
var curve elliptic.Curve
switch key.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("square/go-jose: unsupported elliptic curve '%s'", key.Crv)
}
if key.X == nil || key.Y == nil || key.D == nil {
return nil, fmt.Errorf("square/go-jose: invalid EC private key, missing x/y/d values")
}
x := key.X.bigInt()
y := key.Y.bigInt()
if !curve.IsOnCurve(x, y) {
return nil, errors.New("square/go-jose: invalid EC key, X/Y are not on declared curve")
}
return &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: curve,
X: x,
Y: y,
},
D: key.D.bigInt(),
}, nil
}
func fromEcPrivateKey(ec *ecdsa.PrivateKey) (*rawJsonWebKey, error) {
raw, err := fromEcPublicKey(&ec.PublicKey)
if err != nil {
return nil, err
}
if ec.D == nil {
return nil, fmt.Errorf("square/go-jose: invalid EC private key")
}
raw.D = newBuffer(ec.D.Bytes())
return raw, nil
}
func fromSymmetricKey(key []byte) (*rawJsonWebKey, error) {
return &rawJsonWebKey{
Kty: "oct",
K: newBuffer(key),
}, nil
}
func (key rawJsonWebKey) symmetricKey() ([]byte, error) {
if key.K == nil {
return nil, fmt.Errorf("square/go-jose: invalid OCT (symmetric) key, missing k value")
}
return key.K.bytes(), nil
}

View file

@ -0,0 +1,255 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1
package jose
import (
"fmt"
"strings"
"encoding/json"
)
// rawJsonWebSignature represents a raw JWS JSON object. Used for parsing/serializing.
type rawJsonWebSignature struct {
Payload *byteBuffer `json:"payload,omitempty"`
Signatures []rawSignatureInfo `json:"signatures,omitempty"`
Protected *byteBuffer `json:"protected,omitempty"`
Header *rawHeader `json:"header,omitempty"`
Signature *byteBuffer `json:"signature,omitempty"`
}
// rawSignatureInfo represents a single JWS signature over the JWS payload and protected header.
type rawSignatureInfo struct {
Protected *byteBuffer `json:"protected,omitempty"`
Header *rawHeader `json:"header,omitempty"`
Signature *byteBuffer `json:"signature,omitempty"`
}
// JsonWebSignature represents a signed JWS object after parsing.
type JsonWebSignature struct {
payload []byte
Signatures []Signature
}
// Signature represents a single signature over the JWS payload and protected header.
type Signature struct {
// Header fields, such as the signature algorithm
Header JoseHeader
// The actual signature value
Signature []byte
protected *rawHeader
header *rawHeader
original *rawSignatureInfo
}
// ParseSigned parses a signed message in compact or full serialization format.
func ParseSigned(input string) (*JsonWebSignature, error) {
input = stripWhitespace(input)
if strings.HasPrefix(input, "{") {
return parseSignedFull(input)
}
return parseSignedCompact(input)
}
// Get a header value
func (sig Signature) mergedHeaders() rawHeader {
out := rawHeader{}
out.merge(sig.protected)
out.merge(sig.header)
return out
}
// Compute data to be signed
func (obj JsonWebSignature) computeAuthData(signature *Signature) []byte {
var serializedProtected string
if signature.original != nil && signature.original.Protected != nil {
serializedProtected = signature.original.Protected.base64()
} else if signature.protected != nil {
serializedProtected = base64URLEncode(mustSerializeJSON(signature.protected))
} else {
serializedProtected = ""
}
return []byte(fmt.Sprintf("%s.%s",
serializedProtected,
base64URLEncode(obj.payload)))
}
// parseSignedFull parses a message in full format.
func parseSignedFull(input string) (*JsonWebSignature, error) {
var parsed rawJsonWebSignature
err := json.Unmarshal([]byte(input), &parsed)
if err != nil {
return nil, err
}
return parsed.sanitized()
}
// sanitized produces a cleaned-up JWS object from the raw JSON.
func (parsed *rawJsonWebSignature) sanitized() (*JsonWebSignature, error) {
if parsed.Payload == nil {
return nil, fmt.Errorf("square/go-jose: missing payload in JWS message")
}
obj := &JsonWebSignature{
payload: parsed.Payload.bytes(),
Signatures: make([]Signature, len(parsed.Signatures)),
}
if len(parsed.Signatures) == 0 {
// No signatures array, must be flattened serialization
signature := Signature{}
if parsed.Protected != nil && len(parsed.Protected.bytes()) > 0 {
signature.protected = &rawHeader{}
err := json.Unmarshal(parsed.Protected.bytes(), signature.protected)
if err != nil {
return nil, err
}
}
if parsed.Header != nil && parsed.Header.Nonce != "" {
return nil, ErrUnprotectedNonce
}
signature.header = parsed.Header
signature.Signature = parsed.Signature.bytes()
// Make a fake "original" rawSignatureInfo to store the unprocessed
// Protected header. This is necessary because the Protected header can
// contain arbitrary fields not registered as part of the spec. See
// https://tools.ietf.org/html/draft-ietf-jose-json-web-signature-41#section-4
// If we unmarshal Protected into a rawHeader with its explicit list of fields,
// we cannot marshal losslessly. So we have to keep around the original bytes.
// This is used in computeAuthData, which will first attempt to use
// the original bytes of a protected header, and fall back on marshaling the
// header struct only if those bytes are not available.
signature.original = &rawSignatureInfo{
Protected: parsed.Protected,
Header: parsed.Header,
Signature: parsed.Signature,
}
signature.Header = signature.mergedHeaders().sanitized()
obj.Signatures = append(obj.Signatures, signature)
}
for i, sig := range parsed.Signatures {
if sig.Protected != nil && len(sig.Protected.bytes()) > 0 {
obj.Signatures[i].protected = &rawHeader{}
err := json.Unmarshal(sig.Protected.bytes(), obj.Signatures[i].protected)
if err != nil {
return nil, err
}
}
// Check that there is not a nonce in the unprotected header
if sig.Header != nil && sig.Header.Nonce != "" {
return nil, ErrUnprotectedNonce
}
obj.Signatures[i].Signature = sig.Signature.bytes()
// Copy value of sig
original := sig
obj.Signatures[i].header = sig.Header
obj.Signatures[i].original = &original
obj.Signatures[i].Header = obj.Signatures[i].mergedHeaders().sanitized()
}
return obj, nil
}
// parseSignedCompact parses a message in compact format.
func parseSignedCompact(input string) (*JsonWebSignature, error) {
parts := strings.Split(input, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("square/go-jose: compact JWS format must have three parts")
}
rawProtected, err := base64URLDecode(parts[0])
if err != nil {
return nil, err
}
payload, err := base64URLDecode(parts[1])
if err != nil {
return nil, err
}
signature, err := base64URLDecode(parts[2])
if err != nil {
return nil, err
}
raw := &rawJsonWebSignature{
Payload: newBuffer(payload),
Protected: newBuffer(rawProtected),
Signature: newBuffer(signature),
}
return raw.sanitized()
}
// CompactSerialize serializes an object using the compact serialization format.
func (obj JsonWebSignature) CompactSerialize() (string, error) {
if len(obj.Signatures) != 1 || obj.Signatures[0].header != nil || obj.Signatures[0].protected == nil {
return "", ErrNotSupported
}
serializedProtected := mustSerializeJSON(obj.Signatures[0].protected)
return fmt.Sprintf(
"%s.%s.%s",
base64URLEncode(serializedProtected),
base64URLEncode(obj.payload),
base64URLEncode(obj.Signatures[0].Signature)), nil
}
// FullSerialize serializes an object using the full JSON serialization format.
func (obj JsonWebSignature) FullSerialize() string {
raw := rawJsonWebSignature{
Payload: newBuffer(obj.payload),
}
if len(obj.Signatures) == 1 {
if obj.Signatures[0].protected != nil {
serializedProtected := mustSerializeJSON(obj.Signatures[0].protected)
raw.Protected = newBuffer(serializedProtected)
}
raw.Header = obj.Signatures[0].header
raw.Signature = newBuffer(obj.Signatures[0].Signature)
} else {
raw.Signatures = make([]rawSignatureInfo, len(obj.Signatures))
for i, signature := range obj.Signatures {
raw.Signatures[i] = rawSignatureInfo{
Header: signature.header,
Signature: newBuffer(signature.Signature),
}
if signature.protected != nil {
raw.Signatures[i].Protected = newBuffer(mustSerializeJSON(signature.protected))
}
}
}
return string(mustSerializeJSON(raw))
}

View file

@ -0,0 +1,225 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1
package jose
import (
"crypto/elliptic"
"errors"
"fmt"
)
// KeyAlgorithm represents a key management algorithm.
type KeyAlgorithm string
// SignatureAlgorithm represents a signature (or MAC) algorithm.
type SignatureAlgorithm string
// ContentEncryption represents a content encryption algorithm.
type ContentEncryption string
// CompressionAlgorithm represents an algorithm used for plaintext compression.
type CompressionAlgorithm string
var (
// ErrCryptoFailure represents an error in cryptographic primitive. This
// occurs when, for example, a message had an invalid authentication tag or
// could not be decrypted.
ErrCryptoFailure = errors.New("square/go-jose: error in cryptographic primitive")
// ErrUnsupportedAlgorithm indicates that a selected algorithm is not
// supported. This occurs when trying to instantiate an encrypter for an
// algorithm that is not yet implemented.
ErrUnsupportedAlgorithm = errors.New("square/go-jose: unknown/unsupported algorithm")
// ErrUnsupportedKeyType indicates that the given key type/format is not
// supported. This occurs when trying to instantiate an encrypter and passing
// it a key of an unrecognized type or with unsupported parameters, such as
// an RSA private key with more than two primes.
ErrUnsupportedKeyType = errors.New("square/go-jose: unsupported key type/format")
// ErrNotSupported serialization of object is not supported. This occurs when
// trying to compact-serialize an object which can't be represented in
// compact form.
ErrNotSupported = errors.New("square/go-jose: compact serialization not supported for object")
// ErrUnprotectedNonce indicates that while parsing a JWS or JWE object, a
// nonce header parameter was included in an unprotected header object.
ErrUnprotectedNonce = errors.New("square/go-jose: Nonce parameter included in unprotected header")
)
// Key management algorithms
const (
RSA1_5 = KeyAlgorithm("RSA1_5") // RSA-PKCS1v1.5
RSA_OAEP = KeyAlgorithm("RSA-OAEP") // RSA-OAEP-SHA1
RSA_OAEP_256 = KeyAlgorithm("RSA-OAEP-256") // RSA-OAEP-SHA256
A128KW = KeyAlgorithm("A128KW") // AES key wrap (128)
A192KW = KeyAlgorithm("A192KW") // AES key wrap (192)
A256KW = KeyAlgorithm("A256KW") // AES key wrap (256)
DIRECT = KeyAlgorithm("dir") // Direct encryption
ECDH_ES = KeyAlgorithm("ECDH-ES") // ECDH-ES
ECDH_ES_A128KW = KeyAlgorithm("ECDH-ES+A128KW") // ECDH-ES + AES key wrap (128)
ECDH_ES_A192KW = KeyAlgorithm("ECDH-ES+A192KW") // ECDH-ES + AES key wrap (192)
ECDH_ES_A256KW = KeyAlgorithm("ECDH-ES+A256KW") // ECDH-ES + AES key wrap (256)
A128GCMKW = KeyAlgorithm("A128GCMKW") // AES-GCM key wrap (128)
A192GCMKW = KeyAlgorithm("A192GCMKW") // AES-GCM key wrap (192)
A256GCMKW = KeyAlgorithm("A256GCMKW") // AES-GCM key wrap (256)
PBES2_HS256_A128KW = KeyAlgorithm("PBES2-HS256+A128KW") // PBES2 + HMAC-SHA256 + AES key wrap (128)
PBES2_HS384_A192KW = KeyAlgorithm("PBES2-HS384+A192KW") // PBES2 + HMAC-SHA384 + AES key wrap (192)
PBES2_HS512_A256KW = KeyAlgorithm("PBES2-HS512+A256KW") // PBES2 + HMAC-SHA512 + AES key wrap (256)
)
// Signature algorithms
const (
HS256 = SignatureAlgorithm("HS256") // HMAC using SHA-256
HS384 = SignatureAlgorithm("HS384") // HMAC using SHA-384
HS512 = SignatureAlgorithm("HS512") // HMAC using SHA-512
RS256 = SignatureAlgorithm("RS256") // RSASSA-PKCS-v1.5 using SHA-256
RS384 = SignatureAlgorithm("RS384") // RSASSA-PKCS-v1.5 using SHA-384
RS512 = SignatureAlgorithm("RS512") // RSASSA-PKCS-v1.5 using SHA-512
ES256 = SignatureAlgorithm("ES256") // ECDSA using P-256 and SHA-256
ES384 = SignatureAlgorithm("ES384") // ECDSA using P-384 and SHA-384
ES512 = SignatureAlgorithm("ES512") // ECDSA using P-521 and SHA-512
PS256 = SignatureAlgorithm("PS256") // RSASSA-PSS using SHA256 and MGF1-SHA256
PS384 = SignatureAlgorithm("PS384") // RSASSA-PSS using SHA384 and MGF1-SHA384
PS512 = SignatureAlgorithm("PS512") // RSASSA-PSS using SHA512 and MGF1-SHA512
)
// Content encryption algorithms
const (
A128CBC_HS256 = ContentEncryption("A128CBC-HS256") // AES-CBC + HMAC-SHA256 (128)
A192CBC_HS384 = ContentEncryption("A192CBC-HS384") // AES-CBC + HMAC-SHA384 (192)
A256CBC_HS512 = ContentEncryption("A256CBC-HS512") // AES-CBC + HMAC-SHA512 (256)
A128GCM = ContentEncryption("A128GCM") // AES-GCM (128)
A192GCM = ContentEncryption("A192GCM") // AES-GCM (192)
A256GCM = ContentEncryption("A256GCM") // AES-GCM (256)
)
// Compression algorithms
const (
NONE = CompressionAlgorithm("") // No compression
DEFLATE = CompressionAlgorithm("DEF") // DEFLATE (RFC 1951)
)
// rawHeader represents the JOSE header for JWE/JWS objects (used for parsing).
type rawHeader struct {
Alg string `json:"alg,omitempty"`
Enc ContentEncryption `json:"enc,omitempty"`
Zip CompressionAlgorithm `json:"zip,omitempty"`
Crit []string `json:"crit,omitempty"`
Apu *byteBuffer `json:"apu,omitempty"`
Apv *byteBuffer `json:"apv,omitempty"`
Epk *JsonWebKey `json:"epk,omitempty"`
Iv *byteBuffer `json:"iv,omitempty"`
Tag *byteBuffer `json:"tag,omitempty"`
Jwk *JsonWebKey `json:"jwk,omitempty"`
Kid string `json:"kid,omitempty"`
Nonce string `json:"nonce,omitempty"`
}
// JoseHeader represents the read-only JOSE header for JWE/JWS objects.
type JoseHeader struct {
KeyID string
JsonWebKey *JsonWebKey
Algorithm string
Nonce string
}
// sanitized produces a cleaned-up header object from the raw JSON.
func (parsed rawHeader) sanitized() JoseHeader {
return JoseHeader{
KeyID: parsed.Kid,
JsonWebKey: parsed.Jwk,
Algorithm: parsed.Alg,
Nonce: parsed.Nonce,
}
}
// Merge headers from src into dst, giving precedence to headers from l.
func (dst *rawHeader) merge(src *rawHeader) {
if src == nil {
return
}
if dst.Alg == "" {
dst.Alg = src.Alg
}
if dst.Enc == "" {
dst.Enc = src.Enc
}
if dst.Zip == "" {
dst.Zip = src.Zip
}
if dst.Crit == nil {
dst.Crit = src.Crit
}
if dst.Crit == nil {
dst.Crit = src.Crit
}
if dst.Apu == nil {
dst.Apu = src.Apu
}
if dst.Apv == nil {
dst.Apv = src.Apv
}
if dst.Epk == nil {
dst.Epk = src.Epk
}
if dst.Iv == nil {
dst.Iv = src.Iv
}
if dst.Tag == nil {
dst.Tag = src.Tag
}
if dst.Kid == "" {
dst.Kid = src.Kid
}
if dst.Jwk == nil {
dst.Jwk = src.Jwk
}
if dst.Nonce == "" {
dst.Nonce = src.Nonce
}
}
// Get JOSE name of curve
func curveName(crv elliptic.Curve) (string, error) {
switch crv {
case elliptic.P256():
return "P-256", nil
case elliptic.P384():
return "P-384", nil
case elliptic.P521():
return "P-521", nil
default:
return "", fmt.Errorf("square/go-jose: unsupported/unknown elliptic curve")
}
}
// Get size of curve in bytes
func curveSize(crv elliptic.Curve) int {
bits := crv.Params().BitSize
div := bits / 8
mod := bits % 8
if mod == 0 {
return div
}
return div + 1
}

View file

@ -0,0 +1,219 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1
package jose
import (
"crypto/ecdsa"
"crypto/rsa"
"fmt"
)
// NonceSource represents a source of random nonces to go into JWS objects
type NonceSource interface {
Nonce() (string, error)
}
// Signer represents a signer which takes a payload and produces a signed JWS object.
type Signer interface {
Sign(payload []byte) (*JsonWebSignature, error)
SetNonceSource(source NonceSource)
SetEmbedJwk(embed bool)
}
// MultiSigner represents a signer which supports multiple recipients.
type MultiSigner interface {
Sign(payload []byte) (*JsonWebSignature, error)
SetNonceSource(source NonceSource)
SetEmbedJwk(embed bool)
AddRecipient(alg SignatureAlgorithm, signingKey interface{}) error
}
type payloadSigner interface {
signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error)
}
type payloadVerifier interface {
verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error
}
type genericSigner struct {
recipients []recipientSigInfo
nonceSource NonceSource
embedJwk bool
}
type recipientSigInfo struct {
sigAlg SignatureAlgorithm
keyID string
publicKey *JsonWebKey
signer payloadSigner
}
// NewSigner creates an appropriate signer based on the key type
func NewSigner(alg SignatureAlgorithm, signingKey interface{}) (Signer, error) {
// NewMultiSigner never fails (currently)
signer := NewMultiSigner()
err := signer.AddRecipient(alg, signingKey)
if err != nil {
return nil, err
}
return signer, nil
}
// NewMultiSigner creates a signer for multiple recipients
func NewMultiSigner() MultiSigner {
return &genericSigner{
recipients: []recipientSigInfo{},
embedJwk: true,
}
}
// newVerifier creates a verifier based on the key type
func newVerifier(verificationKey interface{}) (payloadVerifier, error) {
switch verificationKey := verificationKey.(type) {
case *rsa.PublicKey:
return &rsaEncrypterVerifier{
publicKey: verificationKey,
}, nil
case *ecdsa.PublicKey:
return &ecEncrypterVerifier{
publicKey: verificationKey,
}, nil
case []byte:
return &symmetricMac{
key: verificationKey,
}, nil
case *JsonWebKey:
return newVerifier(verificationKey.Key)
default:
return nil, ErrUnsupportedKeyType
}
}
func (ctx *genericSigner) AddRecipient(alg SignatureAlgorithm, signingKey interface{}) error {
recipient, err := makeJWSRecipient(alg, signingKey)
if err != nil {
return err
}
ctx.recipients = append(ctx.recipients, recipient)
return nil
}
func makeJWSRecipient(alg SignatureAlgorithm, signingKey interface{}) (recipientSigInfo, error) {
switch signingKey := signingKey.(type) {
case *rsa.PrivateKey:
return newRSASigner(alg, signingKey)
case *ecdsa.PrivateKey:
return newECDSASigner(alg, signingKey)
case []byte:
return newSymmetricSigner(alg, signingKey)
case *JsonWebKey:
recipient, err := makeJWSRecipient(alg, signingKey.Key)
if err != nil {
return recipientSigInfo{}, err
}
recipient.keyID = signingKey.KeyID
return recipient, nil
default:
return recipientSigInfo{}, ErrUnsupportedKeyType
}
}
func (ctx *genericSigner) Sign(payload []byte) (*JsonWebSignature, error) {
obj := &JsonWebSignature{}
obj.payload = payload
obj.Signatures = make([]Signature, len(ctx.recipients))
for i, recipient := range ctx.recipients {
protected := &rawHeader{
Alg: string(recipient.sigAlg),
}
if recipient.publicKey != nil && ctx.embedJwk {
protected.Jwk = recipient.publicKey
}
if recipient.keyID != "" {
protected.Kid = recipient.keyID
}
if ctx.nonceSource != nil {
nonce, err := ctx.nonceSource.Nonce()
if err != nil {
return nil, fmt.Errorf("square/go-jose: Error generating nonce: %v", err)
}
protected.Nonce = nonce
}
serializedProtected := mustSerializeJSON(protected)
input := []byte(fmt.Sprintf("%s.%s",
base64URLEncode(serializedProtected),
base64URLEncode(payload)))
signatureInfo, err := recipient.signer.signPayload(input, recipient.sigAlg)
if err != nil {
return nil, err
}
signatureInfo.protected = protected
obj.Signatures[i] = signatureInfo
}
return obj, nil
}
// SetNonceSource provides or updates a nonce pool to the first recipients.
// After this method is called, the signer will consume one nonce per
// signature, returning an error it is unable to get a nonce.
func (ctx *genericSigner) SetNonceSource(source NonceSource) {
ctx.nonceSource = source
}
// SetEmbedJwk specifies if the signing key should be embedded in the protected header,
// if any. It defaults to 'true'.
func (ctx *genericSigner) SetEmbedJwk(embed bool) {
ctx.embedJwk = embed
}
// Verify validates the signature on the object and returns the payload.
func (obj JsonWebSignature) Verify(verificationKey interface{}) ([]byte, error) {
verifier, err := newVerifier(verificationKey)
if err != nil {
return nil, err
}
for _, signature := range obj.Signatures {
headers := signature.mergedHeaders()
if len(headers.Crit) > 0 {
// Unsupported crit header
continue
}
input := obj.computeAuthData(&signature)
alg := SignatureAlgorithm(headers.Alg)
err := verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {
return obj.payload, nil
}
}
return nil, ErrCryptoFailure
}

View file

@ -0,0 +1,350 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1
package jose
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"crypto/subtle"
"errors"
"hash"
"io"
"github.com/ossrs/go-oryx-lib/https/jose/cipher"
)
// Random reader (stubbed out in tests)
var randReader = rand.Reader
// Dummy key cipher for shared symmetric key mode
type symmetricKeyCipher struct {
key []byte // Pre-shared content-encryption key
}
// Signer/verifier for MAC modes
type symmetricMac struct {
key []byte
}
// Input/output from an AEAD operation
type aeadParts struct {
iv, ciphertext, tag []byte
}
// A content cipher based on an AEAD construction
type aeadContentCipher struct {
keyBytes int
authtagBytes int
getAead func(key []byte) (cipher.AEAD, error)
}
// Random key generator
type randomKeyGenerator struct {
size int
}
// Static key generator
type staticKeyGenerator struct {
key []byte
}
// Create a new content cipher based on AES-GCM
func newAESGCM(keySize int) contentCipher {
return &aeadContentCipher{
keyBytes: keySize,
authtagBytes: 16,
getAead: func(key []byte) (cipher.AEAD, error) {
aes, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewGCM(aes)
},
}
}
// Create a new content cipher based on AES-CBC+HMAC
func newAESCBC(keySize int) contentCipher {
return &aeadContentCipher{
keyBytes: keySize * 2,
authtagBytes: 16,
getAead: func(key []byte) (cipher.AEAD, error) {
return josecipher.NewCBCHMAC(key, aes.NewCipher)
},
}
}
// Get an AEAD cipher object for the given content encryption algorithm
func getContentCipher(alg ContentEncryption) contentCipher {
switch alg {
case A128GCM:
return newAESGCM(16)
case A192GCM:
return newAESGCM(24)
case A256GCM:
return newAESGCM(32)
case A128CBC_HS256:
return newAESCBC(16)
case A192CBC_HS384:
return newAESCBC(24)
case A256CBC_HS512:
return newAESCBC(32)
default:
return nil
}
}
// newSymmetricRecipient creates a JWE encrypter based on AES-GCM key wrap.
func newSymmetricRecipient(keyAlg KeyAlgorithm, key []byte) (recipientKeyInfo, error) {
switch keyAlg {
case DIRECT, A128GCMKW, A192GCMKW, A256GCMKW, A128KW, A192KW, A256KW:
default:
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
}
return recipientKeyInfo{
keyAlg: keyAlg,
keyEncrypter: &symmetricKeyCipher{
key: key,
},
}, nil
}
// newSymmetricSigner creates a recipientSigInfo based on the given key.
func newSymmetricSigner(sigAlg SignatureAlgorithm, key []byte) (recipientSigInfo, error) {
// Verify that key management algorithm is supported by this encrypter
switch sigAlg {
case HS256, HS384, HS512:
default:
return recipientSigInfo{}, ErrUnsupportedAlgorithm
}
return recipientSigInfo{
sigAlg: sigAlg,
signer: &symmetricMac{
key: key,
},
}, nil
}
// Generate a random key for the given content cipher
func (ctx randomKeyGenerator) genKey() ([]byte, rawHeader, error) {
key := make([]byte, ctx.size)
_, err := io.ReadFull(randReader, key)
if err != nil {
return nil, rawHeader{}, err
}
return key, rawHeader{}, nil
}
// Key size for random generator
func (ctx randomKeyGenerator) keySize() int {
return ctx.size
}
// Generate a static key (for direct mode)
func (ctx staticKeyGenerator) genKey() ([]byte, rawHeader, error) {
cek := make([]byte, len(ctx.key))
copy(cek, ctx.key)
return cek, rawHeader{}, nil
}
// Key size for static generator
func (ctx staticKeyGenerator) keySize() int {
return len(ctx.key)
}
// Get key size for this cipher
func (ctx aeadContentCipher) keySize() int {
return ctx.keyBytes
}
// Encrypt some data
func (ctx aeadContentCipher) encrypt(key, aad, pt []byte) (*aeadParts, error) {
// Get a new AEAD instance
aead, err := ctx.getAead(key)
if err != nil {
return nil, err
}
// Initialize a new nonce
iv := make([]byte, aead.NonceSize())
_, err = io.ReadFull(randReader, iv)
if err != nil {
return nil, err
}
ciphertextAndTag := aead.Seal(nil, iv, pt, aad)
offset := len(ciphertextAndTag) - ctx.authtagBytes
return &aeadParts{
iv: iv,
ciphertext: ciphertextAndTag[:offset],
tag: ciphertextAndTag[offset:],
}, nil
}
// Decrypt some data
func (ctx aeadContentCipher) decrypt(key, aad []byte, parts *aeadParts) ([]byte, error) {
aead, err := ctx.getAead(key)
if err != nil {
return nil, err
}
return aead.Open(nil, parts.iv, append(parts.ciphertext, parts.tag...), aad)
}
// Encrypt the content encryption key.
func (ctx *symmetricKeyCipher) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
switch alg {
case DIRECT:
return recipientInfo{
header: &rawHeader{},
}, nil
case A128GCMKW, A192GCMKW, A256GCMKW:
aead := newAESGCM(len(ctx.key))
parts, err := aead.encrypt(ctx.key, []byte{}, cek)
if err != nil {
return recipientInfo{}, err
}
return recipientInfo{
header: &rawHeader{
Iv: newBuffer(parts.iv),
Tag: newBuffer(parts.tag),
},
encryptedKey: parts.ciphertext,
}, nil
case A128KW, A192KW, A256KW:
block, err := aes.NewCipher(ctx.key)
if err != nil {
return recipientInfo{}, err
}
jek, err := josecipher.KeyWrap(block, cek)
if err != nil {
return recipientInfo{}, err
}
return recipientInfo{
encryptedKey: jek,
header: &rawHeader{},
}, nil
}
return recipientInfo{}, ErrUnsupportedAlgorithm
}
// Decrypt the content encryption key.
func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
switch KeyAlgorithm(headers.Alg) {
case DIRECT:
cek := make([]byte, len(ctx.key))
copy(cek, ctx.key)
return cek, nil
case A128GCMKW, A192GCMKW, A256GCMKW:
aead := newAESGCM(len(ctx.key))
parts := &aeadParts{
iv: headers.Iv.bytes(),
ciphertext: recipient.encryptedKey,
tag: headers.Tag.bytes(),
}
cek, err := aead.decrypt(ctx.key, []byte{}, parts)
if err != nil {
return nil, err
}
return cek, nil
case A128KW, A192KW, A256KW:
block, err := aes.NewCipher(ctx.key)
if err != nil {
return nil, err
}
cek, err := josecipher.KeyUnwrap(block, recipient.encryptedKey)
if err != nil {
return nil, err
}
return cek, nil
}
return nil, ErrUnsupportedAlgorithm
}
// Sign the given payload
func (ctx symmetricMac) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
mac, err := ctx.hmac(payload, alg)
if err != nil {
return Signature{}, errors.New("square/go-jose: failed to compute hmac")
}
return Signature{
Signature: mac,
protected: &rawHeader{},
}, nil
}
// Verify the given payload
func (ctx symmetricMac) verifyPayload(payload []byte, mac []byte, alg SignatureAlgorithm) error {
expected, err := ctx.hmac(payload, alg)
if err != nil {
return errors.New("square/go-jose: failed to compute hmac")
}
if len(mac) != len(expected) {
return errors.New("square/go-jose: invalid hmac")
}
match := subtle.ConstantTimeCompare(mac, expected)
if match != 1 {
return errors.New("square/go-jose: invalid hmac")
}
return nil
}
// Compute the HMAC based on the given alg value
func (ctx symmetricMac) hmac(payload []byte, alg SignatureAlgorithm) ([]byte, error) {
var hash func() hash.Hash
switch alg {
case HS256:
hash = sha256.New
case HS384:
hash = sha512.New384
case HS512:
hash = sha512.New
default:
return nil, ErrUnsupportedAlgorithm
}
hmac := hmac.New(hash, ctx.key)
// According to documentation, Write() on hash never fails
_, _ = hmac.Write(payload)
return hmac.Sum(nil), nil
}

View file

@ -0,0 +1,75 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// from gopkg.in/square/go-jose.v1
package jose
import (
"crypto/x509"
"encoding/pem"
"fmt"
)
// LoadPublicKey loads a public key from PEM/DER-encoded data.
func LoadPublicKey(data []byte) (interface{}, error) {
input := data
block, _ := pem.Decode(data)
if block != nil {
input = block.Bytes
}
// Try to load SubjectPublicKeyInfo
pub, err0 := x509.ParsePKIXPublicKey(input)
if err0 == nil {
return pub, nil
}
cert, err1 := x509.ParseCertificate(input)
if err1 == nil {
return cert.PublicKey, nil
}
return nil, fmt.Errorf("square/go-jose: parse error, got '%s' and '%s'", err0, err1)
}
// LoadPrivateKey loads a private key from PEM/DER-encoded data.
func LoadPrivateKey(data []byte) (interface{}, error) {
input := data
block, _ := pem.Decode(data)
if block != nil {
input = block.Bytes
}
var priv interface{}
priv, err0 := x509.ParsePKCS1PrivateKey(input)
if err0 == nil {
return priv, nil
}
priv, err1 := x509.ParsePKCS8PrivateKey(input)
if err1 == nil {
return priv, nil
}
priv, err2 := x509.ParseECPrivateKey(input)
if err2 == nil {
return priv, nil
}
return nil, fmt.Errorf("square/go-jose: parse error, got '%s', '%s' and '%s'", err0, err1, err2)
}

View file

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,757 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package letsencrypt obtains TLS certificates from LetsEncrypt.org.
//
// LetsEncrypt.org is a service that issues free SSL/TLS certificates to servers
// that can prove control over the given domain's DNS records or
// the servers pointed at by those records.
//
// Quick Start
//
// A complete HTTP/HTTPS web server using TLS certificates from LetsEncrypt.org,
// redirecting all HTTP access to HTTPS, and maintaining TLS certificates in a file
// letsencrypt.cache across server restarts.
//
// package main
//
// import (
// "fmt"
// "log"
// "net/http"
// "rsc.io/letsencrypt"
// )
//
// func main() {
// http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// fmt.Fprintf(w, "Hello, TLS!\n")
// })
// var m letsencrypt.Manager
// if err := m.CacheFile("letsencrypt.cache"); err != nil {
// log.Fatal(err)
// }
// log.Fatal(m.Serve())
// }
//
// Overview
//
// The fundamental type in this package is the Manager, which
// manages obtaining and refreshing a collection of TLS certificates,
// typically for use by an HTTPS server.
// The example above shows the most basic use of a Manager.
// The use can be customized by calling additional methods of the Manager.
//
// Registration
//
// A Manager m registers anonymously with LetsEncrypt.org, including agreeing to
// the letsencrypt.org terms of service, the first time it needs to obtain a certificate.
// To register with a particular email address and with the option of a
// prompt for agreement with the terms of service, call m.Register.
//
// GetCertificate
//
// The Manager's GetCertificate method returns certificates
// from the Manager's cache, filling the cache by requesting certificates
// from LetsEncrypt.org. In this way, a server with a tls.Config.GetCertificate
// set to m.GetCertificate will demand load a certificate for any host name
// it serves. To force loading of certificates ahead of time, install m.GetCertificate
// as before but then call m.Cert for each host name.
//
// A Manager can only obtain a certificate for a given host name if it can prove
// control of that host name to LetsEncrypt.org. By default it proves control by
// answering an HTTPS-based challenge: when
// the LetsEncrypt.org servers connect to the named host on port 443 (HTTPS),
// the TLS SNI handshake must use m.GetCertificate to obtain a per-host certificate.
// The most common way to satisfy this requirement is for the host name to
// resolve to the IP address of a (single) computer running m.ServeHTTPS,
// or at least running a Go TLS server with tls.Config.GetCertificate set to m.GetCertificate.
// However, other configurations are possible. For example, a group of machines
// could use an implementation of tls.Config.GetCertificate that cached
// certificates but handled cache misses by making RPCs to a Manager m
// on an elected leader machine.
//
// In typical usage, then, the setting of tls.Config.GetCertificate to m.GetCertificate
// serves two purposes: it provides certificates to the TLS server for ordinary serving,
// and it also answers challenges to prove ownership of the domains in order to
// obtain those certificates.
//
// To force the loading of a certificate for a given host into the Manager's cache,
// use m.Cert.
//
// Persistent Storage
//
// If a server always starts with a zero Manager m, the server effectively fetches
// a new certificate for each of its host name from LetsEncrypt.org on each restart.
// This is unfortunate both because the server cannot start if LetsEncrypt.org is
// unavailable and because LetsEncrypt.org limits how often it will issue a certificate
// for a given host name (at time of writing, the limit is 5 per week for a given host name).
// To save server state proactively to a cache file and to reload the server state from
// that same file when creating a new manager, call m.CacheFile with the name of
// the file to use.
//
// For alternate storage uses, m.Marshal returns the current state of the Manager
// as an opaque string, m.Unmarshal sets the state of the Manager using a string
// previously returned by m.Marshal (usually a different m), and m.Watch returns
// a channel that receives notifications about state changes.
//
// Limits
//
// To avoid hitting basic rate limits on LetsEncrypt.org, a given Manager limits all its
// interactions to at most one request every minute, with an initial allowed burst of
// 20 requests.
//
// By default, if GetCertificate is asked for a certificate it does not have, it will in turn
// ask LetsEncrypt.org for that certificate. This opens a potential attack where attackers
// connect to a server by IP address and pretend to be asking for an incorrect host name.
// Then GetCertificate will attempt to obtain a certificate for that host, incorrectly,
// eventually hitting LetsEncrypt.org's rate limit for certificate requests and making it
// impossible to obtain actual certificates. Because servers hold certificates for months
// at a time, however, an attack would need to be sustained over a time period
// of at least a month in order to cause real problems.
//
// To mitigate this kind of attack, a given Manager limits
// itself to an average of one certificate request for a new host every three hours,
// with an initial allowed burst of up to 20 requests.
// Long-running servers will therefore stay
// within the LetsEncrypt.org limit of 300 failed requests per month.
// Certificate refreshes are not subject to this limit.
//
// To eliminate the attack entirely, call m.SetHosts to enumerate the exact set
// of hosts that are allowed in certificate requests.
//
// Web Servers
//
// The basic requirement for use of a Manager is that there be an HTTPS server
// running on port 443 and calling m.GetCertificate to obtain TLS certificates.
// Using standard primitives, the way to do this is:
//
// srv := &http.Server{
// Addr: ":https",
// TLSConfig: &tls.Config{
// GetCertificate: m.GetCertificate,
// },
// }
// srv.ListenAndServeTLS("", "")
//
// However, this pattern of serving HTTPS with demand-loaded TLS certificates
// comes up enough to wrap into a single method m.ServeHTTPS.
//
// Similarly, many HTTPS servers prefer to redirect HTTP clients to the HTTPS URLs.
// That functionality is provided by RedirectHTTP.
//
// The combination of serving HTTPS with demand-loaded TLS certificates and
// serving HTTPS redirects to HTTP clients is provided by m.Serve, as used in
// the original example above.
//
// fork from https://github.com/rsc/letsencrypt
package letsencrypt
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/ossrs/go-oryx-lib/https/net/context"
"github.com/ossrs/go-oryx-lib/https/time/rate"
"github.com/ossrs/go-oryx-lib/https/acme"
)
const letsEncryptURL = "https://acme-v01.api.letsencrypt.org/directory"
const debug = false
// A Manager m takes care of obtaining and refreshing a collection of TLS certificates
// obtained by LetsEncrypt.org.
// The zero Manager is not yet registered with LetsEncrypt.org and has no TLS certificates
// but is nonetheless ready for use.
// See the package comment for an overview of how to use a Manager.
type Manager struct {
mu sync.Mutex
state state
rateLimit *rate.Limiter
newHostLimit *rate.Limiter
certCache map[string]*cacheEntry
certTokens map[string]*tls.Certificate
watchChan chan struct{}
}
// Serve runs an HTTP/HTTPS web server using TLS certificates obtained by the manager.
// The HTTP server redirects all requests to the HTTPS server.
// The HTTPS server obtains TLS certificates as needed and responds to requests
// by invoking http.DefaultServeMux.
//
// Serve does not return unitil the HTTPS server fails to start or else stops.
// Either way, Serve can only return a non-nil error, never nil.
func (m *Manager) Serve() error {
l, err := net.Listen("tcp", ":http")
if err != nil {
return err
}
defer l.Close()
go http.Serve(l, http.HandlerFunc(RedirectHTTP))
return m.ServeHTTPS()
}
// ServeHTTPS runs an HTTPS web server using TLS certificates obtained by the manager.
// The HTTPS server obtains TLS certificates as needed and responds to requests
// by invoking http.DefaultServeMux.
// ServeHTTPS does not return unitil the HTTPS server fails to start or else stops.
// Either way, ServeHTTPS can only return a non-nil error, never nil.
func (m *Manager) ServeHTTPS() error {
srv := &http.Server{
Addr: ":https",
TLSConfig: &tls.Config{
GetCertificate: m.GetCertificate,
},
}
return srv.ListenAndServeTLS("", "")
}
// RedirectHTTP is an HTTP handler (suitable for use with http.HandleFunc)
// that responds to all requests by redirecting to the same URL served over HTTPS.
// It should only be invoked for requests received over HTTP.
func RedirectHTTP(w http.ResponseWriter, r *http.Request) {
if r.TLS != nil || r.Host == "" {
http.Error(w, "not found", 404)
}
u := r.URL
u.Host = r.Host
u.Scheme = "https"
http.Redirect(w, r, u.String(), 302)
}
// state is the serializable state for the Manager.
// It also implements acme.User.
type state struct {
Email string
Reg *acme.RegistrationResource
Key string
key *ecdsa.PrivateKey
Hosts []string
Certs map[string]stateCert
}
func (s *state) GetEmail() string { return s.Email }
func (s *state) GetRegistration() *acme.RegistrationResource { return s.Reg }
func (s *state) GetPrivateKey() crypto.PrivateKey { return s.key }
type stateCert struct {
Cert string
Key string
}
func (cert stateCert) toTLS() (*tls.Certificate, error) {
c, err := tls.X509KeyPair([]byte(cert.Cert), []byte(cert.Key))
if err != nil {
return nil, err
}
return &c, err
}
type cacheEntry struct {
host string
m *Manager
mu sync.Mutex
cert *tls.Certificate
timeout time.Time
refreshing bool
err error
}
func (m *Manager) init() {
m.mu.Lock()
if m.certCache == nil {
m.rateLimit = rate.NewLimiter(rate.Every(1*time.Minute), 20)
m.newHostLimit = rate.NewLimiter(rate.Every(3*time.Hour), 20)
m.certCache = map[string]*cacheEntry{}
m.certTokens = map[string]*tls.Certificate{}
m.watchChan = make(chan struct{}, 1)
m.watchChan <- struct{}{}
}
m.mu.Unlock()
}
// Watch returns the manager's watch channel,
// which delivers a notification after every time the
// manager's state (as exposed by Marshal and Unmarshal) changes.
// All calls to Watch return the same watch channel.
//
// The watch channel includes notifications about changes
// before the first call to Watch, so that in the pattern below,
// the range loop executes once immediately, saving
// the result of setup (along with any background updates that
// may have raced in quickly).
//
// m := new(letsencrypt.Manager)
// setup(m)
// go backgroundUpdates(m)
// for range m.Watch() {
// save(m.Marshal())
// }
//
func (m *Manager) Watch() <-chan struct{} {
m.init()
m.updated()
return m.watchChan
}
func (m *Manager) updated() {
select {
case m.watchChan <- struct{}{}:
default:
}
}
func (m *Manager) CacheFile(name string) error {
f, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
return err
}
f.Close()
data, err := ioutil.ReadFile(name)
if err != nil {
return err
}
if len(data) > 0 {
if err := m.Unmarshal(string(data)); err != nil {
return err
}
}
go func() {
for range m.Watch() {
err := ioutil.WriteFile(name, []byte(m.Marshal()), 0600)
if err != nil {
log.Printf("writing letsencrypt cache: %v", err)
}
}
}()
return nil
}
// Registered reports whether the manager has registered with letsencrypt.org yet.
func (m *Manager) Registered() bool {
m.init()
m.mu.Lock()
defer m.mu.Unlock()
return m.registered()
}
func (m *Manager) registered() bool {
return m.state.Reg != nil && m.state.Reg.Body.Agreement != ""
}
// Register registers the manager with letsencrypt.org, using the given email address.
// Registration may require agreeing to the letsencrypt.org terms of service.
// If so, Register calls prompt(url) where url is the URL of the terms of service.
// Prompt should report whether the caller agrees to the terms.
// A nil prompt func is taken to mean that the user always agrees.
// The email address is sent to LetsEncrypt.org but otherwise unchecked;
// it can be omitted by passing the empty string.
//
// Calling Register is only required to make sure registration uses a
// particular email address or to insert an explicit prompt into the
// registration sequence. If the manager is not registered, it will
// automatically register with no email address and automatic
// agreement to the terms of service at the first call to Cert or GetCertificate.
func (m *Manager) Register(email string, prompt func(string) bool) error {
m.init()
m.mu.Lock()
defer m.mu.Unlock()
return m.register(email, prompt)
}
func (m *Manager) register(email string, prompt func(string) bool) error {
if m.registered() {
return fmt.Errorf("already registered")
}
m.state.Email = email
if m.state.key == nil {
key, err := newKey()
if err != nil {
return fmt.Errorf("generating key: %v", err)
}
Key, err := marshalKey(key)
if err != nil {
return fmt.Errorf("generating key: %v", err)
}
m.state.key = key
m.state.Key = string(Key)
}
c, err := acme.NewClient(letsEncryptURL, &m.state, acme.EC256)
if err != nil {
return fmt.Errorf("create client: %v", err)
}
reg, err := c.Register()
if err != nil {
return fmt.Errorf("register: %v", err)
}
m.state.Reg = reg
if reg.Body.Agreement == "" {
if prompt != nil && !prompt(reg.TosURL) {
return fmt.Errorf("did not agree to TOS")
}
if err := c.AgreeToTOS(); err != nil {
return fmt.Errorf("agreeing to TOS: %v", err)
}
}
m.updated()
return nil
}
// Marshal returns an encoding of the manager's state,
// suitable for writing to disk and reloading by calling Unmarshal.
// The state includes registration status, the configured host list
// from SetHosts, and all known certificates, including their private
// cryptographic keys.
// Consequently, the state should be kept private.
func (m *Manager) Marshal() string {
m.init()
js, err := json.MarshalIndent(&m.state, "", "\t")
if err != nil {
panic("unexpected json.Marshal failure")
}
return string(js)
}
// Unmarshal restores the state encoded by a previous call to Marshal
// (perhaps on a different Manager in a different program).
func (m *Manager) Unmarshal(enc string) error {
m.init()
var st state
if err := json.Unmarshal([]byte(enc), &st); err != nil {
return err
}
if st.Key != "" {
key, err := unmarshalKey(st.Key)
if err != nil {
return err
}
st.key = key
}
m.state = st
for host, cert := range m.state.Certs {
c, err := cert.toTLS()
if err != nil {
log.Printf("letsencrypt: ignoring entry for %s: %v", host, err)
continue
}
m.certCache[host] = &cacheEntry{host: host, m: m, cert: c}
}
m.updated()
return nil
}
// SetHosts sets the manager's list of known host names.
// If the list is non-nil, the manager will only ever attempt to acquire
// certificates for host names on the list.
// If the list is nil, the manager does not restrict the hosts it will
// ask for certificates for.
func (m *Manager) SetHosts(hosts []string) {
m.init()
m.mu.Lock()
m.state.Hosts = append(m.state.Hosts[:0], hosts...)
m.mu.Unlock()
m.updated()
}
// GetCertificate can be placed a tls.Config's GetCertificate field to make
// the TLS server use Let's Encrypt certificates.
// Each time a client connects to the TLS server expecting a new host name,
// the TLS server's call to GetCertificate will trigger an exchange with the
// Let's Encrypt servers to obtain that certificate, subject to the manager rate limits.
//
// As noted in the Manager's documentation comment,
// to obtain a certificate for a given host name, that name
// must resolve to a computer running a TLS server on port 443
// that obtains TLS SNI certificates by calling m.GetCertificate.
// In the standard usage, then, installing m.GetCertificate in the tls.Config
// both automatically provisions the TLS certificates needed for
// ordinary HTTPS service and answers the challenges from LetsEncrypt.org.
func (m *Manager) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
m.init()
host := clientHello.ServerName
if debug {
log.Printf("GetCertificate %s", host)
}
if strings.HasSuffix(host, ".acme.invalid") {
m.mu.Lock()
cert := m.certTokens[host]
m.mu.Unlock()
if cert == nil {
return nil, fmt.Errorf("unknown host")
}
return cert, nil
}
return m.Cert(host)
}
// Cert returns the certificate for the given host name, obtaining a new one if necessary.
//
// As noted in the documentation for Manager and for the GetCertificate method,
// obtaining a certificate requires that m.GetCertificate be associated with host.
// In most servers, simply starting a TLS server with a configuration referring
// to m.GetCertificate is sufficient, and Cert need not be called.
//
// The main use of Cert is to force the manager to obtain a certificate
// for a particular host name ahead of time.
func (m *Manager) Cert(host string) (*tls.Certificate, error) {
host = strings.ToLower(host)
if debug {
log.Printf("Cert %s", host)
}
m.init()
m.mu.Lock()
if !m.registered() {
m.register("", nil)
}
ok := false
if m.state.Hosts == nil {
ok = true
} else {
for _, h := range m.state.Hosts {
if host == h {
ok = true
break
}
}
}
if !ok {
m.mu.Unlock()
return nil, fmt.Errorf("unknown host")
}
// Otherwise look in our cert cache.
entry, ok := m.certCache[host]
if !ok {
r := m.rateLimit.Reserve()
ok := r.OK()
if ok {
ok = m.newHostLimit.Allow()
if !ok {
r.Cancel()
}
}
if !ok {
m.mu.Unlock()
return nil, fmt.Errorf("rate limited")
}
entry = &cacheEntry{host: host, m: m}
m.certCache[host] = entry
}
m.mu.Unlock()
entry.mu.Lock()
defer entry.mu.Unlock()
entry.init()
if entry.err != nil {
return nil, entry.err
}
return entry.cert, nil
}
func (e *cacheEntry) init() {
if e.err != nil && time.Now().Before(e.timeout) {
return
}
if e.cert != nil {
if e.timeout.IsZero() {
t, err := certRefreshTime(e.cert)
if err != nil {
e.err = err
e.timeout = time.Now().Add(1 * time.Minute)
e.cert = nil
return
}
e.timeout = t
}
if time.Now().After(e.timeout) && !e.refreshing {
e.refreshing = true
go e.refresh()
}
return
}
cert, refreshTime, err := e.m.verify(e.host)
e.m.mu.Lock()
e.m.certCache[e.host] = e
e.m.mu.Unlock()
e.install(cert, refreshTime, err)
}
func (e *cacheEntry) install(cert *tls.Certificate, refreshTime time.Time, err error) {
e.cert = nil
e.timeout = time.Time{}
e.err = nil
if err != nil {
e.err = err
e.timeout = time.Now().Add(1 * time.Minute)
return
}
e.cert = cert
e.timeout = refreshTime
}
func (e *cacheEntry) refresh() {
e.m.rateLimit.Wait(context.Background())
cert, refreshTime, err := e.m.verify(e.host)
e.mu.Lock()
defer e.mu.Unlock()
e.refreshing = false
if err == nil {
e.install(cert, refreshTime, nil)
}
}
func (m *Manager) verify(host string) (cert *tls.Certificate, refreshTime time.Time, err error) {
c, err := acme.NewClient(letsEncryptURL, &m.state, acme.EC256)
if err != nil {
return
}
if err = c.SetChallengeProvider(acme.TLSSNI01, tlsProvider{m}); err != nil {
return
}
c.SetChallengeProvider(acme.TLSSNI01, tlsProvider{m})
c.ExcludeChallenges([]acme.Challenge{acme.HTTP01})
acmeCert, errmap := c.ObtainCertificate([]string{host}, true, nil)
if len(errmap) > 0 {
if debug {
log.Printf("ObtainCertificate %v => %v", host, errmap)
}
err = fmt.Errorf("%v", errmap)
return
}
entryCert := stateCert{
Cert: string(acmeCert.Certificate),
Key: string(acmeCert.PrivateKey),
}
cert, err = entryCert.toTLS()
if err != nil {
if debug {
log.Printf("ObtainCertificate %v toTLS failure: %v", host, err)
}
err = err
return
}
if refreshTime, err = certRefreshTime(cert); err != nil {
return
}
m.mu.Lock()
if m.state.Certs == nil {
m.state.Certs = make(map[string]stateCert)
}
m.state.Certs[host] = entryCert
m.mu.Unlock()
m.updated()
return cert, refreshTime, nil
}
func certRefreshTime(cert *tls.Certificate) (time.Time, error) {
xc, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
if debug {
log.Printf("ObtainCertificate to X.509 failure: %v", err)
}
return time.Time{}, err
}
t := xc.NotBefore.Add(xc.NotAfter.Sub(xc.NotBefore) / 2)
monthEarly := xc.NotAfter.Add(-30 * 24 * time.Hour)
if t.Before(monthEarly) {
t = monthEarly
}
return t, nil
}
// tlsProvider implements acme.ChallengeProvider for TLS handshake challenges.
type tlsProvider struct {
m *Manager
}
func (p tlsProvider) Present(domain, token, keyAuth string) error {
cert, dom, err := acme.TLSSNI01ChallengeCertDomain(keyAuth)
if err != nil {
return err
}
p.m.mu.Lock()
p.m.certTokens[dom] = &cert
p.m.mu.Unlock()
return nil
}
func (p tlsProvider) CleanUp(domain, token, keyAuth string) error {
_, dom, err := acme.TLSSNI01ChallengeCertDomain(keyAuth)
if err != nil {
return err
}
p.m.mu.Lock()
delete(p.m.certTokens, dom)
p.m.mu.Unlock()
return nil
}
func marshalKey(key *ecdsa.PrivateKey) ([]byte, error) {
data, err := x509.MarshalECPrivateKey(key)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: data}), nil
}
func unmarshalKey(text string) (*ecdsa.PrivateKey, error) {
b, _ := pem.Decode([]byte(text))
if b == nil {
return nil, fmt.Errorf("unmarshalKey: missing key")
}
if b.Type != "EC PRIVATE KEY" {
return nil, fmt.Errorf("unmarshalKey: found %q, not %q", b.Type, "EC PRIVATE KEY")
}
k, err := x509.ParseECPrivateKey(b.Bytes)
if err != nil {
return nil, fmt.Errorf("unmarshalKey: %v", err)
}
return k, nil
}
func newKey() (*ecdsa.PrivateKey, error) {
return ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
}

View file

@ -0,0 +1,158 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package context defines the Context type, which carries deadlines,
// cancelation signals, and other request-scoped values across API boundaries
// and between processes.
//
// Incoming requests to a server should create a Context, and outgoing calls to
// servers should accept a Context. The chain of function calls between must
// propagate the Context, optionally replacing it with a modified copy created
// using WithDeadline, WithTimeout, WithCancel, or WithValue.
//
// Programs that use Contexts should follow these rules to keep interfaces
// consistent across packages and enable static analysis tools to check context
// propagation:
//
// Do not store Contexts inside a struct type; instead, pass a Context
// explicitly to each function that needs it. The Context should be the first
// parameter, typically named ctx:
//
// func DoSomething(ctx context.Context, arg Arg) error {
// // ... use ctx ...
// }
//
// Do not pass a nil Context, even if a function permits it. Pass context.TODO
// if you are unsure about which Context to use.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
//
// The same Context may be passed to functions running in different goroutines;
// Contexts are safe for simultaneous use by multiple goroutines.
//
// See http://blog.golang.org/context for example code for a server that uses
// Contexts.
// fork from golang.org/x/net/context
// from https://github.com/golang/net
package context // import "github.com/ossrs/go-oryx-lib/https/net/context"
import "time"
// A Context carries a deadline, a cancelation signal, and other values across
// API boundaries.
//
// Context's methods may be called by multiple goroutines simultaneously.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
//
// WithCancel arranges for Done to be closed when cancel is called;
// WithDeadline arranges for Done to be closed when the deadline
// expires; WithTimeout arranges for Done to be closed when the timeout
// elapses.
//
// Done is provided for use in select statements:
//
// // Stream generates values with DoSomething and sends them to out
// // until DoSomething returns an error or ctx.Done is closed.
// func Stream(ctx context.Context, out chan<- Value) error {
// for {
// v, err := DoSomething(ctx)
// if err != nil {
// return err
// }
// select {
// case <-ctx.Done():
// return ctx.Err()
// case out <- v:
// }
// }
// }
//
// See http://blog.golang.org/pipelines for more examples of how to use
// a Done channel for cancelation.
Done() <-chan struct{}
// Err returns a non-nil error value after Done is closed. Err returns
// Canceled if the context was canceled or DeadlineExceeded if the
// context's deadline passed. No other values for Err are defined.
// After Done is closed, successive calls to Err return the same value.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
//
// A key identifies a specific value in a Context. Functions that wish
// to store values in Context typically allocate a key in a global
// variable then use that key as the argument to context.WithValue and
// Context.Value. A key can be any type that supports equality;
// packages should define keys as an unexported type to avoid
// collisions.
//
// Packages that define a Context key should provide type-safe accessors
// for the values stores using that key:
//
// // Package user defines a User type that's stored in Contexts.
// package user
//
// import "golang.org/x/net/context"
//
// // User is the type of value stored in the Contexts.
// type User struct {...}
//
// // key is an unexported type for keys defined in this package.
// // This prevents collisions with keys defined in other packages.
// type key int
//
// // userKey is the key for user.User values in Contexts. It is
// // unexported; clients use user.NewContext and user.FromContext
// // instead of using this key directly.
// var userKey key = 0
//
// // NewContext returns a new Context that carries value u.
// func NewContext(ctx context.Context, u *User) context.Context {
// return context.WithValue(ctx, userKey, u)
// }
//
// // FromContext returns the User value stored in ctx, if any.
// func FromContext(ctx context.Context) (*User, bool) {
// u, ok := ctx.Value(userKey).(*User)
// return u, ok
// }
Value(key interface{}) interface{}
}
// Background returns a non-nil, empty Context. It is never canceled, has no
// values, and has no deadline. It is typically used by the main function,
// initialization, and tests, and as the top-level Context for incoming
// requests.
func Background() Context {
return background
}
// TODO returns a non-nil, empty Context. Code should use context.TODO when
// it's unclear which Context to use or it is not yet available (because the
// surrounding function has not yet been extended to accept a Context
// parameter). TODO is recognized by static analysis tools that determine
// whether Contexts are propagated correctly in a program.
func TODO() Context {
return todo
}
// A CancelFunc tells an operation to abandon its work.
// A CancelFunc does not wait for the work to stop.
// After the first call, subsequent calls to a CancelFunc do nothing.
type CancelFunc func()

View file

@ -0,0 +1,74 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.7
// fork from golang.org/x/net/context
// from https://github.com/golang/net
package context
import (
"context" // standard library's context, as of Go 1.7
"time"
)
var (
todo = context.TODO()
background = context.Background()
)
// Canceled is the error returned by Context.Err when the context is canceled.
var Canceled = context.Canceled
// DeadlineExceeded is the error returned by Context.Err when the context's
// deadline passes.
var DeadlineExceeded = context.DeadlineExceeded
// WithCancel returns a copy of parent with a new Done channel. The returned
// context's Done channel is closed when the returned cancel function is called
// or when the parent context's Done channel is closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
ctx, f := context.WithCancel(parent)
return ctx, CancelFunc(f)
}
// WithDeadline returns a copy of the parent context with the deadline adjusted
// to be no later than d. If the parent's deadline is already earlier than d,
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
// context's Done channel is closed when the deadline expires, when the returned
// cancel function is called, or when the parent context's Done channel is
// closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
ctx, f := context.WithDeadline(parent, deadline)
return ctx, CancelFunc(f)
}
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete:
//
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
// defer cancel() // releases resources if slowOperation completes before timeout elapses
// return slowOperation(ctx)
// }
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return WithDeadline(parent, time.Now().Add(timeout))
}
// WithValue returns a copy of parent in which the value associated with key is
// val.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
func WithValue(parent Context, key interface{}, val interface{}) Context {
return context.WithValue(parent, key, val)
}

View file

@ -0,0 +1,302 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.7
// fork from golang.org/x/net/context
// from https://github.com/golang/net
package context
import (
"errors"
"fmt"
"sync"
"time"
)
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
// struct{}, since vars of this type must have distinct addresses.
type emptyCtx int
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (*emptyCtx) Done() <-chan struct{} {
return nil
}
func (*emptyCtx) Err() error {
return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}
func (e *emptyCtx) String() string {
switch e {
case background:
return "context.Background"
case todo:
return "context.TODO"
}
return "unknown empty Context"
}
var (
background = new(emptyCtx)
todo = new(emptyCtx)
)
// Canceled is the error returned by Context.Err when the context is canceled.
var Canceled = errors.New("context canceled")
// DeadlineExceeded is the error returned by Context.Err when the context's
// deadline passes.
var DeadlineExceeded = errors.New("context deadline exceeded")
// WithCancel returns a copy of parent with a new Done channel. The returned
// context's Done channel is closed when the returned cancel function is called
// or when the parent context's Done channel is closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
c := newCancelCtx(parent)
propagateCancel(parent, c)
return c, func() { c.cancel(true, Canceled) }
}
// newCancelCtx returns an initialized cancelCtx.
func newCancelCtx(parent Context) *cancelCtx {
return &cancelCtx{
Context: parent,
done: make(chan struct{}),
}
}
// propagateCancel arranges for child to be canceled when parent is.
func propagateCancel(parent Context, child canceler) {
if parent.Done() == nil {
return // parent is never canceled
}
if p, ok := parentCancelCtx(parent); ok {
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
child.cancel(false, p.err)
} else {
if p.children == nil {
p.children = make(map[canceler]bool)
}
p.children[child] = true
}
p.mu.Unlock()
} else {
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err())
case <-child.Done():
}
}()
}
}
// parentCancelCtx follows a chain of parent references until it finds a
// *cancelCtx. This function understands how each of the concrete types in this
// package represents its parent.
func parentCancelCtx(parent Context) (*cancelCtx, bool) {
for {
switch c := parent.(type) {
case *cancelCtx:
return c, true
case *timerCtx:
return c.cancelCtx, true
case *valueCtx:
parent = c.Context
default:
return nil, false
}
}
}
// removeChild removes a context from its parent.
func removeChild(parent Context, child canceler) {
p, ok := parentCancelCtx(parent)
if !ok {
return
}
p.mu.Lock()
if p.children != nil {
delete(p.children, child)
}
p.mu.Unlock()
}
// A canceler is a context type that can be canceled directly. The
// implementations are *cancelCtx and *timerCtx.
type canceler interface {
cancel(removeFromParent bool, err error)
Done() <-chan struct{}
}
// A cancelCtx can be canceled. When canceled, it also cancels any children
// that implement canceler.
type cancelCtx struct {
Context
done chan struct{} // closed by the first cancel call.
mu sync.Mutex
children map[canceler]bool // set to nil by the first cancel call
err error // set to non-nil by the first cancel call
}
func (c *cancelCtx) Done() <-chan struct{} {
return c.done
}
func (c *cancelCtx) Err() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.err
}
func (c *cancelCtx) String() string {
return fmt.Sprintf("%v.WithCancel", c.Context)
}
// cancel closes c.done, cancels each of c's children, and, if
// removeFromParent is true, removes c from its parent's children.
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // already canceled
}
c.err = err
close(c.done)
for child := range c.children {
// NOTE: acquiring the child's lock while holding parent's lock.
child.cancel(false, err)
}
c.children = nil
c.mu.Unlock()
if removeFromParent {
removeChild(c.Context, c)
}
}
// WithDeadline returns a copy of the parent context with the deadline adjusted
// to be no later than d. If the parent's deadline is already earlier than d,
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
// context's Done channel is closed when the deadline expires, when the returned
// cancel function is called, or when the parent context's Done channel is
// closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
if cur, ok := parent.Deadline(); ok && cur.Before(deadline) {
// The current deadline is already sooner than the new one.
return WithCancel(parent)
}
c := &timerCtx{
cancelCtx: newCancelCtx(parent),
deadline: deadline,
}
propagateCancel(parent, c)
d := deadline.Sub(time.Now())
if d <= 0 {
c.cancel(true, DeadlineExceeded) // deadline has already passed
return c, func() { c.cancel(true, Canceled) }
}
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
c.timer = time.AfterFunc(d, func() {
c.cancel(true, DeadlineExceeded)
})
}
return c, func() { c.cancel(true, Canceled) }
}
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
// implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel.
type timerCtx struct {
*cancelCtx
timer *time.Timer // Under cancelCtx.mu.
deadline time.Time
}
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
return c.deadline, true
}
func (c *timerCtx) String() string {
return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now()))
}
func (c *timerCtx) cancel(removeFromParent bool, err error) {
c.cancelCtx.cancel(false, err)
if removeFromParent {
// Remove this timerCtx from its parent cancelCtx's children.
removeChild(c.cancelCtx.Context, c)
}
c.mu.Lock()
if c.timer != nil {
c.timer.Stop()
c.timer = nil
}
c.mu.Unlock()
}
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete:
//
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
// defer cancel() // releases resources if slowOperation completes before timeout elapses
// return slowOperation(ctx)
// }
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return WithDeadline(parent, time.Now().Add(timeout))
}
// WithValue returns a copy of parent in which the value associated with key is
// val.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
func WithValue(parent Context, key interface{}, val interface{}) Context {
return &valueCtx{parent, key, val}
}
// A valueCtx carries a key-value pair. It implements Value for that key and
// delegates all other calls to the embedded Context.
type valueCtx struct {
Context
key, val interface{}
}
func (c *valueCtx) String() string {
return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val)
}
func (c *valueCtx) Value(key interface{}) interface{} {
if c.key == key {
return c.val
}
return c.Context.Value(key)
}

View file

@ -0,0 +1,371 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rate provides a rate limiter.
// fork from golang.org/x/time/rate
// from https://github.com/golang/time
package rate
import (
"fmt"
"math"
"sync"
"time"
"github.com/ossrs/go-oryx-lib/https/net/context"
)
// Limit defines the maximum frequency of some events.
// Limit is represented as number of events per second.
// A zero Limit allows no events.
type Limit float64
// Inf is the infinite rate limit; it allows all events (even if burst is zero).
const Inf = Limit(math.MaxFloat64)
// Every converts a minimum time interval between events to a Limit.
func Every(interval time.Duration) Limit {
if interval <= 0 {
return Inf
}
return 1 / Limit(interval.Seconds())
}
// A Limiter controls how frequently events are allowed to happen.
// It implements a "token bucket" of size b, initially full and refilled
// at rate r tokens per second.
// Informally, in any large enough time interval, the Limiter limits the
// rate to r tokens per second, with a maximum burst size of b events.
// As a special case, if r == Inf (the infinite rate), b is ignored.
// See https://en.wikipedia.org/wiki/Token_bucket for more about token buckets.
//
// The zero value is a valid Limiter, but it will reject all events.
// Use NewLimiter to create non-zero Limiters.
//
// Limiter has three main methods, Allow, Reserve, and Wait.
// Most callers should use Wait.
//
// Each of the three methods consumes a single token.
// They differ in their behavior when no token is available.
// If no token is available, Allow returns false.
// If no token is available, Reserve returns a reservation for a future token
// and the amount of time the caller must wait before using it.
// If no token is available, Wait blocks until one can be obtained
// or its associated context.Context is canceled.
//
// The methods AllowN, ReserveN, and WaitN consume n tokens.
type Limiter struct {
limit Limit
burst int
mu sync.Mutex
tokens float64
// last is the last time the limiter's tokens field was updated
last time.Time
// lastEvent is the latest time of a rate-limited event (past or future)
lastEvent time.Time
}
// Limit returns the maximum overall event rate.
func (lim *Limiter) Limit() Limit {
lim.mu.Lock()
defer lim.mu.Unlock()
return lim.limit
}
// Burst returns the maximum burst size. Burst is the maximum number of tokens
// that can be consumed in a single call to Allow, Reserve, or Wait, so higher
// Burst values allow more events to happen at once.
// A zero Burst allows no events, unless limit == Inf.
func (lim *Limiter) Burst() int {
return lim.burst
}
// NewLimiter returns a new Limiter that allows events up to rate r and permits
// bursts of at most b tokens.
func NewLimiter(r Limit, b int) *Limiter {
return &Limiter{
limit: r,
burst: b,
}
}
// Allow is shorthand for AllowN(time.Now(), 1).
func (lim *Limiter) Allow() bool {
return lim.AllowN(time.Now(), 1)
}
// AllowN reports whether n events may happen at time now.
// Use this method if you intend to drop / skip events that exceed the rate limit.
// Otherwise use Reserve or Wait.
func (lim *Limiter) AllowN(now time.Time, n int) bool {
return lim.reserveN(now, n, 0).ok
}
// A Reservation holds information about events that are permitted by a Limiter to happen after a delay.
// A Reservation may be canceled, which may enable the Limiter to permit additional events.
type Reservation struct {
ok bool
lim *Limiter
tokens int
timeToAct time.Time
// This is the Limit at reservation time, it can change later.
limit Limit
}
// OK returns whether the limiter can provide the requested number of tokens
// within the maximum wait time. If OK is false, Delay returns InfDuration, and
// Cancel does nothing.
func (r *Reservation) OK() bool {
return r.ok
}
// Delay is shorthand for DelayFrom(time.Now()).
func (r *Reservation) Delay() time.Duration {
return r.DelayFrom(time.Now())
}
// InfDuration is the duration returned by Delay when a Reservation is not OK.
const InfDuration = time.Duration(1<<63 - 1)
// DelayFrom returns the duration for which the reservation holder must wait
// before taking the reserved action. Zero duration means act immediately.
// InfDuration means the limiter cannot grant the tokens requested in this
// Reservation within the maximum wait time.
func (r *Reservation) DelayFrom(now time.Time) time.Duration {
if !r.ok {
return InfDuration
}
delay := r.timeToAct.Sub(now)
if delay < 0 {
return 0
}
return delay
}
// Cancel is shorthand for CancelAt(time.Now()).
func (r *Reservation) Cancel() {
r.CancelAt(time.Now())
return
}
// CancelAt indicates that the reservation holder will not perform the reserved action
// and reverses the effects of this Reservation on the rate limit as much as possible,
// considering that other reservations may have already been made.
func (r *Reservation) CancelAt(now time.Time) {
if !r.ok {
return
}
r.lim.mu.Lock()
defer r.lim.mu.Unlock()
if r.lim.limit == Inf || r.tokens == 0 || r.timeToAct.Before(now) {
return
}
// calculate tokens to restore
// The duration between lim.lastEvent and r.timeToAct tells us how many tokens were reserved
// after r was obtained. These tokens should not be restored.
restoreTokens := float64(r.tokens) - r.limit.tokensFromDuration(r.lim.lastEvent.Sub(r.timeToAct))
if restoreTokens <= 0 {
return
}
// advance time to now
now, _, tokens := r.lim.advance(now)
// calculate new number of tokens
tokens += restoreTokens
if burst := float64(r.lim.burst); tokens > burst {
tokens = burst
}
// update state
r.lim.last = now
r.lim.tokens = tokens
if r.timeToAct == r.lim.lastEvent {
prevEvent := r.timeToAct.Add(r.limit.durationFromTokens(float64(-r.tokens)))
if !prevEvent.Before(now) {
r.lim.lastEvent = prevEvent
}
}
return
}
// Reserve is shorthand for ReserveN(time.Now(), 1).
func (lim *Limiter) Reserve() *Reservation {
return lim.ReserveN(time.Now(), 1)
}
// ReserveN returns a Reservation that indicates how long the caller must wait before n events happen.
// The Limiter takes this Reservation into account when allowing future events.
// ReserveN returns false if n exceeds the Limiter's burst size.
// Usage example:
// r, ok := lim.ReserveN(time.Now(), 1)
// if !ok {
// // Not allowed to act! Did you remember to set lim.burst to be > 0 ?
// }
// time.Sleep(r.Delay())
// Act()
// Use this method if you wish to wait and slow down in accordance with the rate limit without dropping events.
// If you need to respect a deadline or cancel the delay, use Wait instead.
// To drop or skip events exceeding rate limit, use Allow instead.
func (lim *Limiter) ReserveN(now time.Time, n int) *Reservation {
r := lim.reserveN(now, n, InfDuration)
return &r
}
// Wait is shorthand for WaitN(ctx, 1).
func (lim *Limiter) Wait(ctx context.Context) (err error) {
return lim.WaitN(ctx, 1)
}
// WaitN blocks until lim permits n events to happen.
// It returns an error if n exceeds the Limiter's burst size, the Context is
// canceled, or the expected wait time exceeds the Context's Deadline.
func (lim *Limiter) WaitN(ctx context.Context, n int) (err error) {
if n > lim.burst {
return fmt.Errorf("rate: Wait(n=%d) exceeds limiter's burst %d", n, lim.burst)
}
// Check if ctx is already cancelled
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Determine wait limit
now := time.Now()
waitLimit := InfDuration
if deadline, ok := ctx.Deadline(); ok {
waitLimit = deadline.Sub(now)
}
// Reserve
r := lim.reserveN(now, n, waitLimit)
if !r.ok {
return fmt.Errorf("rate: Wait(n=%d) would exceed context deadline", n)
}
// Wait
t := time.NewTimer(r.DelayFrom(now))
defer t.Stop()
select {
case <-t.C:
// We can proceed.
return nil
case <-ctx.Done():
// Context was canceled before we could proceed. Cancel the
// reservation, which may permit other events to proceed sooner.
r.Cancel()
return ctx.Err()
}
}
// SetLimit is shorthand for SetLimitAt(time.Now(), newLimit).
func (lim *Limiter) SetLimit(newLimit Limit) {
lim.SetLimitAt(time.Now(), newLimit)
}
// SetLimitAt sets a new Limit for the limiter. The new Limit, and Burst, may be violated
// or underutilized by those which reserved (using Reserve or Wait) but did not yet act
// before SetLimitAt was called.
func (lim *Limiter) SetLimitAt(now time.Time, newLimit Limit) {
lim.mu.Lock()
defer lim.mu.Unlock()
now, _, tokens := lim.advance(now)
lim.last = now
lim.tokens = tokens
lim.limit = newLimit
}
// reserveN is a helper method for AllowN, ReserveN, and WaitN.
// maxFutureReserve specifies the maximum reservation wait duration allowed.
// reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN.
func (lim *Limiter) reserveN(now time.Time, n int, maxFutureReserve time.Duration) Reservation {
lim.mu.Lock()
if lim.limit == Inf {
lim.mu.Unlock()
return Reservation{
ok: true,
lim: lim,
tokens: n,
timeToAct: now,
}
}
now, last, tokens := lim.advance(now)
// Calculate the remaining number of tokens resulting from the request.
tokens -= float64(n)
// Calculate the wait duration
var waitDuration time.Duration
if tokens < 0 {
waitDuration = lim.limit.durationFromTokens(-tokens)
}
// Decide result
ok := n <= lim.burst && waitDuration <= maxFutureReserve
// Prepare reservation
r := Reservation{
ok: ok,
lim: lim,
limit: lim.limit,
}
if ok {
r.tokens = n
r.timeToAct = now.Add(waitDuration)
}
// Update state
if ok {
lim.last = now
lim.tokens = tokens
lim.lastEvent = r.timeToAct
} else {
lim.last = last
}
lim.mu.Unlock()
return r
}
// advance calculates and returns an updated state for lim resulting from the passage of time.
// lim is not changed.
func (lim *Limiter) advance(now time.Time) (newNow time.Time, newLast time.Time, newTokens float64) {
last := lim.last
if now.Before(last) {
last = now
}
// Avoid making delta overflow below when last is very old.
maxElapsed := lim.limit.durationFromTokens(float64(lim.burst) - lim.tokens)
elapsed := now.Sub(last)
if elapsed > maxElapsed {
elapsed = maxElapsed
}
// Calculate the new number of tokens, due to time that passed.
delta := lim.limit.tokensFromDuration(elapsed)
tokens := lim.tokens + delta
if burst := float64(lim.burst); tokens > burst {
tokens = burst
}
return now, last, tokens
}
// durationFromTokens is a unit conversion function from the number of tokens to the duration
// of time it takes to accumulate them at a rate of limit tokens per second.
func (limit Limit) durationFromTokens(tokens float64) time.Duration {
seconds := tokens / float64(limit)
return time.Nanosecond * time.Duration(1e9*seconds)
}
// tokensFromDuration is a unit conversion function from a time duration to the number of tokens
// which could be accumulated during that duration at a rate of limit tokens per second.
func (limit Limit) tokensFromDuration(d time.Duration) float64 {
return d.Seconds() * float64(limit)
}

View file

@ -0,0 +1,86 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// +build go1.7
package logger
import (
"context"
"fmt"
"os"
)
func (v *loggerPlus) Println(ctx Context, a ...interface{}) {
args := v.contextFormat(ctx, a...)
v.doPrintln(args...)
}
func (v *loggerPlus) Printf(ctx Context, format string, a ...interface{}) {
format, args := v.contextFormatf(ctx, format, a...)
v.doPrintf(format, args...)
}
func (v *loggerPlus) contextFormat(ctx Context, a ...interface{}) []interface{} {
if ctx, ok := ctx.(context.Context); ok {
if cid, ok := ctx.Value(cidKey).(int); ok {
return append([]interface{}{fmt.Sprintf("[%v][%v]", os.Getpid(), cid)}, a...)
}
} else {
return v.format(ctx, a...)
}
return a
}
func (v *loggerPlus) contextFormatf(ctx Context, format string, a ...interface{}) (string, []interface{}) {
if ctx, ok := ctx.(context.Context); ok {
if cid, ok := ctx.Value(cidKey).(int); ok {
return "[%v][%v] " + format, append([]interface{}{os.Getpid(), cid}, a...)
}
} else {
return v.formatf(ctx, format, a...)
}
return format, a
}
// User should use context with value to pass the cid.
type key string
var cidKey key = "cid.logger.ossrs.org"
var gCid int = 999
// Create context with value.
func WithContext(ctx context.Context) context.Context {
gCid += 1
return context.WithValue(ctx, cidKey, gCid)
}
// Create context with value from parent, copy the cid from source context.
// @remark Create new cid if source has no cid represent.
func AliasContext(parent context.Context, source context.Context) context.Context {
if source != nil {
if cid, ok := source.Value(cidKey).(int); ok {
return context.WithValue(parent, cidKey, cid)
}
}
return WithContext(parent)
}

View file

@ -0,0 +1,239 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// The oryx logger package provides connection-oriented log service.
// logger.I(ctx, ...)
// logger.T(ctx, ...)
// logger.W(ctx, ...)
// logger.E(ctx, ...)
// Or use format:
// logger.If(ctx, format, ...)
// logger.Tf(ctx, format, ...)
// logger.Wf(ctx, format, ...)
// logger.Ef(ctx, format, ...)
// @remark the Context is optional thus can be nil.
// @remark From 1.7+, the ctx could be context.Context, wrap by logger.WithContext,
// please read ExampleLogger_ContextGO17().
package logger
import (
"fmt"
"io"
"io/ioutil"
"log"
"os"
)
// default level for logger.
const (
logInfoLabel = "[info] "
logTraceLabel = "[trace] "
logWarnLabel = "[warn] "
logErrorLabel = "[error] "
)
// The context for current goroutine.
// It maybe a cidContext or context.Context from GO1.7.
// @remark Use logger.WithContext(ctx) to wrap the context.
type Context interface{}
// The context to get current coroutine cid.
type cidContext interface {
Cid() int
}
// the LOG+ which provides connection-based log.
type loggerPlus struct {
logger *log.Logger
}
func NewLoggerPlus(l *log.Logger) Logger {
return &loggerPlus{logger: l}
}
func (v *loggerPlus) format(ctx Context, a ...interface{}) []interface{} {
if ctx == nil {
return append([]interface{}{fmt.Sprintf("[%v] ", os.Getpid())}, a...)
} else if ctx, ok := ctx.(cidContext); ok {
return append([]interface{}{fmt.Sprintf("[%v][%v] ", os.Getpid(), ctx.Cid())}, a...)
}
return a
}
func (v *loggerPlus) formatf(ctx Context, format string, a ...interface{}) (string, []interface{}) {
if ctx == nil {
return "[%v] " + format, append([]interface{}{os.Getpid()}, a...)
} else if ctx, ok := ctx.(cidContext); ok {
return "[%v][%v] " + format, append([]interface{}{os.Getpid(), ctx.Cid()}, a...)
}
return format, a
}
var colorYellow = "\033[33m"
var colorRed = "\033[31m"
var colorBlack = "\033[0m"
func (v *loggerPlus) doPrintln(args ...interface{}) {
if previousCloser == nil {
if v == Error {
fmt.Fprintf(os.Stdout, colorRed)
v.logger.Println(args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else if v == Warn {
fmt.Fprintf(os.Stdout, colorYellow)
v.logger.Println(args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else {
v.logger.Println(args...)
}
} else {
v.logger.Println(args...)
}
}
func (v *loggerPlus) doPrintf(format string, args ...interface{}) {
if previousCloser == nil {
if v == Error {
fmt.Fprintf(os.Stdout, colorRed)
v.logger.Printf(format, args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else if v == Warn {
fmt.Fprintf(os.Stdout, colorYellow)
v.logger.Printf(format, args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else {
v.logger.Printf(format, args...)
}
} else {
v.logger.Printf(format, args...)
}
}
// Info, the verbose info level, very detail log, the lowest level, to discard.
var Info Logger
// Alias for Info level println.
func I(ctx Context, a ...interface{}) {
Info.Println(ctx, a...)
}
// Printf for Info level log.
func If(ctx Context, format string, a ...interface{}) {
Info.Printf(ctx, format, a...)
}
// Trace, the trace level, something important, the default log level, to stdout.
var Trace Logger
// Alias for Trace level println.
func T(ctx Context, a ...interface{}) {
Trace.Println(ctx, a...)
}
// Printf for Trace level log.
func Tf(ctx Context, format string, a ...interface{}) {
Trace.Printf(ctx, format, a...)
}
// Warn, the warning level, dangerous information, to Stdout.
var Warn Logger
// Alias for Warn level println.
func W(ctx Context, a ...interface{}) {
Warn.Println(ctx, a...)
}
// Printf for Warn level log.
func Wf(ctx Context, format string, a ...interface{}) {
Warn.Printf(ctx, format, a...)
}
// Error, the error level, fatal error things, ot Stdout.
var Error Logger
// Alias for Error level println.
func E(ctx Context, a ...interface{}) {
Error.Println(ctx, a...)
}
// Printf for Error level log.
func Ef(ctx Context, format string, a ...interface{}) {
Error.Printf(ctx, format, a...)
}
// The logger for oryx.
type Logger interface {
// Println for logger plus,
// @param ctx the connection-oriented context,
// or context.Context from GO1.7, or nil to ignore.
Println(ctx Context, a ...interface{})
Printf(ctx Context, format string, a ...interface{})
}
func init() {
Info = NewLoggerPlus(log.New(ioutil.Discard, logInfoLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Trace = NewLoggerPlus(log.New(os.Stdout, logTraceLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Warn = NewLoggerPlus(log.New(os.Stderr, logWarnLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Error = NewLoggerPlus(log.New(os.Stderr, logErrorLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
// init writer and closer.
previousWriter = os.Stdout
previousCloser = nil
}
// Switch the underlayer io.
// @remark user must close previous io for logger never close it.
func Switch(w io.Writer) io.Writer {
// TODO: support level, default to trace here.
Info = NewLoggerPlus(log.New(ioutil.Discard, logInfoLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Trace = NewLoggerPlus(log.New(w, logTraceLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Warn = NewLoggerPlus(log.New(w, logWarnLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Error = NewLoggerPlus(log.New(w, logErrorLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
ow := previousWriter
previousWriter = w
if c, ok := w.(io.Closer); ok {
previousCloser = c
}
return ow
}
// The previous underlayer io for logger.
var previousCloser io.Closer
var previousWriter io.Writer
// The interface io.Closer
// Cleanup the logger, discard any log util switch to fresh writer.
func Close() (err error) {
Info = NewLoggerPlus(log.New(ioutil.Discard, logInfoLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Trace = NewLoggerPlus(log.New(ioutil.Discard, logTraceLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Warn = NewLoggerPlus(log.New(ioutil.Discard, logWarnLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Error = NewLoggerPlus(log.New(ioutil.Discard, logErrorLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
if previousCloser != nil {
err = previousCloser.Close()
previousCloser = nil
}
return
}

View file

@ -0,0 +1,34 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// +build !go1.7
package logger
func (v *loggerPlus) Println(ctx Context, a ...interface{}) {
args := v.format(ctx, a...)
v.doPrintln(args...)
}
func (v *loggerPlus) Printf(ctx Context, format string, a ...interface{}) {
format, args := v.formatf(ctx, format, a...)
v.doPrintf(format, args...)
}

View file

@ -0,0 +1,13 @@
# github.com/ossrs/go-oryx-lib v0.0.8
## explicit
github.com/ossrs/go-oryx-lib/errors
github.com/ossrs/go-oryx-lib/http
github.com/ossrs/go-oryx-lib/https
github.com/ossrs/go-oryx-lib/https/acme
github.com/ossrs/go-oryx-lib/https/crypto/ocsp
github.com/ossrs/go-oryx-lib/https/jose
github.com/ossrs/go-oryx-lib/https/jose/cipher
github.com/ossrs/go-oryx-lib/https/letsencrypt
github.com/ossrs/go-oryx-lib/https/net/context
github.com/ossrs/go-oryx-lib/https/time/rate
github.com/ossrs/go-oryx-lib/logger

47
trunk/3rdparty/httpx-static/version.go vendored Normal file
View file

@ -0,0 +1,47 @@
/*
The MIT License (MIT)
Copyright (c) 2019 winlin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
package main
import "fmt"
func VersionMajor() int {
return 1
}
func VersionMinor() int {
return 0
}
func VersionRevision() int {
return 5
}
func Version() string {
return fmt.Sprintf("%v.%v.%v", VersionMajor(), VersionMinor(), VersionRevision())
}
func Signature() string {
return "GoOryx"
}

18
trunk/3rdparty/signaling/.gitignore vendored Normal file
View file

@ -0,0 +1,18 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
.format.txt
objs

21
trunk/3rdparty/signaling/LICENSE vendored Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 srs-org
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

19
trunk/3rdparty/signaling/Makefile vendored Normal file
View file

@ -0,0 +1,19 @@
.PHONY: help default clean signaling
default: signaling
clean:
rm -f ./objs/signaling
.format.txt: *.go
gofmt -w .
echo "done" > .format.txt
signaling: ./objs/signaling
./objs/signaling: .format.txt *.go Makefile
go build -mod=vendor -o objs/signaling .
help:
@echo "Usage: make [signaling]"
@echo " signaling Make the signaling to ./objs/signaling"

24
trunk/3rdparty/signaling/README.md vendored Normal file
View file

@ -0,0 +1,24 @@
# signaling
WebRTC signaling for https://github.com/ossrs/srs
## Usage
Build and [run SRS](https://github.com/ossrs/srs/tree/4.0release#usage):
```bash
git clone -b 4.0release https://gitee.com/ossrs/srs.git srs &&
cd srs/trunk && ./configure && make && ./objs/srs -c conf/rtc.conf
```
Build and run signaling:
```bash
cd srs/trunk/3rdparty/signaling && make && ./objs/signaling
```
Open the H5 demos:
* [WebRTC: One to One over SFU(SRS)](http://localhost:1989/demos/one2one.html?autostart=true)
Winlin 2021.05

8
trunk/3rdparty/signaling/go.mod vendored Normal file
View file

@ -0,0 +1,8 @@
module github.com/ossrs/signaling
go 1.16
require (
github.com/ossrs/go-oryx-lib v0.0.8
golang.org/x/net v0.0.0-20210502030024-e5908800b52b
)

9
trunk/3rdparty/signaling/go.sum vendored Normal file
View file

@ -0,0 +1,9 @@
github.com/ossrs/go-oryx-lib v0.0.8 h1:k8ml3ZLsjIMoQEdZdWuy8zkU0w/fbJSyHvT/s9NyeCc=
github.com/ossrs/go-oryx-lib v0.0.8/go.mod h1:i2tH4TZBzAw5h+HwGrNOKvP/nmZgSQz0OEnLLdzcT/8=
golang.org/x/net v0.0.0-20210502030024-e5908800b52b h1:jCRjgm6WJHzM8VQrm/es2wXYqqbq0NZ1yXFHHgzkiVQ=
golang.org/x/net v0.0.0-20210502030024-e5908800b52b/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

319
trunk/3rdparty/signaling/main.go vendored Normal file
View file

@ -0,0 +1,319 @@
// The MIT License (MIT)
//
// Copyright (c) 2021 Winlin
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"net/http"
"os"
"path"
"strings"
"sync"
"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger"
"golang.org/x/net/websocket"
)
type Participant struct {
Room *Room `json:"-"`
Display string `json:"display"`
Publishing bool `json:"publishing"`
Out chan []byte `json:"-"`
}
func (v *Participant) String() string {
return fmt.Sprintf("display=%v, room=%v", v.Display, v.Room.Name)
}
type Room struct {
Name string `json:"room"`
Participants []*Participant `json:"participants"`
lock sync.RWMutex `json:"-"`
}
func (v *Room) String() string {
return fmt.Sprintf("room=%v, participants=%v", v.Name, len(v.Participants))
}
func (v *Room) Add(p *Participant) error {
v.lock.Lock()
defer v.lock.Unlock()
for _, r := range v.Participants {
if r.Display == p.Display {
return errors.Errorf("Participant %v exists in room %v", p.Display, v.Name)
}
}
v.Participants = append(v.Participants, p)
return nil
}
func (v *Room) Get(display string) *Participant {
v.lock.RLock()
defer v.lock.RUnlock()
for _, r := range v.Participants {
if r.Display == display {
return r
}
}
return nil
}
func (v *Room) Remove(p *Participant) {
v.lock.Lock()
defer v.lock.Unlock()
for i, r := range v.Participants {
if p == r {
v.Participants = append(v.Participants[:i], v.Participants[i+1:]...)
return
}
}
}
func (v *Room) Notify(ctx context.Context, peer *Participant, event string) {
var participants []*Participant
func() {
v.lock.RLock()
defer v.lock.RUnlock()
participants = append(participants, v.Participants...)
}()
for _, r := range participants {
if r == peer {
continue
}
res := struct {
Action string `json:"action"`
Event string `json:"event"`
Room string `json:"room"`
Self *Participant `json:"self"`
Peer *Participant `json:"peer"`
Participants []*Participant `json:"participants"`
}{
"notify", event, v.Name, r, peer, participants,
}
b, err := json.Marshal(struct {
Message interface{} `json:"msg"`
}{
res,
})
if err != nil {
return
}
select {
case <-ctx.Done():
return
case r.Out <- b:
}
logger.Tf(ctx, "Notify %v about %v %v", r, peer, event)
}
}
func main() {
var listen string
flag.StringVar(&listen, "listen", "1989", "The TCP listen port")
var html string
flag.StringVar(&html, "root", "./www", "The www web root")
flag.Usage = func() {
fmt.Println(fmt.Sprintf("Usage: %v [Options]", os.Args[0]))
fmt.Println(fmt.Sprintf("Options:"))
fmt.Println(fmt.Sprintf(" -listen The TCP listen port. Default: %v", listen))
fmt.Println(fmt.Sprintf(" -root The www web root. Default: %v", html))
fmt.Println(fmt.Sprintf("For example:"))
fmt.Println(fmt.Sprintf(" %v -listen %v -html %v", os.Args[0], listen, html))
}
flag.Parse()
if !strings.Contains(listen, ":") {
listen = ":" + listen
}
ctx := context.Background()
home := listen
if strings.HasPrefix(home, ":") {
home = "http://localhost" + listen
}
if !path.IsAbs(html) && path.IsAbs(os.Args[0]) {
html = path.Join(path.Dir(os.Args[0]), html)
}
logger.Tf(ctx, "Signaling ok, root=%v, home page is %v", html, home)
http.Handle("/", http.FileServer(http.Dir(html)))
// Key is name of room, value is Room
var rooms sync.Map
http.Handle("/sig/v1/rtc", websocket.Handler(func(c *websocket.Conn) {
ctx, cancel := context.WithCancel(logger.WithContext(ctx))
defer cancel()
r := c.Request()
logger.Tf(ctx, "Serve client %v at %v", r.RemoteAddr, r.RequestURI)
defer c.Close()
var self *Participant
go func() {
<-ctx.Done()
if self != nil {
self.Room.Remove(self)
logger.Tf(ctx, "Remove client %v", self)
}
}()
inMessages := make(chan []byte, 0)
go func() {
defer cancel()
buf := make([]byte, 16384)
for {
n, err := c.Read(buf)
if err != nil {
logger.Wf(ctx, "Ignore err %v for %v", err, r.RemoteAddr)
break
}
select {
case <-ctx.Done():
case inMessages <- buf[:n]:
}
}
}()
outMessages := make(chan []byte, 0)
go func() {
defer cancel()
handleMessage := func(m []byte) error {
action := struct {
TID string `json:"tid"`
Message struct {
Action string `json:"action"`
} `json:"msg"`
}{}
if err := json.Unmarshal(m, &action); err != nil {
return errors.Wrapf(err, "Unmarshal %s", m)
}
var res interface{}
var p *Participant
if action.Message.Action == "join" {
obj := struct {
Message struct {
Room string `json:"room"`
Display string `json:"display"`
} `json:"msg"`
}{}
if err := json.Unmarshal(m, &obj); err != nil {
return errors.Wrapf(err, "Unmarshal %s", m)
}
r, _ := rooms.LoadOrStore(obj.Message.Room, &Room{Name: obj.Message.Room})
p = &Participant{Room: r.(*Room), Display: obj.Message.Display, Out: outMessages}
if err := r.(*Room).Add(p); err != nil {
return errors.Wrapf(err, "join")
}
self = p
logger.Tf(ctx, "Join %v ok", self)
res = struct {
Action string `json:"action"`
Room string `json:"room"`
Self *Participant `json:"self"`
Participants []*Participant `json:"participants"`
}{
action.Message.Action, obj.Message.Room, p, r.(*Room).Participants,
}
go r.(*Room).Notify(ctx, p, action.Message.Action)
} else if action.Message.Action == "publish" {
obj := struct {
Message struct {
Room string `json:"room"`
Display string `json:"display"`
} `json:"msg"`
}{}
if err := json.Unmarshal(m, &obj); err != nil {
return errors.Wrapf(err, "Unmarshal %s", m)
}
r, _ := rooms.LoadOrStore(obj.Message.Room, &Room{Name: obj.Message.Room})
p := r.(*Room).Get(obj.Message.Display)
// Now, the peer is publishing.
p.Publishing = true
go r.(*Room).Notify(ctx, p, action.Message.Action)
} else {
return errors.Errorf("Invalid message %s", m)
}
if b, err := json.Marshal(struct {
TID string `json:"tid"`
Message interface{} `json:"msg"`
}{
action.TID, res,
}); err != nil {
return errors.Wrapf(err, "marshal")
} else {
select {
case <-ctx.Done():
return ctx.Err()
case outMessages <- b:
}
}
return nil
}
for m := range inMessages {
if err := handleMessage(m); err != nil {
logger.Wf(ctx, "Handle %s err %v", m, err)
break
}
}
}()
for m := range outMessages {
if _, err := c.Write(m); err != nil {
logger.Wf(ctx, "Ignore err %v for %v", err, r.RemoteAddr)
break
}
}
}))
http.ListenAndServe(listen, nil)
}

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2013-2017 winlin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,23 @@
Copyright (c) 2015, Dave Cheney <dave@cheney.net>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,52 @@
# errors [![Travis-CI](https://travis-ci.org/pkg/errors.svg)](https://travis-ci.org/pkg/errors) [![AppVeyor](https://ci.appveyor.com/api/projects/status/b98mptawhudj53ep/branch/master?svg=true)](https://ci.appveyor.com/project/davecheney/errors/branch/master) [![GoDoc](https://godoc.org/github.com/pkg/errors?status.svg)](http://godoc.org/github.com/pkg/errors) [![Report card](https://goreportcard.com/badge/github.com/pkg/errors)](https://goreportcard.com/report/github.com/pkg/errors)
Package errors provides simple error handling primitives.
`go get github.com/pkg/errors`
The traditional error handling idiom in Go is roughly akin to
```go
if err != nil {
return err
}
```
which applied recursively up the call stack results in error reports without context or debugging information. The errors package allows programmers to add context to the failure path in their code in a way that does not destroy the original value of the error.
## Adding context to an error
The errors.Wrap function returns a new error that adds context to the original error. For example
```go
_, err := ioutil.ReadAll(r)
if err != nil {
return errors.Wrap(err, "read failed")
}
```
## Retrieving the cause of an error
Using `errors.Wrap` constructs a stack of errors, adding context to the preceding error. Depending on the nature of the error it may be necessary to reverse the operation of errors.Wrap to retrieve the original error for inspection. Any error value which implements this interface can be inspected by `errors.Cause`.
```go
type causer interface {
Cause() error
}
```
`errors.Cause` will recursively retrieve the topmost error which does not implement `causer`, which is assumed to be the original cause. For example:
```go
switch err := errors.Cause(err).(type) {
case *MyError:
// handle specifically
default:
// unknown error
}
```
[Read the package documentation for more information](https://godoc.org/github.com/pkg/errors).
## Contributing
We welcome pull requests, bug fixes and issue reports. With that said, the bar for adding new symbols to this package is intentionally set high.
Before proposing a change, please discuss your change by raising an issue.
## Licence
BSD-2-Clause

View file

@ -0,0 +1,270 @@
// Package errors provides simple error handling primitives.
//
// The traditional error handling idiom in Go is roughly akin to
//
// if err != nil {
// return err
// }
//
// which applied recursively up the call stack results in error reports
// without context or debugging information. The errors package allows
// programmers to add context to the failure path in their code in a way
// that does not destroy the original value of the error.
//
// Adding context to an error
//
// The errors.Wrap function returns a new error that adds context to the
// original error by recording a stack trace at the point Wrap is called,
// and the supplied message. For example
//
// _, err := ioutil.ReadAll(r)
// if err != nil {
// return errors.Wrap(err, "read failed")
// }
//
// If additional control is required the errors.WithStack and errors.WithMessage
// functions destructure errors.Wrap into its component operations of annotating
// an error with a stack trace and an a message, respectively.
//
// Retrieving the cause of an error
//
// Using errors.Wrap constructs a stack of errors, adding context to the
// preceding error. Depending on the nature of the error it may be necessary
// to reverse the operation of errors.Wrap to retrieve the original error
// for inspection. Any error value which implements this interface
//
// type causer interface {
// Cause() error
// }
//
// can be inspected by errors.Cause. errors.Cause will recursively retrieve
// the topmost error which does not implement causer, which is assumed to be
// the original cause. For example:
//
// switch err := errors.Cause(err).(type) {
// case *MyError:
// // handle specifically
// default:
// // unknown error
// }
//
// causer interface is not exported by this package, but is considered a part
// of stable public API.
//
// Formatted printing of errors
//
// All error values returned from this package implement fmt.Formatter and can
// be formatted by the fmt package. The following verbs are supported
//
// %s print the error. If the error has a Cause it will be
// printed recursively
// %v see %s
// %+v extended format. Each Frame of the error's StackTrace will
// be printed in detail.
//
// Retrieving the stack trace of an error or wrapper
//
// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are
// invoked. This information can be retrieved with the following interface.
//
// type stackTracer interface {
// StackTrace() errors.StackTrace
// }
//
// Where errors.StackTrace is defined as
//
// type StackTrace []Frame
//
// The Frame type represents a call site in the stack trace. Frame supports
// the fmt.Formatter interface that can be used for printing information about
// the stack trace of this error. For example:
//
// if err, ok := err.(stackTracer); ok {
// for _, f := range err.StackTrace() {
// fmt.Printf("%+s:%d", f)
// }
// }
//
// stackTracer interface is not exported by this package, but is considered a part
// of stable public API.
//
// See the documentation for Frame.Format for more details.
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
)
// New returns an error with the supplied message.
// New also records the stack trace at the point it was called.
func New(message string) error {
return &fundamental{
msg: message,
stack: callers(),
}
}
// Errorf formats according to a format specifier and returns the string
// as a value that satisfies error.
// Errorf also records the stack trace at the point it was called.
func Errorf(format string, args ...interface{}) error {
return &fundamental{
msg: fmt.Sprintf(format, args...),
stack: callers(),
}
}
// fundamental is an error that has a message and a stack, but no caller.
type fundamental struct {
msg string
*stack
}
func (f *fundamental) Error() string { return f.msg }
func (f *fundamental) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
io.WriteString(s, f.msg)
f.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, f.msg)
case 'q':
fmt.Fprintf(s, "%q", f.msg)
}
}
// WithStack annotates err with a stack trace at the point WithStack was called.
// If err is nil, WithStack returns nil.
func WithStack(err error) error {
if err == nil {
return nil
}
return &withStack{
err,
callers(),
}
}
type withStack struct {
error
*stack
}
func (w *withStack) Cause() error { return w.error }
func (w *withStack) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v", w.Cause())
w.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, w.Error())
case 'q':
fmt.Fprintf(s, "%q", w.Error())
}
}
// Wrap returns an error annotating err with a stack trace
// at the point Wrap is called, and the supplied message.
// If err is nil, Wrap returns nil.
func Wrap(err error, message string) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: message,
}
return &withStack{
err,
callers(),
}
}
// Wrapf returns an error annotating err with a stack trace
// at the point Wrapf is call, and the format specifier.
// If err is nil, Wrapf returns nil.
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: fmt.Sprintf(format, args...),
}
return &withStack{
err,
callers(),
}
}
// WithMessage annotates err with a new message.
// If err is nil, WithMessage returns nil.
func WithMessage(err error, message string) error {
if err == nil {
return nil
}
return &withMessage{
cause: err,
msg: message,
}
}
type withMessage struct {
cause error
msg string
}
func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() }
func (w *withMessage) Cause() error { return w.cause }
func (w *withMessage) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v\n", w.Cause())
io.WriteString(s, w.msg)
return
}
fallthrough
case 's', 'q':
io.WriteString(s, w.Error())
}
}
// Cause returns the underlying cause of the error, if possible.
// An error value has a cause if it implements the following
// interface:
//
// type causer interface {
// Cause() error
// }
//
// If the error does not implement Cause, the original error will
// be returned. If the error is nil, nil will be returned without further
// investigation.
func Cause(err error) error {
type causer interface {
Cause() error
}
for err != nil {
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return err
}

View file

@ -0,0 +1,187 @@
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
"path"
"runtime"
"strings"
)
// Frame represents a program counter inside a stack frame.
type Frame uintptr
// pc returns the program counter for this frame;
// multiple frames may have the same PC value.
func (f Frame) pc() uintptr { return uintptr(f) - 1 }
// file returns the full path to the file that contains the
// function for this Frame's pc.
func (f Frame) file() string {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return "unknown"
}
file, _ := fn.FileLine(f.pc())
return file
}
// line returns the line number of source code of the
// function for this Frame's pc.
func (f Frame) line() int {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return 0
}
_, line := fn.FileLine(f.pc())
return line
}
// Format formats the frame according to the fmt.Formatter interface.
//
// %s source file
// %d source line
// %n function name
// %v equivalent to %s:%d
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+s path of source file relative to the compile time GOPATH
// %+v equivalent to %+s:%d
func (f Frame) Format(s fmt.State, verb rune) {
switch verb {
case 's':
switch {
case s.Flag('+'):
pc := f.pc()
fn := runtime.FuncForPC(pc)
if fn == nil {
io.WriteString(s, "unknown")
} else {
file, _ := fn.FileLine(pc)
fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file)
}
default:
io.WriteString(s, path.Base(f.file()))
}
case 'd':
fmt.Fprintf(s, "%d", f.line())
case 'n':
name := runtime.FuncForPC(f.pc()).Name()
io.WriteString(s, funcname(name))
case 'v':
f.Format(s, 's')
io.WriteString(s, ":")
f.Format(s, 'd')
}
}
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
type StackTrace []Frame
// Format formats the stack of Frames according to the fmt.Formatter interface.
//
// %s lists source files for each Frame in the stack
// %v lists the source file and line number for each Frame in the stack
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+v Prints filename, function, and line number for each Frame in the stack.
func (st StackTrace) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case s.Flag('+'):
for _, f := range st {
fmt.Fprintf(s, "\n%+v", f)
}
case s.Flag('#'):
fmt.Fprintf(s, "%#v", []Frame(st))
default:
fmt.Fprintf(s, "%v", []Frame(st))
}
case 's':
fmt.Fprintf(s, "%s", []Frame(st))
}
}
// stack represents a stack of program counters.
type stack []uintptr
func (s *stack) Format(st fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case st.Flag('+'):
for _, pc := range *s {
f := Frame(pc)
fmt.Fprintf(st, "\n%+v", f)
}
}
}
}
func (s *stack) StackTrace() StackTrace {
f := make([]Frame, len(*s))
for i := 0; i < len(f); i++ {
f[i] = Frame((*s)[i])
}
return f
}
func callers() *stack {
const depth = 32
var pcs [depth]uintptr
n := runtime.Callers(3, pcs[:])
var st stack = pcs[0:n]
return &st
}
// funcname removes the path prefix component of a function's name reported by func.Name().
func funcname(name string) string {
i := strings.LastIndex(name, "/")
name = name[i+1:]
i = strings.Index(name, ".")
return name[i+1:]
}
func trimGOPATH(name, file string) string {
// Here we want to get the source file path relative to the compile time
// GOPATH. As of Go 1.6.x there is no direct way to know the compiled
// GOPATH at runtime, but we can infer the number of path segments in the
// GOPATH. We note that fn.Name() returns the function name qualified by
// the import path, which does not include the GOPATH. Thus we can trim
// segments from the beginning of the file path until the number of path
// separators remaining is one more than the number of path separators in
// the function name. For example, given:
//
// GOPATH /home/user
// file /home/user/src/pkg/sub/file.go
// fn.Name() pkg/sub.Type.Method
//
// We want to produce:
//
// pkg/sub/file.go
//
// From this we can easily see that fn.Name() has one less path separator
// than our desired output. We count separators from the end of the file
// path until it finds two more than in the function name and then move
// one character forward to preserve the initial path segment without a
// leading separator.
const sep = "/"
goal := strings.Count(name, sep) + 2
i := len(file)
for n := 0; n < goal; n++ {
i = strings.LastIndex(file[:i], sep)
if i == -1 {
// not enough separators found, set i so that the slice expression
// below leaves file unmodified
i = -len(sep)
break
}
}
// get back to 0 or trim the leading separator
file = file[i+len(sep):]
return file
}

View file

@ -0,0 +1,86 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// +build go1.7
package logger
import (
"context"
"fmt"
"os"
)
func (v *loggerPlus) Println(ctx Context, a ...interface{}) {
args := v.contextFormat(ctx, a...)
v.doPrintln(args...)
}
func (v *loggerPlus) Printf(ctx Context, format string, a ...interface{}) {
format, args := v.contextFormatf(ctx, format, a...)
v.doPrintf(format, args...)
}
func (v *loggerPlus) contextFormat(ctx Context, a ...interface{}) []interface{} {
if ctx, ok := ctx.(context.Context); ok {
if cid, ok := ctx.Value(cidKey).(int); ok {
return append([]interface{}{fmt.Sprintf("[%v][%v]", os.Getpid(), cid)}, a...)
}
} else {
return v.format(ctx, a...)
}
return a
}
func (v *loggerPlus) contextFormatf(ctx Context, format string, a ...interface{}) (string, []interface{}) {
if ctx, ok := ctx.(context.Context); ok {
if cid, ok := ctx.Value(cidKey).(int); ok {
return "[%v][%v] " + format, append([]interface{}{os.Getpid(), cid}, a...)
}
} else {
return v.formatf(ctx, format, a...)
}
return format, a
}
// User should use context with value to pass the cid.
type key string
var cidKey key = "cid.logger.ossrs.org"
var gCid int = 999
// Create context with value.
func WithContext(ctx context.Context) context.Context {
gCid += 1
return context.WithValue(ctx, cidKey, gCid)
}
// Create context with value from parent, copy the cid from source context.
// @remark Create new cid if source has no cid represent.
func AliasContext(parent context.Context, source context.Context) context.Context {
if source != nil {
if cid, ok := source.Value(cidKey).(int); ok {
return context.WithValue(parent, cidKey, cid)
}
}
return WithContext(parent)
}

View file

@ -0,0 +1,239 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// The oryx logger package provides connection-oriented log service.
// logger.I(ctx, ...)
// logger.T(ctx, ...)
// logger.W(ctx, ...)
// logger.E(ctx, ...)
// Or use format:
// logger.If(ctx, format, ...)
// logger.Tf(ctx, format, ...)
// logger.Wf(ctx, format, ...)
// logger.Ef(ctx, format, ...)
// @remark the Context is optional thus can be nil.
// @remark From 1.7+, the ctx could be context.Context, wrap by logger.WithContext,
// please read ExampleLogger_ContextGO17().
package logger
import (
"fmt"
"io"
"io/ioutil"
"log"
"os"
)
// default level for logger.
const (
logInfoLabel = "[info] "
logTraceLabel = "[trace] "
logWarnLabel = "[warn] "
logErrorLabel = "[error] "
)
// The context for current goroutine.
// It maybe a cidContext or context.Context from GO1.7.
// @remark Use logger.WithContext(ctx) to wrap the context.
type Context interface{}
// The context to get current coroutine cid.
type cidContext interface {
Cid() int
}
// the LOG+ which provides connection-based log.
type loggerPlus struct {
logger *log.Logger
}
func NewLoggerPlus(l *log.Logger) Logger {
return &loggerPlus{logger: l}
}
func (v *loggerPlus) format(ctx Context, a ...interface{}) []interface{} {
if ctx == nil {
return append([]interface{}{fmt.Sprintf("[%v] ", os.Getpid())}, a...)
} else if ctx, ok := ctx.(cidContext); ok {
return append([]interface{}{fmt.Sprintf("[%v][%v] ", os.Getpid(), ctx.Cid())}, a...)
}
return a
}
func (v *loggerPlus) formatf(ctx Context, format string, a ...interface{}) (string, []interface{}) {
if ctx == nil {
return "[%v] " + format, append([]interface{}{os.Getpid()}, a...)
} else if ctx, ok := ctx.(cidContext); ok {
return "[%v][%v] " + format, append([]interface{}{os.Getpid(), ctx.Cid()}, a...)
}
return format, a
}
var colorYellow = "\033[33m"
var colorRed = "\033[31m"
var colorBlack = "\033[0m"
func (v *loggerPlus) doPrintln(args ...interface{}) {
if previousCloser == nil {
if v == Error {
fmt.Fprintf(os.Stdout, colorRed)
v.logger.Println(args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else if v == Warn {
fmt.Fprintf(os.Stdout, colorYellow)
v.logger.Println(args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else {
v.logger.Println(args...)
}
} else {
v.logger.Println(args...)
}
}
func (v *loggerPlus) doPrintf(format string, args ...interface{}) {
if previousCloser == nil {
if v == Error {
fmt.Fprintf(os.Stdout, colorRed)
v.logger.Printf(format, args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else if v == Warn {
fmt.Fprintf(os.Stdout, colorYellow)
v.logger.Printf(format, args...)
fmt.Fprintf(os.Stdout, colorBlack)
} else {
v.logger.Printf(format, args...)
}
} else {
v.logger.Printf(format, args...)
}
}
// Info, the verbose info level, very detail log, the lowest level, to discard.
var Info Logger
// Alias for Info level println.
func I(ctx Context, a ...interface{}) {
Info.Println(ctx, a...)
}
// Printf for Info level log.
func If(ctx Context, format string, a ...interface{}) {
Info.Printf(ctx, format, a...)
}
// Trace, the trace level, something important, the default log level, to stdout.
var Trace Logger
// Alias for Trace level println.
func T(ctx Context, a ...interface{}) {
Trace.Println(ctx, a...)
}
// Printf for Trace level log.
func Tf(ctx Context, format string, a ...interface{}) {
Trace.Printf(ctx, format, a...)
}
// Warn, the warning level, dangerous information, to Stdout.
var Warn Logger
// Alias for Warn level println.
func W(ctx Context, a ...interface{}) {
Warn.Println(ctx, a...)
}
// Printf for Warn level log.
func Wf(ctx Context, format string, a ...interface{}) {
Warn.Printf(ctx, format, a...)
}
// Error, the error level, fatal error things, ot Stdout.
var Error Logger
// Alias for Error level println.
func E(ctx Context, a ...interface{}) {
Error.Println(ctx, a...)
}
// Printf for Error level log.
func Ef(ctx Context, format string, a ...interface{}) {
Error.Printf(ctx, format, a...)
}
// The logger for oryx.
type Logger interface {
// Println for logger plus,
// @param ctx the connection-oriented context,
// or context.Context from GO1.7, or nil to ignore.
Println(ctx Context, a ...interface{})
Printf(ctx Context, format string, a ...interface{})
}
func init() {
Info = NewLoggerPlus(log.New(ioutil.Discard, logInfoLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Trace = NewLoggerPlus(log.New(os.Stdout, logTraceLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Warn = NewLoggerPlus(log.New(os.Stderr, logWarnLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Error = NewLoggerPlus(log.New(os.Stderr, logErrorLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
// init writer and closer.
previousWriter = os.Stdout
previousCloser = nil
}
// Switch the underlayer io.
// @remark user must close previous io for logger never close it.
func Switch(w io.Writer) io.Writer {
// TODO: support level, default to trace here.
Info = NewLoggerPlus(log.New(ioutil.Discard, logInfoLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Trace = NewLoggerPlus(log.New(w, logTraceLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Warn = NewLoggerPlus(log.New(w, logWarnLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Error = NewLoggerPlus(log.New(w, logErrorLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
ow := previousWriter
previousWriter = w
if c, ok := w.(io.Closer); ok {
previousCloser = c
}
return ow
}
// The previous underlayer io for logger.
var previousCloser io.Closer
var previousWriter io.Writer
// The interface io.Closer
// Cleanup the logger, discard any log util switch to fresh writer.
func Close() (err error) {
Info = NewLoggerPlus(log.New(ioutil.Discard, logInfoLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Trace = NewLoggerPlus(log.New(ioutil.Discard, logTraceLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Warn = NewLoggerPlus(log.New(ioutil.Discard, logWarnLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
Error = NewLoggerPlus(log.New(ioutil.Discard, logErrorLabel, log.Ldate|log.Ltime|log.Lmicroseconds))
if previousCloser != nil {
err = previousCloser.Close()
previousCloser = nil
}
return
}

View file

@ -0,0 +1,34 @@
// The MIT License (MIT)
//
// Copyright (c) 2013-2017 Oryx(ossrs)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// +build !go1.7
package logger
func (v *loggerPlus) Println(ctx Context, a ...interface{}) {
args := v.format(ctx, a...)
v.doPrintln(args...)
}
func (v *loggerPlus) Printf(ctx Context, format string, a ...interface{}) {
format, args := v.formatf(ctx, format, a...)
v.doPrintf(format, args...)
}

View file

@ -0,0 +1,3 @@
# This source code refers to The Go Authors for copyright purposes.
# The master list of authors is in the main Go distribution,
# visible at http://tip.golang.org/AUTHORS.

View file

@ -0,0 +1,3 @@
# This source code was written by the Go contributors.
# The master list of contributors is in the main Go distribution,
# visible at http://tip.golang.org/CONTRIBUTORS.

View file

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,22 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

View file

@ -0,0 +1,106 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"io"
"net"
"net/http"
"net/url"
)
// DialError is an error that occurs while dialling a websocket server.
type DialError struct {
*Config
Err error
}
func (e *DialError) Error() string {
return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error()
}
// NewConfig creates a new WebSocket config for client connection.
func NewConfig(server, origin string) (config *Config, err error) {
config = new(Config)
config.Version = ProtocolVersionHybi13
config.Location, err = url.ParseRequestURI(server)
if err != nil {
return
}
config.Origin, err = url.ParseRequestURI(origin)
if err != nil {
return
}
config.Header = http.Header(make(map[string][]string))
return
}
// NewClient creates a new WebSocket client connection over rwc.
func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
err = hybiClientHandshake(config, br, bw)
if err != nil {
return
}
buf := bufio.NewReadWriter(br, bw)
ws = newHybiClientConn(config, buf, rwc)
return
}
// Dial opens a new client connection to a WebSocket.
func Dial(url_, protocol, origin string) (ws *Conn, err error) {
config, err := NewConfig(url_, origin)
if err != nil {
return nil, err
}
if protocol != "" {
config.Protocol = []string{protocol}
}
return DialConfig(config)
}
var portMap = map[string]string{
"ws": "80",
"wss": "443",
}
func parseAuthority(location *url.URL) string {
if _, ok := portMap[location.Scheme]; ok {
if _, _, err := net.SplitHostPort(location.Host); err != nil {
return net.JoinHostPort(location.Host, portMap[location.Scheme])
}
}
return location.Host
}
// DialConfig opens a new client connection to a WebSocket with a config.
func DialConfig(config *Config) (ws *Conn, err error) {
var client net.Conn
if config.Location == nil {
return nil, &DialError{config, ErrBadWebSocketLocation}
}
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}
dialer := config.Dialer
if dialer == nil {
dialer = &net.Dialer{}
}
client, err = dialWithDialer(dialer, config)
if err != nil {
goto Error
}
ws, err = NewClient(config, client)
if err != nil {
client.Close()
goto Error
}
return
Error:
return nil, &DialError{config, err}
}

View file

@ -0,0 +1,24 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"crypto/tls"
"net"
)
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
switch config.Location.Scheme {
case "ws":
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
case "wss":
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
default:
err = ErrBadScheme
}
return
}

View file

@ -0,0 +1,583 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
// This file implements a protocol of hybi draft.
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
)
const (
websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
closeStatusNormal = 1000
closeStatusGoingAway = 1001
closeStatusProtocolError = 1002
closeStatusUnsupportedData = 1003
closeStatusFrameTooLarge = 1004
closeStatusNoStatusRcvd = 1005
closeStatusAbnormalClosure = 1006
closeStatusBadMessageData = 1007
closeStatusPolicyViolation = 1008
closeStatusTooBigData = 1009
closeStatusExtensionMismatch = 1010
maxControlFramePayloadLength = 125
)
var (
ErrBadMaskingKey = &ProtocolError{"bad masking key"}
ErrBadPongMessage = &ProtocolError{"bad pong message"}
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
ErrNotImplemented = &ProtocolError{"not implemented"}
handshakeHeader = map[string]bool{
"Host": true,
"Upgrade": true,
"Connection": true,
"Sec-Websocket-Key": true,
"Sec-Websocket-Origin": true,
"Sec-Websocket-Version": true,
"Sec-Websocket-Protocol": true,
"Sec-Websocket-Accept": true,
}
)
// A hybiFrameHeader is a frame header as defined in hybi draft.
type hybiFrameHeader struct {
Fin bool
Rsv [3]bool
OpCode byte
Length int64
MaskingKey []byte
data *bytes.Buffer
}
// A hybiFrameReader is a reader for hybi frame.
type hybiFrameReader struct {
reader io.Reader
header hybiFrameHeader
pos int64
length int
}
func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
n, err = frame.reader.Read(msg)
if frame.header.MaskingKey != nil {
for i := 0; i < n; i++ {
msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
frame.pos++
}
}
return n, err
}
func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
func (frame *hybiFrameReader) HeaderReader() io.Reader {
if frame.header.data == nil {
return nil
}
if frame.header.data.Len() == 0 {
return nil
}
return frame.header.data
}
func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
func (frame *hybiFrameReader) Len() (n int) { return frame.length }
// A hybiFrameReaderFactory creates new frame reader based on its frame type.
type hybiFrameReaderFactory struct {
*bufio.Reader
}
// NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
// See Section 5.2 Base Framing protocol for detail.
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
hybiFrame := new(hybiFrameReader)
frame = hybiFrame
var header []byte
var b byte
// First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
for i := 0; i < 3; i++ {
j := uint(6 - i)
hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
}
hybiFrame.header.OpCode = header[0] & 0x0f
// Second byte. Mask/Payload len(7bits)
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
mask := (b & 0x80) != 0
b &= 0x7f
lengthFields := 0
switch {
case b <= 125: // Payload length 7bits.
hybiFrame.header.Length = int64(b)
case b == 126: // Payload length 7+16bits
lengthFields = 2
case b == 127: // Payload length 7+64bits
lengthFields = 8
}
for i := 0; i < lengthFields; i++ {
b, err = buf.ReadByte()
if err != nil {
return
}
if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits
b &= 0x7f
}
header = append(header, b)
hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
}
if mask {
// Masking key. 4 bytes.
for i := 0; i < 4; i++ {
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
}
}
hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
hybiFrame.header.data = bytes.NewBuffer(header)
hybiFrame.length = len(header) + int(hybiFrame.header.Length)
return
}
// A HybiFrameWriter is a writer for hybi frame.
type hybiFrameWriter struct {
writer *bufio.Writer
header *hybiFrameHeader
}
func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
var header []byte
var b byte
if frame.header.Fin {
b |= 0x80
}
for i := 0; i < 3; i++ {
if frame.header.Rsv[i] {
j := uint(6 - i)
b |= 1 << j
}
}
b |= frame.header.OpCode
header = append(header, b)
if frame.header.MaskingKey != nil {
b = 0x80
} else {
b = 0
}
lengthFields := 0
length := len(msg)
switch {
case length <= 125:
b |= byte(length)
case length < 65536:
b |= 126
lengthFields = 2
default:
b |= 127
lengthFields = 8
}
header = append(header, b)
for i := 0; i < lengthFields; i++ {
j := uint((lengthFields - i - 1) * 8)
b = byte((length >> j) & 0xff)
header = append(header, b)
}
if frame.header.MaskingKey != nil {
if len(frame.header.MaskingKey) != 4 {
return 0, ErrBadMaskingKey
}
header = append(header, frame.header.MaskingKey...)
frame.writer.Write(header)
data := make([]byte, length)
for i := range data {
data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
}
frame.writer.Write(data)
err = frame.writer.Flush()
return length, err
}
frame.writer.Write(header)
frame.writer.Write(msg)
err = frame.writer.Flush()
return length, err
}
func (frame *hybiFrameWriter) Close() error { return nil }
type hybiFrameWriterFactory struct {
*bufio.Writer
needMaskingKey bool
}
func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
if buf.needMaskingKey {
frameHeader.MaskingKey, err = generateMaskingKey()
if err != nil {
return nil, err
}
}
return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
}
type hybiFrameHandler struct {
conn *Conn
payloadType byte
}
func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
if handler.conn.IsServerConn() {
// The client MUST mask all frames sent to the server.
if frame.(*hybiFrameReader).header.MaskingKey == nil {
handler.WriteClose(closeStatusProtocolError)
return nil, io.EOF
}
} else {
// The server MUST NOT mask all frames.
if frame.(*hybiFrameReader).header.MaskingKey != nil {
handler.WriteClose(closeStatusProtocolError)
return nil, io.EOF
}
}
if header := frame.HeaderReader(); header != nil {
io.Copy(ioutil.Discard, header)
}
switch frame.PayloadType() {
case ContinuationFrame:
frame.(*hybiFrameReader).header.OpCode = handler.payloadType
case TextFrame, BinaryFrame:
handler.payloadType = frame.PayloadType()
case CloseFrame:
return nil, io.EOF
case PingFrame, PongFrame:
b := make([]byte, maxControlFramePayloadLength)
n, err := io.ReadFull(frame, b)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return nil, err
}
io.Copy(ioutil.Discard, frame)
if frame.PayloadType() == PingFrame {
if _, err := handler.WritePong(b[:n]); err != nil {
return nil, err
}
}
return nil, nil
}
return frame, nil
}
func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
handler.conn.wio.Lock()
defer handler.conn.wio.Unlock()
w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
if err != nil {
return err
}
msg := make([]byte, 2)
binary.BigEndian.PutUint16(msg, uint16(status))
_, err = w.Write(msg)
w.Close()
return err
}
func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
handler.conn.wio.Lock()
defer handler.conn.wio.Unlock()
w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
if err != nil {
return 0, err
}
n, err = w.Write(msg)
w.Close()
return n, err
}
// newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
if buf == nil {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
buf = bufio.NewReadWriter(br, bw)
}
ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
frameWriterFactory: hybiFrameWriterFactory{
buf.Writer, request == nil},
PayloadType: TextFrame,
defaultCloseStatus: closeStatusNormal}
ws.frameHandler = &hybiFrameHandler{conn: ws}
return ws
}
// generateMaskingKey generates a masking key for a frame.
func generateMaskingKey() (maskingKey []byte, err error) {
maskingKey = make([]byte, 4)
if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
return
}
return
}
// generateNonce generates a nonce consisting of a randomly selected 16-byte
// value that has been base64-encoded.
func generateNonce() (nonce []byte) {
key := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
panic(err)
}
nonce = make([]byte, 24)
base64.StdEncoding.Encode(nonce, key)
return
}
// removeZone removes IPv6 zone identifer from host.
// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
func removeZone(host string) string {
if !strings.HasPrefix(host, "[") {
return host
}
i := strings.LastIndex(host, "]")
if i < 0 {
return host
}
j := strings.LastIndex(host[:i], "%")
if j < 0 {
return host
}
return host[:j] + host[i:]
}
// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
func getNonceAccept(nonce []byte) (expected []byte, err error) {
h := sha1.New()
if _, err = h.Write(nonce); err != nil {
return
}
if _, err = h.Write([]byte(websocketGUID)); err != nil {
return
}
expected = make([]byte, 28)
base64.StdEncoding.Encode(expected, h.Sum(nil))
return
}
// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
// According to RFC 6874, an HTTP client, proxy, or other
// intermediary must remove any IPv6 zone identifier attached
// to an outgoing URI.
bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n")
bw.WriteString("Upgrade: websocket\r\n")
bw.WriteString("Connection: Upgrade\r\n")
nonce := generateNonce()
if config.handshakeData != nil {
nonce = []byte(config.handshakeData["key"])
}
bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
if config.Version != ProtocolVersionHybi13 {
return ErrBadProtocolVersion
}
bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
if len(config.Protocol) > 0 {
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
}
// TODO(ukai): send Sec-WebSocket-Extensions.
err = config.Header.WriteSubset(bw, handshakeHeader)
if err != nil {
return err
}
bw.WriteString("\r\n")
if err = bw.Flush(); err != nil {
return err
}
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
if err != nil {
return err
}
if resp.StatusCode != 101 {
return ErrBadStatus
}
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
return ErrBadUpgrade
}
expectedAccept, err := getNonceAccept(nonce)
if err != nil {
return err
}
if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
return ErrChallengeResponse
}
if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
return ErrUnsupportedExtensions
}
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
if offeredProtocol != "" {
protocolMatched := false
for i := 0; i < len(config.Protocol); i++ {
if config.Protocol[i] == offeredProtocol {
protocolMatched = true
break
}
}
if !protocolMatched {
return ErrBadWebSocketProtocol
}
config.Protocol = []string{offeredProtocol}
}
return nil
}
// newHybiClientConn creates a client WebSocket connection after handshake.
func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
return newHybiConn(config, buf, rwc, nil)
}
// A HybiServerHandshaker performs a server handshake using hybi draft protocol.
type hybiServerHandshaker struct {
*Config
accept []byte
}
func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
c.Version = ProtocolVersionHybi13
if req.Method != "GET" {
return http.StatusMethodNotAllowed, ErrBadRequestMethod
}
// HTTP version can be safely ignored.
if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
return http.StatusBadRequest, ErrNotWebSocket
}
key := req.Header.Get("Sec-Websocket-Key")
if key == "" {
return http.StatusBadRequest, ErrChallengeResponse
}
version := req.Header.Get("Sec-Websocket-Version")
switch version {
case "13":
c.Version = ProtocolVersionHybi13
default:
return http.StatusBadRequest, ErrBadWebSocketVersion
}
var scheme string
if req.TLS != nil {
scheme = "wss"
} else {
scheme = "ws"
}
c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
if err != nil {
return http.StatusBadRequest, err
}
protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
if protocol != "" {
protocols := strings.Split(protocol, ",")
for i := 0; i < len(protocols); i++ {
c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
}
}
c.accept, err = getNonceAccept([]byte(key))
if err != nil {
return http.StatusInternalServerError, err
}
return http.StatusSwitchingProtocols, nil
}
// Origin parses the Origin header in req.
// If the Origin header is not set, it returns nil and nil.
func Origin(config *Config, req *http.Request) (*url.URL, error) {
var origin string
switch config.Version {
case ProtocolVersionHybi13:
origin = req.Header.Get("Origin")
}
if origin == "" {
return nil, nil
}
return url.ParseRequestURI(origin)
}
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
if len(c.Protocol) > 0 {
if len(c.Protocol) != 1 {
// You need choose a Protocol in Handshake func in Server.
return ErrBadWebSocketProtocol
}
}
buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
buf.WriteString("Upgrade: websocket\r\n")
buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
if len(c.Protocol) > 0 {
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
}
// TODO(ukai): send Sec-WebSocket-Extensions.
if c.Header != nil {
err := c.Header.WriteSubset(buf, handshakeHeader)
if err != nil {
return err
}
}
buf.WriteString("\r\n")
return buf.Flush()
}
func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
return newHybiServerConn(c.Config, buf, rwc, request)
}
// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
return newHybiConn(config, buf, rwc, request)
}

View file

@ -0,0 +1,113 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"fmt"
"io"
"net/http"
)
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
code, err := hs.ReadHandshake(buf.Reader, req)
if err == ErrBadWebSocketVersion {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if err != nil {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if handshake != nil {
err = handshake(config, req)
if err != nil {
code = http.StatusForbidden
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
}
err = hs.AcceptHandshake(buf.Writer)
if err != nil {
code = http.StatusBadRequest
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
conn = hs.NewServerConn(buf, rwc, req)
return
}
// Server represents a server of a WebSocket.
type Server struct {
// Config is a WebSocket configuration for new WebSocket connection.
Config
// Handshake is an optional function in WebSocket handshake.
// For example, you can check, or don't check Origin header.
// Another example, you can select config.Protocol.
Handshake func(*Config, *http.Request) error
// Handler handles a WebSocket connection.
Handler
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.serveWebSocket(w, req)
}
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.Error())
}
// The server should abort the WebSocket connection if it finds
// the client did not send a handshake that matches with protocol
// specification.
defer rwc.Close()
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
if err != nil {
return
}
if conn == nil {
panic("unexpected nil conn")
}
s.Handler(conn)
}
// Handler is a simple interface to a WebSocket browser client.
// It checks if Origin header is valid URL by default.
// You might want to verify websocket.Conn.Config().Origin in the func.
// If you use Server instead of Handler, you could call websocket.Origin and
// check the origin in your Handshake func. So, if you want to accept
// non-browser clients, which do not send an Origin header, set a
// Server.Handshake that does not check the origin.
type Handler func(*Conn)
func checkOrigin(config *Config, req *http.Request) (err error) {
config.Origin, err = Origin(config, req)
if err == nil && config.Origin == nil {
return fmt.Errorf("null origin")
}
return err
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s := Server{Handler: h, Handshake: checkOrigin}
s.serveWebSocket(w, req)
}

View file

@ -0,0 +1,451 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package websocket implements a client and server for the WebSocket protocol
// as specified in RFC 6455.
//
// This package currently lacks some features found in alternative
// and more actively maintained WebSocket packages:
//
// https://godoc.org/github.com/gorilla/websocket
// https://godoc.org/nhooyr.io/websocket
package websocket // import "golang.org/x/net/websocket"
import (
"bufio"
"crypto/tls"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"sync"
"time"
)
const (
ProtocolVersionHybi13 = 13
ProtocolVersionHybi = ProtocolVersionHybi13
SupportedProtocolVersion = "13"
ContinuationFrame = 0
TextFrame = 1
BinaryFrame = 2
CloseFrame = 8
PingFrame = 9
PongFrame = 10
UnknownFrame = 255
DefaultMaxPayloadBytes = 32 << 20 // 32MB
)
// ProtocolError represents WebSocket protocol errors.
type ProtocolError struct {
ErrorString string
}
func (err *ProtocolError) Error() string { return err.ErrorString }
var (
ErrBadProtocolVersion = &ProtocolError{"bad protocol version"}
ErrBadScheme = &ProtocolError{"bad scheme"}
ErrBadStatus = &ProtocolError{"bad status"}
ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"}
ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"}
ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"}
ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"}
ErrBadFrame = &ProtocolError{"bad frame"}
ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"}
ErrNotWebSocket = &ProtocolError{"not websocket protocol"}
ErrBadRequestMethod = &ProtocolError{"bad method"}
ErrNotSupported = &ProtocolError{"not supported"}
)
// ErrFrameTooLarge is returned by Codec's Receive method if payload size
// exceeds limit set by Conn.MaxPayloadBytes
var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
// Addr is an implementation of net.Addr for WebSocket.
type Addr struct {
*url.URL
}
// Network returns the network type for a WebSocket, "websocket".
func (addr *Addr) Network() string { return "websocket" }
// Config is a WebSocket configuration
type Config struct {
// A WebSocket server address.
Location *url.URL
// A Websocket client origin.
Origin *url.URL
// WebSocket subprotocols.
Protocol []string
// WebSocket protocol version.
Version int
// TLS config for secure WebSocket (wss).
TlsConfig *tls.Config
// Additional header fields to be sent in WebSocket opening handshake.
Header http.Header
// Dialer used when opening websocket connections.
Dialer *net.Dialer
handshakeData map[string]string
}
// serverHandshaker is an interface to handle WebSocket server side handshake.
type serverHandshaker interface {
// ReadHandshake reads handshake request message from client.
// Returns http response code and error if any.
ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
// AcceptHandshake accepts the client handshake request and sends
// handshake response back to client.
AcceptHandshake(buf *bufio.Writer) (err error)
// NewServerConn creates a new WebSocket connection.
NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
}
// frameReader is an interface to read a WebSocket frame.
type frameReader interface {
// Reader is to read payload of the frame.
io.Reader
// PayloadType returns payload type.
PayloadType() byte
// HeaderReader returns a reader to read header of the frame.
HeaderReader() io.Reader
// TrailerReader returns a reader to read trailer of the frame.
// If it returns nil, there is no trailer in the frame.
TrailerReader() io.Reader
// Len returns total length of the frame, including header and trailer.
Len() int
}
// frameReaderFactory is an interface to creates new frame reader.
type frameReaderFactory interface {
NewFrameReader() (r frameReader, err error)
}
// frameWriter is an interface to write a WebSocket frame.
type frameWriter interface {
// Writer is to write payload of the frame.
io.WriteCloser
}
// frameWriterFactory is an interface to create new frame writer.
type frameWriterFactory interface {
NewFrameWriter(payloadType byte) (w frameWriter, err error)
}
type frameHandler interface {
HandleFrame(frame frameReader) (r frameReader, err error)
WriteClose(status int) (err error)
}
// Conn represents a WebSocket connection.
//
// Multiple goroutines may invoke methods on a Conn simultaneously.
type Conn struct {
config *Config
request *http.Request
buf *bufio.ReadWriter
rwc io.ReadWriteCloser
rio sync.Mutex
frameReaderFactory
frameReader
wio sync.Mutex
frameWriterFactory
frameHandler
PayloadType byte
defaultCloseStatus int
// MaxPayloadBytes limits the size of frame payload received over Conn
// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
MaxPayloadBytes int
}
// Read implements the io.Reader interface:
// it reads data of a frame from the WebSocket connection.
// if msg is not large enough for the frame data, it fills the msg and next Read
// will read the rest of the frame data.
// it reads Text frame or Binary frame.
func (ws *Conn) Read(msg []byte) (n int, err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
again:
if ws.frameReader == nil {
frame, err := ws.frameReaderFactory.NewFrameReader()
if err != nil {
return 0, err
}
ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
if err != nil {
return 0, err
}
if ws.frameReader == nil {
goto again
}
}
n, err = ws.frameReader.Read(msg)
if err == io.EOF {
if trailer := ws.frameReader.TrailerReader(); trailer != nil {
io.Copy(ioutil.Discard, trailer)
}
ws.frameReader = nil
goto again
}
return n, err
}
// Write implements the io.Writer interface:
// it writes data as a frame to the WebSocket connection.
func (ws *Conn) Write(msg []byte) (n int, err error) {
ws.wio.Lock()
defer ws.wio.Unlock()
w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
if err != nil {
return 0, err
}
n, err = w.Write(msg)
w.Close()
return n, err
}
// Close implements the io.Closer interface.
func (ws *Conn) Close() error {
err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
err1 := ws.rwc.Close()
if err != nil {
return err
}
return err1
}
// IsClientConn reports whether ws is a client-side connection.
func (ws *Conn) IsClientConn() bool { return ws.request == nil }
// IsServerConn reports whether ws is a server-side connection.
func (ws *Conn) IsServerConn() bool { return ws.request != nil }
// LocalAddr returns the WebSocket Origin for the connection for client, or
// the WebSocket location for server.
func (ws *Conn) LocalAddr() net.Addr {
if ws.IsClientConn() {
return &Addr{ws.config.Origin}
}
return &Addr{ws.config.Location}
}
// RemoteAddr returns the WebSocket location for the connection for client, or
// the Websocket Origin for server.
func (ws *Conn) RemoteAddr() net.Addr {
if ws.IsClientConn() {
return &Addr{ws.config.Location}
}
return &Addr{ws.config.Origin}
}
var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
// SetDeadline sets the connection's network read & write deadlines.
func (ws *Conn) SetDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetDeadline(t)
}
return errSetDeadline
}
// SetReadDeadline sets the connection's network read deadline.
func (ws *Conn) SetReadDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetReadDeadline(t)
}
return errSetDeadline
}
// SetWriteDeadline sets the connection's network write deadline.
func (ws *Conn) SetWriteDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetWriteDeadline(t)
}
return errSetDeadline
}
// Config returns the WebSocket config.
func (ws *Conn) Config() *Config { return ws.config }
// Request returns the http request upgraded to the WebSocket.
// It is nil for client side.
func (ws *Conn) Request() *http.Request { return ws.request }
// Codec represents a symmetric pair of functions that implement a codec.
type Codec struct {
Marshal func(v interface{}) (data []byte, payloadType byte, err error)
Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
}
// Send sends v marshaled by cd.Marshal as single frame to ws.
func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
data, payloadType, err := cd.Marshal(v)
if err != nil {
return err
}
ws.wio.Lock()
defer ws.wio.Unlock()
w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
if err != nil {
return err
}
_, err = w.Write(data)
w.Close()
return err
}
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
// in v. The whole frame payload is read to an in-memory buffer; max size of
// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
// completely. The next call to Receive would read and discard leftover data of
// previous oversized frame before processing next frame.
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
if ws.frameReader != nil {
_, err = io.Copy(ioutil.Discard, ws.frameReader)
if err != nil {
return err
}
ws.frameReader = nil
}
again:
frame, err := ws.frameReaderFactory.NewFrameReader()
if err != nil {
return err
}
frame, err = ws.frameHandler.HandleFrame(frame)
if err != nil {
return err
}
if frame == nil {
goto again
}
maxPayloadBytes := ws.MaxPayloadBytes
if maxPayloadBytes == 0 {
maxPayloadBytes = DefaultMaxPayloadBytes
}
if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
// payload size exceeds limit, no need to call Unmarshal
//
// set frameReader to current oversized frame so that
// the next call to this function can drain leftover
// data before processing the next frame
ws.frameReader = frame
return ErrFrameTooLarge
}
payloadType := frame.PayloadType()
data, err := ioutil.ReadAll(frame)
if err != nil {
return err
}
return cd.Unmarshal(data, payloadType, v)
}
func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
switch data := v.(type) {
case string:
return []byte(data), TextFrame, nil
case []byte:
return data, BinaryFrame, nil
}
return nil, UnknownFrame, ErrNotSupported
}
func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
switch data := v.(type) {
case *string:
*data = string(msg)
return nil
case *[]byte:
*data = msg
return nil
}
return ErrNotSupported
}
/*
Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
To send/receive text frame, use string type.
To send/receive binary frame, use []byte type.
Trivial usage:
import "websocket"
// receive text frame
var message string
websocket.Message.Receive(ws, &message)
// send text frame
message = "hello"
websocket.Message.Send(ws, message)
// receive binary frame
var data []byte
websocket.Message.Receive(ws, &data)
// send binary frame
data = []byte{0, 1, 2}
websocket.Message.Send(ws, data)
*/
var Message = Codec{marshal, unmarshal}
func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
msg, err = json.Marshal(v)
return msg, TextFrame, err
}
func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
return json.Unmarshal(msg, v)
}
/*
JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
Trivial usage:
import "websocket"
type T struct {
Msg string
Count int
}
// receive JSON type T
var data T
websocket.JSON.Receive(ws, &data)
// send JSON type T
websocket.JSON.Send(ws, data)
*/
var JSON = Codec{jsonMarshal, jsonUnmarshal}

View file

@ -0,0 +1,7 @@
# github.com/ossrs/go-oryx-lib v0.0.8
## explicit
github.com/ossrs/go-oryx-lib/errors
github.com/ossrs/go-oryx-lib/logger
# golang.org/x/net v0.0.0-20210502030024-e5908800b52b
## explicit
golang.org/x/net/websocket

View file

@ -0,0 +1,3 @@
<cross-domain-policy>
<allow-access-from domain="*"/>
</cross-domain-policy>

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 783 B

View file

@ -0,0 +1,13 @@
<html>
<head>
<title>SRS</title>
<meta charset="utf-8">
</head>
<body>
<h3><a href="https://github.com/ossrs/signaling">Signaling</a> works!</h3>
<p>
Run demo for <a href="one2one.html">WebRTC: One to One over SFU(SRS)</a><br/>
点击进入<a href="one2one.html">SRS一对一通话演示</a>
</p>
</body>

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,509 @@
/**
* The MIT License (MIT)
*
* Copyright (c) 2013-2021 Winlin
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
// Depends on adapter-7.4.0.min.js from https://github.com/webrtc/adapter
// Async-awat-prmise based SRS RTC Publisher.
function SrsRtcPublisherAsync() {
var self = {};
// @see https://github.com/rtcdn/rtcdn-draft
// @url The WebRTC url to play with, for example:
// webrtc://r.ossrs.net/live/livestream
// or specifies the API port:
// webrtc://r.ossrs.net:11985/live/livestream
// or autostart the publish:
// webrtc://r.ossrs.net/live/livestream?autostart=true
// or change the app from live to myapp:
// webrtc://r.ossrs.net:11985/myapp/livestream
// or change the stream from livestream to mystream:
// webrtc://r.ossrs.net:11985/live/mystream
// or set the api server to myapi.domain.com:
// webrtc://myapi.domain.com/live/livestream
// or set the candidate(ip) of answer:
// webrtc://r.ossrs.net/live/livestream?eip=39.107.238.185
// or force to access https API:
// webrtc://r.ossrs.net/live/livestream?schema=https
// or use plaintext, without SRTP:
// webrtc://r.ossrs.net/live/livestream?encrypt=false
// or any other information, will pass-by in the query:
// webrtc://r.ossrs.net/live/livestream?vhost=xxx
// webrtc://r.ossrs.net/live/livestream?token=xxx
self.publish = async function (url) {
var conf = self.__internal.prepareUrl(url);
self.pc.addTransceiver("audio", {direction: "sendonly"});
self.pc.addTransceiver("video", {direction: "sendonly"});
var stream = await navigator.mediaDevices.getUserMedia(
{audio: true, video: {height: {max: 320}}}
);
// @see https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/addStream#Migrating_to_addTrack
stream.getTracks().forEach(function (track) {
self.pc.addTrack(track);
});
var offer = await self.pc.createOffer();
await self.pc.setLocalDescription(offer);
var session = await new Promise(function (resolve, reject) {
// @see https://github.com/rtcdn/rtcdn-draft
var data = {
api: conf.apiUrl, tid: conf.tid, streamurl: conf.streamUrl,
clientip: null, sdp: offer.sdp
};
console.log("Generated offer: ", data);
$.ajax({
type: "POST", url: conf.apiUrl, data: JSON.stringify(data),
contentType: 'application/json', dataType: 'json'
}).done(function (data) {
console.log("Got answer: ", data);
if (data.code) {
reject(data);
return;
}
resolve(data);
}).fail(function (reason) {
reject(reason);
});
});
await self.pc.setRemoteDescription(
new RTCSessionDescription({type: 'answer', sdp: session.sdp})
);
session.simulator = conf.schema + '//' + conf.urlObject.server + ':' + conf.port + '/rtc/v1/nack/';
// Notify about local stream when success.
self.onaddstream && self.onaddstream({stream: stream});
return session;
};
// Close the publisher.
self.close = function () {
self.pc && self.pc.close();
self.pc = null;
};
// The callback when got local stream.
self.onaddstream = function (event) {
};
// Internal APIs.
self.__internal = {
defaultPath: '/rtc/v1/publish/',
prepareUrl: function (webrtcUrl) {
var urlObject = self.__internal.parse(webrtcUrl);
// If user specifies the schema, use it as API schema.
var schema = urlObject.user_query.schema;
schema = schema ? schema + ':' : window.location.protocol;
var port = urlObject.port || 1985;
if (schema === 'https:') {
port = urlObject.port || 443;
}
// @see https://github.com/rtcdn/rtcdn-draft
var api = urlObject.user_query.play || self.__internal.defaultPath;
if (api.lastIndexOf('/') !== api.length - 1) {
api += '/';
}
apiUrl = schema + '//' + urlObject.server + ':' + port + api;
for (var key in urlObject.user_query) {
if (key !== 'api' && key !== 'play') {
apiUrl += '&' + key + '=' + urlObject.user_query[key];
}
}
// Replace /rtc/v1/play/&k=v to /rtc/v1/play/?k=v
var apiUrl = apiUrl.replace(api + '&', api + '?');
var streamUrl = urlObject.url;
return {
apiUrl: apiUrl, streamUrl: streamUrl, schema: schema, urlObject: urlObject, port: port,
tid: new Date().getTime().toString(16)
};
},
parse: function (url) {
// @see: http://stackoverflow.com/questions/10469575/how-to-use-location-object-to-parse-url-without-redirecting-the-page-in-javascri
var a = document.createElement("a");
a.href = url.replace("rtmp://", "http://")
.replace("webrtc://", "http://")
.replace("rtc://", "http://");
var vhost = a.hostname;
var app = a.pathname.substr(1, a.pathname.lastIndexOf("/") - 1);
var stream = a.pathname.substr(a.pathname.lastIndexOf("/") + 1);
// parse the vhost in the params of app, that srs supports.
app = app.replace("...vhost...", "?vhost=");
if (app.indexOf("?") >= 0) {
var params = app.substr(app.indexOf("?"));
app = app.substr(0, app.indexOf("?"));
if (params.indexOf("vhost=") > 0) {
vhost = params.substr(params.indexOf("vhost=") + "vhost=".length);
if (vhost.indexOf("&") > 0) {
vhost = vhost.substr(0, vhost.indexOf("&"));
}
}
}
// when vhost equals to server, and server is ip,
// the vhost is __defaultVhost__
if (a.hostname === vhost) {
var re = /^(\d+)\.(\d+)\.(\d+)\.(\d+)$/;
if (re.test(a.hostname)) {
vhost = "__defaultVhost__";
}
}
// parse the schema
var schema = "rtmp";
if (url.indexOf("://") > 0) {
schema = url.substr(0, url.indexOf("://"));
}
var port = a.port;
if (!port) {
if (schema === 'http') {
port = 80;
} else if (schema === 'https') {
port = 443;
} else if (schema === 'rtmp') {
port = 1935;
}
}
var ret = {
url: url,
schema: schema,
server: a.hostname, port: port,
vhost: vhost, app: app, stream: stream
};
self.__internal.fill_query(a.search, ret);
// For webrtc API, we use 443 if page is https, or schema specified it.
if (!ret.port) {
if (schema === 'webrtc' || schema === 'rtc') {
if (ret.user_query.schema === 'https') {
ret.port = 443;
} else if (window.location.href.indexOf('https://') === 0) {
ret.port = 443;
} else {
// For WebRTC, SRS use 1985 as default API port.
ret.port = 1985;
}
}
}
return ret;
},
fill_query: function (query_string, obj) {
// pure user query object.
obj.user_query = {};
if (query_string.length === 0) {
return;
}
// split again for angularjs.
if (query_string.indexOf("?") >= 0) {
query_string = query_string.split("?")[1];
}
var queries = query_string.split("&");
for (var i = 0; i < queries.length; i++) {
var elem = queries[i];
var query = elem.split("=");
obj[query[0]] = query[1];
obj.user_query[query[0]] = query[1];
}
// alias domain for vhost.
if (obj.domain) {
obj.vhost = obj.domain;
}
}
};
self.pc = new RTCPeerConnection(null);
return self;
}
// Depends on adapter-7.4.0.min.js from https://github.com/webrtc/adapter
// Async-await-promise based SRS RTC Player.
function SrsRtcPlayerAsync() {
var self = {};
// @see https://github.com/rtcdn/rtcdn-draft
// @url The WebRTC url to play with, for example:
// webrtc://r.ossrs.net/live/livestream
// or specifies the API port:
// webrtc://r.ossrs.net:11985/live/livestream
// or autostart the play:
// webrtc://r.ossrs.net/live/livestream?autostart=true
// or change the app from live to myapp:
// webrtc://r.ossrs.net:11985/myapp/livestream
// or change the stream from livestream to mystream:
// webrtc://r.ossrs.net:11985/live/mystream
// or set the api server to myapi.domain.com:
// webrtc://myapi.domain.com/live/livestream
// or set the candidate(ip) of answer:
// webrtc://r.ossrs.net/live/livestream?eip=39.107.238.185
// or force to access https API:
// webrtc://r.ossrs.net/live/livestream?schema=https
// or use plaintext, without SRTP:
// webrtc://r.ossrs.net/live/livestream?encrypt=false
// or any other information, will pass-by in the query:
// webrtc://r.ossrs.net/live/livestream?vhost=xxx
// webrtc://r.ossrs.net/live/livestream?token=xxx
self.play = async function(url) {
var conf = self.__internal.prepareUrl(url);
self.pc.addTransceiver("audio", {direction: "recvonly"});
self.pc.addTransceiver("video", {direction: "recvonly"});
var offer = await self.pc.createOffer();
await self.pc.setLocalDescription(offer);
var session = await new Promise(function(resolve, reject) {
// @see https://github.com/rtcdn/rtcdn-draft
var data = {
api: conf.apiUrl, tid: conf.tid, streamurl: conf.streamUrl,
clientip: null, sdp: offer.sdp
};
console.log("Generated offer: ", data);
$.ajax({
type: "POST", url: conf.apiUrl, data: JSON.stringify(data),
contentType:'application/json', dataType: 'json'
}).done(function(data) {
console.log("Got answer: ", data);
if (data.code) {
reject(data); return;
}
resolve(data);
}).fail(function(reason){
reject(reason);
});
});
await self.pc.setRemoteDescription(
new RTCSessionDescription({type: 'answer', sdp: session.sdp})
);
return session;
};
// Close the player.
self.close = function() {
self.pc && self.pc.close();
self.pc = null;
};
// The callback when got remote stream.
self.onaddstream = function (event) {};
// Internal APIs.
self.__internal = {
defaultPath: '/rtc/v1/play/',
prepareUrl: function (webrtcUrl) {
var urlObject = self.__internal.parse(webrtcUrl);
// If user specifies the schema, use it as API schema.
var schema = urlObject.user_query.schema;
schema = schema ? schema + ':' : window.location.protocol;
var port = urlObject.port || 1985;
if (schema === 'https:') {
port = urlObject.port || 443;
}
// @see https://github.com/rtcdn/rtcdn-draft
var api = urlObject.user_query.play || self.__internal.defaultPath;
if (api.lastIndexOf('/') !== api.length - 1) {
api += '/';
}
apiUrl = schema + '//' + urlObject.server + ':' + port + api;
for (var key in urlObject.user_query) {
if (key !== 'api' && key !== 'play') {
apiUrl += '&' + key + '=' + urlObject.user_query[key];
}
}
// Replace /rtc/v1/play/&k=v to /rtc/v1/play/?k=v
var apiUrl = apiUrl.replace(api + '&', api + '?');
var streamUrl = urlObject.url;
return {
apiUrl: apiUrl, streamUrl: streamUrl, schema: schema, urlObject: urlObject, port: port,
tid: new Date().getTime().toString(16)
};
},
parse: function (url) {
// @see: http://stackoverflow.com/questions/10469575/how-to-use-location-object-to-parse-url-without-redirecting-the-page-in-javascri
var a = document.createElement("a");
a.href = url.replace("rtmp://", "http://")
.replace("webrtc://", "http://")
.replace("rtc://", "http://");
var vhost = a.hostname;
var app = a.pathname.substr(1, a.pathname.lastIndexOf("/") - 1);
var stream = a.pathname.substr(a.pathname.lastIndexOf("/") + 1);
// parse the vhost in the params of app, that srs supports.
app = app.replace("...vhost...", "?vhost=");
if (app.indexOf("?") >= 0) {
var params = app.substr(app.indexOf("?"));
app = app.substr(0, app.indexOf("?"));
if (params.indexOf("vhost=") > 0) {
vhost = params.substr(params.indexOf("vhost=") + "vhost=".length);
if (vhost.indexOf("&") > 0) {
vhost = vhost.substr(0, vhost.indexOf("&"));
}
}
}
// when vhost equals to server, and server is ip,
// the vhost is __defaultVhost__
if (a.hostname === vhost) {
var re = /^(\d+)\.(\d+)\.(\d+)\.(\d+)$/;
if (re.test(a.hostname)) {
vhost = "__defaultVhost__";
}
}
// parse the schema
var schema = "rtmp";
if (url.indexOf("://") > 0) {
schema = url.substr(0, url.indexOf("://"));
}
var port = a.port;
if (!port) {
if (schema === 'http') {
port = 80;
} else if (schema === 'https') {
port = 443;
} else if (schema === 'rtmp') {
port = 1935;
}
}
var ret = {
url: url,
schema: schema,
server: a.hostname, port: port,
vhost: vhost, app: app, stream: stream
};
self.__internal.fill_query(a.search, ret);
// For webrtc API, we use 443 if page is https, or schema specified it.
if (!ret.port) {
if (schema === 'webrtc' || schema === 'rtc') {
if (ret.user_query.schema === 'https') {
ret.port = 443;
} else if (window.location.href.indexOf('https://') === 0) {
ret.port = 443;
} else {
// For WebRTC, SRS use 1985 as default API port.
ret.port = 1985;
}
}
}
return ret;
},
fill_query: function (query_string, obj) {
// pure user query object.
obj.user_query = {};
if (query_string.length === 0) {
return;
}
// split again for angularjs.
if (query_string.indexOf("?") >= 0) {
query_string = query_string.split("?")[1];
}
var queries = query_string.split("&");
for (var i = 0; i < queries.length; i++) {
var elem = queries[i];
var query = elem.split("=");
obj[query[0]] = query[1];
obj.user_query[query[0]] = query[1];
}
// alias domain for vhost.
if (obj.domain) {
obj.vhost = obj.domain;
}
}
};
self.pc = new RTCPeerConnection(null);
self.pc.onaddstream = function (event) {
if (self.onaddstream) {
self.onaddstream(event);
}
};
return self;
}
// Format the codec of RTCRtpSender, kind(audio/video) is optional filter.
// https://developer.mozilla.org/en-US/docs/Web/Media/Formats/WebRTC_codecs#getting_the_supported_codecs
function SrsRtcFormatSenders(senders, kind) {
var codecs = [];
senders.forEach(function (sender) {
sender.getParameters().codecs.forEach(function(c) {
if (kind && sender.track.kind !== kind) {
return;
}
if (c.mimeType.indexOf('/red') > 0 || c.mimeType.indexOf('/rtx') > 0 || c.mimeType.indexOf('/fec') > 0) {
return;
}
var s = '';
s += c.mimeType.replace('audio/', '').replace('video/', '');
s += ', ' + c.clockRate + 'HZ';
if (sender.track.kind === "audio") {
s += ', channels: ' + c.channels;
}
s += ', pt: ' + c.payloadType;
codecs.push(s);
});
});
return codecs.join(", ");
}

View file

@ -0,0 +1,120 @@
/**
* The MIT License (MIT)
*
* Copyright (c) 2013-2021 Winlin
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
// Async-await-promise based SRS RTC Signaling.
function SrsRtcSignalingAsync() {
var self = {};
// The schema is ws or wss, host is ip or ip:port, display is nickname
// of user to join the room.
self.connect = async function (schema, host, room, display) {
var url = schema + '://' + host + '/sig/v1/rtc';
self.ws = new WebSocket(url + '?room=' + room + '&display=' + display);
self.ws.onmessage = function(event) {
var r = JSON.parse(event.data);
var promise = self._internals.msgs[r.tid];
if (promise) {
promise.resolve(r.msg);
delete self._internals.msgs[r.tid];
} else {
self.onmessage(r.msg);
}
};
return new Promise(function (resolve, reject) {
self.ws.onopen = function (event) {
resolve(event);
};
self.ws.onerror = function (event) {
reject(event);
};
});
};
// The message is a json object.
self.send = async function (message) {
return new Promise(function (resolve, reject) {
var r = {tid: new Date().getTime().toString(16), msg: message};
self._internals.msgs[r.tid] = {resolve: resolve, reject: reject};
self.ws.send(JSON.stringify(r));
});
};
self.close = function () {
self.ws && self.ws.close();
self.ws = null;
for (const tid in self._internals.msgs) {
var promise = self._internals.msgs[tid];
promise.reject('close');
}
};
// The callback when got messages from signaling server.
self.onmessage = function (msg) {
};
self._internals = {
// Key is tid, value is object {resolve, reject, response}.
msgs: {}
};
return self;
}
// Parse params in query string.
function SrsRtcSignalingParse(location) {
let query = location.href.split('?')[1];
query = query? '?' + query : null;
let wsSchema = location.href.split('wss=')[1];
wsSchema = wsSchema? wsSchema.split('&')[0] : (location.protocol === 'http:'? 'ws' : 'wss');
let wsHost = location.href.split('wsh=')[1];
wsHost = wsHost? wsHost.split('&')[0] : location.hostname;
let wsPort = location.href.split('wsp=')[1];
wsPort = wsPort? wsPort.split('&')[0] : location.host.split(':')[1];
wsHost = wsPort? wsHost.split(':')[0] + ':' + wsPort : wsHost;
let host = location.href.split('host=')[1];
host = host? host.split('&')[0] : location.hostname;
let room = location.href.split('room=')[1];
let display = location.href.split('display=')[1];
display = display? display.split('&')[0] : new Date().getTime().toString(16).substr(3);
let autostart = location.href.split('autostart=')[1];
autostart = autostart && autostart.split('&')[0] === 'true';
return {
query: query, wsSchema: wsSchema, wsHost: wsHost, host: host,
room: room, display: display, autostart: autostart,
};
}

View file

@ -0,0 +1,179 @@
<!DOCTYPE html>
<html>
<head>
<title>SRS</title>
<meta charset="utf-8">
<style>
body{
padding-top: 55px;
}
</style>
<link rel="stylesheet" type="text/css" href="css/bootstrap.min.css"/>
<script type="text/javascript" src="js/jquery-1.10.2.min.js"></script>
<script type="text/javascript" src="js/adapter-7.4.0.min.js"></script>
<script type="text/javascript" src="js/srs.sdk.js"></script>
<script type="text/javascript" src="js/srs.sig.js"></script>
</head>
<body>
<img src='https://ossrs.net/gif/v1/sls.gif?site=ossrs.net&path=/player/rtcpublisher'/>
<div class="navbar navbar-fixed-top">
<div class="navbar-inner">
<div class="container">
<a class="brand" href="https://github.com/ossrs/srs">SRS</a>
<div class="nav-collapse collapse">
<ul class="nav">
<li class="active"><a href="#">一对一通话</a></li>
<li>
<a href="https://github.com/ossrs/signaling">
<img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/ossrs/signaling?style=social">
</a>
</li>
</ul>
</div>
</div>
</div>
</div>
<div class="container">
<div class="form-inline">
SRS:
<input type="text" id="txt_host" class="input-medium" value="">
Room:
<input type="text" id="txt_room" class="input-small" value="live">
Display:
<input type="text" id="txt_display" class="input-small" value="">
<button class="btn btn-primary" id="btn_start">开始通话</button>
</div>
<div class="row">
<div class="span5">
<label></label>
<video id="rtc_media_publisher" width="320" autoplay muted controls></video>
<label></label>
<span id='self'></span>
</div>
<div class="span6">
<label></label>
<video id="rtc_media_player" width="320" autoplay muted controls></video>
<label></label>
<span id='peer'></span>
</div>
</div>
</div>
<script type="text/javascript">
var sig = null;
var publisher = null;
var player = null;
$(function(){
console.log('?wss=x to specify the websocket schema, ws or wss');
console.log('?wsh=x to specify the websocket server ip');
console.log('?wsp=x to specify the websocket server port');
console.log('?host=x to specify the SRS server');
console.log('?room=x to specify the room to join');
console.log('?display=x to specify your nick name');
var startDemo = async function () {
var host = $('#txt_host').val();
var room = $('#txt_room').val();
var display = $('#txt_display').val();
// Connect to signaling first.
if (sig) {
sig.close();
}
sig = new SrsRtcSignalingAsync();
sig.onmessage = function (msg) {
console.log('Notify: ', msg);
msg.participants.forEach(function (participant) {
if (participant.display === display || !participant.publishing) return;
startPlay(host, room, participant.display);
});
};
await sig.connect(conf.wsSchema, conf.wsHost, room, display);
let r0 = await sig.send({action:'join', room:room, display:display});
console.log('Signaling: join ok', r0);
// For one to one demo, alert and ignore when room is full.
if (r0.participants.length > 2) {
alert('Room is full, already ' + (r0.participants.length - 1) + ' participants');
sig.close();
return;
}
// Start publish media if signaling is ok.
await startPublish(host, room, display);
let r1 = await sig.send({action:'publish', room:room, display:display});
console.log('Signaling: publish ok', r1);
// Play the stream already in room.
r0.participants.forEach(function(participant) {
if (participant.display === display || !participant.publishing) return;
startPlay(host, room, participant.display);
});
};
var startPublish = function (host, room, display) {
var url = 'webrtc://' + host + '/' + room + '/' + display + conf.query;
$('#rtc_media_publisher').show();
if (publisher) {
publisher.close();
}
publisher = new SrsRtcPublisherAsync();
publisher.onaddstream = function (event) {
console.log('Start publish, event: ', event);
$('#rtc_media_publisher').prop('srcObject', event.stream);
};
return publisher.publish(url).then(function(session){
$('#self').text('Self: ' + display);
}).catch(function (reason) {
publisher.close();
$('#rtc_media_publisher').hide();
console.error(reason);
});
};
var startPlay = function (host, room, display) {
var url = 'webrtc://' + host + '/' + room + '/' + display + conf.query;
$('#rtc_media_player').show();
if (player) {
player.close();
}
player = new SrsRtcPlayerAsync();
player.onaddstream = function (event) {
console.log('Start play, event: ', event);
$('#rtc_media_player').prop('srcObject', event.stream);
};
player.play(url).then(function(session){
$('#peer').text('Peer: ' + display);
$('#rtc_media_player').prop('muted', false);
}).catch(function (reason) {
player.close();
$('#rtc_media_player').hide();
console.error(reason);
});
};
$('#rtc_media_publisher').hide();
$('#rtc_media_player').hide();
$("#btn_start").click(startDemo);
// Pass-by to SRS url.
let conf = SrsRtcSignalingParse(window.location);
$('#txt_host').val(conf.host);
conf.room && $('#txt_room').val(conf.room);
$('#txt_display').val(conf.display);
if (conf.autostart) {
startDemo();
}
});
</script>
</body>
</html>

BIN
trunk/3rdparty/signaling/www/favicon.ico vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.3 KiB

9
trunk/3rdparty/signaling/www/index.html vendored Executable file
View file

@ -0,0 +1,9 @@
<html>
<head>
<title>SRS</title>
<meta charset="utf-8">
<script type="text/javascript">
window.location.href = "demos/"
</script>
</head>

View file

@ -432,6 +432,10 @@ ln -sf `pwd`/research/api-server/static-dir/favicon.ico ${SRS_OBJS}/nginx/html/f
rm -rf ${SRS_OBJS}/nginx/html/console &&
ln -sf `pwd`/research/console ${SRS_OBJS}/nginx/html/console
# For SRS signaling.
rm -rf ${SRS_OBJS}/nginx/html/demos &&
ln -sf `pwd`/3rdparty/signaling/www/demos ${SRS_OBJS}/nginx/html/demos
# For home page index.html
rm -rf ${SRS_OBJS}/nginx/html/index.html &&
ln -sf `pwd`/research/api-server/static-dir/index.html ${SRS_OBJS}/nginx/html/index.html

View file

@ -119,6 +119,11 @@ srs_error_t SrsGoApiRtcPlay::do_serve_http(ISrsHttpResponseWriter* w, ISrsHttpMe
api = prop->to_str();
}
string tid;
if ((prop = req->ensure_property_string("tid")) != NULL) {
tid = prop->to_str();
}
// TODO: FIXME: Parse vhost.
// Parse app and stream from streamurl.
string app;
@ -139,9 +144,9 @@ srs_error_t SrsGoApiRtcPlay::do_serve_http(ISrsHttpResponseWriter* w, ISrsHttpMe
string srtp = r->query_get("encrypt");
string dtls = r->query_get("dtls");
srs_trace("RTC play %s, api=%s, clientip=%s, app=%s, stream=%s, offer=%dB, eip=%s, codec=%s, srtp=%s, dtls=%s",
streamurl.c_str(), api.c_str(), clientip.c_str(), app.c_str(), stream_name.c_str(), remote_sdp_str.length(), eip.c_str(),
codec.c_str(), srtp.c_str(), dtls.c_str()
srs_trace("RTC play %s, api=%s, tid=%s, clientip=%s, app=%s, stream=%s, offer=%dB, eip=%s, codec=%s, srtp=%s, dtls=%s",
streamurl.c_str(), api.c_str(), tid.c_str(), clientip.c_str(), app.c_str(), stream_name.c_str(), remote_sdp_str.length(),
eip.c_str(), codec.c_str(), srtp.c_str(), dtls.c_str()
);
// The RTC user config object.
@ -488,6 +493,11 @@ srs_error_t SrsGoApiRtcPublish::do_serve_http(ISrsHttpResponseWriter* w, ISrsHtt
api = prop->to_str();
}
string tid;
if ((prop = req->ensure_property_string("tid")) != NULL) {
tid = prop->to_str();
}
// Parse app and stream from streamurl.
string app;
string stream_name;
@ -504,9 +514,9 @@ srs_error_t SrsGoApiRtcPublish::do_serve_http(ISrsHttpResponseWriter* w, ISrsHtt
string eip = r->query_get("eip");
string codec = r->query_get("codec");
srs_trace("RTC publish %s, api=%s, clientip=%s, app=%s, stream=%s, offer=%dB, eip=%s, codec=%s",
streamurl.c_str(), api.c_str(), clientip.c_str(), app.c_str(), stream_name.c_str(), remote_sdp_str.length(), eip.c_str(),
codec.c_str()
srs_trace("RTC publish %s, api=%s, tid=%s, clientip=%s, app=%s, stream=%s, offer=%dB, eip=%s, codec=%s",
streamurl.c_str(), api.c_str(), tid.c_str(), clientip.c_str(), app.c_str(), stream_name.c_str(),
remote_sdp_str.length(), eip.c_str(), codec.c_str()
);
// The RTC user config object.

Some files were not shown because too many files have changed in this diff Show more