diff --git a/trunk/src/rtmp/srs_protocol_handshake.cpp b/trunk/src/rtmp/srs_protocol_handshake.cpp index ebef178e9..1d4810fb0 100644 --- a/trunk/src/rtmp/srs_protocol_handshake.cpp +++ b/trunk/src/rtmp/srs_protocol_handshake.cpp @@ -759,6 +759,29 @@ int c1s1::c1_validate_digest(bool& is_valid) return ret; } +int c1s1::s1_validate_digest(bool& is_valid) +{ + int ret = ERROR_SUCCESS; + + char* s1_digest = NULL; + + if ((ret = calc_s1_digest(s1_digest)) != ERROR_SUCCESS) { + srs_error("validate s1 error, failed to calc digest. ret=%d", ret); + return ret; + } + + srs_assert(s1_digest != NULL); + SrsAutoFree(char, s1_digest, true); + + if (schema == srs_schema0) { + is_valid = srs_bytes_equals(block1.digest.digest, s1_digest, 32); + } else { + is_valid = srs_bytes_equals(block0.digest.digest, s1_digest, 32); + } + + return ret; +} + int c1s1::s1_create(c1s1* c1) { int ret = ERROR_SUCCESS; @@ -1076,6 +1099,13 @@ int SrsComplexHandshake::handshake_with_client(ISrsProtocolReaderWriter* skt, ch return ret; } srs_verbose("create s1 from c1 success."); + // verify s1 + if ((ret = s1.s1_validate_digest(is_valid)) != ERROR_SUCCESS || !is_valid) { + ret = ERROR_RTMP_TRY_SIMPLE_HS; + srs_info("valid s1 failed, try simple handshake. ret=%d", ret); + return ret; + } + srs_verbose("verify s1 from c1 success."); c2s2 s2; if ((ret = s2.s2_create(&c1)) != ERROR_SUCCESS) { diff --git a/trunk/src/rtmp/srs_protocol_handshake.hpp b/trunk/src/rtmp/srs_protocol_handshake.hpp index bc5ce161a..d38bc34e1 100644 --- a/trunk/src/rtmp/srs_protocol_handshake.hpp +++ b/trunk/src/rtmp/srs_protocol_handshake.hpp @@ -206,13 +206,17 @@ namespace srs */ virtual int c1_parse(char* _c1s1, srs_schema_type _schema); /** - * server: validate the parsed schema and c1s1 + * server: validate the parsed c1 schema */ virtual int c1_validate_digest(bool& is_valid); /** * server: create and sign the s1 from c1. */ virtual int s1_create(c1s1* c1); + /** + * server: validate the parsed s1 schema + */ + virtual int s1_validate_digest(bool& is_valid); private: virtual int calc_s1_digest(char*& digest); virtual int calc_c1_digest(char*& digest);