1
0
Fork 0
mirror of https://github.com/ossrs/srs.git synced 2025-03-09 15:49:59 +00:00

TEST: Upgrade pion to v3.2.9. (#3567)

------

Co-authored-by: chundonglinlin <chundonglinlin@163.com>
This commit is contained in:
Winlin 2023-06-05 11:25:04 +08:00 committed by GitHub
parent 104cf14d68
commit df854339ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
1383 changed files with 118469 additions and 41421 deletions

View file

@ -2,6 +2,8 @@ package sctp
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math"
@ -12,13 +14,33 @@ import (
"github.com/pion/logging"
"github.com/pion/randutil"
"github.com/pkg/errors"
)
// Use global random generator to properly seed by crypto grade random.
var globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals
// Association errors
var (
globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals
errChunk = errors.New("Abort chunk, with following errors")
ErrChunk = errors.New("abort chunk, with following errors")
ErrShutdownNonEstablished = errors.New("shutdown called in non-established state")
ErrAssociationClosedBeforeConn = errors.New("association closed before connecting")
ErrSilentlyDiscard = errors.New("silently discard")
ErrInitNotStoredToSend = errors.New("the init not stored to send")
ErrCookieEchoNotStoredToSend = errors.New("cookieEcho not stored to send")
ErrSCTPPacketSourcePortZero = errors.New("sctp packet must not have a source port of 0")
ErrSCTPPacketDestinationPortZero = errors.New("sctp packet must not have a destination port of 0")
ErrInitChunkBundled = errors.New("init chunk must not be bundled with any other chunk")
ErrInitChunkVerifyTagNotZero = errors.New("init chunk expects a verification tag of 0 on the packet when out-of-the-blue")
ErrHandleInitState = errors.New("todo: handle Init when in state")
ErrInitAckNoCookie = errors.New("no cookie in InitAck")
ErrInflightQueueTSNPop = errors.New("unable to be popped from inflight queue TSN")
ErrTSNRequestNotExist = errors.New("requested non-existent TSN")
ErrResetPacketInStateNotExist = errors.New("sending reset packet in non-established state")
ErrParamterType = errors.New("unexpected parameter type")
ErrPayloadDataStateNotExist = errors.New("sending payload data in non-established state")
ErrChunkTypeUnhandled = errors.New("unhandled chunk type")
ErrHandshakeInitAck = errors.New("handshake failed (INIT ACK)")
ErrHandshakeCookieEcho = errors.New("handshake failed (COOKIE ECHO)")
)
const (
@ -46,6 +68,7 @@ const (
const (
timerT1Init int = iota
timerT1Cookie
timerT2Shutdown
timerT3RTX
timerReconfig
)
@ -60,8 +83,8 @@ const (
// ack transmission state
const (
ackStateIdle int = iota // ack timer is off
ackStateImmediate // ack timer is on (ack is being delayed)
ackStateDelay // will send ack immediately
ackStateImmediate // will send ack immediately
ackStateDelay // ack timer is on (ack is being delayed)
)
// other constants
@ -94,21 +117,17 @@ func getAssociationStateString(a uint32) string {
// Association represents an SCTP association
// 13.2. Parameters Necessary per Association (i.e., the TCB)
// Peer : Tag value to be sent in every packet and is received
// Verification: in the INIT or INIT ACK chunk.
// Tag :
//
// My : Tag expected in every inbound packet and sent in the
// Verification: INIT or INIT ACK chunk.
// Peer : Tag value to be sent in every packet and is received
// Verification: in the INIT or INIT ACK chunk.
// Tag :
// State : A state variable indicating what state the association
// : is in, i.e., COOKIE-WAIT, COOKIE-ECHOED, ESTABLISHED,
// : SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED,
// : SHUTDOWN-ACK-SENT.
//
// Tag :
// State : A state variable indicating what state the association
// : is in, i.e., COOKIE-WAIT, COOKIE-ECHOED, ESTABLISHED,
// : SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED,
// : SHUTDOWN-ACK-SENT.
//
// Note: No "CLOSED" state is illustrated since if a
// association is "CLOSED" its TCB SHOULD be removed.
// Note: No "CLOSED" state is illustrated since if a
// association is "CLOSED" its TCB SHOULD be removed.
type Association struct {
bytesReceived uint64
bytesSent uint64
@ -127,6 +146,13 @@ type Association struct {
willRetransmitFast bool
willRetransmitReconfig bool
willSendShutdown bool
willSendShutdownAck bool
willSendShutdownComplete bool
willSendAbort bool
willSendAbortCause errorCause
// Reconfig
myNextRSN uint32
reconfigs map[uint32]*chunkReconfig
@ -143,7 +169,8 @@ type Association struct {
pendingQueue *pendingQueue
controlQueue *controlQueue
mtu uint32
maxPayloadSize uint32 // max DATA chunk payload size
maxPayloadSize uint32 // max DATA chunk payload size
srtt atomic.Value // type float64
cumulativeTSNAckPoint uint32
advancedPeerTSNAckPoint uint32
useForwardTSN bool
@ -159,12 +186,13 @@ type Association struct {
fastRecoverExitPoint uint32
// RTX & Ack timer
rtoMgr *rtoManager
t1Init *rtxTimer
t1Cookie *rtxTimer
t3RTX *rtxTimer
tReconfig *rtxTimer
ackTimer *ackTimer
rtoMgr *rtoManager
t1Init *rtxTimer
t1Cookie *rtxTimer
t2Shutdown *rtxTimer
t3RTX *rtxTimer
tReconfig *rtxTimer
ackTimer *ackTimer
// Chunks stored for retransmission
storedInit *chunkInit
@ -217,7 +245,7 @@ func Server(config Config) (*Association, error) {
}
return a, nil
case <-a.readLoopCloseCh:
return nil, errors.Errorf("association closed before connecting")
return nil, ErrAssociationClosedBeforeConn
}
}
@ -233,7 +261,7 @@ func Client(config Config) (*Association, error) {
}
return a, nil
case <-a.readLoopCloseCh:
return nil, errors.Errorf("association closed before connecting")
return nil, ErrAssociationClosedBeforeConn
}
}
@ -281,7 +309,7 @@ func createAssociation(config Config) *Association {
handshakeCompletedCh: make(chan error),
cumulativeTSNAckPoint: tsn - 1,
advancedPeerTSNAckPoint: tsn - 1,
silentError: errors.Errorf("silently discard"),
silentError: ErrSilentlyDiscard,
stats: &associationStats{},
log: config.LoggerFactory.NewLogger("sctp"),
}
@ -296,10 +324,12 @@ func createAssociation(config Config) *Association {
a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)",
a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes())
a.srtt.Store(float64(0))
a.t1Init = newRTXTimer(timerT1Init, a, maxInitRetrans)
a.t1Cookie = newRTXTimer(timerT1Cookie, a, maxInitRetrans)
a.t3RTX = newRTXTimer(timerT3RTX, a, noMaxRetrans) // retransmit forever
a.tReconfig = newRTXTimer(timerReconfig, a, noMaxRetrans) // retransmit forever
a.t2Shutdown = newRTXTimer(timerT2Shutdown, a, noMaxRetrans) // retransmit forever
a.t3RTX = newRTXTimer(timerT3RTX, a, noMaxRetrans) // retransmit forever
a.tReconfig = newRTXTimer(timerReconfig, a, noMaxRetrans) // retransmit forever
a.ackTimer = newAckTimer(a)
return a
@ -336,7 +366,7 @@ func (a *Association) init(isClient bool) {
func (a *Association) sendInit() error {
a.log.Debugf("[%s] sending INIT", a.name)
if a.storedInit == nil {
return errors.Errorf("the init not stored to send")
return ErrInitNotStoredToSend
}
outbound := &packet{}
@ -357,7 +387,7 @@ func (a *Association) sendInit() error {
// caller must hold a.lock
func (a *Association) sendCookieEcho() error {
if a.storedCookieEcho == nil {
return errors.Errorf("cookieEcho not stored to send")
return ErrCookieEchoNotStoredToSend
}
a.log.Debugf("[%s] sending COOKIE-ECHO", a.name)
@ -374,10 +404,62 @@ func (a *Association) sendCookieEcho() error {
return nil
}
// Shutdown initiates the shutdown sequence. The method blocks until the
// shutdown sequence is completed and the connection is closed, or until the
// passed context is done, in which case the context's error is returned.
func (a *Association) Shutdown(ctx context.Context) error {
a.log.Debugf("[%s] closing association..", a.name)
state := a.getState()
if state != established {
return fmt.Errorf("%w: shutdown %s", ErrShutdownNonEstablished, a.name)
}
// Attempt a graceful shutdown.
a.setState(shutdownPending)
a.lock.Lock()
if a.inflightQueue.size() == 0 {
// No more outstanding, send shutdown.
a.willSendShutdown = true
a.awakeWriteLoop()
a.setState(shutdownSent)
}
a.lock.Unlock()
select {
case <-a.closeWriteLoopCh:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Close ends the SCTP Association and cleans up any state
func (a *Association) Close() error {
a.log.Debugf("[%s] closing association..", a.name)
err := a.close()
// Wait for readLoop to end
<-a.readLoopCloseCh
a.log.Debugf("[%s] association closed", a.name)
a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs())
a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs())
a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts())
a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts())
a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans())
return err
}
func (a *Association) close() error {
a.log.Debugf("[%s] closing association..", a.name)
a.setState(closed)
err := a.netConn.Close()
@ -387,22 +469,34 @@ func (a *Association) Close() error {
// awake writeLoop to exit
a.closeWriteLoopOnce.Do(func() { close(a.closeWriteLoopCh) })
return err
}
// Abort sends the abort packet with user initiated abort and immediately
// closes the connection.
func (a *Association) Abort(reason string) {
a.log.Debugf("[%s] aborting association: %s", a.name, reason)
a.lock.Lock()
a.willSendAbort = true
a.willSendAbortCause = &errorCauseUserInitiatedAbort{
upperLayerAbortReason: []byte(reason),
}
a.lock.Unlock()
a.awakeWriteLoop()
// Wait for readLoop to end
<-a.readLoopCloseCh
a.log.Debugf("[%s] association closed", a.name)
a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs())
a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs())
a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts())
a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts())
a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans())
return err
}
func (a *Association) closeAllTimers() {
// Close all retransmission & ack timers
a.t1Init.close()
a.t1Cookie.close()
a.t2Shutdown.close()
a.t3RTX.close()
a.tReconfig.close()
a.ackTimer.close()
@ -422,6 +516,13 @@ func (a *Association) readLoop() {
a.lock.Unlock()
close(a.acceptCh)
close(a.readLoopCloseCh)
a.log.Debugf("[%s] association closed", a.name)
a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs())
a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs())
a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts())
a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts())
a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans())
}()
a.log.Debugf("[%s] readLoop entered", a.name)
@ -451,15 +552,16 @@ func (a *Association) readLoop() {
func (a *Association) writeLoop() {
a.log.Debugf("[%s] writeLoop entered", a.name)
defer a.log.Debugf("[%s] writeLoop exited", a.name)
loop:
for {
rawPackets := a.gatherOutbound()
rawPackets, ok := a.gatherOutbound()
for _, raw := range rawPackets {
_, err := a.netConn.Write(raw)
if err != nil {
if err != io.EOF {
if !errors.Is(err, io.EOF) {
a.log.Warnf("[%s] failed to write packets on netConn: %v", a.name, err)
}
a.log.Debugf("[%s] writeLoop ended", a.name)
@ -468,6 +570,14 @@ loop:
atomic.AddUint64(&a.bytesSent, uint64(len(raw)))
}
if !ok {
if err := a.close(); err != nil {
a.log.Warnf("[%s] failed to close association: %v", a.name, err)
}
return
}
select {
case <-a.awakeWriteLoopCh:
case <-a.closeWriteLoopCh:
@ -477,8 +587,6 @@ loop:
a.setState(closed)
a.closeAllTimers()
a.log.Debugf("[%s] writeLoop exited", a.name)
}
func (a *Association) awakeWriteLoop() {
@ -526,7 +634,7 @@ func (a *Association) handleInbound(raw []byte) error {
}
// The caller should hold the lock
func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte) [][]byte {
func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte) [][]byte {
for _, p := range a.getDataPacketsToRetransmit() {
raw, err := p.marshal()
if err != nil {
@ -536,6 +644,11 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
rawPackets = append(rawPackets, raw)
}
return rawPackets
}
// The caller should hold the lock
func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte) [][]byte {
// Pop unsent data chunks from the pending queue to send as much as
// cwnd and rwnd allow.
chunks, sisToReset := a.popPendingDataChunksToSend()
@ -599,7 +712,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
}
// The caller should hold the lock
func (a *Association) gatherOutboundFrastRetransmissionPackets(rawPackets [][]byte) [][]byte {
func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byte) [][]byte {
if a.willRetransmitFast {
a.willRetransmitFast = false
@ -662,7 +775,7 @@ func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte {
if a.ackState == ackStateImmediate {
a.ackState = ackStateIdle
sack := a.createSelectiveAckChunk()
a.log.Debugf("[%s] sending SACK: %s", a.name, sack.String())
a.log.Debugf("[%s] sending SACK: %s", a.name, sack)
raw, err := a.createPacket([]chunk{sack}).marshal()
if err != nil {
a.log.Warnf("[%s] failed to serialize a SACK packet", a.name)
@ -692,11 +805,86 @@ func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]b
return rawPackets
}
// gatherOutbound gathers outgoing packets
func (a *Association) gatherOutbound() [][]byte {
func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]byte, bool) {
ok := true
switch {
case a.willSendShutdown:
a.willSendShutdown = false
shutdown := &chunkShutdown{
cumulativeTSNAck: a.cumulativeTSNAckPoint,
}
raw, err := a.createPacket([]chunk{shutdown}).marshal()
if err != nil {
a.log.Warnf("[%s] failed to serialize a Shutdown packet", a.name)
} else {
a.t2Shutdown.start(a.rtoMgr.getRTO())
rawPackets = append(rawPackets, raw)
}
case a.willSendShutdownAck:
a.willSendShutdownAck = false
shutdownAck := &chunkShutdownAck{}
raw, err := a.createPacket([]chunk{shutdownAck}).marshal()
if err != nil {
a.log.Warnf("[%s] failed to serialize a ShutdownAck packet", a.name)
} else {
a.t2Shutdown.start(a.rtoMgr.getRTO())
rawPackets = append(rawPackets, raw)
}
case a.willSendShutdownComplete:
a.willSendShutdownComplete = false
shutdownComplete := &chunkShutdownComplete{}
raw, err := a.createPacket([]chunk{shutdownComplete}).marshal()
if err != nil {
a.log.Warnf("[%s] failed to serialize a ShutdownComplete packet", a.name)
} else {
rawPackets = append(rawPackets, raw)
ok = false
}
}
return rawPackets, ok
}
func (a *Association) gatherAbortPacket() ([]byte, error) {
cause := a.willSendAbortCause
a.willSendAbort = false
a.willSendAbortCause = nil
abort := &chunkAbort{}
if cause != nil {
abort.errorCauses = []errorCause{cause}
}
raw, err := a.createPacket([]chunk{abort}).marshal()
return raw, err
}
// gatherOutbound gathers outgoing packets. The returned bool value set to
// false means the association should be closed down after the final send.
func (a *Association) gatherOutbound() ([][]byte, bool) {
a.lock.Lock()
defer a.lock.Unlock()
if a.willSendAbort {
pkt, err := a.gatherAbortPacket()
if err != nil {
a.log.Warnf("[%s] failed to serialize an abort packet", a.name)
return nil, false
}
return [][]byte{pkt}, false
}
rawPackets := [][]byte{}
if a.controlQueue.size() > 0 {
@ -712,14 +900,25 @@ func (a *Association) gatherOutbound() [][]byte {
state := a.getState()
if state == established {
ok := true
switch state {
case established:
rawPackets = a.gatherDataPacketsToRetransmit(rawPackets)
rawPackets = a.gatherOutboundDataAndReconfigPackets(rawPackets)
rawPackets = a.gatherOutboundFrastRetransmissionPackets(rawPackets)
rawPackets = a.gatherOutboundFastRetransmissionPackets(rawPackets)
rawPackets = a.gatherOutboundSackPackets(rawPackets)
rawPackets = a.gatherOutboundForwardTSNPackets(rawPackets)
case shutdownPending, shutdownSent, shutdownReceived:
rawPackets = a.gatherDataPacketsToRetransmit(rawPackets)
rawPackets = a.gatherOutboundFastRetransmissionPackets(rawPackets)
rawPackets = a.gatherOutboundSackPackets(rawPackets)
rawPackets, ok = a.gatherOutboundShutdownPackets(rawPackets)
case shutdownAckSent:
rawPackets, ok = a.gatherOutboundShutdownPackets(rawPackets)
}
return rawPackets
return rawPackets, ok
}
func checkPacket(p *packet) error {
@ -731,7 +930,7 @@ func checkPacket(p *packet) error {
// identify the association to which this packet belongs. The port
// number 0 MUST NOT be used.
if p.sourcePort == 0 {
return errors.Errorf("sctp packet must not have a source port of 0")
return ErrSCTPPacketSourcePortZero
}
// This is the SCTP port number to which this packet is destined.
@ -739,7 +938,7 @@ func checkPacket(p *packet) error {
// SCTP packet to the correct receiving endpoint/application. The
// port number 0 MUST NOT be used.
if p.destinationPort == 0 {
return errors.Errorf("sctp packet must not have a destination port of 0")
return ErrSCTPPacketDestinationPortZero
}
// Check values on the packet that are specific to a particular chunk type
@ -750,13 +949,13 @@ func checkPacket(p *packet) error {
// They MUST be the only chunks present in the SCTP packets that carry
// them.
if len(p.chunks) != 1 {
return errors.Errorf("init chunk must not be bundled with any other chunk")
return ErrInitChunkBundled
}
// A packet containing an INIT chunk MUST have a zero Verification
// Tag.
if p.verificationTag != 0 {
return errors.Errorf("init chunk expects a verification tag of 0 on the packet when out-of-the-blue")
return ErrInitChunkVerifyTagNotZero
}
}
}
@ -812,6 +1011,26 @@ func (a *Association) BytesReceived() uint64 {
return atomic.LoadUint64(&a.bytesReceived)
}
// MTU returns the association's current MTU
func (a *Association) MTU() uint32 {
return atomic.LoadUint32(&a.mtu)
}
// CWND returns the association's current congestion window (cwnd)
func (a *Association) CWND() uint32 {
return atomic.LoadUint32(&a.cwnd)
}
// RWND returns the association's current receiver window (rwnd)
func (a *Association) RWND() uint32 {
return atomic.LoadUint32(&a.rwnd)
}
// SRTT returns the latest smoothed round-trip time (srrt)
func (a *Association) SRTT() float64 {
return a.srtt.Load().(float64) //nolint:forcetypeassert
}
func setSupportedExtensions(init *chunkInitCommon) {
// nolint:godox
// TODO RFC5061 https://tools.ietf.org/html/rfc6525#section-5.2
@ -838,7 +1057,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
if state != closed && state != cookieWait && state != cookieEchoed {
// 5.2.2. Unexpected INIT in States Other than CLOSED, COOKIE-ECHOED,
// COOKIE-WAIT, and SHUTDOWN-ACK-SENT
return nil, errors.Errorf("todo: handle Init when in state %s", getAssociationStateString(state))
return nil, fmt.Errorf("%w: %s", ErrHandleInitState, getAssociationStateString(state))
}
// Should we be setting any of these permanently until we've ACKed further?
@ -859,14 +1078,14 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
case *paramSupportedExtensions:
for _, t := range v.ChunkTypes {
if t == ctForwardTSN {
a.log.Debugf("[%s] use ForwardTSN (on init)\n", a.name)
a.log.Debugf("[%s] use ForwardTSN (on init)", a.name)
a.useForwardTSN = true
}
}
}
}
if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on init)\n", a.name)
a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name)
}
outbound := &packet{}
@ -944,17 +1163,17 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error {
case *paramSupportedExtensions:
for _, t := range v.ChunkTypes {
if t == ctForwardTSN {
a.log.Debugf("[%s] use ForwardTSN (on initAck)\n", a.name)
a.log.Debugf("[%s] use ForwardTSN (on initAck)", a.name)
a.useForwardTSN = true
}
}
}
}
if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on initAck)\n", a.name)
a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name)
}
if cookieParam == nil {
return errors.Errorf("no cookie in InitAck")
return ErrInitAckNoCookie
}
a.storedCookieEcho = &chunkCookieEcho{}
@ -996,6 +1215,11 @@ func (a *Association) handleHeartbeat(c *chunkHeartbeat) []*packet {
func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet {
state := a.getState()
a.log.Debugf("[%s] COOKIE-ECHO received in state '%s'", a.name, getAssociationStateString(state))
if a.myCookie == nil {
a.log.Debugf("[%s] COOKIE-ECHO received before initialization", a.name)
return nil
}
switch state {
default:
return nil
@ -1054,7 +1278,7 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet {
canPush := a.payloadQueue.canPush(d, a.peerLastTSN)
if canPush {
s := a.getOrCreateStream(d.streamIdentifier)
s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown)
if s == nil {
// silentely discard the data. (sender will retry on T3-rtx timeout)
// see pion/sctp#30
@ -1145,14 +1369,7 @@ func (a *Association) OpenStream(streamIdentifier uint16, defaultPayloadType Pay
a.lock.Lock()
defer a.lock.Unlock()
if _, ok := a.streams[streamIdentifier]; ok {
return nil, errors.Errorf("there already exists a stream with identifier %d", streamIdentifier)
}
s := a.createStream(streamIdentifier, false)
s.setDefaultPayloadType(defaultPayloadType)
return s, nil
return a.getOrCreateStream(streamIdentifier, false, defaultPayloadType), nil
}
// AcceptStream accepts a stream
@ -1195,12 +1412,17 @@ func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream
}
// getOrCreateStream gets or creates a stream. The caller should hold the lock.
func (a *Association) getOrCreateStream(streamIdentifier uint16) *Stream {
func (a *Association) getOrCreateStream(streamIdentifier uint16, accept bool, defaultPayloadType PayloadProtocolIdentifier) *Stream {
if s, ok := a.streams[streamIdentifier]; ok {
s.SetDefaultPayloadType(defaultPayloadType)
return s
}
return a.createStream(streamIdentifier, true)
s := a.createStream(streamIdentifier, accept)
if s != nil {
s.SetDefaultPayloadType(defaultPayloadType)
}
return s
}
// The caller should hold the lock.
@ -1213,7 +1435,7 @@ func (a *Association) processSelectiveAck(d *chunkSelectiveAck) (map[uint16]int,
for i := a.cumulativeTSNAckPoint + 1; sna32LTE(i, d.cumulativeTSNAck); i++ {
c, ok := a.inflightQueue.pop(i)
if !ok {
return nil, 0, errors.Errorf("tsn %v unable to be popped from inflight queue", i)
return nil, 0, fmt.Errorf("%w: %v", ErrInflightQueueTSNPop, i)
}
if !c.acked {
@ -1249,6 +1471,7 @@ func (a *Association) processSelectiveAck(d *chunkSelectiveAck) (map[uint16]int,
a.minTSN2MeasureRTT = a.myNextTSN
rtt := time.Since(c.since).Seconds() * 1000.0
srtt := a.rtoMgr.setNewRTT(rtt)
a.srtt.Store(srtt)
a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f",
a.name, rtt, srtt, a.rtoMgr.getRTO())
}
@ -1268,7 +1491,7 @@ func (a *Association) processSelectiveAck(d *chunkSelectiveAck) (map[uint16]int,
tsn := d.cumulativeTSNAck + uint32(i)
c, ok := a.inflightQueue.get(tsn)
if !ok {
return nil, 0, errors.Errorf("requested non-existent TSN %v", tsn)
return nil, 0, fmt.Errorf("%w: %v", ErrTSNRequestNotExist, tsn)
}
if !c.acked {
@ -1287,6 +1510,7 @@ func (a *Association) processSelectiveAck(d *chunkSelectiveAck) (map[uint16]int,
a.minTSN2MeasureRTT = a.myNextTSN
rtt := time.Since(c.since).Seconds() * 1000.0
srtt := a.rtoMgr.setNewRTT(rtt)
a.srtt.Store(srtt)
a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f",
a.name, rtt, srtt, a.rtoMgr.getRTO())
}
@ -1384,7 +1608,7 @@ func (a *Association) processFastRetransmission(cumTSNAckPoint, htna uint32, cum
for tsn := cumTSNAckPoint + 1; sna32LT(tsn, maxTSN); tsn++ {
c, ok := a.inflightQueue.get(tsn)
if !ok {
return errors.Errorf("requested non-existent TSN %v", tsn)
return fmt.Errorf("%w: %v", ErrTSNRequestNotExist, tsn)
}
if !c.acked && !c.abandoned() && c.missIndicator < 3 {
c.missIndicator++
@ -1419,7 +1643,7 @@ func (a *Association) processFastRetransmission(cumTSNAckPoint, htna uint32, cum
func (a *Association) handleSack(d *chunkSelectiveAck) error {
a.log.Tracef("[%s] SACK: cumTSN=%d a_rwnd=%d", a.name, d.cumulativeTSNAck, d.advertisedReceiverWindowCredit)
state := a.getState()
if state != established {
if state != established && state != shutdownPending && state != shutdownReceived {
return nil
}
@ -1518,19 +1742,94 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error {
a.awakeWriteLoop()
}
if a.inflightQueue.size() > 0 {
a.postprocessSack(state, cumTSNAckPointAdvanced)
return nil
}
// The caller must hold the lock. This method was only added because the
// linter was complaining about the "cognitive complexity" of handleSack.
func (a *Association) postprocessSack(state uint32, shouldAwakeWriteLoop bool) {
switch {
case a.inflightQueue.size() > 0:
// Start timer. (noop if already started)
a.log.Tracef("[%s] T3-rtx timer start (pt3)", a.name)
a.t3RTX.start(a.rtoMgr.getRTO())
case state == shutdownPending:
// No more outstanding, send shutdown.
shouldAwakeWriteLoop = true
a.willSendShutdown = true
a.setState(shutdownSent)
case state == shutdownReceived:
// No more outstanding, send shutdown ack.
shouldAwakeWriteLoop = true
a.willSendShutdownAck = true
a.setState(shutdownAckSent)
}
if cumTSNAckPointAdvanced {
if shouldAwakeWriteLoop {
a.awakeWriteLoop()
}
}
// The caller should hold the lock.
func (a *Association) handleShutdown(_ *chunkShutdown) {
state := a.getState()
switch state {
case established:
if a.inflightQueue.size() > 0 {
a.setState(shutdownReceived)
} else {
// No more outstanding, send shutdown ack.
a.willSendShutdownAck = true
a.setState(shutdownAckSent)
a.awakeWriteLoop()
}
// a.cumulativeTSNAckPoint = c.cumulativeTSNAck
case shutdownSent:
a.willSendShutdownAck = true
a.setState(shutdownAckSent)
a.awakeWriteLoop()
}
}
// The caller should hold the lock.
func (a *Association) handleShutdownAck(_ *chunkShutdownAck) {
state := a.getState()
if state == shutdownSent || state == shutdownAckSent {
a.t2Shutdown.stop()
a.willSendShutdownComplete = true
a.awakeWriteLoop()
}
}
func (a *Association) handleShutdownComplete(_ *chunkShutdownComplete) error {
state := a.getState()
if state == shutdownAckSent {
a.t2Shutdown.stop()
return a.close()
}
return nil
}
func (a *Association) handleAbort(c *chunkAbort) error {
var errStr string
for _, e := range c.errorCauses {
errStr += fmt.Sprintf("(%s)", e)
}
_ = a.close()
return fmt.Errorf("[%s] %w: %s", a.name, ErrChunk, errStr)
}
// createForwardTSN generates ForwardTSN chunk.
// This method will be be called if useForwardTSN is set to false.
// The caller should hold the lock.
@ -1633,7 +1932,7 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet {
// send a SACK to its peer (the sender of the FORWARD TSN) since such a
// duplicate may indicate the previous SACK was lost in the network.
a.log.Tracef("[%s] should send ack? newCumTSN=%d peerLastTSN=%d\n",
a.log.Tracef("[%s] should send ack? newCumTSN=%d peerLastTSN=%d",
a.name, c.newCumulativeTSN, a.peerLastTSN)
if sna32LTE(c.newCumulativeTSN, a.peerLastTSN) {
a.log.Tracef("[%s] sending ack on Forward TSN", a.name)
@ -1686,7 +1985,7 @@ func (a *Association) sendResetRequest(streamIdentifier uint16) error {
state := a.getState()
if state != established {
return errors.Errorf("sending reset packet in non-established state: state=%s",
return fmt.Errorf("%w: state=%s", ErrResetPacketInStateNotExist,
getAssociationStateString(state))
}
@ -1708,21 +2007,23 @@ func (a *Association) sendResetRequest(streamIdentifier uint16) error {
func (a *Association) handleReconfigParam(raw param) (*packet, error) {
switch p := raw.(type) {
case *paramOutgoingResetRequest:
a.log.Tracef("[%s] handleReconfigParam (OutgoingResetRequest)", a.name)
a.reconfigRequests[p.reconfigRequestSequenceNumber] = p
resp := a.resetStreamsIfAny(p)
if resp != nil {
return resp, nil
}
return nil, nil
return nil, nil //nolint:nilnil
case *paramReconfigResponse:
a.log.Tracef("[%s] handleReconfigParam (ReconfigResponse)", a.name)
delete(a.reconfigs, p.reconfigResponseSequenceNumber)
if len(a.reconfigs) == 0 {
a.tReconfig.stop()
}
return nil, nil
return nil, nil //nolint:nilnil
default:
return nil, errors.Errorf("unexpected parameter type %T", p)
return nil, fmt.Errorf("%w: %t", ErrParamterType, p)
}
}
@ -1737,7 +2038,11 @@ func (a *Association) resetStreamsIfAny(p *paramOutgoingResetRequest) *packet {
if !ok {
continue
}
a.unregisterStream(s, io.EOF)
a.lock.Unlock()
s.onInboundStreamReset()
a.lock.Lock()
a.log.Debugf("[%s] deleting stream %d", a.name, id)
delete(a.streams, s.streamIdentifier)
}
delete(a.reconfigRequests, p.reconfigRequestSequenceNumber)
} else {
@ -1777,7 +2082,6 @@ func (a *Association) movePendingDataChunkToInflightQueue(c *chunkPayloadData) {
a.log.Tracef("[%s] sending ppi=%d tsn=%d ssn=%d sent=%d len=%d (%v,%v)",
a.name, c.payloadType, c.tsn, c.streamSequenceNumber, c.nSent, len(c.userData), c.beginningFragment, c.endingFragment)
// Push it into the inflightQueue
a.inflightQueue.pushNoCheck(c)
}
@ -1880,7 +2184,7 @@ func (a *Association) sendPayloadData(chunks []*chunkPayloadData) error {
state := a.getState()
if state != established {
return errors.Errorf("sending payload data in non-established state: state=%s",
return fmt.Errorf("%w: state=%s", ErrPayloadDataStateNotExist,
getAssociationStateString(state))
}
@ -2014,7 +2318,6 @@ func (a *Association) handleChunkEnd() {
defer a.lock.Unlock()
if a.immediateAckTriggered {
// Send SACK now!
a.ackState = ackStateImmediate
a.ackTimer.stop()
a.awakeWriteLoop()
@ -2037,6 +2340,8 @@ func (a *Association) handleChunk(p *packet, c chunk) error {
return nil
}
isAbort := false
switch c := c.(type) {
case *chunkInit:
packets, err = a.handleInit(p, c)
@ -2045,11 +2350,8 @@ func (a *Association) handleChunk(p *packet, c chunk) error {
err = a.handleInitAck(p, c)
case *chunkAbort:
var errStr string
for _, e := range c.errorCauses {
errStr += fmt.Sprintf("(%s)", e)
}
return fmt.Errorf("[%s] %w: %s", a.name, errChunk, errStr)
isAbort = true
err = a.handleAbort(c)
case *chunkError:
var errStr string
@ -2079,12 +2381,23 @@ func (a *Association) handleChunk(p *packet, c chunk) error {
case *chunkForwardTSN:
packets = a.handleForwardTSN(c)
case *chunkShutdown:
a.handleShutdown(c)
case *chunkShutdownAck:
a.handleShutdownAck(c)
case *chunkShutdownComplete:
err = a.handleShutdownComplete(c)
default:
err = errors.Errorf("unhandled chunk type")
err = ErrChunkTypeUnhandled
}
// Log and return, the only condition that is fatal is a ABORT chunk
if err != nil {
if isAbort {
return err
}
a.log.Errorf("Failed to handle chunk: %v", err)
return nil
}
@ -2117,6 +2430,20 @@ func (a *Association) onRetransmissionTimeout(id int, nRtos uint) {
return
}
if id == timerT2Shutdown {
a.log.Debugf("[%s] retransmission of shutdown timeout (nRtos=%d): %v", a.name, nRtos)
state := a.getState()
switch state {
case shutdownSent:
a.willSendShutdown = true
a.awakeWriteLoop()
case shutdownAckSent:
a.willSendShutdownAck = true
a.awakeWriteLoop()
}
}
if id == timerT3RTX {
a.stats.incT3Timeouts()
@ -2190,13 +2517,18 @@ func (a *Association) onRetransmissionFailure(id int) {
if id == timerT1Init {
a.log.Errorf("[%s] retransmission failure: T1-init", a.name)
a.handshakeCompletedCh <- errors.Errorf("handshake failed (INIT ACK)")
a.handshakeCompletedCh <- ErrHandshakeInitAck
return
}
if id == timerT1Cookie {
a.log.Errorf("[%s] retransmission failure: T1-cookie", a.name)
a.handshakeCompletedCh <- errors.Errorf("handshake failed (COOKIE ECHO)")
a.handshakeCompletedCh <- ErrHandshakeCookieEcho
return
}
if id == timerT2Shutdown {
a.log.Errorf("[%s] retransmission failure: T2-shutdown", a.name)
return
}