1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

GB28181: Support GB28181-2016 protocol. v5.0.74 (#3201)

01. Support GB config as StreamCaster.
02. Support disable GB by --gb28181=off.
03. Add utests for SIP examples.
04. Wireshark plugin to decode TCP/9000 as rtp.rfc4571
05. Support MPEGPS program stream codec.
06. Add utest for PS stream codec.
07. Decode MPEGPS packet stream.
08. Carry RTP and PS packet as helper in PS message.
09. Support recover from error mode.
10. Support process by a pack of PS/TS messages.
11. Add statistic for recovered and msgs dropped.
12. Recover from err position fastly.
13. Define state machine for GB session.
14. Bind context to GB session.
15. Re-invite when media disconnected.
16. Update GitHub actions with GB28181.
17. Support parse CANDIDATE by env or pip.
18. Support mux GB28181 to RTMP.
19. Support regression test by srs-bench.
This commit is contained in:
Winlin 2022-10-06 17:40:58 +08:00 committed by GitHub
parent 9c81a0e1bd
commit 5a420ece3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
298 changed files with 43343 additions and 763 deletions

View file

@ -0,0 +1,25 @@
BSD 2-Clause License
Copyright (c) 2017, The GoSIP authors.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,85 @@
package log
import (
"fmt"
"strings"
)
// Logger interface used as base logger throughout the library.
type Logger interface {
Print(args ...interface{})
Printf(format string, args ...interface{})
Trace(args ...interface{})
Tracef(format string, args ...interface{})
Debug(args ...interface{})
Debugf(format string, args ...interface{})
Info(args ...interface{})
Infof(format string, args ...interface{})
Warn(args ...interface{})
Warnf(format string, args ...interface{})
Error(args ...interface{})
Errorf(format string, args ...interface{})
Fatal(args ...interface{})
Fatalf(format string, args ...interface{})
Panic(args ...interface{})
Panicf(format string, args ...interface{})
WithPrefix(prefix string) Logger
Prefix() string
WithFields(fields Fields) Logger
Fields() Fields
SetLevel(level Level)
}
type Loggable interface {
Log() Logger
}
type Fields map[string]interface{}
func (fields Fields) String() string {
str := make([]string, 0)
for k, v := range fields {
str = append(str, fmt.Sprintf("%s=%+v", k, v))
}
return strings.Join(str, " ")
}
func (fields Fields) WithFields(newFields Fields) Fields {
allFields := make(Fields)
for k, v := range fields {
allFields[k] = v
}
for k, v := range newFields {
allFields[k] = v
}
return allFields
}
func AddFieldsFrom(logger Logger, values ...interface{}) Logger {
for _, value := range values {
switch v := value.(type) {
case Logger:
logger = logger.WithFields(v.Fields())
case Loggable:
logger = logger.WithFields(v.Log().Fields())
case interface{ Fields() Fields }:
logger = logger.WithFields(v.Fields())
}
}
return logger
}

View file

@ -0,0 +1,148 @@
package log
import (
"github.com/sirupsen/logrus"
prefixed "github.com/x-cray/logrus-prefixed-formatter"
)
type LogrusLogger struct {
log logrus.Ext1FieldLogger
prefix string
fields Fields
}
// Level type
type Level uint32
// These are the different logging levels. You can set the logging level to log
// on your instance of logger, obtained with `logrus.New()`.
const (
// PanicLevel level, highest level of severity. Logs and then calls panic with the
// message passed to Debug, Info, ...
PanicLevel Level = iota
// FatalLevel level. Logs and then calls `logger.Exit(1)`. It will exit even if the
// logging level is set to Panic.
FatalLevel
// ErrorLevel level. Logs. Used for errors that should definitely be noted.
// Commonly used for hooks to send errors to an error tracking service.
ErrorLevel
// WarnLevel level. Non-critical entries that deserve eyes.
WarnLevel
// InfoLevel level. General operational entries about what's going on inside the
// application.
InfoLevel
// DebugLevel level. Usually only enabled when debugging. Very verbose logging.
DebugLevel
// TraceLevel level. Designates finer-grained informational events than the Debug.
TraceLevel
)
func NewLogrusLogger(logrus logrus.Ext1FieldLogger, prefix string, fields Fields) *LogrusLogger {
return &LogrusLogger{
log: logrus,
prefix: prefix,
fields: fields,
}
}
func NewDefaultLogrusLogger() *LogrusLogger {
logger := logrus.New()
logger.Formatter = &prefixed.TextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05.000",
}
return NewLogrusLogger(logger, "main", nil)
}
func (l *LogrusLogger) Print(args ...interface{}) {
l.prepareEntry().Print(args...)
}
func (l *LogrusLogger) Printf(format string, args ...interface{}) {
l.prepareEntry().Printf(format, args...)
}
func (l *LogrusLogger) Trace(args ...interface{}) {
l.prepareEntry().Trace(args...)
}
func (l *LogrusLogger) Tracef(format string, args ...interface{}) {
l.prepareEntry().Tracef(format, args...)
}
func (l *LogrusLogger) Debug(args ...interface{}) {
l.prepareEntry().Debug(args...)
}
func (l *LogrusLogger) Debugf(format string, args ...interface{}) {
l.prepareEntry().Debugf(format, args...)
}
func (l *LogrusLogger) Info(args ...interface{}) {
l.prepareEntry().Info(args...)
}
func (l *LogrusLogger) Infof(format string, args ...interface{}) {
l.prepareEntry().Infof(format, args...)
}
func (l *LogrusLogger) Warn(args ...interface{}) {
l.prepareEntry().Warn(args...)
}
func (l *LogrusLogger) Warnf(format string, args ...interface{}) {
l.prepareEntry().Warnf(format, args...)
}
func (l *LogrusLogger) Error(args ...interface{}) {
l.prepareEntry().Error(args...)
}
func (l *LogrusLogger) Errorf(format string, args ...interface{}) {
l.prepareEntry().Errorf(format, args...)
}
func (l *LogrusLogger) Fatal(args ...interface{}) {
l.prepareEntry().Fatal(args...)
}
func (l *LogrusLogger) Fatalf(format string, args ...interface{}) {
l.prepareEntry().Fatalf(format, args...)
}
func (l *LogrusLogger) Panic(args ...interface{}) {
l.prepareEntry().Panic(args...)
}
func (l *LogrusLogger) Panicf(format string, args ...interface{}) {
l.prepareEntry().Panicf(format, args...)
}
func (l *LogrusLogger) WithPrefix(prefix string) Logger {
return NewLogrusLogger(l.log, prefix, l.Fields())
}
func (l *LogrusLogger) Prefix() string {
return l.prefix
}
func (l *LogrusLogger) WithFields(fields Fields) Logger {
return NewLogrusLogger(l.log, l.Prefix(), l.Fields().WithFields(fields))
}
func (l *LogrusLogger) Fields() Fields {
return l.fields
}
func (l *LogrusLogger) prepareEntry() *logrus.Entry {
return l.log.
WithFields(logrus.Fields(l.Fields())).
WithField("prefix", l.Prefix())
}
func (l *LogrusLogger) SetLevel(level Level) {
if ll, ok := l.log.(*logrus.Logger); ok {
ll.SetLevel(logrus.Level(level))
}
}

View file

@ -0,0 +1,279 @@
package sip
import (
"crypto/md5"
"encoding/hex"
"fmt"
"regexp"
"strings"
)
// currently only Digest and MD5
type Authorization struct {
realm string
nonce string
algorithm string
username string
password string
uri string
response string
method string
qop string
nc string
cnonce string
other map[string]string
}
func AuthFromValue(value string) *Authorization {
auth := &Authorization{
algorithm: "MD5",
other: make(map[string]string),
}
re := regexp.MustCompile(`([\w]+)="([^"]+)"`)
matches := re.FindAllStringSubmatch(value, -1)
for _, match := range matches {
switch match[1] {
case "realm":
auth.realm = match[2]
case "algorithm":
auth.algorithm = match[2]
case "nonce":
auth.nonce = match[2]
case "username":
auth.username = match[2]
case "uri":
auth.uri = match[2]
case "response":
auth.response = match[2]
case "qop":
for _, v := range strings.Split(match[2], ",") {
v = strings.Trim(v, " ")
if v == "auth" || v == "auth-int" {
auth.qop = "auth"
break
}
}
case "nc":
auth.nc = match[2]
case "cnonce":
auth.cnonce = match[2]
default:
auth.other[match[1]] = match[2]
}
}
return auth
}
func (auth *Authorization) Realm() string {
return auth.realm
}
func (auth *Authorization) Nonce() string {
return auth.nonce
}
func (auth *Authorization) Algorithm() string {
return auth.algorithm
}
func (auth *Authorization) Username() string {
return auth.username
}
func (auth *Authorization) SetUsername(username string) *Authorization {
auth.username = username
return auth
}
func (auth *Authorization) SetPassword(password string) *Authorization {
auth.password = password
return auth
}
func (auth *Authorization) Uri() string {
return auth.uri
}
func (auth *Authorization) SetUri(uri string) *Authorization {
auth.uri = uri
return auth
}
func (auth *Authorization) SetMethod(method string) *Authorization {
auth.method = method
return auth
}
func (auth *Authorization) Response() string {
return auth.response
}
func (auth *Authorization) SetResponse(response string) {
auth.response = response
}
func (auth *Authorization) Qop() string {
return auth.qop
}
func (auth *Authorization) SetQop(qop string) {
auth.qop = qop
}
func (auth *Authorization) Nc() string {
return auth.nc
}
func (auth *Authorization) SetNc(nc string) {
auth.nc = nc
}
func (auth *Authorization) CNonce() string {
return auth.cnonce
}
func (auth *Authorization) SetCNonce(cnonce string) {
auth.cnonce = cnonce
}
func (auth *Authorization) CalcResponse() string {
return calcResponse(
auth.username,
auth.realm,
auth.password,
auth.method,
auth.uri,
auth.nonce,
auth.qop,
auth.cnonce,
auth.nc,
)
}
func (auth *Authorization) String() string {
if auth == nil {
return "<nil>"
}
str := fmt.Sprintf(
`Digest realm="%s",algorithm=%s,nonce="%s",username="%s",uri="%s",response="%s"`,
auth.realm,
auth.algorithm,
auth.nonce,
auth.username,
auth.uri,
auth.response,
)
if auth.qop == "auth" {
str += fmt.Sprintf(`,qop=%s,nc=%s,cnonce="%s"`, auth.qop, auth.nc, auth.cnonce)
}
return str
}
// calculates Authorization response https://www.ietf.org/rfc/rfc2617.txt
func calcResponse(username, realm, password, method, uri, nonce, qop, cnonce, nc string) string {
calcA1 := func() string {
encoder := md5.New()
encoder.Write([]byte(username + ":" + realm + ":" + password))
return hex.EncodeToString(encoder.Sum(nil))
}
calcA2 := func() string {
encoder := md5.New()
encoder.Write([]byte(method + ":" + uri))
return hex.EncodeToString(encoder.Sum(nil))
}
encoder := md5.New()
encoder.Write([]byte(calcA1() + ":" + nonce + ":"))
if qop != "" {
encoder.Write([]byte(nc + ":" + cnonce + ":" + qop + ":"))
}
encoder.Write([]byte(calcA2()))
return hex.EncodeToString(encoder.Sum(nil))
}
func AuthorizeRequest(request Request, response Response, user, password MaybeString) error {
if user == nil {
return fmt.Errorf("authorize request: user is nil")
}
var authenticateHeaderName, authorizeHeaderName string
if response.StatusCode() == 401 {
// on 401 Unauthorized increase request seq num, add Authorization header and send once again
authenticateHeaderName = "WWW-Authenticate"
authorizeHeaderName = "Authorization"
} else {
// 407 Proxy authentication
authenticateHeaderName = "Proxy-Authenticate"
authorizeHeaderName = "Proxy-Authorization"
}
if hdrs := response.GetHeaders(authenticateHeaderName); len(hdrs) > 0 {
authenticateHeader := hdrs[0].(*GenericHeader)
auth := AuthFromValue(authenticateHeader.Contents).
SetMethod(string(request.Method())).
SetUri(request.Recipient().String()).
SetUsername(user.String())
if password != nil {
auth.SetPassword(password.String())
}
if auth.Qop() == "auth" {
auth.SetNc("00000001")
encoder := md5.New()
encoder.Write([]byte(user.String() + request.Recipient().String()))
if password != nil {
encoder.Write([]byte(password.String()))
}
auth.SetCNonce(hex.EncodeToString(encoder.Sum(nil)))
}
auth.SetResponse(auth.CalcResponse())
if hdrs = request.GetHeaders(authorizeHeaderName); len(hdrs) > 0 {
authorizationHeader := hdrs[0].Clone().(*GenericHeader)
authorizationHeader.Contents = auth.String()
request.ReplaceHeaders(authorizationHeader.Name(), []Header{authorizationHeader})
} else {
request.AppendHeader(&GenericHeader{
HeaderName: authorizeHeaderName,
Contents: auth.String(),
})
}
} else {
return fmt.Errorf("authorize request: header '%s' not found in response", authenticateHeaderName)
}
if viaHop, ok := request.ViaHop(); ok {
viaHop.Params.Add("branch", String{Str: GenerateBranch()})
}
if cseq, ok := request.CSeq(); ok {
cseq := cseq.Clone().(*CSeq)
cseq.SeqNo++
request.ReplaceHeaders(cseq.Name(), []Header{cseq})
}
return nil
}
type Authorizer interface {
AuthorizeRequest(request Request, response Response) error
}
type DefaultAuthorizer struct {
User MaybeString
Password MaybeString
}
func (auth *DefaultAuthorizer) AuthorizeRequest(request Request, response Response) error {
return AuthorizeRequest(request, response, auth.User, auth.Password)
}

View file

@ -0,0 +1,336 @@
package sip
import (
"fmt"
"github.com/ghettovoice/gosip/util"
)
type RequestBuilder struct {
protocol string
protocolVersion string
transport string
host string
method RequestMethod
cseq *CSeq
recipient Uri
body string
callID *CallID
via ViaHeader
from *FromHeader
to *ToHeader
contact *ContactHeader
expires *Expires
userAgent *UserAgentHeader
maxForwards *MaxForwards
supported *SupportedHeader
require *RequireHeader
allow AllowHeader
contentType *ContentType
accept *Accept
route *RouteHeader
generic map[string]Header
}
func NewRequestBuilder() *RequestBuilder {
callID := CallID(util.RandString(32))
maxForwards := MaxForwards(70)
userAgent := UserAgentHeader("GoSIP")
rb := &RequestBuilder{
protocol: "SIP",
protocolVersion: "2.0",
transport: "UDP",
host: "localhost",
cseq: &CSeq{SeqNo: 1},
body: "",
via: make(ViaHeader, 0),
callID: &callID,
userAgent: &userAgent,
maxForwards: &maxForwards,
generic: make(map[string]Header),
}
return rb
}
func (rb *RequestBuilder) SetTransport(transport string) *RequestBuilder {
if transport == "" {
rb.transport = "UDP"
} else {
rb.transport = transport
}
return rb
}
func (rb *RequestBuilder) SetHost(host string) *RequestBuilder {
if host == "" {
rb.host = "localhost"
} else {
rb.host = host
}
return rb
}
func (rb *RequestBuilder) SetMethod(method RequestMethod) *RequestBuilder {
rb.method = method
rb.cseq.MethodName = method
return rb
}
func (rb *RequestBuilder) SetSeqNo(seqNo uint) *RequestBuilder {
rb.cseq.SeqNo = uint32(seqNo)
return rb
}
func (rb *RequestBuilder) SetRecipient(uri Uri) *RequestBuilder {
rb.recipient = uri.Clone()
return rb
}
func (rb *RequestBuilder) SetBody(body string) *RequestBuilder {
rb.body = body
return rb
}
func (rb *RequestBuilder) SetCallID(callID *CallID) *RequestBuilder {
if callID != nil {
rb.callID = callID
}
return rb
}
func (rb *RequestBuilder) AddVia(via *ViaHop) *RequestBuilder {
if via.ProtocolName == "" {
via.ProtocolName = rb.protocol
}
if via.ProtocolVersion == "" {
via.ProtocolVersion = rb.protocolVersion
}
if via.Transport == "" {
via.Transport = rb.transport
}
if via.Host == "" {
via.Host = rb.host
}
if via.Params == nil {
via.Params = NewParams()
}
rb.via = append(rb.via, via)
return rb
}
func (rb *RequestBuilder) SetFrom(address *Address) *RequestBuilder {
if address == nil {
rb.from = nil
} else {
address = address.Clone()
if address.Uri.Host() == "" {
address.Uri.SetHost(rb.host)
}
rb.from = &FromHeader{
DisplayName: address.DisplayName,
Address: address.Uri,
Params: address.Params,
}
}
return rb
}
func (rb *RequestBuilder) SetTo(address *Address) *RequestBuilder {
if address == nil {
rb.to = nil
} else {
address = address.Clone()
if address.Uri.Host() == "" {
address.Uri.SetHost(rb.host)
}
rb.to = &ToHeader{
DisplayName: address.DisplayName,
Address: address.Uri,
Params: address.Params,
}
}
return rb
}
func (rb *RequestBuilder) SetContact(address *Address) *RequestBuilder {
if address == nil {
rb.contact = nil
} else {
address = address.Clone()
if address.Uri.Host() == "" {
address.Uri.SetHost(rb.host)
}
rb.contact = &ContactHeader{
DisplayName: address.DisplayName,
Address: address.Uri,
Params: address.Params,
}
}
return rb
}
func (rb *RequestBuilder) SetExpires(expires *Expires) *RequestBuilder {
rb.expires = expires
return rb
}
func (rb *RequestBuilder) SetUserAgent(userAgent *UserAgentHeader) *RequestBuilder {
rb.userAgent = userAgent
return rb
}
func (rb *RequestBuilder) SetMaxForwards(maxForwards *MaxForwards) *RequestBuilder {
rb.maxForwards = maxForwards
return rb
}
func (rb *RequestBuilder) SetAllow(methods []RequestMethod) *RequestBuilder {
rb.allow = methods
return rb
}
func (rb *RequestBuilder) SetSupported(options []string) *RequestBuilder {
if len(options) == 0 {
rb.supported = nil
} else {
rb.supported = &SupportedHeader{
Options: options,
}
}
return rb
}
func (rb *RequestBuilder) SetRequire(options []string) *RequestBuilder {
if len(options) == 0 {
rb.require = nil
} else {
rb.require = &RequireHeader{
Options: options,
}
}
return rb
}
func (rb *RequestBuilder) SetContentType(contentType *ContentType) *RequestBuilder {
rb.contentType = contentType
return rb
}
func (rb *RequestBuilder) SetAccept(accept *Accept) *RequestBuilder {
rb.accept = accept
return rb
}
func (rb *RequestBuilder) SetRoutes(routes []Uri) *RequestBuilder {
if len(routes) == 0 {
rb.route = nil
} else {
rb.route = &RouteHeader{
Addresses: routes,
}
}
return rb
}
func (rb *RequestBuilder) AddHeader(header Header) *RequestBuilder {
rb.generic[header.Name()] = header
return rb
}
func (rb *RequestBuilder) RemoveHeader(headerName string) *RequestBuilder {
if _, ok := rb.generic[headerName]; ok {
delete(rb.generic, headerName)
}
return rb
}
func (rb *RequestBuilder) Build() (Request, error) {
if rb.method == "" {
return nil, fmt.Errorf("undefined method name")
}
if rb.recipient == nil {
return nil, fmt.Errorf("empty recipient")
}
if rb.from == nil {
return nil, fmt.Errorf("empty 'From' header")
}
if rb.to == nil {
return nil, fmt.Errorf("empty 'From' header")
}
hdrs := make([]Header, 0)
if rb.route != nil {
hdrs = append(hdrs, rb.route)
}
if len(rb.via) != 0 {
via := make(ViaHeader, 0)
for _, viaHop := range rb.via {
via = append(via, viaHop)
}
hdrs = append(hdrs, via)
}
hdrs = append(hdrs, rb.cseq, rb.from, rb.to, rb.callID)
if rb.contact != nil {
hdrs = append(hdrs, rb.contact)
}
if rb.maxForwards != nil {
hdrs = append(hdrs, rb.maxForwards)
}
if rb.expires != nil {
hdrs = append(hdrs, rb.expires)
}
if rb.supported != nil {
hdrs = append(hdrs, rb.supported)
}
if rb.allow != nil {
hdrs = append(hdrs, rb.allow)
}
if rb.contentType != nil {
hdrs = append(hdrs, rb.contentType)
}
if rb.accept != nil {
hdrs = append(hdrs, rb.accept)
}
if rb.userAgent != nil {
hdrs = append(hdrs, rb.userAgent)
}
for _, header := range rb.generic {
hdrs = append(hdrs, header)
}
sipVersion := rb.protocol + "/" + rb.protocolVersion
// basic request
req := NewRequest("", rb.method, rb.recipient, sipVersion, hdrs, "", nil)
req.SetBody(rb.body, true)
return req, nil
}

