diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 9e39904f2..de8a4f6a4 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -179,6 +179,13 @@ func (v *RTMPMultipleError) Error() string { return b.String() } +func (v *RTMPMultipleError) Cause() error { + if len(v.errs) == 0 { + return nil + } + return v.errs[0] +} + type RTMPProxyError struct { // Whether error is caused by backend. isBackend bool @@ -190,6 +197,10 @@ func (v *RTMPProxyError) Error() string { return v.err.Error() } +func (v *RTMPProxyError) Cause() error { + return v.err +} + type RTMPConnection struct { // The random number generator. rd *rand.Rand @@ -207,13 +218,18 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr()) // Close the connection when ctx done. - connDoneCtx, connDoneCancel := context.WithCancel(ctx) - defer connDoneCancel() - go func() { - <-connDoneCtx.Done() - time.Sleep(10 * time.Millisecond) - conn.Close() - }() + var backend *RTMPClientToBackend + if true { + connDoneCtx, connDoneCancel := context.WithCancel(ctx) + defer connDoneCancel() + go func() { + <-connDoneCtx.Done() + conn.Close() + if backend != nil { + backend.Close() + } + }() + } // Simple handshake with client. hs := rtmp.NewHandshake(v.rd) @@ -358,7 +374,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { tcUrl, streamName, currentStreamID, clientType) // Find a backend SRS server to proxy the RTMP stream. - backend := NewRTMPClientToBackend(func(client *RTMPClientToBackend) { + backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) { client.rd, client.typ = v.rd, clientType }) defer backend.Close() @@ -408,11 +424,16 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { var wg sync.WaitGroup defer wg.Wait() + // If any goroutine quit, cancel another one. + parentCtx := ctx + ctx, cancel := context.WithCancel(ctx) + // Proxy all message from backend to client. wg.Add(1) var r0 error go func() { defer wg.Done() + defer cancel() r0 = func() error { for { @@ -435,6 +456,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { var r1 error go func() { defer wg.Done() + defer cancel() r1 = func() error { for { @@ -452,8 +474,18 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { }() }() + // Wait until all goroutine quit. wg.Wait() - return NewRTMPMultipleError(r0, r1) + + // Reset the error if caused by another goroutine. + if errors.Cause(r0) == context.Canceled && parentCtx.Err() == nil { + r0 = nil + } + if errors.Cause(r1) == context.Canceled && parentCtx.Err() == nil { + r1 = nil + } + + return NewRTMPMultipleError(r0, r1, parentCtx.Err()) } type RTMPClientType string diff --git a/proxy/utils.go b/proxy/utils.go index 0644fabef..5f3e813fc 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -97,9 +97,12 @@ func buildStreamURL(r string) (string, error) { // isPeerClosedError indicates whether peer object closed the connection. func isPeerClosedError(err error) bool { causeErr := errors.Cause(err) - if stdErr.Is(causeErr, io.EOF) || - stdErr.Is(causeErr, net.ErrClosed) || - stdErr.Is(causeErr, syscall.EPIPE) { + + if stdErr.Is(causeErr, io.EOF) { + return true + } + + if stdErr.Is(causeErr, syscall.EPIPE) { return true }