diff --git a/core/aimedb.py b/core/aimedb.py index b8ce3b7..78ce350 100644 --- a/core/aimedb.py +++ b/core/aimedb.py @@ -164,7 +164,7 @@ class AimedbProtocol(Protocol): def handle_register(self, data: bytes) -> bytes: luid = data[0x20: 0x2a].hex() - if self.config.server.allow_registration: + if self.config.server.allow_user_registration: user_id = self.data.user.create_user() if user_id is None: diff --git a/core/allnet.py b/core/allnet.py index 3b371ce..bd2ecce 100644 --- a/core/allnet.py +++ b/core/allnet.py @@ -92,10 +92,10 @@ class AllnetServlet: self.uri_registry[code] = (uri, host) self.logger.info(f"Allnet serving {len(self.uri_registry)} games on port {core_cfg.allnet.port}") - def handle_poweron(self, request: Request): + def handle_poweron(self, request: Request, _: Dict): request_ip = request.getClientAddress().host try: - req = AllnetPowerOnRequest(self.allnet_req_to_dict(request.content.getvalue())) + req = AllnetPowerOnRequest(self.allnet_req_to_dict(request.content.getvalue())[0]) # Validate the request. Currently we only validate the fields we plan on using if not req.game_id or not req.ver or not req.token or not req.serial or not req.ip: @@ -131,31 +131,35 @@ class AllnetServlet: if machine is not None: arcade = self.data.arcade.get_arcade(machine["arcade"]) - req.country = arcade["country"] if machine["country"] is None else machine["country"] - req.place_id = arcade["id"] - req.allnet_id = machine["id"] - req.name = arcade["name"] - req.nickname = arcade["nickname"] - req.region0 = arcade["region_id"] - req.region_name0 = arcade["country"] - req.region_name1 = arcade["state"] - req.region_name2 = arcade["city"] - req.client_timezone = arcade["timezone"] if arcade["timezone"] is not None else "+0900" + resp.country = arcade["country"] if machine["country"] is None else machine["country"] + resp.place_id = arcade["id"] + resp.allnet_id = machine["id"] + resp.name = arcade["name"] + resp.nickname = arcade["nickname"] + resp.region0 = arcade["region_id"] + resp.region_name0 = arcade["country"] + resp.region_name1 = arcade["state"] + resp.region_name2 = arcade["city"] + resp.client_timezone = arcade["timezone"] if arcade["timezone"] is not None else "+0900" + + int_ver = req.ver.replace(".", "") + resp.uri = resp.uri.replace("$v", int_ver) + resp.host = resp.host.replace("$v", int_ver) msg = f"{req.serial} authenticated from {request_ip}: {req.game_id} v{req.ver}" self.data.base.log_event("allnet", "ALLNET_AUTH_SUCCESS", logging.INFO, msg) self.logger.info(msg) - return self.dict_to_http_form_string([vars(resp)]) + return self.dict_to_http_form_string([vars(resp)]).encode("utf-8") - def handle_dlorder(self, request: Request): + def handle_dlorder(self, request: Request, _: Dict): request_ip = request.getClientAddress().host try: - req = AllnetDownloadOrderRequest(self.allnet_req_to_dict(request.content.getvalue())) + req = AllnetDownloadOrderRequest(self.allnet_req_to_dict(request.content.getvalue())[0]) # Validate the request. Currently we only validate the fields we plan on using - if not req.game_id or not req.ver or not req.token or not req.serial or not req.ip: - raise AllnetRequestException(f"Bad auth request params from {request_ip} - {vars(req)}") + if not req.game_id or not req.ver or not req.serial: + raise AllnetRequestException(f"Bad download request params from {request_ip} - {vars(req)}") except AllnetRequestException as e: self.logger.error(e) @@ -163,12 +167,12 @@ class AllnetServlet: resp = AllnetDownloadOrderResponse() if not self.config.allnet.allow_online_updates: - return self.dict_to_http_form_string(vars(resp)) + return self.dict_to_http_form_string([vars(resp)]) else: # TODO: Actual dlorder response - return self.dict_to_http_form_string(vars(resp)) + return self.dict_to_http_form_string([vars(resp)]) - def handle_billing_request(self, request: Request): + def handle_billing_request(self, request: Request, _: Dict): req_dict = self.billing_req_to_dict(request.content.getvalue()) request_ip = request.getClientAddress() if req_dict is None: @@ -223,14 +227,14 @@ class AllnetServlet: resp = BillingResponse(playlimit, playlimit_sig, nearfull, nearfull_sig) - resp_str = self.dict_to_http_form_string([vars(resp)]) + resp_str = self.dict_to_http_form_string([vars(resp)], True) if resp_str is None: self.logger.error(f"Failed to parse response {vars(resp)}") self.logger.debug(f"response {vars(resp)}") return resp_str.encode("utf-8") - def kvp_to_dict(self, *kvp: str) -> List[Dict[str, Any]]: + def kvp_to_dict(self, kvp: List[str]) -> List[Dict[str, Any]]: ret: List[Dict[str, Any]] = [] for x in kvp: items = x.split('&') @@ -242,8 +246,10 @@ class AllnetServlet: tmp[kvp[0]] = kvp[1] ret.append(tmp) + + return ret - def allnet_req_to_dict(self, data: bytes): + def billing_req_to_dict(self, data: bytes): """ Parses an billing request string into a python dictionary """ @@ -252,13 +258,13 @@ class AllnetServlet: unzipped = decomp.decompress(data) sections = unzipped.decode('ascii').split('\r\n') - return Utils.kvp_to_dict(sections) + return self.kvp_to_dict(sections) except Exception as e: - print(e) + self.logger.error(e) return None - def billing_req_to_dict(self, data: str) -> Optional[List[Dict[str, Any]]]: + def allnet_req_to_dict(self, data: str) -> Optional[List[Dict[str, Any]]]: """ Parses an allnet request string into a python dictionary """ @@ -267,10 +273,10 @@ class AllnetServlet: unzipped = zlib.decompress(zipped) sections = unzipped.decode('utf-8').split('\r\n') - return Utils.kvp_to_dict(sections) + return self.kvp_to_dict(sections) except Exception as e: - print(e) + self.logger.error(e) return None def dict_to_http_form_string(self, data:List[Dict[str, Any]], crlf: bool = False, trailing_newline: bool = True) -> Optional[str]: @@ -297,7 +303,7 @@ class AllnetServlet: return urlencode except Exception as e: - print(e) + self.logger.error(e) return None class AllnetPowerOnRequest(): diff --git a/core/mucha.py b/core/mucha.py index a861101..312d83f 100644 --- a/core/mucha.py +++ b/core/mucha.py @@ -28,7 +28,7 @@ class MuchaServlet: self.logger.setLevel(logging.INFO) coloredlogs.install(level=logging.INFO, logger=self.logger, fmt=log_fmt_str) - def handle_boardauth(self, request: Request) -> bytes: + def handle_boardauth(self, request: Request, _: Dict) -> bytes: req_dict = self.mucha_preprocess(request.content.getvalue()) if req_dict is None: self.logger.error(f"Error processing mucha request {request.content.getvalue()}") @@ -41,7 +41,7 @@ class MuchaServlet: return self.mucha_postprocess(vars(resp)) - def handle_updatecheck(self, request: Request) -> bytes: + def handle_updatecheck(self, request: Request, _: Dict) -> bytes: req_dict = self.mucha_preprocess(request.content.getvalue()) if req_dict is None: self.logger.error(f"Error processing mucha request {request.content.getvalue()}") diff --git a/index.py b/index.py index 7fae586..0049427 100644 --- a/index.py +++ b/index.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import argparse +from typing import Dict import yaml from os import path, mkdir, access, W_OK from core import * @@ -36,46 +37,27 @@ class HttpDispatcher(resource.Resource): if test is None: return b"" - controller = getattr(self, test["controller"], None) - if controller is None: - return b"" - - handler = getattr(controller, test["action"], None) - if handler is None: - return b"" - - url_vars = test - url_vars.pop("controller") - url_vars.pop("action") - - if len(url_vars) > 0: - ret = handler(request, url_vars) - else: - ret = handler(request) - - if type(ret) == str: - return ret.encode() - elif type(ret) == bytes: - return ret - else: - return b"" + return self.dispatch(test, request) def render_POST(self, request: Request) -> bytes: test = self.map_post.match(request.uri.decode()) if test is None: return b"" + + return self.dispatch(test, request) - controller = getattr(self, test["controller"], None) + def dispatch(self, matcher: Dict, request: Request) -> bytes: + controller = getattr(self, matcher["controller"], None) if controller is None: return b"" - handler = getattr(controller, test["action"], None) + handler = getattr(controller, matcher["action"], None) if handler is None: return b"" - url_vars = test + url_vars = matcher url_vars.pop("controller") - url_vars.pop("action") + url_vars.pop("action") ret = handler(request, url_vars) if type(ret) == str: