1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

For regression test, add srs-bench to 3rdparty

This commit is contained in:
winlin 2021-03-04 13:23:01 +08:00
parent de87dd427d
commit 876210f6c9
1158 changed files with 256967 additions and 3 deletions

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2018
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,110 @@
// Package deadline provides deadline timer used to implement
// net.Conn compatible connection
package deadline
import (
"context"
"sync"
"time"
)
// Deadline signals updatable deadline timer.
// Also, it implements context.Context.
type Deadline struct {
exceeded chan struct{}
stop chan struct{}
stopped chan bool
deadline time.Time
mu sync.RWMutex
}
// New creates new deadline timer.
func New() *Deadline {
d := &Deadline{
exceeded: make(chan struct{}),
stop: make(chan struct{}),
stopped: make(chan bool, 1),
}
d.stopped <- true
return d
}
// Set new deadline. Zero value means no deadline.
func (d *Deadline) Set(t time.Time) {
d.mu.Lock()
defer d.mu.Unlock()
d.deadline = t
close(d.stop)
select {
case <-d.exceeded:
d.exceeded = make(chan struct{})
default:
stopped := <-d.stopped
if !stopped {
d.exceeded = make(chan struct{})
}
}
d.stop = make(chan struct{})
d.stopped = make(chan bool, 1)
if t.IsZero() {
d.stopped <- true
return
}
if dur := time.Until(t); dur > 0 {
exceeded := d.exceeded
stopped := d.stopped
go func() {
select {
case <-time.After(dur):
close(exceeded)
stopped <- false
case <-d.stop:
stopped <- true
}
}()
return
}
close(d.exceeded)
d.stopped <- false
}
// Done receives deadline signal.
func (d *Deadline) Done() <-chan struct{} {
d.mu.RLock()
defer d.mu.RUnlock()
return d.exceeded
}
// Err returns context.DeadlineExceeded if the deadline is exceeded.
// Otherwise, it returns nil.
func (d *Deadline) Err() error {
d.mu.RLock()
defer d.mu.RUnlock()
select {
case <-d.exceeded:
return context.DeadlineExceeded
default:
return nil
}
}
// Deadline returns current deadline.
func (d *Deadline) Deadline() (time.Time, bool) {
d.mu.RLock()
defer d.mu.RUnlock()
if d.deadline.IsZero() {
return d.deadline, false
}
return d.deadline, true
}
// Value returns nil.
func (d *Deadline) Value(interface{}) interface{} {
return nil
}

View file

@ -0,0 +1,347 @@
// Package packetio provides packet buffer
package packetio
import (
"errors"
"io"
"sync"
"time"
"github.com/pion/transport/deadline"
)
var errPacketTooBig = errors.New("packet too big")
// BufferPacketType allow the Buffer to know which packet protocol is writing.
type BufferPacketType int
const (
// RTPBufferPacket indicates the Buffer that is handling RTP packets
RTPBufferPacket BufferPacketType = 1
// RTCPBufferPacket indicates the Buffer that is handling RTCP packets
RTCPBufferPacket BufferPacketType = 2
)
// Buffer allows writing packets to an intermediate buffer, which can then be read form.
// This is verify similar to bytes.Buffer but avoids combining multiple writes into a single read.
type Buffer struct {
mutex sync.Mutex
// this is a circular buffer. If head <= tail, then the useful
// data is in the interval [head, tail[. If tail < head, then
// the useful data is the union of [head, len[ and [0, tail[.
// In order to avoid ambiguity when head = tail, we always leave
// an unused byte in the buffer.
data []byte
head, tail int
notify chan struct{}
subs bool
closed bool
count int
limitCount, limitSize int
readDeadline *deadline.Deadline
}
const (
minSize = 2048
cutoffSize = 128 * 1024
maxSize = 4 * 1024 * 1024
)
// NewBuffer creates a new Buffer.
func NewBuffer() *Buffer {
return &Buffer{
notify: make(chan struct{}),
readDeadline: deadline.New(),
}
}
// available returns true if the buffer is large enough to fit a packet
// of the given size, taking overhead into account.
func (b *Buffer) available(size int) bool {
available := b.head - b.tail
if available <= 0 {
available += len(b.data)
}
// we interpret head=tail as empty, so always keep a byte free
if size+2+1 > available {
return false
}
return true
}
// grow increases the size of the buffer. If it returns nil, then the
// buffer has been grown. It returns ErrFull if hits a limit.
func (b *Buffer) grow() error {
var newsize int
if len(b.data) < cutoffSize {
newsize = 2 * len(b.data)
} else {
newsize = 5 * len(b.data) / 4
}
if newsize < minSize {
newsize = minSize
}
if (b.limitSize <= 0 || sizeHardlimit) && newsize > maxSize {
newsize = maxSize
}
// one byte slack
if b.limitSize > 0 && newsize > b.limitSize+1 {
newsize = b.limitSize + 1
}
if newsize <= len(b.data) {
return ErrFull
}
newdata := make([]byte, newsize)
var n int
if b.head <= b.tail {
// data was contiguous
n = copy(newdata, b.data[b.head:b.tail])
} else {
// data was discontiguous
n = copy(newdata, b.data[b.head:])
n += copy(newdata[n:], b.data[:b.tail])
}
b.head = 0
b.tail = n
b.data = newdata
return nil
}
// Write appends a copy of the packet data to the buffer.
// Returns ErrFull if the packet doesn't fit.
//
// Note that the packet size is limited to 65536 bytes since v0.11.0 due to the internal data structure.
func (b *Buffer) Write(packet []byte) (int, error) {
if len(packet) >= 0x10000 {
return 0, errPacketTooBig
}
b.mutex.Lock()
if b.closed {
b.mutex.Unlock()
return 0, io.ErrClosedPipe
}
if (b.limitCount > 0 && b.count >= b.limitCount) ||
(b.limitSize > 0 && b.size()+2+len(packet) > b.limitSize) {
b.mutex.Unlock()
return 0, ErrFull
}
// grow the buffer until the packet fits
for !b.available(len(packet)) {
err := b.grow()
if err != nil {
b.mutex.Unlock()
return 0, err
}
}
var notify chan struct{}
if b.subs {
// readers are waiting. Prepare to notify, but only
// actually do it after we release the lock.
notify = b.notify
b.notify = make(chan struct{})
b.subs = false
}
// store the length of the packet
b.data[b.tail] = uint8(len(packet) >> 8)
b.tail++
if b.tail >= len(b.data) {
b.tail = 0
}
b.data[b.tail] = uint8(len(packet))
b.tail++
if b.tail >= len(b.data) {
b.tail = 0
}
// store the packet
n := copy(b.data[b.tail:], packet)
b.tail += n
if b.tail >= len(b.data) {
// we reached the end, wrap around
m := copy(b.data, packet[n:])
b.tail = m
}
b.count++
b.mutex.Unlock()
if notify != nil {
close(notify)
}
return len(packet), nil
}
// Read populates the given byte slice, returning the number of bytes read.
// Blocks until data is available or the buffer is closed.
// Returns io.ErrShortBuffer is the packet is too small to copy the Write.
// Returns io.EOF if the buffer is closed.
func (b *Buffer) Read(packet []byte) (n int, err error) {
// Return immediately if the deadline is already exceeded.
select {
case <-b.readDeadline.Done():
return 0, &netError{ErrTimeout, true, true}
default:
}
for {
b.mutex.Lock()
if b.head != b.tail {
// decode the packet size
n1 := b.data[b.head]
b.head++
if b.head >= len(b.data) {
b.head = 0
}
n2 := b.data[b.head]
b.head++
if b.head >= len(b.data) {
b.head = 0
}
count := int((uint16(n1) << 8) | uint16(n2))
// determine the number of bytes we'll actually copy
copied := count
if copied > len(packet) {
copied = len(packet)
}
// copy the data
if b.head+copied < len(b.data) {
copy(packet, b.data[b.head:b.head+copied])
} else {
k := copy(packet, b.data[b.head:])
copy(packet[k:], b.data[:copied-k])
}
// advance head, discarding any data that wasn't copied
b.head += count
if b.head >= len(b.data) {
b.head -= len(b.data)
}
if b.head == b.tail {
// the buffer is empty, reset to beginning
// in order to improve cache locality.
b.head = 0
b.tail = 0
}
b.count--
b.mutex.Unlock()
if copied < count {
return copied, io.ErrShortBuffer
}
return copied, nil
}
if b.closed {
b.mutex.Unlock()
return 0, io.EOF
}
notify := b.notify
b.subs = true
b.mutex.Unlock()
select {
case <-b.readDeadline.Done():
return 0, &netError{ErrTimeout, true, true}
case <-notify:
}
}
}
// Close the buffer, unblocking any pending reads.
// Data in the buffer can still be read, Read will return io.EOF only when empty.
func (b *Buffer) Close() (err error) {
b.mutex.Lock()
if b.closed {
b.mutex.Unlock()
return nil
}
notify := b.notify
b.closed = true
b.mutex.Unlock()
close(notify)
return nil
}
// Count returns the number of packets in the buffer.
func (b *Buffer) Count() int {
b.mutex.Lock()
defer b.mutex.Unlock()
return b.count
}
// SetLimitCount controls the maximum number of packets that can be buffered.
// Causes Write to return ErrFull when this limit is reached.
// A zero value will disable this limit.
func (b *Buffer) SetLimitCount(limit int) {
b.mutex.Lock()
defer b.mutex.Unlock()
b.limitCount = limit
}
// Size returns the total byte size of packets in the buffer, including
// a small amount of administrative overhead.
func (b *Buffer) Size() int {
b.mutex.Lock()
defer b.mutex.Unlock()
return b.size()
}
func (b *Buffer) size() int {
size := b.tail - b.head
if size < 0 {
size += len(b.data)
}
return size
}
// SetLimitSize controls the maximum number of bytes that can be buffered.
// Causes Write to return ErrFull when this limit is reached.
// A zero value means 4MB since v0.11.0.
//
// User can set packetioSizeHardlimit build tag to enable 4MB hardlimit.
// When packetioSizeHardlimit build tag is set, SetLimitSize exceeding
// the hardlimit will be silently discarded.
func (b *Buffer) SetLimitSize(limit int) {
b.mutex.Lock()
defer b.mutex.Unlock()
b.limitSize = limit
}
// SetReadDeadline sets the deadline for the Read operation.
// Setting to zero means no deadline.
func (b *Buffer) SetReadDeadline(t time.Time) error {
b.readDeadline.Set(t)
return nil
}

View file