View file

@ -0,0 +1,412 @@
package sip
import (
"bytes"
"fmt"
"strings"
"github.com/ghettovoice/gosip/util"
)
const (
MTU uint = 1500
DefaultHost = "127.0.0.1"
DefaultProtocol = "UDP"
DefaultUdpPort Port = 5060
DefaultTcpPort Port = 5060
DefaultTlsPort Port = 5061
DefaultWsPort Port = 80
DefaultWssPort Port = 443
)
// TODO should be refactored, currently here the pit
type Address struct {
DisplayName MaybeString
Uri Uri
Params Params
}
func NewAddressFromFromHeader(from *FromHeader) *Address {
addr := &Address{
DisplayName: from.DisplayName,
}
if from.Address != nil {
addr.Uri = from.Address.Clone()
}
if from.Params != nil {
addr.Params = from.Params.Clone()
}
return addr
}
func NewAddressFromToHeader(to *ToHeader) *Address {
addr := &Address{
DisplayName: to.DisplayName,
}
if to.Address != nil {
addr.Uri = to.Address.Clone()
}
if to.Params != nil {
addr.Params = to.Params.Clone()
}
return addr
}
func NewAddressFromContactHeader(cnt *ContactHeader) *Address {
addr := &Address{
DisplayName: cnt.DisplayName,
}
if cnt.Address != nil {
addr.Uri = cnt.Address.Clone()
}
if cnt.Params != nil {
addr.Params = cnt.Params.Clone()
}
return addr
}
func (addr *Address) String() string {
var buffer bytes.Buffer
if addr == nil {
return "<nil>"
}
if addr.DisplayName != nil {
if displayName, ok := addr.DisplayName.(String); ok && displayName.String() != "" {
buffer.WriteString(fmt.Sprintf("\"%s\" ", displayName))
}
}
buffer.WriteString(fmt.Sprintf("<%s>", addr.Uri))
if addr.Params != nil && addr.Params.Length() > 0 {
buffer.WriteString(";")
buffer.WriteString(addr.Params.ToString(';'))
}
return buffer.String()
}
func (addr *Address) Clone() *Address {
var name MaybeString
var uri Uri
var params Params
if addr.DisplayName != nil {
name = String{Str: addr.DisplayName.String()}
}
if addr.Uri != nil {
uri = addr.Uri.Clone()
}
if addr.Params != nil {
params = addr.Params.Clone()
}
return &Address{
DisplayName: name,
Uri: uri,
Params: params,
}
}
func (addr *Address) Equals(other interface{}) bool {
otherPtr, ok := other.(*Address)
if !ok {
return false
}
if addr == otherPtr {
return true
}
if addr == nil && otherPtr != nil || addr != nil && otherPtr == nil {
return false
}
res := true
if addr.DisplayName != otherPtr.DisplayName {
if addr.DisplayName == nil {
res = res && otherPtr.DisplayName == nil
} else {
res = res && addr.DisplayName.Equals(otherPtr.DisplayName)
}
}
if addr.Uri != otherPtr.Uri {
if addr.Uri == nil {
res = res && otherPtr.Uri == nil
} else {
res = res && addr.Uri.Equals(otherPtr.Uri)
}
}
if addr.Params != otherPtr.Params {
if addr.Params == nil {
res = res && otherPtr.Params == nil
} else {
res = res && addr.Params.Equals(otherPtr.Params)
}
}
return res
}
func (addr *Address) AsToHeader() *ToHeader {
to := &ToHeader{
DisplayName: addr.DisplayName,
}
if addr.Uri != nil {
to.Address = addr.Uri.Clone()
}
if addr.Params != nil {
to.Params = addr.Params.Clone()
}
return to
}
func (addr *Address) AsFromHeader() *FromHeader {
from := &FromHeader{
DisplayName: addr.DisplayName,
}
if addr.Uri != nil {
from.Address = addr.Uri.Clone()
}
if addr.Params != nil {
from.Params = addr.Params.Clone()
}
return from
}
func (addr *Address) AsContactHeader() *ContactHeader {
cnt := &ContactHeader{
DisplayName: addr.DisplayName,
}
if addr.Uri != nil {
cnt.Address = addr.Uri.Clone()
}
if addr.Params != nil {
cnt.Params = addr.Params.Clone()
}
return cnt
}
// Port number
type Port uint16
func (port *Port) Clone() *Port {
if port == nil {
return nil
}
newPort := *port
return &newPort
}
func (port *Port) String() string {
if port == nil {
return ""
}
return fmt.Sprintf("%d", *port)
}
func (port *Port) Equals(other interface{}) bool {
if p, ok := other.(*Port); ok {
return util.Uint16PtrEq((*uint16)(port), (*uint16)(p))
}
return false
}
// String wrapper
type MaybeString interface {
String() string
Equals(other interface{}) bool
}
type String struct {
Str string
}
func (str String) String() string {
return str.Str
}
func (str String) Equals(other interface{}) bool {
if v, ok := other.(string); ok {
return str.Str == v
}
if v, ok := other.(String); ok {
return str.Str == v.Str
}
return false
}
type CancelError interface {
Canceled() bool
}
type ExpireError interface {
Expired() bool
}
type MessageError interface {
error
// Malformed indicates that message is syntactically valid but has invalid headers, or
// without required headers.
Malformed() bool
// Broken or incomplete message, or not a SIP message
Broken() bool
}
// Broken or incomplete messages, or not a SIP message.
type BrokenMessageError struct {
Err error
Msg string
}
func (err *BrokenMessageError) Malformed() bool { return false }
func (err *BrokenMessageError) Broken() bool { return true }
func (err *BrokenMessageError) Error() string {
if err == nil {
return "<nil>"
}
s := "BrokenMessageError: " + err.Err.Error()
if err.Msg != "" {
s += fmt.Sprintf("\nMessage dump:\n%s", err.Msg)
}
return s
}
// syntactically valid but logically invalid message
type MalformedMessageError struct {
Err error
Msg string
}
func (err *MalformedMessageError) Malformed() bool { return true }
func (err *MalformedMessageError) Broken() bool { return false }
func (err *MalformedMessageError) Error() string {
if err == nil {
return "<nil>"
}
s := "MalformedMessageError: " + err.Err.Error()
if err.Msg != "" {
s += fmt.Sprintf("\nMessage dump:\n%s", err.Msg)
}
return s
}
type UnsupportedMessageError struct {
Err error
Msg string
}
func (err *UnsupportedMessageError) Malformed() bool { return true }
func (err *UnsupportedMessageError) Broken() bool { return false }
func (err *UnsupportedMessageError) Error() string {
if err == nil {
return "<nil>"
}
s := "UnsupportedMessageError: " + err.Err.Error()
if err.Msg != "" {
s += fmt.Sprintf("\nMessage dump:\n%s", err.Msg)
}
return s
}
type UnexpectedMessageError struct {
Err error
Msg string
}
func (err *UnexpectedMessageError) Broken() bool { return false }
func (err *UnexpectedMessageError) Malformed() bool { return false }
func (err *UnexpectedMessageError) Error() string {
if err == nil {
return "<nil>"
}
s := "UnexpectedMessageError: " + err.Err.Error()
if err.Msg != "" {
s += fmt.Sprintf("\nMessage dump:\n%s", err.Msg)
}
return s
}
const RFC3261BranchMagicCookie = "z9hG4bK"
// GenerateBranch returns random unique branch ID.
func GenerateBranch() string {
return strings.Join([]string{
RFC3261BranchMagicCookie,
util.RandString(32),
}, ".")
}
// DefaultPort returns protocol default port by network.
func DefaultPort(protocol string) Port {
switch strings.ToLower(protocol) {
case "tls":
return DefaultTlsPort
case "tcp":
return DefaultTcpPort
case "udp":
return DefaultUdpPort
case "ws":
return DefaultWsPort
case "wss":
return DefaultWssPort
default:
return DefaultTcpPort
}
}
func MakeDialogIDFromMessage(msg Message) (string, error) {
callID, ok := msg.CallID()
if !ok {
return "", fmt.Errorf("missing Call-ID header")
}
to, ok := msg.To()
if !ok {
return "", fmt.Errorf("missing To header")
}
toTag, ok := to.Params.Get("tag")
if !ok {
return "", fmt.Errorf("missing tag param in To header")
}
from, ok := msg.From()
if !ok {
return "", fmt.Errorf("missing To header")
}
fromTag, ok := from.Params.Get("tag")
if !ok {
return "", fmt.Errorf("missing tag param in From header")
}
return MakeDialogID(string(*callID), toTag.String(), fromTag.String()), nil
}
func MakeDialogID(callID, innerID, externalID string) string {
return strings.Join([]string{callID, innerID, externalID}, "__")
}

View file

@ -0,0 +1,37 @@
package sip
import "fmt"
type RequestError struct {
Request Request
Response Response
Code uint
Reason string
}
func NewRequestError(code uint, reason string, request Request, response Response) *RequestError {
err := &RequestError{
Code: code,
Reason: reason,
}
if request != nil {
err.Request = CopyRequest(request)
}
if response != nil {
err.Response = CopyResponse(response)
}
return err
}
func (err *RequestError) Error() string {
if err == nil {
return "<nil>"
}
reason := err.Reason
if err.Code != 0 {
reason += fmt.Sprintf(" (Code %d)", err.Code)
}
return fmt.Sprintf("sip.RequestError: request failed with reason '%s'", reason)
}

View file

