From 8b43d554fc35cf26ff1b9a68f631d9f79dad682a Mon Sep 17 00:00:00 2001 From: Kevin Trocolli Date: Fri, 30 Jun 2023 01:19:17 -0400 Subject: [PATCH] allnet: make use of urllib.parse where applicable --- core/allnet.py | 137 +++++++++++++++++++++++++++---------------------- 1 file changed, 76 insertions(+), 61 deletions(-) diff --git a/core/allnet.py b/core/allnet.py index edb704c..bc5eedf 100644 --- a/core/allnet.py +++ b/core/allnet.py @@ -11,6 +11,7 @@ from Crypto.Hash import SHA from Crypto.Signature import PKCS1_v1_5 from time import strptime from os import path +import urllib.parse from core.config import CoreConfig from core.utils import Utils @@ -79,7 +80,7 @@ class AllnetServlet: req = AllnetPowerOnRequest(req_dict[0]) # Validate the request. Currently we only validate the fields we plan on using - if not req.game_id or not req.ver or not req.serial or not req.ip: + if not req.game_id or not req.ver or not req.serial or not req.ip or not req.firm_ver or not req.boot_ver: raise AllnetRequestException( f"Bad auth request params from {request_ip} - {vars(req)}" ) @@ -89,12 +90,14 @@ class AllnetServlet: self.logger.error(e) return b"" - if req.format_ver == "3": + if req.format_ver == 3: resp = AllnetPowerOnResponse3(req.token) - else: + elif req.format_ver == 2: resp = AllnetPowerOnResponse2() + else: + resp = AllnetPowerOnResponse() - self.logger.debug(f"Allnet request: {vars(req)}") + self.logger.debug(f"Allnet request: {vars(req)}") if req.game_id not in self.uri_registry: if not self.config.server.is_develop: msg = f"Unrecognised game {req.game_id} attempted allnet auth from {request_ip}." @@ -103,8 +106,9 @@ class AllnetServlet: ) self.logger.warn(msg) - resp.stat = 0 - return self.dict_to_http_form_string([vars(resp)]) + resp.stat = -1 + resp_dict = {k: v for k, v in vars(resp).items() if v is not None} + return (urllib.parse.unquote(urllib.parse.urlencode(resp_dict)) + "\n").encode("utf-8") else: self.logger.info( @@ -113,12 +117,15 @@ class AllnetServlet: resp.uri = f"http://{self.config.title.hostname}:{self.config.title.port}/{req.game_id}/{req.ver.replace('.', '')}/" resp.host = f"{self.config.title.hostname}:{self.config.title.port}" - self.logger.debug(f"Allnet response: {vars(resp)}") - return self.dict_to_http_form_string([vars(resp)]) + resp_dict = {k: v for k, v in vars(resp).items() if v is not None} + resp_str = urllib.parse.unquote(urllib.parse.urlencode(resp_dict)) + + self.logger.debug(f"Allnet response: {resp_str}") + return (resp_str + "\n").encode("utf-8") resp.uri, resp.host = self.uri_registry[req.game_id] - machine = self.data.arcade.get_machine(req.serial) + machine = self.data.arcade.get_machine(req.serial) if machine is None and not self.config.server.allow_unregistered_serials: msg = f"Unrecognised serial {req.serial} attempted allnet auth from {request_ip}." self.data.base.log_event( @@ -126,8 +133,9 @@ class AllnetServlet: ) self.logger.warn(msg) - resp.stat = 0 - return self.dict_to_http_form_string([vars(resp)]) + resp.stat = -2 + resp_dict = {k: v for k, v in vars(resp).items() if v is not None} + return (urllib.parse.unquote(urllib.parse.urlencode(resp_dict)) + "\n").encode("utf-8") if machine is not None: arcade = self.data.arcade.get_arcade(machine["arcade"]) @@ -169,9 +177,13 @@ class AllnetServlet: msg = f"{req.serial} authenticated from {request_ip}: {req.game_id} v{req.ver}" self.data.base.log_event("allnet", "ALLNET_AUTH_SUCCESS", logging.INFO, msg) self.logger.info(msg) - self.logger.debug(f"Allnet response: {vars(resp)}") - return self.dict_to_http_form_string([vars(resp)]).encode("utf-8") + resp_dict = {k: v for k, v in vars(resp).items() if v is not None} + resp_str = urllib.parse.unquote(urllib.parse.urlencode(resp_dict)) + self.logger.debug(f"Allnet response: {resp_dict}") + resp_str += "\n" + + return resp_str.encode("utf-8") def handle_dlorder(self, request: Request, _: Dict): request_ip = Utils.get_ip_addr(request) @@ -202,7 +214,7 @@ class AllnetServlet: not self.config.allnet.allow_online_updates or not self.config.allnet.update_cfg_folder ): - return self.dict_to_http_form_string([vars(resp)]) + return urllib.parse.unquote(urllib.parse.urlencode(vars(resp))) + "\n" else: # TODO: Keychip check if path.exists( @@ -217,7 +229,8 @@ class AllnetServlet: self.logger.debug(f"Sending download uri {resp.uri}") self.data.base.log_event("allnet", "DLORDER_REQ_SUCCESS", logging.INFO, f"{Utils.get_ip_addr(request)} requested DL Order for {req.serial} {req.game_id} v{req.ver}") - return self.dict_to_http_form_string([vars(resp)]) + + return urllib.parse.unquote(urllib.parse.urlencode(vars(resp))) + "\n" def handle_dlorder_ini(self, request: Request, match: Dict) -> bytes: if "file" not in match: @@ -323,7 +336,7 @@ class AllnetServlet: resp = BillingResponse(playlimit, playlimit_sig, nearfull, nearfull_sig) - resp_str = self.dict_to_http_form_string([vars(resp)], True) + resp_str = self.dict_to_http_form_string([vars(resp)]) if resp_str is None: self.logger.error(f"Failed to parse response {vars(resp)}") @@ -382,7 +395,7 @@ class AllnetServlet: def dict_to_http_form_string( self, data: List[Dict[str, Any]], - crlf: bool = False, + crlf: bool = True, trailing_newline: bool = True, ) -> Optional[str]: """ @@ -392,21 +405,19 @@ class AllnetServlet: urlencode = "" for item in data: for k, v in item.items(): + if k is None or v is None: + continue urlencode += f"{k}={v}&" - if crlf: urlencode = urlencode[:-1] + "\r\n" else: urlencode = urlencode[:-1] + "\n" - if not trailing_newline: if crlf: urlencode = urlencode[:-2] else: urlencode = urlencode[:-1] - return urlencode - except Exception as e: self.logger.error(f"dict_to_http_form_string: {e} while parsing {data}") return None @@ -416,20 +427,19 @@ class AllnetPowerOnRequest: def __init__(self, req: Dict) -> None: if req is None: raise AllnetRequestException("Request processing failed") - self.game_id: str = req.get("game_id", "") - self.ver: str = req.get("ver", "") - self.serial: str = req.get("serial", "") - self.ip: str = req.get("ip", "") - self.firm_ver: str = req.get("firm_ver", "") - self.boot_ver: str = req.get("boot_ver", "") - self.encode: str = req.get("encode", "") - self.hops = int(req.get("hops", "0")) - self.format_ver = req.get("format_ver", "2") - self.token = int(req.get("token", "0")) + self.game_id: str = req.get("game_id", None) + self.ver: str = req.get("ver", None) + self.serial: str = req.get("serial", None) + self.ip: str = req.get("ip", None) + self.firm_ver: str = req.get("firm_ver", None) + self.boot_ver: str = req.get("boot_ver", None) + self.encode: str = req.get("encode", "EUC-JP") + self.hops = int(req.get("hops", "-1")) + self.format_ver = float(req.get("format_ver", "1.00")) + self.token: str = req.get("token", "0") - -class AllnetPowerOnResponse3: - def __init__(self, token) -> None: +class AllnetPowerOnResponse: + def __init__(self) -> None: self.stat = 1 self.uri = "" self.host = "" @@ -440,40 +450,45 @@ class AllnetPowerOnResponse3: self.region_name0 = "W" self.region_name1 = "" self.region_name2 = "" - self.region_name3 = "" - self.country = "JPN" - self.allnet_id = "123" - self.client_timezone = "+0900" - self.utc_time = datetime.now(tz=pytz.timezone("UTC")).strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) + self.region_name3 = "" self.setting = "1" - self.res_ver = "3" - self.token = str(token) - - -class AllnetPowerOnResponse2: - def __init__(self) -> None: - self.stat = 1 - self.uri = "" - self.host = "" - self.place_id = "123" - self.name = "ARTEMiS" - self.nickname = "ARTEMiS" - self.region0 = "1" - self.region_name0 = "W" - self.region_name1 = "X" - self.region_name2 = "Y" - self.region_name3 = "Z" - self.country = "JPN" self.year = datetime.now().year self.month = datetime.now().month self.day = datetime.now().day self.hour = datetime.now().hour self.minute = datetime.now().minute self.second = datetime.now().second - self.setting = "1" - self.timezone = "+0900" + +class AllnetPowerOnResponse3(AllnetPowerOnResponse): + def __init__(self, token) -> None: + super().__init__() + + # Added in v3 + self.country = "JPN" + self.allnet_id = "123" + self.client_timezone = "+0900" + self.utc_time = datetime.now(tz=pytz.timezone("UTC")).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + self.res_ver = "3" + self.token = token + + # Removed in v3 + self.year = None + self.month = None + self.day = None + self.hour = None + self.minute = None + self.second = None + + +class AllnetPowerOnResponse2(AllnetPowerOnResponse): + def __init__(self) -> None: + super().__init__() + + # Added in v2 + self.country = "JPN" + self.timezone = "+09:00" self.res_class = "PowerOnResponseV2"