@ -0,0 +1,27 @@
package packetio
import (
"errors"
)
// netError implements net.Error
type netError struct {
error
timeout, temporary bool
}
func (e *netError) Timeout() bool {
return e.timeout
}
func (e *netError) Temporary() bool {
return e.temporary
}
var (
// ErrFull is returned when the buffer has hit the configured limits.
ErrFull = errors.New("packetio.Buffer is full, discarding write")
// ErrTimeout is returned when a deadline has expired
ErrTimeout = errors.New("i/o timeout")
)

View file

@ -0,0 +1,5 @@
// +build packetioSizeHardlimit
package packetio
const sizeHardlimit = true

View file

@ -0,0 +1,5 @@
// +build !packetioSizeHardlimit
package packetio
const sizeHardlimit = false

View file

@ -0,0 +1,78 @@
package replaydetector
import (
"fmt"
)
// fixedBigInt is the fix-sized multi-word integer.
type fixedBigInt struct {
bits []uint64
n uint
msbMask uint64
}
// newFixedBigInt creates a new fix-sized multi-word int.
func newFixedBigInt(n uint) *fixedBigInt {
chunkSize := (n + 63) / 64
if chunkSize == 0 {
chunkSize = 1
}
return &fixedBigInt{
bits: make([]uint64, chunkSize),
n: n,
msbMask: (1 << (64 - n%64)) - 1,
}
}
// Lsh is the left shift operation.
func (s *fixedBigInt) Lsh(n uint) {
if n == 0 {
return
}
nChunk := int(n / 64)
nN := n % 64
for i := len(s.bits) - 1; i >= 0; i-- {
var carry uint64
if i-nChunk >= 0 {
carry = s.bits[i-nChunk] << nN
if i-nChunk-1 >= 0 {
carry |= s.bits[i-nChunk-1] >> (64 - nN)
}
}
s.bits[i] = (s.bits[i] << n) | carry
}
s.bits[len(s.bits)-1] &= s.msbMask
}
// Bit returns i-th bit of the fixedBigInt.
func (s *fixedBigInt) Bit(i uint) uint {
if i >= s.n {
return 0
}
chunk := i / 64
pos := i % 64
if s.bits[chunk]&(1<<pos) != 0 {
return 1
}
return 0
}
// SetBit sets i-th bit to 1.
func (s *fixedBigInt) SetBit(i uint) {
if i >= s.n {
return
}
chunk := i / 64
pos := i % 64
s.bits[chunk] |= 1 << pos
}
// String returns string representation of fixedBigInt.
func (s *fixedBigInt) String() string {
var out string
for i := len(s.bits) - 1; i >= 0; i-- {
out += fmt.Sprintf("%016X", s.bits[i])
}
return out
}

View file

@ -0,0 +1,116 @@
// Package replaydetector provides packet replay detection algorithm.
package replaydetector
// ReplayDetector is the interface of sequence replay detector.
type ReplayDetector interface {
// Check returns true if given sequence number is not replayed.
// Call accept() to mark the packet is received properly.
Check(seq uint64) (accept func(), ok bool)
}
type slidingWindowDetector struct {
latestSeq uint64
maxSeq uint64
windowSize uint
mask *fixedBigInt
}
// New creates ReplayDetector.
// Created ReplayDetector doesn't allow wrapping.
// It can handle monotonically increasing sequence number up to
// full 64bit number. It is suitable for DTLS replay protection.
func New(windowSize uint, maxSeq uint64) ReplayDetector {
return &slidingWindowDetector{
maxSeq: maxSeq,
windowSize: windowSize,
mask: newFixedBigInt(windowSize),
}
}
func (d *slidingWindowDetector) Check(seq uint64) (accept func(), ok bool) {
if seq > d.maxSeq {
// Exceeded upper limit.
return func() {}, false
}
if seq <= d.latestSeq {
if d.latestSeq >= uint64(d.windowSize)+seq {
return func() {}, false
}
if d.mask.Bit(uint(d.latestSeq-seq)) != 0 {
// The sequence number is duplicated.
return func() {}, false
}
}
return func() {
if seq > d.latestSeq {
// Update the head of the window.
d.mask.Lsh(uint(seq - d.latestSeq))
d.latestSeq = seq
}
diff := (d.latestSeq - seq) % d.maxSeq
d.mask.SetBit(uint(diff))
}, true
}
// WithWrap creates ReplayDetector allowing sequence wrapping.
// This is suitable for short bitwidth counter like SRTP and SRTCP.
func WithWrap(windowSize uint, maxSeq uint64) ReplayDetector {
return &wrappedSlidingWindowDetector{
maxSeq: maxSeq,
windowSize: windowSize,
mask: newFixedBigInt(windowSize),
}
}
type wrappedSlidingWindowDetector struct {
latestSeq uint64
maxSeq uint64
windowSize uint
mask *fixedBigInt
init bool
}
func (d *wrappedSlidingWindowDetector) Check(seq uint64) (accept func(), ok bool) {
if seq > d.maxSeq {
// Exceeded upper limit.
return func() {}, false
}
if !d.init {
if seq != 0 {
d.latestSeq = seq - 1
} else {
d.latestSeq = d.maxSeq
}
d.init = true
}
diff := int64(d.latestSeq) - int64(seq)
// Wrap the number.
if diff > int64(d.maxSeq)/2 {
diff -= int64(d.maxSeq + 1)
} else if diff <= -int64(d.maxSeq)/2 {
diff += int64(d.maxSeq + 1)
}
if diff >= int64(d.windowSize) {
// Too old.
return func() {}, false
}
if diff >= 0 {
if d.mask.Bit(uint(diff)) != 0 {
// The sequence number is duplicated.
return func() {}, false
}
}
return func() {
if diff < 0 {
// Update the head of the window.
d.mask.Lsh(uint(-diff))
d.latestSeq = seq
}
d.mask.SetBit(uint(d.latestSeq - seq))
}, true
}

View file

@ -0,0 +1 @@
*.sw[poe]

View file

@ -0,0 +1,239 @@
# vnet
A virtual network layer for pion.
## Overview
### Goals
* To make NAT traversal tests easy.
* To emulate packet impairment at application level for testing.
* To monitor packets at specified arbitrary interfaces.
### Features
* Configurable virtual LAN and WAN
* Virtually hosted ICE servers
### Virtual network components
#### Top View
```
......................................
: Virtual Network (vnet) :
: :
+-------+ * 1 +----+ +--------+ :
| :App |------------>|:Net|--o<-----|:Router | :
+-------+ +----+ | | :
+-----------+ * 1 +----+ | | :
|:STUNServer|-------->|:Net|--o<-----| | :
+-----------+ +----+ | | :
+-----------+ * 1 +----+ | | :
|:TURNServer|-------->|:Net|--o<-----| | :
+-----------+ +----+ [1] | | :
: 1 | | 1 <<has>> :
: +---<>| |<>----+ [2] :
: | +--------+ | :
To form | *| v 0..1 :
a subnet tree | o [3] +-----+ :
: | ^ |:NAT | :
: | | +-----+ :
: +-------+ :
......................................
Note:
o: NIC (Netork Interface Controller)
[1]: Net implments NIC interface.
[2]: Root router has no NAT. All child routers have a NAT always.
[3]: Router implements NIC interface for accesses from the
parent router.
```
#### Net
Net provides 3 interfaces:
* Configuration API (direct)
* Network API via Net (equivalent to net.Xxx())
* Router access via NIC interface
```
(Pion module/app, ICE servers, etc.)
+-----------+
| :App |
+-----------+
* |
| <<uses>>
1 v
+---------+ 1 * +-----------+ 1 * +-----------+ 1 * +------+
..| :Router |----+------>o--| :Net |<>------|:Interface |<>------|:Addr |
+---------+ | NIC +-----------+ +-----------+ +------+
| <<interface>> (vnet.Interface) (net.Addr)
|
| * +-----------+ 1 * +-----------+ 1 * +------+
+------>o--| :Router |<>------|:Interface |<>------|:Addr |
NIC +-----------+ +-----------+ +------+
<<interface>> (vnet.Interface) (net.Addr)
```
> The instance of `Net` will be the one passed around the project.
> Net class has public methods for configuration and for application use.
## Implementation
### Design Policy
* Each pion package should have config object which has `Net` (of type vnet.Net) property. (just like how
we distribute `LoggerFactory` throughout the pion project.
* DNS => a simple dictionary (global)?
* Each Net has routing capability (a goroutine)
* Use interface provided net package as much as possible
* Routers are connected in a tree structure (no loop is allowed)
- To simplify routing
- Easy to control / monitor (stats, etc)
* Root router has no NAT (== Internet / WAN)
* Non-root router has a NAT always
* When a Net is instantiated, it will automatically add `lo0` and `eth0` interface, and `lo0` will
have one IP address, 127.0.0.1. (this is not used in pion/ice, however)
* When a Net is added to a router, the router automatically assign an IP address for `eth0`
interface.
- For simplicity
* User data won't fragment, but optionally drop chunk larger than MTU
* IPv6 is not supported
### Basic steps for setting up virtual network
1. Create a root router (WAN)
1. Create child routers and add to its parent (forms a tree, don't create a loop!)
1. Add instances of Net to each routers
1. Call Stop(), or Stop(), on the top router, which propages all other routers
#### Example: WAN with one endpoint (vnet)
```go
import (
"net"
"github.com/pion/transport/vnet"
"github.com/pion/logging"
)
// Create WAN (a root router).
wan, err := vnet.NewRouter(&RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
// Create a network.
// You can specify a static IP for the instance of Net to use. If not specified,
// router will assign an IP address that is contained in the router's CIDR.
nw := vnet.NewNet(&vnet.NetConfig{
StaticIP: "27.1.2.3",
})
// Add the network to the router.
// The router will assign an IP address to `nw`.
if err = wan.AddNet(nw); err != nil {
// handle error
}
// Start router.
// This will start internal goroutine to route packets.
// If you set child routers (using AddRouter), the call on the root router
// will start the rest of routers for you.
if err = wan.Start(); err != nil {
// handle error
}
//
// Your application runs here using `nw`.
//
// Stop the router.
// This will stop all internal goroutines in the router tree.
// (No need to call Stop() on child routers)
if err = wan.Stop(); err != nil {
// handle error
}
```
#### Example of how to pass around the instance of vnet.Net
The instance of vnet.Net wraps a subset of net package to enable operations
on the virtual network. Your project must be able to pass the instance to
all your routines that do network operation with net package. A typical way
is to use a config param to create your instances with the virtual network
instance (`nw` in the above example) like this:
```go
type AgentConfig struct {
:
Net: *vnet.Net,
}
type Agent struct {
:
net: *vnet.Net,
}
func NetAgent(config *AgentConfig) *Agent {
if config.Net == nil {
config.Net = vnet.NewNet(nil) // defaults to native operation
}
return &Agent {
:
net: config.Net,
}
}
```
```go
// a.net is the instance of vnet.Net class
func (a *Agent) listenUDP(...) error {
conn, err := a.net.ListenPacket(udpString, ...)
if err != nil {
return nil, err
}
:
}
```
### Compatibility and Support Status
|`net`<br>(built-in)|`vnet`|Note|
|---|---|---|
|net.Interfaces()|a.net.Interfaces()||
|net.InterfaceByName()|a.net.InterfaceByName()||
|net.ResolveUDPAddr()|a.net.ResolveUDPAddr()||
|net.ListenPacket()|a.net.ListenPacket()||
|net.ListenUDP()|a.net.ListenUDP()|(ListenPacket() is recommended)|
|net.Listen()|a.net.Listen()|(TODO)|
|net.ListenTCP()|(not supported)|(Listen() would be recommended)|
|net.Dial()|a.net.Dial()||
|net.DialUDP()|a.net.DialUDP()||
|net.DialTCP()|(not supported)||
|net.Interface|vnet.Interface||
|net.PacketConn|(use it as-is)||
|net.UDPConn|vnet.UDPConn|Use vnet.UDPPacketConn in your code|
|net.TCPConn|vnet.TCPConn|(TODO)|Use net.Conn in your code|
|net.Dialer|vnet.Dialer|Use a.net.CreateDialer() to create it.<br>The use of vnet.Dialer is currently experimental.|
> `a.net` is an instance of Net class, and types are defined under the package name `vnet`
> Most of other `interface` types in net package can be used as is.
> Please post a github issue when other types/methods need to be added to vnet/vnet.Net.
## TODO / Next Step
* Implement TCP (TCPConn, Listen)
* Support of IPv6
* Write a bunch of examples for building virtual networks.
* Add network impairment features (on Router)
- Introduce lantecy / jitter
- Packet filtering handler (allow selectively drop packets, etc.)
* Add statistics data retrieval
- Total number of packets forward by each router
- Total number of packet loss
- Total number of connection failure (TCP)
## References
* [Comparing Simulated Packet Loss and RealWorld Network Congestion](https://www.riverbed.com/document/fpo/WhitePaper-Riverbed-SimulatedPacketLoss.pdf)
### Code experiments
* [CIDR and IPMask](https://play.golang.org/p/B7OBhkZqjmj)
* [Test with net.IP](https://play.golang.org/p/AgXd23wKY4W)
* [ListenPacket](https://play.golang.org/p/d4vasbnRimQ)
* [isDottedIP()](https://play.golang.org/p/t4aZ47TgJfO)
* [SplitHostPort](https://play.golang.org/p/JtvurlcMbhn)

View file

@ -0,0 +1,283 @@
package vnet
import (
"fmt"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
)
type tcpFlag uint8
const (
tcpFIN tcpFlag = 0x01
tcpSYN tcpFlag = 0x02
tcpRST tcpFlag = 0x04
tcpPSH tcpFlag = 0x08
tcpACK tcpFlag = 0x10
)
func (f tcpFlag) String() string {
var sa []string
if f&tcpFIN != 0 {
sa = append(sa, "FIN")
}
if f&tcpSYN != 0 {
sa = append(sa, "SYN")
}
if f&tcpRST != 0 {
sa = append(sa, "RST")
}
if f&tcpPSH != 0 {
sa = append(sa, "PSH")
}
if f&tcpACK != 0 {
sa = append(sa, "ACK")
}
return strings.Join(sa, "-")
}
// Generate a base36-encoded unique tag
// See: https://play.golang.org/p/0ZaAID1q-HN
var assignChunkTag = func() func() string { //nolint:gochecknoglobals
var tagCtr uint64
return func() string {
n := atomic.AddUint64(&tagCtr, 1)
return strconv.FormatUint(n, 36)
}
}()
// Chunk represents a packet passed around in the vnet
type Chunk interface {
setTimestamp() time.Time // used by router
getTimestamp() time.Time // used by router
getSourceIP() net.IP // used by router
getDestinationIP() net.IP // used by router
setSourceAddr(address string) error // used by nat
setDestinationAddr(address string) error // used by nat
SourceAddr() net.Addr
DestinationAddr() net.Addr
UserData() []byte
Tag() string
Clone() Chunk
Network() string // returns "udp" or "tcp"
String() string
}
type chunkIP struct {
timestamp time.Time
sourceIP net.IP
destinationIP net.IP
tag string
}
func (c *chunkIP) setTimestamp() time.Time {
c.timestamp = time.Now()
return c.timestamp
}
func (c *chunkIP) getTimestamp() time.Time {
return c.timestamp
}
func (c *chunkIP) getDestinationIP() net.IP {
return c.destinationIP
}
func (c *chunkIP) getSourceIP() net.IP {
return c.sourceIP
}
func (c *chunkIP) Tag() string {
return c.tag
}
type chunkUDP struct {
chunkIP
sourcePort int
destinationPort int
userData []byte
}
func newChunkUDP(srcAddr, dstAddr *net.UDPAddr) *chunkUDP {
return &chunkUDP{
chunkIP: chunkIP{
sourceIP: srcAddr.IP,
destinationIP: dstAddr.IP,
tag: assignChunkTag(),
},
sourcePort: srcAddr.Port,
destinationPort: dstAddr.Port,
}
}
func (c *chunkUDP) SourceAddr() net.Addr {
return &net.UDPAddr{
IP: c.sourceIP,
Port: c.sourcePort,
}
}
func (c *chunkUDP) DestinationAddr() net.Addr {
return &net.UDPAddr{
IP: c.destinationIP,
Port: c.destinationPort,
}
}
func (c *chunkUDP) UserData() []byte {
return c.userData
}
func (c *chunkUDP) Clone() Chunk {
var userData []byte
if c.userData != nil {
userData = make([]byte, len(c.userData))
copy(userData, c.userData)
}
return &chunkUDP{
chunkIP: chunkIP{
timestamp: c.timestamp,
sourceIP: c.sourceIP,
destinationIP: c.destinationIP,
tag: c.tag,
},
sourcePort: c.sourcePort,
destinationPort: c.destinationPort,
userData: userData,
}
}
func (c *chunkUDP) Network() string {
return udpString
}
func (c *chunkUDP) String() string {
src := c.SourceAddr()
dst := c.DestinationAddr()
return fmt.Sprintf("%s chunk %s %s => %s",
src.Network(),
c.tag,
src.String(),
dst.String(),
)
}
func (c *chunkUDP) setSourceAddr(address string) error {
addr, err := net.ResolveUDPAddr(udpString, address)
if err != nil {
return err
}
c.sourceIP = addr.IP
c.sourcePort = addr.Port
return nil
}
func (c *chunkUDP) setDestinationAddr(address string) error {
addr, err := net.ResolveUDPAddr(udpString, address)
if err != nil {
return err
}
c.destinationIP = addr.IP
c.destinationPort = addr.Port
return nil
}
type chunkTCP struct {
chunkIP
sourcePort int
destinationPort int
flags tcpFlag // control bits
userData []byte // only with PSH flag
// seq uint32 // always starts with 0
// ack uint32 // always starts with 0
}
func newChunkTCP(srcAddr, dstAddr *net.TCPAddr, flags tcpFlag) *chunkTCP {
return &chunkTCP{
chunkIP: chunkIP{
sourceIP: srcAddr.IP,
destinationIP: dstAddr.IP,
tag: assignChunkTag(),
},
sourcePort: srcAddr.Port,
destinationPort: dstAddr.Port,
flags: flags,
}
}
func (c *chunkTCP) SourceAddr() net.Addr {
return &net.TCPAddr{
IP: c.sourceIP,
Port: c.sourcePort,
}
}
func (c *chunkTCP) DestinationAddr() net.Addr {
return &net.TCPAddr{
IP: c.destinationIP,
Port: c.destinationPort,
}
}
func (c *chunkTCP) UserData() []byte {
return c.userData
}
func (c *chunkTCP) Clone() Chunk {
userData := make([]byte, len(c.userData))
copy(userData, c.userData)
return &chunkTCP{
chunkIP: chunkIP{
timestamp: c.timestamp,
sourceIP: c.sourceIP,
destinationIP: c.destinationIP,
},
sourcePort: c.sourcePort,
destinationPort: c.destinationPort,
userData: userData,
}
}
func (c *chunkTCP) Network() string {
return "tcp"
}
func (c *chunkTCP) String() string {
src := c.SourceAddr()
dst := c.DestinationAddr()
return fmt.Sprintf("%s %s chunk %s %s => %s",
src.Network(),
c.flags.String(),
c.tag,
src.String(),
dst.String(),
)
}
func (c *chunkTCP) setSourceAddr(address string) error {
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return err
}
c.sourceIP = addr.IP
c.sourcePort = addr.Port
return nil
}
func (c *chunkTCP) setDestinationAddr(address string) error {
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return err
}
c.destinationIP = addr.IP
c.destinationPort = addr.Port
return nil
}

View file

@ -0,0 +1,52 @@
package vnet
import (
"sync"
)
type chunkQueue struct {
chunks []Chunk
maxSize int // 0 or negative value: unlimited
mutex sync.RWMutex
}
func newChunkQueue(maxSize int) *chunkQueue {
return &chunkQueue{maxSize: maxSize}
}
func (q *chunkQueue) push(c Chunk) bool {
q.mutex.Lock()
defer q.mutex.Unlock()
if q.maxSize > 0 && len(q.chunks) >= q.maxSize {
return false // dropped
}
q.chunks = append(q.chunks, c)
return true
}
func (q *chunkQueue) pop() (Chunk, bool) {
q.mutex.Lock()
defer q.mutex.Unlock()
if len(q.chunks) == 0 {
return nil, false
}
c := q.chunks[0]
q.chunks = q.chunks[1:]
return c, true
}
func (q *chunkQueue) peek() Chunk {
q.mutex.RLock()
defer q.mutex.RUnlock()
if len(q.chunks) == 0 {
return nil
}
return q.chunks[0]
}

View file

@ -0,0 +1,246 @@
package vnet
import (
"errors"
"io"
"math"
"net"
"sync"
"time"
)
const (
maxReadQueueSize = 1024
)
var (
errObsCannotBeNil = errors.New("obs cannot be nil")
errUseClosedNetworkConn = errors.New("use of closed network connection")
errAddrNotUDPAddr = errors.New("addr is not a net.UDPAddr")
errLocAddr = errors.New("something went wrong with locAddr")
errAlreadyClosed = errors.New("already closed")
errNoRemAddr = errors.New("no remAddr defined")
)
// UDPPacketConn is packet-oriented connection for UDP.
type UDPPacketConn interface {
net.PacketConn
Read(b []byte) (int, error)
RemoteAddr() net.Addr
Write(b []byte) (int, error)
}
// vNet implements this
type connObserver interface {
write(c Chunk) error
onClosed(addr net.Addr)
determineSourceIP(locIP, dstIP net.IP) net.IP
}
// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections.
// comatible with net.PacketConn and net.Conn
type UDPConn struct {
locAddr *net.UDPAddr // read-only
remAddr *net.UDPAddr // read-only
obs connObserver // read-only
readCh chan Chunk // thread-safe
closed bool // requires mutex
mu sync.Mutex // to mutex closed flag
readTimer *time.Timer // thread-safe
}
func newUDPConn(locAddr, remAddr *net.UDPAddr, obs connObserver) (*UDPConn, error) {
if obs == nil {
return nil, errObsCannotBeNil
}
return &UDPConn{
locAddr: locAddr,
remAddr: remAddr,
obs: obs,
readCh: make(chan Chunk, maxReadQueueSize),
readTimer: time.NewTimer(time.Duration(math.MaxInt64)),
}, nil
}
// ReadFrom reads a packet from the connection,
// copying the payload into p. It returns the number of
// bytes copied into p and the return address that
// was on the packet.
// It returns the number of bytes read (0 <= n <= len(p))
// and any error encountered. Callers should always process
// the n > 0 bytes returned before considering the error err.
// ReadFrom can be made to time out and return
// an Error with Timeout() == true after a fixed time limit;
// see SetDeadline and SetReadDeadline.
func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
loop:
for {
select {
case chunk, ok := <-c.readCh:
if !ok {
break loop
}
var err error
n := copy(p, chunk.UserData())
addr := chunk.SourceAddr()
if n < len(chunk.UserData()) {
err = io.ErrShortBuffer
}
if c.remAddr != nil {
if addr.String() != c.remAddr.String() {
break // discard (shouldn't happen)
}
}
return n, addr, err
case <-c.readTimer.C:
return 0, nil, &net.OpError{
Op: "read",
Net: c.locAddr.Network(),
Addr: c.locAddr,
Err: newTimeoutError("i/o timeout"),
}
}
}
return 0, nil, &net.OpError{
Op: "read",
Net: c.locAddr.Network(),
Addr: c.locAddr,
Err: errUseClosedNetworkConn,
}
}
// WriteTo writes a packet with payload p to addr.
// WriteTo can be made to time out and return
// an Error with Timeout() == true after a fixed time limit;
// see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
dstAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, errAddrNotUDPAddr
}
srcIP := c.obs.determineSourceIP(c.locAddr.IP, dstAddr.IP)
if srcIP == nil {
return 0, errLocAddr
}
srcAddr := &net.UDPAddr{
IP: srcIP,
Port: c.locAddr.Port,
}
chunk := newChunkUDP(srcAddr, dstAddr)
chunk.userData = make([]byte, len(p))
copy(chunk.userData, p)
if err := c.obs.write(chunk); err != nil {
return 0, err
}
return len(p), nil
}
// Close closes the connection.
// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors.
func (c *UDPConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return errAlreadyClosed
}
c.closed = true
close(c.readCh)
c.obs.onClosed(c.locAddr)
return nil
}
// LocalAddr returns the local network address.
func (c *UDPConn) LocalAddr() net.Addr {
return c.locAddr
}
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
//
// A deadline is an absolute time after which I/O operations
// fail with a timeout (see type Error) instead of
// blocking. The deadline applies to all future and pending
// I/O, not just the immediately following call to ReadFrom or
// WriteTo. After a deadline has been exceeded, the connection
// can be refreshed by setting a deadline in the future.
//
// An idle timeout can be implemented by repeatedly extending
// the deadline after successful ReadFrom or WriteTo calls.
//
// A zero value for t means I/O operations will not time out.
func (c *UDPConn) SetDeadline(t time.Time) error {
return c.SetReadDeadline(t)
}
// SetReadDeadline sets the deadline for future ReadFrom calls
// and any currently-blocked ReadFrom call.
// A zero value for t means ReadFrom will not time out.
func (c *UDPConn) SetReadDeadline(t time.Time) error {
var d time.Duration
var noDeadline time.Time
if t == noDeadline {
d = time.Duration(math.MaxInt64)
} else {
d = time.Until(t)
}
c.readTimer.Reset(d)
return nil
}
// SetWriteDeadline sets the deadline for future WriteTo calls
// and any currently-blocked WriteTo call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means WriteTo will not time out.
func (c *UDPConn) SetWriteDeadline(t time.Time) error {
// Write never blocks.
return nil
}
// Read reads data from the connection.
// Read can be made to time out and return an Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
func (c *UDPConn) Read(b []byte) (int, error) {
n, _, err := c.ReadFrom(b)
return n, err
}
// RemoteAddr returns the remote network address.
func (c *UDPConn) RemoteAddr() net.Addr {
return c.remAddr
}
// Write writes data to the connection.
// Write can be made to time out and return an Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
func (c *UDPConn) Write(b []byte) (int, error) {
if c.remAddr == nil {
return 0, errNoRemAddr
}
return c.WriteTo(b, c.remAddr)
}
func (c *UDPConn) onInboundChunk(chunk Chunk) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
select {
case c.readCh <- chunk:
default:
}
}

View file

@ -0,0 +1,136 @@
package vnet
import (
"errors"
"net"
"sync"
)
var (
errAddressAlreadyInUse = errors.New("address already in use")
errNoSuchUDPConn = errors.New("no such UDPConn")
errCannotRemoveUnspecifiedIP = errors.New("cannot remove unspecified IP by the specified IP")
)
type udpConnMap struct {
portMap map[int][]*UDPConn
mutex sync.RWMutex
}
func newUDPConnMap() *udpConnMap {
return &udpConnMap{
portMap: map[int][]*UDPConn{},
}
}
func (m *udpConnMap) insert(conn *UDPConn) error {
m.mutex.Lock()
defer m.mutex.Unlock()
udpAddr := conn.LocalAddr().(*net.UDPAddr)
// check if the port has a listener
conns, ok := m.portMap[udpAddr.Port]
if ok {
if udpAddr.IP.IsUnspecified() {
return errAddressAlreadyInUse
}
for _, conn := range conns {
laddr := conn.LocalAddr().(*net.UDPAddr)
if laddr.IP.IsUnspecified() || laddr.IP.Equal(udpAddr.IP) {
return errAddressAlreadyInUse
}
}
conns = append(conns, conn)
} else {
conns = []*UDPConn{conn}
}
m.portMap[udpAddr.Port] = conns
return nil
}
func (m *udpConnMap) find(addr net.Addr) (*UDPConn, bool) {
m.mutex.Lock() // could be RLock, but we have delete() op
defer m.mutex.Unlock()
udpAddr := addr.(*net.UDPAddr)
if conns, ok := m.portMap[udpAddr.Port]; ok {
if udpAddr.IP.IsUnspecified() {
// pick the first one appears in the iteration
if len(conns) == 0 {
// This can't happen!
delete(m.portMap, udpAddr.Port)
return nil, false
}
return conns[0], true
}
for _, conn := range conns {
laddr := conn.LocalAddr().(*net.UDPAddr)
if laddr.IP.IsUnspecified() || laddr.IP.Equal(udpAddr.IP) {
return conn, ok
}
}
}
return nil, false
}
func (m *udpConnMap) delete(addr net.Addr) error {
m.mutex.Lock()
defer m.mutex.Unlock()
udpAddr := addr.(*net.UDPAddr)
conns, ok := m.portMap[udpAddr.Port]
if !ok {
return errNoSuchUDPConn
}
if udpAddr.IP.IsUnspecified() {
// remove all from this port
delete(m.portMap, udpAddr.Port)
return nil
}
newConns := []*UDPConn{}
for _, conn := range conns {
laddr := conn.LocalAddr().(*net.UDPAddr)
if laddr.IP.IsUnspecified() {
// This can't happen!
return errCannotRemoveUnspecifiedIP
}
if laddr.IP.Equal(udpAddr.IP) {
continue
}
newConns = append(newConns, conn)
}
if len(newConns) == 0 {
delete(m.portMap, udpAddr.Port)
} else {
m.portMap[udpAddr.Port] = newConns
}
return nil
}
// size returns the number of UDPConns (UDP listeners)
func (m *udpConnMap) size() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
n := 0
for _, conns := range m.portMap {
n += len(conns)
}
return n
}