@ -0,0 +1,230 @@
package sip
import (
"strconv"
"strings"
)
// Copyright 2009 The Go Authors. All rights reserved.
// This is actually shorten copy of escape/unescape helpers of the net/url package.
type encoding int
const (
EncodeUserPassword encoding = 1 + iota
EncodeHost
EncodeZone
EncodeQueryComponent
)
type EscapeError string
func (e EscapeError) Error() string {
return "invalid URL escape " + strconv.Quote(string(e))
}
type InvalidHostError string
func (e InvalidHostError) Error() string {
return "invalid character " + strconv.Quote(string(e)) + " in host name"
}
// unescape unescapes a string; the mode specifies
// which section of the URL string is being unescaped.
func Unescape(s string, mode encoding) (string, error) {
// Count %, check that they're well-formed.
n := 0
hasPlus := false
for i := 0; i < len(s); {
switch s[i] {
case '%':
n++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[:3]
}
return "", EscapeError(s)
}
// Per https://tools.ietf.org/html/rfc3986#page-21
// in the host component %-encoding can only be used
// for non-ASCII bytes.
// But https://tools.ietf.org/html/rfc6874#section-2
// introduces %25 being allowed to escape a percent sign
// in IPv6 scoped-address literals. Yay.
if mode == EncodeHost && unhex(s[i+1]) < 8 && s[i:i+3] != "%25" {
return "", EscapeError(s[i : i+3])
}
if mode == EncodeZone {
// RFC 6874 says basically "anything goes" for zone identifiers
// and that even non-ASCII can be redundantly escaped,
// but it seems prudent to restrict %-escaped bytes here to those
// that are valid host name bytes in their unescaped form.
// That is, you can use escaping in the zone identifier but not
// to introduce bytes you couldn't just write directly.
// But Windows puts spaces here! Yay.
v := unhex(s[i+1])<<4 | unhex(s[i+2])
if s[i:i+3] != "%25" && v != ' ' && shouldEscape(v, EncodeHost) {
return "", EscapeError(s[i : i+3])
}
}
i += 3
case '+':
hasPlus = mode == EncodeQueryComponent
i++
default:
if (mode == EncodeHost || mode == EncodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) {
return "", InvalidHostError(s[i : i+1])
}
i++
}
}
if n == 0 && !hasPlus {
return s, nil
}
var t strings.Builder
t.Grow(len(s) - 2*n)
for i := 0; i < len(s); i++ {
switch s[i] {
case '%':
t.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2]))
i += 2
case '+':
t.WriteByte('+')
default:
t.WriteByte(s[i])
}
}
return t.String(), nil
}
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}
// Return true if the specified character should be escaped when
// appearing in a URL string, according to RFC 3986.
//
// Please be informed that for now shouldEscape does not check all
// reserved characters correctly. See golang.org/issue/5684.
func shouldEscape(c byte, mode encoding) bool {
// §2.3 Unreserved characters (alphanum)
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
return false
}
if mode == EncodeHost || mode == EncodeZone {
// §3.2.2 Host allows
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "="
// as part of reg-name.
// We add : because we include :port as part of host.
// We add [ ] because we include [ipv6]:port as part of host.
// We add < > because they're the only characters left that
// we could possibly allow, and Parse will reject them if we
// escape them (because hosts can't use %-encoding for
// ASCII bytes).
switch c {
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"':
return false
}
}
switch c {
case '-', '_', '.', '~': // §2.3 Unreserved characters (mark)
return false
case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved)
// Different sections of the URL allow a few of
// the reserved characters to appear unescaped.
switch mode {
case EncodeUserPassword: // §3.2.1
// The RFC allows ';', ':', '&', '=', '+', '$', and ',' in
// userinfo, so we must escape only '@', '/', and '?'.
// The parsing of userinfo treats ':' as special so we must escape
// that too.
return c == '@' || c == '/' || c == '?' || c == ':'
case EncodeQueryComponent: // §3.4
// The RFC reserves (so we must escape) everything.
return true
}
}
// Everything else must be escaped.
return true
}
const upperhex = "0123456789ABCDEF"
func Escape(s string, mode encoding) string {
spaceCount, hexCount := 0, 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c, mode) {
if c == ' ' && mode == EncodeQueryComponent {
spaceCount++
} else {
hexCount++
}
}
}
if spaceCount == 0 && hexCount == 0 {
return s
}
var buf [64]byte
var t []byte
required := len(s) + 2*hexCount
if required <= len(buf) {
t = buf[:required]
} else {
t = make([]byte, required)
}
if hexCount == 0 {
copy(t, s)
return string(t)
}
j := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case c == ' ' && mode == EncodeQueryComponent:
t[j] = c
j++
case shouldEscape(c, mode):
t[j] = '%'
t[j+1] = upperhex[c>>4]
t[j+2] = upperhex[c&15]
j += 3
default:
t[j] = s[i]
j++
}
}
return string(t)
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,542 @@
package sip
import (
"bytes"
"strings"
"sync"
uuid "github.com/satori/go.uuid"
"github.com/ghettovoice/gosip/log"
)
// A representation of a SIP method.
// This is syntactic sugar around the string type, so make sure to use
// the Equals method rather than built-in equality, or you'll fall foul of case differences.
// If you're defining your own Method, uppercase is preferred but not compulsory.
type RequestMethod string
// StatusCode - response status code: 1xx - 6xx
type StatusCode uint16
// Determine if the given method equals some other given method.
// This is syntactic sugar for case insensitive equality checking.
func (method *RequestMethod) Equals(other *RequestMethod) bool {
if method != nil && other != nil {
return strings.EqualFold(string(*method), string(*other))
} else {
return method == other
}
}
// It's nicer to avoid using raw strings to represent methods, so the following standard
// method names are defined here as constants for convenience.
const (
INVITE RequestMethod = "INVITE"
ACK RequestMethod = "ACK"
CANCEL RequestMethod = "CANCEL"
BYE RequestMethod = "BYE"
REGISTER RequestMethod = "REGISTER"
OPTIONS RequestMethod = "OPTIONS"
SUBSCRIBE RequestMethod = "SUBSCRIBE"
NOTIFY RequestMethod = "NOTIFY"
REFER RequestMethod = "REFER"
INFO RequestMethod = "INFO"
MESSAGE RequestMethod = "MESSAGE"
PRACK RequestMethod = "PRACK"
UPDATE RequestMethod = "UPDATE"
PUBLISH RequestMethod = "PUBLISH"
)
type MessageID string
func NextMessageID() MessageID {
return MessageID(uuid.Must(uuid.NewV4()).String())
}
// Message introduces common SIP message RFC 3261 - 7.
type Message interface {
MessageID() MessageID
Clone() Message
// Start line returns message start line.
StartLine() string
// String returns string representation of SIP message in RFC 3261 form.
String() string
// Short returns short string info about message.
Short() string
// SipVersion returns SIP protocol version.
SipVersion() string
// SetSipVersion sets SIP protocol version.
SetSipVersion(version string)
// Headers returns all message headers.
Headers() []Header
// GetHeaders returns slice of headers of the given type.
GetHeaders(name string) []Header
// AppendHeader appends header to message.
AppendHeader(header Header)
// PrependHeader prepends header to message.
PrependHeader(header Header)
PrependHeaderAfter(header Header, afterName string)
// RemoveHeader removes header from message.
RemoveHeader(name string)
ReplaceHeaders(name string, headers []Header)
// Body returns message body.
Body() string
// SetBody sets message body.
SetBody(body string, setContentLength bool)
/* Helper getters for common headers */
// CallID returns 'Call-ID' header.
CallID() (*CallID, bool)
// Via returns the top 'Via' header field.
Via() (ViaHeader, bool)
// ViaHop returns the first segment of the top 'Via' header.
ViaHop() (*ViaHop, bool)
// From returns 'From' header field.
From() (*FromHeader, bool)
// To returns 'To' header field.
To() (*ToHeader, bool)
// CSeq returns 'CSeq' header field.
CSeq() (*CSeq, bool)
ContentLength() (*ContentLength, bool)
ContentType() (*ContentType, bool)
Contact() (*ContactHeader, bool)
Transport() string
SetTransport(tp string)
Source() string
SetSource(src string)
Destination() string
SetDestination(dest string)
IsCancel() bool
IsAck() bool
Fields() log.Fields
WithFields(fields log.Fields) Message
}
// headers is a struct with methods to work with SIP headers.
type headers struct {
mu sync.RWMutex
// The logical SIP headers attached to this message.
headers map[string][]Header
// The order the headers should be displayed in.
headerOrder []string
}
func newHeaders(hdrs []Header) *headers {
hs := new(headers)
hs.headers = make(map[string][]Header)
hs.headerOrder = make([]string, 0)
for _, header := range hdrs {
hs.AppendHeader(header)
}
return hs
}
func (hs *headers) String() string {
buffer := bytes.Buffer{}
hs.mu.RLock()
// Construct each header in turn and add it to the message.
for typeIdx, name := range hs.headerOrder {
headers := hs.headers[name]
for idx, header := range headers {
buffer.WriteString(header.String())
if typeIdx < len(hs.headerOrder) || idx < len(headers) {
buffer.WriteString("\r\n")
}
}
}
hs.mu.RUnlock()
return buffer.String()
}
// Add the given header.
func (hs *headers) AppendHeader(header Header) {
name := strings.ToLower(header.Name())
hs.mu.Lock()
if _, ok := hs.headers[name]; ok {
hs.headers[name] = append(hs.headers[name], header)
} else {
hs.headers[name] = []Header{header}
hs.headerOrder = append(hs.headerOrder, name)
}
hs.mu.Unlock()
}
// AddFrontHeader adds header to the front of header list
// if there is no header has h's name, add h to the font of all headers
// if there are some headers have h's name, add h to front of the sublist
func (hs *headers) PrependHeader(header Header) {
name := strings.ToLower(header.Name())
hs.mu.Lock()
if hdrs, ok := hs.headers[name]; ok {
hs.headers[name] = append([]Header{header}, hdrs...)
} else {
hs.headers[name] = []Header{header}
newOrder := make([]string, 1, len(hs.headerOrder)+1)
newOrder[0] = name
hs.headerOrder = append(newOrder, hs.headerOrder...)
}
hs.mu.Unlock()
}
func (hs *headers) PrependHeaderAfter(header Header, afterName string) {
headerName := strings.ToLower(header.Name())
afterName = strings.ToLower(afterName)
hs.mu.Lock()
if _, ok := hs.headers[afterName]; ok {
afterIdx := -1
headerIdx := -1
for i, name := range hs.headerOrder {
if name == afterName {
afterIdx = i
}
if name == headerName {
headerIdx = i
}
}
if headerIdx == -1 {
hs.headers[headerName] = []Header{header}
newOrder := make([]string, 0)
newOrder = append(newOrder, hs.headerOrder[:afterIdx+1]...)
newOrder = append(newOrder, headerName)
newOrder = append(newOrder, hs.headerOrder[afterIdx+1:]...)
hs.headerOrder = newOrder
} else {
hs.headers[headerName] = append([]Header{header}, hs.headers[headerName]...)
newOrder := make([]string, 0)
if afterIdx < headerIdx {
newOrder = append(newOrder, hs.headerOrder[:afterIdx+1]...)
newOrder = append(newOrder, headerName)
newOrder = append(newOrder, hs.headerOrder[afterIdx+1:headerIdx]...)
newOrder = append(newOrder, hs.headerOrder[headerIdx+1:]...)
} else {
newOrder = append(newOrder, hs.headerOrder[:headerIdx]...)
newOrder = append(newOrder, hs.headerOrder[headerIdx+1:afterIdx+1]...)
newOrder = append(newOrder, headerName)
newOrder = append(newOrder, hs.headerOrder[afterIdx+1:]...)
}
hs.headerOrder = newOrder
}
hs.mu.Unlock()
} else {
hs.mu.Unlock()
hs.PrependHeader(header)
}
}
func (hs *headers) ReplaceHeaders(name string, headers []Header) {
name = strings.ToLower(name)
hs.mu.Lock()
if _, ok := hs.headers[name]; ok {
hs.headers[name] = headers
}
hs.mu.Unlock()
}
// Gets some headers.
func (hs *headers) Headers() []Header {
hdrs := make([]Header, 0)
hs.mu.RLock()
for _, key := range hs.headerOrder {
hdrs = append(hdrs, hs.headers[key]...)
}
hs.mu.RUnlock()
return hdrs
}
func (hs *headers) GetHeaders(name string) []Header {
name = strings.ToLower(name)
hs.mu.RLock()
defer hs.mu.RUnlock()
if hs.headers == nil {
hs.headers = map[string][]Header{}
hs.headerOrder = []string{}
}
if headers, ok := hs.headers[name]; ok {
return headers
}
return []Header{}
}
func (hs *headers) RemoveHeader(name string) {
name = strings.ToLower(name)
hs.mu.Lock()
delete(hs.headers, name)
// update order slice
for idx, entry := range hs.headerOrder {
if entry == name {
hs.headerOrder = append(hs.headerOrder[:idx], hs.headerOrder[idx+1:]...)
break
}
}
hs.mu.Unlock()
}
// CloneHeaders returns all cloned headers in slice.
func (hs *headers) CloneHeaders() []Header {
return cloneHeaders(hs)
}
func cloneHeaders(msg interface{ Headers() []Header }) []Header {
hdrs := make([]Header, 0)
for _, header := range msg.Headers() {
hdrs = append(hdrs, header.Clone())
}
return hdrs
}
func (hs *headers) CallID() (*CallID, bool) {
hdrs := hs.GetHeaders("Call-ID")
if len(hdrs) == 0 {
return nil, false
}
callId, ok := hdrs[0].(*CallID)
if !ok {
return nil, false
}
return callId, true
}
func (hs *headers) Via() (ViaHeader, bool) {
hdrs := hs.GetHeaders("Via")
if len(hdrs) == 0 {
return nil, false
}
via, ok := (hdrs[0]).(ViaHeader)
if !ok {
return nil, false
}
return via, true
}
func (hs *headers) ViaHop() (*ViaHop, bool) {
via, ok := hs.Via()
if !ok {
return nil, false
}
hops := []*ViaHop(via)
if len(hops) == 0 {
return nil, false
}
return hops[0], true
}
func (hs *headers) From() (*FromHeader, bool) {
hdrs := hs.GetHeaders("From")
if len(hdrs) == 0 {
return nil, false
}
from, ok := hdrs[0].(*FromHeader)
if !ok {
return nil, false
}
return from, true
}
func (hs *headers) To() (*ToHeader, bool) {
hdrs := hs.GetHeaders("To")
if len(hdrs) == 0 {
return nil, false
}
to, ok := hdrs[0].(*ToHeader)
if !ok {
return nil, false
}
return to, true
}
func (hs *headers) CSeq() (*CSeq, bool) {
hdrs := hs.GetHeaders("CSeq")
if len(hdrs) == 0 {
return nil, false
}
cseq, ok := hdrs[0].(*CSeq)
if !ok {
return nil, false
}
return cseq, true
}
func (hs *headers) ContentLength() (*ContentLength, bool) {
hdrs := hs.GetHeaders("Content-Length")
if len(hdrs) == 0 {
return nil, false
}
contentLength, ok := hdrs[0].(*ContentLength)
if !ok {
return nil, false
}
return contentLength, true
}
func (hs *headers) ContentType() (*ContentType, bool) {
hdrs := hs.GetHeaders("Content-Type")
if len(hdrs) == 0 {
return nil, false
}
contentType, ok := hdrs[0].(*ContentType)
if !ok {
return nil, false
}
return contentType, true
}
func (hs *headers) Contact() (*ContactHeader, bool) {
hdrs := hs.GetHeaders("Contact")
if len(hdrs) == 0 {
return nil, false
}
contactHeader, ok := hdrs[0].(*ContactHeader)
if !ok {
return nil, false
}
return contactHeader, true
}
// basic message implementation
type message struct {
// message headers
*headers
mu sync.RWMutex
messID MessageID
sipVersion string
body string
startLine func() string
tp string
src string
dest string
fields log.Fields
}
func (msg *message) MessageID() MessageID {
return msg.messID
}
func (msg *message) StartLine() string {
return msg.startLine()
}
func (msg *message) Fields() log.Fields {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.fields.WithFields(log.Fields{
"transport": msg.tp,
"source": msg.src,
"destination": msg.dest,
})
}
func (msg *message) String() string {
var buffer bytes.Buffer
// write message start line
buffer.WriteString(msg.StartLine() + "\r\n")
// Write the headers.
msg.mu.RLock()
buffer.WriteString(msg.headers.String())
msg.mu.RUnlock()
// message body
buffer.WriteString("\r\n" + msg.Body())
return buffer.String()
}
func (msg *message) SipVersion() string {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.sipVersion
}
func (msg *message) SetSipVersion(version string) {
msg.mu.Lock()
msg.sipVersion = version
msg.mu.Unlock()
}
func (msg *message) Body() string {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.body
}
// SetBody sets message body, calculates it length and add 'Content-Length' header.
func (msg *message) SetBody(body string, setContentLength bool) {
msg.mu.Lock()
msg.body = body
msg.mu.Unlock()
if setContentLength {
hdrs := msg.GetHeaders("Content-Length")
if len(hdrs) == 0 {
length := ContentLength(len(body))
msg.AppendHeader(&length)
} else {
length := ContentLength(len(body))
msg.ReplaceHeaders("Content-Length", []Header{&length})
}
}
}
func (msg *message) Transport() string {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.tp
}
func (msg *message) SetTransport(tp string) {
msg.mu.Lock()
msg.tp = strings.ToUpper(tp)
msg.mu.Unlock()
}
func (msg *message) Source() string {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.src
}
func (msg *message) SetSource(src string) {
msg.mu.Lock()
msg.src = src
msg.mu.Unlock()
}
func (msg *message) Destination() string {
msg.mu.RLock()
defer msg.mu.RUnlock()
return msg.dest
}
func (msg *message) SetDestination(dest string) {
msg.mu.Lock()
msg.dest = dest
msg.mu.Unlock()
}
// Copy all headers of one type from one message to another.
// Appending to any headers that were already there.
func CopyHeaders(name string, from, to Message) {
name = strings.ToLower(name)
for _, h := range from.GetHeaders(name) {
to.AppendHeader(h.Clone())
}
}
func PrependCopyHeaders(name string, from, to Message) {
name = strings.ToLower(name)
for _, h := range from.GetHeaders(name) {
to.PrependHeader(h.Clone())
}
}
type MessageMapper func(msg Message) Message

View file

