From 05c3a422a5bd163a5b61b231cdbe0b99e49bf2c0 Mon Sep 17 00:00:00 2001 From: Winlin Date: Sat, 31 Aug 2024 23:15:51 +0800 Subject: [PATCH 01/12] HTTP-FLV: Notify connection to expire when unpublishing. v6.0.152 v7.0.11 (#4164) When stopping the stream, it will wait for the HTTP Streaming to exit. If the HTTP Streaming goroutine hangs, it will not exit automatically. ```cpp void SrsHttpStreamServer::http_unmount(SrsRequest* r) { SrsUniquePtr stream(entry->stream); if (stream->entry) stream->entry->enabled = false; srs_usleep(...); // Wait for about 120s. mux.unhandle(entry->mount, stream.get()); // Free stream. } srs_error_t SrsLiveStream::serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r) { err = do_serve_http(w, r); // If stuck in here for 120s+ alive_viewers_--; // Crash at here, because stream has been deleted. ``` We should notify http stream connection to interrupt(expire): ```cpp void SrsHttpStreamServer::http_unmount(SrsRequest* r) { SrsUniquePtr stream(entry->stream); if (stream->entry) stream->entry->enabled = false; stream->expire(); // Notify http stream to interrupt. ``` Note that we should notify all viewers pulling stream from this http stream. Note that we have tried to fix this issue, but only try to wait for all viewers to quit, without interrupting the viewers, see https://github.com/ossrs/srs/pull/4144 --------- Co-authored-by: Jacob Su --- trunk/doc/CHANGELOG.md | 2 ++ trunk/src/app/srs_app_http_api.cpp | 2 -- trunk/src/app/srs_app_http_stream.cpp | 30 ++++++++++++++++++++++----- trunk/src/app/srs_app_http_stream.hpp | 9 ++++++-- trunk/src/core/srs_core_version6.hpp | 2 +- trunk/src/core/srs_core_version7.hpp | 2 +- 6 files changed, 36 insertions(+), 11 deletions(-) diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index 2ca58e903..b6f396c5c 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2024-08-31, Merge [#4164](https://github.com/ossrs/srs/pull/4164): HTTP-FLV: Notify connection to expire when unpublishing. v7.0.11 (#4164) * v7.0, 2024-08-24, Merge [#4157](https://github.com/ossrs/srs/pull/4157): Fix crash when quiting. v7.0.10 (#4157) * v7.0, 2024-08-24, Merge [#4156](https://github.com/ossrs/srs/pull/4156): Build: Fix srs_mp4_parser compiling error. v7.0.9 (#4156) * v7.0, 2024-08-22, Merge [#4154](https://github.com/ossrs/srs/pull/4154): ASAN: Disable memory leak detection by default. v7.0.8 (#4154) @@ -22,6 +23,7 @@ The changelog for SRS. ## SRS 6.0 Changelog +* v6.0, 2024-08-31, Merge [#4164](https://github.com/ossrs/srs/pull/4164): HTTP-FLV: Notify connection to expire when unpublishing. v6.0.152 (#4164) * v6.0, 2024-08-24, Merge [#4157](https://github.com/ossrs/srs/pull/4157): Fix crash when quiting. v6.0.151 (#4157) * v6.0, 2024-08-24, Merge [#4156](https://github.com/ossrs/srs/pull/4156): Build: Fix srs_mp4_parser compiling error. v6.0.150 (#4156) * v6.0, 2024-08-21, Merge [#4150](https://github.com/ossrs/srs/pull/4150): API: Support new HTTP API for VALGRIND. v6.0.149 (#4150) diff --git a/trunk/src/app/srs_app_http_api.cpp b/trunk/src/app/srs_app_http_api.cpp index a083a218a..3eb342c30 100644 --- a/trunk/src/app/srs_app_http_api.cpp +++ b/trunk/src/app/srs_app_http_api.cpp @@ -1206,8 +1206,6 @@ SrsGoApiSignal::~SrsGoApiSignal() srs_error_t SrsGoApiSignal::serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r) { - srs_error_t err = srs_success; - std::string signal = r->query_get("signo"); srs_trace("query signo=%s", signal.c_str()); diff --git a/trunk/src/app/srs_app_http_stream.cpp b/trunk/src/app/srs_app_http_stream.cpp index b6f7fa47b..0a4a8e5b8 100755 --- a/trunk/src/app/srs_app_http_stream.cpp +++ b/trunk/src/app/srs_app_http_stream.cpp @@ -583,13 +583,15 @@ SrsLiveStream::SrsLiveStream(SrsRequest* r, SrsBufferCache* c) cache = c; req = r->copy()->as_http(); security_ = new SrsSecurity(); - alive_viewers_ = 0; } SrsLiveStream::~SrsLiveStream() { srs_freep(req); srs_freep(security_); + + // The live stream should never be destroyed when it's serving any viewers. + srs_assert(viewers_.empty()); } srs_error_t SrsLiveStream::update_auth(SrsRequest* r) @@ -634,10 +636,18 @@ srs_error_t SrsLiveStream::serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage return srs_error_wrap(err, "http hook"); } - alive_viewers_++; + // Add the viewer to the viewers list. + viewers_.push_back(hc); + + // Serve the viewer connection. err = do_serve_http(w, r); - alive_viewers_--; - + + // Remove viewer from the viewers list. + vector::iterator it = std::find(viewers_.begin(), viewers_.end(), hc); + srs_assert (it != viewers_.end()); + viewers_.erase(it); + + // Do hook after serving. http_hooks_on_stop(r); return err; @@ -645,7 +655,16 @@ srs_error_t SrsLiveStream::serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage bool SrsLiveStream::alive() { - return alive_viewers_ > 0; + return !viewers_.empty(); +} + +void SrsLiveStream::expire() +{ + vector::iterator it; + for (it = viewers_.begin(); it != viewers_.end(); ++it) { + ISrsExpire* conn = *it; + conn->expire(); + } } srs_error_t SrsLiveStream::do_serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r) @@ -1075,6 +1094,7 @@ void SrsHttpStreamServer::http_unmount(SrsRequest* r) // Notify cache and stream to stop. if (stream->entry) stream->entry->enabled = false; + stream->expire(); cache->stop(); // Wait for cache and stream to stop. diff --git a/trunk/src/app/srs_app_http_stream.hpp b/trunk/src/app/srs_app_http_stream.hpp index 2c557972b..2e233ced6 100755 --- a/trunk/src/app/srs_app_http_stream.hpp +++ b/trunk/src/app/srs_app_http_stream.hpp @@ -11,6 +11,8 @@ #include #include +#include + class SrsAacTransmuxer; class SrsMp3Transmuxer; class SrsFlvTransmuxer; @@ -176,7 +178,7 @@ public: // HTTP Live Streaming, to transmux RTMP to HTTP FLV or other format. // TODO: FIXME: Rename to SrsHttpLive -class SrsLiveStream : public ISrsHttpHandler +class SrsLiveStream : public ISrsHttpHandler, public ISrsExpire { private: SrsRequest* req; @@ -185,7 +187,7 @@ private: // For multiple viewers, which means there will more than one alive viewers for a live stream, so we must // use an int value to represent if there is any viewer is alive. We should never do cleanup unless all // viewers closed the connection. - int alive_viewers_; + std::vector viewers_; public: SrsLiveStream(SrsRequest* r, SrsBufferCache* c); virtual ~SrsLiveStream(); @@ -193,6 +195,9 @@ public: public: virtual srs_error_t serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r); virtual bool alive(); +// Interface ISrsExpire +public: + virtual void expire(); private: virtual srs_error_t do_serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r); virtual srs_error_t http_hooks_on_play(ISrsHttpMessage* r); diff --git a/trunk/src/core/srs_core_version6.hpp b/trunk/src/core/srs_core_version6.hpp index edc5171e3..28d5d9219 100644 --- a/trunk/src/core/srs_core_version6.hpp +++ b/trunk/src/core/srs_core_version6.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 6 #define VERSION_MINOR 0 -#define VERSION_REVISION 151 +#define VERSION_REVISION 152 #endif diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index 7cac790c4..84ef6eefd 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 10 +#define VERSION_REVISION 11 #endif \ No newline at end of file From a7aa2eaf76e60e076f77f4747acb8816172d1be4 Mon Sep 17 00:00:00 2001 From: Winlin Date: Sun, 1 Sep 2024 06:40:16 +0800 Subject: [PATCH 02/12] Fix #3767: RTMP: Do not response empty data packet. v6.0.153 v7.0.12 (#4162) If SRS responds with this empty data packet, FFmpeg will receive an empty stream, like `Stream #0:0: Data: none` in following logs: ```bash ffmpeg -i rtmp://localhost:11935/live/livestream # Stream #0:0: Data: none # Stream #0:1: Audio: aac (LC), 44100 Hz, stereo, fltp, 30 kb/s # Stream #0:2: Video: h264 (High), yuv420p(progressive), 768x320 [SAR 1:1 DAR 12:5], 212 kb/s, 25 fps, 25 tbr, 1k tbn ``` This won't cause the player to fail, but it will inconvenience the user significantly. It may also cause FFmpeg slower to analysis the stream, see #3767 --------- Co-authored-by: Jacob Su --- trunk/doc/CHANGELOG.md | 2 ++ trunk/src/core/srs_core_version6.hpp | 2 +- trunk/src/core/srs_core_version7.hpp | 2 +- trunk/src/protocol/srs_protocol_rtmp_stack.cpp | 3 ++- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index b6f396c5c..0bea6e3cc 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2024-08-31, Merge [#4162](https://github.com/ossrs/srs/pull/4162): Fix #3767: RTMP: Do not response empty data packet. v7.0.12 (#4162) * v7.0, 2024-08-31, Merge [#4164](https://github.com/ossrs/srs/pull/4164): HTTP-FLV: Notify connection to expire when unpublishing. v7.0.11 (#4164) * v7.0, 2024-08-24, Merge [#4157](https://github.com/ossrs/srs/pull/4157): Fix crash when quiting. v7.0.10 (#4157) * v7.0, 2024-08-24, Merge [#4156](https://github.com/ossrs/srs/pull/4156): Build: Fix srs_mp4_parser compiling error. v7.0.9 (#4156) @@ -23,6 +24,7 @@ The changelog for SRS. ## SRS 6.0 Changelog +* v6.0, 2024-08-31, Merge [#4162](https://github.com/ossrs/srs/pull/4162): Fix #3767: RTMP: Do not response empty data packet. v6.0.153 (#4162) * v6.0, 2024-08-31, Merge [#4164](https://github.com/ossrs/srs/pull/4164): HTTP-FLV: Notify connection to expire when unpublishing. v6.0.152 (#4164) * v6.0, 2024-08-24, Merge [#4157](https://github.com/ossrs/srs/pull/4157): Fix crash when quiting. v6.0.151 (#4157) * v6.0, 2024-08-24, Merge [#4156](https://github.com/ossrs/srs/pull/4156): Build: Fix srs_mp4_parser compiling error. v6.0.150 (#4156) diff --git a/trunk/src/core/srs_core_version6.hpp b/trunk/src/core/srs_core_version6.hpp index 28d5d9219..c5007030a 100644 --- a/trunk/src/core/srs_core_version6.hpp +++ b/trunk/src/core/srs_core_version6.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 6 #define VERSION_MINOR 0 -#define VERSION_REVISION 152 +#define VERSION_REVISION 153 #endif diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index 84ef6eefd..a97cb3697 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 11 +#define VERSION_REVISION 12 #endif \ No newline at end of file diff --git a/trunk/src/protocol/srs_protocol_rtmp_stack.cpp b/trunk/src/protocol/srs_protocol_rtmp_stack.cpp index 7b1a20395..2e75b0d4f 100644 --- a/trunk/src/protocol/srs_protocol_rtmp_stack.cpp +++ b/trunk/src/protocol/srs_protocol_rtmp_stack.cpp @@ -2569,7 +2569,8 @@ srs_error_t SrsRtmpServer::start_play(int stream_id) } // onStatus(NetStream.Data.Start) - if (true) { + // We should not response this packet, or there is an empty stream "Stream #0:0: Data: none" in FFmpeg. + if (false) { SrsOnStatusDataPacket* pkt = new SrsOnStatusDataPacket(); pkt->data->set(StatusCode, SrsAmf0Any::str(StatusCodeDataStart)); if ((err = protocol->send_and_free_packet(pkt, stream_id)) != srs_success) { From 740f0d38ec97a05b3eb96abb41bf24ef6970077e Mon Sep 17 00:00:00 2001 From: Winlin Date: Sun, 1 Sep 2024 06:44:35 +0800 Subject: [PATCH 03/12] Edge: Fix flv edge crash when http unmount. v6.0.154 v7.0.13 (#4166) Edge FLV is not working because it is stuck in an infinite loop waiting. Previously, there was no need to wait for exit since resources were not being cleaned up. Now, since resources need to be cleaned up, it must wait for all active connections to exit, which causes this issue. To reproduce the issue, start SRS edge, run the bellow command and press `CTRL+C` to stop the request: ```bash curl http://localhost:8080/live/livestream.flv -v >/dev/null ``` It will cause edge to fetch stream from origin, and free the consumer when client quit. When `SrsLiveStream::do_serve_http` return, it will free the consumer: ```cpp srs_error_t SrsLiveStream::do_serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r) { SrsUniquePtr consumer(consumer_raw); ``` Keep in mind that in this moment, the stream is alive, because only set to not alive after this function return: ```cpp alive_viewers_++; err = do_serve_http(w, r); // Free 'this' alive stream. alive_viewers_--; // Crash here, because 'this' is freed. ``` When freeing the consumer, it will cause the source to unpublish and attempt to free the HTTP handler, which ultimately waits for the stream not to be alive: ```cpp SrsLiveConsumer::~SrsLiveConsumer() { source_->on_consumer_destroy(this); void SrsLiveSource::on_consumer_destroy(SrsLiveConsumer* consumer) { if (consumers.empty()) { play_edge->on_all_client_stop(); void SrsLiveSource::on_unpublish() { handler->on_unpublish(req); void SrsHttpStreamServer::http_unmount(SrsRequest* r) { if (stream->entry) stream->entry->enabled = false; for (; i < 1024; i++) { if (!cache->alive() && !stream->alive()) { break; } srs_usleep(100 * SRS_UTIME_MILLISECONDS); } ``` After 120 seconds, it will free the stream and cause SRS to crash because the stream is still active. In order to track this potential issue, also add an important warning log: ```cpp srs_warn("http: try to free a alive stream, cache=%d, stream=%d", cache->alive(), stream->alive()); ``` SRS may crash if got this log. --------- Co-authored-by: Jacob Su --- trunk/doc/CHANGELOG.md | 2 ++ trunk/src/app/srs_app_http_stream.cpp | 41 ++++++++++++++++----------- trunk/src/app/srs_app_http_stream.hpp | 2 +- trunk/src/app/srs_app_rtc_source.hpp | 2 +- trunk/src/app/srs_app_utility.hpp | 2 +- trunk/src/core/srs_core_version6.hpp | 2 +- trunk/src/core/srs_core_version7.hpp | 2 +- 7 files changed, 31 insertions(+), 22 deletions(-) diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index 0bea6e3cc..2e5ad9b92 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v7.0.13 (#4166) * v7.0, 2024-08-31, Merge [#4162](https://github.com/ossrs/srs/pull/4162): Fix #3767: RTMP: Do not response empty data packet. v7.0.12 (#4162) * v7.0, 2024-08-31, Merge [#4164](https://github.com/ossrs/srs/pull/4164): HTTP-FLV: Notify connection to expire when unpublishing. v7.0.11 (#4164) * v7.0, 2024-08-24, Merge [#4157](https://github.com/ossrs/srs/pull/4157): Fix crash when quiting. v7.0.10 (#4157) @@ -24,6 +25,7 @@ The changelog for SRS. ## SRS 6.0 Changelog +* v6.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v6.0.154 (#4166) * v6.0, 2024-08-31, Merge [#4162](https://github.com/ossrs/srs/pull/4162): Fix #3767: RTMP: Do not response empty data packet. v6.0.153 (#4162) * v6.0, 2024-08-31, Merge [#4164](https://github.com/ossrs/srs/pull/4164): HTTP-FLV: Notify connection to expire when unpublishing. v6.0.152 (#4164) * v6.0, 2024-08-24, Merge [#4157](https://github.com/ossrs/srs/pull/4157): Fix crash when quiting. v6.0.151 (#4157) diff --git a/trunk/src/app/srs_app_http_stream.cpp b/trunk/src/app/srs_app_http_stream.cpp index 0a4a8e5b8..df2c4a523 100755 --- a/trunk/src/app/srs_app_http_stream.cpp +++ b/trunk/src/app/srs_app_http_stream.cpp @@ -636,17 +636,32 @@ srs_error_t SrsLiveStream::serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage return srs_error_wrap(err, "http hook"); } + SrsSharedPtr live_source = _srs_sources->fetch(req); + if (!live_source.get()) { + return srs_error_new(ERROR_NO_SOURCE, "no source for %s", req->get_stream_url().c_str()); + } + + // Create consumer of source, ignore gop cache, use the audio gop cache. + SrsLiveConsumer* consumer_raw = NULL; + if ((err = live_source->create_consumer(consumer_raw)) != srs_success) { + return srs_error_wrap(err, "create consumer"); + } + // When freeing the consumer, it may trigger the source unpublishing for edge. This will trigger the http + // unmount, which waiting for all http live stream to dispose, so we should free the consumer when this + // object is not alive. + SrsUniquePtr consumer(consumer_raw); + // Add the viewer to the viewers list. viewers_.push_back(hc); // Serve the viewer connection. - err = do_serve_http(w, r); + err = do_serve_http(live_source.get(), consumer.get(), w, r); // Remove viewer from the viewers list. vector::iterator it = std::find(viewers_.begin(), viewers_.end(), hc); srs_assert (it != viewers_.end()); viewers_.erase(it); - + // Do hook after serving. http_hooks_on_stop(r); @@ -667,7 +682,7 @@ void SrsLiveStream::expire() } } -srs_error_t SrsLiveStream::do_serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r) +srs_error_t SrsLiveStream::do_serve_http(SrsLiveSource* source, SrsLiveConsumer* consumer, ISrsHttpResponseWriter* w, ISrsHttpMessage* r) { srs_error_t err = srs_success; @@ -711,19 +726,7 @@ srs_error_t SrsLiveStream::do_serve_http(ISrsHttpResponseWriter* w, ISrsHttpMess // Enter chunked mode, because we didn't set the content-length. w->write_header(SRS_CONSTS_HTTP_OK); - SrsSharedPtr live_source = _srs_sources->fetch(req); - if (!live_source.get()) { - return srs_error_new(ERROR_NO_SOURCE, "no source for %s", req->get_stream_url().c_str()); - } - - // create consumer of souce, ignore gop cache, use the audio gop cache. - SrsLiveConsumer* consumer_raw = NULL; - if ((err = live_source->create_consumer(consumer_raw)) != srs_success) { - return srs_error_wrap(err, "create consumer"); - } - SrsUniquePtr consumer(consumer_raw); - - if ((err = live_source->consumer_dumps(consumer.get(), true, true, !enc->has_cache())) != srs_success) { + if ((err = source->consumer_dumps(consumer, true, true, !enc->has_cache())) != srs_success) { return srs_error_wrap(err, "dumps consumer"); } @@ -744,7 +747,7 @@ srs_error_t SrsLiveStream::do_serve_http(ISrsHttpResponseWriter* w, ISrsHttpMess // if gop cache enabled for encoder, dump to consumer. if (enc->has_cache()) { - if ((err = enc->dump_cache(consumer.get(), live_source->jitter())) != srs_success) { + if ((err = enc->dump_cache(consumer, source->jitter())) != srs_success) { return srs_error_wrap(err, "encoder dump cache"); } } @@ -1106,6 +1109,10 @@ void SrsHttpStreamServer::http_unmount(SrsRequest* r) srs_usleep(100 * SRS_UTIME_MILLISECONDS); } + if (cache->alive() || stream->alive()) { + srs_warn("http: try to free a alive stream, cache=%d, stream=%d", cache->alive(), stream->alive()); + } + // Unmount the HTTP handler, which will free the entry. Note that we must free it after cache and // stream stopped for it uses it. mux.unhandle(entry->mount, stream.get()); diff --git a/trunk/src/app/srs_app_http_stream.hpp b/trunk/src/app/srs_app_http_stream.hpp index 2e233ced6..cf3737982 100755 --- a/trunk/src/app/srs_app_http_stream.hpp +++ b/trunk/src/app/srs_app_http_stream.hpp @@ -199,7 +199,7 @@ public: public: virtual void expire(); private: - virtual srs_error_t do_serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r); + virtual srs_error_t do_serve_http(SrsLiveSource* source, SrsLiveConsumer* consumer, ISrsHttpResponseWriter* w, ISrsHttpMessage* r); virtual srs_error_t http_hooks_on_play(ISrsHttpMessage* r); virtual void http_hooks_on_stop(ISrsHttpMessage* r); virtual srs_error_t streaming_send_messages(ISrsBufferEncoder* enc, SrsSharedPtrMessage** msgs, int nb_msgs); diff --git a/trunk/src/app/srs_app_rtc_source.hpp b/trunk/src/app/srs_app_rtc_source.hpp index e917316d5..b8450a51b 100644 --- a/trunk/src/app/srs_app_rtc_source.hpp +++ b/trunk/src/app/srs_app_rtc_source.hpp @@ -253,7 +253,7 @@ public: void set_publish_stream(ISrsRtcPublishStream* v); // Consume the shared RTP packet, user must free it. srs_error_t on_rtp(SrsRtpPacket* pkt); - // Set and get stream description for souce + // Set and get stream description for source bool has_stream_desc(); void set_stream_desc(SrsRtcSourceDescription* stream_desc); std::vector get_track_desc(std::string type, std::string media_type); diff --git a/trunk/src/app/srs_app_utility.hpp b/trunk/src/app/srs_app_utility.hpp index 347cb54bf..b7b877154 100644 --- a/trunk/src/app/srs_app_utility.hpp +++ b/trunk/src/app/srs_app_utility.hpp @@ -55,7 +55,7 @@ extern std::string srs_path_build_timestamp(std::string template_path); // @return an int error code. extern srs_error_t srs_kill_forced(int& pid); -// Current process resouce usage. +// Current process resource usage. // @see: man getrusage class SrsRusage { diff --git a/trunk/src/core/srs_core_version6.hpp b/trunk/src/core/srs_core_version6.hpp index c5007030a..ed3ffb50d 100644 --- a/trunk/src/core/srs_core_version6.hpp +++ b/trunk/src/core/srs_core_version6.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 6 #define VERSION_MINOR 0 -#define VERSION_REVISION 153 +#define VERSION_REVISION 154 #endif diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index a97cb3697..0f7e659ef 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 12 +#define VERSION_REVISION 13 #endif \ No newline at end of file From 15fbe45a9a3aa93fc2b09d402379f698fb45c9bd Mon Sep 17 00:00:00 2001 From: Winlin Date: Sun, 1 Sep 2024 13:02:07 +0800 Subject: [PATCH 04/12] FLV: Refine source and http handler. v6.0.155 v7.0.14 (#4165) 1. Do not create a source when mounting FLV because it may not unmount FLV when freeing the source. If you access the FLV stream without any publisher, then wait for source cleanup and review the FLV stream again, there is an annoying warning message. ```bash # View HTTP FLV stream by curl, wait for stream to be ready. # curl http://localhost:8080/live/livestream.flv -v >/dev/null HTTP #0 127.0.0.1:58026 GET http://localhost:8080/live/livestream.flv, content-length=-1 new live source, stream_url=/live/livestream http: mount flv stream for sid=/live/livestream, mount=/live/livestream.flv # Cancel the curl and trigger source cleanup without http unmount. client disconnect peer. ret=1007 Live: cleanup die source, id=[], total=1 # View the stream again, it fails. # curl http://localhost:8080/live/livestream.flv -v >/dev/null HTTP #0 127.0.0.1:58040 GET http://localhost:8080/live/livestream.flv, content-length=-1 serve error code=1097(NoSource)(No source found) : process request=0 : cors serve : serve http : no source for /live/livestream serve_http() [srs_app_http_stream.cpp:641] ``` > Note: There is an inconsistency. The first time, you can access the FLV stream and wait for the publisher, but the next time, you cannot. 2. Create a source when starting to serve the FLV client. We do not need to create the source when creating the HTTP handler. Instead, we should try to create the source in the cache or stream. Because the source cleanup does not unmount the HTTP handler, the handler remains after the source is destroyed. The next time you access the FLV stream, the source is not found. ```cpp srs_error_t SrsHttpStreamServer::hijack(ISrsHttpMessage* request, ISrsHttpHandler** ph) { SrsSharedPtr live_source; if ((err = _srs_sources->fetch_or_create(r.get(), server, live_source)) != srs_success) { } if ((err = http_mount(r.get())) != srs_success) { } srs_error_t SrsBufferCache::cycle() { SrsSharedPtr live_source = _srs_sources->fetch(req); if (!live_source.get()) { return srs_error_new(ERROR_NO_SOURCE, "no source for %s", req->get_stream_url().c_str()); } srs_error_t SrsLiveStream::serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r) { SrsSharedPtr live_source = _srs_sources->fetch(req); if (!live_source.get()) { return srs_error_new(ERROR_NO_SOURCE, "no source for %s", req->get_stream_url().c_str()); } ``` > Note: We should not create the source in hijack, instead, we create it in cache or stream: ```cpp srs_error_t SrsHttpStreamServer::hijack(ISrsHttpMessage* request, ISrsHttpHandler** ph) { if ((err = http_mount(r.get())) != srs_success) { } srs_error_t SrsBufferCache::cycle() { SrsSharedPtr live_source; if ((err = _srs_sources->fetch_or_create(req, server_, live_source)) != srs_success) { } srs_error_t SrsLiveStream::serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r) { SrsSharedPtr live_source; if ((err = _srs_sources->fetch_or_create(req, server_, live_source)) != srs_success) { } ``` > Note: This fixes the failure and annoying warning message, and maintains consistency by always waiting for the stream to be ready if there is no publisher. 3. Fail the http request if the HTTP handler is disposing, and also keep the handler entry when disposing the stream, because we should dispose the handler entry and stream at the same time. ```cpp srs_error_t SrsHttpStreamServer::http_mount(SrsRequest* r) { entry = streamHandlers[sid]; if (entry->disposing) { return srs_error_new(ERROR_STREAM_DISPOSING, "stream is disposing"); } void SrsHttpStreamServer::http_unmount(SrsRequest* r) { std::map::iterator it = streamHandlers.find(sid); SrsUniquePtr entry(it->second); entry->disposing = true; ``` > Note: If the disposal process takes a long time, this will prevent unexpected behavior or access to the resource that is being disposed of. 4. In edge mode, the edge ingester will unpublish the source when the last consumer quits, which is actually triggered by the HTTP stream. While it also waits for the stream to quit when the HTTP unmounts, there is a self-destruction risk: the HTTP live stream object destroys itself. ```cpp srs_error_t SrsLiveStream::serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage* r) { SrsUniquePtr consumer(consumer_raw); // Trigger destroy. void SrsHttpStreamServer::http_unmount(SrsRequest* r) { for (;;) { if (!cache->alive() && !stream->alive()) { break; } // A circle reference. mux.unhandle(entry->mount, stream.get()); // Free the SrsLiveStream itself. ``` > Note: It also introduces a circular reference in the object relationships, the stream reference to itself when unmount: ```text SrsLiveStream::serve_http -> SrsLiveConsumer::~SrsLiveConsumer -> SrsEdgeIngester::stop -> SrsLiveSource::on_unpublish -> SrsHttpStreamServer::http_unmount -> SrsLiveStream::alive ``` > Note: We should use an asynchronous worker to perform the cleanup to avoid the stream destroying itself and to prevent self-referencing. ```cpp void SrsHttpStreamServer::http_unmount(SrsRequest* r) { entry->disposing = true; if ((err = async_->execute(new SrsHttpStreamDestroy(&mux, &streamHandlers, sid))) != srs_success) { } ``` > Note: This also ensures there are no circular references and no self-destruction. --------- Co-authored-by: Jacob Su --- trunk/doc/CHANGELOG.md | 2 + trunk/src/app/srs_app_http_stream.cpp | 183 +++++++++++++++++--------- trunk/src/app/srs_app_http_stream.hpp | 26 +++- trunk/src/core/srs_core_version6.hpp | 2 +- trunk/src/core/srs_core_version7.hpp | 2 +- trunk/src/kernel/srs_kernel_error.hpp | 3 +- 6 files changed, 154 insertions(+), 64 deletions(-) diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index 2e5ad9b92..c6eae0b8f 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v7.0.14 (#4165) * v7.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v7.0.13 (#4166) * v7.0, 2024-08-31, Merge [#4162](https://github.com/ossrs/srs/pull/4162): Fix #3767: RTMP: Do not response empty data packet. v7.0.12 (#4162) * v7.0, 2024-08-31, Merge [#4164](https://github.com/ossrs/srs/pull/4164): HTTP-FLV: Notify connection to expire when unpublishing. v7.0.11 (#4164) @@ -25,6 +26,7 @@ The changelog for SRS. ## SRS 6.0 Changelog +* v6.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v6.0.155 (#4165) * v6.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v6.0.154 (#4166) * v6.0, 2024-08-31, Merge [#4162](https://github.com/ossrs/srs/pull/4162): Fix #3767: RTMP: Do not response empty data packet. v6.0.153 (#4162) * v6.0, 2024-08-31, Merge [#4164](https://github.com/ossrs/srs/pull/4164): HTTP-FLV: Notify connection to expire when unpublishing. v6.0.152 (#4164) diff --git a/trunk/src/app/srs_app_http_stream.cpp b/trunk/src/app/srs_app_http_stream.cpp index df2c4a523..36131c451 100755 --- a/trunk/src/app/srs_app_http_stream.cpp +++ b/trunk/src/app/srs_app_http_stream.cpp @@ -39,8 +39,9 @@ using namespace std; #include #include #include +#include -SrsBufferCache::SrsBufferCache(SrsRequest* r) +SrsBufferCache::SrsBufferCache(SrsServer* s, SrsRequest* r) { req = r->copy()->as_http(); queue = new SrsMessageQueue(true); @@ -48,6 +49,7 @@ SrsBufferCache::SrsBufferCache(SrsRequest* r) // TODO: FIXME: support reload. fast_cache = _srs_config->get_vhost_http_remux_fast_cache(req->vhost); + server_ = s; } SrsBufferCache::~SrsBufferCache() @@ -69,6 +71,11 @@ srs_error_t SrsBufferCache::update_auth(SrsRequest* r) srs_error_t SrsBufferCache::start() { srs_error_t err = srs_success; + + // Not enabled. + if (fast_cache <= 0) { + return err; + } if ((err = trd->start()) != srs_success) { return srs_error_wrap(err, "corotine"); @@ -79,11 +86,21 @@ srs_error_t SrsBufferCache::start() void SrsBufferCache::stop() { + // Not enabled. + if (fast_cache <= 0) { + return; + } + trd->stop(); } bool SrsBufferCache::alive() { + // Not enabled. + if (fast_cache <= 0) { + return false; + } + srs_error_t err = trd->pull(); if (err == srs_success) { return true; @@ -115,17 +132,12 @@ srs_error_t SrsBufferCache::dump_cache(SrsLiveConsumer* consumer, SrsRtmpJitterA srs_error_t SrsBufferCache::cycle() { srs_error_t err = srs_success; - - // TODO: FIXME: support reload. - if (fast_cache <= 0) { - srs_usleep(SRS_STREAM_CACHE_CYCLE); - return err; - } - SrsSharedPtr live_source = _srs_sources->fetch(req); - if (!live_source.get()) { - return srs_error_new(ERROR_NO_SOURCE, "no source for %s", req->get_stream_url().c_str()); + SrsSharedPtr live_source; + if ((err = _srs_sources->fetch_or_create(req, server_, live_source)) != srs_success) { + return srs_error_wrap(err, "source create"); } + srs_assert(live_source.get() != NULL); // the stream cache will create consumer to cache stream, // which will trigger to fetch stream from origin for edge. @@ -578,11 +590,12 @@ srs_error_t SrsBufferWriter::writev(const iovec* iov, int iovcnt, ssize_t* pnwri return writer->writev(iov, iovcnt, pnwrite); } -SrsLiveStream::SrsLiveStream(SrsRequest* r, SrsBufferCache* c) +SrsLiveStream::SrsLiveStream(SrsServer* s, SrsRequest* r, SrsBufferCache* c) { cache = c; req = r->copy()->as_http(); security_ = new SrsSecurity(); + server_ = s; } SrsLiveStream::~SrsLiveStream() @@ -636,10 +649,17 @@ srs_error_t SrsLiveStream::serve_http(ISrsHttpResponseWriter* w, ISrsHttpMessage return srs_error_wrap(err, "http hook"); } - SrsSharedPtr live_source = _srs_sources->fetch(req); - if (!live_source.get()) { - return srs_error_new(ERROR_NO_SOURCE, "no source for %s", req->get_stream_url().c_str()); + // Always try to create the source, because http handler won't create it. + SrsSharedPtr live_source; + if ((err = _srs_sources->fetch_or_create(req, server_, live_source)) != srs_success) { + return srs_error_wrap(err, "source create"); } + srs_assert(live_source.get() != NULL); + + bool enabled_cache = _srs_config->get_gop_cache(req->vhost); + int gcmf = _srs_config->get_gop_cache_max_frames(req->vhost); + live_source->set_cache(enabled_cache); + live_source->set_gop_cache_max_frames(gcmf); // Create consumer of source, ignore gop cache, use the audio gop cache. SrsLiveConsumer* consumer_raw = NULL; @@ -926,6 +946,7 @@ srs_error_t SrsLiveStream::streaming_send_messages(ISrsBufferEncoder* enc, SrsSh SrsLiveEntry::SrsLiveEntry(std::string m) { mount = m; + disposing = false; stream = NULL; cache = NULL; @@ -967,6 +988,7 @@ bool SrsLiveEntry::is_mp3() SrsHttpStreamServer::SrsHttpStreamServer(SrsServer* svr) { server = svr; + async_ = new SrsAsyncCallWorker(); mux.hijack(this); _srs_config->subscribe(this); @@ -976,6 +998,9 @@ SrsHttpStreamServer::~SrsHttpStreamServer() { mux.unhijack(this); _srs_config->unsubscribe(this); + + async_->stop(); + srs_freep(async_); if (true) { std::map::iterator it; @@ -1003,6 +1028,10 @@ srs_error_t SrsHttpStreamServer::initialize() if ((err = initialize_flv_streaming()) != srs_success) { return srs_error_wrap(err, "http flv stream"); } + + if ((err = async_->start()) != srs_success) { + return srs_error_wrap(err, "async start"); + } return err; } @@ -1037,8 +1066,8 @@ srs_error_t SrsHttpStreamServer::http_mount(SrsRequest* r) entry = new SrsLiveEntry(mount); entry->req = r->copy()->as_http(); - entry->cache = new SrsBufferCache(r); - entry->stream = new SrsLiveStream(r, entry->cache); + entry->cache = new SrsBufferCache(server, r); + entry->stream = new SrsLiveStream(server, r, entry->cache); // TODO: FIXME: maybe refine the logic of http remux service. // if user push streams followed: @@ -1067,6 +1096,12 @@ srs_error_t SrsHttpStreamServer::http_mount(SrsRequest* r) } else { // The entry exists, we reuse it and update the request of stream and cache. entry = streamHandlers[sid]; + + // Fail if system is disposing the entry. + if (entry->disposing) { + return srs_error_new(ERROR_STREAM_DISPOSING, "stream is disposing"); + } + entry->stream->update_auth(r); entry->cache->update_auth(r); } @@ -1088,36 +1123,19 @@ void SrsHttpStreamServer::http_unmount(SrsRequest* r) return; } - // Free all HTTP resources. - SrsUniquePtr entry(it->second); - streamHandlers.erase(it); - - SrsUniquePtr stream(entry->stream); - SrsUniquePtr cache(entry->cache); - - // Notify cache and stream to stop. - if (stream->entry) stream->entry->enabled = false; - stream->expire(); - cache->stop(); - - // Wait for cache and stream to stop. - int i = 0; - for (; i < 1024; i++) { - if (!cache->alive() && !stream->alive()) { - break; - } - srs_usleep(100 * SRS_UTIME_MILLISECONDS); + // Set the entry to disposing, which will prevent the stream to be reused. + SrsLiveEntry* entry = it->second; + if (entry->disposing) { + return; } + entry->disposing = true; - if (cache->alive() || stream->alive()) { - srs_warn("http: try to free a alive stream, cache=%d, stream=%d", cache->alive(), stream->alive()); + // Use async worker to execute the task, which will destroy the stream. + srs_error_t err = srs_success; + if ((err = async_->execute(new SrsHttpStreamDestroy(&mux, &streamHandlers, sid))) != srs_success) { + srs_warn("http: ignore unmount stream failed, sid=%s, err=%s", sid.c_str(), srs_error_desc(err).c_str()); + srs_freep(err); } - - // Unmount the HTTP handler, which will free the entry. Note that we must free it after cache and - // stream stopped for it uses it. - mux.unhandle(entry->mount, stream.get()); - - srs_trace("http: unmount flv stream for sid=%s, i=%d", sid.c_str(), i); } srs_error_t SrsHttpStreamServer::hijack(ISrsHttpMessage* request, ISrsHttpHandler** ph) @@ -1214,17 +1232,6 @@ srs_error_t SrsHttpStreamServer::hijack(ISrsHttpMessage* request, ISrsHttpHandle } } - SrsSharedPtr live_source; - if ((err = _srs_sources->fetch_or_create(r.get(), server, live_source)) != srs_success) { - return srs_error_wrap(err, "source create"); - } - srs_assert(live_source.get() != NULL); - - bool enabled_cache = _srs_config->get_gop_cache(r->vhost); - int gcmf = _srs_config->get_gop_cache_max_frames(r->vhost); - live_source->set_cache(enabled_cache); - live_source->set_gop_cache_max_frames(gcmf); - // create http streaming handler. if ((err = http_mount(r.get())) != srs_success) { return srs_error_wrap(err, "http mount"); @@ -1235,11 +1242,8 @@ srs_error_t SrsHttpStreamServer::hijack(ISrsHttpMessage* request, ISrsHttpHandle entry = streamHandlers[sid]; *ph = entry->stream; } - - // trigger edge to fetch from origin. - bool vhost_is_edge = _srs_config->get_vhost_is_edge(r->vhost); - srs_trace("flv: source url=%s, is_edge=%d, source_id=%s/%s", - r->get_stream_url().c_str(), vhost_is_edge, live_source->source_id().c_str(), live_source->pre_source_id().c_str()); + + srs_trace("flv: hijack %s ok", upath.c_str()); return err; } @@ -1281,3 +1285,64 @@ srs_error_t SrsHttpStreamServer::initialize_flv_entry(std::string vhost) return err; } +SrsHttpStreamDestroy::SrsHttpStreamDestroy(SrsHttpServeMux* mux, map* handlers, string sid) +{ + mux_ = mux; + sid_ = sid; + streamHandlers_ = handlers; +} + +SrsHttpStreamDestroy::~SrsHttpStreamDestroy() +{ +} + +srs_error_t SrsHttpStreamDestroy::call() +{ + srs_error_t err = srs_success; + + std::map::iterator it = streamHandlers_->find(sid_); + if (it == streamHandlers_->end()) { + return err; + } + + // Free all HTTP resources. + SrsUniquePtr entry(it->second); + srs_assert(entry->disposing); + + SrsUniquePtr stream(entry->stream); + SrsUniquePtr cache(entry->cache); + + // Notify cache and stream to stop. + if (stream->entry) stream->entry->enabled = false; + stream->expire(); + cache->stop(); + + // Wait for cache and stream to stop. + int i = 0; + for (; i < 1024; i++) { + if (!cache->alive() && !stream->alive()) { + break; + } + srs_usleep(100 * SRS_UTIME_MILLISECONDS); + } + + if (cache->alive() || stream->alive()) { + srs_warn("http: try to free a alive stream, cache=%d, stream=%d", cache->alive(), stream->alive()); + } + + // Remove the entry from handlers. + streamHandlers_->erase(it); + + // Unmount the HTTP handler, which will free the entry. Note that we must free it after cache and + // stream stopped for it uses it. + mux_->unhandle(entry->mount, stream.get()); + + srs_trace("http: unmount flv stream for sid=%s, i=%d", sid_.c_str(), i); + return err; +} + +string SrsHttpStreamDestroy::to_string() +{ + return "destroy"; +} + diff --git a/trunk/src/app/srs_app_http_stream.hpp b/trunk/src/app/srs_app_http_stream.hpp index cf3737982..352c4f99f 100755 --- a/trunk/src/app/srs_app_http_stream.hpp +++ b/trunk/src/app/srs_app_http_stream.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include @@ -17,18 +18,20 @@ class SrsAacTransmuxer; class SrsMp3Transmuxer; class SrsFlvTransmuxer; class SrsTsTransmuxer; +class SrsAsyncCallWorker; // A cache for HTTP Live Streaming encoder, to make android(weixin) happy. class SrsBufferCache : public ISrsCoroutineHandler { private: srs_utime_t fast_cache; + SrsServer* server_; private: SrsMessageQueue* queue; SrsRequest* req; SrsCoroutine* trd; public: - SrsBufferCache(SrsRequest* r); + SrsBufferCache(SrsServer* s, SrsRequest* r); virtual ~SrsBufferCache(); virtual srs_error_t update_auth(SrsRequest* r); public: @@ -184,12 +187,13 @@ private: SrsRequest* req; SrsBufferCache* cache; SrsSecurity* security_; + SrsServer* server_; // For multiple viewers, which means there will more than one alive viewers for a live stream, so we must // use an int value to represent if there is any viewer is alive. We should never do cleanup unless all // viewers closed the connection. std::vector viewers_; public: - SrsLiveStream(SrsRequest* r, SrsBufferCache* c); + SrsLiveStream(SrsServer* s, SrsRequest* r, SrsBufferCache* c); virtual ~SrsLiveStream(); virtual srs_error_t update_auth(SrsRequest* r); public: @@ -223,6 +227,9 @@ public: SrsLiveStream* stream; SrsBufferCache* cache; + + // Whether is disposing the entry. + bool disposing; SrsLiveEntry(std::string m); virtual ~SrsLiveEntry(); @@ -240,6 +247,7 @@ class SrsHttpStreamServer : public ISrsReloadHandler { private: SrsServer* server; + SrsAsyncCallWorker* async_; public: SrsHttpServeMux mux; // The http live streaming template, to create streams. @@ -263,5 +271,19 @@ private: virtual srs_error_t initialize_flv_entry(std::string vhost); }; +class SrsHttpStreamDestroy : public ISrsAsyncCallTask +{ +private: + std::string sid_; + std::map* streamHandlers_; + SrsHttpServeMux* mux_; +public: + SrsHttpStreamDestroy(SrsHttpServeMux* mux, std::map* handlers, std::string sid); + virtual ~SrsHttpStreamDestroy(); +public: + virtual srs_error_t call(); + virtual std::string to_string(); +}; + #endif diff --git a/trunk/src/core/srs_core_version6.hpp b/trunk/src/core/srs_core_version6.hpp index ed3ffb50d..32328d99a 100644 --- a/trunk/src/core/srs_core_version6.hpp +++ b/trunk/src/core/srs_core_version6.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 6 #define VERSION_MINOR 0 -#define VERSION_REVISION 154 +#define VERSION_REVISION 155 #endif diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index 0f7e659ef..ee85bc820 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 13 +#define VERSION_REVISION 14 #endif \ No newline at end of file diff --git a/trunk/src/kernel/srs_kernel_error.hpp b/trunk/src/kernel/srs_kernel_error.hpp index af9acf12d..dcd818483 100644 --- a/trunk/src/kernel/srs_kernel_error.hpp +++ b/trunk/src/kernel/srs_kernel_error.hpp @@ -107,7 +107,8 @@ XX(ERROR_BACKTRACE_ADDR2LINE , 1094, "BacktraceAddr2Line", "Backtrace addr2line failed") \ XX(ERROR_SYSTEM_FILE_NOT_OPEN , 1095, "FileNotOpen", "File is not opened") \ XX(ERROR_SYSTEM_FILE_SETVBUF , 1096, "FileSetVBuf", "Failed to set file vbuf") \ - XX(ERROR_NO_SOURCE , 1097, "NoSource", "No source found") + XX(ERROR_NO_SOURCE , 1097, "NoSource", "No source found") \ + XX(ERROR_STREAM_DISPOSING , 1098, "StreamDisposing", "Stream is disposing") /**************************************************/ /* RTMP protocol error. */ From d70e7357cfe5b34eb8e8dd30723d07f4bdd48b24 Mon Sep 17 00:00:00 2001 From: winlin Date: Sun, 1 Sep 2024 15:57:55 +0800 Subject: [PATCH 05/12] Release v6.0-a1, 6.0 alpha1, v6.0.155, 169636 lines. --- .github/workflows/release.yml | 23 ++++++++++++----------- README.md | 1 + 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f8b218da6..b8d65a36e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -4,7 +4,7 @@ name: "Release" on: push: tags: - - v6* + - v7* # For draft, need write permission. permissions: @@ -408,6 +408,7 @@ jobs: echo "SRS_TAG=${{ needs.envs.outputs.SRS_TAG }}" >> $GITHUB_ENV echo "SRS_VERSION=${{ needs.envs.outputs.SRS_VERSION }}" >> $GITHUB_ENV echo "SRS_MAJOR=${{ needs.envs.outputs.SRS_MAJOR }}" >> $GITHUB_ENV + echo "SRS_XYZ=${{ needs.envs.outputs.SRS_XYZ }}" >> $GITHUB_ENV echo "SRS_RELEASE_ID=${{ needs.draft.outputs.SRS_RELEASE_ID }}" >> $GITHUB_ENV echo "SRS_PACKAGE_ZIP=${{ needs.linux.outputs.SRS_PACKAGE_ZIP }}" >> $GITHUB_ENV echo "SRS_PACKAGE_MD5=${{ needs.linux.outputs.SRS_PACKAGE_MD5 }}" >> $GITHUB_ENV @@ -448,23 +449,23 @@ jobs: * Binary: ${{ env.SRS_CYGWIN_MD5 }} [${{ env.SRS_CYGWIN_TAR }}](https://gitee.com/ossrs/srs/releases/download/${{ env.SRS_TAG }}/${{ env.SRS_CYGWIN_TAR }}) ## Docker - * [docker pull ossrs/srs:${{ env.SRS_MAJOR }}](https://ossrs.io/lts/en-us/docs/v5/doc/getting-started) - * [docker pull ossrs/srs:${{ env.SRS_TAG }}](https://ossrs.io/lts/en-us/docs/v5/doc/getting-started) - * [docker pull ossrs/srs:${{ env.SRS_XYZ }}](https://ossrs.io/lts/en-us/docs/v5/doc/getting-started) + * [docker pull ossrs/srs:${{ env.SRS_MAJOR }}](https://ossrs.io/lts/en-us/docs/v7/doc/getting-started) + * [docker pull ossrs/srs:${{ env.SRS_TAG }}](https://ossrs.io/lts/en-us/docs/v7/doc/getting-started) + * [docker pull ossrs/srs:${{ env.SRS_XYZ }}](https://ossrs.io/lts/en-us/docs/v7/doc/getting-started) ## Docker Mirror: aliyun.com - * [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_MAJOR }}](https://ossrs.net/lts/zh-cn/docs/v5/doc/getting-started) - * [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_TAG }}](https://ossrs.net/lts/zh-cn/docs/v5/doc/getting-started) - * [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_XYZ }}](https://ossrs.net/lts/zh-cn/docs/v5/doc/getting-started) + * [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_MAJOR }}](https://ossrs.net/lts/zh-cn/docs/v7/doc/getting-started) + * [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_TAG }}](https://ossrs.net/lts/zh-cn/docs/v7/doc/getting-started) + * [docker pull registry.cn-hangzhou.aliyuncs.com/ossrs/srs:${{ env.SRS_XYZ }}](https://ossrs.net/lts/zh-cn/docs/v7/doc/getting-started) ## Doc: ossrs.io - * [Getting Started](https://ossrs.io/lts/en-us/docs/v5/doc/getting-started) - * [Wiki home](https://ossrs.io/lts/en-us/docs/v5/doc/introduction) + * [Getting Started](https://ossrs.io/lts/en-us/docs/v7/doc/getting-started) + * [Wiki home](https://ossrs.io/lts/en-us/docs/v7/doc/introduction) * [FAQ](https://ossrs.io/lts/en-us/faq), [Features](https://github.com/ossrs/srs/blob/${{ github.sha }}/trunk/doc/Features.md#features) or [ChangeLogs](https://github.com/ossrs/srs/blob/${{ github.sha }}/trunk/doc/CHANGELOG.md#changelog) ## Doc: ossrs.net - * [快速入门](https://ossrs.net/lts/zh-cn/docs/v5/doc/getting-started) - * [中文Wiki首页](https://ossrs.net/lts/zh-cn/docs/v5/doc/introduction) + * [快速入门](https://ossrs.net/lts/zh-cn/docs/v7/doc/getting-started) + * [中文Wiki首页](https://ossrs.net/lts/zh-cn/docs/v7/doc/introduction) * [中文FAQ](https://ossrs.net/lts/zh-cn/faq), [功能列表](https://github.com/ossrs/srs/blob/${{ github.sha }}/trunk/doc/Features.md#features) 或 [修订历史](https://github.com/ossrs/srs/blob/${{ github.sha }}/trunk/doc/CHANGELOG.md#changelog) draft: false prerelease: true diff --git a/README.md b/README.md index bf345621a..c4d3b20a5 100755 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ distributed under their [licenses](https://ossrs.io/lts/en-us/license). ## Releases +* 2024-09-01, [Release v6.0-a1](https://github.com/ossrs/srs/releases/tag/v6.0-a1), v6.0-a1, 6.0 alpha1, v6.0.155, 169636 lines. * 2024-07-27, [Release v6.0-a0](https://github.com/ossrs/srs/releases/tag/v6.0-a0), v6.0-a0, 6.0 alpha0, v6.0.145, 169259 lines. * 2024-07-04, [Release v6.0-d6](https://github.com/ossrs/srs/releases/tag/v6.0-d6), v6.0-d6, 6.0 dev6, v6.0.134, 168904 lines. * 2024-06-15, [Release v6.0-d5](https://github.com/ossrs/srs/releases/tag/v6.0-d5), v6.0-d5, 6.0 dev5, v6.0.129, 168454 lines. From b475d552aad8a97ff5da4cebd142f3f47b19a5af Mon Sep 17 00:00:00 2001 From: Winlin Date: Mon, 9 Sep 2024 10:37:41 +0800 Subject: [PATCH 06/12] Heartbeat: Report ports for proxy server. v5.0.215 v6.0.156 v7.0.15 (#4171) The heartbeat of SRS is a timer that requests an HTTP URL. We can use this heartbeat to report the necessary information for registering the backend server with the proxy server. ```text SRS(backend) --heartbeat---> Proxy server ``` A proxy server is a specialized load balancer for media servers. It operates at the application level rather than the TCP level. For more information about the proxy server, see issue #4158. Note that we will merge this PR into SRS 5.0+, allowing the use of SRS 5.0+ as the backend server, not limited to SRS 7.0. However, the proxy server is introduced in SRS 7.0. It's also possible to implement a registration service, allowing you to use other media servers as backend servers. For example, if you gather information about an nginx-rtmp server and register it with the proxy server, the proxy will forward RTMP streams to nginx-rtmp. The backend server is not limited to SRS. --------- Co-authored-by: Jacob Su --- .run/private.run.xml | 1 - trunk/conf/full.conf | 8 +++ trunk/doc/CHANGELOG.md | 3 ++ trunk/src/app/srs_app_config.cpp | 27 ++++++++-- trunk/src/app/srs_app_config.hpp | 1 + trunk/src/app/srs_app_heartbeat.cpp | 78 +++++++++++++++++++++++++--- trunk/src/core/srs_core_version5.hpp | 2 +- trunk/src/core/srs_core_version6.hpp | 2 +- trunk/src/core/srs_core_version7.hpp | 2 +- 9 files changed, 110 insertions(+), 14 deletions(-) diff --git a/.run/private.run.xml b/.run/private.run.xml index 458dabc61..5f4618462 100644 --- a/.run/private.run.xml +++ b/.run/private.run.xml @@ -1,7 +1,6 @@ - diff --git a/trunk/conf/full.conf b/trunk/conf/full.conf index 0579788fa..a84309c22 100644 --- a/trunk/conf/full.conf +++ b/trunk/conf/full.conf @@ -907,6 +907,14 @@ heartbeat { # Overwrite by env SRS_HEARTBEAT_SUMMARIES # default: off summaries off; + # Whether report with listen ports. + # if on, request with the ports of SRS: + # { + # "rtmp": ["1935"], "http": ["8080"], "api": ["1985"], "srt": ["10080"], "rtc": ["8000"] + # } + # Overwrite by env SRS_HEARTBEAT_PORTS + # default: off + ports off; } # system statistics section. diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index c6eae0b8f..2772c0bf2 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2024-09-09, Merge [#4171](https://github.com/ossrs/srs/pull/4171): Heartbeat: Report ports for proxy server. v7.0.15 (#4171) * v7.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v7.0.14 (#4165) * v7.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v7.0.13 (#4166) * v7.0, 2024-08-31, Merge [#4162](https://github.com/ossrs/srs/pull/4162): Fix #3767: RTMP: Do not response empty data packet. v7.0.12 (#4162) @@ -26,6 +27,7 @@ The changelog for SRS. ## SRS 6.0 Changelog +* v6.0, 2024-09-09, Merge [#4171](https://github.com/ossrs/srs/pull/4171): Heartbeat: Report ports for proxy server. v6.0.156 (#4171) * v6.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v6.0.155 (#4165) * v6.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v6.0.154 (#4166) * v6.0, 2024-08-31, Merge [#4162](https://github.com/ossrs/srs/pull/4162): Fix #3767: RTMP: Do not response empty data packet. v6.0.153 (#4162) @@ -185,6 +187,7 @@ The changelog for SRS. ## SRS 5.0 Changelog +* v5.0, 2024-09-09, Merge [#4171](https://github.com/ossrs/srs/pull/4171): Heartbeat: Report ports for proxy server. v5.0.215 (#4171) * v5.0, 2024-07-24, Merge [#4126](https://github.com/ossrs/srs/pull/4126): Edge: Improve stability for state and fd closing. v5.0.214 (#4126) * v5.0, 2024-06-03, Merge [#4057](https://github.com/ossrs/srs/pull/4057): RTC: Support dropping h.264 SEI from NALUs. v5.0.213 (#4057) * v5.0, 2024-04-23, Merge [#4038](https://github.com/ossrs/srs/pull/4038): RTMP: Do not response publish start message if hooks fail. v5.0.212 (#4038) diff --git a/trunk/src/app/srs_app_config.cpp b/trunk/src/app/srs_app_config.cpp index 0731f3cd2..eeae92ef1 100644 --- a/trunk/src/app/srs_app_config.cpp +++ b/trunk/src/app/srs_app_config.cpp @@ -2409,7 +2409,7 @@ srs_error_t SrsConfig::check_normal_config() for (int i = 0; conf && i < (int)conf->directives.size(); i++) { string n = conf->at(i)->name; if (n != "enabled" && n != "interval" && n != "url" - && n != "device_id" && n != "summaries") { + && n != "device_id" && n != "summaries" && n != "ports") { return srs_error_new(ERROR_SYSTEM_CONFIG_INVALID, "illegal heartbeat.%s", n.c_str()); } } @@ -8794,17 +8794,36 @@ bool SrsConfig::get_heartbeat_summaries() SRS_OVERWRITE_BY_ENV_BOOL("srs.heartbeat.summaries"); // SRS_HEARTBEAT_SUMMARIES static bool DEFAULT = false; - + SrsConfDirective* conf = get_heartbeart(); if (!conf) { return DEFAULT; } - + conf = conf->get("summaries"); if (!conf || conf->arg0().empty()) { return DEFAULT; } - + + return SRS_CONF_PREFER_FALSE(conf->arg0()); +} + +bool SrsConfig::get_heartbeat_ports() +{ + SRS_OVERWRITE_BY_ENV_BOOL("srs.heartbeat.ports"); // SRS_HEARTBEAT_PORTS + + static bool DEFAULT = false; + + SrsConfDirective* conf = get_heartbeart(); + if (!conf) { + return DEFAULT; + } + + conf = conf->get("ports"); + if (!conf || conf->arg0().empty()) { + return DEFAULT; + } + return SRS_CONF_PREFER_FALSE(conf->arg0()); } diff --git a/trunk/src/app/srs_app_config.hpp b/trunk/src/app/srs_app_config.hpp index 28aec179d..e7432d14c 100644 --- a/trunk/src/app/srs_app_config.hpp +++ b/trunk/src/app/srs_app_config.hpp @@ -1119,6 +1119,7 @@ public: virtual std::string get_heartbeat_device_id(); // Whether report with summaries of http api: /api/v1/summaries. virtual bool get_heartbeat_summaries(); + bool get_heartbeat_ports(); // stats section private: // Get the stats directive. diff --git a/trunk/src/app/srs_app_heartbeat.cpp b/trunk/src/app/srs_app_heartbeat.cpp index 819075a68..c57ee3757 100644 --- a/trunk/src/app/srs_app_heartbeat.cpp +++ b/trunk/src/app/srs_app_heartbeat.cpp @@ -18,6 +18,8 @@ using namespace std; #include #include #include +#include +#include SrsHttpHeartbeat::SrsHttpHeartbeat() { @@ -48,18 +50,28 @@ srs_error_t SrsHttpHeartbeat::do_heartbeat() return srs_error_wrap(err, "http uri parse hartbeart url failed. url=%s", url.c_str()); } - SrsIPAddress* ip = NULL; + string ip; std::string device_id = _srs_config->get_heartbeat_device_id(); - - vector& ips = srs_get_local_ips(); - if (!ips.empty()) { - ip = ips[_srs_config->get_stats_network() % (int)ips.size()]; + + // Try to load the ip from the environment variable. + ip = srs_getenv("srs.device.ip"); // SRS_DEVICE_IP + if (ip.empty()) { + // Use the local ip address specified by the stats.network config. + vector& ips = srs_get_local_ips(); + if (!ips.empty()) { + ip = ips[_srs_config->get_stats_network() % (int) ips.size()]->ip; + } } SrsUniquePtr obj(SrsJsonAny::object()); obj->set("device_id", SrsJsonAny::str(device_id.c_str())); - obj->set("ip", SrsJsonAny::str(ip->ip.c_str())); + obj->set("ip", SrsJsonAny::str(ip.c_str())); + + SrsStatistic* stat = SrsStatistic::instance(); + obj->set("server", SrsJsonAny::str(stat->server_id().c_str())); + obj->set("service", SrsJsonAny::str(stat->service_id().c_str())); + obj->set("pid", SrsJsonAny::str(stat->service_pid().c_str())); if (_srs_config->get_heartbeat_summaries()) { SrsJsonObject* summaries = SrsJsonAny::object(); @@ -67,6 +79,60 @@ srs_error_t SrsHttpHeartbeat::do_heartbeat() srs_api_dump_summaries(summaries); } + + if (_srs_config->get_heartbeat_ports()) { + // For RTMP listen endpoints. + if (true) { + SrsJsonArray* o = SrsJsonAny::array(); + obj->set("rtmp", o); + + vector endpoints = _srs_config->get_listens(); + for (int i = 0; i < (int) endpoints.size(); i++) { + o->append(SrsJsonAny::str(endpoints.at(i).c_str())); + } + } + + // For HTTP Stream listen endpoints. + if (_srs_config->get_http_stream_enabled()) { + SrsJsonArray* o = SrsJsonAny::array(); + obj->set("http", o); + + string endpoint = _srs_config->get_http_stream_listen(); + o->append(SrsJsonAny::str(endpoint.c_str())); + } + + // For HTTP API listen endpoints. + if (_srs_config->get_http_api_enabled()) { + SrsJsonArray* o = SrsJsonAny::array(); + obj->set("api", o); + + string endpoint = _srs_config->get_http_api_listen(); + o->append(SrsJsonAny::str(endpoint.c_str())); + } + + // For SRT listen endpoints. + if (_srs_config->get_srt_enabled()) { + SrsJsonArray* o = SrsJsonAny::array(); + obj->set("srt", o); + + uint16_t endpoint = _srs_config->get_srt_listen_port(); + o->append(SrsJsonAny::str(srs_fmt("udp://0.0.0.0:%d", endpoint).c_str())); + } + + // For WebRTC listen endpoints. + if (_srs_config->get_rtc_server_enabled()) { + SrsJsonArray* o = SrsJsonAny::array(); + obj->set("rtc", o); + + int endpoint = _srs_config->get_rtc_server_listen(); + o->append(SrsJsonAny::str(srs_fmt("udp://0.0.0.0:%d", endpoint).c_str())); + + if (_srs_config->get_rtc_server_tcp_enabled()) { + endpoint = _srs_config->get_rtc_server_tcp_listen(); + o->append(SrsJsonAny::str(srs_fmt("tcp://0.0.0.0:%d", endpoint).c_str())); + } + } + } SrsHttpClient http; if ((err = http.initialize(uri.get_schema(), uri.get_host(), uri.get_port())) != srs_success) { diff --git a/trunk/src/core/srs_core_version5.hpp b/trunk/src/core/srs_core_version5.hpp index ef74aa44d..791b3ae6e 100644 --- a/trunk/src/core/srs_core_version5.hpp +++ b/trunk/src/core/srs_core_version5.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 5 #define VERSION_MINOR 0 -#define VERSION_REVISION 214 +#define VERSION_REVISION 215 #endif diff --git a/trunk/src/core/srs_core_version6.hpp b/trunk/src/core/srs_core_version6.hpp index 32328d99a..69725cd53 100644 --- a/trunk/src/core/srs_core_version6.hpp +++ b/trunk/src/core/srs_core_version6.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 6 #define VERSION_MINOR 0 -#define VERSION_REVISION 155 +#define VERSION_REVISION 156 #endif diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index ee85bc820..fed95c499 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 14 +#define VERSION_REVISION 15 #endif \ No newline at end of file From 2e4014ae1c84a5bde4381959c0ee8a79de32fdbc Mon Sep 17 00:00:00 2001 From: Winlin Date: Mon, 9 Sep 2024 12:06:02 +0800 Subject: [PATCH 07/12] Proxy: Support proxy server for SRS. v7.0.16 (#4158) Please note that the proxy server is a new architecture or the next version of the Origin Cluster, which allows the publication of multiple streams. The SRS origin cluster consists of a group of origin servers designed to handle a large number of streams. ```text +-----------------------+ +---+ SRS Proxy(Deployment) +------+---------------------+ +-----------------+ | +-----------+-----------+ + + | LB(K8s Service) +--+ +(Redis/MESH) + SRS Origin Servers + +-----------------+ | +-----------+-----------+ + (Deployment) + +---+ SRS Proxy(Deployment) +------+---------------------+ +-----------------------+ ``` The new origin cluster is designed as a collection of proxy servers. For more information, see [Discussion #3634](https://github.com/ossrs/srs/discussions/3634). If you prefer to use the old origin cluster, please switch to a version before SRS 6.0. A proxy server can be used for a set of origin servers, which are isolated and dedicated origin servers. The main improvement in the new architecture is to store the state for origin servers in the proxy server, rather than using MESH to communicate between origin servers. With a proxy server, you can deploy origin servers as stateless servers, such as in a Kubernetes (K8s) deployment. Now that the proxy server is a stateful server, it uses Redis to store the states. For faster development, we use Go to develop the proxy server, instead of C/C++. Therefore, the proxy server itself is also stateless, with all states stored in the Redis server or cluster. This makes the new origin cluster architecture very powerful and robust. The proxy server is also an architecture designed to solve multiple process bottlenecks. You can run hundreds of SRS origin servers with one proxy server on the same machine. This solution can utilize multi-core machines, such as servers with 128 CPUs. Thus, we can keep SRS single-threaded and very simple. See https://github.com/ossrs/srs/discussions/3665#discussioncomment-6474441 for details. ```text +--------------------+ +-------+ SRS Origin Server + + +--------------------+ + +-----------------------+ + +--------------------+ + SRS Proxy(Deployment) +------+-------+ SRS Origin Server + +-----------------------+ + +--------------------+ + + +--------------------+ +-------+ SRS Origin Server + +--------------------+ ``` Keep in mind that the proxy server for the Origin Cluster is designed to handle many streams. To address the issue of many viewers, we will enhance the Edge Cluster to support more protocols. ```text +------------------+ +--------------------+ + SRS Edge Server +--+ +-------+ SRS Origin Server + +------------------+ + + +--------------------+ + + +------------------+ + +-----------------------+ + +--------------------+ + SRS Edge Server +--+-----+ SRS Proxy(Deployment) +------+-------+ SRS Origin Server + +------------------+ + +-----------------------+ + +--------------------+ + + +------------------+ + + +--------------------+ + SRS Edge Server +--+ +-------+ SRS Origin Server + +------------------+ +--------------------+ ``` With the new Origin Cluster and Edge Cluster, you have a media system capable of supporting a large number of streams and viewers. For example, you can publish 10,000 streams, each with 100,000 viewers. --------- Co-authored-by: Jacob Su --- proxy/.gitignore | 4 + proxy/Makefile | 23 + proxy/api.go | 272 ++++ proxy/debug.go | 20 + proxy/env.go | 197 +++ proxy/errors/errors.go | 270 ++++ proxy/errors/stack.go | 187 +++ proxy/go.mod | 13 + proxy/go.sum | 17 + proxy/http.go | 419 ++++++ proxy/logger/context.go | 43 + proxy/logger/log.go | 87 ++ proxy/main.go | 121 ++ proxy/rtc.go | 515 ++++++++ proxy/rtmp.go | 655 ++++++++++ proxy/rtmp/amf0.go | 771 +++++++++++ proxy/rtmp/rtmp.go | 1792 ++++++++++++++++++++++++++ proxy/signal.go | 44 + proxy/srs.go | 553 ++++++++ proxy/srt.go | 574 +++++++++ proxy/sync/map.go | 45 + proxy/utils.go | 276 ++++ proxy/version.go | 27 + trunk/conf/origin1-for-proxy.conf | 57 + trunk/conf/origin2-for-proxy.conf | 57 + trunk/conf/origin3-for-proxy.conf | 57 + trunk/doc/CHANGELOG.md | 1 + trunk/src/app/srs_app_st.cpp | 7 +- trunk/src/core/srs_core_version7.hpp | 2 +- 29 files changed, 7104 insertions(+), 2 deletions(-) create mode 100644 proxy/.gitignore create mode 100644 proxy/Makefile create mode 100644 proxy/api.go create mode 100644 proxy/debug.go create mode 100644 proxy/env.go create mode 100644 proxy/errors/errors.go create mode 100644 proxy/errors/stack.go create mode 100644 proxy/go.mod create mode 100644 proxy/go.sum create mode 100644 proxy/http.go create mode 100644 proxy/logger/context.go create mode 100644 proxy/logger/log.go create mode 100644 proxy/main.go create mode 100644 proxy/rtc.go create mode 100644 proxy/rtmp.go create mode 100644 proxy/rtmp/amf0.go create mode 100644 proxy/rtmp/rtmp.go create mode 100644 proxy/signal.go create mode 100644 proxy/srs.go create mode 100644 proxy/srt.go create mode 100644 proxy/sync/map.go create mode 100644 proxy/utils.go create mode 100644 proxy/version.go create mode 100644 trunk/conf/origin1-for-proxy.conf create mode 100644 trunk/conf/origin2-for-proxy.conf create mode 100644 trunk/conf/origin3-for-proxy.conf diff --git a/proxy/.gitignore b/proxy/.gitignore new file mode 100644 index 000000000..c20f4b678 --- /dev/null +++ b/proxy/.gitignore @@ -0,0 +1,4 @@ +.idea +srs-proxy +.env +.go-formarted \ No newline at end of file diff --git a/proxy/Makefile b/proxy/Makefile new file mode 100644 index 000000000..29084d5b7 --- /dev/null +++ b/proxy/Makefile @@ -0,0 +1,23 @@ +.PHONY: all build test fmt clean run + +all: build + +build: fmt ./srs-proxy + +./srs-proxy: *.go + go build -o srs-proxy . + +test: + go test ./... + +fmt: ./.go-formarted + +./.go-formarted: *.go + touch .go-formarted + go fmt ./... + +clean: + rm -f srs-proxy .go-formarted + +run: fmt + go run . diff --git a/proxy/api.go b/proxy/api.go new file mode 100644 index 000000000..04baa9252 --- /dev/null +++ b/proxy/api.go @@ -0,0 +1,272 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "strings" + "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP, +// to proxy other HTTP API of SRS like the streams and clients, etc. +type srsHTTPAPIServer struct { + // The underlayer HTTP server. + server *http.Server + // The WebRTC server. + rtc *srsWebRTCServer + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewSRSHTTPAPIServer(opts ...func(*srsHTTPAPIServer)) *srsHTTPAPIServer { + v := &srsHTTPAPIServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsHTTPAPIServer) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *srsHTTPAPIServer) Run(ctx context.Context) error { + // Parse address to listen. + addr := envHttpAPI() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "HTTP API server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + apiResponse(ctx, w, r, map[string]string{ + "signature": Signature(), + "version": Version(), + }) + }) + + // The WebRTC WHIP API handler. + logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr) + mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) { + if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } + }) + + // The WebRTC WHEP API handler. + logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr) + mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) { + if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } + }) + + // Run HTTP API server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If HTTP API server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "HTTP API accept err %+v", err) + } else { + logger.Df(ctx, "HTTP API server done") + } + } + }() + + return nil +} + +// systemAPI is the system HTTP API of the proxy server, for SRS media server to register the service +// to proxy server. It also provides some other system APIs like the status of proxy server, like exporter +// for Prometheus metrics. +type systemAPI struct { + // The underlayer HTTP server. + server *http.Server + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewSystemAPI(opts ...func(*systemAPI)) *systemAPI { + v := &systemAPI{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *systemAPI) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *systemAPI) Run(ctx context.Context) error { + // Parse address to listen. + addr := envSystemAPI() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "System API server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + apiResponse(ctx, w, r, map[string]string{ + "signature": Signature(), + "version": Version(), + }) + }) + + // The register service for SRS media servers. + logger.Df(ctx, "Handle /api/v1/srs/register by %v", addr) + mux.HandleFunc("/api/v1/srs/register", func(w http.ResponseWriter, r *http.Request) { + if err := func() error { + var deviceID, ip, serverID, serviceID, pid string + var rtmp, stream, api, srt, rtc []string + if err := ParseBody(r.Body, &struct { + // The IP of SRS, mandatory. + IP *string `json:"ip"` + // The server id of SRS, store in file, may not change, mandatory. + ServerID *string `json:"server"` + // The service id of SRS, always change when restarted, mandatory. + ServiceID *string `json:"service"` + // The process id of SRS, always change when restarted, mandatory. + PID *string `json:"pid"` + // The RTMP listen endpoints, mandatory. + RTMP *[]string `json:"rtmp"` + // The HTTP Stream listen endpoints, optional. + HTTP *[]string `json:"http"` + // The API listen endpoints, optional. + API *[]string `json:"api"` + // The SRT listen endpoints, optional. + SRT *[]string `json:"srt"` + // The RTC listen endpoints, optional. + RTC *[]string `json:"rtc"` + // The device id of SRS, optional. + DeviceID *string `json:"device_id"` + }{ + IP: &ip, DeviceID: &deviceID, + ServerID: &serverID, ServiceID: &serviceID, PID: &pid, + RTMP: &rtmp, HTTP: &stream, API: &api, SRT: &srt, RTC: &rtc, + }); err != nil { + return errors.Wrapf(err, "parse body") + } + + if ip == "" { + return errors.Errorf("empty ip") + } + if serverID == "" { + return errors.Errorf("empty server") + } + if serviceID == "" { + return errors.Errorf("empty service") + } + if pid == "" { + return errors.Errorf("empty pid") + } + if len(rtmp) == 0 { + return errors.Errorf("empty rtmp") + } + + server := NewSRSServer(func(srs *SRSServer) { + srs.IP, srs.DeviceID = ip, deviceID + srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid + srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api + srs.SRT, srs.RTC = srt, rtc + srs.UpdatedAt = time.Now() + }) + if err := srsLoadBalancer.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update SRS server %+v", server) + } + + logger.Df(ctx, "Register SRS media server, %+v", server) + return nil + }(); err != nil { + apiError(ctx, w, r, err) + } + + type Response struct { + Code int `json:"code"` + PID string `json:"pid"` + } + + apiResponse(ctx, w, r, &Response{ + Code: 0, PID: fmt.Sprintf("%v", os.Getpid()), + }) + }) + + // Run System API server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If System API server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "System API accept err %+v", err) + } else { + logger.Df(ctx, "System API server done") + } + } + }() + + return nil +} diff --git a/proxy/debug.go b/proxy/debug.go new file mode 100644 index 000000000..3a389b8bb --- /dev/null +++ b/proxy/debug.go @@ -0,0 +1,20 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "net/http" + + "srs-proxy/logger" +) + +func handleGoPprof(ctx context.Context) { + if addr := envGoPprof(); addr != "" { + go func() { + logger.Df(ctx, "Start Go pprof at %v", addr) + http.ListenAndServe(addr, nil) + }() + } +} diff --git a/proxy/env.go b/proxy/env.go new file mode 100644 index 000000000..0c201bb1d --- /dev/null +++ b/proxy/env.go @@ -0,0 +1,197 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + "path" + + "github.com/joho/godotenv" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +// loadEnvFile loads the environment variables from file. Note that we only use .env file. +func loadEnvFile(ctx context.Context) error { + if workDir, err := os.Getwd(); err != nil { + return errors.Wrapf(err, "getpwd") + } else { + envFile := path.Join(workDir, ".env") + if _, err := os.Stat(envFile); err == nil { + if err := godotenv.Load(envFile); err != nil { + return errors.Wrapf(err, "load %v", envFile) + } + } + } + + return nil +} + +// buildDefaultEnvironmentVariables setups the default environment variables. +func buildDefaultEnvironmentVariables(ctx context.Context) { + // Whether enable the Go pprof. + setEnvDefault("GO_PPROF", "") + // Force shutdown timeout. + setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s") + // Graceful quit timeout. + setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s") + + // The HTTP API server. + setEnvDefault("PROXY_HTTP_API", "11985") + // The HTTP web server. + setEnvDefault("PROXY_HTTP_SERVER", "18080") + // The RTMP media server. + setEnvDefault("PROXY_RTMP_SERVER", "11935") + // The WebRTC media server, via UDP protocol. + setEnvDefault("PROXY_WEBRTC_SERVER", "18000") + // The SRT media server, via UDP protocol. + setEnvDefault("PROXY_SRT_SERVER", "20080") + // The API server of proxy itself. + setEnvDefault("PROXY_SYSTEM_API", "12025") + // The static directory for web server. + setEnvDefault("PROXY_STATIC_FILES", "../trunk/research") + + // The load balancer, use redis or memory. + setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "memory") + // The redis server host. + setEnvDefault("PROXY_REDIS_HOST", "127.0.0.1") + // The redis server port. + setEnvDefault("PROXY_REDIS_PORT", "6379") + // The redis server password. + setEnvDefault("PROXY_REDIS_PASSWORD", "") + // The redis server db. + setEnvDefault("PROXY_REDIS_DB", "0") + + // Whether enable the default backend server, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_ENABLED", "off") + // Default backend server IP, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1") + // Default backend server port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_RTMP", "1935") + // Default backend api port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985") + // Default backend udp rtc port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000") + // Default backend udp srt port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_SRT", "10080") + + logger.Df(ctx, "load .env as GO_PPROF=%v, "+ + "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ + "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ + "PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+ + "PROXY_SYSTEM_API=%v, PROXY_STATIC_FILES=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ + "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ + "PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+ + "PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+ + "PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+ + "PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v", + envGoPprof(), + envForceQuitTimeout(), envGraceQuitTimeout(), + envHttpAPI(), envHttpServer(), envRtmpServer(), + envWebRTCServer(), envSRTServer(), + envSystemAPI(), envStaticFiles(), envDefaultBackendEnabled(), + envDefaultBackendIP(), envDefaultBackendRTMP(), + envDefaultBackendHttp(), envDefaultBackendAPI(), + envDefaultBackendRTC(), envDefaultBackendSRT(), + envLoadBalancerType(), envRedisHost(), envRedisPort(), + envRedisPassword(), envRedisDB(), + ) +} + +func envStaticFiles() string { + return os.Getenv("PROXY_STATIC_FILES") +} + +func envDefaultBackendSRT() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_SRT") +} + +func envDefaultBackendRTC() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_RTC") +} + +func envDefaultBackendAPI() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_API") +} + +func envSRTServer() string { + return os.Getenv("PROXY_SRT_SERVER") +} + +func envWebRTCServer() string { + return os.Getenv("PROXY_WEBRTC_SERVER") +} + +func envDefaultBackendHttp() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_HTTP") +} + +func envRedisDB() string { + return os.Getenv("PROXY_REDIS_DB") +} + +func envRedisPassword() string { + return os.Getenv("PROXY_REDIS_PASSWORD") +} + +func envRedisPort() string { + return os.Getenv("PROXY_REDIS_PORT") +} + +func envRedisHost() string { + return os.Getenv("PROXY_REDIS_HOST") +} + +func envLoadBalancerType() string { + return os.Getenv("PROXY_LOAD_BALANCER_TYPE") +} + +func envDefaultBackendRTMP() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP") +} + +func envDefaultBackendIP() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_IP") +} + +func envDefaultBackendEnabled() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED") +} + +func envGraceQuitTimeout() string { + return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT") +} + +func envForceQuitTimeout() string { + return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT") +} + +func envGoPprof() string { + return os.Getenv("GO_PPROF") +} + +func envSystemAPI() string { + return os.Getenv("PROXY_SYSTEM_API") +} + +func envRtmpServer() string { + return os.Getenv("PROXY_RTMP_SERVER") +} + +func envHttpServer() string { + return os.Getenv("PROXY_HTTP_SERVER") +} + +func envHttpAPI() string { + return os.Getenv("PROXY_HTTP_API") +} + +// setEnvDefault set env key=value if not set. +func setEnvDefault(key, value string) { + if os.Getenv(key) == "" { + os.Setenv(key, value) + } +} diff --git a/proxy/errors/errors.go b/proxy/errors/errors.go new file mode 100644 index 000000000..257bc3ccd --- /dev/null +++ b/proxy/errors/errors.go @@ -0,0 +1,270 @@ +// Package errors provides simple error handling primitives. +// +// The traditional error handling idiom in Go is roughly akin to +// +// if err != nil { +// return err +// } +// +// which applied recursively up the call stack results in error reports +// without context or debugging information. The errors package allows +// programmers to add context to the failure path in their code in a way +// that does not destroy the original value of the error. +// +// Adding context to an error +// +// The errors.Wrap function returns a new error that adds context to the +// original error by recording a stack trace at the point Wrap is called, +// and the supplied message. For example +// +// _, err := ioutil.ReadAll(r) +// if err != nil { +// return errors.Wrap(err, "read failed") +// } +// +// If additional control is required the errors.WithStack and errors.WithMessage +// functions destructure errors.Wrap into its component operations of annotating +// an error with a stack trace and an a message, respectively. +// +// Retrieving the cause of an error +// +// Using errors.Wrap constructs a stack of errors, adding context to the +// preceding error. Depending on the nature of the error it may be necessary +// to reverse the operation of errors.Wrap to retrieve the original error +// for inspection. Any error value which implements this interface +// +// type causer interface { +// Cause() error +// } +// +// can be inspected by errors.Cause. errors.Cause will recursively retrieve +// the topmost error which does not implement causer, which is assumed to be +// the original cause. For example: +// +// switch err := errors.Cause(err).(type) { +// case *MyError: +// // handle specifically +// default: +// // unknown error +// } +// +// causer interface is not exported by this package, but is considered a part +// of stable public API. +// +// Formatted printing of errors +// +// All error values returned from this package implement fmt.Formatter and can +// be formatted by the fmt package. The following verbs are supported +// +// %s print the error. If the error has a Cause it will be +// printed recursively +// %v see %s +// %+v extended format. Each Frame of the error's StackTrace will +// be printed in detail. +// +// Retrieving the stack trace of an error or wrapper +// +// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are +// invoked. This information can be retrieved with the following interface. +// +// type stackTracer interface { +// StackTrace() errors.StackTrace +// } +// +// Where errors.StackTrace is defined as +// +// type StackTrace []Frame +// +// The Frame type represents a call site in the stack trace. Frame supports +// the fmt.Formatter interface that can be used for printing information about +// the stack trace of this error. For example: +// +// if err, ok := err.(stackTracer); ok { +// for _, f := range err.StackTrace() { +// fmt.Printf("%+s:%d", f) +// } +// } +// +// stackTracer interface is not exported by this package, but is considered a part +// of stable public API. +// +// See the documentation for Frame.Format for more details. +// Fork from https://github.com/pkg/errors +package errors + +import ( + "fmt" + "io" +) + +// New returns an error with the supplied message. +// New also records the stack trace at the point it was called. +func New(message string) error { + return &fundamental{ + msg: message, + stack: callers(), + } +} + +// Errorf formats according to a format specifier and returns the string +// as a value that satisfies error. +// Errorf also records the stack trace at the point it was called. +func Errorf(format string, args ...interface{}) error { + return &fundamental{ + msg: fmt.Sprintf(format, args...), + stack: callers(), + } +} + +// fundamental is an error that has a message and a stack, but no caller. +type fundamental struct { + msg string + *stack +} + +func (f *fundamental) Error() string { return f.msg } + +func (f *fundamental) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + io.WriteString(s, f.msg) + f.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, f.msg) + case 'q': + fmt.Fprintf(s, "%q", f.msg) + } +} + +// WithStack annotates err with a stack trace at the point WithStack was called. +// If err is nil, WithStack returns nil. +func WithStack(err error) error { + if err == nil { + return nil + } + return &withStack{ + err, + callers(), + } +} + +type withStack struct { + error + *stack +} + +func (w *withStack) Cause() error { return w.error } + +func (w *withStack) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v", w.Cause()) + w.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, w.Error()) + case 'q': + fmt.Fprintf(s, "%q", w.Error()) + } +} + +// Wrap returns an error annotating err with a stack trace +// at the point Wrap is called, and the supplied message. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: message, + } + return &withStack{ + err, + callers(), + } +} + +// Wrapf returns an error annotating err with a stack trace +// at the point Wrapf is call, and the format specifier. +// If err is nil, Wrapf returns nil. +func Wrapf(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: fmt.Sprintf(format, args...), + } + return &withStack{ + err, + callers(), + } +} + +// WithMessage annotates err with a new message. +// If err is nil, WithMessage returns nil. +func WithMessage(err error, message string) error { + if err == nil { + return nil + } + return &withMessage{ + cause: err, + msg: message, + } +} + +type withMessage struct { + cause error + msg string +} + +func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } +func (w *withMessage) Cause() error { return w.cause } + +func (w *withMessage) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v\n", w.Cause()) + io.WriteString(s, w.msg) + return + } + fallthrough + case 's', 'q': + io.WriteString(s, w.Error()) + } +} + +// Cause returns the underlying cause of the error, if possible. +// An error value has a cause if it implements the following +// interface: +// +// type causer interface { +// Cause() error +// } +// +// If the error does not implement Cause, the original error will +// be returned. If the error is nil, nil will be returned without further +// investigation. +func Cause(err error) error { + type causer interface { + Cause() error + } + + for err != nil { + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + return err +} diff --git a/proxy/errors/stack.go b/proxy/errors/stack.go new file mode 100644 index 000000000..6c42db5a8 --- /dev/null +++ b/proxy/errors/stack.go @@ -0,0 +1,187 @@ +// Fork from https://github.com/pkg/errors +package errors + +import ( + "fmt" + "io" + "path" + "runtime" + "strings" +) + +// Frame represents a program counter inside a stack frame. +type Frame uintptr + +// pc returns the program counter for this frame; +// multiple frames may have the same PC value. +func (f Frame) pc() uintptr { return uintptr(f) - 1 } + +// file returns the full path to the file that contains the +// function for this Frame's pc. +func (f Frame) file() string { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return "unknown" + } + file, _ := fn.FileLine(f.pc()) + return file +} + +// line returns the line number of source code of the +// function for this Frame's pc. +func (f Frame) line() int { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return 0 + } + _, line := fn.FileLine(f.pc()) + return line +} + +// Format formats the frame according to the fmt.Formatter interface. +// +// %s source file +// %d source line +// %n function name +// %v equivalent to %s:%d +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+s path of source file relative to the compile time GOPATH +// %+v equivalent to %+s:%d +func (f Frame) Format(s fmt.State, verb rune) { + switch verb { + case 's': + switch { + case s.Flag('+'): + pc := f.pc() + fn := runtime.FuncForPC(pc) + if fn == nil { + io.WriteString(s, "unknown") + } else { + file, _ := fn.FileLine(pc) + fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) + } + default: + io.WriteString(s, path.Base(f.file())) + } + case 'd': + fmt.Fprintf(s, "%d", f.line()) + case 'n': + name := runtime.FuncForPC(f.pc()).Name() + io.WriteString(s, funcname(name)) + case 'v': + f.Format(s, 's') + io.WriteString(s, ":") + f.Format(s, 'd') + } +} + +// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). +type StackTrace []Frame + +// Format formats the stack of Frames according to the fmt.Formatter interface. +// +// %s lists source files for each Frame in the stack +// %v lists the source file and line number for each Frame in the stack +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+v Prints filename, function, and line number for each Frame in the stack. +func (st StackTrace) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case s.Flag('+'): + for _, f := range st { + fmt.Fprintf(s, "\n%+v", f) + } + case s.Flag('#'): + fmt.Fprintf(s, "%#v", []Frame(st)) + default: + fmt.Fprintf(s, "%v", []Frame(st)) + } + case 's': + fmt.Fprintf(s, "%s", []Frame(st)) + } +} + +// stack represents a stack of program counters. +type stack []uintptr + +func (s *stack) Format(st fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case st.Flag('+'): + for _, pc := range *s { + f := Frame(pc) + fmt.Fprintf(st, "\n%+v", f) + } + } + } +} + +func (s *stack) StackTrace() StackTrace { + f := make([]Frame, len(*s)) + for i := 0; i < len(f); i++ { + f[i] = Frame((*s)[i]) + } + return f +} + +func callers() *stack { + const depth = 32 + var pcs [depth]uintptr + n := runtime.Callers(3, pcs[:]) + var st stack = pcs[0:n] + return &st +} + +// funcname removes the path prefix component of a function's name reported by func.Name(). +func funcname(name string) string { + i := strings.LastIndex(name, "/") + name = name[i+1:] + i = strings.Index(name, ".") + return name[i+1:] +} + +func trimGOPATH(name, file string) string { + // Here we want to get the source file path relative to the compile time + // GOPATH. As of Go 1.6.x there is no direct way to know the compiled + // GOPATH at runtime, but we can infer the number of path segments in the + // GOPATH. We note that fn.Name() returns the function name qualified by + // the import path, which does not include the GOPATH. Thus we can trim + // segments from the beginning of the file path until the number of path + // separators remaining is one more than the number of path separators in + // the function name. For example, given: + // + // GOPATH /home/user + // file /home/user/src/pkg/sub/file.go + // fn.Name() pkg/sub.Type.Method + // + // We want to produce: + // + // pkg/sub/file.go + // + // From this we can easily see that fn.Name() has one less path separator + // than our desired output. We count separators from the end of the file + // path until it finds two more than in the function name and then move + // one character forward to preserve the initial path segment without a + // leading separator. + const sep = "/" + goal := strings.Count(name, sep) + 2 + i := len(file) + for n := 0; n < goal; n++ { + i = strings.LastIndex(file[:i], sep) + if i == -1 { + // not enough separators found, set i so that the slice expression + // below leaves file unmodified + i = -len(sep) + break + } + } + // get back to 0 or trim the leading separator + file = file[i+len(sep):] + return file +} diff --git a/proxy/go.mod b/proxy/go.mod new file mode 100644 index 000000000..2e2a17ab3 --- /dev/null +++ b/proxy/go.mod @@ -0,0 +1,13 @@ +module srs-proxy + +go 1.18 + +require ( + github.com/go-redis/redis/v8 v8.11.5 + github.com/joho/godotenv v1.5.1 +) + +require ( + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect +) diff --git a/proxy/go.sum b/proxy/go.sum new file mode 100644 index 000000000..1efc5318e --- /dev/null +++ b/proxy/go.sum @@ -0,0 +1,17 @@ +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/proxy/http.go b/proxy/http.go new file mode 100644 index 000000000..f02af02a3 --- /dev/null +++ b/proxy/http.go @@ -0,0 +1,419 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "strconv" + "strings" + stdSync "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS, +// HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy +// the request to the origin server. +type srsHTTPStreamServer struct { + // The underlayer HTTP server. + server *http.Server + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg stdSync.WaitGroup +} + +func NewSRSHTTPStreamServer(opts ...func(*srsHTTPStreamServer)) *srsHTTPStreamServer { + v := &srsHTTPStreamServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsHTTPStreamServer) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *srsHTTPStreamServer) Run(ctx context.Context) error { + // Parse address to listen. + addr := envHttpServer() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "HTTP Stream server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + type Response struct { + Code int `json:"code"` + PID string `json:"pid"` + Data struct { + Major int `json:"major"` + Minor int `json:"minor"` + Revision int `json:"revision"` + Version string `json:"version"` + } `json:"data"` + } + + res := Response{Code: 0, PID: fmt.Sprintf("%v", os.Getpid())} + res.Data.Major = VersionMajor() + res.Data.Minor = VersionMinor() + res.Data.Revision = VersionRevision() + res.Data.Version = Version() + + apiResponse(ctx, w, r, &res) + }) + + // The static web server, for the web pages. + var staticServer http.Handler + if staticFiles := envStaticFiles(); staticFiles != "" { + if _, err := os.Stat(staticFiles); err != nil { + return errors.Wrapf(err, "invalid static files %v", staticFiles) + } + + staticServer = http.FileServer(http.Dir(staticFiles)) + logger.Df(ctx, "Handle static files at %v", staticFiles) + } + + // The default handler, for both static web server and streaming server. + logger.Df(ctx, "Handle / by %v", addr) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + // For HLS streaming, we will proxy the request to the streaming server. + if strings.HasSuffix(r.URL.Path, ".m3u8") { + unifiedURL, fullURL := convertURLToStreamURL(r) + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + http.Error(w, fmt.Sprintf("build stream url by %v from %v", unifiedURL, fullURL), http.StatusBadRequest) + return + } + + stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) { + s.SRSProxyBackendHLSID = logger.GenerateContextID() + s.StreamURL, s.FullURL = streamURL, fullURL + })) + + stream.Initialize(ctx).ServeHTTP(w, r) + return + } + + // For HTTP streaming, we will proxy the request to the streaming server. + if strings.HasSuffix(r.URL.Path, ".flv") || + strings.HasSuffix(r.URL.Path, ".ts") { + // If SPBHID is specified, it must be a HLS stream client. + if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" { + if stream, err := srsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { + http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest) + } else { + stream.Initialize(ctx).ServeHTTP(w, r) + } + return + } + + // Use HTTP pseudo streaming to proxy the request. + NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) { + c.ctx = ctx + }).ServeHTTP(w, r) + return + } + + // Serve by static server. + if staticServer != nil { + staticServer.ServeHTTP(w, r) + return + } + + http.NotFound(w, r) + }) + + // Run HTTP server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If HTTP Stream server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "HTTP Stream accept err %+v", err) + } else { + logger.Df(ctx, "HTTP Stream server done") + } + } + }() + + return nil +} + +// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS +// connection. There is no state need to be sync between proxy servers. +// +// When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request, +// then proxy to the corresponding backend server. All state is in the HTTP request, so this +// connection is stateless. +type HTTPFlvTsConnection struct { + // The context for HTTP streaming. + ctx context.Context +} + +func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection { + v := &HTTPFlvTsConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + ctx := logger.WithContext(v.ctx) + + if err := v.serve(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } else { + logger.Df(ctx, "HTTP client done") + } +} + +func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.serveByBackend(ctx, w, r, backend); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { + // Parse HTTP port from backend. + if len(backend.HTTP) == 0 { + return errors.Errorf("no http stream server") + } + + var httpPort int + if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.HTTP[0]) + } else { + httpPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Wrapf(err, "do request to %v", backendURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + logger.Df(ctx, "HTTP start streaming") + + // Proxy the stream from backend to client. + if _, err := io.Copy(w, resp.Body); err != nil { + return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL) + } + + return nil +} + +// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS +// clients will share this object, and they do not use the same ctx among proxy servers. +// +// Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections. +// Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create +// the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert +// to the stream URL and then query the backend server to serve it. +type HLSPlayStream struct { + // The context for HLS streaming. + ctx context.Context + + // The spbhid, used to identify the backend server. + SRSProxyBackendHLSID string `json:"spbhid"` + // The stream URL in vhost/app/stream schema. + StreamURL string `json:"stream_url"` + // The full request URL for HLS streaming + FullURL string `json:"full_url"` +} + +func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream { + v := &HLSPlayStream{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream { + if v.ctx == nil { + v.ctx = logger.WithContext(ctx) + } + return v +} + +func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + if err := v.serve(v.ctx, w, r); err != nil { + apiError(v.ctx, w, r, err) + } else { + logger.Df(v.ctx, "HLS client %v for %v with %v done", + v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path) + } +} + +func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.serveByBackend(ctx, w, r, backend); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { + // Parse HTTP port from backend. + if len(backend.HTTP) == 0 { + return errors.Errorf("no rtmp server") + } + + var httpPort int + if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.HTTP[0]) + } else { + httpPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + if r.URL.RawQuery != "" { + backendURL += "?" + r.URL.RawQuery + } + + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Errorf("do request to %v EOF", backendURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + // For TS file, directly copy it. + if !strings.HasSuffix(r.URL.Path, ".m3u8") { + if _, err := io.Copy(w, resp.Body); err != nil { + return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL) + } + + return nil + } + + // Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts + // URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID. + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrapf(err, "read stream from %v", backendURL) + } + + m3u8 := string(b) + if strings.Contains(m3u8, ".ts?") { + m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID)) + } else { + m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID)) + } + + if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil { + return errors.Wrapf(err, "proxy m3u8 client to %v", backendURL) + } + + return nil +} diff --git a/proxy/logger/context.go b/proxy/logger/context.go new file mode 100644 index 000000000..ef15a7d4f --- /dev/null +++ b/proxy/logger/context.go @@ -0,0 +1,43 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package logger + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" +) + +type key string + +var cidKey key = "cid.proxy.ossrs.org" + +// generateContextID generates a random context id in string. +func GenerateContextID() string { + randomBytes := make([]byte, 32) + _, _ = rand.Read(randomBytes) + hash := sha256.Sum256(randomBytes) + hashString := hex.EncodeToString(hash[:]) + cid := hashString[:7] + return cid +} + +// WithContext creates a new context with cid, which will be used for log. +func WithContext(ctx context.Context) context.Context { + return WithContextID(ctx, GenerateContextID()) +} + +// WithContextID creates a new context with cid, which will be used for log. +func WithContextID(ctx context.Context, cid string) context.Context { + return context.WithValue(ctx, cidKey, cid) +} + +// ContextID returns the cid in context, or empty string if not set. +func ContextID(ctx context.Context) string { + if cid, ok := ctx.Value(cidKey).(string); ok { + return cid + } + return "" +} diff --git a/proxy/logger/log.go b/proxy/logger/log.go new file mode 100644 index 000000000..debbe1a84 --- /dev/null +++ b/proxy/logger/log.go @@ -0,0 +1,87 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package logger + +import ( + "context" + "io/ioutil" + stdLog "log" + "os" +) + +type logger interface { + Printf(ctx context.Context, format string, v ...any) +} + +type loggerPlus struct { + logger *stdLog.Logger + level string +} + +func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus { + v := &loggerPlus{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) { + format, args := f, a + if cid := ContextID(ctx); cid != "" { + format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...) + } + + v.logger.Printf(format, args...) +} + +var verboseLogger logger + +func Vf(ctx context.Context, format string, a ...interface{}) { + verboseLogger.Printf(ctx, format, a...) +} + +var debugLogger logger + +func Df(ctx context.Context, format string, a ...interface{}) { + debugLogger.Printf(ctx, format, a...) +} + +var warnLogger logger + +func Wf(ctx context.Context, format string, a ...interface{}) { + warnLogger.Printf(ctx, format, a...) +} + +var errorLogger logger + +func Ef(ctx context.Context, format string, a ...interface{}) { + errorLogger.Printf(ctx, format, a...) +} + +const ( + logVerboseLabel = "verb" + logDebugLabel = "debug" + logWarnLabel = "warn" + logErrorLabel = "error" +) + +func init() { + verboseLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logVerboseLabel + }) + debugLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logDebugLabel + }) + warnLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logWarnLabel + }) + errorLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logErrorLabel + }) +} diff --git a/proxy/main.go b/proxy/main.go new file mode 100644 index 000000000..6327a7cf8 --- /dev/null +++ b/proxy/main.go @@ -0,0 +1,121 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +func main() { + ctx := logger.WithContext(context.Background()) + logger.Df(ctx, "%v/%v started", Signature(), Version()) + + // Install signals. + ctx, cancel := context.WithCancel(ctx) + installSignals(ctx, cancel) + + // Start the main loop, ignore the user cancel error. + err := doMain(ctx) + if err != nil && ctx.Err() != context.Canceled { + logger.Ef(ctx, "main: %+v", err) + os.Exit(-1) + } + + logger.Df(ctx, "%v done", Signature()) +} + +func doMain(ctx context.Context) error { + // Setup the environment variables. + if err := loadEnvFile(ctx); err != nil { + return errors.Wrapf(err, "load env") + } + + buildDefaultEnvironmentVariables(ctx) + + // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur + // because the main thread exits after the context is cancelled. However, sometimes the main thread + // may be blocked for some reason, so a forced exit is necessary to ensure the program terminates. + if err := installForceQuit(ctx); err != nil { + return errors.Wrapf(err, "install force quit") + } + + // Start the Go pprof if enabled. + handleGoPprof(ctx) + + // Initialize SRS load balancers. + switch lbType := envLoadBalancerType(); lbType { + case "memory": + srsLoadBalancer = NewMemoryLoadBalancer() + case "redis": + srsLoadBalancer = NewRedisLoadBalancer() + default: + return errors.Errorf("invalid load balancer %v", lbType) + } + + if err := srsLoadBalancer.Initialize(ctx); err != nil { + return errors.Wrapf(err, "initialize srs load balancer") + } + + // Parse the gracefully quit timeout. + gracefulQuitTimeout, err := parseGracefullyQuitTimeout() + if err != nil { + return errors.Wrapf(err, "parse gracefully quit timeout") + } + + // Start the RTMP server. + srsRTMPServer := NewSRSRTMPServer() + defer srsRTMPServer.Close() + if err := srsRTMPServer.Run(ctx); err != nil { + return errors.Wrapf(err, "rtmp server") + } + + // Start the WebRTC server. + srsWebRTCServer := NewSRSWebRTCServer() + defer srsWebRTCServer.Close() + if err := srsWebRTCServer.Run(ctx); err != nil { + return errors.Wrapf(err, "rtc server") + } + + // Start the HTTP API server. + srsHTTPAPIServer := NewSRSHTTPAPIServer(func(server *srsHTTPAPIServer) { + server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, srsWebRTCServer + }) + defer srsHTTPAPIServer.Close() + if err := srsHTTPAPIServer.Run(ctx); err != nil { + return errors.Wrapf(err, "http api server") + } + + // Start the SRT server. + srsSRTServer := NewSRSSRTServer() + defer srsSRTServer.Close() + if err := srsSRTServer.Run(ctx); err != nil { + return errors.Wrapf(err, "srt server") + } + + // Start the System API server. + systemAPI := NewSystemAPI(func(server *systemAPI) { + server.gracefulQuitTimeout = gracefulQuitTimeout + }) + defer systemAPI.Close() + if err := systemAPI.Run(ctx); err != nil { + return errors.Wrapf(err, "system api server") + } + + // Start the HTTP web server. + srsHTTPStreamServer := NewSRSHTTPStreamServer(func(server *srsHTTPStreamServer) { + server.gracefulQuitTimeout = gracefulQuitTimeout + }) + defer srsHTTPStreamServer.Close() + if err := srsHTTPStreamServer.Run(ctx); err != nil { + return errors.Wrapf(err, "http server") + } + + // Wait for the main loop to quit. + <-ctx.Done() + return nil +} diff --git a/proxy/rtc.go b/proxy/rtc.go new file mode 100644 index 000000000..5a7d9936c --- /dev/null +++ b/proxy/rtc.go @@ -0,0 +1,515 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/binary" + "fmt" + "io/ioutil" + "net" + "net/http" + "strconv" + "strings" + stdSync "sync" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out +// which backend server to proxy to. It will also replace the UDP port to the proxy server's in the +// SDP answer. +type srsWebRTCServer struct { + // The UDP listener for WebRTC server. + listener *net.UDPConn + + // Fast cache for the username to identify the connection. + // The key is username, the value is the UDP address. + usernames sync.Map[string, *RTCConnection] + // Fast cache for the udp address to identify the connection. + // The key is UDP address, the value is the username. + // TODO: Support fast earch by uint64 address. + addresses sync.Map[string, *RTCConnection] + + // The wait group for server. + wg stdSync.WaitGroup +} + +func NewSRSWebRTCServer(opts ...func(*srsWebRTCServer)) *srsWebRTCServer { + v := &srsWebRTCServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsWebRTCServer) Close() error { + if v.listener != nil { + _ = v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + ctx = logger.WithContext(ctx) + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Read remote SDP offer from body. + remoteSDPOffer, err := ioutil.ReadAll(r.Body) + if err != nil { + return errors.Wrapf(err, "read remote sdp offer") + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got WebRTC WHIP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + ctx = logger.WithContext(ctx) + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Read remote SDP offer from body. + remoteSDPOffer, err := ioutil.ReadAll(r.Body) + if err != nil { + return errors.Wrapf(err, "read remote sdp offer") + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got WebRTC WHEP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *srsWebRTCServer) proxyApiToBackend( + ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, + remoteSDPOffer string, streamURL string, +) error { + // Parse HTTP port from backend. + if len(backend.API) == 0 { + return errors.Errorf("no http api server") + } + + var apiPort int + if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.API[0]) + } else { + apiPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path) + if r.URL.RawQuery != "" { + backendURL += "?" + r.URL.RawQuery + } + + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer)) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Errorf("do request to %v EOF", backendURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return errors.Errorf("proxy api to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + // Parse the local SDP answer from backend. + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrapf(err, "read stream from %v", backendURL) + } + + // Replace the WebRTC UDP port in answer. + localSDPAnswer := string(b) + for _, endpoint := range backend.RTC { + _, _, port, err := parseListenEndpoint(endpoint) + if err != nil { + return errors.Wrapf(err, "parse endpoint %v", endpoint) + } + + from := fmt.Sprintf(" %v typ host", port) + to := fmt.Sprintf(" %v typ host", envWebRTCServer()) + localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1) + } + + // Fetch the ice-ufrag and ice-pwd from local SDP answer. + remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer) + if err != nil { + return errors.Wrapf(err, "parse remote sdp offer") + } + + localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer) + if err != nil { + return errors.Wrapf(err, "parse local sdp answer") + } + + // Save the new WebRTC connection to LB. + icePair := &RTCICEPair{ + RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd, + LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd, + } + if err := srsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) { + c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag() + c.Initialize(ctx, v.listener) + + // Cache the connection for fast search by username. + v.usernames.Store(c.Ufrag, c) + })); err != nil { + return errors.Wrapf(err, "load or store webrtc %v", streamURL) + } + + // Response client with local answer. + if _, err = w.Write([]byte(localSDPAnswer)); err != nil { + return errors.Wrapf(err, "write local sdp answer %v", localSDPAnswer) + } + + logger.Df(ctx, "Create WebRTC connection with local answer %vB with ice-ufrag=%v, ice-pwd=%vB", + len(localSDPAnswer), localICEUfrag, len(localICEPwd)) + return nil +} + +func (v *srsWebRTCServer) Run(ctx context.Context) error { + // Parse address to listen. + endpoint := envWebRTCServer() + if !strings.Contains(endpoint, ":") { + endpoint = fmt.Sprintf(":%v", endpoint) + } + + saddr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + + listener, err := net.ListenUDP("udp", saddr) + if err != nil { + return errors.Wrapf(err, "listen udp %v", saddr) + } + v.listener = listener + logger.Df(ctx, "WebRTC server listen at %v", saddr) + + // Consume all messages from UDP media transport. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, caddr, err := listener.ReadFromUDP(buf) + if err != nil { + // TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "read from udp failed, err=%+v", err) + continue + } + + if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) + } + } + }() + + return nil +} + +func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + var connection *RTCConnection + + // If STUN binding request, parse the ufrag and identify the connection. + if err := func() error { + if rtcIsRTPOrRTCP(data) || !rtcIsSTUN(data) { + return nil + } + + var pkt RTCStunPacket + if err := pkt.UnmarshalBinary(data); err != nil { + return errors.Wrapf(err, "unmarshal stun packet") + } + + // Search the connection in fast cache. + if s, ok := v.usernames.Load(pkt.Username); ok { + connection = s + return nil + } + + // Load connection by username. + if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { + return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username) + } else { + connection = s.Initialize(ctx, v.listener) + logger.Df(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL) + } + + // Cache connection for fast search. + if connection != nil { + v.usernames.Store(pkt.Username, connection) + } + return nil + }(); err != nil { + return err + } + + // Search the connection by addr. + if s, ok := v.addresses.Load(addr.String()); ok { + connection = s + } else if connection != nil { + // Cache the address for fast search. + v.addresses.Store(addr.String(), connection) + } + + // If connection is not found, ignore the packet. + if connection == nil { + // TODO: Should logging the dropped packet, only logging the first one for each address. + return nil + } + + // Proxy the packet to backend. + if err := connection.HandlePacket(addr, data); err != nil { + return errors.Wrapf(err, "proxy %vB for %v", len(data), connection.StreamURL) + } + + return nil +} + +// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC +// connection, identify by the ufrag in sdp offer/answer and ICE binding request. +// +// It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is +// in the client request. The RTCConnection is stateful, and need to sync the ufrag between +// proxy servers. +// +// The media transport is UDP, which is also a special thing for WebRTC. So if the client switch +// to another UDP address, it may connect to another WebRTC proxy, then we should discover the +// RTCConnection by the ufrag from the ICE binding request. +type RTCConnection struct { + // The stream context for WebRTC streaming. + ctx context.Context + + // The stream URL in vhost/app/stream schema. + StreamURL string `json:"stream_url"` + // The ufrag for this WebRTC connection. + Ufrag string `json:"ufrag"` + + // The UDP connection proxy to backend. + backendUDP *net.UDPConn + // The client UDP address. Note that it may change. + clientUDP *net.UDPAddr + // The listener UDP connection, used to send messages to client. + listenerUDP *net.UDPConn +} + +func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection { + v := &RTCConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection { + if v.ctx == nil { + v.ctx = logger.WithContext(ctx) + } + if listener != nil { + v.listenerUDP = listener + } + return v +} + +func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error { + ctx := v.ctx + + // Update the current UDP address. + v.clientUDP = addr + + // Start the UDP proxy to backend. + if err := v.connectBackend(ctx); err != nil { + return errors.Wrapf(err, "connect backend for %v", v.StreamURL) + } + + // Proxy client message to backend. + if v.backendUDP == nil { + return nil + } + + // Proxy all messages from backend to client. + go func() { + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, _, err := v.backendUDP.ReadFromUDP(buf) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + break + } + + if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + break + } + } + }() + + if _, err := v.backendUDP.Write(data); err != nil { + return errors.Wrapf(err, "write to backend %v", v.StreamURL) + } + + return nil +} + +func (v *RTCConnection) connectBackend(ctx context.Context) error { + if v.backendUDP != nil { + return nil + } + + // Pick a backend SRS server to proxy the RTC stream. + backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL) + if err != nil { + return errors.Wrapf(err, "pick backend") + } + + // Parse UDP port from backend. + if len(backend.RTC) == 0 { + return errors.Errorf("no udp server") + } + + _, _, udpPort, err := parseListenEndpoint(backend.RTC[0]) + if err != nil { + return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.RTC[0], backend, v.StreamURL) + } + + // Connect to backend SRS server via UDP client. + // TODO: FIXME: Support close the connection when timeout or DTLS alert. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} + if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { + return errors.Wrapf(err, "dial udp to %v", backendAddr) + } else { + v.backendUDP = backendUDP + } + + return nil +} + +type RTCICEPair struct { + // The remote ufrag, used for ICE username and session id. + RemoteICEUfrag string `json:"remote_ufrag"` + // The remote pwd, used for ICE password. + RemoteICEPwd string `json:"remote_pwd"` + // The local ufrag, used for ICE username and session id. + LocalICEUfrag string `json:"local_ufrag"` + // The local pwd, used for ICE password. + LocalICEPwd string `json:"local_pwd"` +} + +// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag. +func (v *RTCICEPair) Ufrag() string { + return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag) +} + +type RTCStunPacket struct { + // The stun message type. + MessageType uint16 + // The stun username, or ufrag. + Username string +} + +func (v *RTCStunPacket) UnmarshalBinary(data []byte) error { + if len(data) < 20 { + return errors.Errorf("stun packet too short %v", len(data)) + } + + p := data + v.MessageType = binary.BigEndian.Uint16(p) + messageLen := binary.BigEndian.Uint16(p[2:]) + //magicCookie := p[:8] + //transactionID := p[:20] + p = p[20:] + + if len(p) != int(messageLen) { + return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen) + } + + for len(p) > 0 { + typ := binary.BigEndian.Uint16(p) + length := binary.BigEndian.Uint16(p[2:]) + p = p[4:] + + if len(p) < int(length) { + return errors.Errorf("stun attribute length invalid %v < %v", len(p), length) + } + + value := p[:length] + p = p[length:] + + if length%4 != 0 { + p = p[4-length%4:] + } + + switch typ { + case 0x0006: + v.Username = string(value) + } + } + + return nil +} diff --git a/proxy/rtmp.go b/proxy/rtmp.go new file mode 100644 index 000000000..d93f04b3a --- /dev/null +++ b/proxy/rtmp.go @@ -0,0 +1,655 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/rtmp" +) + +// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS +// server. It will figure out the backend server to proxy to. Unlike the edge server, it will +// not cache the stream, but just proxy the stream to backend. +type srsRTMPServer struct { + // The TCP listener for RTMP server. + listener *net.TCPListener + // The random number generator. + rd *rand.Rand + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewSRSRTMPServer(opts ...func(*srsRTMPServer)) *srsRTMPServer { + v := &srsRTMPServer{ + rd: rand.New(rand.NewSource(time.Now().UnixNano())), + } + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsRTMPServer) Close() error { + if v.listener != nil { + v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srsRTMPServer) Run(ctx context.Context) error { + endpoint := envRtmpServer() + if !strings.Contains(endpoint, ":") { + endpoint = ":" + endpoint + } + + addr, err := net.ResolveTCPAddr("tcp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve rtmp addr %v", endpoint) + } + + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + return errors.Wrapf(err, "listen rtmp addr %v", addr) + } + v.listener = listener + logger.Df(ctx, "RTMP server listen at %v", addr) + + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for { + conn, err := v.listener.AcceptTCP() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "RTMP server accept err %+v", err) + } else { + logger.Df(ctx, "RTMP server done") + } + return + } + + v.wg.Add(1) + go func(ctx context.Context, conn *net.TCPConn) { + defer v.wg.Done() + defer conn.Close() + + handleErr := func(err error) { + if isPeerClosedError(err) { + logger.Df(ctx, "RTMP peer is closed") + } else { + logger.Wf(ctx, "RTMP serve err %+v", err) + } + } + + rc := NewRTMPConnection(func(client *RTMPConnection) { + client.rd = v.rd + }) + if err := rc.serve(ctx, conn); err != nil { + handleErr(err) + } else { + logger.Df(ctx, "RTMP client done") + } + }(logger.WithContext(ctx), conn) + } + }() + + return nil +} + +// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between +// proxy servers. +// +// When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request, +// then proxy to the corresponding backend server. All state is in the RTMP request, so this +// connection is stateless. +type RTMPConnection struct { + // The random number generator. + rd *rand.Rand +} + +func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection { + v := &RTMPConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { + logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr()) + + // If any goroutine quit, cancel another one. + parentCtx := ctx + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var backend *RTMPClientToBackend + if true { + go func() { + <-ctx.Done() + conn.Close() + if backend != nil { + backend.Close() + } + }() + } + + // Simple handshake with client. + hs := rtmp.NewHandshake(v.rd) + if _, err := hs.ReadC0S0(conn); err != nil { + return errors.Wrapf(err, "read c0") + } + if _, err := hs.ReadC1S1(conn); err != nil { + return errors.Wrapf(err, "read c1") + } + if err := hs.WriteC0S0(conn); err != nil { + return errors.Wrapf(err, "write s1") + } + if err := hs.WriteC1S1(conn); err != nil { + return errors.Wrapf(err, "write s1") + } + if err := hs.WriteC2S2(conn, hs.C1S1()); err != nil { + return errors.Wrapf(err, "write s2") + } + if _, err := hs.ReadC2S2(conn); err != nil { + return errors.Wrapf(err, "read c2") + } + + client := rtmp.NewProtocol(conn) + logger.Df(ctx, "RTMP simple handshake done") + + // Expect RTMP connect command with tcUrl. + var connectReq *rtmp.ConnectAppPacket + if _, err := rtmp.ExpectPacket(ctx, client, &connectReq); err != nil { + return errors.Wrapf(err, "expect connect req") + } + + if true { + ack := rtmp.NewWindowAcknowledgementSize() + ack.AckSize = 2500000 + if err := client.WritePacket(ctx, ack, 0); err != nil { + return errors.Wrapf(err, "write set ack size") + } + } + if true { + chunk := rtmp.NewSetChunkSize() + chunk.ChunkSize = 128 + if err := client.WritePacket(ctx, chunk, 0); err != nil { + return errors.Wrapf(err, "write set chunk size") + } + } + + connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID) + connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888")) + connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127)) + connectRes.CommandObject.Set("mode", rtmp.NewAmf0Number(1)) + connectRes.Args.Set("level", rtmp.NewAmf0String("status")) + connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success")) + connectRes.Args.Set("description", rtmp.NewAmf0String("Connection succeeded")) + connectRes.Args.Set("objectEncoding", rtmp.NewAmf0Number(0)) + connectResData := rtmp.NewAmf0EcmaArray() + connectResData.Set("version", rtmp.NewAmf0String("3,5,3,888")) + connectResData.Set("srs_version", rtmp.NewAmf0String(Version())) + connectResData.Set("srs_id", rtmp.NewAmf0String(logger.ContextID(ctx))) + connectRes.Args.Set("data", connectResData) + if err := client.WritePacket(ctx, connectRes, 0); err != nil { + return errors.Wrapf(err, "write connect res") + } + + tcUrl := connectReq.TcUrl() + logger.Df(ctx, "RTMP connect app %v", tcUrl) + + // Expect RTMP command to identify the client, a publisher or viewer. + var currentStreamID, nextStreamID int + var streamName string + var clientType RTMPClientType + for clientType == "" { + var identifyReq rtmp.Packet + if _, err := rtmp.ExpectPacket(ctx, client, &identifyReq); err != nil { + return errors.Wrapf(err, "expect identify req") + } + + var response rtmp.Packet + switch pkt := identifyReq.(type) { + case *rtmp.CallPacket: + if pkt.CommandName == "createStream" { + identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID) + response = identifyRes + + nextStreamID = 1 + identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID)) + } else if pkt.CommandName == "getStreamLength" { + // Ignore and do not reply these packets. + } else { + // For releaseStream, FCPublish, etc. + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.TransactionID = pkt.TransactionID + identifyRes.CommandName = "_result" + identifyRes.CommandObject = rtmp.NewAmf0Null() + identifyRes.Args = rtmp.NewAmf0Undefined() + } + case *rtmp.PublishPacket: + streamName = string(pkt.StreamName) + clientType = RTMPClientTypePublisher + + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.CommandName = "onFCPublish" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) + data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) + identifyRes.Args = data + case *rtmp.PlayPacket: + streamName = string(pkt.StreamName) + clientType = RTMPClientTypeViewer + + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Play.Reset")) + data.Set("description", rtmp.NewAmf0String("Playing and resetting stream.")) + data.Set("details", rtmp.NewAmf0String("stream")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + } + + if response != nil { + if err := client.WritePacket(ctx, response, currentStreamID); err != nil { + return errors.Wrapf(err, "write identify res for req=%v, stream=%v", + identifyReq, currentStreamID) + } + } + + // Update the stream ID for next request. + currentStreamID = nextStreamID + } + logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v", + tcUrl, streamName, currentStreamID, clientType) + + // Find a backend SRS server to proxy the RTMP stream. + backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) { + client.rd, client.typ = v.rd, clientType + }) + defer backend.Close() + + if err := backend.Connect(ctx, tcUrl, streamName); err != nil { + return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName) + } + + // Start the streaming. + if clientType == RTMPClientTypePublisher { + identifyRes := rtmp.NewCallPacket() + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) + data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + + if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { + return errors.Wrapf(err, "start publish") + } + } else if clientType == RTMPClientTypeViewer { + identifyRes := rtmp.NewCallPacket() + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Play.Start")) + data.Set("description", rtmp.NewAmf0String("Started playing stream.")) + data.Set("details", rtmp.NewAmf0String("stream")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + + if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { + return errors.Wrapf(err, "start play") + } + } + logger.Df(ctx, "RTMP start streaming") + + // For all proxy goroutines. + var wg sync.WaitGroup + defer wg.Wait() + + // Proxy all message from backend to client. + wg.Add(1) + var r0 error + go func() { + defer wg.Done() + defer cancel() + + r0 = func() error { + for { + m, err := backend.client.ReadMessage(ctx) + if err != nil { + return errors.Wrapf(err, "read message") + } + //logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + + // TODO: Update the stream ID if not the same. + if err := client.WriteMessage(ctx, m); err != nil { + return errors.Wrapf(err, "write message") + } + } + }() + }() + + // Proxy all messages from client to backend. + wg.Add(1) + var r1 error + go func() { + defer wg.Done() + defer cancel() + + r1 = func() error { + for { + m, err := client.ReadMessage(ctx) + if err != nil { + return errors.Wrapf(err, "read message") + } + //logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + + // TODO: Update the stream ID if not the same. + if err := backend.client.WriteMessage(ctx, m); err != nil { + return errors.Wrapf(err, "write message") + } + } + }() + }() + + // Wait until all goroutine quit. + wg.Wait() + + // Reset the error if caused by another goroutine. + if r0 != nil { + return errors.Wrapf(r0, "proxy backend->client") + } + if r1 != nil { + return errors.Wrapf(r1, "proxy client->backend") + } + + return parentCtx.Err() +} + +type RTMPClientType string + +const ( + RTMPClientTypePublisher RTMPClientType = "publisher" + RTMPClientTypeViewer RTMPClientType = "viewer" +) + +// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend. +type RTMPClientToBackend struct { + // The random number generator. + rd *rand.Rand + // The underlayer tcp client. + tcpConn *net.TCPConn + // The RTMP protocol client. + client *rtmp.Protocol + // The stream type. + typ RTMPClientType +} + +func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend { + v := &RTMPClientToBackend{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTMPClientToBackend) Close() error { + if v.tcpConn != nil { + v.tcpConn.Close() + } + return nil +} + +func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error { + // Build the stream URL in vhost/app/stream schema. + streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName)) + if err != nil { + return errors.Wrapf(err, "build stream url %v/%v", tcUrl, streamName) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + // Parse RTMP port from backend. + if len(backend.RTMP) == 0 { + return errors.Errorf("no rtmp server %+v for %v", backend, streamURL) + } + + var rtmpPort int + if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.RTMP[0]) + } else { + rtmpPort = int(iv) + } + + // Connect to backend SRS server via TCP client. + addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort} + c, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend) + } + v.tcpConn = c + + hs := rtmp.NewHandshake(v.rd) + client := rtmp.NewProtocol(c) + v.client = client + + // Simple RTMP handshake with server. + if err := hs.WriteC0S0(c); err != nil { + return errors.Wrapf(err, "write c0") + } + if err := hs.WriteC1S1(c); err != nil { + return errors.Wrapf(err, "write c1") + } + + if _, err = hs.ReadC0S0(c); err != nil { + return errors.Wrapf(err, "read s0") + } + if _, err := hs.ReadC1S1(c); err != nil { + return errors.Wrapf(err, "read s1") + } + if _, err = hs.ReadC2S2(c); err != nil { + return errors.Wrapf(err, "read c2") + } + logger.Df(ctx, "backend simple handshake done, server=%v", addr) + + if err := hs.WriteC2S2(c, hs.C1S1()); err != nil { + return errors.Wrapf(err, "write c2") + } + + // Connect RTMP app on tcUrl with server. + if true { + connectApp := rtmp.NewConnectAppPacket() + connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl)) + if err := client.WritePacket(ctx, connectApp, 1); err != nil { + return errors.Wrapf(err, "write connect app") + } + } + + if true { + var connectAppRes *rtmp.ConnectAppResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil { + return errors.Wrapf(err, "expect connect app res") + } + logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID()) + } + + // Play or view RTMP stream with server. + if v.typ == RTMPClientTypeViewer { + return v.play(ctx, client, streamName) + } + + // Publish RTMP stream with server. + return v.publish(ctx, client, streamName) +} + +func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error { + if true { + identifyReq := rtmp.NewCallPacket() + identifyReq.CommandName = "releaseStream" + identifyReq.TransactionID = 2 + identifyReq.CommandObject = rtmp.NewAmf0Null() + identifyReq.Args = rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, identifyReq, 0); err != nil { + return errors.Wrapf(err, "releaseStream") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect releaseStream res") + } + if identifyRes.CommandName == "_result" { + break + } + } + + if true { + identifyReq := rtmp.NewCallPacket() + identifyReq.CommandName = "FCPublish" + identifyReq.TransactionID = 3 + identifyReq.CommandObject = rtmp.NewAmf0Null() + identifyReq.Args = rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, identifyReq, 0); err != nil { + return errors.Wrapf(err, "FCPublish") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect FCPublish res") + } + if identifyRes.CommandName == "_result" { + break + } + } + + var currentStreamID int + if true { + createStream := rtmp.NewCreateStreamPacket() + createStream.TransactionID = 4 + createStream.CommandObject = rtmp.NewAmf0Null() + if err := client.WritePacket(ctx, createStream, 0); err != nil { + return errors.Wrapf(err, "createStream") + } + } + for { + var identifyRes *rtmp.CreateStreamResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect createStream res") + } + if sid := identifyRes.StreamID; sid != 0 { + currentStreamID = int(sid) + break + } + } + + if true { + publishStream := rtmp.NewPublishPacket() + publishStream.TransactionID = 5 + publishStream.CommandObject = rtmp.NewAmf0Null() + publishStream.StreamName = *rtmp.NewAmf0String(streamName) + publishStream.StreamType = *rtmp.NewAmf0String("live") + if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil { + return errors.Wrapf(err, "publish") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect publish res") + } + // Ignore onFCPublish, expect onStatus(NetStream.Publish.Start). + if identifyRes.CommandName == "onStatus" { + if data := rtmp.NewAmf0Converter(identifyRes.Args).ToObject(); data == nil { + return errors.Errorf("onStatus args not object") + } else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil { + return errors.Errorf("onStatus code not string") + } else if *code != "NetStream.Publish.Start" { + return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code) + } + break + } + } + logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID) + + return nil +} + +func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error { + var currentStreamID int + if true { + createStream := rtmp.NewCreateStreamPacket() + createStream.TransactionID = 4 + createStream.CommandObject = rtmp.NewAmf0Null() + if err := client.WritePacket(ctx, createStream, 0); err != nil { + return errors.Wrapf(err, "createStream") + } + } + for { + var identifyRes *rtmp.CreateStreamResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect createStream res") + } + if sid := identifyRes.StreamID; sid != 0 { + currentStreamID = int(sid) + break + } + } + + playStream := rtmp.NewPlayPacket() + playStream.StreamName = *rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil { + return errors.Wrapf(err, "play") + } + + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect releaseStream res") + } + if identifyRes.CommandName == "onStatus" && identifyRes.ArgsCode() == "NetStream.Play.Start" { + break + } + } + return nil +} diff --git a/proxy/rtmp/amf0.go b/proxy/rtmp/amf0.go new file mode 100644 index 000000000..a013d5ecc --- /dev/null +++ b/proxy/rtmp/amf0.go @@ -0,0 +1,771 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bytes" + "encoding" + "encoding/binary" + "fmt" + "math" + "sync" + + "srs-proxy/errors" +) + +// Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview +type amf0Marker uint8 + +const ( + amf0MarkerNumber amf0Marker = iota // 0 + amf0MarkerBoolean // 1 + amf0MarkerString // 2 + amf0MarkerObject // 3 + amf0MarkerMovieClip // 4 + amf0MarkerNull // 5 + amf0MarkerUndefined // 6 + amf0MarkerReference // 7 + amf0MarkerEcmaArray // 8 + amf0MarkerObjectEnd // 9 + amf0MarkerStrictArray // 10 + amf0MarkerDate // 11 + amf0MarkerLongString // 12 + amf0MarkerUnsupported // 13 + amf0MarkerRecordSet // 14 + amf0MarkerXmlDocument // 15 + amf0MarkerTypedObject // 16 + amf0MarkerAvmPlusObject // 17 + + amf0MarkerForbidden amf0Marker = 0xff +) + +func (v amf0Marker) String() string { + switch v { + case amf0MarkerNumber: + return "Amf0Number" + case amf0MarkerBoolean: + return "amf0Boolean" + case amf0MarkerString: + return "Amf0String" + case amf0MarkerObject: + return "Amf0Object" + case amf0MarkerNull: + return "Null" + case amf0MarkerUndefined: + return "Undefined" + case amf0MarkerReference: + return "Reference" + case amf0MarkerEcmaArray: + return "EcmaArray" + case amf0MarkerObjectEnd: + return "ObjectEnd" + case amf0MarkerStrictArray: + return "StrictArray" + case amf0MarkerDate: + return "Date" + case amf0MarkerLongString: + return "LongString" + case amf0MarkerUnsupported: + return "Unsupported" + case amf0MarkerXmlDocument: + return "XmlDocument" + case amf0MarkerTypedObject: + return "TypedObject" + case amf0MarkerAvmPlusObject: + return "AvmPlusObject" + case amf0MarkerMovieClip: + return "MovieClip" + case amf0MarkerRecordSet: + return "RecordSet" + default: + return "Forbidden" + } +} + +// For utest to mock it. +type amf0Buffer interface { + Bytes() []byte + WriteByte(c byte) error + Write(p []byte) (n int, err error) +} + +var createBuffer = func() amf0Buffer { + return &bytes.Buffer{} +} + +// All AMF0 things. +type amf0Any interface { + // Binary marshaler and unmarshaler. + encoding.BinaryUnmarshaler + encoding.BinaryMarshaler + // Get the size of bytes to marshal this object. + Size() int + + // Get the Marker of any AMF0 stuff. + amf0Marker() amf0Marker +} + +type amf0Converter struct { + from amf0Any +} + +func NewAmf0Converter(from amf0Any) *amf0Converter { + return &amf0Converter{from: from} +} + +func (v *amf0Converter) ToNumber() *amf0Number { + return amf0AnyTo[*amf0Number](v.from) +} + +func (v *amf0Converter) ToBoolean() *amf0Boolean { + return amf0AnyTo[*amf0Boolean](v.from) +} + +func (v *amf0Converter) ToString() *amf0String { + return amf0AnyTo[*amf0String](v.from) +} + +func (v *amf0Converter) ToObject() *amf0Object { + return amf0AnyTo[*amf0Object](v.from) +} + +func (v *amf0Converter) ToNull() *amf0Null { + return amf0AnyTo[*amf0Null](v.from) +} + +func (v *amf0Converter) ToUndefined() *amf0Undefined { + return amf0AnyTo[*amf0Undefined](v.from) +} + +func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray { + return amf0AnyTo[*amf0EcmaArray](v.from) +} + +func (v *amf0Converter) ToStrictArray() *amf0StrictArray { + return amf0AnyTo[*amf0StrictArray](v.from) +} + +// Convert any to specified object. +func amf0AnyTo[T amf0Any](a amf0Any) T { + var to T + if a != nil { + if v, ok := a.(T); ok { + return v + } + } + return to +} + +// Discovery the amf0 object from the bytes b. +func Amf0Discovery(p []byte) (a amf0Any, err error) { + if len(p) < 1 { + return nil, errors.Errorf("require 1 bytes only %v", len(p)) + } + m := amf0Marker(p[0]) + + switch m { + case amf0MarkerNumber: + return NewAmf0Number(0), nil + case amf0MarkerBoolean: + return NewAmf0Boolean(false), nil + case amf0MarkerString: + return NewAmf0String(""), nil + case amf0MarkerObject: + return NewAmf0Object(), nil + case amf0MarkerNull: + return NewAmf0Null(), nil + case amf0MarkerUndefined: + return NewAmf0Undefined(), nil + case amf0MarkerReference: + case amf0MarkerEcmaArray: + return NewAmf0EcmaArray(), nil + case amf0MarkerObjectEnd: + return &amf0ObjectEOF{}, nil + case amf0MarkerStrictArray: + return NewAmf0StrictArray(), nil + case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument, + amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip, + amf0MarkerRecordSet: + return nil, errors.Errorf("Marker %v is not supported", m) + } + return nil, errors.Errorf("Marker %v is invalid", m) +} + +// The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8 +type amf0UTF8 string + +func (v *amf0UTF8) Size() int { + return 2 + len(string(*v)) +} + +func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 2 { + return errors.Errorf("require 2 bytes only %v", len(p)) + } + size := uint16(p[0])<<8 | uint16(p[1]) + + if p = data[2:]; len(p) < int(size) { + return errors.Errorf("require %v bytes only %v", int(size), len(p)) + } + *v = amf0UTF8(string(p[:size])) + + return +} + +func (v *amf0UTF8) MarshalBinary() (data []byte, err error) { + data = make([]byte, v.Size()) + + size := uint16(len(string(*v))) + data[0] = byte(size >> 8) + data[1] = byte(size) + + if size > 0 { + copy(data[2:], []byte(*v)) + } + + return +} + +// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type +type amf0Number float64 + +func NewAmf0Number(f float64) *amf0Number { + v := amf0Number(f) + return &v +} + +func (v *amf0Number) amf0Marker() amf0Marker { + return amf0MarkerNumber +} + +func (v *amf0Number) Size() int { + return 1 + 8 +} + +func (v *amf0Number) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 9 { + return errors.Errorf("require 9 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerNumber { + return errors.Errorf("Amf0Number amf0Marker %v is illegal", m) + } + + f := binary.BigEndian.Uint64(p[1:]) + *v = amf0Number(math.Float64frombits(f)) + return +} + +func (v *amf0Number) MarshalBinary() (data []byte, err error) { + data = make([]byte, 9) + data[0] = byte(amf0MarkerNumber) + f := math.Float64bits(float64(*v)) + binary.BigEndian.PutUint64(data[1:], f) + return +} + +// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type +type amf0String string + +func NewAmf0String(s string) *amf0String { + v := amf0String(s) + return &v +} + +func (v *amf0String) amf0Marker() amf0Marker { + return amf0MarkerString +} + +func (v *amf0String) Size() int { + u := amf0UTF8(*v) + return 1 + u.Size() +} + +func (v *amf0String) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return errors.Errorf("require 1 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerString { + return errors.Errorf("Amf0String amf0Marker %v is illegal", m) + } + + var sv amf0UTF8 + if err = sv.UnmarshalBinary(p[1:]); err != nil { + return errors.WithMessage(err, "utf8") + } + *v = amf0String(string(sv)) + return +} + +func (v *amf0String) MarshalBinary() (data []byte, err error) { + u := amf0UTF8(*v) + + var pb []byte + if pb, err = u.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "utf8") + } + + data = append([]byte{byte(amf0MarkerString)}, pb...) + return +} + +// The AMF0 object end type, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.11 Object End Type +type amf0ObjectEOF struct { +} + +func (v *amf0ObjectEOF) amf0Marker() amf0Marker { + return amf0MarkerObjectEnd +} + +func (v *amf0ObjectEOF) Size() int { + return 3 +} + +func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) { + p := data + + if len(p) < 3 { + return errors.Errorf("require 3 bytes only %v", len(p)) + } + + if p[0] != 0 || p[1] != 0 || p[2] != 9 { + return errors.Errorf("EOF amf0Marker %v is illegal", p[0:3]) + } + return +} + +func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) { + return []byte{0, 0, 9}, nil +} + +// Use array for object and ecma array, to keep the original order. +type amf0Property struct { + key amf0UTF8 + value amf0Any +} + +// The object-like AMF0 structure, like object and ecma array and strict array. +type amf0ObjectBase struct { + properties []*amf0Property + lock sync.Mutex +} + +func (v *amf0ObjectBase) Size() int { + v.lock.Lock() + defer v.lock.Unlock() + + var size int + + for _, p := range v.properties { + key, value := p.key, p.value + size += key.Size() + value.Size() + } + + return size +} + +func (v *amf0ObjectBase) Get(key string) amf0Any { + v.lock.Lock() + defer v.lock.Unlock() + + for _, p := range v.properties { + if string(p.key) == key { + return p.value + } + } + + return nil +} + +func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase { + v.lock.Lock() + defer v.lock.Unlock() + + prop := &amf0Property{key: amf0UTF8(key), value: value} + + var ok bool + for i, p := range v.properties { + if string(p.key) == key { + v.properties[i] = prop + ok = true + } + } + + if !ok { + v.properties = append(v.properties, prop) + } + + return v +} + +func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) { + // if no eof, elems specified by maxElems. + if !eof && maxElems < 0 { + return errors.Errorf("maxElems=%v without eof", maxElems) + } + // if eof, maxElems must be -1. + if eof && maxElems != -1 { + return errors.Errorf("maxElems=%v with eof", maxElems) + } + + readOne := func() (amf0UTF8, amf0Any, error) { + var u amf0UTF8 + if err = u.UnmarshalBinary(p); err != nil { + return "", nil, errors.WithMessage(err, "prop name") + } + + p = p[u.Size():] + var a amf0Any + if a, err = Amf0Discovery(p); err != nil { + return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) + } + return u, a, nil + } + + pushOne := func(u amf0UTF8, a amf0Any) error { + // For object property, consume the whole bytes. + if err = a.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) + } + + v.Set(string(u), a) + p = p[a.Size():] + return nil + } + + for eof { + u, a, err := readOne() + if err != nil { + return errors.WithMessage(err, "read") + } + + // For object EOF, we should only consume total 3bytes. + if u.Size() == 2 && a.amf0Marker() == amf0MarkerObjectEnd { + // 2 bytes is consumed by u(name), the a(eof) should only consume 1 byte. + p = p[1:] + return nil + } + + if err := pushOne(u, a); err != nil { + return errors.WithMessage(err, "push") + } + } + + for len(v.properties) < maxElems { + u, a, err := readOne() + if err != nil { + return errors.WithMessage(err, "read") + } + + if err := pushOne(u, a); err != nil { + return errors.WithMessage(err, "push") + } + } + + return +} + +func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) { + v.lock.Lock() + defer v.lock.Unlock() + + var pb []byte + for _, p := range v.properties { + key, value := p.key, p.value + + if pb, err = key.MarshalBinary(); err != nil { + return errors.WithMessage(err, fmt.Sprintf("marshal %v", string(key))) + } + if _, err = b.Write(pb); err != nil { + return errors.Wrapf(err, "write %v", string(key)) + } + + if pb, err = value.MarshalBinary(); err != nil { + return errors.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key))) + } + if _, err = b.Write(pb); err != nil { + return errors.Wrapf(err, "marshal value for %v", string(key)) + } + } + + return +} + +// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type +type amf0Object struct { + amf0ObjectBase + eof amf0ObjectEOF +} + +func NewAmf0Object() *amf0Object { + v := &amf0Object{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0Object) amf0Marker() amf0Marker { + return amf0MarkerObject +} + +func (v *amf0Object) Size() int { + return int(1) + v.eof.Size() + v.amf0ObjectBase.Size() +} + +func (v *amf0Object) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return errors.Errorf("require 1 byte only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerObject { + return errors.Errorf("Amf0Object amf0Marker %v is illegal", m) + } + p = p[1:] + + if err = v.unmarshal(p, true, -1); err != nil { + return errors.WithMessage(err, "unmarshal") + } + + return +} + +func (v *amf0Object) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerObject)); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + + var pb []byte + if pb, err = v.eof.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + if _, err = b.Write(pb); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + return b.Bytes(), nil +} + +// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type +type amf0EcmaArray struct { + amf0ObjectBase + count uint32 + eof amf0ObjectEOF +} + +func NewAmf0EcmaArray() *amf0EcmaArray { + v := &amf0EcmaArray{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0EcmaArray) amf0Marker() amf0Marker { + return amf0MarkerEcmaArray +} + +func (v *amf0EcmaArray) Size() int { + return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size() +} + +func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 5 { + return errors.Errorf("require 5 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray { + return errors.Errorf("EcmaArray amf0Marker %v is illegal", m) + } + v.count = binary.BigEndian.Uint32(p[1:]) + p = p[5:] + + if err = v.unmarshal(p, true, -1); err != nil { + return errors.WithMessage(err, "unmarshal") + } + return +} + +func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = binary.Write(b, binary.BigEndian, v.count); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + + var pb []byte + if pb, err = v.eof.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + if _, err = b.Write(pb); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + return b.Bytes(), nil +} + +// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type +type amf0StrictArray struct { + amf0ObjectBase + count uint32 +} + +func NewAmf0StrictArray() *amf0StrictArray { + v := &amf0StrictArray{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0StrictArray) amf0Marker() amf0Marker { + return amf0MarkerStrictArray +} + +func (v *amf0StrictArray) Size() int { + return int(1) + 4 + v.amf0ObjectBase.Size() +} + +func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 5 { + return errors.Errorf("require 5 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerStrictArray { + return errors.Errorf("StrictArray amf0Marker %v is illegal", m) + } + v.count = binary.BigEndian.Uint32(p[1:]) + p = p[5:] + + if int(v.count) <= 0 { + return + } + + if err = v.unmarshal(p, false, int(v.count)); err != nil { + return errors.WithMessage(err, "unmarshal") + } + return +} + +func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = binary.Write(b, binary.BigEndian, v.count); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + + return b.Bytes(), nil +} + +// The single amf0Marker object, for all AMF0 which only has the amf0Marker, like null and undefined. +type amf0SingleMarkerObject struct { + target amf0Marker +} + +func newAmf0SingleMarkerObject(m amf0Marker) amf0SingleMarkerObject { + return amf0SingleMarkerObject{target: m} +} + +func (v *amf0SingleMarkerObject) amf0Marker() amf0Marker { + return v.target +} + +func (v *amf0SingleMarkerObject) Size() int { + return int(1) +} + +func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return errors.Errorf("require 1 byte only %v", len(p)) + } + if m := amf0Marker(p[0]); m != v.target { + return errors.Errorf("%v amf0Marker %v is illegal", v.target, m) + } + return +} + +func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) { + return []byte{byte(v.target)}, nil +} + +// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type +type amf0Null struct { + amf0SingleMarkerObject +} + +func NewAmf0Null() *amf0Null { + v := amf0Null{} + v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull) + return &v +} + +// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type +type amf0Undefined struct { + amf0SingleMarkerObject +} + +func NewAmf0Undefined() amf0Any { + v := amf0Undefined{} + v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined) + return &v +} + +// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type +type amf0Boolean bool + +func NewAmf0Boolean(b bool) amf0Any { + v := amf0Boolean(b) + return &v +} + +func (v *amf0Boolean) amf0Marker() amf0Marker { + return amf0MarkerBoolean +} + +func (v *amf0Boolean) Size() int { + return int(2) +} + +func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 2 { + return errors.Errorf("require 2 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerBoolean { + return errors.Errorf("BOOL amf0Marker %v is illegal", m) + } + if p[1] == 0 { + *v = false + } else { + *v = true + } + return +} + +func (v *amf0Boolean) MarshalBinary() (data []byte, err error) { + var b byte + if *v { + b = 1 + } + return []byte{byte(amf0MarkerBoolean), b}, nil +} diff --git a/proxy/rtmp/rtmp.go b/proxy/rtmp/rtmp.go new file mode 100644 index 000000000..ee0970e96 --- /dev/null +++ b/proxy/rtmp/rtmp.go @@ -0,0 +1,1792 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bufio" + "bytes" + "context" + "encoding" + "encoding/binary" + "fmt" + "io" + "math/rand" + "sync" + + "srs-proxy/errors" +) + +// The handshake implements the RTMP handshake protocol. +type Handshake struct { + // The random number generator. + r *rand.Rand + // The c1s1 cache. + c1s1 []byte +} + +func NewHandshake(r *rand.Rand) *Handshake { + return &Handshake{r: r} +} + +func (v *Handshake) C1S1() []byte { + return v.c1s1 +} + +func (v *Handshake) WriteC0S0(w io.Writer) (err error) { + r := bytes.NewReader([]byte{0x03}) + if _, err = io.Copy(w, r); err != nil { + return errors.Wrap(err, "write c0s0") + } + + return +} + +func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1); err != nil { + return nil, errors.Wrap(err, "read c0s0") + } + + c0 = b.Bytes() + + return +} + +func (v *Handshake) WriteC1S1(w io.Writer) (err error) { + p := make([]byte, 1536) + + for i := 8; i < len(p); i++ { + p[i] = byte(v.r.Int()) + } + + r := bytes.NewReader(p) + if _, err = io.Copy(w, r); err != nil { + return errors.Wrap(err, "write c0s1") + } + + return +} + +func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1536); err != nil { + return nil, errors.Wrap(err, "read c1s1") + } + + c1s1 = b.Bytes() + v.c1s1 = c1s1 + + return +} + +func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { + r := bytes.NewReader(s1c1[:]) + if _, err = io.Copy(w, r); err != nil { + return errors.Wrap(err, "write c2s2") + } + + return +} + +func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1536); err != nil { + return nil, errors.Wrap(err, "read c2s2") + } + + c2 = b.Bytes() + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 16, @section 6.1. Chunk Format +// Extended timestamp: 0 or 4 bytes +// This field MUST be sent when the normal timsestamp is set to +// 0xffffff, it MUST NOT be sent if the normal timestamp is set to +// anything else. So for values less than 0xffffff the normal +// timestamp field SHOULD be used in which case the extended timestamp +// MUST NOT be present. For values greater than or equal to 0xffffff +// the normal timestamp field MUST NOT be used and MUST be set to +// 0xffffff and the extended timestamp MUST be sent. +const extendedTimestamp = uint64(0xffffff) + +// The default chunk size of RTMP is 128 bytes. +const defaultChunkSize = 128 + +// The intput or output settings for RTMP protocol. +type settings struct { + chunkSize uint32 +} + +func newSettings() *settings { + return &settings{ + chunkSize: defaultChunkSize, + } +} + +// The chunk stream which transport a message once. +type chunkStream struct { + format formatType + cid chunkID + header messageHeader + message *Message + count uint64 + extendedTimestamp bool +} + +func newChunkStream() *chunkStream { + return &chunkStream{} +} + +// The protocol implements the RTMP command and chunk stack. +type Protocol struct { + r *bufio.Reader + w *bufio.Writer + input struct { + opt *settings + chunks map[chunkID]*chunkStream + + transactions map[amf0Number]amf0String + ltransactions sync.Mutex + } + output struct { + opt *settings + } +} + +func NewProtocol(rw io.ReadWriter) *Protocol { + v := &Protocol{ + r: bufio.NewReader(rw), + w: bufio.NewWriter(rw), + } + + v.input.opt = newSettings() + v.input.chunks = map[chunkID]*chunkStream{} + v.input.transactions = map[amf0Number]amf0String{} + + v.output.opt = newSettings() + + return v +} + +func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) { + for { + if m, err = v.ReadMessage(ctx); err != nil { + return nil, errors.WithMessage(err, "read message") + } + + var pkt Packet + if pkt, err = v.DecodeMessage(m); err != nil { + return nil, errors.WithMessage(err, "decode message") + } + + if p, ok := pkt.(T); ok { + *ppkt = p + break + } + } + + return +} + +// Deprecated: Please use rtmp.ExpectPacket instead. +func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err error) { + panic("Please use rtmp.ExpectPacket instead") +} + +func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) { + for { + if m, err = v.ReadMessage(ctx); err != nil { + return nil, errors.WithMessage(err, "read message") + } + + if len(types) == 0 { + return + } + + for _, t := range types { + if m.MessageType == t { + return + } + } + } + + return +} + +func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { + var commandName amf0String + if err = commandName.UnmarshalBinary(p); err != nil { + return nil, errors.WithMessage(err, "unmarshal command name") + } + + switch commandName { + case commandResult, commandError: + var transactionID amf0Number + if err = transactionID.UnmarshalBinary(p[commandName.Size():]); err != nil { + return nil, errors.WithMessage(err, "unmarshal tid") + } + + var requestName amf0String + if err = func() error { + v.input.ltransactions.Lock() + defer v.input.ltransactions.Unlock() + + var ok bool + if requestName, ok = v.input.transactions[transactionID]; !ok { + return errors.Errorf("No matched request for tid=%v", transactionID) + } + delete(v.input.transactions, transactionID) + + return nil + }(); err != nil { + return nil, errors.WithMessage(err, "discovery request name") + } + + switch requestName { + case commandConnect: + return NewConnectAppResPacket(transactionID), nil + case commandCreateStream: + return NewCreateStreamResPacket(transactionID), nil + case commandReleaseStream, commandFCPublish, commandFCUnpublish: + call := NewCallPacket() + call.TransactionID = transactionID + return call, nil + default: + return nil, errors.Errorf("No request for %v", string(requestName)) + } + case commandConnect: + return NewConnectAppPacket(), nil + case commandPublish: + return NewPublishPacket(), nil + case commandPlay: + return NewPlayPacket(), nil + default: + return NewCallPacket(), nil + } +} + +func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { + p := m.Payload[:] + if len(p) == 0 { + return nil, errors.New("Empty packet") + } + + switch m.MessageType { + case MessageTypeAMF3Command, MessageTypeAMF3Data: + p = p[1:] + } + + switch m.MessageType { + case MessageTypeSetChunkSize: + pkt = NewSetChunkSize() + case MessageTypeWindowAcknowledgementSize: + pkt = NewWindowAcknowledgementSize() + case MessageTypeSetPeerBandwidth: + pkt = NewSetPeerBandwidth() + case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data: + if pkt, err = v.parseAMFObject(p); err != nil { + return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) + } + case MessageTypeUserControl: + pkt = NewUserControl() + default: + return nil, errors.Errorf("Unknown message %v", m.MessageType) + } + + if err = pkt.UnmarshalBinary(p); err != nil { + return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) + } + + return +} + +func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { + for m == nil { + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return nil, ctx.Err() + } + + var cid chunkID + var format formatType + if format, cid, err = v.readBasicHeader(ctx); err != nil { + return nil, errors.WithMessage(err, "read basic header") + } + + var ok bool + var chunk *chunkStream + if chunk, ok = v.input.chunks[cid]; !ok { + chunk = newChunkStream() + v.input.chunks[cid] = chunk + chunk.header.betterCid = cid + } + + if err = v.readMessageHeader(ctx, chunk, format); err != nil { + return nil, errors.WithMessage(err, "read message header") + } + + if m, err = v.readMessagePayload(ctx, chunk); err != nil { + return nil, errors.WithMessage(err, "read message payload") + } + + if err = v.onMessageArrivated(m); err != nil { + return nil, errors.WithMessage(err, "on message") + } + } + + return +} + +func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m *Message, err error) { + // Empty payload message. + if chunk.message.payloadLength == 0 { + m = chunk.message + chunk.message = nil + return + } + + // Calculate the chunk payload size. + chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.Payload) + if chunkedPayloadSize > int(v.input.opt.chunkSize) { + chunkedPayloadSize = int(v.input.opt.chunkSize) + } + + b := make([]byte, chunkedPayloadSize) + if _, err = io.ReadFull(v.r, b); err != nil { + return nil, errors.Wrapf(err, "read chunk %vB", chunkedPayloadSize) + } + chunk.message.Payload = append(chunk.message.Payload, b...) + + // Got entire RTMP message? + if int(chunk.message.payloadLength) == len(chunk.message.Payload) { + m = chunk.message + chunk.message = nil + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 18, @section 6.1.2. Chunk Message Header +// There are four different formats for the chunk message header, +// selected by the "fmt" field in the chunk basic header. +type formatType uint8 + +const ( + // 6.1.2.1. Type 0 + // Chunks of Type 0 are 11 bytes long. This type MUST be used at the + // start of a chunk stream, and whenever the stream timestamp goes + // backward (e.g., because of a backward seek). + formatType0 formatType = iota + // 6.1.2.2. Type 1 + // Chunks of Type 1 are 7 bytes long. The message stream ID is not + // included; this chunk takes the same stream ID as the preceding chunk. + // Streams with variable-sized messages (for example, many video + // formats) SHOULD use this format for the first chunk of each new + // message after the first. + formatType1 + // 6.1.2.3. Type 2 + // Chunks of Type 2 are 3 bytes long. Neither the stream ID nor the + // message length is included; this chunk has the same stream ID and + // message length as the preceding chunk. Streams with constant-sized + // messages (for example, some audio and data formats) SHOULD use this + // format for the first chunk of each message after the first. + formatType2 + // 6.1.2.4. Type 3 + // Chunks of Type 3 have no header. Stream ID, message length and + // timestamp delta are not present; chunks of this type take values from + // the preceding chunk. When a single message is split into chunks, all + // chunks of a message except the first one, SHOULD use this type. Refer + // to example 2 in section 6.2.2. Stream consisting of messages of + // exactly the same size, stream ID and spacing in time SHOULD use this + // type for all chunks after chunk of Type 2. Refer to example 1 in + // section 6.2.1. If the delta between the first message and the second + // message is same as the time stamp of first message, then chunk of + // type 3 would immediately follow the chunk of type 0 as there is no + // need for a chunk of type 2 to register the delta. If Type 3 chunk + // follows a Type 0 chunk, then timestamp delta for this Type 3 chunk is + // the same as the timestamp of Type 0 chunk. + formatType3 +) + +// The message header size, index is format. +var messageHeaderSizes = []int{11, 7, 3, 0} + +// Parse the chunk message header. +// 3bytes: timestamp delta, fmt=0,1,2 +// 3bytes: payload length, fmt=0,1 +// 1bytes: message type, fmt=0,1 +// 4bytes: stream id, fmt=0 +// where: +// fmt=0, 0x0X +// fmt=1, 0x4X +// fmt=2, 0x8X +// fmt=3, 0xCX +func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) { + // We should not assert anything about fmt, for the first packet. + // (when first packet, the chunk.message is nil). + // the fmt maybe 0/1/2/3, the FMLE will send a 0xC4 for some audio packet. + // the previous packet is: + // 04 // fmt=0, cid=4 + // 00 00 1a // timestamp=26 + // 00 00 9d // payload_length=157 + // 08 // message_type=8(audio) + // 01 00 00 00 // stream_id=1 + // the current packet maybe: + // c4 // fmt=3, cid=4 + // it's ok, for the packet is audio, and timestamp delta is 26. + // the current packet must be parsed as: + // fmt=0, cid=4 + // timestamp=26+26=52 + // payload_length=157 + // message_type=8(audio) + // stream_id=1 + // so we must update the timestamp even fmt=3 for first packet. + // + // The fresh packet used to update the timestamp even fmt=3 for first packet. + // fresh packet always means the chunk is the first one of message. + var isFirstChunkOfMsg bool + if chunk.message == nil { + isFirstChunkOfMsg = true + } + + // But, we can ensure that when a chunk stream is fresh, + // the fmt must be 0, a new stream. + if chunk.count == 0 && format != formatType0 { + // For librtmp, if ping, it will send a fresh stream with fmt=1, + // 0x42 where: fmt=1, cid=2, protocol contorl user-control message + // 0x00 0x00 0x00 where: timestamp=0 + // 0x00 0x00 0x06 where: payload_length=6 + // 0x04 where: message_type=4(protocol control user-control message) + // 0x00 0x06 where: event Ping(0x06) + // 0x00 0x00 0x0d 0x0f where: event data 4bytes ping timestamp. + // @see: https://github.com/ossrs/srs/issues/98 + if chunk.cid == chunkIDProtocolControl && format == formatType1 { + // We accept cid=2, fmt=1 to make librtmp happy. + } else { + return errors.Errorf("For fresh chunk, fmt %v != %v(required), cid is %v", format, formatType0, chunk.cid) + } + } + + // When exists cache msg, means got an partial message, + // the fmt must not be type0 which means new message. + if chunk.message != nil && format == formatType0 { + return errors.Errorf("For exists chunk, fmt is %v, cid is %v", format, chunk.cid) + } + + // Create msg when new chunk stream start + if chunk.message == nil { + chunk.message = NewMessage() + } + + // Read the message header. + p := make([]byte, messageHeaderSizes[format]) + if _, err = io.ReadFull(v.r, p); err != nil { + return errors.Wrapf(err, "read %vB message header", len(p)) + } + + // Prse the message header. + // 3bytes: timestamp delta, fmt=0,1,2 + // 3bytes: payload length, fmt=0,1 + // 1bytes: message type, fmt=0,1 + // 4bytes: stream id, fmt=0 + // where: + // fmt=0, 0x0X + // fmt=1, 0x4X + // fmt=2, 0x8X + // fmt=3, 0xCX + if format <= formatType2 { + chunk.header.timestampDelta = uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) + p = p[3:] + + // fmt: 0 + // timestamp: 3 bytes + // If the timestamp is greater than or equal to 16777215 + // (hexadecimal 0x00ffffff), this value MUST be 16777215, and the + // 'extended timestamp header' MUST be present. Otherwise, this value + // SHOULD be the entire timestamp. + // + // fmt: 1 or 2 + // timestamp delta: 3 bytes + // If the delta is greater than or equal to 16777215 (hexadecimal + // 0x00ffffff), this value MUST be 16777215, and the 'extended + // timestamp header' MUST be present. Otherwise, this value SHOULD be + // the entire delta. + chunk.extendedTimestamp = uint64(chunk.header.timestampDelta) >= extendedTimestamp + if !chunk.extendedTimestamp { + // Extended timestamp: 0 or 4 bytes + // This field MUST be sent when the normal timsestamp is set to + // 0xffffff, it MUST NOT be sent if the normal timestamp is set to + // anything else. So for values less than 0xffffff the normal + // timestamp field SHOULD be used in which case the extended timestamp + // MUST NOT be present. For values greater than or equal to 0xffffff + // the normal timestamp field MUST NOT be used and MUST be set to + // 0xffffff and the extended timestamp MUST be sent. + if format == formatType0 { + // 6.1.2.1. Type 0 + // For a type-0 chunk, the absolute timestamp of the message is sent + // here. + chunk.header.Timestamp = uint64(chunk.header.timestampDelta) + } else { + // 6.1.2.2. Type 1 + // 6.1.2.3. Type 2 + // For a type-1 or type-2 chunk, the difference between the previous + // chunk's timestamp and the current chunk's timestamp is sent here. + chunk.header.Timestamp += uint64(chunk.header.timestampDelta) + } + } + + if format <= formatType1 { + payloadLength := uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) + p = p[3:] + + // For a message, if msg exists in cache, the size must not changed. + // always use the actual msg size to compare, for the cache payload length can changed, + // for the fmt type1(stream_id not changed), user can change the payload + // length(it's not allowed in the continue chunks). + if !isFirstChunkOfMsg && chunk.header.payloadLength != payloadLength { + return errors.Errorf("Chunk message size %v != %v(required)", payloadLength, chunk.header.payloadLength) + } + chunk.header.payloadLength = payloadLength + + chunk.header.MessageType = MessageType(p[0]) + p = p[1:] + + if format == formatType0 { + chunk.header.streamID = uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24 + p = p[4:] + } + } + } else { + // Update the timestamp even fmt=3 for first chunk packet + if isFirstChunkOfMsg && !chunk.extendedTimestamp { + chunk.header.Timestamp += uint64(chunk.header.timestampDelta) + } + } + + // Read extended-timestamp + if chunk.extendedTimestamp { + var timestamp uint32 + if err = binary.Read(v.r, binary.BigEndian, ×tamp); err != nil { + return errors.Wrapf(err, "read ext-ts, pkt-ts=%v", chunk.header.Timestamp) + } + + // We always use 31bits timestamp, for some server may use 32bits extended timestamp. + // @see https://github.com/ossrs/srs/issues/111 + timestamp &= 0x7fffffff + + // TODO: FIXME: Support detect the extended timestamp. + // @see http://blog.csdn.net/win_lin/article/details/13363699 + chunk.header.Timestamp = uint64(timestamp) + } + + // The extended-timestamp must be unsigned-int, + // 24bits timestamp: 0xffffff = 16777215ms = 16777.215s = 4.66h + // 32bits timestamp: 0xffffffff = 4294967295ms = 4294967.295s = 1193.046h = 49.71d + // because the rtmp protocol says the 32bits timestamp is about "50 days": + // 3. Byte Order, Alignment, and Time Format + // Because timestamps are generally only 32 bits long, they will roll + // over after fewer than 50 days. + // + // but, its sample says the timestamp is 31bits: + // An application could assume, for example, that all + // adjacent timestamps are within 2^31 milliseconds of each other, so + // 10000 comes after 4000000000, while 3000000000 comes before + // 4000000000. + // and flv specification says timestamp is 31bits: + // Extension of the Timestamp field to form a SI32 value. This + // field represents the upper 8 bits, while the previous + // Timestamp field represents the lower 24 bits of the time in + // milliseconds. + // in a word, 31bits timestamp is ok. + // convert extended timestamp to 31bits. + chunk.header.Timestamp &= 0x7fffffff + + // Copy header to msg + chunk.message.messageHeader = chunk.header + + // Increase the msg count, the chunk stream can accept fmt=1/2/3 message now. + chunk.count++ + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 17, @section 6.1.1. Chunk Basic Header +// The Chunk Basic Header encodes the chunk stream ID and the chunk +// type(represented by fmt field in the figure below). Chunk type +// determines the format of the encoded message header. Chunk Basic +// Header field may be 1, 2, or 3 bytes, depending on the chunk stream +// ID. +// +// The bits 0-5 (least significant) in the chunk basic header represent +// the chunk stream ID. +// +// Chunk stream IDs 2-63 can be encoded in the 1-byte version of this +// field. +// 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+ +// |fmt| cs id | +// +-+-+-+-+-+-+-+-+ +// Figure 6 Chunk basic header 1 +// +// Chunk stream IDs 64-319 can be encoded in the 2-byte version of this +// field. ID is computed as (the second byte + 64). +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |fmt| 0 | cs id - 64 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// Figure 7 Chunk basic header 2 +// +// Chunk stream IDs 64-65599 can be encoded in the 3-byte version of +// this field. ID is computed as ((the third byte)*256 + the second byte +// + 64). +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |fmt| 1 | cs id - 64 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// Figure 8 Chunk basic header 3 +// +// cs id: 6 bits +// fmt: 2 bits +// cs id - 64: 8 or 16 bits +// +// Chunk stream IDs with values 64-319 could be represented by both 2- +// byte version and 3-byte version of this field. +func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) { + // 2-63, 1B chunk header + var t uint8 + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, errors.Wrap(err, "read basic header") + } + cid = chunkID(t & 0x3f) + format = formatType((t >> 6) & 0x03) + + if cid > 1 { + return + } + + // 64-319, 2B chunk header + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) + } + cid = chunkID(64 + uint32(t)) + + // 64-65599, 3B chunk header + if cid == 1 { + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) + } + cid += chunkID(uint32(t) * 256) + } + + return +} + +func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) { + m := NewMessage() + + if m.Payload, err = pkt.MarshalBinary(); err != nil { + return errors.WithMessage(err, "marshal payload") + } + + m.MessageType = pkt.Type() + m.streamID = uint32(streamID) + m.betterCid = pkt.BetterCid() + + if err = v.WriteMessage(ctx, m); err != nil { + return errors.WithMessage(err, "write message") + } + + if err = v.onPacketWriten(m, pkt); err != nil { + return errors.WithMessage(err, "on write packet") + } + + return +} + +func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { + var tid amf0Number + var name amf0String + + switch pkt := pkt.(type) { + case *ConnectAppPacket: + tid, name = pkt.TransactionID, pkt.CommandName + case *CreateStreamPacket: + tid, name = pkt.TransactionID, pkt.CommandName + case *CallPacket: + tid, name = pkt.TransactionID, pkt.CommandName + } + + if tid > 0 && len(name) > 0 { + v.input.ltransactions.Lock() + defer v.input.ltransactions.Unlock() + + v.input.transactions[tid] = name + } + + return +} + +func (v *Protocol) onMessageArrivated(m *Message) (err error) { + if m == nil { + return + } + + var pkt Packet + switch m.MessageType { + case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize: + if pkt, err = v.DecodeMessage(m); err != nil { + return errors.Errorf("decode message %v", m.MessageType) + } + } + + switch pkt := pkt.(type) { + case *SetChunkSize: + v.input.opt.chunkSize = pkt.ChunkSize + } + + return +} + +func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { + m.payloadLength = uint32(len(m.Payload)) + + var c0h, c3h []byte + if c0h, err = m.generateC0Header(); err != nil { + return errors.WithMessage(err, "generate c0 header") + } + if c3h, err = m.generateC3Header(); err != nil { + return errors.WithMessage(err, "generate c3 header") + } + + var h []byte + p := m.Payload + for len(p) > 0 { + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return ctx.Err() + } + + if h == nil { + h = c0h + } else { + h = c3h + } + + if _, err = io.Copy(v.w, bytes.NewReader(h)); err != nil { + return errors.Wrapf(err, "write c0c3 header %x", h) + } + + size := len(p) + if size > int(v.output.opt.chunkSize) { + size = int(v.output.opt.chunkSize) + } + + if _, err = io.Copy(v.w, bytes.NewReader(p[:size])); err != nil { + return errors.Wrapf(err, "write chunk payload %vB", size) + } + p = p[size:] + } + + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return ctx.Err() + } + + // TODO: FIXME: Use writev to write for high performance. + if err = v.w.Flush(); err != nil { + return errors.Wrapf(err, "flush writer") + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header +// 1byte. One byte field to represent the message type. A range of type IDs +// (1-7) are reserved for protocol control messages. +type MessageType uint8 + +const ( + // Please read @doc rtmp_specification_1.0.pdf, @page 30, @section 5. Protocol Control Messages + // RTMP reserves message type IDs 1-7 for protocol control messages. + // These messages contain information needed by the RTM Chunk Stream + // protocol or RTMP itself. Protocol messages with IDs 1 & 2 are + // reserved for usage with RTM Chunk Stream protocol. Protocol messages + // with IDs 3-6 are reserved for usage of RTMP. Protocol message with ID + // 7 is used between edge server and origin server. + MessageTypeSetChunkSize MessageType = 0x01 + MessageTypeAbort MessageType = 0x02 // 0x02 + MessageTypeAcknowledgement MessageType = 0x03 // 0x03 + MessageTypeUserControl MessageType = 0x04 // 0x04 + MessageTypeWindowAcknowledgementSize MessageType = 0x05 // 0x05 + MessageTypeSetPeerBandwidth MessageType = 0x06 // 0x06 + MessageTypeEdgeAndOriginServerCommand MessageType = 0x07 // 0x07 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3. Types of messages + // The server and the client send messages over the network to + // communicate with each other. The messages can be of any type which + // includes audio messages, video messages, command messages, shared + // object messages, data messages, and user control messages. + // + // Please read @doc rtmp_specification_1.0.pdf, @page 41, @section 3.4. Audio message + // The client or the server sends this message to send audio data to the + // peer. The message type value of 8 is reserved for audio messages. + MessageTypeAudio MessageType = 0x08 + // Please read @doc rtmp_specification_1.0.pdf, @page 41, @section 3.5. Video message + // The client or the server sends this message to send video data to the + // peer. The message type value of 9 is reserved for video messages. + // These messages are large and can delay the sending of other type of + // messages. To avoid such a situation, the video message is assigned + // the lowest priority. + MessageTypeVideo MessageType = 0x09 // 0x09 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3.1. Command message + // Command messages carry the AMF-encoded commands between the client + // and the server. These messages have been assigned message type value + // of 20 for AMF0 encoding and message type value of 17 for AMF3 + // encoding. These messages are sent to perform some operations like + // connect, createStream, publish, play, pause on the peer. Command + // messages like onstatus, result etc. are used to inform the sender + // about the status of the requested commands. A command message + // consists of command name, transaction ID, and command object that + // contains related parameters. A client or a server can request Remote + // Procedure Calls (RPC) over streams that are communicated using the + // command messages to the peer. + MessageTypeAMF3Command MessageType = 17 // 0x11 + MessageTypeAMF0Command MessageType = 20 // 0x14 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3.2. Data message + // The client or the server sends this message to send Metadata or any + // user data to the peer. Metadata includes details about the + // data(audio, video etc.) like creation time, duration, theme and so + // on. These messages have been assigned message type value of 18 for + // AMF0 and message type value of 15 for AMF3. + MessageTypeAMF0Data MessageType = 18 // 0x12 + MessageTypeAMF3Data MessageType = 15 // 0x0f +) + +// The header of message. +type messageHeader struct { + // 3bytes. + // Three-byte field that contains a timestamp delta of the message. + // @remark, only used for decoding message from chunk stream. + timestampDelta uint32 + // 3bytes. + // Three-byte field that represents the size of the payload in bytes. + // It is set in big-endian format. + payloadLength uint32 + // 1byte. + // One byte field to represent the message type. A range of type IDs + // (1-7) are reserved for protocol control messages. + MessageType MessageType + // 4bytes. + // Four-byte field that identifies the stream of the message. These + // bytes are set in little-endian format. + streamID uint32 + + // The chunk stream id over which transport. + betterCid chunkID + + // Four-byte field that contains a timestamp of the message. + // The 4 bytes are packed in the big-endian order. + // @remark, we use 64bits for large time for jitter detect and for large tbn like HLS. + Timestamp uint64 +} + +// The RTMP message, transport over chunk stream in RTMP. +// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header +type Message struct { + messageHeader + + // The payload which carries the RTMP packet. + Payload []byte +} + +func NewMessage() *Message { + return &Message{} +} + +func NewStreamMessage(streamID int) *Message { + v := NewMessage() + v.streamID = uint32(streamID) + v.betterCid = chunkIDOverStream + return v +} + +func (v *Message) generateC3Header() ([]byte, error) { + var c3h []byte + if v.Timestamp < extendedTimestamp { + c3h = make([]byte, 1) + } else { + c3h = make([]byte, 1+4) + } + + p := c3h + p[0] = 0xc0 | byte(v.betterCid&0x3f) + p = p[1:] + + // In RTMP protocol, there must not any timestamp in C3 header, + // but actually all products from adobe, such as FMS/AMS and Flash player and FMLE, + // always carry a extended timestamp in C3 header. + // @see: http://blog.csdn.net/win_lin/article/details/13363699 + if v.Timestamp >= extendedTimestamp { + p[0] = byte(v.Timestamp >> 24) + p[1] = byte(v.Timestamp >> 16) + p[2] = byte(v.Timestamp >> 8) + p[3] = byte(v.Timestamp) + } + + return c3h, nil +} + +func (v *Message) generateC0Header() ([]byte, error) { + var c0h []byte + if v.Timestamp < extendedTimestamp { + c0h = make([]byte, 1+3+3+1+4) + } else { + c0h = make([]byte, 1+3+3+1+4+4) + } + + p := c0h + p[0] = byte(v.betterCid) & 0x3f + p = p[1:] + + if v.Timestamp < extendedTimestamp { + p[0] = byte(v.Timestamp >> 16) + p[1] = byte(v.Timestamp >> 8) + p[2] = byte(v.Timestamp) + } else { + p[0] = 0xff + p[1] = 0xff + p[2] = 0xff + } + p = p[3:] + + p[0] = byte(v.payloadLength >> 16) + p[1] = byte(v.payloadLength >> 8) + p[2] = byte(v.payloadLength) + p = p[3:] + + p[0] = byte(v.MessageType) + p = p[1:] + + p[0] = byte(v.streamID) + p[1] = byte(v.streamID >> 8) + p[2] = byte(v.streamID >> 16) + p[3] = byte(v.streamID >> 24) + p = p[4:] + + if v.Timestamp >= extendedTimestamp { + p[0] = byte(v.Timestamp >> 24) + p[1] = byte(v.Timestamp >> 16) + p[2] = byte(v.Timestamp >> 8) + p[3] = byte(v.Timestamp) + } + + return c0h, nil +} + +// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 17, @section 6.1.1. Chunk Basic Header +type chunkID uint32 + +const ( + chunkIDProtocolControl chunkID = 0x02 + chunkIDOverConnection chunkID = 0x03 + chunkIDOverConnection2 chunkID = 0x04 + chunkIDOverStream chunkID = 0x05 + chunkIDOverStream2 chunkID = 0x06 + chunkIDVideo chunkID = 0x07 + chunkIDAudio chunkID = 0x08 +) + +// The Command Name of message. +const ( + commandConnect amf0String = amf0String("connect") + commandCreateStream amf0String = amf0String("createStream") + commandCloseStream amf0String = amf0String("closeStream") + commandPlay amf0String = amf0String("play") + commandPause amf0String = amf0String("pause") + commandOnBWDone amf0String = amf0String("onBWDone") + commandOnStatus amf0String = amf0String("onStatus") + commandResult amf0String = amf0String("_result") + commandError amf0String = amf0String("_error") + commandReleaseStream amf0String = amf0String("releaseStream") + commandFCPublish amf0String = amf0String("FCPublish") + commandFCUnpublish amf0String = amf0String("FCUnpublish") + commandPublish amf0String = amf0String("publish") + commandRtmpSampleAccess amf0String = amf0String("|RtmpSampleAccess") +) + +// The RTMP packet, transport as payload of RTMP message. +type Packet interface { + // Marshaler and unmarshaler + Size() int + encoding.BinaryUnmarshaler + encoding.BinaryMarshaler + + // RTMP protocol fields for each packet. + BetterCid() chunkID + Type() MessageType +} + +// A Call packet, both object and args are AMF0 objects. +type objectCallPacket struct { + CommandName amf0String + TransactionID amf0Number + CommandObject *amf0Object + Args *amf0Object +} + +func (v *objectCallPacket) BetterCid() chunkID { + return chunkIDOverConnection +} + +func (v *objectCallPacket) Type() MessageType { + return MessageTypeAMF0Command +} + +func (v *objectCallPacket) Size() int { + size := v.CommandName.Size() + v.TransactionID.Size() + v.CommandObject.Size() + if v.Args != nil { + size += v.Args.Size() + } + return size +} + +func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.CommandName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command name") + } + p = p[v.CommandName.Size():] + + if err = v.TransactionID.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal tid") + } + p = p[v.TransactionID.Size():] + + if err = v.CommandObject.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command") + } + p = p[v.CommandObject.Size():] + + if len(p) == 0 { + return + } + + v.Args = NewAmf0Object() + if err = v.Args.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal args") + } + + return +} + +func (v *objectCallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.CommandName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command name") + } + data = append(data, pb...) + + if pb, err = v.TransactionID.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal tid") + } + data = append(data, pb...) + + if pb, err = v.CommandObject.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command object") + } + data = append(data, pb...) + + if v.Args != nil { + if pb, err = v.Args.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal args") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 45, @section 4.1.1. connect +// The client sends the connect command to the server to request +// connection to a server application instance. +type ConnectAppPacket struct { + objectCallPacket +} + +func NewConnectAppPacket() *ConnectAppPacket { + v := &ConnectAppPacket{} + v.CommandName = commandConnect + v.CommandObject = NewAmf0Object() + v.TransactionID = amf0Number(1.0) + return v +} + +func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { + if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + + if v.CommandName != commandConnect { + return errors.Errorf("Invalid command name %v", string(v.CommandName)) + } + + if v.TransactionID != 1.0 { + return errors.Errorf("Invalid transaction ID %v", float64(v.TransactionID)) + } + + return +} + +func (v *ConnectAppPacket) TcUrl() string { + if v.CommandObject != nil { + if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok { + return string(*v) + } + } + return "" +} + +// The response for ConnectAppPacket. +type ConnectAppResPacket struct { + objectCallPacket +} + +func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { + v := &ConnectAppResPacket{} + v.CommandName = commandResult + v.CommandObject = NewAmf0Object() + v.Args = NewAmf0Object() + v.TransactionID = tid + return v +} + +func (v *ConnectAppResPacket) SrsID() string { + if v.Args != nil { + if v, ok := v.Args.Get("data").(*amf0EcmaArray); ok { + if v, ok := v.Get("srs_id").(*amf0String); ok { + return string(*v) + } + } + } + return "" +} + +func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) { + if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + + if v.CommandName != commandResult { + return errors.Errorf("Invalid command name %v", string(v.CommandName)) + } + + return +} + +// A Call object, command object is variant. +type variantCallPacket struct { + CommandName amf0String + TransactionID amf0Number + CommandObject amf0Any // object or null +} + +func (v *variantCallPacket) BetterCid() chunkID { + return chunkIDOverConnection +} + +func (v *variantCallPacket) Type() MessageType { + return MessageTypeAMF0Command +} + +func (v *variantCallPacket) Size() int { + size := v.CommandName.Size() + v.TransactionID.Size() + + if v.CommandObject != nil { + size += v.CommandObject.Size() + } + + return size +} + +func (v *variantCallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.CommandName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command name") + } + p = p[v.CommandName.Size():] + + if err = v.TransactionID.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal tid") + } + p = p[v.TransactionID.Size():] + + if len(p) > 0 { + if v.CommandObject, err = Amf0Discovery(p); err != nil { + return errors.WithMessage(err, "discovery command object") + } + if err = v.CommandObject.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command object") + } + p = p[v.CommandObject.Size():] + } + + return +} + +func (v *variantCallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.CommandName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command name") + } + data = append(data, pb...) + + if pb, err = v.TransactionID.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal tid") + } + data = append(data, pb...) + + if v.CommandObject != nil { + if pb, err = v.CommandObject.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command object") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 51, @section 4.1.2. Call +// The call method of the NetConnection object runs remote procedure +// calls (RPC) at the receiving end. The called RPC name is passed as a +// parameter to the call command. +// @remark onStatus packet is a call packet. +type CallPacket struct { + variantCallPacket + Args amf0Any // optional or object or null +} + +func NewCallPacket() *CallPacket { + return &CallPacket{} +} + +func (v *CallPacket) ArgsCode() string { + if v.Args != nil { + if v, ok := v.Args.(*amf0Object); ok { + if code, ok := v.Get("code").(*amf0String); ok { + return string(*code) + } + } + } + return "" +} + +func (v *CallPacket) Size() int { + size := v.variantCallPacket.Size() + + if v.Args != nil { + size += v.Args.Size() + } + + return size +} + +func (v *CallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if len(p) > 0 { + if v.Args, err = Amf0Discovery(p); err != nil { + return errors.WithMessage(err, "discovery args") + } + if err = v.Args.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal args") + } + } + + return +} + +func (v *CallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if v.Args != nil { + if pb, err = v.Args.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal args") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 52, @section 4.1.3. createStream +// The client sends this command to the server to create a logical +// channel for message communication The publishing of audio, video, and +// metadata is carried out over stream channel created using the +// createStream command. +type CreateStreamPacket struct { + variantCallPacket +} + +func NewCreateStreamPacket() *CreateStreamPacket { + v := &CreateStreamPacket{} + v.CommandName = commandCreateStream + v.TransactionID = amf0Number(2) + v.CommandObject = NewAmf0Null() + return v +} + +// The response for create stream +type CreateStreamResPacket struct { + variantCallPacket + StreamID amf0Number +} + +func NewCreateStreamResPacket(tid amf0Number) *CreateStreamResPacket { + v := &CreateStreamResPacket{} + v.CommandName = commandResult + v.TransactionID = tid + v.CommandObject = NewAmf0Null() + v.StreamID = 0 + return v +} + +func (v *CreateStreamResPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamID.Size() +} + +func (v *CreateStreamResPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamID.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal sid") + } + + return +} + +func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamID.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal sid") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 64, @section 4.2.6. Publish +type PublishPacket struct { + variantCallPacket + StreamName amf0String + StreamType amf0String +} + +func NewPublishPacket() *PublishPacket { + v := &PublishPacket{} + v.CommandName = commandPublish + v.CommandObject = NewAmf0Null() + v.StreamType = "live" + return v +} + +func (v *PublishPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamName.Size() + v.StreamType.Size() +} + +func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal stream name") + } + p = p[v.StreamName.Size():] + + if err = v.StreamType.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal stream type") + } + + return +} + +func (v *PublishPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal stream name") + } + data = append(data, pb...) + + if pb, err = v.StreamType.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal stream type") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 54, @section 4.2.1. play +type PlayPacket struct { + variantCallPacket + StreamName amf0String +} + +func NewPlayPacket() *PlayPacket { + v := &PlayPacket{} + v.CommandName = commandPlay + v.CommandObject = NewAmf0Null() + return v +} + +func (v *PlayPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamName.Size() +} + +func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal stream name") + } + p = p[v.StreamName.Size():] + + return +} + +func (v *PlayPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal stream name") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 31, @section 5.1. Set Chunk Size +// Protocol control message 1, Set Chunk Size, is used to notify the +// peer about the new maximum chunk size. +type SetChunkSize struct { + ChunkSize uint32 +} + +func NewSetChunkSize() *SetChunkSize { + return &SetChunkSize{ + ChunkSize: defaultChunkSize, + } +} + +func (v *SetChunkSize) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *SetChunkSize) Type() MessageType { + return MessageTypeSetChunkSize +} + +func (v *SetChunkSize) Size() int { + return 4 +} + +func (v *SetChunkSize) UnmarshalBinary(data []byte) (err error) { + if len(data) < 4 { + return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) + } + v.ChunkSize = binary.BigEndian.Uint32(data) + + return +} + +func (v *SetChunkSize) MarshalBinary() (data []byte, err error) { + data = make([]byte, 4) + binary.BigEndian.PutUint32(data, v.ChunkSize) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.5. Window Acknowledgement Size (5) +// The client or the server sends this message to inform the peer which +// window size to use when sending acknowledgment. +type WindowAcknowledgementSize struct { + AckSize uint32 +} + +func NewWindowAcknowledgementSize() *WindowAcknowledgementSize { + return &WindowAcknowledgementSize{} +} + +func (v *WindowAcknowledgementSize) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *WindowAcknowledgementSize) Type() MessageType { + return MessageTypeWindowAcknowledgementSize +} + +func (v *WindowAcknowledgementSize) Size() int { + return 4 +} + +func (v *WindowAcknowledgementSize) UnmarshalBinary(data []byte) (err error) { + if len(data) < 4 { + return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) + } + v.AckSize = binary.BigEndian.Uint32(data) + + return +} + +func (v *WindowAcknowledgementSize) MarshalBinary() (data []byte, err error) { + data = make([]byte, 4) + binary.BigEndian.PutUint32(data, v.AckSize) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.6. Set Peer Bandwidth (6) +// The sender can mark this message hard (0), soft (1), or dynamic (2) +// using the Limit type field. +type LimitType uint8 + +const ( + LimitTypeHard LimitType = iota + LimitTypeSoft + LimitTypeDynamic +) + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.6. Set Peer Bandwidth (6) +// The client or the server sends this message to update the output +// bandwidth of the peer. +type SetPeerBandwidth struct { + Bandwidth uint32 + LimitType LimitType +} + +func NewSetPeerBandwidth() *SetPeerBandwidth { + return &SetPeerBandwidth{} +} + +func (v *SetPeerBandwidth) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *SetPeerBandwidth) Type() MessageType { + return MessageTypeSetPeerBandwidth +} + +func (v *SetPeerBandwidth) Size() int { + return 4 + 1 +} + +func (v *SetPeerBandwidth) UnmarshalBinary(data []byte) (err error) { + if len(data) < 5 { + return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) + } + v.Bandwidth = binary.BigEndian.Uint32(data) + v.LimitType = LimitType(data[4]) + + return +} + +func (v *SetPeerBandwidth) MarshalBinary() (data []byte, err error) { + data = make([]byte, 5) + binary.BigEndian.PutUint32(data, v.Bandwidth) + data[4] = byte(v.LimitType) + + return +} + +type EventType uint16 + +const ( + // Generally, 4bytes event-data + + // The server sends this event to notify the client + // that a stream has become functional and can be + // used for communication. By default, this event + // is sent on ID 0 after the application connect + // command is successfully received from the + // client. The event data is 4-byte and represents + // The stream ID of the stream that became + // Functional. + EventTypeStreamBegin = 0x00 + + // The server sends this event to notify the client + // that the playback of data is over as requested + // on this stream. No more data is sent without + // issuing additional commands. The client discards + // The messages received for the stream. The + // 4 bytes of event data represent the ID of the + // stream on which playback has ended. + EventTypeStreamEOF = 0x01 + + // The server sends this event to notify the client + // that there is no more data on the stream. If the + // server does not detect any message for a time + // period, it can notify the subscribed clients + // that the stream is dry. The 4 bytes of event + // data represent the stream ID of the dry stream. + EventTypeStreamDry = 0x02 + + // The client sends this event to inform the server + // of the buffer size (in milliseconds) that is + // used to buffer any data coming over a stream. + // This event is sent before the server starts + // processing the stream. The first 4 bytes of the + // event data represent the stream ID and the next + // 4 bytes represent the buffer length, in + // milliseconds. + EventTypeSetBufferLength = 0x03 // 8bytes event-data + + // The server sends this event to notify the client + // that the stream is a recorded stream. The + // 4 bytes event data represent the stream ID of + // The recorded stream. + EventTypeStreamIsRecorded = 0x04 + + // The server sends this event to test whether the + // client is reachable. Event data is a 4-byte + // timestamp, representing the local server time + // When the server dispatched the command. The + // client responds with kMsgPingResponse on + // receiving kMsgPingRequest. + EventTypePingRequest = 0x06 + + // The client sends this event to the server in + // Response to the ping request. The event data is + // a 4-byte timestamp, which was received with the + // kMsgPingRequest request. + EventTypePingResponse = 0x07 + + // For PCUC size=3, for example the payload is "00 1A 01", + // it's a FMS control event, where the event type is 0x001a and event data is 0x01, + // please notice that the event data is only 1 byte for this event. + EventTypeFmsEvent0 = 0x1a +) + +// Please read @doc rtmp_specification_1.0.pdf, @page 32, @5.4. User Control Message (4) +// The client or the server sends this message to notify the peer about the user control events. +// This message carries Event type and Event data. +type UserControl struct { + // Event type is followed by Event data. + // @see: SrcPCUCEventType + EventType EventType + // The event data generally in 4bytes. + // @remark for event type is 0x001a, only 1bytes. + // @see SrsPCUCFmsEvent0 + EventData int32 + // 4bytes if event_type is SetBufferLength; otherwise 0. + ExtraData int32 +} + +func NewUserControl() *UserControl { + return &UserControl{} +} + +func (v *UserControl) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *UserControl) Type() MessageType { + return MessageTypeUserControl +} + +func (v *UserControl) Size() int { + size := 2 + + if v.EventType == EventTypeFmsEvent0 { + size += 1 + } else { + size += 4 + } + + if v.EventType == EventTypeSetBufferLength { + size += 4 + } + + return size +} + +func (v *UserControl) UnmarshalBinary(data []byte) (err error) { + if len(data) < 3 { + return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) + } + + v.EventType = EventType(binary.BigEndian.Uint16(data)) + if len(data) < v.Size() { + return errors.Errorf("requires %v only %v bytes, %x", v.Size(), len(data), data) + } + + if v.EventType == EventTypeFmsEvent0 { + v.EventData = int32(uint8(data[2])) + } else { + v.EventData = int32(binary.BigEndian.Uint32(data[2:])) + } + + if v.EventType == EventTypeSetBufferLength { + v.ExtraData = int32(binary.BigEndian.Uint32(data[6:])) + } + + return +} + +func (v *UserControl) MarshalBinary() (data []byte, err error) { + data = make([]byte, v.Size()) + binary.BigEndian.PutUint16(data, uint16(v.EventType)) + + if v.EventType == EventTypeFmsEvent0 { + data[2] = uint8(v.EventData) + } else { + binary.BigEndian.PutUint32(data[2:], uint32(v.EventData)) + } + + if v.EventType == EventTypeSetBufferLength { + binary.BigEndian.PutUint32(data[6:], uint32(v.ExtraData)) + } + + return +} diff --git a/proxy/signal.go b/proxy/signal.go new file mode 100644 index 000000000..367543f4a --- /dev/null +++ b/proxy/signal.go @@ -0,0 +1,44 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +func installSignals(ctx context.Context, cancel context.CancelFunc) { + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + + go func() { + for s := range sc { + logger.Df(ctx, "Got signal %v", s) + cancel() + } + }() +} + +func installForceQuit(ctx context.Context) error { + var forceTimeout time.Duration + if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil { + return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout()) + } else { + forceTimeout = t + } + + go func() { + <-ctx.Done() + time.Sleep(forceTimeout) + logger.Wf(ctx, "Force to exit by timeout") + os.Exit(1) + }() + return nil +} diff --git a/proxy/srs.go b/proxy/srs.go new file mode 100644 index 000000000..d05a39c61 --- /dev/null +++ b/proxy/srs.go @@ -0,0 +1,553 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "os" + "strconv" + "strings" + "time" + + // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ + "github.com/go-redis/redis/v8" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +// If server heartbeat in this duration, it's alive. +const srsServerAliveDuration = 300 * time.Second + +// If HLS streaming update in this duration, it's alive. +const srsHLSAliveDuration = 120 * time.Second + +// If WebRTC streaming update in this duration, it's alive. +const srsRTCAliveDuration = 120 * time.Second + +type SRSServer struct { + // The server IP. + IP string `json:"ip,omitempty"` + // The server device ID, configured by user. + DeviceID string `json:"device_id,omitempty"` + // The server id of SRS, store in file, may not change, mandatory. + ServerID string `json:"server_id,omitempty"` + // The service id of SRS, always change when restarted, mandatory. + ServiceID string `json:"service_id,omitempty"` + // The process id of SRS, always change when restarted, mandatory. + PID string `json:"pid,omitempty"` + // The RTMP listen endpoints. + RTMP []string `json:"rtmp,omitempty"` + // The HTTP Stream listen endpoints. + HTTP []string `json:"http,omitempty"` + // The HTTP API listen endpoints. + API []string `json:"api,omitempty"` + // The SRT server listen endpoints. + SRT []string `json:"srt,omitempty"` + // The RTC server listen endpoints. + RTC []string `json:"rtc,omitempty"` + // Last update time. + UpdatedAt time.Time `json:"update_at,omitempty"` +} + +func (v *SRSServer) ID() string { + return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID) +} + +func (v *SRSServer) String() string { + return fmt.Sprintf("%v", v) +} + +func (v *SRSServer) Format(f fmt.State, c rune) { + switch c { + case 'v', 's': + if f.Flag('+') { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID)) + if v.DeviceID != "" { + sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID)) + } + if len(v.RTMP) > 0 { + sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ","))) + } + if len(v.HTTP) > 0 { + sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ","))) + } + if len(v.API) > 0 { + sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ","))) + } + if len(v.SRT) > 0 { + sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ","))) + } + if len(v.RTC) > 0 { + sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ","))) + } + sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999"))) + fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String()) + } else { + fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID()) + } + default: + fmt.Fprintf(f, "%v, fmt=%%%c", v, c) + } +} + +func NewSRSServer(opts ...func(*SRSServer)) *SRSServer { + v := &SRSServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only. +func NewDefaultSRSForDebugging() (*SRSServer, error) { + if envDefaultBackendEnabled() != "on" { + return nil, nil + } + + if envDefaultBackendIP() == "" { + return nil, fmt.Errorf("empty default backend ip") + } + if envDefaultBackendRTMP() == "" { + return nil, fmt.Errorf("empty default backend rtmp") + } + + server := NewSRSServer(func(srs *SRSServer) { + srs.IP = envDefaultBackendIP() + srs.RTMP = []string{envDefaultBackendRTMP()} + srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID()) + srs.ServiceID = logger.GenerateContextID() + srs.PID = fmt.Sprintf("%v", os.Getpid()) + srs.UpdatedAt = time.Now() + }) + + if envDefaultBackendHttp() != "" { + server.HTTP = []string{envDefaultBackendHttp()} + } + if envDefaultBackendAPI() != "" { + server.API = []string{envDefaultBackendAPI()} + } + if envDefaultBackendRTC() != "" { + server.RTC = []string{envDefaultBackendRTC()} + } + if envDefaultBackendSRT() != "" { + server.SRT = []string{envDefaultBackendSRT()} + } + return server, nil +} + +// SRSLoadBalancer is the interface to load balance the SRS servers. +type SRSLoadBalancer interface { + // Initialize the load balancer. + Initialize(ctx context.Context) error + // Update the backer server. + Update(ctx context.Context, server *SRSServer) error + // Pick a backend server for the specified stream URL. + Pick(ctx context.Context, streamURL string) (*SRSServer, error) + // Load or store the HLS streaming for the specified stream URL. + LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) + // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID. + LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) + // Store the WebRTC streaming for the specified stream URL. + StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error + // Load the WebRTC streaming by ufrag, the ICE username. + LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) +} + +// srsLoadBalancer is the global SRS load balancer. +var srsLoadBalancer SRSLoadBalancer + +// srsMemoryLoadBalancer stores state in memory. +type srsMemoryLoadBalancer struct { + // All available SRS servers, key is server ID. + servers sync.Map[string, *SRSServer] + // The picked server to servce client by specified stream URL, key is stream url. + picked sync.Map[string, *SRSServer] + // The HLS streaming, key is stream URL. + hlsStreamURL sync.Map[string, *HLSPlayStream] + // The HLS streaming, key is SPBHID. + hlsSPBHID sync.Map[string, *HLSPlayStream] + // The WebRTC streaming, key is stream URL. + rtcStreamURL sync.Map[string, *RTCConnection] + // The WebRTC streaming, key is ufrag. + rtcUfrag sync.Map[string, *RTCConnection] +} + +func NewMemoryLoadBalancer() SRSLoadBalancer { + return &srsMemoryLoadBalancer{} +} + +func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error { + if server, err := NewDefaultSRSForDebugging(); err != nil { + return errors.Wrapf(err, "initialize default SRS") + } else if server != nil { + if err := v.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update default SRS %+v", server) + } + + // Keep alive. + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + if err := v.Update(ctx, server); err != nil { + logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err) + } + } + } + }() + logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server) + } + return nil +} + +func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error { + v.servers.Store(server.ID(), server) + return nil +} + +func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { + // Always proxy to the same server for the same stream URL. + if server, ok := v.picked.Load(streamURL); ok { + return server, nil + } + + // Gather all servers that were alive within the last few seconds. + var servers []*SRSServer + v.servers.Range(func(key string, server *SRSServer) bool { + if time.Since(server.UpdatedAt) < srsServerAliveDuration { + servers = append(servers, server) + } + return true + }) + + // If no servers available, use all possible servers. + if len(servers) == 0 { + v.servers.Range(func(key string, server *SRSServer) bool { + servers = append(servers, server) + return true + }) + } + + // No server found, failed. + if len(servers) == 0 { + return nil, fmt.Errorf("no server available for %v", streamURL) + } + + // Pick a server randomly from servers. + server := servers[rand.Intn(len(servers))] + v.picked.Store(streamURL, server) + return server, nil +} + +func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) { + // Load the HLS streaming for the SPBHID, for TS files. + if actual, ok := v.hlsSPBHID.Load(spbhid); !ok { + return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid) + } else { + return actual, nil + } +} + +func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) { + // Update the HLS streaming for the stream URL, for M3u8. + actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value) + if actual == nil { + return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL) + } + + // Update the HLS streaming for the SPBHID, for TS files. + v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual) + + return actual, nil +} + +func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { + // Update the WebRTC streaming for the stream URL. + v.rtcStreamURL.Store(streamURL, value) + + // Update the WebRTC streaming for the ufrag. + v.rtcUfrag.Store(value.Ufrag, value) + return nil +} + +func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { + if actual, ok := v.rtcUfrag.Load(ufrag); !ok { + return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag) + } else { + return actual, nil + } +} + +type srsRedisLoadBalancer struct { + // The redis client sdk. + rdb *redis.Client +} + +func NewRedisLoadBalancer() SRSLoadBalancer { + return &srsRedisLoadBalancer{} +} + +func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error { + redisDatabase, err := strconv.Atoi(envRedisDB()) + if err != nil { + return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB()) + } + + rdb := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()), + Password: envRedisPassword(), + DB: redisDatabase, + }) + v.rdb = rdb + + if err := rdb.Ping(ctx).Err(); err != nil { + return errors.Wrapf(err, "unable to connect to redis %v", rdb.String()) + } + logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String()) + + if server, err := NewDefaultSRSForDebugging(); err != nil { + return errors.Wrapf(err, "initialize default SRS") + } else if server != nil { + if err := v.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update default SRS %+v", server) + } + + // Keep alive. + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + if err := v.Update(ctx, server); err != nil { + logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err) + } + } + } + }() + logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server) + } + return nil +} + +func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error { + b, err := json.Marshal(server) + if err != nil { + return errors.Wrapf(err, "marshal server %+v", server) + } + + key := v.redisKeyServer(server.ID()) + if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v server %+v", key, server) + } + + // Query all servers from redis, in json string. + var serverKeys []string + if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { + if err := json.Unmarshal(b, &serverKeys); err != nil { + return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) + } + } + + // Check each server expiration, if not exists in redis, remove from servers. + for i := len(serverKeys) - 1; i >= 0; i-- { + if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil { + serverKeys = append(serverKeys[:i], serverKeys[i+1:]...) + } + } + + // Add server to servers if not exists. + var found bool + for _, serverKey := range serverKeys { + if serverKey == key { + found = true + break + } + } + if !found { + serverKeys = append(serverKeys, key) + } + + // Update all servers to redis. + b, err = json.Marshal(serverKeys) + if err != nil { + return errors.Wrapf(err, "marshal servers %+v", serverKeys) + } + if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil { + return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys) + } + + return nil +} + +func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { + key := fmt.Sprintf("srs-proxy-url:%v", streamURL) + + // Always proxy to the same server for the same stream URL. + if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil { + // If server not exists, ignore and pick another server for the stream URL. + if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 { + var server SRSServer + if err := json.Unmarshal(b, &server); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b)) + } + + // TODO: If server fail, we should migrate the streams to another server. + return &server, nil + } + } + + // Query all servers from redis, in json string. + var serverKeys []string + if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { + if err := json.Unmarshal(b, &serverKeys); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) + } + } + + // No server found, failed. + if len(serverKeys) == 0 { + return nil, fmt.Errorf("no server available for %v", streamURL) + } + + // All server should be alive, if not, should have been removed by redis. So we only + // random pick one that is always available. + var serverKey string + var server SRSServer + for i := 0; i < 3; i++ { + tryServerKey := serverKeys[rand.Intn(len(serverKeys))] + b, err := v.rdb.Get(ctx, tryServerKey).Bytes() + if err == nil && len(b) > 0 { + if err := json.Unmarshal(b, &server); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b)) + } + + serverKey = tryServerKey + break + } + } + if serverKey == "" { + return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL) + } + + // Update the picked server for the stream URL. + if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey) + } + + return &server, nil +} + +func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) { + key := v.redisKeySPBHID(spbhid) + + b, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v HLS", key) + } + + var actual HLSPlayStream + if err := json.Unmarshal(b, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b)) + } + + return &actual, nil +} + +func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) { + b, err := json.Marshal(value) + if err != nil { + return nil, errors.Wrapf(err, "marshal HLS %v", value) + } + + key := v.redisKeyHLS(streamURL) + if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value) + } + + key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID) + if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value) + } + + // Query the HLS streaming from redis. + b2, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v HLS", key) + } + + var actual HLSPlayStream + if err := json.Unmarshal(b2, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2)) + } + + return &actual, nil +} + +func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { + b, err := json.Marshal(value) + if err != nil { + return errors.Wrapf(err, "marshal WebRTC %v", value) + } + + key := v.redisKeyRTC(streamURL) + if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v WebRTC %v", key, value) + } + + key2 := v.redisKeyUfrag(value.Ufrag) + if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value) + } + + return nil +} + +func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { + key := v.redisKeyUfrag(ufrag) + + b, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v WebRTC", key) + } + + var actual RTCConnection + if err := json.Unmarshal(b, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b)) + } + + return &actual, nil +} + +func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string { + return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag) +} + +func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string { + return fmt.Sprintf("srs-proxy-rtc:%v", streamURL) +} + +func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string { + return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid) +} + +func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string { + return fmt.Sprintf("srs-proxy-hls:%v", streamURL) +} + +func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string { + return fmt.Sprintf("srs-proxy-server:%v", serverID) +} + +func (v *srsRedisLoadBalancer) redisKeyServers() string { + return fmt.Sprintf("srs-proxy-all-servers") +} diff --git a/proxy/srt.go b/proxy/srt.go new file mode 100644 index 000000000..e4c629af8 --- /dev/null +++ b/proxy/srt.go @@ -0,0 +1,574 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "net" + "strings" + stdSync "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +// srsSRTServer is the proxy for SRS server via SRT. It will figure out which backend server to +// proxy to. It only parses the SRT handshake messages, parses the stream id, and proxy to the +// backend server. +type srsSRTServer struct { + // The UDP listener for SRT server. + listener *net.UDPConn + + // The SRT connections, identify by the socket ID. + sockets sync.Map[uint32, *SRTConnection] + // The system start time. + start time.Time + + // The wait group for server. + wg stdSync.WaitGroup +} + +func NewSRSSRTServer(opts ...func(*srsSRTServer)) *srsSRTServer { + v := &srsSRTServer{ + start: time.Now(), + } + + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsSRTServer) Close() error { + if v.listener != nil { + v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srsSRTServer) Run(ctx context.Context) error { + // Parse address to listen. + endpoint := envSRTServer() + if !strings.Contains(endpoint, ":") { + endpoint = ":" + endpoint + } + + saddr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + + listener, err := net.ListenUDP("udp", saddr) + if err != nil { + return errors.Wrapf(err, "listen udp %v", saddr) + } + v.listener = listener + logger.Df(ctx, "SRT server listen at %v", saddr) + + // Consume all messages from UDP media transport. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, caddr, err := v.listener.ReadFromUDP(buf) + if err != nil { + // TODO: If SRT server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "read from udp failed, err=%+v", err) + continue + } + + if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) + } + } + }() + + return nil +} + +func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + socketID := srtParseSocketID(data) + + var pkt *SRTHandshakePacket + if srtIsHandshake(data) { + pkt = &SRTHandshakePacket{} + if err := pkt.UnmarshalBinary(data); err != nil { + return err + } + + if socketID == 0 { + socketID = pkt.SRTSocketID + } + } + + conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) { + c.ctx = logger.WithContext(ctx) + c.listenerUDP, c.socketID = v.listener, socketID + c.start = v.start + })) + + ctx = conn.ctx + if !ok { + logger.Df(ctx, "Create new SRT connection skt=%v", socketID) + } + + if newSocketID, err := conn.HandlePacket(pkt, addr, data); err != nil { + return errors.Wrapf(err, "handle packet") + } else if newSocketID != 0 && newSocketID != socketID { + // The connection may use a new socket ID. + // TODO: FIXME: Should cleanup the dead SRT connection. + v.sockets.Store(newSocketID, conn) + } + + return nil +} + +// SRTConnection is an SRT connection proxy, for both caller and listener. It represents an SRT +// connection, identify by the socket ID. +// +// It's similar to RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is in +// the client request. The SRTConnection is stateless, and no need to sync between proxy servers. +// +// Unlike the WebRTC connection, SRTConnection does not support address changes. This means the +// client should never switch to another network or port. If this occurs, the client may be served +// by a different proxy server and fail because the other proxy server cannot identify the client. +type SRTConnection struct { + // The stream context for SRT connection. + ctx context.Context + + // The current socket ID. + socketID uint32 + + // The UDP connection proxy to backend. + backendUDP *net.UDPConn + // The listener UDP connection, used to send messages to client. + listenerUDP *net.UDPConn + + // Listener start time. + start time.Time + + // Handshake packets with client. + handshake0 *SRTHandshakePacket + handshake1 *SRTHandshakePacket + handshake2 *SRTHandshakePacket + handshake3 *SRTHandshakePacket +} + +func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection { + v := &SRTConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) { + ctx := v.ctx + + // If not handshake, try to proxy to backend directly. + if pkt == nil { + // Proxy client message to backend. + if v.backendUDP != nil { + if _, err := v.backendUDP.Write(data); err != nil { + return v.socketID, errors.Wrapf(err, "write to backend") + } + } + + return v.socketID, nil + } + + // Handle handshake messages. + if err := v.handleHandshake(ctx, pkt, addr, data); err != nil { + return v.socketID, errors.Wrapf(err, "handle handshake %v", pkt) + } + + return v.socketID, nil +} + +func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error { + // Handle handshake 0 and 1 messages. + if pkt.SynCookie == 0 { + // Save handshake 0 packet. + v.handshake0 = pkt + logger.Df(ctx, "SRT Handshake 0: %v", v.handshake0) + + // Response handshake 1. + v.handshake1 = &SRTHandshakePacket{ + ControlFlag: pkt.ControlFlag, + ControlType: 0, + SubType: 0, + AdditionalInfo: 0, + Timestamp: uint32(time.Since(v.start).Microseconds()), + SocketID: pkt.SRTSocketID, + Version: 5, + EncryptionField: 0, + ExtensionField: 0x4A17, + InitSequence: pkt.InitSequence, + MTU: pkt.MTU, + FlowWindow: pkt.FlowWindow, + HandshakeType: 1, + SRTSocketID: pkt.SRTSocketID, + SynCookie: 0x418d5e4e, + PeerIP: net.ParseIP("127.0.0.1"), + } + logger.Df(ctx, "SRT Handshake 1: %v", v.handshake1) + + if b, err := v.handshake1.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 1") + } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + return errors.Wrapf(err, "write handshake 1") + } + + return nil + } + + // Handle handshake 2 and 3 messages. + // Parse stream id from packet. + streamID, err := pkt.StreamID() + if err != nil { + return errors.Wrapf(err, "parse stream id") + } + + // Save handshake packet. + v.handshake2 = pkt + logger.Df(ctx, "SRT Handshake 2: %v, sid=%v", v.handshake2, streamID) + + // Start the UDP proxy to backend. + if err := v.connectBackend(ctx, streamID); err != nil { + return errors.Wrapf(err, "connect backend for %v", streamID) + } + + // Proxy client message to backend. + if v.backendUDP == nil { + return errors.Errorf("no backend for %v", streamID) + } + + // Proxy handshake 0 to backend server. + if b, err := v.handshake0.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 0") + } else if _, err = v.backendUDP.Write(b); err != nil { + return errors.Wrapf(err, "write handshake 0") + } + logger.Df(ctx, "Proxy send handshake 0: %v", v.handshake0) + + // Read handshake 1 from backend server. + b := make([]byte, 4096) + handshake1p := &SRTHandshakePacket{} + if nn, err := v.backendUDP.Read(b); err != nil { + return errors.Wrapf(err, "read handshake 1") + } else if err := handshake1p.UnmarshalBinary(b[:nn]); err != nil { + return errors.Wrapf(err, "unmarshal handshake 1") + } + logger.Df(ctx, "Proxy got handshake 1: %v", handshake1p) + + // Proxy handshake 2 to backend server. + handshake2p := *v.handshake2 + handshake2p.SynCookie = handshake1p.SynCookie + if b, err := handshake2p.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 2") + } else if _, err = v.backendUDP.Write(b); err != nil { + return errors.Wrapf(err, "write handshake 2") + } + logger.Df(ctx, "Proxy send handshake 2: %v", handshake2p) + + // Read handshake 3 from backend server. + handshake3p := &SRTHandshakePacket{} + if nn, err := v.backendUDP.Read(b); err != nil { + return errors.Wrapf(err, "read handshake 3") + } else if err := handshake3p.UnmarshalBinary(b[:nn]); err != nil { + return errors.Wrapf(err, "unmarshal handshake 3") + } + logger.Df(ctx, "Proxy got handshake 3: %v", handshake3p) + + // Response handshake 3 to client. + v.handshake3 = &*handshake3p + v.handshake3.SynCookie = v.handshake1.SynCookie + v.socketID = handshake3p.SRTSocketID + logger.Df(ctx, "Handshake 3: %v", v.handshake3) + + if b, err := v.handshake3.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 3") + } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + return errors.Wrapf(err, "write handshake 3") + } + + // Start a goroutine to proxy message from backend to client. + // TODO: FIXME: Support close the connection when timeout or client disconnected. + go func() { + for ctx.Err() == nil { + nn, err := v.backendUDP.Read(b) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + return + } + if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + return + } + } + }() + return nil +} + +func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) error { + if v.backendUDP != nil { + return nil + } + + // Parse stream id to host and resource. + host, resource, err := parseSRTStreamID(streamID) + if err != nil { + return errors.Wrapf(err, "parse stream id %v", streamID) + } + + if host == "" { + host = "localhost" + } + + streamURL, err := buildStreamURL(fmt.Sprintf("srt://%v/%v", host, resource)) + if err != nil { + return errors.Wrapf(err, "build stream url %v", streamID) + } + + // Pick a backend SRS server to proxy the SRT stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + // Parse UDP port from backend. + if len(backend.SRT) == 0 { + return errors.Errorf("no udp server %v for %v", backend, streamURL) + } + + _, _, udpPort, err := parseListenEndpoint(backend.SRT[0]) + if err != nil { + return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL) + } + + // Connect to backend SRS server via UDP client. + // TODO: FIXME: Support close the connection when timeout or client disconnected. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} + if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { + return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL) + } else { + v.backendUDP = backendUDP + } + + return nil +} + +// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2 +// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2.1 +type SRTHandshakePacket struct { + // F: 1 bit. Packet Type Flag. The control packet has this flag set to + // "1". The data packet has this flag set to "0". + ControlFlag uint8 + // Control Type: 15 bits. Control Packet Type. The use of these bits + // is determined by the control packet type definition. + // Handshake control packets (Control Type = 0x0000) are used to + // exchange peer configurations, to agree on connection parameters, and + // to establish a connection. + ControlType uint16 + // Subtype: 16 bits. This field specifies an additional subtype for + // specific packets. + SubType uint16 + // Type-specific Information: 32 bits. The use of this field depends on + // the particular control packet type. Handshake packets do not use + // this field. + AdditionalInfo uint32 + // Timestamp: 32 bits. + Timestamp uint32 + // Destination Socket ID: 32 bits. + SocketID uint32 + + // Version: 32 bits. A base protocol version number. Currently used + // values are 4 and 5. Values greater than 5 are reserved for future + // use. + Version uint32 + // Encryption Field: 16 bits. Block cipher family and key size. The + // values of this field are described in Table 2. The default value + // is AES-128. + // 0 | No Encryption Advertised + // 2 | AES-128 + // 3 | AES-192 + // 4 | AES-256 + EncryptionField uint16 + // Extension Field: 16 bits. This field is message specific extension + // related to Handshake Type field. The value MUST be set to 0 + // except for the following cases. (1) If the handshake control + // packet is the INDUCTION message, this field is sent back by the + // Listener. (2) In the case of a CONCLUSION message, this field + // value should contain a combination of Extension Type values. + // 0x00000001 | HSREQ + // 0x00000002 | KMREQ + // 0x00000004 | CONFIG + // 0x4A17 if HandshakeType is INDUCTION, see https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-4.3.1.1 + ExtensionField uint16 + // Initial Packet Sequence Number: 32 bits. The sequence number of the + // very first data packet to be sent. + InitSequence uint32 + // Maximum Transmission Unit Size: 32 bits. This value is typically set + // to 1500, which is the default Maximum Transmission Unit (MTU) size + // for Ethernet, but can be less. + MTU uint32 + // Maximum Flow Window Size: 32 bits. The value of this field is the + // maximum number of data packets allowed to be "in flight" (i.e. the + // number of sent packets for which an ACK control packet has not yet + // been received). + FlowWindow uint32 + // Handshake Type: 32 bits. This field indicates the handshake packet + // type. + // 0xFFFFFFFD | DONE + // 0xFFFFFFFE | AGREEMENT + // 0xFFFFFFFF | CONCLUSION + // 0x00000000 | WAVEHAND + // 0x00000001 | INDUCTION + HandshakeType uint32 + // SRT Socket ID: 32 bits. This field holds the ID of the source SRT + // socket from which a handshake packet is issued. + SRTSocketID uint32 + // SYN Cookie: 32 bits. Randomized value for processing a handshake. + // The value of this field is specified by the handshake message + // type. + SynCookie uint32 + // Peer IP Address: 128 bits. IPv4 or IPv6 address of the packet's + // sender. The value consists of four 32-bit fields. + PeerIP net.IP + // Extensions. + // Extension Type: 16 bits. The value of this field is used to process + // an integrated handshake. Each extension can have a pair of + // request and response types. + // Extension Length: 16 bits. The length of the Extension Contents + // field in four-byte blocks. + // Extension Contents: variable length. The payload of the extension. + ExtraData []byte +} + +func (v *SRTHandshakePacket) IsData() bool { + return v.ControlFlag == 0x00 +} + +func (v *SRTHandshakePacket) IsControl() bool { + return v.ControlFlag == 0x80 +} + +func (v *SRTHandshakePacket) IsHandshake() bool { + return v.IsControl() && v.ControlType == 0x00 && v.SubType == 0x00 +} + +func (v *SRTHandshakePacket) StreamID() (string, error) { + p := v.ExtraData + for { + if len(p) < 2 { + return "", errors.Errorf("Require 2 bytes, actual=%v, extra=%v", len(p), len(v.ExtraData)) + } + + extType := binary.BigEndian.Uint16(p) + extSize := binary.BigEndian.Uint16(p[2:]) + p = p[4:] + + if len(p) < int(extSize*4) { + return "", errors.Errorf("Require %v bytes, actual=%v, extra=%v", extSize*4, len(p), len(v.ExtraData)) + } + + // Ignore other packets except stream id. + if extType != 0x05 { + p = p[extSize*4:] + continue + } + + // We must copy it, because we will decode the stream id. + data := append([]byte{}, p[:extSize*4]...) + + // Reverse the stream id encoded in little-endian to big-endian. + for i := 0; i < len(data); i += 4 { + value := binary.LittleEndian.Uint32(data[i:]) + binary.BigEndian.PutUint32(data[i:], value) + } + + // Trim the trailing zero bytes. + data = bytes.TrimRight(data, "\x00") + return string(data), nil + } +} + +func (v *SRTHandshakePacket) String() string { + return fmt.Sprintf("Control=%v, CType=%v, SType=%v, Timestamp=%v, SocketID=%v, Version=%v, Encrypt=%v, Extension=%v, InitSequence=%v, MTU=%v, FlowWnd=%v, HSType=%v, SRTSocketID=%v, Cookie=%v, Peer=%vB, Extra=%vB", + v.IsControl(), v.ControlType, v.SubType, v.Timestamp, v.SocketID, v.Version, v.EncryptionField, v.ExtensionField, v.InitSequence, v.MTU, v.FlowWindow, v.HandshakeType, v.SRTSocketID, v.SynCookie, len(v.PeerIP), len(v.ExtraData)) +} + +func (v *SRTHandshakePacket) UnmarshalBinary(b []byte) error { + if len(b) < 4 { + return errors.Errorf("Invalid packet length %v", len(b)) + } + v.ControlFlag = b[0] & 0x80 + v.ControlType = binary.BigEndian.Uint16(b[0:2]) & 0x7fff + v.SubType = binary.BigEndian.Uint16(b[2:4]) + + if len(b) < 64 { + return errors.Errorf("Invalid packet length %v", len(b)) + } + v.AdditionalInfo = binary.BigEndian.Uint32(b[4:]) + v.Timestamp = binary.BigEndian.Uint32(b[8:]) + v.SocketID = binary.BigEndian.Uint32(b[12:]) + v.Version = binary.BigEndian.Uint32(b[16:]) + v.EncryptionField = binary.BigEndian.Uint16(b[20:]) + v.ExtensionField = binary.BigEndian.Uint16(b[22:]) + v.InitSequence = binary.BigEndian.Uint32(b[24:]) + v.MTU = binary.BigEndian.Uint32(b[28:]) + v.FlowWindow = binary.BigEndian.Uint32(b[32:]) + v.HandshakeType = binary.BigEndian.Uint32(b[36:]) + v.SRTSocketID = binary.BigEndian.Uint32(b[40:]) + v.SynCookie = binary.BigEndian.Uint32(b[44:]) + + // Only support IPv4. + v.PeerIP = net.IPv4(b[51], b[50], b[49], b[48]) + + v.ExtraData = b[64:] + + return nil +} + +func (v *SRTHandshakePacket) MarshalBinary() ([]byte, error) { + b := make([]byte, 64+len(v.ExtraData)) + binary.BigEndian.PutUint16(b, uint16(v.ControlFlag)<<8|v.ControlType) + binary.BigEndian.PutUint16(b[2:], v.SubType) + binary.BigEndian.PutUint32(b[4:], v.AdditionalInfo) + binary.BigEndian.PutUint32(b[8:], v.Timestamp) + binary.BigEndian.PutUint32(b[12:], v.SocketID) + binary.BigEndian.PutUint32(b[16:], v.Version) + binary.BigEndian.PutUint16(b[20:], v.EncryptionField) + binary.BigEndian.PutUint16(b[22:], v.ExtensionField) + binary.BigEndian.PutUint32(b[24:], v.InitSequence) + binary.BigEndian.PutUint32(b[28:], v.MTU) + binary.BigEndian.PutUint32(b[32:], v.FlowWindow) + binary.BigEndian.PutUint32(b[36:], v.HandshakeType) + binary.BigEndian.PutUint32(b[40:], v.SRTSocketID) + binary.BigEndian.PutUint32(b[44:], v.SynCookie) + + // Only support IPv4. + ip := v.PeerIP.To4() + b[48] = ip[3] + b[49] = ip[2] + b[50] = ip[1] + b[51] = ip[0] + + if len(v.ExtraData) > 0 { + copy(b[64:], v.ExtraData) + } + + return b, nil +} diff --git a/proxy/sync/map.go b/proxy/sync/map.go new file mode 100644 index 000000000..75db12f9a --- /dev/null +++ b/proxy/sync/map.go @@ -0,0 +1,45 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package sync + +import "sync" + +type Map[K comparable, V any] struct { + m sync.Map +} + +func (m *Map[K, V]) Delete(key K) { + m.m.Delete(key) +} + +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + v, ok := m.m.Load(key) + if !ok { + return value, ok + } + return v.(V), ok +} + +func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + v, loaded := m.m.LoadAndDelete(key) + if !loaded { + return value, loaded + } + return v.(V), loaded +} + +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + a, loaded := m.m.LoadOrStore(key, value) + return a.(V), loaded +} + +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.m.Range(func(key, value any) bool { + return f(key.(K), value.(V)) + }) +} + +func (m *Map[K, V]) Store(key K, value V) { + m.m.Store(key, value) +} diff --git a/proxy/utils.go b/proxy/utils.go new file mode 100644 index 000000000..f3c393076 --- /dev/null +++ b/proxy/utils.go @@ -0,0 +1,276 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/binary" + "encoding/json" + stdErr "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "os" + "path" + "reflect" + "regexp" + "strconv" + "strings" + "syscall" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) { + w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version())) + + b, err := json.Marshal(data) + if err != nil { + apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data)) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(b) +} + +func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) { + logger.Wf(ctx, "HTTP API error %+v", err) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, fmt.Sprintf("%v", err)) +} + +func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool { + // Always support CORS. Note that browser may send origin header for m3u8, but no origin header + // for ts. So we always response CORS header. + if true { + // SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin, + // headers, expose headers and methods. + w.Header().Set("Access-Control-Allow-Origin", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + w.Header().Set("Access-Control-Allow-Headers", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + w.Header().Set("Access-Control-Allow-Methods", "*") + } + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return true + } + + return false +} + +func parseGracefullyQuitTimeout() (time.Duration, error) { + if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { + return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) + } else { + return t, nil + } +} + +// ParseBody read the body from r, and unmarshal JSON to v. +func ParseBody(r io.ReadCloser, v interface{}) error { + b, err := ioutil.ReadAll(r) + if err != nil { + return errors.Wrapf(err, "read body") + } + defer r.Close() + + if len(b) == 0 { + return nil + } + + if err := json.Unmarshal(b, v); err != nil { + return errors.Wrapf(err, "json unmarshal %v", string(b)) + } + + return nil +} + +// buildStreamURL build as vhost/app/stream for stream URL r. +func buildStreamURL(r string) (string, error) { + u, err := url.Parse(r) + if err != nil { + return "", errors.Wrapf(err, "parse url %v", r) + } + + // If not domain or ip in hostname, it's __defaultVhost__. + defaultVhost := !strings.Contains(u.Hostname(), ".") + + // If hostname is actually an IP address, it's __defaultVhost__. + if ip := net.ParseIP(u.Hostname()); ip.To4() != nil { + defaultVhost = true + } + + if defaultVhost { + return fmt.Sprintf("__defaultVhost__%v", u.Path), nil + } + + // Ignore port, only use hostname as vhost. + return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil +} + +// isPeerClosedError indicates whether peer object closed the connection. +func isPeerClosedError(err error) bool { + causeErr := errors.Cause(err) + + if stdErr.Is(causeErr, io.EOF) { + return true + } + + if stdErr.Is(causeErr, syscall.EPIPE) { + return true + } + + if netErr, ok := causeErr.(*net.OpError); ok { + if sysErr, ok := netErr.Err.(*os.SyscallError); ok { + if stdErr.Is(sysErr.Err, syscall.ECONNRESET) { + return true + } + } + } + + return false +} + +// convertURLToStreamURL convert the URL in HTTP request to special URLs. The unifiedURL is the URL +// in unified, foramt as scheme://vhost/app/stream without extensions. While the fullURL is the unifiedURL +// with extension. +func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + hostname := "__defaultVhost__" + if strings.Contains(r.Host, ":") { + if v, _, err := net.SplitHostPort(r.Host); err == nil { + hostname = v + } + } + + var appStream, streamExt string + + // Parse app/stream from query string. + q := r.URL.Query() + if app := q.Get("app"); app != "" { + appStream = "/" + app + } + if stream := q.Get("stream"); stream != "" { + appStream = fmt.Sprintf("%v/%v", appStream, stream) + } + + // Parse app/stream from path. + if appStream == "" { + streamExt = path.Ext(r.URL.Path) + appStream = strings.TrimSuffix(r.URL.Path, streamExt) + } + + unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, appStream) + fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt) + return +} + +// rtcIsSTUN returns true if data of UDP payload is a STUN packet. +func rtcIsSTUN(data []byte) bool { + return len(data) > 0 && (data[0] == 0 || data[0] == 1) +} + +// rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet. +func rtcIsRTPOrRTCP(data []byte) bool { + return len(data) >= 12 && (data[0]&0xC0) == 0x80 +} + +// srtIsHandshake returns true if data of UDP payload is a SRT handshake packet. +func srtIsHandshake(data []byte) bool { + return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000 +} + +// srtParseSocketID parse the socket id from the SRT packet. +func srtParseSocketID(data []byte) uint32 { + if len(data) >= 16 { + return binary.BigEndian.Uint32(data[12:]) + } + return 0 +} + +// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP. +func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) { + if true { + ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) + ufragMatch := ufragRe.FindStringSubmatch(sdp) + if len(ufragMatch) <= 1 { + return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp) + } + ufrag = ufragMatch[1] + } + + if true { + pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) + pwdMatch := pwdRe.FindStringSubmatch(sdp) + if len(pwdMatch) <= 1 { + return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp) + } + pwd = pwdMatch[1] + } + + return ufrag, pwd, nil +} + +// parseSRTStreamID parse the SRT stream id to host(optional) and resource(required). +// See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url +func parseSRTStreamID(sid string) (host, resource string, err error) { + if true { + hostRe := regexp.MustCompile(`h=([^,]+)`) + hostMatch := hostRe.FindStringSubmatch(sid) + if len(hostMatch) > 1 { + host = hostMatch[1] + } + } + + if true { + resourceRe := regexp.MustCompile(`r=([^,]+)`) + resourceMatch := resourceRe.FindStringSubmatch(sid) + if len(resourceMatch) <= 1 { + return "", "", errors.Errorf("no resource in sid %v", sid) + } + resource = resourceMatch[1] + } + + return host, resource, nil +} + +// parseListenEndpoint parse the listen endpoint as: +// port The tcp listen port, like 1935. +// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935 +func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) { + // If no colon in ep, it's port in string. + if !strings.Contains(ep, ":") { + if p, err := strconv.Atoi(ep); err != nil { + return "", nil, 0, errors.Wrapf(err, "parse port %v", ep) + } else { + return "tcp", nil, uint16(p), nil + } + } + + // Must be protocol://ip:port schema. + parts := strings.Split(ep, ":") + if len(parts) != 3 { + return "", nil, 0, errors.Errorf("invalid endpoint %v", ep) + } + + if p, err := strconv.Atoi(parts[2]); err != nil { + return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2]) + } else { + return parts[0], net.ParseIP(parts[1]), uint16(p), nil + } +} diff --git a/proxy/version.go b/proxy/version.go new file mode 100644 index 000000000..94f668f96 --- /dev/null +++ b/proxy/version.go @@ -0,0 +1,27 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import "fmt" + +func VersionMajor() int { + return 1 +} + +// VersionMinor specifies the typical version of SRS we adapt to. +func VersionMinor() int { + return 5 +} + +func VersionRevision() int { + return 0 +} + +func Version() string { + return fmt.Sprintf("%v.%v.%v", VersionMajor(), VersionMinor(), VersionRevision()) +} + +func Signature() string { + return "SRSProxy" +} diff --git a/trunk/conf/origin1-for-proxy.conf b/trunk/conf/origin1-for-proxy.conf new file mode 100644 index 000000000..baca5c9f4 --- /dev/null +++ b/trunk/conf/origin1-for-proxy.conf @@ -0,0 +1,57 @@ + +listen 19351; +max_connections 1000; +pid objs/origin1.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8081; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19851; +} +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} +srt_server { + enabled on; + listen 10081; + tsbpdmode off; + tlpktdrop off; +} +heartbeat { + enabled on; + interval 9; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin1; + ports on; +} +vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } + srt { + enabled on; + srt_to_rtmp on; + } +} diff --git a/trunk/conf/origin2-for-proxy.conf b/trunk/conf/origin2-for-proxy.conf new file mode 100644 index 000000000..48f639893 --- /dev/null +++ b/trunk/conf/origin2-for-proxy.conf @@ -0,0 +1,57 @@ + +listen 19352; +max_connections 1000; +pid objs/origin2.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8082; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19853; +} +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} +srt_server { + enabled on; + listen 10082; + tsbpdmode off; + tlpktdrop off; +} +heartbeat { + enabled on; + interval 9; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin2; + ports on; +} +vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } + srt { + enabled on; + srt_to_rtmp on; + } +} diff --git a/trunk/conf/origin3-for-proxy.conf b/trunk/conf/origin3-for-proxy.conf new file mode 100644 index 000000000..95624fb77 --- /dev/null +++ b/trunk/conf/origin3-for-proxy.conf @@ -0,0 +1,57 @@ + +listen 19353; +max_connections 1000; +pid objs/origin3.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8083; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19852; +} +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} +srt_server { + enabled on; + listen 10083; + tsbpdmode off; + tlpktdrop off; +} +heartbeat { + enabled on; + interval 9; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin3; + ports on; +} +vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } + srt { + enabled on; + srt_to_rtmp on; + } +} diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index 2772c0bf2..9e676930f 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2024-09-09, Merge [#4158](https://github.com/ossrs/srs/pull/4158): Proxy: Support proxy server for SRS. v7.0.16 (#4158) * v7.0, 2024-09-09, Merge [#4171](https://github.com/ossrs/srs/pull/4171): Heartbeat: Report ports for proxy server. v7.0.15 (#4171) * v7.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v7.0.14 (#4165) * v7.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v7.0.13 (#4166) diff --git a/trunk/src/app/srs_app_st.cpp b/trunk/src/app/srs_app_st.cpp index 3e21e468c..466cbe068 100755 --- a/trunk/src/app/srs_app_st.cpp +++ b/trunk/src/app/srs_app_st.cpp @@ -342,7 +342,12 @@ SrsWaitGroup::SrsWaitGroup() SrsWaitGroup::~SrsWaitGroup() { - wait(); + // In the destructor, we should NOT wait for all coroutines to be done, because user should decide + // to wait or not. Similar to the Go's sync.WaitGroup, it also requires user to wait explicitly. For + // some special use scenarios, such as error handling, for example, if we started three servers with + // wait group, and one of them failed, user may want to return error and quit directly, without wait + // for other running servers to be done. If we wait in the destructor, it will continue to run without + // some servers, in unknown behaviors. srs_cond_destroy(done_); } diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index fed95c499..458a6c3d8 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 15 +#define VERSION_REVISION 16 #endif \ No newline at end of file From e674f8266a7360e553fcd2c1177b3158d0cd5643 Mon Sep 17 00:00:00 2001 From: winlin Date: Tue, 10 Sep 2024 10:53:09 +0800 Subject: [PATCH 08/12] Proxy: Remove dependency of godotenv. #4158 --- proxy/env.go | 43 +++++++++++++++++++++++++++++++++++-------- proxy/go.mod | 5 +---- proxy/go.sum | 2 -- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/proxy/env.go b/proxy/env.go index 0c201bb1d..26b014609 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -5,10 +5,10 @@ package main import ( "context" + "io/ioutil" "os" "path" - - "github.com/joho/godotenv" + "strings" "srs-proxy/errors" "srs-proxy/logger" @@ -16,14 +16,41 @@ import ( // loadEnvFile loads the environment variables from file. Note that we only use .env file. func loadEnvFile(ctx context.Context) error { - if workDir, err := os.Getwd(); err != nil { + workDir, err := os.Getwd() + if err != nil { return errors.Wrapf(err, "getpwd") - } else { - envFile := path.Join(workDir, ".env") - if _, err := os.Stat(envFile); err == nil { - if err := godotenv.Load(envFile); err != nil { - return errors.Wrapf(err, "load %v", envFile) + } + + envFile := path.Join(workDir, ".env") + if _, err := os.Stat(envFile); err != nil { + return nil + } + + file, err := os.Open(envFile) + if err != nil { + return errors.Wrapf(err, "open %v", envFile) + } + defer file.Close() + + b, err := ioutil.ReadAll(file) + if err != nil { + return errors.Wrapf(err, "read %v", envFile) + } + + lines := strings.Split(strings.Replace(string(b), "\r\n", "\n", -1), "\n") + for _, line := range lines { + if strings.HasPrefix(strings.TrimSpace(line), "#") { + continue + } + + if pos := strings.IndexByte(line, '='); pos > 0 { + key := strings.TrimSpace(line[:pos]) + value := strings.TrimSpace(line[pos+1:]) + if v := os.Getenv(key); v != "" { + continue } + + os.Setenv(key, value) } } diff --git a/proxy/go.mod b/proxy/go.mod index 2e2a17ab3..e9e196d2f 100644 --- a/proxy/go.mod +++ b/proxy/go.mod @@ -2,10 +2,7 @@ module srs-proxy go 1.18 -require ( - github.com/go-redis/redis/v8 v8.11.5 - github.com/joho/godotenv v1.5.1 -) +require github.com/go-redis/redis/v8 v8.11.5 require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect diff --git a/proxy/go.sum b/proxy/go.sum index 1efc5318e..7342ff813 100644 --- a/proxy/go.sum +++ b/proxy/go.sum @@ -5,8 +5,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= From 40e8ed458615e58fae5eb4c161d6bf0971ce4b37 Mon Sep 17 00:00:00 2001 From: winlin Date: Tue, 10 Sep 2024 16:41:34 +0800 Subject: [PATCH 09/12] VSCode: Support IDE vscode to run and debug. --- .gitignore | 6 +++--- .vscode/README.md | 38 ++++++++++++++++++++++++++++++++++++++ .vscode/launch.json | 36 ++++++++++++++++++++++++++++++++++++ .vscode/settings.json | 5 +++++ .vscode/tasks.json | 17 +++++++++++++++++ proxy/env.go | 6 ++++-- 6 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 .vscode/README.md create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json create mode 100644 .vscode/tasks.json diff --git a/.gitignore b/.gitignore index 63880f228..0319e6d68 100644 --- a/.gitignore +++ b/.gitignore @@ -16,8 +16,6 @@ *.pyc *.swp .DS_Store -.vscode -.vscode/* /trunk/Makefile /trunk/objs /trunk/src/build-qt-Desktop-Debug @@ -42,4 +40,6 @@ cmake-build-debug /trunk/ide/srs_clion/Makefile /trunk/ide/srs_clion/cmake_install.cmake /trunk/ide/srs_clion/srs -/trunk/ide/srs_clion/Testing/ \ No newline at end of file +/trunk/ide/srs_clion/Testing/ +/trunk/ide/vscode-build + diff --git a/.vscode/README.md b/.vscode/README.md new file mode 100644 index 000000000..eb513707b --- /dev/null +++ b/.vscode/README.md @@ -0,0 +1,38 @@ +# Debug with VSCode + +Support run and debug with VSCode. + +## SRS + +Install the following extensions: + +- CMake Tools +- CodeLLDB +- C/C++ Extension Pack + +Open the folder like `~/git/srs` in VSCode. +Run commmand `> CMake: Configure` to configure the project. + +> Note: You can press `Ctrl+R`, then type `CMake: Configure` then select `Clang` as the toolchain. + +> Note: The `settings.json` is used to configure the cmake. It will use `${workspaceFolder}/trunk/ide/srs_clion/CMakeLists.txt` +> and `${workspaceFolder}/trunk/ide/vscode-build` as the source file and build directory. + +Click the `Run > Run Without Debugging` button to start the server. + +> Note: The `launch.json` is used for running and debugging. The build will output the binary to +> `${workspaceFolder}/trunk/ide/vscode-build/srs`. + +## Proxy + +Install the following extensions: + +- Go + +Open the folder like `~/git/srs` in VSCode. + +Select the `View > Run` and select `Launch srs-proxy` to start the proxy server. + +Click the `Run > Run Without Debugging` button to start the server. + +> Note: The `launch.json` is used for running and debugging. diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..047efcbfd --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,36 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Launch SRS", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/trunk/ide/vscode-build/srs", + "args": ["-c", "conf/console.conf"], + "stopAtEntry": false, + "cwd": "${workspaceFolder}/trunk", + "environment": [], + "externalConsole": false, + "MIMode": "lldb", + "setupCommands": [ + { + "description": "Enable pretty-printing for gdb", + "text": "-enable-pretty-printing", + "ignoreFailures": true + } + ], + "preLaunchTask": "build", + "logging": { + "engineLogging": true + } + }, + { + "name": "Launch srs-proxy", + "type": "go", + "request": "launch", + "mode": "auto", + "cwd": "${workspaceFolder}/proxy", + "program": "${workspaceFolder}/proxy" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..0d9dbf97b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "cmake.sourceDirectory": "${workspaceFolder}/trunk/ide/srs_clion", + "cmake.buildDirectory": "${workspaceFolder}/trunk/ide/vscode-build", + "cmake.configureOnOpen": false +} diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 000000000..98388f3b3 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,17 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "build", + "type": "shell", + "command": "cmake --build ${workspaceFolder}/trunk/ide/vscode-build", + "group": { + "kind": "build", + "isDefault": true + }, + "problemMatcher": ["$gcc"], + "detail": "Build SRS by cmake." + } + ] + } + \ No newline at end of file diff --git a/proxy/env.go b/proxy/env.go index 26b014609..d60726354 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -5,7 +5,7 @@ package main import ( "context" - "io/ioutil" + "io" "os" "path" "strings" @@ -32,12 +32,14 @@ func loadEnvFile(ctx context.Context) error { } defer file.Close() - b, err := ioutil.ReadAll(file) + b, err := io.ReadAll(file) if err != nil { return errors.Wrapf(err, "read %v", envFile) } lines := strings.Split(strings.Replace(string(b), "\r\n", "\n", -1), "\n") + logger.Df(ctx, "load env file %v, lines=%v", envFile, len(lines)) + for _, line := range lines { if strings.HasPrefix(strings.TrimSpace(line), "#") { continue From 2068aa4659fcfaef9c54098fa98f8214c76bbfc1 Mon Sep 17 00:00:00 2001 From: Winlin Date: Sat, 28 Sep 2024 10:41:35 +0800 Subject: [PATCH 10/12] Update FUNDING.yml --- .github/FUNDING.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index f4170d9b1..0e11853ba 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,6 +1,6 @@ # These are supported funding model platforms -github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +github: [winlinvip] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] patreon: # Replace with patreon id. open_collective: srs-server ko_fi: # Replace with a single Ko-fi username From 0de887d374120b9f66725e71848ce890ec09c856 Mon Sep 17 00:00:00 2001 From: Winlin Date: Wed, 9 Oct 2024 11:29:35 +0800 Subject: [PATCH 11/12] Update bug_report.md --- .github/ISSUE_TEMPLATE/bug_report.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 037369f12..708a1ba9e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -7,8 +7,7 @@ assignees: '' --- -!!! Before submitting a new bug report, please ensure you have searched for any existing bugs and utilized -the `Ask AI` feature at https://ossrs.io or https://ossrs.net (for users in China). Duplicate issues or +!!! Before submitting a new bug report, please ensure you have searched for any existing bugs. Duplicate issues or questions that are overly simple or already addressed in the documentation will be removed without any response. From e7d78462fe38e7390f049b0d8794bc1ca301f93b Mon Sep 17 00:00:00 2001 From: Jacob Su Date: Tue, 15 Oct 2024 17:52:17 +0800 Subject: [PATCH 12/12] ST: Use clock_gettime to prevent time jumping backwards. v7.0.17 (#3979) try to fix #3978 **Background** check #3978 **Research** I referred the Android platform's solution, because I have android background, and there is a loop to handle message inside android. https://github.com/aosp-mirror/platform_frameworks_base/blob/ff007a03c01bf936d1e961a13adff9f266d5189c/core/java/android/os/Handler.java#L701-L706C6 ``` public final boolean sendMessageDelayed(@NonNull Message msg, long delayMillis) { if (delayMillis < 0) { delayMillis = 0; } return sendMessageAtTime(msg, SystemClock.uptimeMillis() + delayMillis); } ``` https://github.com/aosp-mirror/platform_system_core/blob/59d9dc1f50b1ae8630ec11a431858a3cb66487b7/libutils/SystemClock.cpp#L37-L51 ``` /* * native public static long uptimeMillis(); */ int64_t uptimeMillis() { return nanoseconds_to_milliseconds(uptimeNanos()); } /* * public static native long uptimeNanos(); */ int64_t uptimeNanos() { return systemTime(SYSTEM_TIME_MONOTONIC); } ``` https://github.com/aosp-mirror/platform_system_core/blob/59d9dc1f50b1ae8630ec11a431858a3cb66487b7/libutils/Timers.cpp#L32-L55 ``` #if defined(__linux__) nsecs_t systemTime(int clock) { checkClockId(clock); static constexpr clockid_t clocks[] = {CLOCK_REALTIME, CLOCK_MONOTONIC, CLOCK_PROCESS_CPUTIME_ID, CLOCK_THREAD_CPUTIME_ID, CLOCK_BOOTTIME}; static_assert(clock_id_max == arraysize(clocks)); timespec t = {}; clock_gettime(clocks[clock], &t); return nsecs_t(t.tv_sec)*1000000000LL + t.tv_nsec; } #else nsecs_t systemTime(int clock) { // TODO: is this ever called with anything but REALTIME on mac/windows? checkClockId(clock); // Clock support varies widely across hosts. Mac OS doesn't support // CLOCK_BOOTTIME (and doesn't even have clock_gettime until 10.12). // Windows is windows. timeval t = {}; gettimeofday(&t, nullptr); return nsecs_t(t.tv_sec)*1000000000LL + nsecs_t(t.tv_usec)*1000LL; } #endif ``` For Linux system, we can use `clock_gettime` api, but it's first appeared in Mac OSX 10.12. `man clock_gettime` The requirement is to find an alternative way to get the timestamp in microsecond unit, but the `clock_gettime` get nanoseconds, the math formula is the nanoseconds / 1000 = microsecond. Then I check the performance of this api + math division. I used those code to check the `clock_gettime` performance. ``` #include #include #include #include int main() { struct timeval tv; struct timespec ts; clock_t start; clock_t end; long t; while (1) { start = clock(); gettimeofday(&tv, NULL); end = clock(); printf("gettimeofday clock is %lu\n", end - start); printf("gettimeofday is %lld\n", (tv.tv_sec * 1000000LL + tv.tv_usec)); start = clock(); clock_gettime(CLOCK_MONOTONIC, &ts); t = ts.tv_sec * 1000000L + ts.tv_nsec / 1000L; end = clock(); printf("clock_monotonic clock is %lu\n", end - start); printf("clock_monotonic: seconds is %ld, nanoseconds is %ld, sum is %ld\n", ts.tv_sec, ts.tv_nsec, t); start = clock(); clock_gettime(CLOCK_MONOTONIC_RAW, &ts); t = ts.tv_sec * 1000000L + ts.tv_nsec / 1000L; end = clock(); printf("clock_monotonic_raw clock is %lu\n", end - start); printf("clock_monotonic_raw: nanoseconds is %ld, sum is %ld\n", ts.tv_nsec, t); sleep(3); } return 0; } ``` Here is output: env: Mac OS M2 chip. ``` gettimeofday clock is 11 gettimeofday is 1709775727153949 clock_monotonic clock is 2 clock_monotonic: seconds is 1525204, nanoseconds is 409453000, sum is 1525204409453 clock_monotonic_raw clock is 2 clock_monotonic_raw: nanoseconds is 770493000, sum is 1525222770493 ``` We can see the `clock_gettime` is faster than `gettimeofday`, so there are no performance risks. **MacOS solution** `clock_gettime` api only available until mac os 10.12, for the mac os older than 10.12, just keep the `gettimeofday`. check osx version in `auto/options.sh`, then add MACRO in `auto/depends.sh`, the MACRO is `MD_OSX_HAS_NO_CLOCK_GETTIME`. **CYGWIN** According to google search, it seems the `clock_gettime(CLOCK_MONOTONIC)` is not support well at least 10 years ago, but I didn't own an windows machine, so can't verify it. so keep win's solution. --------- Co-authored-by: winlin --- trunk/3rdparty/st-srs/Makefile | 5 +- trunk/3rdparty/st-srs/md.h | 31 ++++--- trunk/auto/depends.sh | 3 + trunk/auto/options.sh | 6 +- trunk/configure | 3 +- trunk/doc/CHANGELOG.md | 1 + trunk/src/core/srs_core_version7.hpp | 2 +- trunk/src/utest/srs_utest_st.cpp | 117 +++++++++++++++++++++++++++ trunk/src/utest/srs_utest_st.hpp | 15 ++++ 9 files changed, 168 insertions(+), 15 deletions(-) create mode 100644 trunk/src/utest/srs_utest_st.cpp create mode 100644 trunk/src/utest/srs_utest_st.hpp diff --git a/trunk/3rdparty/st-srs/Makefile b/trunk/3rdparty/st-srs/Makefile index 11fdc95c9..8fb418c57 100644 --- a/trunk/3rdparty/st-srs/Makefile +++ b/trunk/3rdparty/st-srs/Makefile @@ -185,11 +185,9 @@ endif # make EXTRA_CFLAGS=-UMD_HAVE_EPOLL # # or to enable sendmmsg(2) support: -# # make EXTRA_CFLAGS="-DMD_HAVE_SENDMMSG -D_GNU_SOURCE" # # or to enable stats for ST: -# # make EXTRA_CFLAGS=-DDEBUG_STATS # # or cache the stack and reuse it: @@ -201,6 +199,9 @@ endif # or enable support for asan: # make EXTRA_CFLAGS="-DMD_ASAN -fsanitize=address -fno-omit-frame-pointer" # +# or to disable the clock_gettime for MacOS before 10.12, see https://github.com/ossrs/srs/issues/3978 +# make EXTRA_CFLAGS=-DMD_OSX_NO_CLOCK_GETTIME +# # or enable the coverage for utest: # make UTEST_FLAGS="-fprofile-arcs -ftest-coverage" # diff --git a/trunk/3rdparty/st-srs/md.h b/trunk/3rdparty/st-srs/md.h index 677d6fb46..a25c0087a 100644 --- a/trunk/3rdparty/st-srs/md.h +++ b/trunk/3rdparty/st-srs/md.h @@ -101,10 +101,21 @@ extern void _st_md_cxt_restore(_st_jmp_buf_t env, int val); #error Unknown CPU architecture #endif - #define MD_GET_UTIME() \ - struct timeval tv; \ - (void) gettimeofday(&tv, NULL); \ - return (tv.tv_sec * 1000000LL + tv.tv_usec) + #if defined (MD_OSX_NO_CLOCK_GETTIME) + #define MD_GET_UTIME() \ + struct timeval tv; \ + (void) gettimeofday(&tv, NULL); \ + return (tv.tv_sec * 1000000LL + tv.tv_usec) + #else + /* + * https://github.com/ossrs/srs/issues/3978 + * use clock_gettime to get the timestamp in microseconds. + */ + #define MD_GET_UTIME() \ + struct timespec ts; \ + clock_gettime(CLOCK_MONOTONIC, &ts); \ + return (ts.tv_sec * 1000000LL + ts.tv_nsec / 1000) + #endif #elif defined (LINUX) @@ -120,13 +131,13 @@ extern void _st_md_cxt_restore(_st_jmp_buf_t env, int val); #define MD_HAVE_SOCKLEN_T /* - * All architectures and flavors of linux have the gettimeofday - * function but if you know of a faster way, use it. + * https://github.com/ossrs/srs/issues/3978 + * use clock_gettime to get the timestamp in microseconds. */ - #define MD_GET_UTIME() \ - struct timeval tv; \ - (void) gettimeofday(&tv, NULL); \ - return (tv.tv_sec * 1000000LL + tv.tv_usec) + #define MD_GET_UTIME() \ + struct timespec ts; \ + clock_gettime(CLOCK_MONOTONIC, &ts); \ + return (ts.tv_sec * 1000000LL + ts.tv_nsec / 1000) #if defined(__i386__) #define MD_GET_SP(_t) *((long *)&((_t)->context[0].__jmpbuf[4])) diff --git a/trunk/auto/depends.sh b/trunk/auto/depends.sh index a73b6b8b5..349feec2f 100755 --- a/trunk/auto/depends.sh +++ b/trunk/auto/depends.sh @@ -266,6 +266,9 @@ fi # for osx, use darwin for st, donot use epoll. if [[ $SRS_OSX == YES ]]; then _ST_MAKE=darwin-debug && _ST_OBJ="DARWIN_`uname -r`_DBG" + if [[ $SRS_OSX_HAS_CLOCK_GETTIME != YES ]]; then + _ST_EXTRA_CFLAGS="$_ST_EXTRA_CFLAGS -DMD_OSX_NO_CLOCK_GETTIME" + fi fi # for windows/cygwin if [[ $SRS_CYGWIN64 = YES ]]; then diff --git a/trunk/auto/options.sh b/trunk/auto/options.sh index 03d58c54c..e5dcfeda9 100755 --- a/trunk/auto/options.sh +++ b/trunk/auto/options.sh @@ -112,6 +112,8 @@ SRS_CROSS_BUILD_HOST= SRS_CROSS_BUILD_PREFIX= # For cache build SRS_BUILD_CACHE=YES +# Only support MacOS 10.12+ for clock_gettime, see https://github.com/ossrs/srs/issues/3978 +SRS_OSX_HAS_CLOCK_GETTIME=YES # ##################################################################################### # Toolchain for cross-build on Ubuntu for ARM or MIPS. @@ -150,7 +152,9 @@ function apply_system_options() { OS_IS_RISCV=$(gcc -dM -E - /dev/null || echo 1); fi diff --git a/trunk/configure b/trunk/configure index b9475493e..b7f004059 100755 --- a/trunk/configure +++ b/trunk/configure @@ -464,7 +464,8 @@ if [[ $SRS_UTEST == YES ]]; then MODULE_FILES=("srs_utest" "srs_utest_amf0" "srs_utest_kernel" "srs_utest_core" "srs_utest_config" "srs_utest_rtmp" "srs_utest_http" "srs_utest_avc" "srs_utest_reload" "srs_utest_mp4" "srs_utest_service" "srs_utest_app" "srs_utest_rtc" "srs_utest_config2" - "srs_utest_protocol" "srs_utest_protocol2" "srs_utest_kernel2" "srs_utest_protocol3") + "srs_utest_protocol" "srs_utest_protocol2" "srs_utest_kernel2" "srs_utest_protocol3" + "srs_utest_st") if [[ $SRS_SRT == YES ]]; then MODULE_FILES+=("srs_utest_srt") fi diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index 9e676930f..18a7b36a4 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2024-10-15, Merge [#3979](https://github.com/ossrs/srs/pull/3979): ST: Use clock_gettime to prevent time jumping backwards. v7.0.17 (#3979) * v7.0, 2024-09-09, Merge [#4158](https://github.com/ossrs/srs/pull/4158): Proxy: Support proxy server for SRS. v7.0.16 (#4158) * v7.0, 2024-09-09, Merge [#4171](https://github.com/ossrs/srs/pull/4171): Heartbeat: Report ports for proxy server. v7.0.15 (#4171) * v7.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v7.0.14 (#4165) diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index 458a6c3d8..67461fab9 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 16 +#define VERSION_REVISION 17 #endif \ No newline at end of file diff --git a/trunk/src/utest/srs_utest_st.cpp b/trunk/src/utest/srs_utest_st.cpp new file mode 100644 index 000000000..553bc0c95 --- /dev/null +++ b/trunk/src/utest/srs_utest_st.cpp @@ -0,0 +1,117 @@ +// +// Copyright (c) 2013-2024 The SRS Authors +// +// SPDX-License-Identifier: MIT +// +#include +#include +#include +#include + +using namespace std; + +VOID TEST(StTest, StUtimeInMicroseconds) +{ + st_utime_t st_time_1 = st_utime(); + // sleep 1 microsecond +#if !defined(SRS_CYGWIN64) + usleep(1); +#endif + st_utime_t st_time_2 = st_utime(); + + EXPECT_GT(st_time_1, 0); + EXPECT_GT(st_time_2, 0); + EXPECT_GE(st_time_2, st_time_1); + // st_time_2 - st_time_1 should be in range of [1, 100] microseconds + EXPECT_GE(st_time_2 - st_time_1, 0); + EXPECT_LE(st_time_2 - st_time_1, 100); +} + +static inline st_utime_t time_gettimeofday() { + struct timeval tv; + gettimeofday(&tv, NULL); + return (tv.tv_sec * 1000000LL + tv.tv_usec); +} + +VOID TEST(StTest, StUtimePerformance) +{ + clock_t start; + int gettimeofday_elapsed_time = 0; + int st_utime_elapsed_time = 0; + + // Both the st_utime(clock_gettime or gettimeofday) and gettimeofday's + // elpased time to execute is dependence on whether it is the first time be called. + // In general, the gettimeofday has better performance, but the gap between + // them is really small, maybe less than 10 clock ~ 10 microseconds. + + // check st_utime first, then gettimeofday + { + start = clock(); + st_utime_t t2 = st_utime(); + int elapsed_time = clock() - start; + st_utime_elapsed_time += elapsed_time; + EXPECT_GT(t2, 0); + + start = clock(); + st_utime_t t1 = time_gettimeofday(); + elapsed_time = clock() - start; + gettimeofday_elapsed_time += elapsed_time; + EXPECT_GT(t1, 0); + + + EXPECT_GE(gettimeofday_elapsed_time, 0); + EXPECT_GE(st_utime_elapsed_time, 0); + + // pass the test, if + EXPECT_LT(gettimeofday_elapsed_time > st_utime_elapsed_time ? + gettimeofday_elapsed_time - st_utime_elapsed_time : + st_utime_elapsed_time - gettimeofday_elapsed_time, 10); + } + + // check gettimeofday first, then st_utime + { + start = clock(); + st_utime_t t1 = time_gettimeofday(); + int elapsed_time = clock() - start; + gettimeofday_elapsed_time += elapsed_time; + EXPECT_GT(t1, 0); + + start = clock(); + st_utime_t t2 = st_utime(); + elapsed_time = clock() - start; + st_utime_elapsed_time += elapsed_time; + EXPECT_GT(t2, 0); + + EXPECT_GE(gettimeofday_elapsed_time, 0); + EXPECT_GE(st_utime_elapsed_time, 0); + + EXPECT_LT(gettimeofday_elapsed_time > st_utime_elapsed_time ? + gettimeofday_elapsed_time - st_utime_elapsed_time : + st_utime_elapsed_time - gettimeofday_elapsed_time, 10); + } + + // compare st_utime & gettimeofday in a loop + for (int i = 0; i < 100; i++) { + start = clock(); + st_utime_t t2 = st_utime(); + int elapsed_time = clock() - start; + st_utime_elapsed_time = elapsed_time; + EXPECT_GT(t2, 0); + usleep(1); + + start = clock(); + st_utime_t t1 = time_gettimeofday(); + elapsed_time = clock() - start; + gettimeofday_elapsed_time = elapsed_time; + EXPECT_GT(t1, 0); + usleep(1); + + EXPECT_GE(gettimeofday_elapsed_time, 0); + EXPECT_GE(st_utime_elapsed_time, 0); + + EXPECT_LT(gettimeofday_elapsed_time > st_utime_elapsed_time ? + gettimeofday_elapsed_time - st_utime_elapsed_time : + st_utime_elapsed_time - gettimeofday_elapsed_time, 10); + + } +} diff --git a/trunk/src/utest/srs_utest_st.hpp b/trunk/src/utest/srs_utest_st.hpp new file mode 100644 index 000000000..32862447a --- /dev/null +++ b/trunk/src/utest/srs_utest_st.hpp @@ -0,0 +1,15 @@ +// +// Copyright (c) 2013-2024 The SRS Authors +// +// SPDX-License-Identifier: MIT +// + +#ifndef SRS_UTEST_ST_HPP +#define SRS_UTEST_ST_HPP + +#include + +#include + +#endif // SRS_UTEST_ST_HPP +