View file

@ -0,0 +1,19 @@
package vnet
type timeoutError struct {
msg string
}
func newTimeoutError(msg string) error {
return &timeoutError{
msg: msg,
}
}
func (e *timeoutError) Error() string {
return e.msg
}
func (e *timeoutError) Timeout() bool {
return true
}

View file

@ -0,0 +1,40 @@
package vnet
import (
"errors"
"net"
)
var errNoAddressAssigned = errors.New("no address assigned")
// See: https://play.golang.org/p/nBO9KGYEziv
// InterfaceBase ...
type InterfaceBase net.Interface
// Interface ...
type Interface struct {
InterfaceBase
addrs []net.Addr
}
// NewInterface ...
func NewInterface(ifc net.Interface) *Interface {
return &Interface{
InterfaceBase: InterfaceBase(ifc),
addrs: nil,
}
}
// AddAddr ...
func (ifc *Interface) AddAddr(addr net.Addr) {
ifc.addrs = append(ifc.addrs, addr)
}
// Addrs ...
func (ifc *Interface) Addrs() ([]net.Addr, error) {
if len(ifc.addrs) == 0 {
return nil, errNoAddressAssigned
}
return ifc.addrs, nil
}

View file

@ -0,0 +1,338 @@
package vnet
import (
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/pion/logging"
)
var (
errNATRequriesMapping = errors.New("1:1 NAT requires more than one mapping")
errMismatchLengthIP = errors.New("length mismtach between mappedIPs and localIPs")
errNonUDPTranslationNotSupported = errors.New("non-udp translation is not supported yet")
errNoAssociatedLocalAddress = errors.New("no associated local address")
errNoNATBindingFound = errors.New("no NAT binding found")
errHasNoPermission = errors.New("has no permission")
)
// EndpointDependencyType defines a type of behavioral dependendency on the
// remote endpoint's IP address or port number. This is used for the two
// kinds of behaviors:
// - Port mapping behavior
// - Filtering behavior
// See: https://tools.ietf.org/html/rfc4787
type EndpointDependencyType uint8
const (
// EndpointIndependent means the behavior is independent of the endpoint's address or port
EndpointIndependent EndpointDependencyType = iota
// EndpointAddrDependent means the behavior is dependent on the endpoint's address
EndpointAddrDependent
// EndpointAddrPortDependent means the behavior is dependent on the endpoint's address and port
EndpointAddrPortDependent
)
// NATMode defines basic behavior of the NAT
type NATMode uint8
const (
// NATModeNormal means the NAT behaves as a standard NAPT (RFC 2663).
NATModeNormal NATMode = iota
// NATModeNAT1To1 exhibits 1:1 DNAT where the external IP address is statically mapped to
// a specific local IP address with port number is preserved always between them.
// When this mode is selected, MappingBehavior, FilteringBehavior, PortPreservation and
// MappingLifeTime of NATType are ignored.
NATModeNAT1To1
)
const (
defaultNATMappingLifeTime = 30 * time.Second
)
// NATType has a set of parameters that define the behavior of NAT.
type NATType struct {
Mode NATMode
MappingBehavior EndpointDependencyType
FilteringBehavior EndpointDependencyType
Hairpining bool // Not implemented yet
PortPreservation bool // Not implemented yet
MappingLifeTime time.Duration
}
type natConfig struct {
name string
natType NATType
mappedIPs []net.IP // mapped IPv4
localIPs []net.IP // local IPv4, required only when the mode is NATModeNAT1To1
loggerFactory logging.LoggerFactory
}
type mapping struct {
proto string // "udp" or "tcp"
local string // "<local-ip>:<local-port>"
mapped string // "<mapped-ip>:<mapped-port>"
bound string // key: "[<remote-ip>[:<remote-port>]]"
filters map[string]struct{} // key: "[<remote-ip>[:<remote-port>]]"
expires time.Time // time to expire
}
type networkAddressTranslator struct {
name string
natType NATType
mappedIPs []net.IP // mapped IPv4
localIPs []net.IP // local IPv4, required only when the mode is NATModeNAT1To1
outboundMap map[string]*mapping // key: "<proto>:<local-ip>:<local-port>[:remote-ip[:remote-port]]
inboundMap map[string]*mapping // key: "<proto>:<mapped-ip>:<mapped-port>"
udpPortCounter int
mutex sync.RWMutex
log logging.LeveledLogger
}
func newNAT(config *natConfig) (*networkAddressTranslator, error) {
natType := config.natType
if natType.Mode == NATModeNAT1To1 {
// 1:1 NAT behavior
natType.MappingBehavior = EndpointIndependent
natType.FilteringBehavior = EndpointIndependent
natType.PortPreservation = true
natType.MappingLifeTime = 0
if len(config.mappedIPs) == 0 {
return nil, errNATRequriesMapping
}
if len(config.mappedIPs) != len(config.localIPs) {
return nil, errMismatchLengthIP
}
} else {
// Normal (NAPT) behavior
natType.Mode = NATModeNormal
if natType.MappingLifeTime == 0 {
natType.MappingLifeTime = defaultNATMappingLifeTime
}
}
return &networkAddressTranslator{
name: config.name,
natType: natType,
mappedIPs: config.mappedIPs,
localIPs: config.localIPs,
outboundMap: map[string]*mapping{},
inboundMap: map[string]*mapping{},
log: config.loggerFactory.NewLogger("vnet"),
}, nil
}
func (n *networkAddressTranslator) getPairedMappedIP(locIP net.IP) net.IP {
for i, ip := range n.localIPs {
if ip.Equal(locIP) {
return n.mappedIPs[i]
}
}
return nil
}
func (n *networkAddressTranslator) getPairedLocalIP(mappedIP net.IP) net.IP {
for i, ip := range n.mappedIPs {
if ip.Equal(mappedIP) {
return n.localIPs[i]
}
}
return nil
}
func (n *networkAddressTranslator) translateOutbound(from Chunk) (Chunk, error) {
n.mutex.Lock()
defer n.mutex.Unlock()
to := from.Clone()
if from.Network() == udpString {
if n.natType.Mode == NATModeNAT1To1 {
// 1:1 NAT behavior
srcAddr := from.SourceAddr().(*net.UDPAddr)
srcIP := n.getPairedMappedIP(srcAddr.IP)
if srcIP == nil {
n.log.Debugf("[%s] drop outbound chunk %s with not route", n.name, from.String())
return nil, nil // silently discard
}
srcPort := srcAddr.Port
if err := to.setSourceAddr(fmt.Sprintf("%s:%d", srcIP.String(), srcPort)); err != nil {
return nil, err
}
} else {
// Normal (NAPT) behavior
var bound, filterKey string
switch n.natType.MappingBehavior {
case EndpointIndependent:
bound = ""
case EndpointAddrDependent:
bound = from.getDestinationIP().String()
case EndpointAddrPortDependent:
bound = from.DestinationAddr().String()
}
switch n.natType.FilteringBehavior {
case EndpointIndependent:
filterKey = ""
case EndpointAddrDependent:
filterKey = from.getDestinationIP().String()
case EndpointAddrPortDependent:
filterKey = from.DestinationAddr().String()
}
oKey := fmt.Sprintf("udp:%s:%s", from.SourceAddr().String(), bound)
m := n.findOutboundMapping(oKey)
if m == nil {
// Create a new mapping
mappedPort := 0xC000 + n.udpPortCounter
n.udpPortCounter++
m = &mapping{
proto: from.SourceAddr().Network(),
local: from.SourceAddr().String(),
bound: bound,
mapped: fmt.Sprintf("%s:%d", n.mappedIPs[0].String(), mappedPort),
filters: map[string]struct{}{},
expires: time.Now().Add(n.natType.MappingLifeTime),
}
n.outboundMap[oKey] = m
iKey := fmt.Sprintf("udp:%s", m.mapped)
n.log.Debugf("[%s] created a new NAT binding oKey=%s iKey=%s\n",
n.name,
oKey,
iKey)
m.filters[filterKey] = struct{}{}
n.log.Debugf("[%s] permit access from %s to %s\n", n.name, filterKey, m.mapped)
n.inboundMap[iKey] = m
} else if _, ok := m.filters[filterKey]; !ok {
n.log.Debugf("[%s] permit access from %s to %s\n", n.name, filterKey, m.mapped)
m.filters[filterKey] = struct{}{}
}
if err := to.setSourceAddr(m.mapped); err != nil {
return nil, err
}
}
n.log.Debugf("[%s] translate outbound chunk from %s to %s", n.name, from.String(), to.String())
return to, nil
}
return nil, errNonUDPTranslationNotSupported
}
func (n *networkAddressTranslator) translateInbound(from Chunk) (Chunk, error) {
n.mutex.Lock()
defer n.mutex.Unlock()
to := from.Clone()
if from.Network() == udpString {
if n.natType.Mode == NATModeNAT1To1 {
// 1:1 NAT behavior
dstAddr := from.DestinationAddr().(*net.UDPAddr)
dstIP := n.getPairedLocalIP(dstAddr.IP)
if dstIP == nil {
return nil, fmt.Errorf("drop %s as %w", from.String(), errNoAssociatedLocalAddress)
}
dstPort := from.DestinationAddr().(*net.UDPAddr).Port
if err := to.setDestinationAddr(fmt.Sprintf("%s:%d", dstIP, dstPort)); err != nil {
return nil, err
}
} else {
// Normal (NAPT) behavior
iKey := fmt.Sprintf("udp:%s", from.DestinationAddr().String())
m := n.findInboundMapping(iKey)
if m == nil {
return nil, fmt.Errorf("drop %s as %w", from.String(), errNoNATBindingFound)
}
var filterKey string
switch n.natType.FilteringBehavior {
case EndpointIndependent:
filterKey = ""
case EndpointAddrDependent:
filterKey = from.getSourceIP().String()
case EndpointAddrPortDependent:
filterKey = from.SourceAddr().String()
}
if _, ok := m.filters[filterKey]; !ok {
return nil, fmt.Errorf("drop %s as the remote %s %w", from.String(), filterKey, errHasNoPermission)
}
// See RFC 4847 Section 4.3. Mapping Refresh
// a) Inbound refresh may be useful for applications with no outgoing
// UDP traffic. However, allowing inbound refresh may allow an
// external attacker or misbehaving application to keep a mapping
// alive indefinitely. This may be a security risk. Also, if the
// process is repeated with different ports, over time, it could
// use up all the ports on the NAT.
if err := to.setDestinationAddr(m.local); err != nil {
return nil, err
}
}
n.log.Debugf("[%s] translate inbound chunk from %s to %s", n.name, from.String(), to.String())
return to, nil
}
return nil, errNonUDPTranslationNotSupported
}
// caller must hold the mutex
func (n *networkAddressTranslator) findOutboundMapping(oKey string) *mapping {
now := time.Now()
m, ok := n.outboundMap[oKey]
if ok {
// check if this mapping is expired
if now.After(m.expires) {
n.removeMapping(m)
m = nil // expired
} else {
m.expires = time.Now().Add(n.natType.MappingLifeTime)
}
}
return m
}
// caller must hold the mutex
func (n *networkAddressTranslator) findInboundMapping(iKey string) *mapping {
now := time.Now()
m, ok := n.inboundMap[iKey]
if !ok {
return nil
}
// check if this mapping is expired
if now.After(m.expires) {
n.removeMapping(m)
return nil
}
return m
}
// caller must hold the mutex
func (n *networkAddressTranslator) removeMapping(m *mapping) {
oKey := fmt.Sprintf("%s:%s:%s", m.proto, m.local, m.bound)
iKey := fmt.Sprintf("%s:%s", m.proto, m.mapped)
delete(n.outboundMap, oKey)
delete(n.inboundMap, iKey)
}