@ -0,0 +1,5 @@
# SIP parser
> Package implements SIP protocol parser compatible with [RFC 3261](https://tools.ietf.org/html/rfc3261)
Originally forked from [gossip](https://github.com/StefanKopieczek/gossip) library by @StefanKopieczek.

View file

@ -0,0 +1,26 @@
package parser
type Error interface {
error
// Syntax indicates that this is syntax error
Syntax() bool
}
type InvalidStartLineError string
func (err InvalidStartLineError) Syntax() bool { return true }
func (err InvalidStartLineError) Malformed() bool { return false }
func (err InvalidStartLineError) Broken() bool { return true }
func (err InvalidStartLineError) Error() string { return "parser.InvalidStartLineError: " + string(err) }
type InvalidMessageFormat string
func (err InvalidMessageFormat) Syntax() bool { return true }
func (err InvalidMessageFormat) Malformed() bool { return true }
func (err InvalidMessageFormat) Broken() bool { return true }
func (err InvalidMessageFormat) Error() string { return "parser.InvalidMessageFormat: " + string(err) }
type WriteError string
func (err WriteError) Syntax() bool { return false }
func (err WriteError) Error() string { return "parser.WriteError: " + string(err) }

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,131 @@
// Forked from github.com/StefanKopieczek/gossip by @StefanKopieczek
package parser
import (
"bufio"
"bytes"
"fmt"
"io"
"sync"
"github.com/ghettovoice/gosip/log"
)
// parserBuffer is a specialized buffer for use in the parser.
// It is written to via the non-blocking Write.
// It exposes various blocking read methods, which wait until the requested
// data is available, and then return it.
type parserBuffer struct {
mu sync.RWMutex
writer io.Writer
buffer bytes.Buffer
// Wraps parserBuffer.pipeReader
reader *bufio.Reader
// Don't access this directly except when closing.
pipeReader *io.PipeReader
log log.Logger
}
// Create a new parserBuffer object (see struct comment for object details).
// Note that resources owned by the parserBuffer may not be able to be GCed
// until the Dispose() method is called.
func newParserBuffer(logger log.Logger) *parserBuffer {
var pb parserBuffer
pb.pipeReader, pb.writer = io.Pipe()
pb.reader = bufio.NewReader(pb.pipeReader)
pb.log = logger.
WithPrefix("parser.parserBuffer").
WithFields(log.Fields{
"parser_buffer_ptr": fmt.Sprintf("%p", &pb),
})
return &pb
}
func (pb *parserBuffer) Log() log.Logger {
return pb.log
}
func (pb *parserBuffer) Write(p []byte) (n int, err error) {
pb.mu.RLock()
defer pb.mu.RUnlock()
return pb.writer.Write(p)
}
// Block until the buffer contains at least one CRLF-terminated line.
// Return the line, excluding the terminal CRLF, and delete it from the buffer.
// Returns an error if the parserbuffer has been stopped.
func (pb *parserBuffer) NextLine() (response string, err error) {
var buffer bytes.Buffer
var data string
var b byte
// There has to be a better way!
for {
data, err = pb.reader.ReadString('\r')
if err != nil {
return
}
buffer.WriteString(data)
b, err = pb.reader.ReadByte()
if err != nil {
return
}
buffer.WriteByte(b)
if b == '\n' {
response = buffer.String()
response = response[:len(response)-2]
pb.Log().Tracef("return line '%s'", response)
return
}
}
}
// Block until the buffer contains at least n characters.
// Return precisely those n characters, then delete them from the buffer.
func (pb *parserBuffer) NextChunk(n int) (response string, err error) {
var data = make([]byte, n)
var read int
for total := 0; total < n; {
read, err = pb.reader.Read(data[total:])
total += read
if err != nil {
return
}
}
response = string(data)
pb.Log().Tracef("return chunk:\n%s", response)
return
}
// Stop the parser buffer.
func (pb *parserBuffer) Stop() {
pb.mu.RLock()
if err := pb.pipeReader.Close(); err != nil {
pb.Log().Errorf("parser pipe reader close failed: %s", err)
}
pb.mu.RUnlock()
pb.Log().Trace("parser buffer stopped")
}
func (pb *parserBuffer) Reset() {
pb.mu.Lock()
pb.pipeReader, pb.writer = io.Pipe()
pb.reader.Reset(pb.pipeReader)
pb.mu.Unlock()
}

View file

@ -0,0 +1,382 @@
package sip
import (
"bytes"
"fmt"
"strconv"
"strings"
"github.com/ghettovoice/gosip/log"
)
// Request RFC 3261 - 7.1.
type Request interface {
Message
Method() RequestMethod
SetMethod(method RequestMethod)
Recipient() Uri
SetRecipient(recipient Uri)
/* Common Helpers */
IsInvite() bool
}
type request struct {
message
method RequestMethod
recipient Uri
}
func NewRequest(
messID MessageID,
method RequestMethod,
recipient Uri,
sipVersion string,
hdrs []Header,
body string,
fields log.Fields,
) Request {
req := new(request)
if messID == "" {
req.messID = NextMessageID()
} else {
req.messID = messID
}
req.startLine = req.StartLine
req.sipVersion = sipVersion
req.headers = newHeaders(hdrs)
req.method = method
req.recipient = recipient
req.body = body
req.fields = fields.WithFields(log.Fields{
"request_id": req.messID,
})
return req
}
func (req *request) Short() string {
if req == nil {
return "<nil>"
}
fields := log.Fields{
"method": req.Method(),
"recipient": req.Recipient(),
"transport": req.Transport(),
"source": req.Source(),
"destination": req.Destination(),
}
if cseq, ok := req.CSeq(); ok {
fields["sequence"] = cseq.SeqNo
}
fields = req.Fields().WithFields(fields)
return fmt.Sprintf("sip.Request<%s>", fields)
}
func (req *request) Method() RequestMethod {
req.mu.RLock()
defer req.mu.RUnlock()
return req.method
}
func (req *request) SetMethod(method RequestMethod) {
req.mu.Lock()
req.method = method
req.mu.Unlock()
}
func (req *request) Recipient() Uri {
req.mu.RLock()
defer req.mu.RUnlock()
return req.recipient
}
func (req *request) SetRecipient(recipient Uri) {
req.mu.Lock()
req.recipient = recipient
req.mu.Unlock()
}
// StartLine returns Request Line - RFC 2361 7.1.
func (req *request) StartLine() string {
var buffer bytes.Buffer
// Every SIP request starts with a Request Line - RFC 2361 7.1.
buffer.WriteString(
fmt.Sprintf(
"%s %s %s",
string(req.Method()),
req.Recipient(),
req.SipVersion(),
),
)
return buffer.String()
}
func (req *request) Clone() Message {
return cloneRequest(req, "", nil)
}
func (req *request) Fields() log.Fields {
return req.fields.WithFields(log.Fields{
"transport": req.Transport(),
"source": req.Source(),
"destination": req.Destination(),
})
}
func (req *request) WithFields(fields log.Fields) Message {
req.mu.Lock()
req.fields = req.fields.WithFields(fields)
req.mu.Unlock()
return req
}
func (req *request) IsInvite() bool {
return req.Method() == INVITE
}
func (req *request) IsAck() bool {
return req.Method() == ACK
}
func (req *request) IsCancel() bool {
return req.Method() == CANCEL
}
func (req *request) Transport() string {
if tp := req.message.Transport(); tp != "" {
return strings.ToUpper(tp)
}
var tp string
if viaHop, ok := req.ViaHop(); ok && viaHop.Transport != "" {
tp = viaHop.Transport
} else {
tp = DefaultProtocol
}
uri := req.Recipient()
if hdrs := req.GetHeaders("Route"); len(hdrs) > 0 {
routeHeader, ok := hdrs[0].(*RouteHeader)
if ok && len(routeHeader.Addresses) > 0 {
uri = routeHeader.Addresses[0]
}
}
if uri != nil {
if uri.UriParams() != nil {
if val, ok := uri.UriParams().Get("transport"); ok && !val.Equals("") {
tp = strings.ToUpper(val.String())
}
}
if uri.IsEncrypted() {
if tp == "TCP" {
tp = "TLS"
} else if tp == "WS" {
tp = "WSS"
}
}
}
if tp == "UDP" && len(req.String()) > int(MTU)-200 {
tp = "TCP"
}
return tp
}
func (req *request) Source() string {
if src := req.message.Source(); src != "" {
return src
}
viaHop, ok := req.ViaHop()
if !ok {
return ""
}
var (
host string
port Port
)
host = viaHop.Host
if viaHop.Port != nil {
port = *viaHop.Port
} else {
port = DefaultPort(req.Transport())
}
if viaHop.Params != nil {
if received, ok := viaHop.Params.Get("received"); ok && received.String() != "" {
host = received.String()
}
if rport, ok := viaHop.Params.Get("rport"); ok && rport != nil && rport.String() != "" {
if p, err := strconv.Atoi(rport.String()); err == nil {
port = Port(uint16(p))
}
}
}
return fmt.Sprintf("%v:%v", host, port)
}
func (req *request) Destination() string {
if dest := req.message.Destination(); dest != "" {
return dest
}
var uri *SipUri
if hdrs := req.GetHeaders("Route"); len(hdrs) > 0 {
routeHeader, ok := hdrs[0].(*RouteHeader)
if ok && len(routeHeader.Addresses) > 0 {
uri = routeHeader.Addresses[0].(*SipUri)
}
}
if uri == nil {
if u, ok := req.Recipient().(*SipUri); ok {
uri = u
} else {
return ""
}
}
host := uri.FHost
var port Port
if uri.FPort != nil {
port = *uri.FPort
} else {
port = DefaultPort(req.Transport())
}
return fmt.Sprintf("%v:%v", host, port)
}
// NewAckRequest creates ACK request for 2xx INVITE
// https://tools.ietf.org/html/rfc3261#section-13.2.2.4
func NewAckRequest(ackID MessageID, inviteRequest Request, inviteResponse Response, body string, fields log.Fields) Request {
recipient := inviteRequest.Recipient()
if contact, ok := inviteResponse.Contact(); ok {
// For ws and wss (like clients in browser), don't use Contact
if strings.Index(strings.ToLower(recipient.String()), "transport=ws") == -1 {
recipient = contact.Address
}
}
ackRequest := NewRequest(
ackID,
ACK,
recipient,
inviteRequest.SipVersion(),
[]Header{},
body,
inviteRequest.Fields().
WithFields(fields).
WithFields(log.Fields{
"invite_request_id": inviteRequest.MessageID(),
"invite_response_id": inviteResponse.MessageID(),
}),
)
CopyHeaders("Via", inviteRequest, ackRequest)
if inviteResponse.IsSuccess() {
// update branch, 2xx ACK is separate Tx
viaHop, _ := ackRequest.ViaHop()
viaHop.Params.Add("branch", String{Str: GenerateBranch()})
}
if len(inviteRequest.GetHeaders("Route")) > 0 {
CopyHeaders("Route", inviteRequest, ackRequest)
} else {
hdrs := inviteResponse.GetHeaders("Record-Route")
for i := len(hdrs) - 1; i >= 0; i-- {
h := hdrs[i]
uris := make([]Uri, 0)
for j := len(h.(*RecordRouteHeader).Addresses) - 1; j >= 0; j-- {
uris = append(uris, h.(*RecordRouteHeader).Addresses[j].Clone())
}
ackRequest.AppendHeader(&RouteHeader{
Addresses: uris,
})
}
}
maxForwardsHeader := MaxForwards(70)
ackRequest.AppendHeader(&maxForwardsHeader)
CopyHeaders("From", inviteRequest, ackRequest)
CopyHeaders("To", inviteResponse, ackRequest)
CopyHeaders("Call-ID", inviteRequest, ackRequest)
CopyHeaders("CSeq", inviteRequest, ackRequest)
cseq, _ := ackRequest.CSeq()
cseq.MethodName = ACK
ackRequest.SetBody("", true)
ackRequest.SetTransport(inviteRequest.Transport())
ackRequest.SetSource(inviteRequest.Source())
ackRequest.SetDestination(inviteRequest.Destination())
return ackRequest
}
func NewCancelRequest(cancelID MessageID, requestForCancel Request, fields log.Fields) Request {
cancelReq := NewRequest(
cancelID,
CANCEL,
requestForCancel.Recipient(),
requestForCancel.SipVersion(),
[]Header{},
"",
requestForCancel.Fields().
WithFields(fields).
WithFields(log.Fields{
"cancelling_request_id": requestForCancel.MessageID(),
}),
)
viaHop, _ := requestForCancel.ViaHop()
cancelReq.AppendHeader(ViaHeader{viaHop.Clone()})
CopyHeaders("Route", requestForCancel, cancelReq)
maxForwardsHeader := MaxForwards(70)
cancelReq.AppendHeader(&maxForwardsHeader)
CopyHeaders("From", requestForCancel, cancelReq)
CopyHeaders("To", requestForCancel, cancelReq)
CopyHeaders("Call-ID", requestForCancel, cancelReq)
CopyHeaders("CSeq", requestForCancel, cancelReq)
cseq, _ := cancelReq.CSeq()
cseq.MethodName = CANCEL
cancelReq.SetBody("", true)
cancelReq.SetTransport(requestForCancel.Transport())
cancelReq.SetSource(requestForCancel.Source())
cancelReq.SetDestination(requestForCancel.Destination())
return cancelReq
}
func cloneRequest(req Request, id MessageID, fields log.Fields) Request {
newFields := req.Fields()
if fields != nil {
newFields = newFields.WithFields(fields)
}
newReq := NewRequest(
id,
req.Method(),
req.Recipient().Clone(),
req.SipVersion(),
cloneHeaders(req),
req.Body(),
newFields,
)
newReq.SetTransport(req.Transport())
newReq.SetSource(req.Source())
newReq.SetDestination(req.Destination())
return newReq
}
func CopyRequest(req Request) Request {
return cloneRequest(req, req.MessageID(), nil)
}

View file

@ -0,0 +1,310 @@
package sip
import (
"bytes"
"fmt"
"strconv"
"strings"
"github.com/ghettovoice/gosip/log"
)
// Response RFC 3261 - 7.2.
type Response interface {
Message
StatusCode() StatusCode
SetStatusCode(code StatusCode)
Reason() string
SetReason(reason string)
// Previous returns previous provisional responses
Previous() []Response
SetPrevious(responses []Response)
/* Common helpers */
IsProvisional() bool
IsSuccess() bool
IsRedirection() bool
IsClientError() bool
IsServerError() bool
IsGlobalError() bool
}
type response struct {
message
status StatusCode
reason string
previous []Response
}
func NewResponse(
messID MessageID,
sipVersion string,
statusCode StatusCode,
reason string,
hdrs []Header,
body string,
fields log.Fields,
) Response {
res := new(response)
if messID == "" {
res.messID = NextMessageID()
} else {
res.messID = messID
}
res.startLine = res.StartLine
res.sipVersion = sipVersion
res.headers = newHeaders(hdrs)
res.status = statusCode
res.reason = reason
res.body = body
res.fields = fields.WithFields(log.Fields{
"response_id": res.messID,
})
res.previous = make([]Response, 0)
return res
}
func (res *response) Short() string {
if res == nil {
return "<nil>"
}
fields := log.Fields{
"status": res.StatusCode(),
"reason": res.Reason(),
"transport": res.Transport(),
"source": res.Source(),
"destination": res.Destination(),
}
if cseq, ok := res.CSeq(); ok {
fields["method"] = cseq.MethodName
fields["sequence"] = cseq.SeqNo
}
fields = res.Fields().WithFields(fields)
return fmt.Sprintf("sip.Response<%s>", fields)
}
func (res *response) StatusCode() StatusCode {
res.mu.RLock()
defer res.mu.RUnlock()
return res.status
}
func (res *response) SetStatusCode(code StatusCode) {
res.mu.Lock()
res.status = code
res.mu.Unlock()
}
func (res *response) Reason() string {
res.mu.RLock()
defer res.mu.RUnlock()
return res.reason
}
func (res *response) SetReason(reason string) {
res.mu.Lock()
res.reason = reason
res.mu.Unlock()
}
func (res *response) Previous() []Response {
res.mu.RLock()
defer res.mu.RUnlock()
return res.previous
}
func (res *response) SetPrevious(responses []Response) {
res.mu.Lock()
res.previous = responses
res.mu.Unlock()
}
// StartLine returns Response Status Line - RFC 2361 7.2.
func (res *response) StartLine() string {
var buffer bytes.Buffer
// Every SIP response starts with a Status Line - RFC 2361 7.2.
buffer.WriteString(
fmt.Sprintf(
"%s %d %s",
res.SipVersion(),
res.StatusCode(),
res.Reason(),
),
)
return buffer.String()
}
func (res *response) Clone() Message {
return cloneResponse(res, "", nil)
}
func (res *response) Fields() log.Fields {
return res.fields.WithFields(log.Fields{
"transport": res.Transport(),
"source": res.Source(),
"destination": res.Destination(),
})
}
func (res *response) WithFields(fields log.Fields) Message {
res.mu.Lock()
res.fields = res.fields.WithFields(fields)
res.mu.Unlock()
return res
}
func (res *response) IsProvisional() bool {
return res.StatusCode() < 200
}
func (res *response) IsSuccess() bool {
return res.StatusCode() >= 200 && res.StatusCode() < 300
}
func (res *response) IsRedirection() bool {
return res.StatusCode() >= 300 && res.StatusCode() < 400
}
func (res *response) IsClientError() bool {
return res.StatusCode() >= 400 && res.StatusCode() < 500
}
func (res *response) IsServerError() bool {
return res.StatusCode() >= 500 && res.StatusCode() < 600
}
func (res *response) IsGlobalError() bool {
return res.StatusCode() >= 600
}
func (res *response) IsAck() bool {
if cseq, ok := res.CSeq(); ok {
return cseq.MethodName == ACK
}
return false
}
func (res *response) IsCancel() bool {
if cseq, ok := res.CSeq(); ok {
return cseq.MethodName == CANCEL
}
return false
}
func (res *response) Transport() string {
if tp := res.message.Transport(); tp != "" {
return strings.ToUpper(tp)
}
var tp string
if viaHop, ok := res.ViaHop(); ok && viaHop.Transport != "" {
tp = viaHop.Transport
} else {
tp = DefaultProtocol
}
return tp
}
func (res *response) Destination() string {
if dest := res.message.Destination(); dest != "" {
return dest
}
viaHop, ok := res.ViaHop()
if !ok {
return ""
}
var (
host string
port Port
)
host = viaHop.Host
if viaHop.Port != nil {
port = *viaHop.Port
} else {
port = DefaultPort(res.Transport())
}
if viaHop.Params != nil {
if received, ok := viaHop.Params.Get("received"); ok && received.String() != "" {
host = received.String()
}
if rport, ok := viaHop.Params.Get("rport"); ok && rport != nil && rport.String() != "" {
if p, err := strconv.Atoi(rport.String()); err == nil {
port = Port(uint16(p))
}
}
}
return fmt.Sprintf("%v:%v", host, port)
}
// RFC 3261 - 8.2.6
func NewResponseFromRequest(
resID MessageID,
req Request,
statusCode StatusCode,
reason string,
body string,
) Response {
res := NewResponse(
resID,
req.SipVersion(),
statusCode,
reason,
[]Header{},
"",
req.Fields(),
)
CopyHeaders("Record-Route", req, res)
CopyHeaders("Via", req, res)
CopyHeaders("From", req, res)
CopyHeaders("To", req, res)
CopyHeaders("Call-ID", req, res)
CopyHeaders("CSeq", req, res)
if statusCode == 100 {
CopyHeaders("Timestamp", req, res)
}
res.SetBody(body, true)
res.SetTransport(req.Transport())
res.SetSource(req.Destination())
res.SetDestination(req.Source())
return res
}
func cloneResponse(res Response, id MessageID, fields log.Fields) Response {
newFields := res.Fields()
if fields != nil {
newFields = newFields.WithFields(fields)
}
newRes := NewResponse(
id,
res.SipVersion(),
res.StatusCode(),
res.Reason(),
cloneHeaders(res),
res.Body(),
newFields,
)
newRes.SetPrevious(res.Previous())
newRes.SetTransport(res.Transport())
newRes.SetSource(res.Source())
newRes.SetDestination(res.Destination())
return newRes
}
func CopyResponse(res Response) Response {
return cloneResponse(res, res.MessageID(), nil)
}

View file

@ -0,0 +1,28 @@
package sip
type TransactionKey string
func (key TransactionKey) String() string {
return string(key)
}
type Transaction interface {
Origin() Request
Key() TransactionKey
String() string
Errors() <-chan error
Done() <-chan bool
}
type ServerTransaction interface {
Transaction
Respond(res Response) error
Acks() <-chan Request
Cancels() <-chan Request
}
type ClientTransaction interface {
Transaction
Responses() <-chan Response
Cancel() error
}

View file

@ -0,0 +1,8 @@
package sip
type Transport interface {
Messages() <-chan Message
Send(msg Message) error
IsReliable(network string) bool
IsStreamed(network string) bool
}

View file

@ -0,0 +1,234 @@
// Forked from github.com/StefanKopieczek/gossip by @StefanKopieczek
package timing
import (
"sync"
"time"
)
// Controls whether library calls should be mocked, or whether we should use the standard Go time library.
// If we're in Mock Mode, then time does not pass as normal, but only progresses when Elapse is called.
// False by default, indicating that we just call through to standard Go functions.
var MockMode = false
var currentTimeMock = time.Unix(0, 0)
var mockTimers = make([]*mockTimer, 0)
var mockTimerMu = new(sync.Mutex)
// Interface over Golang's built-in Timers, allowing them to be swapped out for mocked timers.
type Timer interface {
// Returns a channel which sends the current time immediately when the timer expires.
// Equivalent to time.Timer.C; however, we have to use a method here instead of a member since this is an interface.
C() <-chan time.Time
// Resets the timer such that it will expire in duration 'd' after the current time.
// Returns true if the timer had been active, and false if it had expired or been stopped.
Reset(d time.Duration) bool
// Stops the timer, preventing it from firing.
// Returns true if the timer had been active, and false if it had expired or been stopped.
Stop() bool
}
// Implementation of Timer that just wraps time.Timer.
type realTimer struct {
*time.Timer
}
func (t *realTimer) C() <-chan time.Time {
return t.Timer.C
}
func (t *realTimer) Reset(d time.Duration) bool {
t.Stop()
return t.Timer.Reset(d)
}
func (t *realTimer) Stop() bool {
// return t.Timer.Stop()
if !t.Timer.Stop() {
select {
case <-t.Timer.C:
return true
default:
return false
}
}
return true
}
// Implementation of Timer that mocks time.Timer, firing when the total elapsed time (as controlled by Elapse)
// exceeds the duration specified when the timer was constructed.
type mockTimer struct {
EndTime time.Time
Chan chan time.Time
fired bool
toRun func()
}
func (t *mockTimer) C() <-chan time.Time {
return t.Chan
}
func (t *mockTimer) Reset(d time.Duration) bool {
wasActive := removeMockTimer(t)
t.EndTime = currentTimeMock.Add(d)
if d > 0 {
mockTimerMu.Lock()
mockTimers = append(mockTimers, t)
mockTimerMu.Unlock()
} else {
// The new timer has an expiry time of 0.
// Fire it right away, and don't bother tracking it.
t.Chan <- currentTimeMock
}
return wasActive
}
func (t *mockTimer) Stop() bool {
if !removeMockTimer(t) {
select {
case <-t.Chan:
return true
default:
return false
}
}
return true
}
// Creates a new Timer; either a wrapper around a standard Go time.Timer, or a mocked-out Timer,
// depending on whether MockMode is set.
func NewTimer(d time.Duration) Timer {
if MockMode {
t := mockTimer{currentTimeMock.Add(d), make(chan time.Time, 1), false, nil}
if d == 0 {
t.Chan <- currentTimeMock
} else {
mockTimerMu.Lock()
mockTimers = append(mockTimers, &t)
mockTimerMu.Unlock()
}
return &t
} else {
return &realTimer{time.NewTimer(d)}
}
}
// See built-in time.After() function.
func After(d time.Duration) <-chan time.Time {
return NewTimer(d).C()
}
// See built-in time.AfterFunc() function.
func AfterFunc(d time.Duration, f func()) Timer {
if MockMode {
mockTimerMu.Lock()
t := mockTimer{currentTimeMock.Add(d), make(chan time.Time, 1), false, f}
mockTimerMu.Unlock()
if d == 0 {
go f()
t.Chan <- currentTimeMock
} else {
mockTimerMu.Lock()
mockTimers = append(mockTimers, &t)
mockTimerMu.Unlock()
}
return &t
} else {
return &realTimer{time.AfterFunc(d, f)}
}
}
// See built-in time.Sleep() function.
func Sleep(d time.Duration) {
<-After(d)
}
// Increment the current time by the given Duration.
// This function can only be called in Mock Mode, otherwise we will panic.
func Elapse(d time.Duration) {
requireMockMode()
mockTimerMu.Lock()
currentTimeMock = currentTimeMock.Add(d)
mockTimerMu.Unlock()
// Fire any timers whose time has come up.
mockTimerMu.Lock()
for _, t := range mockTimers {
t.fired = false
if !t.EndTime.After(currentTimeMock) {
if t.toRun != nil {
go t.toRun()
}
// Clear the channel if something is already in it.
select {
case <-t.Chan:
default:
}
t.Chan <- currentTimeMock
t.fired = true
}
}
mockTimerMu.Unlock()
// Stop tracking any fired timers.
remainingTimers := make([]*mockTimer, 0)
mockTimerMu.Lock()
for _, t := range mockTimers {
if !t.fired {
remainingTimers = append(remainingTimers, t)
}
}
mockTimers = remainingTimers
mockTimerMu.Unlock()
}
// Returns the current time.
// If Mock Mode is set, this will be the sum of all Durations passed into Elapse calls;
// otherwise it will be the true system time.
func Now() time.Time {
if MockMode {
return currentTimeMock
} else {
return time.Now()
}
}
// Shortcut method to enforce that Mock Mode is enabled.
func requireMockMode() {
if !MockMode {
panic("This method requires MockMode to be enabled")
}
}
// Utility method to remove a mockTimer from the list of outstanding timers.
func removeMockTimer(t *mockTimer) bool {
// First, find the index of the timer in our list.
found := false
var idx int
var elt *mockTimer
mockTimerMu.Lock()
for idx, elt = range mockTimers {
if elt == t {
found = true
break
}
}
mockTimerMu.Unlock()
if found {
mockTimerMu.Lock()
// We found the given timer. Remove it.
mockTimers = append(mockTimers[:idx], mockTimers[idx+1:]...)
mockTimerMu.Unlock()
return true
} else {
// The timer was not present, indicating that it was already expired.
return false
}
}

View file

@ -0,0 +1,219 @@
package transport
import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/ghettovoice/gosip/log"
)
var (
bufferSize uint16 = 65535 - 20 - 8 // IPv4 max size - IPv4 Header size - UDP Header size
)
// Wrapper around net.Conn.
type Connection interface {
net.Conn
Key() ConnectionKey
Network() string
Streamed() bool
String() string
ReadFrom(buf []byte) (num int, raddr net.Addr, err error)
WriteTo(buf []byte, raddr net.Addr) (num int, err error)
}
// Connection implementation.
type connection struct {
baseConn net.Conn
key ConnectionKey
network string
laddr net.Addr
raddr net.Addr
streamed bool
mu sync.RWMutex
log log.Logger
}
func NewConnection(baseConn net.Conn, key ConnectionKey, network string, logger log.Logger) Connection {
var stream bool
switch baseConn.(type) {
case net.PacketConn:
stream = false
default:
stream = true
}
conn := &connection{
baseConn: baseConn,
key: key,
network: network,
laddr: baseConn.LocalAddr(),
raddr: baseConn.RemoteAddr(),
streamed: stream,
}
conn.log = logger.
WithPrefix("transport.Connection").
WithFields(log.Fields{
"connection_ptr": fmt.Sprintf("%p", conn),
"connection_key": conn.Key(),
})
return conn
}
func (conn *connection) String() string {
if conn == nil {
return "<nil>"
}
fields := conn.Log().Fields().WithFields(log.Fields{
"key": conn.Key(),
"network": conn.Network(),
"local_addr": conn.LocalAddr(),
"remote_addr": conn.RemoteAddr(),
})
return fmt.Sprintf("transport.Connection<%s>", fields)
}
func (conn *connection) Log() log.Logger {
return conn.log
}
func (conn *connection) Key() ConnectionKey {
return conn.key
}
func (conn *connection) Streamed() bool {
return conn.streamed
}
func (conn *connection) Network() string {
return strings.ToUpper(conn.network)
}
func (conn *connection) Read(buf []byte) (int, error) {
var (
num int
err error
)
num, err = conn.baseConn.Read(buf)
if err != nil {
return num, &ConnectionError{
err,
"read",
conn.Network(),
fmt.Sprintf("%v", conn.RemoteAddr()),
fmt.Sprintf("%v", conn.LocalAddr()),
fmt.Sprintf("%p", conn),
}
}
conn.Log().Tracef("read %d bytes %s <- %s:\n%s", num, conn.LocalAddr(), conn.RemoteAddr(), buf[:num])
return num, err
}
func (conn *connection) ReadFrom(buf []byte) (num int, raddr net.Addr, err error) {
num, raddr, err = conn.baseConn.(net.PacketConn).ReadFrom(buf)
if err != nil {
return num, raddr, &ConnectionError{
err,
"read",
conn.Network(),
fmt.Sprintf("%v", raddr),
fmt.Sprintf("%v", conn.LocalAddr()),
fmt.Sprintf("%p", conn),
}
}
conn.Log().Tracef("read %d bytes %s <- %s:\n%s", num, conn.LocalAddr(), raddr, buf[:num])
return num, raddr, err
}
func (conn *connection) Write(buf []byte) (int, error) {
var (
num int
err error
)
num, err = conn.baseConn.Write(buf)
if err != nil {
return num, &ConnectionError{
err,
"write",
conn.Network(),
fmt.Sprintf("%v", conn.LocalAddr()),
fmt.Sprintf("%v", conn.RemoteAddr()),
fmt.Sprintf("%p", conn),
}
}
conn.Log().Tracef("write %d bytes %s -> %s:\n%s", num, conn.LocalAddr(), conn.RemoteAddr(), buf[:num])
return num, err
}
func (conn *connection) WriteTo(buf []byte, raddr net.Addr) (num int, err error) {
num, err = conn.baseConn.(net.PacketConn).WriteTo(buf, raddr)
if err != nil {
return num, &ConnectionError{
err,
"write",
conn.Network(),
fmt.Sprintf("%v", conn.LocalAddr()),
fmt.Sprintf("%v", raddr),
fmt.Sprintf("%p", conn),
}
}
conn.Log().Tracef("write %d bytes %s -> %s:\n%s", num, conn.LocalAddr(), raddr, buf[:num])
return num, err
}
func (conn *connection) LocalAddr() net.Addr {
return conn.baseConn.LocalAddr()
}
func (conn *connection) RemoteAddr() net.Addr {
return conn.baseConn.RemoteAddr()
}
func (conn *connection) Close() error {
err := conn.baseConn.Close()
if err != nil {
return &ConnectionError{
err,
"close",
conn.Network(),
"",
"",
fmt.Sprintf("%p", conn),
}
}
conn.Log().Trace("connection closed")
return nil
}
func (conn *connection) SetDeadline(t time.Time) error {
return conn.baseConn.SetDeadline(t)
}
func (conn *connection) SetReadDeadline(t time.Time) error {
return conn.baseConn.SetReadDeadline(t)
}
func (conn *connection) SetWriteDeadline(t time.Time) error {
return conn.baseConn.SetWriteDeadline(t)
}

View file

@ -0,0 +1,802 @@
package transport
import (
"bytes"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
"github.com/ghettovoice/gosip/sip/parser"
"github.com/ghettovoice/gosip/timing"
"github.com/ghettovoice/gosip/util"
)
type ConnectionKey string
func (key ConnectionKey) String() string {
return string(key)
}
// ConnectionPool used for active connection management.
type ConnectionPool interface {
Done() <-chan struct{}
String() string
Put(connection Connection, ttl time.Duration) error
Get(key ConnectionKey) (Connection, error)
All() []Connection
Drop(key ConnectionKey) error
DropAll() error
Length() int
}
// ConnectionHandler serves associated connection, i.e. parses
// incoming data, manages expiry time & etc.
type ConnectionHandler interface {
Cancel()
Done() <-chan struct{}
String() string
Key() ConnectionKey
Connection() Connection
// Expiry returns connection expiry time.
Expiry() time.Time
Expired() bool
// Update updates connection expiry time.
// TODO put later to allow runtime update
// Update(conn Connection, ttl time.Duration)
// Manage runs connection serving.
Serve()
}
type connectionPool struct {
store map[ConnectionKey]ConnectionHandler
msgMapper sip.MessageMapper
output chan<- sip.Message
errs chan<- error
cancel <-chan struct{}
done chan struct{}
hmess chan sip.Message
herrs chan error
hwg sync.WaitGroup
mu sync.RWMutex
log log.Logger
}
func NewConnectionPool(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) ConnectionPool {
pool := &connectionPool{
store: make(map[ConnectionKey]ConnectionHandler),
msgMapper: msgMapper,
output: output,
errs: errs,
cancel: cancel,
done: make(chan struct{}),
hmess: make(chan sip.Message),
herrs: make(chan error),
}
pool.log = logger.
WithPrefix("transport.ConnectionPool").
WithFields(log.Fields{
"connection_pool_ptr": fmt.Sprintf("%p", pool),
})
go func() {
<-pool.cancel
pool.dispose()
}()
go pool.serveHandlers()
return pool
}
func (pool *connectionPool) String() string {
if pool == nil {
return "<nil>"
}
return fmt.Sprintf("transport.ConnectionPool<%s>", pool.Log().Fields())
}
func (pool *connectionPool) Log() log.Logger {
return pool.log
}
func (pool *connectionPool) Done() <-chan struct{} {
return pool.done
}
// Put adds new connection to pool or updates TTL of existing connection
// TTL - 0 - unlimited; 1 - ... - time to live in pool
func (pool *connectionPool) Put(connection Connection, ttl time.Duration) error {
select {
case <-pool.cancel:
return &PoolError{
fmt.Errorf("connection pool closed"),
"get connection",
pool.String(),
}
default:
}
key := connection.Key()
if key == "" {
return &PoolError{
fmt.Errorf("empty connection key"),
"put connection",
pool.String(),
}
}
pool.mu.Lock()
defer pool.mu.Unlock()
return pool.put(key, connection, ttl)
}
func (pool *connectionPool) Get(key ConnectionKey) (Connection, error) {
pool.mu.RLock()
defer pool.mu.RUnlock()
return pool.getConnection(key)
}
func (pool *connectionPool) Drop(key ConnectionKey) error {
pool.mu.Lock()
defer pool.mu.Unlock()
return pool.drop(key)
}
func (pool *connectionPool) DropAll() error {
pool.mu.Lock()
for key := range pool.store {
if err := pool.drop(key); err != nil {
pool.Log().Errorf("drop connection %s failed: %s", key, err)
}
}
pool.mu.Unlock()
return nil
}
func (pool *connectionPool) All() []Connection {
pool.mu.RLock()
conns := make([]Connection, 0)
for _, handler := range pool.store {
conns = append(conns, handler.Connection())
}
pool.mu.RUnlock()
return conns
}
func (pool *connectionPool) Length() int {
pool.mu.RLock()
defer pool.mu.RUnlock()
return len(pool.store)
}
func (pool *connectionPool) dispose() {
// clean pool
pool.DropAll()
pool.hwg.Wait()
// stop serveHandlers goroutine
close(pool.hmess)
close(pool.herrs)
close(pool.done)
}
func (pool *connectionPool) serveHandlers() {
pool.Log().Debug("begin serve connection handlers")
defer pool.Log().Debug("stop serve connection handlers")
for {
logger := pool.Log()
select {
case msg, ok := <-pool.hmess:
// cancel signal, serveStore exists
if !ok {
return
}
if msg == nil {
continue
}
logger = logger.WithFields(msg.Fields())
logger.Trace("passing up SIP message")
select {
case <-pool.cancel:
return
case pool.output <- msg:
logger.Trace("SIP message passed up")
continue
}
case err, ok := <-pool.herrs:
// cancel signal, serveStore exists
if !ok {
return
}
if err == nil {
continue
}
// on ConnectionHandleError we should drop handler in some cases
// all other possible errors ignored because in pool.herrs should be only ConnectionHandlerErrors
// so ConnectionPool passes up only Network (when connection falls) and MalformedMessage errors
var herr *ConnectionHandlerError
if !errors.As(err, &herr) {
// all other possible errors
logger.Tracef("ignore non connection error: %s", err)
continue
}
pool.mu.RLock()
handler, gerr := pool.get(herr.Key)
pool.mu.RUnlock()
if gerr != nil {
// ignore, handler already dropped out
logger.Tracef("ignore error from already dropped out connection %s: %s", herr.Key, gerr)
continue
}
logger = logger.WithFields(log.Fields{
"connection_handler": handler.String(),
})
if herr.Expired() {
// handler expired, drop it from pool and continue without emitting error
if handler.Expired() {
// connection expired
logger.Debug("connection expired, drop it and go further")
if err := pool.Drop(handler.Key()); err != nil {
logger.Error(err)
}
} else {
// Due to a race condition, the socket has been updated since this expiry happened.
// Ignore the expiry since we already have a new socket for this address.
logger.Trace("ignore spurious connection expiry")
}
continue
} else if herr.EOF() {
select {
case <-pool.cancel:
return
default:
}
// remote endpoint closed
logger.Debugf("connection EOF: %s; drop it and go further", herr)
if err := pool.Drop(handler.Key()); err != nil {
logger.Error(err)
}
var connErr *ConnectionError
if errors.As(herr.Err, &connErr) {
pool.errs <- herr.Err
}
continue
} else if herr.Network() {
// connection broken or closed
logger.Debugf("connection network error: %s; drop it and pass the error up", herr)
if err := pool.Drop(handler.Key()); err != nil {
logger.Error(err)
}
} else {
// syntax errors, malformed message errors and other
logger.Tracef("connection error: %s; pass the error up", herr)
}
// send initial error
select {
case <-pool.cancel:
return
case pool.errs <- herr.Err:
logger.Trace("error passed up")
continue
}
}
}
}
func (pool *connectionPool) put(key ConnectionKey, conn Connection, ttl time.Duration) error {
if _, err := pool.get(key); err == nil {
return &PoolError{
fmt.Errorf("key %s already exists in the pool", key),
"put connection",
pool.String(),
}
}
// wrap to handler
handler := NewConnectionHandler(
conn,
ttl,
pool.hmess,
pool.herrs,
pool.msgMapper,
pool.Log(),
)
logger := log.AddFieldsFrom(pool.Log(), handler)
logger.Tracef("put connection to the pool with TTL = %s", ttl)
pool.store[handler.Key()] = handler
// start serving
pool.hwg.Add(1)
go handler.Serve()
go func() {
<-handler.Done()
pool.hwg.Done()
}()
return nil
}
func (pool *connectionPool) drop(key ConnectionKey) error {
// check existence in pool
handler, err := pool.get(key)
if err != nil {
return err
}
handler.Cancel()
logger := log.AddFieldsFrom(pool.Log(), handler)
logger.Trace("drop connection from the pool")
// modify store
delete(pool.store, key)
return nil
}
func (pool *connectionPool) get(key ConnectionKey) (ConnectionHandler, error) {
if handler, ok := pool.store[key]; ok {
return handler, nil
}
return nil, &PoolError{
fmt.Errorf("connection %s not found in the pool", key),
"get connection",
pool.String(),
}
}
func (pool *connectionPool) getConnection(key ConnectionKey) (Connection, error) {
var conn Connection
handler, err := pool.get(key)
if err == nil {
conn = handler.Connection()
}
return conn, err
}
// connectionHandler actually serves associated connection
type connectionHandler struct {
connection Connection
msgMapper sip.MessageMapper
timer timing.Timer
ttl time.Duration
expiry time.Time
output chan<- sip.Message
errs chan<- error
cancelOnce sync.Once
canceled chan struct{}
done chan struct{}
addrs util.ElasticChan
log log.Logger
}
func NewConnectionHandler(
conn Connection,
ttl time.Duration,
output chan<- sip.Message,
errs chan<- error,
msgMapper sip.MessageMapper,
logger log.Logger,
) ConnectionHandler {
handler := &connectionHandler{
connection: conn,
msgMapper: msgMapper,
output: output,
errs: errs,
canceled: make(chan struct{}),
done: make(chan struct{}),
ttl: ttl,
}
handler.log = logger.
WithPrefix("transport.ConnectionHandler").
WithFields(log.Fields{
"connection_handler_ptr": fmt.Sprintf("%p", handler),
"connection_ptr": fmt.Sprintf("%p", conn),
"connection_key": conn.Key(),
"connection_network": conn.Network(),
})
// handler.Update(ttl)
if ttl > 0 {
handler.expiry = time.Now().Add(ttl)
handler.timer = timing.NewTimer(ttl)
} else {
handler.expiry = time.Time{}
handler.timer = timing.NewTimer(0)
if !handler.timer.Stop() {
<-handler.timer.C()
}
}
if handler.msgMapper == nil {
handler.msgMapper = func(msg sip.Message) sip.Message {
return msg
}
}
return handler
}
func (handler *connectionHandler) String() string {
if handler == nil {
return "<nil>"
}
return fmt.Sprintf("transport.ConnectionHandler<%s>", handler.Log().Fields())
}
func (handler *connectionHandler) Log() log.Logger {
return handler.log
}
func (handler *connectionHandler) Key() ConnectionKey {
return handler.connection.Key()
}
func (handler *connectionHandler) Connection() Connection {
return handler.connection
}
func (handler *connectionHandler) Expiry() time.Time {
return handler.expiry
}
func (handler *connectionHandler) Expired() bool {
return !handler.Expiry().IsZero() && handler.Expiry().Before(time.Now())
}
// resets the timeout timer.
// func (handler *connectionHandler) Update(ttl time.Duration) {
// if ttl > 0 {
// expiryTime := timing.Now().Put(ttl)
// handler.Log().Debugf("set %s expiry time to %s", handler, expiryTime)
// handler.expiry = expiryTime
//
// if handler.timer == nil {
// handler.timer = timing.NewTimer(ttl)
// } else {
// handler.timer.Reset(ttl)
// }
// } else {
// handler.Log().Debugf("set %s unlimited expiry time")
// handler.expiry = time.Time{}
//
// if handler.timer == nil {
// handler.timer = timing.NewTimer(0)
// }
// handler.timer.Stop()
// }
// }
// Serve is connection serving loop.
// Waits for the connection to expire, and notifies the pool when it does.
func (handler *connectionHandler) Serve() {
defer close(handler.done)
handler.Log().Debug("begin serve connection")
defer handler.Log().Debug("stop serve connection")
// start connection serving goroutines
msgs, errs := handler.readConnection()
handler.pipeOutputs(msgs, errs)
}
func (handler *connectionHandler) readConnection() (<-chan sip.Message, <-chan error) {
msgs := make(chan sip.Message)
errs := make(chan error)
streamed := handler.Connection().Streamed()
var (
pktPrs *parser.PacketParser
strPrs parser.Parser
)
if streamed {
strPrs = parser.NewParser(msgs, errs, streamed, handler.Log())
} else {
pktPrs = parser.NewPacketParser(handler.Log())
}
var raddr net.Addr
if streamed {
raddr = handler.Connection().RemoteAddr()
} else {
handler.addrs.Init()
handler.addrs.SetLog(handler.Log())
handler.addrs.Run()
}
go func() {
defer func() {
handler.Connection().Close()
if streamed {
strPrs.Stop()
} else {
pktPrs.Stop()
}
if !streamed {
handler.addrs.Stop()
}
close(msgs)
close(errs)
}()
handler.Log().Debug("begin read connection")
defer handler.Log().Debug("stop read connection")
buf := make([]byte, bufferSize)
var (
num int
err error
)
for {
// wait for data
if streamed {
num, err = handler.Connection().Read(buf)
} else {
num, raddr, err = handler.Connection().ReadFrom(buf)
}
if err != nil {
//// if we get timeout error just go further and try read on the next iteration
//var netErr net.Error
//if errors.As(err, &netErr) {
// if netErr.Timeout() || netErr.Temporary() {
// handler.Log().Tracef(
// "connection read failed due to timeout or temporary unavailable reason: %s, sleep by %s",
// err,
// netErrRetryTime,
// )
//
// time.Sleep(netErrRetryTime)
//
// continue
// }
//}
// broken or closed connection
// so send error and exit
handler.handleError(err, fmt.Sprintf("%v", raddr))
return
}
data := buf[:num]
// skip empty udp packets
if len(bytes.Trim(data, "\x00")) == 0 {
handler.Log().Tracef("skip empty data: %#v", data)
continue
}
// parse received data
if streamed {
if _, err := strPrs.Write(data); err != nil {
handler.handleError(err, fmt.Sprintf("%v", raddr))
}
} else {
if msg, err := pktPrs.ParseMessage(data); err == nil {
handler.handleMessage(msg, fmt.Sprintf("%v", raddr))
} else {
handler.handleError(err, fmt.Sprintf("%v", raddr))
}
}
}
}()
return msgs, errs
}
func (handler *connectionHandler) pipeOutputs(msgs <-chan sip.Message, errs <-chan error) {
streamed := handler.Connection().Streamed()
handler.Log().Debug("begin pipe outputs")
defer handler.Log().Debug("stop pipe outputs")
for {
select {
case <-handler.timer.C():
var raddr string
if streamed {
raddr = fmt.Sprintf("%v", handler.Connection().RemoteAddr())
}
if handler.Expiry().IsZero() {
// handler expiryTime is zero only when TTL = 0 (unlimited handler)
// so we must not get here with zero expiryTime
handler.Log().Panic("fires expiry timer with ZERO expiryTime")
}
// pass up to the pool
handler.handleError(ExpireError("connection expired"), raddr)
case msg, ok := <-msgs:
if !ok {
return
}
handler.handleMessage(msg, handler.getRemoteAddr())
case err, ok := <-errs:
if !ok {
return
}
handler.handleError(err, handler.getRemoteAddr())
}
}
}
func (handler *connectionHandler) getRemoteAddr() string {
if handler.Connection().Streamed() {
return fmt.Sprintf("%v", handler.Connection().RemoteAddr())
} else {
// use non-blocking read because remote address already should be here
// or error occurred in read connection goroutine
select {
case v := <-handler.addrs.Out:
return v.(string)
default:
return "<nil>"
}
}
}
func (handler *connectionHandler) handleMessage(msg sip.Message, raddr string) {
msg.SetDestination(handler.Connection().LocalAddr().String())
rhost, rport, _ := net.SplitHostPort(raddr)
switch msg := msg.(type) {
case sip.Request:
// RFC 3261 - 18.2.1
viaHop, ok := msg.ViaHop()
if !ok {
handler.Log().Warn("ignore message without 'Via' header")
return
}
if rhost != "" && rhost != viaHop.Host {
viaHop.Params.Add("received", sip.String{Str: rhost})
}
// rfc3581
if viaHop.Params.Has("rport") {
viaHop.Params.Add("rport", sip.String{Str: rport})
}
if !handler.Connection().Streamed() {
if !viaHop.Params.Has("rport") {
var port sip.Port
if viaHop.Port != nil {
port = *viaHop.Port
} else {
port = sip.DefaultPort(handler.Connection().Network())
}
raddr = fmt.Sprintf("%s:%d", rhost, port)
}
}
msg.SetTransport(handler.connection.Network())
msg.SetSource(raddr)
case sip.Response:
// Set Remote Address as response source
msg.SetTransport(handler.connection.Network())
msg.SetSource(raddr)
}
msg = handler.msgMapper(msg.WithFields(log.Fields{
"connection_key": handler.Connection().Key(),
"received_at": time.Now(),
}))
// pass up
handler.output <- msg
if !handler.Expiry().IsZero() {
handler.expiry = time.Now().Add(handler.ttl)
handler.timer.Reset(handler.ttl)
}
}
func (handler *connectionHandler) handleError(err error, raddr string) {
if isSyntaxError(err) {
handler.Log().Tracef("ignore error: %s", err)
return
}
err = &ConnectionHandlerError{
err,
handler.Key(),
fmt.Sprintf("%p", handler),
handler.Connection().Network(),
fmt.Sprintf("%v", handler.Connection().LocalAddr()),
raddr,
}
select {
case <-handler.canceled:
case handler.errs <- err:
}
}
func isSyntaxError(err error) bool {
var perr parser.Error
if errors.As(err, &perr) && perr.Syntax() {
return true
}
var merr sip.MessageError
if errors.As(err, &merr) && merr.Broken() {
return true
}
return false
}
// Cancel simply calls runtime provided cancel function.
func (handler *connectionHandler) Cancel() {
handler.cancelOnce.Do(func() {
close(handler.canceled)
handler.Connection().Close()
handler.Log().Debug("connection handler canceled")
})
}
func (handler *connectionHandler) Done() <-chan struct{} {
return handler.done
}

View file

@ -0,0 +1,477 @@
package transport
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"strings"
"sync"
"time"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
// Layer is responsible for the actual transmission of messages - RFC 3261 - 18.
type Layer interface {
Cancel()
Done() <-chan struct{}
Messages() <-chan sip.Message
Errors() <-chan error
// Listen starts listening on `addr` for each registered protocol.
Listen(network string, addr string, options ...ListenOption) error
// Send sends message on suitable protocol.
Send(msg sip.Message) error
String() string
IsReliable(network string) bool
IsStreamed(network string) bool
}
var protocolFactory ProtocolFactory = func(
network string,
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) (Protocol, error) {
switch strings.ToLower(network) {
case "udp":
return NewUdpProtocol(output, errs, cancel, msgMapper, logger), nil
case "tcp":
return NewTcpProtocol(output, errs, cancel, msgMapper, logger), nil
case "tls":
return NewTlsProtocol(output, errs, cancel, msgMapper, logger), nil
case "ws":
return NewWsProtocol(output, errs, cancel, msgMapper, logger), nil
case "wss":
return NewWssProtocol(output, errs, cancel, msgMapper, logger), nil
default:
return nil, UnsupportedProtocolError(fmt.Sprintf("protocol %s is not supported", network))
}
}
// SetProtocolFactory replaces default protocol factory
func SetProtocolFactory(factory ProtocolFactory) {
protocolFactory = factory
}
// GetProtocolFactory returns default protocol factory
func GetProtocolFactory() ProtocolFactory {
return protocolFactory
}
// TransportLayer implementation.
type layer struct {
protocols *protocolStore
listenPorts map[string][]sip.Port
ip net.IP
dnsResolver *net.Resolver
msgMapper sip.MessageMapper
msgs chan sip.Message
errs chan error
pmsgs chan sip.Message
perrs chan error
canceled chan struct{}
done chan struct{}
wg sync.WaitGroup
cancelOnce sync.Once
log log.Logger
}
// NewLayer creates transport layer.
// - ip - host IP
// - dnsAddr - DNS server address, default is 127.0.0.1:53
func NewLayer(
ip net.IP,
dnsResolver *net.Resolver,
msgMapper sip.MessageMapper,
logger log.Logger,
) Layer {
tpl := &layer{
protocols: newProtocolStore(),
listenPorts: make(map[string][]sip.Port),
ip: ip,
dnsResolver: dnsResolver,
msgMapper: msgMapper,
msgs: make(chan sip.Message),
errs: make(chan error),
pmsgs: make(chan sip.Message),
perrs: make(chan error),
canceled: make(chan struct{}),
done: make(chan struct{}),
}
tpl.log = logger.
WithPrefix("transport.Layer").
WithFields(map[string]interface{}{
"transport_layer_ptr": fmt.Sprintf("%p", tpl),
})
go tpl.serveProtocols()
return tpl
}
func (tpl *layer) String() string {
if tpl == nil {
return "<nil>"
}
return fmt.Sprintf("transport.Layer<%s>", tpl.Log().Fields())
}
func (tpl *layer) Log() log.Logger {
return tpl.log
}
func (tpl *layer) Cancel() {
select {
case <-tpl.canceled:
return
default:
}
tpl.cancelOnce.Do(func() {
close(tpl.canceled)
tpl.Log().Debug("transport layer canceled")
})
}
func (tpl *layer) Done() <-chan struct{} {
return tpl.done
}
func (tpl *layer) Messages() <-chan sip.Message {
return tpl.msgs
}
func (tpl *layer) Errors() <-chan error {
return tpl.errs
}
func (tpl *layer) IsReliable(network string) bool {
if protocol, ok := tpl.protocols.get(protocolKey(network)); ok && protocol.Reliable() {
return true
}
return false
}
func (tpl *layer) IsStreamed(network string) bool {
if protocol, ok := tpl.protocols.get(protocolKey(network)); ok && protocol.Streamed() {
return true
}
return false
}
func (tpl *layer) Listen(network string, addr string, options ...ListenOption) error {
select {
case <-tpl.canceled:
return fmt.Errorf("transport layer is canceled")
default:
}
protocol, err := tpl.getProtocol(network)
if err != nil {
return err
}
target, err := NewTargetFromAddr(addr)
if err != nil {
return err
}
target = FillTargetHostAndPort(protocol.Network(), target)
err = protocol.Listen(target, options...)
if err == nil {
if _, ok := tpl.listenPorts[protocol.Network()]; !ok {
if tpl.listenPorts[protocol.Network()] == nil {
tpl.listenPorts[protocol.Network()] = make([]sip.Port, 0)
}
tpl.listenPorts[protocol.Network()] = append(tpl.listenPorts[protocol.Network()], *target.Port)
}
}
return err
}
func (tpl *layer) Send(msg sip.Message) error {
select {
case <-tpl.canceled:
return fmt.Errorf("transport layer is canceled")
default:
}
viaHop, ok := msg.ViaHop()
if !ok {
return &sip.MalformedMessageError{
Err: fmt.Errorf("missing required 'Via' header"),
Msg: msg.String(),
}
}
switch msg := msg.(type) {
// RFC 3261 - 18.1.1.
case sip.Request:
network := msg.Transport()
// rewrite sent-by transport
viaHop.Transport = strings.ToUpper(network)
viaHop.Host = tpl.ip.String()
protocol, err := tpl.getProtocol(network)
if err != nil {
return err
}
// rewrite sent-by port
if viaHop.Port == nil {
if ports, ok := tpl.listenPorts[network]; ok {
port := ports[rand.Intn(len(ports))]
viaHop.Port = &port
} else {
defPort := sip.DefaultPort(network)
viaHop.Port = &defPort
}
}
target, err := NewTargetFromAddr(msg.Destination())
if err != nil {
return fmt.Errorf("build address target for %s: %w", msg.Destination(), err)
}
// dns srv lookup
if net.ParseIP(target.Host) == nil {
ctx := context.Background()
proto := strings.ToLower(network)
if _, addrs, err := tpl.dnsResolver.LookupSRV(ctx, "sip", proto, target.Host); err == nil && len(addrs) > 0 {
addr := addrs[0]
addrStr := fmt.Sprintf("%s:%d", addr.Target[:len(addr.Target)-1], addr.Port)
switch network {
case "UDP":
if addr, err := net.ResolveUDPAddr("udp", addrStr); err == nil {
port := sip.Port(addr.Port)
if addr.IP.To4() == nil {
target.Host = fmt.Sprintf("[%v]", addr.IP.String())
} else {
target.Host = addr.IP.String()
}
target.Port = &port
}
case "TLS":
fallthrough
case "WS":
fallthrough
case "WSS":
fallthrough
case "TCP":
if addr, err := net.ResolveTCPAddr("tcp", addrStr); err == nil {
port := sip.Port(addr.Port)
if addr.IP.To4() == nil {
target.Host = fmt.Sprintf("[%v]", addr.IP.String())
} else {
target.Host = addr.IP.String()
}
target.Port = &port
}
}
}
}
logger := log.AddFieldsFrom(tpl.Log(), protocol, msg)
logger.Debugf("sending SIP request:\n%s", msg)
if err = protocol.Send(target, msg); err != nil {
return fmt.Errorf("send SIP message through %s protocol to %s: %w", protocol.Network(), target.Addr(), err)
}
return nil
// RFC 3261 - 18.2.2.
case sip.Response:
// resolve protocol from Via
protocol, err := tpl.getProtocol(msg.Transport())
if err != nil {
return err
}
target, err := NewTargetFromAddr(msg.Destination())
if err != nil {
return fmt.Errorf("build address target for %s: %w", msg.Destination(), err)
}
logger := log.AddFieldsFrom(tpl.Log(), protocol, msg)
logger.Debugf("sending SIP response:\n%s", msg)
if err = protocol.Send(target, msg); err != nil {
return fmt.Errorf("send SIP message through %s protocol to %s: %w", protocol.Network(), target.Addr(), err)
}
return nil
default:
return &sip.UnsupportedMessageError{
Err: fmt.Errorf("unsupported message %s", msg.Short()),
Msg: msg.String(),
}
}
}
func (tpl *layer) getProtocol(network string) (Protocol, error) {
network = strings.ToLower(network)
return tpl.protocols.getOrPutNew(protocolKey(network), func() (Protocol, error) {
return protocolFactory(
network,
tpl.pmsgs,
tpl.perrs,
tpl.canceled,
tpl.msgMapper,
tpl.Log(),
)
})
}
func (tpl *layer) serveProtocols() {
defer func() {
tpl.dispose()
close(tpl.done)
}()
tpl.Log().Debug("begin serve protocols")
defer tpl.Log().Debug("stop serve protocols")
for {
select {
case <-tpl.canceled:
return
case msg := <-tpl.pmsgs:
tpl.handleMessage(msg)
case err := <-tpl.perrs:
tpl.handlerError(err)
}
}
}
func (tpl *layer) dispose() {
tpl.Log().Debug("disposing...")
// wait for protocols
for _, protocol := range tpl.protocols.all() {
tpl.protocols.drop(protocolKey(protocol.Network()))
<-protocol.Done()
}
tpl.listenPorts = make(map[string][]sip.Port)
close(tpl.pmsgs)
close(tpl.perrs)
close(tpl.msgs)
close(tpl.errs)
}
// handles incoming message from protocol
// should be called inside goroutine for non-blocking forwarding
func (tpl *layer) handleMessage(msg sip.Message) {
logger := tpl.Log().WithFields(msg.Fields())
logger.Debugf("received SIP message:\n%s", msg)
logger.Trace("passing up SIP message...")
// pass up message
select {
case <-tpl.canceled:
case tpl.msgs <- msg:
logger.Trace("SIP message passed up")
}
}
func (tpl *layer) handlerError(err error) {
// TODO: implement re-connection strategy for listeners
var terr Error
if errors.As(err, &terr) {
// currently log
tpl.Log().Warnf("SIP transport error: %s", err)
}
logger := tpl.Log().WithFields(log.Fields{
"sip_error": err.Error(),
})
logger.Trace("passing up error...")
select {
case <-tpl.canceled:
case tpl.errs <- err:
logger.Trace("error passed up")
}
}
type protocolKey string
// Thread-safe protocols pool.
type protocolStore struct {
protocols map[protocolKey]Protocol
mu sync.RWMutex
}
func newProtocolStore() *protocolStore {
return &protocolStore{
protocols: make(map[protocolKey]Protocol),
}
}
func (store *protocolStore) put(key protocolKey, protocol Protocol) {
store.mu.Lock()
store.protocols[key] = protocol
store.mu.Unlock()
}
func (store *protocolStore) get(key protocolKey) (Protocol, bool) {
store.mu.RLock()
defer store.mu.RUnlock()
protocol, ok := store.protocols[key]
return protocol, ok
}
func (store *protocolStore) getOrPutNew(key protocolKey, factory func() (Protocol, error)) (Protocol, error) {
store.mu.Lock()
defer store.mu.Unlock()
protocol, ok := store.protocols[key]
if ok {
return protocol, nil
}
var err error
protocol, err = factory()
if err != nil {
return nil, err
}
store.protocols[key] = protocol
return protocol, nil
}
func (store *protocolStore) drop(key protocolKey) bool {
if _, ok := store.get(key); !ok {
return false
}
store.mu.Lock()
defer store.mu.Unlock()
delete(store.protocols, key)
return true
}
func (store *protocolStore) all() []Protocol {
all := make([]Protocol, 0)
store.mu.RLock()
defer store.mu.RUnlock()
for _, protocol := range store.protocols {
all = append(all, protocol)
}
return all
}

View file

@ -0,0 +1,498 @@
package transport
import (
"crypto/tls"
"errors"
"fmt"
"net"
"strings"
"sync"
"github.com/ghettovoice/gosip/log"
)
type ListenerKey string
func (key ListenerKey) String() string {
return string(key)
}
type ListenerPool interface {
log.Loggable
Done() <-chan struct{}
String() string
Put(key ListenerKey, listener net.Listener) error
Get(key ListenerKey) (net.Listener, error)
All() []net.Listener
Drop(key ListenerKey) error
DropAll() error
Length() int
}
type ListenerHandler interface {
log.Loggable
Cancel()
Done() <-chan struct{}
String() string
Key() ListenerKey
Listener() net.Listener
Serve()
// TODO implement later, runtime replace of the net.Listener in handler
// Update(ls net.Listener)
}
type listenerPool struct {
hwg sync.WaitGroup
mu sync.RWMutex
store map[ListenerKey]ListenerHandler
output chan<- Connection
errs chan<- error
cancel <-chan struct{}
done chan struct{}
hconns chan Connection
herrs chan error
log log.Logger
}
func NewListenerPool(
output chan<- Connection,
errs chan<- error,
cancel <-chan struct{},
logger log.Logger,
) ListenerPool {
pool := &listenerPool{
store: make(map[ListenerKey]ListenerHandler),
output: output,
errs: errs,
cancel: cancel,
done: make(chan struct{}),
hconns: make(chan Connection),
herrs: make(chan error),
}
pool.log = logger.
WithPrefix("transport.ListenerPool").
WithFields(log.Fields{
"listener_pool_ptr": fmt.Sprintf("%p", pool),
})
go func() {
<-pool.cancel
pool.dispose()
}()
go pool.serveHandlers()
return pool
}
func (pool *listenerPool) String() string {
if pool == nil {
return "<nil>"
}
return fmt.Sprintf("transport.ListenerPool<%s>", pool.Log().Fields())
}
func (pool *listenerPool) Log() log.Logger {
return pool.log
}
// Done returns channel that resolves when pool gracefully completes it work.
func (pool *listenerPool) Done() <-chan struct{} {
return pool.done
}
func (pool *listenerPool) Put(key ListenerKey, listener net.Listener) error {
select {
case <-pool.cancel:
return &PoolError{
fmt.Errorf("listener pool closed"),
"put listener",
pool.String(),
}
default:
}
if key == "" {
return &PoolError{
fmt.Errorf("empty listener key"),
"put listener",
pool.String(),
}
}
pool.mu.Lock()
defer pool.mu.Unlock()
return pool.put(key, listener)
}
func (pool *listenerPool) Get(key ListenerKey) (net.Listener, error) {
pool.mu.RLock()
defer pool.mu.RUnlock()
return pool.getListener(key)
}
func (pool *listenerPool) Drop(key ListenerKey) error {
pool.mu.Lock()
defer pool.mu.Unlock()
return pool.drop(key)
}
func (pool *listenerPool) DropAll() error {
pool.mu.Lock()
for key := range pool.store {
if err := pool.drop(key); err != nil {
pool.Log().Errorf("drop listener %s failed: %s", key, err)
}
}
pool.mu.Unlock()
return nil
}
func (pool *listenerPool) All() []net.Listener {
pool.mu.RLock()
listns := make([]net.Listener, 0)
for _, handler := range pool.store {
listns = append(listns, handler.Listener())
}
pool.mu.RUnlock()
return listns
}
func (pool *listenerPool) Length() int {
pool.mu.RLock()
defer pool.mu.RUnlock()
return len(pool.store)
}
func (pool *listenerPool) dispose() {
// clean pool
pool.DropAll()
pool.hwg.Wait()
// stop serveHandlers goroutine
close(pool.hconns)
close(pool.herrs)
close(pool.done)
}
func (pool *listenerPool) serveHandlers() {
pool.Log().Debug("start serve listener handlers")
defer pool.Log().Debug("stop serve listener handlers")
for {
logger := pool.Log()
select {
case conn, ok := <-pool.hconns:
if !ok {
return
}
if conn == nil {
continue
}
logger = log.AddFieldsFrom(logger, conn)
logger.Trace("passing up connection")
select {
case <-pool.cancel:
return
case pool.output <- conn:
logger.Trace("connection passed up")
}
case err, ok := <-pool.herrs:
if !ok {
return
}
if err == nil {
continue
}
var lerr *ListenerHandlerError
if errors.As(err, &lerr) {
pool.mu.RLock()
handler, gerr := pool.get(lerr.Key)
pool.mu.RUnlock()
if gerr == nil {
logger = logger.WithFields(handler.Log().Fields())
if lerr.Network() {
// listener broken or closed, should be dropped
logger.Debugf("listener network error: %s; drop it and go further", lerr)
if err := pool.Drop(handler.Key()); err != nil {
logger.Error(err)
}
} else {
// other
logger.Tracef("listener error: %s; pass the error up", lerr)
}
} else {
// ignore, handler already dropped out
logger.Tracef("ignore error from already dropped out listener %s: %s", lerr.Key, lerr)
continue
}
} else {
// all other possible errors
logger.Tracef("ignore non listener error: %s", err)
continue
}
select {
case <-pool.cancel:
return
case pool.errs <- err:
logger.Trace("error passed up")
}
}
}
}
func (pool *listenerPool) put(key ListenerKey, listener net.Listener) error {
if _, err := pool.get(key); err == nil {
return &PoolError{
fmt.Errorf("key %s already exists in the pool", key),
"put listener",
pool.String(),
}
}
// wrap to handler
handler := NewListenerHandler(key, listener, pool.hconns, pool.herrs, pool.Log())
pool.Log().WithFields(handler.Log().Fields()).Trace("put listener to the pool")
// update store
pool.store[handler.Key()] = handler
// start serving
pool.hwg.Add(1)
go handler.Serve()
go func() {
<-handler.Done()
pool.hwg.Done()
}()
return nil
}
func (pool *listenerPool) drop(key ListenerKey) error {
// check existence in pool
handler, err := pool.get(key)
if err != nil {
return err
}
handler.Cancel()
pool.Log().WithFields(handler.Log().Fields()).Trace("drop listener from the pool")
// modify store
delete(pool.store, key)
return nil
}
func (pool *listenerPool) get(key ListenerKey) (ListenerHandler, error) {
if handler, ok := pool.store[key]; ok {
return handler, nil
}
return nil, &PoolError{
fmt.Errorf("listenr %s not found in the pool", key),
"get listener",
pool.String(),
}
}
func (pool *listenerPool) getListener(key ListenerKey) (net.Listener, error) {
if handler, err := pool.get(key); err == nil {
return handler.Listener(), nil
} else {
return nil, err
}
}
type listenerHandler struct {
key ListenerKey
listener net.Listener
output chan<- Connection
errs chan<- error
cancelOnce sync.Once
canceled chan struct{}
done chan struct{}
log log.Logger
}
func NewListenerHandler(
key ListenerKey,
listener net.Listener,
output chan<- Connection,
errs chan<- error,
logger log.Logger,
) ListenerHandler {
handler := &listenerHandler{
key: key,
listener: listener,
output: output,
errs: errs,
canceled: make(chan struct{}),
done: make(chan struct{}),
}
handler.log = logger.
WithPrefix("transport.ListenerHandler").
WithFields(log.Fields{
"listener_handler_ptr": fmt.Sprintf("%p", handler),
"listener_ptr": fmt.Sprintf("%p", listener),
"listener_key": key,
})
return handler
}
func (handler *listenerHandler) String() string {
if handler == nil {
return "<nil>"
}
return fmt.Sprintf("transport.ListenerHandler<%s>", handler.Log().Fields())
}
func (handler *listenerHandler) Log() log.Logger {
return handler.log
}
func (handler *listenerHandler) Key() ListenerKey {
return handler.key
}
func (handler *listenerHandler) Listener() net.Listener {
return handler.listener
}
func (handler *listenerHandler) Serve() {
defer close(handler.done)
handler.Log().Debug("begin serve listener")
defer handler.Log().Debugf("stop serve listener")
wg := &sync.WaitGroup{}
wg.Add(1)
go handler.acceptConnections(wg)
wg.Wait()
}
func (handler *listenerHandler) acceptConnections(wg *sync.WaitGroup) {
defer func() {
handler.Listener().Close()
wg.Done()
}()
handler.Log().Debug("begin accept connections")
defer handler.Log().Debug("stop accept connections")
for {
// wait for the new connection
baseConn, err := handler.Listener().Accept()
if err != nil {
//// if we get timeout error just go further and try accept on the next iteration
//var netErr net.Error
//if errors.As(err, &netErr) {
// if netErr.Timeout() || netErr.Temporary() {
// handler.Log().Warnf("listener timeout or temporary unavailable, sleep by %s", netErrRetryTime)
//
// time.Sleep(netErrRetryTime)
//
// continue
// }
//}
// broken or closed listener
// pass up error and exit
err = &ListenerHandlerError{
err,
handler.Key(),
fmt.Sprintf("%p", handler),
listenerNetwork(handler.Listener()),
handler.Listener().Addr().String(),
}
select {
case <-handler.canceled:
case handler.errs <- err:
}
return
}
var network string
switch bc := baseConn.(type) {
case *tls.Conn:
network = "tls"
case *wsConn:
if _, ok := bc.Conn.(*tls.Conn); ok {
network = "wss"
} else {
network = "ws"
}
default:
network = strings.ToLower(baseConn.RemoteAddr().Network())
}
key := ConnectionKey(network + ":" + baseConn.RemoteAddr().String())
handler.output <- NewConnection(baseConn, key, network, handler.Log())
}
}
// Cancel stops serving.
// blocked until Serve completes
func (handler *listenerHandler) Cancel() {
handler.cancelOnce.Do(func() {
close(handler.canceled)
handler.Listener().Close()
handler.Log().Debug("listener handler canceled")
})
}
// Done returns channel that resolves when handler gracefully completes it work.
func (handler *listenerHandler) Done() <-chan struct{} {
return handler.done
}
func listenerNetwork(ls net.Listener) string {
if val, ok := ls.(interface{ Network() string }); ok {
return val.Network()
}
switch ls.(type) {
case *net.TCPListener:
return "tcp"
case *net.UnixListener:
return "unix"
default:
return ""
}
}

View file

@ -0,0 +1,90 @@
package transport
import (
"net"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
// TODO migrate other factories to functional arguments
type Options struct {
MessageMapper sip.MessageMapper
Logger log.Logger
}
type LayerOption interface {
ApplyLayer(opts *LayerOptions)
}
type LayerOptions struct {
Options
DNSResolver *net.Resolver
}
type ProtocolOption interface {
ApplyProtocol(opts *ProtocolOptions)
}
type ProtocolOptions struct {
Options
}
func WithMessageMapper(mapper sip.MessageMapper) interface {
LayerOption
ProtocolOption
} {
return withMessageMapper{mapper}
}
type withMessageMapper struct {
mapper sip.MessageMapper
}
func (o withMessageMapper) ApplyLayer(opts *LayerOptions) {
opts.MessageMapper = o.mapper
}
func (o withMessageMapper) ApplyProtocol(opts *ProtocolOptions) {
opts.MessageMapper = o.mapper
}
func WithLogger(logger log.Logger) interface {
LayerOption
ProtocolOption
} {
return withLogger{logger}
}
type withLogger struct {
logger log.Logger
}
func (o withLogger) ApplyLayer(opts *LayerOptions) {
opts.Logger = o.logger
}
func (o withLogger) ApplyProtocol(opts *ProtocolOptions) {
opts.Logger = o.logger
}
func WithDNSResolver(resolver *net.Resolver) LayerOption {
return withDnsResolver{resolver}
}
type withDnsResolver struct {
resolver *net.Resolver
}
func (o withDnsResolver) ApplyLayer(opts *LayerOptions) {
opts.DNSResolver = o.resolver
}
// Listen method options
type ListenOption interface {
ApplyListen(opts *ListenOptions)
}
type ListenOptions struct {
TLSConfig TLSConfig
}

View file

@ -0,0 +1,71 @@
package transport
import (
"fmt"
"strings"
"time"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
const (
//netErrRetryTime = 5 * time.Second
sockTTL = time.Hour
)
// Protocol implements network specific features.
type Protocol interface {
Done() <-chan struct{}
Network() string
Reliable() bool
Streamed() bool
Listen(target *Target, options ...ListenOption) error
Send(target *Target, msg sip.Message) error
String() string
}
type ProtocolFactory func(
network string,
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) (Protocol, error)
type protocol struct {
network string
reliable bool
streamed bool
log log.Logger
}
func (pr *protocol) Log() log.Logger {
return pr.log
}
func (pr *protocol) String() string {
if pr == nil {
return "<nil>"
}
fields := pr.Log().Fields().WithFields(log.Fields{
"network": pr.network,
})
return fmt.Sprintf("transport.Protocol<%s>", fields)
}
func (pr *protocol) Network() string {
return strings.ToUpper(pr.network)
}
func (pr *protocol) Reliable() bool {
return pr.reliable
}
func (pr *protocol) Streamed() bool {
return pr.streamed
}

View file

@ -0,0 +1,209 @@
package transport
import (
"fmt"
"net"
"strings"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
type tcpListener struct {
net.Listener
network string
}
func (l *tcpListener) Network() string {
return strings.ToUpper(l.network)
}
// TCP protocol implementation
type tcpProtocol struct {
protocol
listeners ListenerPool
connections ConnectionPool
conns chan Connection
listen func(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error)
dial func(addr *net.TCPAddr) (net.Conn, error)
resolveAddr func(addr string) (*net.TCPAddr, error)
}
func NewTcpProtocol(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) Protocol {
p := new(tcpProtocol)
p.network = "tcp"
p.reliable = true
p.streamed = true
p.conns = make(chan Connection)
p.log = logger.
WithPrefix("transport.Protocol").
WithFields(log.Fields{
"protocol_ptr": fmt.Sprintf("%p", p),
})
// TODO: add separate errs chan to listen errors from pool for reconnection?
p.listeners = NewListenerPool(p.conns, errs, cancel, p.Log())
p.connections = NewConnectionPool(output, errs, cancel, msgMapper, p.Log())
p.listen = p.defaultListen
p.dial = p.defaultDial
p.resolveAddr = p.defaultResolveAddr
// pipe listener and connection pools
go p.pipePools()
return p
}
func (p *tcpProtocol) defaultListen(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error) {
return net.ListenTCP(p.network, addr)
}
func (p *tcpProtocol) defaultDial(addr *net.TCPAddr) (net.Conn, error) {
return net.DialTCP(p.network, nil, addr)
}
func (p *tcpProtocol) defaultResolveAddr(addr string) (*net.TCPAddr, error) {
return net.ResolveTCPAddr(p.network, addr)
}
func (p *tcpProtocol) Done() <-chan struct{} {
return p.connections.Done()
}
// piping new connections to connection pool for serving
func (p *tcpProtocol) pipePools() {
defer close(p.conns)
p.Log().Debug("start pipe pools")
defer p.Log().Debug("stop pipe pools")
for {
select {
case <-p.listeners.Done():
return
case conn := <-p.conns:
logger := log.AddFieldsFrom(p.Log(), conn)
if err := p.connections.Put(conn, sockTTL); err != nil {
// TODO should it be passed up to UA?
logger.Errorf("put %s connection to the pool failed: %s", conn.Key(), err)
conn.Close()
continue
}
}
}
}
func (p *tcpProtocol) Listen(target *Target, options ...ListenOption) error {
target = FillTargetHostAndPort(p.Network(), target)
laddr, err := p.resolveAddr(target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
listener, err := p.listen(laddr, options...)
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("listen on %s %s address", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
p.Log().Debugf("begin listening on %s %s", p.Network(), target.Addr())
// index listeners by local address
// should live infinitely
key := ListenerKey(fmt.Sprintf("%s:0.0.0.0:%d", p.network, target.Port))
err = p.listeners.Put(key, &tcpListener{
Listener: listener,
network: p.network,
})
if err != nil {
err = &ProtocolError{
Err: err,
Op: fmt.Sprintf("put %s listener to the pool", key),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return err // should be nil here
}
func (p *tcpProtocol) Send(target *Target, msg sip.Message) error {
target = FillTargetHostAndPort(p.Network(), target)
// validate remote address
if target.Host == "" {
return &ProtocolError{
fmt.Errorf("empty remote target host"),
fmt.Sprintf("send SIP message to %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
// resolve remote address
raddr, err := p.resolveAddr(target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
// find or create connection
conn, err := p.getOrCreateConnection(raddr)
if err != nil {
return &ProtocolError{
Err: err,
Op: fmt.Sprintf("get or create %s connection", p.Network()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
logger := log.AddFieldsFrom(p.Log(), conn, msg)
logger.Tracef("writing SIP message to %s %s", p.Network(), raddr)
// send message
_, err = conn.Write([]byte(msg.String()))
if err != nil {
err = &ProtocolError{
Err: err,
Op: fmt.Sprintf("write SIP message to the %s connection", conn.Key()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return err
}
func (p *tcpProtocol) getOrCreateConnection(raddr *net.TCPAddr) (Connection, error) {
key := ConnectionKey(p.network + ":" + raddr.String())
conn, err := p.connections.Get(key)
if err != nil {
p.Log().Debugf("connection for remote address %s %s not found, create a new one", p.Network(), raddr)
tcpConn, err := p.dial(raddr)
if err != nil {
return nil, fmt.Errorf("dial to %s %s: %w", p.Network(), raddr, err)
}
conn = NewConnection(tcpConn, key, p.network, p.Log())
if err := p.connections.Put(conn, sockTTL); err != nil {
return conn, fmt.Errorf("put %s connection to the pool: %w", conn.Key(), err)
}
}
return conn, nil
}

View file

@ -0,0 +1,67 @@
package transport
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
type tlsProtocol struct {
tcpProtocol
}
func NewTlsProtocol(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) Protocol {
p := new(tlsProtocol)
p.network = "tls"
p.reliable = true
p.streamed = true
p.conns = make(chan Connection)
p.log = logger.
WithPrefix("transport.Protocol").
WithFields(log.Fields{
"protocol_ptr": fmt.Sprintf("%p", p),
})
//TODO: add separate errs chan to listen errors from pool for reconnection?
p.listeners = NewListenerPool(p.conns, errs, cancel, p.Log())
p.connections = NewConnectionPool(output, errs, cancel, msgMapper, p.Log())
p.listen = func(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error) {
if len(options) == 0 {
return net.ListenTCP("tcp", addr)
}
optsHash := ListenOptions{}
for _, opt := range options {
opt.ApplyListen(&optsHash)
}
cert, err := tls.LoadX509KeyPair(optsHash.TLSConfig.Cert, optsHash.TLSConfig.Key)
if err != nil {
return nil, fmt.Errorf("load TLS certficate %s: %w", optsHash.TLSConfig.Cert, err)
}
return tls.Listen("tcp", addr.String(), &tls.Config{
Certificates: []tls.Certificate{cert},
})
}
p.dial = func(addr *net.TCPAddr) (net.Conn, error) {
return tls.Dial("tcp", addr.String(), &tls.Config{
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return nil
},
})
}
p.resolveAddr = func(addr string) (*net.TCPAddr, error) {
return net.ResolveTCPAddr("tcp", addr)
}
//pipe listener and connection pools
go p.pipePools()
return p
}

View file

@ -0,0 +1,357 @@
// transport package implements SIP transport layer.
package transport
import (
"errors"
"fmt"
"io"
"net"
"regexp"
"strconv"
"strings"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
const (
MTU = sip.MTU
DefaultHost = sip.DefaultHost
DefaultProtocol = sip.DefaultProtocol
DefaultUdpPort = sip.DefaultUdpPort
DefaultTcpPort = sip.DefaultTcpPort
DefaultTlsPort = sip.DefaultTlsPort
DefaultWsPort = sip.DefaultWsPort
DefaultWssPort = sip.DefaultWssPort
)
// Target endpoint
type Target struct {
Host string
Port *sip.Port
}
func (trg *Target) Addr() string {
var (
host string
port sip.Port
)
if strings.TrimSpace(trg.Host) != "" {
host = trg.Host
} else {
host = DefaultHost
}
if trg.Port != nil {
port = *trg.Port
}
return fmt.Sprintf("%v:%v", host, port)
}
func (trg *Target) String() string {
if trg == nil {
return "<nil>"
}
fields := log.Fields{
"target_addr": trg.Addr(),
}
return fmt.Sprintf("transport.Target<%s>", fields)
}
func NewTarget(host string, port int) *Target {
cport := sip.Port(port)
return &Target{Host: host, Port: &cport}
}
func NewTargetFromAddr(addr string) (*Target, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
iport, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
return NewTarget(host, iport), nil
}
// Fills endpoint target with default values.
func FillTargetHostAndPort(network string, target *Target) *Target {
if strings.TrimSpace(target.Host) == "" {
target.Host = DefaultHost
}
if target.Port == nil {
p := sip.DefaultPort(network)
target.Port = &p
}
return target
}
// Transport error
type Error interface {
net.Error
// Network indicates network level errors
Network() bool
}
func isNetwork(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return true
} else {
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe)
}
}
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}
func isTemporary(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Temporary()
}
return false
}
func isCanceled(err error) bool {
var cancelErr sip.CancelError
if errors.As(err, &cancelErr) {
return cancelErr.Canceled()
}
return false
}
func isExpired(err error) bool {
var expiryErr sip.ExpireError
if errors.As(err, &expiryErr) {
return expiryErr.Expired()
}
return false
}
// Connection level error.
type ConnectionError struct {
Err error
Op string
Net string
Source string
Dest string
ConnPtr string
}
func (err *ConnectionError) Unwrap() error { return err.Err }
func (err *ConnectionError) Network() bool { return isNetwork(err.Err) }
func (err *ConnectionError) Timeout() bool { return isTimeout(err.Err) }
func (err *ConnectionError) Temporary() bool { return isTemporary(err.Err) }
func (err *ConnectionError) Error() string {
if err == nil {
return "<nil>"
}
fields := log.Fields{
"network": "???",
"connection_ptr": "???",
"source": "???",
"destination": "???",
}
if err.Net != "" {
fields["network"] = err.Net
}
if err.ConnPtr != "" {
fields["connection_ptr"] = err.ConnPtr
}
if err.Source != "" {
fields["source"] = err.Source
}
if err.Dest != "" {
fields["destination"] = err.Dest
}
return fmt.Sprintf("transport.ConnectionError<%s> %s failed: %s", fields, err.Op, err.Err)
}
type ExpireError string
func (err ExpireError) Network() bool { return false }
func (err ExpireError) Timeout() bool { return true }
func (err ExpireError) Temporary() bool { return false }
func (err ExpireError) Canceled() bool { return false }
func (err ExpireError) Expired() bool { return true }
func (err ExpireError) Error() string { return "transport.ExpireError: " + string(err) }
// Net Protocol level error
type ProtocolError struct {
Err error
Op string
ProtoPtr string
}
func (err *ProtocolError) Unwrap() error { return err.Err }
func (err *ProtocolError) Network() bool { return isNetwork(err.Err) }
func (err *ProtocolError) Timeout() bool { return isTimeout(err.Err) }
func (err *ProtocolError) Temporary() bool { return isTemporary(err.Err) }
func (err *ProtocolError) Error() string {
if err == nil {
return "<nil>"
}
fields := log.Fields{
"protocol_ptr": "???",
}
if err.ProtoPtr != "" {
fields["protocol_ptr"] = err.ProtoPtr
}
return fmt.Sprintf("transport.ProtocolError<%s> %s failed: %s", fields, err.Op, err.Err)
}
type ConnectionHandlerError struct {
Err error
Key ConnectionKey
HandlerPtr string
Net string
LAddr string
RAddr string
}
func (err *ConnectionHandlerError) Unwrap() error { return err.Err }
func (err *ConnectionHandlerError) Network() bool { return isNetwork(err.Err) }
func (err *ConnectionHandlerError) Timeout() bool { return isTimeout(err.Err) }
func (err *ConnectionHandlerError) Temporary() bool { return isTemporary(err.Err) }
func (err *ConnectionHandlerError) Canceled() bool { return isCanceled(err.Err) }
func (err *ConnectionHandlerError) Expired() bool { return isExpired(err.Err) }
func (err *ConnectionHandlerError) EOF() bool {
if err.Err == io.EOF {
return true
}
ok, _ := regexp.MatchString("(?i)eof", err.Err.Error())
return ok
}
func (err *ConnectionHandlerError) Error() string {
if err == nil {
return "<nil>"
}
fields := log.Fields{
"handler_ptr": "???",
"network": "???",
"local_addr": "???",
"remote_addr": "???",
}
if err.HandlerPtr != "" {
fields["handler_ptr"] = err.HandlerPtr
}
if err.Net != "" {
fields["network"] = err.Net
}
if err.LAddr != "" {
fields["local_addr"] = err.LAddr
}
if err.RAddr != "" {
fields["remote_addr"] = err.RAddr
}
return fmt.Sprintf("transport.ConnectionHandlerError<%s>: %s", fields, err.Err)
}
type ListenerHandlerError struct {
Err error
Key ListenerKey
HandlerPtr string
Net string
Addr string
}
func (err *ListenerHandlerError) Unwrap() error { return err.Err }
func (err *ListenerHandlerError) Network() bool { return isNetwork(err.Err) }
func (err *ListenerHandlerError) Timeout() bool { return isTimeout(err.Err) }
func (err *ListenerHandlerError) Temporary() bool { return isTemporary(err.Err) }
func (err *ListenerHandlerError) Canceled() bool { return isCanceled(err.Err) }
func (err *ListenerHandlerError) Expired() bool { return isExpired(err.Err) }
func (err *ListenerHandlerError) Error() string {
if err == nil {
return "<nil>"
}
fields := log.Fields{
"handler_ptr": "???",
"network": "???",
"local_addr": "???",
"remote_addr": "???",
}
if err.HandlerPtr != "" {
fields["handler_ptr"] = err.HandlerPtr
}
if err.Net != "" {
fields["network"] = err.Net
}
if err.Addr != "" {
fields["local_addr"] = err.Addr
}
return fmt.Sprintf("transport.ListenerHandlerError<%s>: %s", fields, err.Err)
}
type PoolError struct {
Err error
Op string
Pool string
}
func (err *PoolError) Unwrap() error { return err.Err }
func (err *PoolError) Network() bool { return isNetwork(err.Err) }
func (err *PoolError) Timeout() bool { return isTimeout(err.Err) }
func (err *PoolError) Temporary() bool { return isTemporary(err.Err) }
func (err *PoolError) Error() string {
if err == nil {
return "<nil>"
}
fields := log.Fields{
"pool": "???",
}
if err.Pool != "" {
fields["pool"] = err.Pool
}
return fmt.Sprintf("transport.PoolError<%s> %s failed: %s", fields, err.Op, err.Err)
}
type UnsupportedProtocolError string
func (err UnsupportedProtocolError) Network() bool { return false }
func (err UnsupportedProtocolError) Timeout() bool { return false }
func (err UnsupportedProtocolError) Temporary() bool { return false }
func (err UnsupportedProtocolError) Error() string {
return "transport.UnsupportedProtocolError: " + string(err)
}
//TLSConfig for TLS and WSS only
type TLSConfig struct {
Domain string
Cert string
Key string
Pass string
}
func (c TLSConfig) ApplyListen(opts *ListenOptions) {
opts.TLSConfig.Domain = c.Domain
opts.TLSConfig.Cert = c.Cert
opts.TLSConfig.Key = c.Key
opts.TLSConfig.Pass = c.Pass
}

View file

@ -0,0 +1,138 @@
package transport
import (
"fmt"
"net"
"strings"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
// UDP protocol implementation
type udpProtocol struct {
protocol
connections ConnectionPool
}
func NewUdpProtocol(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) Protocol {
p := new(udpProtocol)
p.network = "udp"
p.reliable = false
p.streamed = false
p.log = logger.
WithPrefix("transport.Protocol").
WithFields(log.Fields{
"protocol_ptr": fmt.Sprintf("%p", p),
})
// TODO: add separate errs chan to listen errors from pool for reconnection?
p.connections = NewConnectionPool(output, errs, cancel, msgMapper, p.Log())
return p
}
func (p *udpProtocol) Done() <-chan struct{} {
return p.connections.Done()
}
func (p *udpProtocol) Listen(target *Target, options ...ListenOption) error {
// fill empty target props with default values
target = FillTargetHostAndPort(p.Network(), target)
// resolve local UDP endpoint
laddr, err := net.ResolveUDPAddr(p.network, target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
// create UDP connection
udpConn, err := net.ListenUDP(p.network, laddr)
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("listen on %s %s address", p.Network(), laddr),
fmt.Sprintf("%p", p),
}
}
p.Log().Debugf("begin listening on %s %s", p.Network(), laddr)
// register new connection
// index by local address, TTL=0 - unlimited expiry time
key := ConnectionKey(fmt.Sprintf("%s:0.0.0.0:%d", p.network, laddr.Port))
conn := NewConnection(udpConn, key, p.network, p.Log())
err = p.connections.Put(conn, 0)
if err != nil {
err = &ProtocolError{
Err: err,
Op: fmt.Sprintf("put %s connection to the pool", conn.Key()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return err // should be nil here
}
func (p *udpProtocol) Send(target *Target, msg sip.Message) error {
target = FillTargetHostAndPort(p.Network(), target)
// validate remote address
if target.Host == "" {
return &ProtocolError{
fmt.Errorf("empty remote target host"),
fmt.Sprintf("send SIP message to %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
// resolve remote address
raddr, err := net.ResolveUDPAddr(p.network, target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
_, port, err := net.SplitHostPort(msg.Source())
if err != nil {
return &ProtocolError{
Err: err,
Op: "resolve source port",
ProtoPtr: fmt.Sprintf("%p", p),
}
}
for _, conn := range p.connections.All() {
parts := strings.Split(string(conn.Key()), ":")
if parts[2] == port {
logger := log.AddFieldsFrom(p.Log(), conn, msg)
logger.Tracef("writing SIP message to %s %s", p.Network(), raddr)
if _, err = conn.WriteTo([]byte(msg.String()), raddr); err != nil {
return &ProtocolError{
Err: err,
Op: fmt.Sprintf("write SIP message to the %s connection", conn.Key()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return nil
}
}
return &ProtocolError{
fmt.Errorf("connection on port %s not found", port),
"search connection",
fmt.Sprintf("%p", p),
}
}

View file

@ -0,0 +1,301 @@
package transport
import (
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
var (
wsSubProtocol = "sip"
)
type wsConn struct {
net.Conn
client bool
}
func (wc *wsConn) Read(b []byte) (n int, err error) {
var msg []byte
var op ws.OpCode
if wc.client {
msg, op, err = wsutil.ReadServerData(wc.Conn)
} else {
msg, op, err = wsutil.ReadClientData(wc.Conn)
}
if err != nil {
// handle error
var wsErr wsutil.ClosedError
if errors.As(err, &wsErr) {
return n, io.EOF
}
return n, err
}
if op == ws.OpClose {
return n, io.EOF
}
copy(b, msg)
return len(msg), err
}
func (wc *wsConn) Write(b []byte) (n int, err error) {
if wc.client {
err = wsutil.WriteClientMessage(wc.Conn, ws.OpText, b)
} else {
err = wsutil.WriteServerMessage(wc.Conn, ws.OpText, b)
}
if err != nil {
// handle error
var wsErr wsutil.ClosedError
if errors.As(err, &wsErr) {
return n, io.EOF
}
return n, err
}
return len(b), nil
}
type wsListener struct {
net.Listener
network string
u ws.Upgrader
log log.Logger
}
func NewWsListener(listener net.Listener, network string, log log.Logger) *wsListener {
l := &wsListener{
Listener: listener,
network: network,
log: log,
}
l.u.Protocol = func(val []byte) bool {
return string(val) == wsSubProtocol
}
return l
}
func (l *wsListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, fmt.Errorf("accept new connection: %w", err)
}
if _, err = l.u.Upgrade(conn); err == nil {
conn = &wsConn{
Conn: conn,
client: false,
}
} else {
l.log.Warnf("fallback to simple TCP connection due to WS upgrade error: %s", err)
err = nil
}
return conn, err
}
func (l *wsListener) Network() string {
return strings.ToUpper(l.network)
}
type wsProtocol struct {
protocol
listeners ListenerPool
connections ConnectionPool
conns chan Connection
listen func(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error)
resolveAddr func(addr string) (*net.TCPAddr, error)
dialer ws.Dialer
}
func NewWsProtocol(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) Protocol {
p := new(wsProtocol)
p.network = "ws"
p.reliable = true
p.streamed = true
p.conns = make(chan Connection)
p.log = logger.
WithPrefix("transport.Protocol").
WithFields(log.Fields{
"protocol_ptr": fmt.Sprintf("%p", p),
})
//TODO: add separate errs chan to listen errors from pool for reconnection?
p.listeners = NewListenerPool(p.conns, errs, cancel, p.Log())
p.connections = NewConnectionPool(output, errs, cancel, msgMapper, p.Log())
p.listen = p.defaultListen
p.resolveAddr = p.defaultResolveAddr
p.dialer.Protocols = []string{wsSubProtocol}
p.dialer.Timeout = time.Minute
//pipe listener and connection pools
go p.pipePools()
return p
}
func (p *wsProtocol) defaultListen(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error) {
return net.ListenTCP("tcp", addr)
}
func (p *wsProtocol) defaultResolveAddr(addr string) (*net.TCPAddr, error) {
return net.ResolveTCPAddr("tcp", addr)
}
func (p *wsProtocol) Done() <-chan struct{} {
return p.connections.Done()
}
//piping new connections to connection pool for serving
func (p *wsProtocol) pipePools() {
defer close(p.conns)
p.Log().Debug("start pipe pools")
defer p.Log().Debug("stop pipe pools")
for {
select {
case <-p.listeners.Done():
return
case conn := <-p.conns:
logger := log.AddFieldsFrom(p.Log(), conn)
if err := p.connections.Put(conn, sockTTL); err != nil {
// TODO should it be passed up to UA?
logger.Errorf("put %s connection to the pool failed: %s", conn.Key(), err)
conn.Close()
continue
}
}
}
}
func (p *wsProtocol) Listen(target *Target, options ...ListenOption) error {
target = FillTargetHostAndPort(p.Network(), target)
laddr, err := p.resolveAddr(target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
listener, err := p.listen(laddr, options...)
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("listen on %s %s address", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
p.Log().Debugf("begin listening on %s %s", p.Network(), target.Addr())
//index listeners by local address
// should live infinitely
key := ListenerKey(fmt.Sprintf("%s:0.0.0.0:%d", p.network, target.Port))
err = p.listeners.Put(key, NewWsListener(listener, p.network, p.Log()))
if err != nil {
err = &ProtocolError{
Err: err,
Op: fmt.Sprintf("put %s listener to the pool", key),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return err //should be nil here
}
func (p *wsProtocol) Send(target *Target, msg sip.Message) error {
target = FillTargetHostAndPort(p.Network(), target)
//validate remote address
if target.Host == "" {
return &ProtocolError{
fmt.Errorf("empty remote target host"),
fmt.Sprintf("send SIP message to %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
//resolve remote address
raddr, err := p.resolveAddr(target.Addr())
if err != nil {
return &ProtocolError{
err,
fmt.Sprintf("resolve target address %s %s", p.Network(), target.Addr()),
fmt.Sprintf("%p", p),
}
}
//find or create connection
conn, err := p.getOrCreateConnection(raddr)
if err != nil {
return &ProtocolError{
Err: err,
Op: fmt.Sprintf("get or create %s connection", p.Network()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
logger := log.AddFieldsFrom(p.Log(), conn, msg)
logger.Tracef("writing SIP message to %s %s", p.Network(), raddr)
//send message
_, err = conn.Write([]byte(msg.String()))
if err != nil {
err = &ProtocolError{
Err: err,
Op: fmt.Sprintf("write SIP message to the %s connection", conn.Key()),
ProtoPtr: fmt.Sprintf("%p", p),
}
}
return err
}
func (p *wsProtocol) getOrCreateConnection(raddr *net.TCPAddr) (Connection, error) {
key := ConnectionKey(p.network + ":" + raddr.String())
conn, err := p.connections.Get(key)
if err != nil {
p.Log().Debugf("connection for address %s %s not found; create a new one", p.Network(), raddr)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
url := fmt.Sprintf("%s://%s", p.network, raddr)
baseConn, _, _, err := p.dialer.Dial(ctx, url)
if err == nil {
baseConn = &wsConn{
Conn: baseConn,
client: true,
}
} else {
if baseConn == nil {
return nil, fmt.Errorf("dial to %s %s: %w", p.Network(), raddr, err)
}
p.Log().Warnf("fallback to TCP connection due to WS upgrade error: %s", err)
}
conn = NewConnection(baseConn, key, p.network, p.Log())
if err := p.connections.Put(conn, sockTTL); err != nil {
return conn, fmt.Errorf("put %s connection to the pool: %w", conn.Key(), err)
}
}
return conn, nil
}

View file

@ -0,0 +1,66 @@
package transport
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"time"
"github.com/ghettovoice/gosip/log"
"github.com/ghettovoice/gosip/sip"
)
type wssProtocol struct {
wsProtocol
}
func NewWssProtocol(
output chan<- sip.Message,
errs chan<- error,
cancel <-chan struct{},
msgMapper sip.MessageMapper,
logger log.Logger,
) Protocol {
p := new(wssProtocol)
p.network = "wss"
p.reliable = true
p.streamed = true
p.conns = make(chan Connection)
p.log = logger.
WithPrefix("transport.Protocol").
WithFields(log.Fields{
"protocol_ptr": fmt.Sprintf("%p", p),
})
//TODO: add separate errs chan to listen errors from pool for reconnection?
p.listeners = NewListenerPool(p.conns, errs, cancel, p.Log())
p.connections = NewConnectionPool(output, errs, cancel, msgMapper, p.Log())
p.listen = func(addr *net.TCPAddr, options ...ListenOption) (net.Listener, error) {
if len(options) == 0 {
return net.ListenTCP("tcp", addr)
}
optsHash := ListenOptions{}
for _, opt := range options {
opt.ApplyListen(&optsHash)
}
cert, err := tls.LoadX509KeyPair(optsHash.TLSConfig.Cert, optsHash.TLSConfig.Key)
if err != nil {
return nil, fmt.Errorf("load TLS certficate %s: %w", optsHash.TLSConfig.Cert, err)
}
return tls.Listen("tcp", addr.String(), &tls.Config{
Certificates: []tls.Certificate{cert},
})
}
p.resolveAddr = p.defaultResolveAddr
p.dialer.Protocols = []string{wsSubProtocol}
p.dialer.Timeout = time.Minute
p.dialer.TLSConfig = &tls.Config{
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return nil
},
}
//pipe listener and connection pools
go p.pipePools()
return p
}

View file

@ -0,0 +1,147 @@
// Forked from github.com/StefanKopieczek/gossip by @StefanKopieczek
package util
import (
"fmt"
"sync"
"github.com/ghettovoice/gosip/log"
)
// The buffer size of the primitive input and output chans.
const c_ELASTIC_CHANSIZE = 3
// A dynamic channel that does not block on send, but has an unlimited buffer capacity.
// ElasticChan uses a dynamic slice to buffer signals received on the input channel until
// the output channel is ready to process them.
type ElasticChan struct {
In chan interface{}
Out chan interface{}
buffer []interface{}
stopped bool
done chan struct{}
log log.Logger
logMu sync.RWMutex
}
// Initialise the Elastic channel, and start the management goroutine.
func (c *ElasticChan) Init() {
c.In = make(chan interface{}, c_ELASTIC_CHANSIZE)
c.Out = make(chan interface{}, c_ELASTIC_CHANSIZE)
c.buffer = make([]interface{}, 0)
c.done = make(chan struct{})
}
func (c *ElasticChan) Run() {
go c.manage()
}
func (c *ElasticChan) Stop() {
select {
case <-c.done:
return
default:
}
logger := c.Log()
if logger != nil {
logger.Trace("stopping elastic chan...")
}
close(c.In)
<-c.done
if logger != nil {
logger.Trace("elastic chan stopped")
}
}
func (c *ElasticChan) Log() log.Logger {
c.logMu.RLock()
defer c.logMu.RUnlock()
return c.log
}
func (c *ElasticChan) SetLog(logger log.Logger) {
c.logMu.Lock()
c.log = logger.
WithPrefix("util.ElasticChan").
WithFields(log.Fields{
"elastic_chan_ptr": fmt.Sprintf("%p", c),
})
c.logMu.Unlock()
}
// Poll for input from one end of the channel and add it to the buffer.
// Also poll sending buffered signals out over the output chan.
// TODO: add cancel chan
func (c *ElasticChan) manage() {
defer close(c.done)
loop:
for {
logger := c.Log()
if len(c.buffer) > 0 {
// The buffer has something in it, so try to send as well as
// receive.
// (Receive first in order to minimize blocked Send() calls).
select {
case in, ok := <-c.In:
if !ok {
if logger != nil {
logger.Trace("elastic chan will dispose")
}
break loop
}
c.Log().Tracef("ElasticChan %p gets '%v'", c, in)
c.buffer = append(c.buffer, in)
case c.Out <- c.buffer[0]:
c.Log().Tracef("ElasticChan %p sends '%v'", c, c.buffer[0])
c.buffer = c.buffer[1:]
}
} else {
// The buffer is empty, so there's nothing to send.
// Just wait to receive.
in, ok := <-c.In
if !ok {
if logger != nil {
logger.Trace("elastic chan will dispose")
}
break loop
}
c.Log().Tracef("ElasticChan %p gets '%v'", c, in)
c.buffer = append(c.buffer, in)
}
}
c.dispose()
}
func (c *ElasticChan) dispose() {
logger := c.Log()
if logger != nil {
logger.Trace("elastic chan disposing...")
}
for len(c.buffer) > 0 {
select {
case c.Out <- c.buffer[0]:
c.buffer = c.buffer[1:]
default:
}
}
if logger != nil {
logger.Trace("elastic chan disposed")
}
}

View file

@ -0,0 +1,104 @@
// Forked from github.com/StefanKopieczek/gossip by @StefanKopieczek
package util
import (
"errors"
"net"
"sync"
)
// Check two string pointers for equality as follows:
// - If neither pointer is nil, check equality of the underlying strings.
// - If either pointer is nil, return true if and only if they both are.
func StrPtrEq(a *string, b *string) bool {
if a == nil || b == nil {
return a == b
}
return *a == *b
}
// Check two uint16 pointers for equality as follows:
// - If neither pointer is nil, check equality of the underlying uint16s.
// - If either pointer is nil, return true if and only if they both are.
func Uint16PtrEq(a *uint16, b *uint16) bool {
if a == nil || b == nil {
return a == b
}
return *a == *b
}
func Coalesce(arg1 interface{}, arg2 interface{}, args ...interface{}) interface{} {
all := append([]interface{}{arg1, arg2}, args...)
for _, arg := range all {
if arg != nil {
return arg
}
}
return nil
}
func Noop() {}
func MergeErrs(chs ...<-chan error) <-chan error {
wg := new(sync.WaitGroup)
out := make(chan error)
pipe := func(ch <-chan error) {
defer wg.Done()
for err := range ch {
out <- err
}
}
wg.Add(len(chs))
for _, ch := range chs {
go pipe(ch)
}
go func() {
wg.Wait()
close(out)
}()
return out
}
func ResolveSelfIP() (net.IP, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
return ip, nil
}
}
return nil, errors.New("server not connected to any network")
}

View file

@ -0,0 +1,38 @@
package util
import (
"math/rand"
"time"
)
const (
letterBytes = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
// https://github.com/kpbird/golang_random_string
func RandString(n int) string {
output := make([]byte, n)
// We will take n bytes, one byte for each character of output.
randomness := make([]byte, n)
// read all random
_, err := rand.Read(randomness)
if err != nil {
panic(err)
}
l := len(letterBytes)
// fill output
for pos := range output {
// get random item
random := randomness[pos]
// random % 64
randomPos := random % uint8(l)
// put into output
output[pos] = letterBytes[randomPos]
}
return string(output)
}

View file

@ -0,0 +1,75 @@
// Forked from github.com/StefanKopieczek/gossip by @StefanKopieczek
package util
import "sync"
// Simple semaphore implementation.
// Any number of calls to Acquire() can be made; these will not block.
// If the semaphore has been acquired more times than it has been released, it is called 'blocked'.
// Otherwise, it is called 'free'.
type Semaphore interface {
// Take a semaphore lock.
Acquire()
// Release an acquired semaphore lock.
// This should only be called when the semaphore is blocked, otherwise behaviour is undefined
Release()
// Block execution until the semaphore is free.
Wait()
// Clean up the semaphore object.
Dispose()
}
func NewSemaphore() Semaphore {
sem := new(semaphore)
sem.cond = sync.NewCond(&sync.Mutex{})
go func(s *semaphore) {
select {
case <-s.stop:
return
case <-s.acquired:
s.locks += 1
case <-s.released:
s.locks -= 1
if s.locks == 0 {
s.cond.Broadcast()
}
}
}(sem)
return sem
}
// Concrete implementation of Semaphore.
type semaphore struct {
held bool
locks int
acquired chan bool
released chan bool
stop chan bool
cond *sync.Cond
}
// Implements Semaphore.Acquire()
func (sem *semaphore) Acquire() {
sem.acquired <- true
}
// Implements Semaphore.Release()
func (sem *semaphore) Release() {
sem.released <- true
}
// Implements Semaphore.Wait()
func (sem *semaphore) Wait() {
sem.cond.L.Lock()
for sem.locks != 0 {
sem.cond.Wait()
}
}
// Implements Semaphore.Dispose()
func (sem *semaphore) Dispose() {
sem.stop <- true
}