From 8b785f2f2b36b4e1251384977a8c1641d72e01cd Mon Sep 17 00:00:00 2001 From: topilski Date: Wed, 4 Sep 2019 21:20:51 -0400 Subject: [PATCH] Close clients if dos --- app/service/service_client.py | 4 ++ app/service/service_manager.py | 90 ++++++++++++++++++++-------------- 2 files changed, 56 insertions(+), 38 deletions(-) diff --git a/app/service/service_client.py b/app/service/service_client.py index c1ca49a..534285e 100644 --- a/app/service/service_client.py +++ b/app/service/service_client.py @@ -171,12 +171,16 @@ class ServiceClient(IClientHandler): return if req.method == Commands.STATISTIC_STREAM_COMMAND: + assert req.is_notification() self._handler.on_stream_statistic_received(req.params) elif req.method == Commands.CHANGED_STREAM_COMMAND: + assert req.is_notification() self._handler.on_stream_sources_changed(req.params) elif req.method == Commands.STATISTIC_SERVICE_COMMAND: + assert req.is_notification() self._handler.on_service_statistic_received(req.params) elif req.method == Commands.QUIT_STATUS_STREAM_COMMAND: + assert req.is_notification() self._handler.on_quit_status_stream(req.params) elif req.method == Commands.CLIENT_PING_COMMAND: self._handler.on_ping_received(req.params) diff --git a/app/service/service_manager.py b/app/service/service_manager.py index bd1230c..ef62562 100644 --- a/app/service/service_manager.py +++ b/app/service/service_manager.py @@ -72,8 +72,7 @@ class ServiceManager(IClientHandler): if client.socket() == read: res = client.recv_data() if not res: - self.__remove_subscriber(client) - client.disconnect() + self.__close_subscriber(client) break for server in self._servers_pool: @@ -83,8 +82,9 @@ class ServiceManager(IClientHandler): for client in self._subscribers: if ts_sec - client.last_ping_ts > ServiceManager.PING_SUBSCRIBERS_SEC: - client.ping(client.gen_request_id()) - client.last_ping_ts = ts_sec + if client.is_active(): + client.ping(client.gen_request_id()) + client.last_ping_ts = ts_sec def process_response(self, client, req: Request, resp: Response): if req.method == Commands.SERVER_PING_COMMAND: @@ -96,19 +96,23 @@ class ServiceManager(IClientHandler): if not req: return + result = False if req.method == Commands.ACTIVATE_COMMAND: - self._handle_activate_subscriber(client, req.id, req.params) + result = self._handle_activate_subscriber(client, req.id, req.params) elif req.method == Commands.GET_SERVER_INFO_COMMAND: - self._handle_get_server_info(client, req.id, req.params) + result = self._handle_get_server_info(client, req.id, req.params) elif req.method == Commands.CLIENT_PING_COMMAND: - self._handle_client_ping(client, req.id, req.params) + result = self._handle_client_ping(client, req.id, req.params) elif req.method == Commands.GET_CHANNELS: - self._handle_get_channels(client, req.id, req.params) + result = self._handle_get_channels(client, req.id, req.params) elif req.method == Commands.GET_RUNTIME_CHANNEL_INFO: - self._handle_get_runtime_channel_info(client, req.id, req.params) + result = self._handle_get_runtime_channel_info(client, req.id, req.params) else: pass + if not result: + self.__close_subscriber(client) + def on_client_state_changed(self, client, status: ClientStatus): pass @@ -120,7 +124,7 @@ class ServiceManager(IClientHandler): def _handle_server_get_client_info(self, client, resp: Response): pass - def _handle_activate_subscriber(self, client, cid: str, params: dict): + def _handle_activate_subscriber(self, client, cid: str, params: dict) -> bool: login = params[Subscriber.EMAIL_FIELD] password_hash = params[Subscriber.PASSWORD_FIELD] device_id = params['device_id'] @@ -128,75 +132,71 @@ class ServiceManager(IClientHandler): check_user = Subscriber.objects(email=login, class_check=False).first() if not check_user: client.activate_fail(cid, 'User not found') - return + return False if check_user.status == Subscriber.Status.NOT_ACTIVE: client.activate_fail(cid, 'User not active') - return + return False if check_user.status == Subscriber.Status.BANNED: client.activate_fail(cid, 'Banned user') - return + return False if check_user[Subscriber.PASSWORD_FIELD] != password_hash: client.activate_fail(cid, 'User invalid password') - return + return False found_device = check_user.find_device(device_id) if not found_device: client.activate_fail(cid, 'Device not found') - return + return False user_connections = self.get_user_connections_by_email(login) for conn in user_connections: if conn.device == found_device: client.activate_fail(cid, 'Device in use') - return + return False client.activate_success(cid) client.info = check_user client.device = found_device + return True - def _handle_get_server_info(self, client, cid: str, params: dict): + def _handle_get_server_info(self, client, cid: str, params: dict) -> bool: if not check_is_auth_client(client): client.check_activate_fail(cid, 'User not active') - client.disconnect() - return + return False client.get_server_info_success(cid, '{0}:{1}'.format(self._host, ServiceManager.BANDWIDTH_PORT)) + return True - def _handle_client_ping(self, client, cid: str, params: dict): - pass - - def _handle_get_channels(self, client, cid: str, params: dict): + def _handle_client_ping(self, client, cid: str, params: dict) -> bool: if not check_is_auth_client(client): client.check_activate_fail(cid, 'User not active') - client.disconnect() - return + return False + + client.pong(cid) + return True + + def _handle_get_channels(self, client, cid: str, params: dict) -> bool: + if not check_is_auth_client(client): + client.check_activate_fail(cid, 'User not active') + return False channels = client.info.get_streams() client.get_channels_success(cid, channels) + return True - def _handle_get_runtime_channel_info(self, client, cid: str, params: dict): + def _handle_get_runtime_channel_info(self, client, cid: str, params: dict) -> bool: if not check_is_auth_client(client): client.check_activate_fail(cid, 'User not active') - client.disconnect() - return + return False sid = params['id'] watchers = self.get_watchers_by_stream_id(sid) client.current_stream_id = sid client.get_runtime_channel_info_success(cid, sid, watchers) - - # private - def __add_server(self, server: Service): - self._servers_pool.append(server) - - def __add_subscriber(self, subs: SubscriberConnection): - self._subscribers.append(subs) - - def __remove_subscriber(self, subs: SubscriberConnection): - self._subscribers.remove(subs) + return True def get_watchers_by_stream_id(self, sid: str): total = 0 @@ -218,3 +218,17 @@ class ServiceManager(IClientHandler): for user in self._subscribers: if user.info and user.info.email == email: user.send_message(user.gen_request_id(), message.message, message.type, message.ttl * 1000) + + # private + def __close_subscriber(self, subs: SubscriberConnection): + self.__remove_subscriber(subs) + subs.disconnect() + + def __add_server(self, server: Service): + self._servers_pool.append(server) + + def __add_subscriber(self, subs: SubscriberConnection): + self._subscribers.append(subs) + + def __remove_subscriber(self, subs: SubscriberConnection): + self._subscribers.remove(subs)