View file

@ -0,0 +1,677 @@
package vnet
import (
"encoding/binary"
"errors"
"fmt"
"math/rand"
"net"
"strconv"
"strings"
"sync"
)
const (
lo0String = "lo0String"
udpString = "udp"
)
var (
macAddrCounter uint64 = 0xBEEFED910200 //nolint:gochecknoglobals
errNoInterface = errors.New("no interface is available")
errNotFound = errors.New("not found")
errUnexpectedNetwork = errors.New("unexpected network")
errCantAssignRequestedAddr = errors.New("can't assign requested address")
errUnknownNetwork = errors.New("unknown network")
errNoRouterLinked = errors.New("no router linked")
errInvalidPortNumber = errors.New("invalid port number")
errUnexpectedTypeSwitchFailure = errors.New("unexpected type-switch failure")
errBindFailerFor = errors.New("bind failed for")
errEndPortLessThanStart = errors.New("end port is less than the start")
errPortSpaceExhausted = errors.New("port space exhausted")
errVNetDisabled = errors.New("vnet is not enabled")
)
func newMACAddress() net.HardwareAddr {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, macAddrCounter)
macAddrCounter++
return b[2:]
}
type vNet struct {
interfaces []*Interface // read-only
staticIPs []net.IP // read-only
router *Router // read-only
udpConns *udpConnMap // read-only
mutex sync.RWMutex
}
func (v *vNet) _getInterfaces() ([]*Interface, error) {
if len(v.interfaces) == 0 {
return nil, errNoInterface
}
return v.interfaces, nil
}
func (v *vNet) getInterfaces() ([]*Interface, error) {
v.mutex.RLock()
defer v.mutex.RUnlock()
return v._getInterfaces()
}
// caller must hold the mutex (read)
func (v *vNet) _getInterface(ifName string) (*Interface, error) {
ifs, err := v._getInterfaces()
if err != nil {
return nil, err
}
for _, ifc := range ifs {
if ifc.Name == ifName {
return ifc, nil
}
}
return nil, fmt.Errorf("interface %s %w", ifName, errNotFound)
}
func (v *vNet) getInterface(ifName string) (*Interface, error) {
v.mutex.RLock()
defer v.mutex.RUnlock()
return v._getInterface(ifName)
}
// caller must hold the mutex
func (v *vNet) getAllIPAddrs(ipv6 bool) []net.IP {
ips := []net.IP{}
for _, ifc := range v.interfaces {
addrs, err := ifc.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
if ipNet, ok := addr.(*net.IPNet); ok {
ip = ipNet.IP
} else if ipAddr, ok := addr.(*net.IPAddr); ok {
ip = ipAddr.IP
} else {
continue
}
if !ipv6 {
if ip.To4() != nil {
ips = append(ips, ip)
}
}
}
}
return ips
}
func (v *vNet) setRouter(r *Router) error {
v.mutex.Lock()
defer v.mutex.Unlock()
v.router = r
return nil
}
func (v *vNet) onInboundChunk(c Chunk) {
v.mutex.Lock()
defer v.mutex.Unlock()
if c.Network() == udpString {
if conn, ok := v.udpConns.find(c.DestinationAddr()); ok {
conn.onInboundChunk(c)
}
}
}
// caller must hold the mutex
func (v *vNet) _dialUDP(network string, locAddr, remAddr *net.UDPAddr) (UDPPacketConn, error) {
// validate network
if network != udpString && network != "udp4" {
return nil, fmt.Errorf("%w: %s", errUnexpectedNetwork, network)
}
if locAddr == nil {
locAddr = &net.UDPAddr{
IP: net.IPv4zero,
}
} else if locAddr.IP == nil {
locAddr.IP = net.IPv4zero
}
// validate address. do we have that address?
if !v.hasIPAddr(locAddr.IP) {
return nil, &net.OpError{
Op: "listen",
Net: network,
Addr: locAddr,
Err: fmt.Errorf("bind: %w", errCantAssignRequestedAddr),
}
}
if locAddr.Port == 0 {
// choose randomly from the range between 5000 and 5999
port, err := v.assignPort(locAddr.IP, 5000, 5999)
if err != nil {
return nil, &net.OpError{
Op: "listen",
Net: network,
Addr: locAddr,
Err: err,
}
}
locAddr.Port = port
} else if _, ok := v.udpConns.find(locAddr); ok {
return nil, &net.OpError{
Op: "listen",
Net: network,
Addr: locAddr,
Err: fmt.Errorf("bind: %w", errAddressAlreadyInUse),
}
}
conn, err := newUDPConn(locAddr, remAddr, v)
if err != nil {
return nil, err
}
err = v.udpConns.insert(conn)
if err != nil {
return nil, err
}
return conn, nil
}
func (v *vNet) listenPacket(network string, address string) (UDPPacketConn, error) {
v.mutex.Lock()
defer v.mutex.Unlock()
locAddr, err := v.resolveUDPAddr(network, address)
if err != nil {
return nil, err
}
return v._dialUDP(network, locAddr, nil)
}
func (v *vNet) listenUDP(network string, locAddr *net.UDPAddr) (UDPPacketConn, error) {
v.mutex.Lock()
defer v.mutex.Unlock()
return v._dialUDP(network, locAddr, nil)
}
func (v *vNet) dialUDP(network string, locAddr, remAddr *net.UDPAddr) (UDPPacketConn, error) {
v.mutex.Lock()
defer v.mutex.Unlock()
return v._dialUDP(network, locAddr, remAddr)
}
func (v *vNet) dial(network string, address string) (UDPPacketConn, error) {
v.mutex.Lock()
defer v.mutex.Unlock()
remAddr, err := v.resolveUDPAddr(network, address)
if err != nil {
return nil, err
}
// Determine source address
srcIP := v.determineSourceIP(nil, remAddr.IP)
locAddr := &net.UDPAddr{IP: srcIP, Port: 0}
return v._dialUDP(network, locAddr, remAddr)
}
func (v *vNet) resolveUDPAddr(network, address string) (*net.UDPAddr, error) {
if network != udpString && network != "udp4" {
return nil, fmt.Errorf("%w %s", errUnknownNetwork, network)
}
host, sPort, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
// Check if host is a domain name
ip := net.ParseIP(host)
if ip == nil {
host = strings.ToLower(host)
if host == "localhost" {
ip = net.IPv4(127, 0, 0, 1)
} else {
// host is a domain name. resolve IP address by the name
if v.router == nil {
return nil, errNoRouterLinked
}
ip, err = v.router.resolver.lookUp(host)
if err != nil {
return nil, err
}
}
}
port, err := strconv.Atoi(sPort)
if err != nil {
return nil, errInvalidPortNumber
}
udpAddr := &net.UDPAddr{
IP: ip,
Port: port,
}
return udpAddr, nil
}
func (v *vNet) write(c Chunk) error {
if c.Network() == udpString {
if udp, ok := c.(*chunkUDP); ok {
if c.getDestinationIP().IsLoopback() {
if conn, ok := v.udpConns.find(udp.DestinationAddr()); ok {
conn.onInboundChunk(udp)
}
return nil
}
} else {
return errUnexpectedTypeSwitchFailure
}
}
if v.router == nil {
return errNoRouterLinked
}
v.router.push(c)
return nil
}
func (v *vNet) onClosed(addr net.Addr) {
if addr.Network() == udpString {
//nolint:errcheck
v.udpConns.delete(addr) // #nosec
}
}
// This method determines the srcIP based on the dstIP when locIP
// is any IP address ("0.0.0.0" or "::"). If locIP is a non-any addr,
// this method simply returns locIP.
// caller must hold the mutex
func (v *vNet) determineSourceIP(locIP, dstIP net.IP) net.IP {
if locIP != nil && !locIP.IsUnspecified() {
return locIP
}
var srcIP net.IP
if dstIP.IsLoopback() {
srcIP = net.ParseIP("127.0.0.1")
} else {
ifc, err2 := v._getInterface("eth0")
if err2 != nil {
return nil
}
addrs, err2 := ifc.Addrs()
if err2 != nil {
return nil
}
if len(addrs) == 0 {
return nil
}
var findIPv4 bool
if locIP != nil {
findIPv4 = (locIP.To4() != nil)
} else {
findIPv4 = (dstIP.To4() != nil)
}
for _, addr := range addrs {
ip := addr.(*net.IPNet).IP
if findIPv4 {
if ip.To4() != nil {
srcIP = ip
break
}
} else {
if ip.To4() == nil {
srcIP = ip
break
}
}
}
}
return srcIP
}
// caller must hold the mutex
func (v *vNet) hasIPAddr(ip net.IP) bool { //nolint:gocognit
for _, ifc := range v.interfaces {
if addrs, err := ifc.Addrs(); err == nil {
for _, addr := range addrs {
var locIP net.IP
if ipNet, ok := addr.(*net.IPNet); ok {
locIP = ipNet.IP
} else if ipAddr, ok := addr.(*net.IPAddr); ok {
locIP = ipAddr.IP
} else {
continue
}
switch ip.String() {
case "0.0.0.0":
if locIP.To4() != nil {
return true
}
case "::":
if locIP.To4() == nil {
return true
}
default:
if locIP.Equal(ip) {
return true
}
}
}
}
}
return false
}
// caller must hold the mutex
func (v *vNet) allocateLocalAddr(ip net.IP, port int) error {
// gather local IP addresses to bind
var ips []net.IP
if ip.IsUnspecified() {
ips = v.getAllIPAddrs(ip.To4() == nil)
} else if v.hasIPAddr(ip) {
ips = []net.IP{ip}
}
if len(ips) == 0 {
return fmt.Errorf("%w %s", errBindFailerFor, ip.String())
}
// check if all these transport addresses are not in use
for _, ip2 := range ips {
addr := &net.UDPAddr{
IP: ip2,
Port: port,
}
if _, ok := v.udpConns.find(addr); ok {
return &net.OpError{
Op: "bind",
Net: udpString,
Addr: addr,
Err: fmt.Errorf("bind: %w", errAddressAlreadyInUse),
}
}
}
return nil
}
// caller must hold the mutex
func (v *vNet) assignPort(ip net.IP, start, end int) (int, error) {
// choose randomly from the range between start and end (inclusive)
if end < start {
return -1, errEndPortLessThanStart
}
space := end + 1 - start
offset := rand.Intn(space) //nolint:gosec
for i := 0; i < space; i++ {
port := ((offset + i) % space) + start
err := v.allocateLocalAddr(ip, port)
if err == nil {
return port, nil
}
}
return -1, errPortSpaceExhausted
}
// NetConfig is a bag of configuration parameters passed to NewNet().
type NetConfig struct {
// StaticIPs is an array of static IP addresses to be assigned for this Net.
// If no static IP address is given, the router will automatically assign
// an IP address.
StaticIPs []string
// StaticIP is deprecated. Use StaticIPs.
StaticIP string
}
// Net represents a local network stack euivalent to a set of layers from NIC
// up to the transport (UDP / TCP) layer.
type Net struct {
v *vNet
ifs []*Interface
}
// NewNet creates an instance of Net.
// If config is nil, the virtual network is disabled. (uses corresponding
// net.Xxxx() operations.
// By design, it always have lo0 and eth0 interfaces.
// The lo0 has the address 127.0.0.1 assigned by default.
// IP address for eth0 will be assigned when this Net is added to a router.
func NewNet(config *NetConfig) *Net {
if config == nil {
ifs := []*Interface{}
if orgIfs, err := net.Interfaces(); err == nil {
for _, orgIfc := range orgIfs {
ifc := NewInterface(orgIfc)
if addrs, err := orgIfc.Addrs(); err == nil {
for _, addr := range addrs {
ifc.AddAddr(addr)
}
}
ifs = append(ifs, ifc)
}
}
return &Net{ifs: ifs}
}
lo0 := NewInterface(net.Interface{
Index: 1,
MTU: 16384,
Name: lo0String,
HardwareAddr: nil,
Flags: net.FlagUp | net.FlagLoopback | net.FlagMulticast,
})
lo0.AddAddr(&net.IPNet{
IP: net.ParseIP("127.0.0.1"),
Mask: net.CIDRMask(8, 32),
})
eth0 := NewInterface(net.Interface{
Index: 2,
MTU: 1500,
Name: "eth0",
HardwareAddr: newMACAddress(),
Flags: net.FlagUp | net.FlagMulticast,
})
var staticIPs []net.IP
for _, ipStr := range config.StaticIPs {
if ip := net.ParseIP(ipStr); ip != nil {
staticIPs = append(staticIPs, ip)
}
}
if len(config.StaticIP) > 0 {
if ip := net.ParseIP(config.StaticIP); ip != nil {
staticIPs = append(staticIPs, ip)
}
}
v := &vNet{
interfaces: []*Interface{lo0, eth0},
staticIPs: staticIPs,
udpConns: newUDPConnMap(),
}
return &Net{
v: v,
}
}
// Interfaces returns a list of the system's network interfaces.
func (n *Net) Interfaces() ([]*Interface, error) {
if n.v == nil {
return n.ifs, nil
}
return n.v.getInterfaces()
}
// InterfaceByName returns the interface specified by name.
func (n *Net) InterfaceByName(name string) (*Interface, error) {
if n.v == nil {
for _, ifc := range n.ifs {
if ifc.Name == name {
return ifc, nil
}
}
return nil, fmt.Errorf("interface %s %w", name, errNotFound)
}
return n.v.getInterface(name)
}
// ListenPacket announces on the local network address.
func (n *Net) ListenPacket(network string, address string) (net.PacketConn, error) {
if n.v == nil {
return net.ListenPacket(network, address)
}
return n.v.listenPacket(network, address)
}
// ListenUDP acts like ListenPacket for UDP networks.
func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (UDPPacketConn, error) {
if n.v == nil {
return net.ListenUDP(network, locAddr)
}
return n.v.listenUDP(network, locAddr)
}
// Dial connects to the address on the named network.
func (n *Net) Dial(network, address string) (net.Conn, error) {
if n.v == nil {
return net.Dial(network, address)
}
return n.v.dial(network, address)
}
// CreateDialer creates an instance of vnet.Dialer
func (n *Net) CreateDialer(dialer *net.Dialer) Dialer {
if n.v == nil {
return &vDialer{
dialer: dialer,
}
}
return &vDialer{
dialer: dialer,
v: n.v,
}
}
// DialUDP acts like Dial for UDP networks.
func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (UDPPacketConn, error) {
if n.v == nil {
return net.DialUDP(network, laddr, raddr)
}
return n.v.dialUDP(network, laddr, raddr)
}
// ResolveUDPAddr returns an address of UDP end point.
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
if n.v == nil {
return net.ResolveUDPAddr(network, address)
}
return n.v.resolveUDPAddr(network, address)
}
func (n *Net) getInterface(ifName string) (*Interface, error) {
if n.v == nil {
return nil, errVNetDisabled
}
return n.v.getInterface(ifName)
}
func (n *Net) setRouter(r *Router) error {
if n.v == nil {
return errVNetDisabled
}
return n.v.setRouter(r)
}
func (n *Net) onInboundChunk(c Chunk) {
if n.v == nil {
return
}
n.v.onInboundChunk(c)
}
func (n *Net) getStaticIPs() []net.IP {
if n.v == nil {
return nil
}
return n.v.staticIPs
}
// IsVirtual tests if the virtual network is enabled.
func (n *Net) IsVirtual() bool {
return n.v != nil
}
// Dialer is identical to net.Dialer excepts that its methods
// (Dial, DialContext) are overridden to use virtual network.
// Use vnet.CreateDialer() to create an instance of this Dialer.
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}
type vDialer struct {
dialer *net.Dialer
v *vNet
}
func (d *vDialer) Dial(network, address string) (net.Conn, error) {
if d.v == nil {
return d.dialer.Dial(network, address)
}
return d.v.dial(network, address)
}

View file

@ -0,0 +1,89 @@
package vnet
import (
"errors"
"fmt"
"net"
"sync"
"github.com/pion/logging"
)
var (
errHostnameEmpty = errors.New("host name must not be empty")
errFailedtoParseIPAddr = errors.New("failed to parse IP address")
)
type resolverConfig struct {
LoggerFactory logging.LoggerFactory
}
type resolver struct {
parent *resolver // read-only
hosts map[string]net.IP // requires mutex
mutex sync.RWMutex // thread-safe
log logging.LeveledLogger // read-only
}
func newResolver(config *resolverConfig) *resolver {
r := &resolver{
hosts: map[string]net.IP{},
log: config.LoggerFactory.NewLogger("vnet"),
}
if err := r.addHost("localhost", "127.0.0.1"); err != nil {
r.log.Warn("failed to add localhost to resolver")
}
return r
}
func (r *resolver) setParent(parent *resolver) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.parent = parent
}
func (r *resolver) addHost(name string, ipAddr string) error {
r.mutex.Lock()
defer r.mutex.Unlock()
if len(name) == 0 {
return errHostnameEmpty
}
ip := net.ParseIP(ipAddr)
if ip == nil {
return fmt.Errorf("%w \"%s\"", errFailedtoParseIPAddr, ipAddr)
}
r.hosts[name] = ip
return nil
}
func (r *resolver) lookUp(hostName string) (net.IP, error) {
ip := func() net.IP {
r.mutex.RLock()
defer r.mutex.RUnlock()
if ip2, ok := r.hosts[hostName]; ok {
return ip2
}
return nil
}()
if ip != nil {
return ip, nil
}
// mutex must be unlocked before calling into parent resolver
if r.parent != nil {
return r.parent.lookUp(hostName)
}
return nil, &net.DNSError{
Err: "host not found",
Name: hostName,
Server: "vnet resolver",
IsTimeout: false,
IsTemporary: false,
}
}

View file

@ -0,0 +1,605 @@
package vnet
import (
"errors"
"fmt"
"math/rand"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pion/logging"
)
const (
defaultRouterQueueSize = 0 // unlimited
)
var (
errInvalidLocalIPinStaticIPs = errors.New("invalid local IP in StaticIPs")
errLocalIPBeyondStaticIPsSubset = errors.New("mapped in StaticIPs is beyond subnet")
errLocalIPNoStaticsIPsAssociated = errors.New("all StaticIPs must have associated local IPs")
errRouterAlreadyStarted = errors.New("router already started")
errRouterAlreadyStopped = errors.New("router already stopped")
errStaticIPisBeyondSubnet = errors.New("static IP is beyond subnet")
errAddressSpaceExhausted = errors.New("address space exhausted")
errNoIPAddrEth0 = errors.New("no IP address is assigned for eth0")
)
// Generate a unique router name
var assignRouterName = func() func() string { //nolint:gochecknoglobals
var routerIDCtr uint64
return func() string {
n := atomic.AddUint64(&routerIDCtr, 1)
return fmt.Sprintf("router%d", n)
}
}()
// RouterConfig ...
type RouterConfig struct {
// Name of router. If not specified, a unique name will be assigned.
Name string
// CIDR notation, like "192.0.2.0/24"
CIDR string
// StaticIPs is an array of static IP addresses to be assigned for this router.
// If no static IP address is given, the router will automatically assign
// an IP address.
// This will be ignored if this router is the root.
StaticIPs []string
// StaticIP is deprecated. Use StaticIPs.
StaticIP string
// Internal queue size
QueueSize int
// Effective only when this router has a parent router
NATType *NATType
// Minimum Delay
MinDelay time.Duration
// Max Jitter
MaxJitter time.Duration
// Logger factory
LoggerFactory logging.LoggerFactory
}
// NIC is a nework inerface controller that interfaces Router
type NIC interface {
getInterface(ifName string) (*Interface, error)
onInboundChunk(c Chunk)
getStaticIPs() []net.IP
setRouter(r *Router) error
}
// ChunkFilter is a handler users can add to filter chunks.
// If the filter returns false, the packet will be dropped.
type ChunkFilter func(c Chunk) bool
// Router ...
type Router struct {
name string // read-only
interfaces []*Interface // read-only
ipv4Net *net.IPNet // read-only
staticIPs []net.IP // read-only
staticLocalIPs map[string]net.IP // read-only,
lastID byte // requires mutex [x], used to assign the last digit of IPv4 address
queue *chunkQueue // read-only
parent *Router // read-only
children []*Router // read-only
natType *NATType // read-only
nat *networkAddressTranslator // read-only
nics map[string]NIC // read-only
stopFunc func() // requires mutex [x]
resolver *resolver // read-only
chunkFilters []ChunkFilter // requires mutex [x]
minDelay time.Duration // requires mutex [x]
maxJitter time.Duration // requires mutex [x]
mutex sync.RWMutex // thread-safe
pushCh chan struct{} // writer requires mutex
loggerFactory logging.LoggerFactory // read-only
log logging.LeveledLogger // read-only
}
// NewRouter ...
func NewRouter(config *RouterConfig) (*Router, error) {
loggerFactory := config.LoggerFactory
log := loggerFactory.NewLogger("vnet")
_, ipv4Net, err := net.ParseCIDR(config.CIDR)
if err != nil {
return nil, err
}
queueSize := defaultRouterQueueSize
if config.QueueSize > 0 {
queueSize = config.QueueSize
}
// set up network interface, lo0
lo0 := NewInterface(net.Interface{
Index: 1,
MTU: 16384,
Name: lo0String,
HardwareAddr: nil,
Flags: net.FlagUp | net.FlagLoopback | net.FlagMulticast,
})
lo0.AddAddr(&net.IPAddr{IP: net.ParseIP("127.0.0.1"), Zone: ""})
// set up network interface, eth0
eth0 := NewInterface(net.Interface{
Index: 2,
MTU: 1500,
Name: "eth0",
HardwareAddr: newMACAddress(),
Flags: net.FlagUp | net.FlagMulticast,
})
// local host name resolver
resolver := newResolver(&resolverConfig{
LoggerFactory: config.LoggerFactory,
})
name := config.Name
if len(name) == 0 {
name = assignRouterName()
}
var staticIPs []net.IP
staticLocalIPs := map[string]net.IP{}
for _, ipStr := range config.StaticIPs {
ipPair := strings.Split(ipStr, "/")
if ip := net.ParseIP(ipPair[0]); ip != nil {
if len(ipPair) > 1 {
locIP := net.ParseIP(ipPair[1])
if locIP == nil {
return nil, errInvalidLocalIPinStaticIPs
}
if !ipv4Net.Contains(locIP) {
return nil, fmt.Errorf("local IP %s %w", locIP.String(), errLocalIPBeyondStaticIPsSubset)
}
staticLocalIPs[ip.String()] = locIP
}
staticIPs = append(staticIPs, ip)
}
}
if len(config.StaticIP) > 0 {
log.Warn("StaticIP is deprecated. Use StaticIPs instead")
if ip := net.ParseIP(config.StaticIP); ip != nil {
staticIPs = append(staticIPs, ip)
}
}
if nStaticLocal := len(staticLocalIPs); nStaticLocal > 0 {
if nStaticLocal != len(staticIPs) {
return nil, errLocalIPNoStaticsIPsAssociated
}
}
return &Router{
name: name,
interfaces: []*Interface{lo0, eth0},
ipv4Net: ipv4Net,
staticIPs: staticIPs,
staticLocalIPs: staticLocalIPs,
queue: newChunkQueue(queueSize),
natType: config.NATType,
nics: map[string]NIC{},
resolver: resolver,
minDelay: config.MinDelay,
maxJitter: config.MaxJitter,
pushCh: make(chan struct{}, 1),
loggerFactory: loggerFactory,
log: log,
}, nil
}
// caller must hold the mutex
func (r *Router) getInterfaces() ([]*Interface, error) {
if len(r.interfaces) == 0 {
return nil, fmt.Errorf("%w is available", errNoInterface)
}
return r.interfaces, nil
}
func (r *Router) getInterface(ifName string) (*Interface, error) {
r.mutex.RLock()
defer r.mutex.RUnlock()
ifs, err := r.getInterfaces()
if err != nil {
return nil, err
}
for _, ifc := range ifs {
if ifc.Name == ifName {
return ifc, nil
}
}
return nil, fmt.Errorf("interface %s %w", ifName, errNotFound)
}
// Start ...
func (r *Router) Start() error {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.stopFunc != nil {
return errRouterAlreadyStarted
}
cancelCh := make(chan struct{})
go func() {
loop:
for {
d, err := r.processChunks()
if err != nil {
r.log.Errorf("[%s] %s", r.name, err.Error())
break
}
if d <= 0 {
select {
case <-r.pushCh:
case <-cancelCh:
break loop
}
} else {
t := time.NewTimer(d)
select {
case <-t.C:
case <-cancelCh:
break loop
}
}
}
}()
r.stopFunc = func() {
close(cancelCh)
}
for _, child := range r.children {
if err := child.Start(); err != nil {
return err
}
}
return nil
}
// Stop ...
func (r *Router) Stop() error {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.stopFunc == nil {
return errRouterAlreadyStopped
}
for _, router := range r.children {
r.mutex.Unlock()
err := router.Stop()
r.mutex.Lock()
if err != nil {
return err
}
}
r.stopFunc()
r.stopFunc = nil
return nil
}
// caller must hold the mutex
func (r *Router) addNIC(nic NIC) error {
ifc, err := nic.getInterface("eth0")
if err != nil {
return err
}
var ips []net.IP
if ips = nic.getStaticIPs(); len(ips) == 0 {
// assign an IP address
ip, err2 := r.assignIPAddress()
if err2 != nil {
return err2
}
ips = append(ips, ip)
}
for _, ip := range ips {
if !r.ipv4Net.Contains(ip) {
return fmt.Errorf("%w: %s", errStaticIPisBeyondSubnet, r.ipv4Net.String())
}
ifc.AddAddr(&net.IPNet{
IP: ip,
Mask: r.ipv4Net.Mask,
})
r.nics[ip.String()] = nic
}
if err = nic.setRouter(r); err != nil {
return err
}
return nil
}
// AddRouter adds a chile Router.
func (r *Router) AddRouter(router *Router) error {
r.mutex.Lock()
defer r.mutex.Unlock()
// Router is a NIC. Add it as a NIC so that packets are routed to this child
// router.
err := r.addNIC(router)
if err != nil {
return err
}
if err = router.setRouter(r); err != nil {
return err
}
r.children = append(r.children, router)
return nil
}
// AddNet ...
func (r *Router) AddNet(nic NIC) error {
r.mutex.Lock()
defer r.mutex.Unlock()
return r.addNIC(nic)
}
// AddHost adds a mapping of hostname and an IP address to the local resolver.
func (r *Router) AddHost(hostName string, ipAddr string) error {
return r.resolver.addHost(hostName, ipAddr)
}
// AddChunkFilter adds a filter for chunks traversing this router.
// You may add more than one filter. The filters are called in the order of this method call.
// If a chunk is dropped by a filter, subsequent filter will not receive the chunk.
func (r *Router) AddChunkFilter(filter ChunkFilter) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.chunkFilters = append(r.chunkFilters, filter)
}
// caller should hold the mutex
func (r *Router) assignIPAddress() (net.IP, error) {
// See: https://stackoverflow.com/questions/14915188/ip-address-ending-with-zero
if r.lastID == 0xfe {
return nil, errAddressSpaceExhausted
}
ip := make(net.IP, 4)
copy(ip, r.ipv4Net.IP[:3])
r.lastID++
ip[3] = r.lastID
return ip, nil
}
func (r *Router) push(c Chunk) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.log.Debugf("[%s] route %s", r.name, c.String())
if r.stopFunc != nil {
c.setTimestamp()
if r.queue.push(c) {
select {
case r.pushCh <- struct{}{}:
default:
}
} else {
r.log.Warnf("[%s] queue was full. dropped a chunk", r.name)
}
}
}
func (r *Router) processChunks() (time.Duration, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
// Introduce jitter by delaying the processing of chunks.
if r.maxJitter > 0 {
jitter := time.Duration(rand.Int63n(int64(r.maxJitter))) //nolint:gosec
time.Sleep(jitter)
}
// cutOff
// v min delay
// |<--->|
// +------------:--
// |OOOOOOXXXXX : --> time
// +------------:--
// |<--->| now
// due
enteredAt := time.Now()
cutOff := enteredAt.Add(-r.minDelay)
var d time.Duration // the next sleep duration
for {
d = 0
c := r.queue.peek()
if c == nil {
break // no more chunk in the queue
}
// check timestamp to find if the chunk is due
if c.getTimestamp().After(cutOff) {
// There is one or more chunk in the queue but none of them are due.
// Calculate the next sleep duration here.
nextExpire := c.getTimestamp().Add(r.minDelay)
d = nextExpire.Sub(enteredAt)
break
}
var ok bool
if c, ok = r.queue.pop(); !ok {
break // no more chunk in the queue
}
blocked := false
for i := 0; i < len(r.chunkFilters); i++ {
filter := r.chunkFilters[i]
if !filter(c) {
blocked = true
break
}
}
if blocked {
continue // discard
}
dstIP := c.getDestinationIP()
// check if the desination is in our subnet
if r.ipv4Net.Contains(dstIP) {
// search for the destination NIC
var nic NIC
if nic, ok = r.nics[dstIP.String()]; !ok {
// NIC not found. drop it.
r.log.Debugf("[%s] %s unreachable", r.name, c.String())
continue
}
// found the NIC, forward the chunk to the NIC.
// call to NIC must unlock mutex
r.mutex.Unlock()
nic.onInboundChunk(c)
r.mutex.Lock()
continue
}
// the destination is outside of this subnet
// is this WAN?
if r.parent == nil {
// this WAN. No route for this chunk
r.log.Debugf("[%s] no route found for %s", r.name, c.String())
continue
}
// Pass it to the parent via NAT
toParent, err := r.nat.translateOutbound(c)
if err != nil {
return 0, err
}
if toParent == nil {
continue
}
//nolint:godox
/* FIXME: this implementation would introduce a duplicate packet!
if r.nat.natType.Hairpining {
hairpinned, err := r.nat.translateInbound(toParent)
if err != nil {
r.log.Warnf("[%s] %s", r.name, err.Error())
} else {
go func() {
r.push(hairpinned)
}()
}
}
*/
// call to parent router mutex unlock mutex
r.mutex.Unlock()
r.parent.push(toParent)
r.mutex.Lock()
}
return d, nil
}
// caller must hold the mutex
func (r *Router) setRouter(parent *Router) error {
r.parent = parent
r.resolver.setParent(parent.resolver)
// when this method is called, one or more IP address has already been assigned by
// the parent router.
ifc, err := r.getInterface("eth0")
if err != nil {
return err
}
if len(ifc.addrs) == 0 {
return errNoIPAddrEth0
}
mappedIPs := []net.IP{}
localIPs := []net.IP{}
for _, ifcAddr := range ifc.addrs {
var ip net.IP
switch addr := ifcAddr.(type) {
case *net.IPNet:
ip = addr.IP
case *net.IPAddr: // Do we really need this case?
ip = addr.IP
default:
}
if ip == nil {
continue
}
mappedIPs = append(mappedIPs, ip)
if locIP := r.staticLocalIPs[ip.String()]; locIP != nil {
localIPs = append(localIPs, locIP)
}
}
// Set up NAT here
if r.natType == nil {
r.natType = &NATType{
MappingBehavior: EndpointIndependent,
FilteringBehavior: EndpointAddrPortDependent,
Hairpining: false,
PortPreservation: false,
MappingLifeTime: 30 * time.Second,
}
}
r.nat, err = newNAT(&natConfig{
name: r.name,
natType: *r.natType,
mappedIPs: mappedIPs,
localIPs: localIPs,
loggerFactory: r.loggerFactory,
})
if err != nil {
return err
}
return nil
}
func (r *Router) onInboundChunk(c Chunk) {
fromParent, err := r.nat.translateInbound(c)
if err != nil {
r.log.Warnf("[%s] %s", r.name, err.Error())
return
}
r.push(fromParent)
}
func (r *Router) getStaticIPs() []net.IP {
return r.staticIPs
}

View file

@ -0,0 +1,2 @@
// Package vnet provides a virtual network layer for pion
package vnet