change to using txroutes

This commit is contained in:
Hay1tsme 2023-02-16 17:13:41 -05:00
parent 32879491f4
commit f18e939dd0
11 changed files with 322 additions and 69 deletions

View File

@ -1,4 +1,5 @@
from core.config import CoreConfig
from core.allnet import AllnetServlet, BillingServlet
from core.allnet import AllnetServlet
from core.aimedb import AimedbFactory
from core.title import TitleServlet
from core.utils import Utils

View File

@ -212,7 +212,7 @@ class AimedbFactory(Factory):
log_fmt = logging.Formatter(log_fmt_str)
self.logger = logging.getLogger("aimedb")
fileHandler = TimedRotatingFileHandler("{0}/{1}.log".format(self.config.server.logs, "aimedb"), when="d", backupCount=10)
fileHandler = TimedRotatingFileHandler("{0}/{1}.log".format(self.config.server.log_dir, "aimedb"), when="d", backupCount=10)
fileHandler.setFormatter(log_fmt)
consoleHandler = logging.StreamHandler()

View File

@ -1,12 +1,17 @@
from twisted.web import resource
from typing import Dict, List, Any, Optional
import logging, coloredlogs
from logging.handlers import TimedRotatingFileHandler
from twisted.web.http import Request
from datetime import datetime
import pytz
import base64
import zlib
from core.config import CoreConfig
from core.data import Data
from core.utils import Utils
class AllnetServlet(resource.Resource):
isLeaf = True
class AllnetServlet():
def __init__(self, core_cfg: CoreConfig, cfg_folder: str):
super().__init__()
self.config = core_cfg
@ -14,26 +19,211 @@ class AllnetServlet(resource.Resource):
self.data = Data(core_cfg)
self.logger = logging.getLogger("allnet")
log_fmt_str = "[%(asctime)s] Allnet | %(levelname)s | %(message)s"
log_fmt = logging.Formatter(log_fmt_str)
if not hasattr(self.logger, "initialized"):
log_fmt_str = "[%(asctime)s] Allnet | %(levelname)s | %(message)s"
log_fmt = logging.Formatter(log_fmt_str)
fileHandler = TimedRotatingFileHandler("{0}/{1}.log".format(self.config.server.log_dir, "allnet"), when="d", backupCount=10)
fileHandler.setFormatter(log_fmt)
fileHandler = TimedRotatingFileHandler("{0}/{1}.log".format(self.config.server.log_dir, "allnet"), when="d", backupCount=10)
fileHandler.setFormatter(log_fmt)
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(log_fmt)
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(log_fmt)
self.logger.addHandler(fileHandler)
self.logger.addHandler(consoleHandler)
self.logger.addHandler(fileHandler)
self.logger.addHandler(consoleHandler)
self.logger.setLevel(core_cfg.allnet.loglevel)
coloredlogs.install(level=core_cfg.allnet.loglevel, logger=self.logger, fmt=log_fmt_str)
self.logger.setLevel(core_cfg.allnet.loglevel)
coloredlogs.install(level=core_cfg.allnet.loglevel, logger=self.logger, fmt=log_fmt_str)
self.logger.initialized = True
class BillingServlet(resource.Resource):
isLeaf = True
def __init__(self, core_cfg: CoreConfig, cfg_folder: str):
super().__init__()
self.config = core_cfg
self.config_folder = cfg_folder
self.data = Data(core_cfg)
self.logger = logging.getLogger('allnet')
if "game_registry" not in globals():
globals()["game_registry"] = Utils.get_all_titles()
if len(globals()["game_registry"]) == 0:
self.logger.error("No games detected!")
def handle_poweron(self, request: Request):
try:
req = AllnetPowerOnRequest(self.allnet_req_to_dict(request.content.getvalue()))
# Validate the request. Currently we only validate the fields we plan on using
if not req.game_id or not req.ver or not req.token or not req.serial or not req.ip:
raise AllnetRequestException(f"Bad request params {vars(req)}")
except AllnetRequestException as e:
self.logger.error(e)
return b""
def handle_dlorder(self, request: Request):
pass
def handle_billing_request(self, request: Request):
pass
def kvp_to_dict(self, *kvp: str) -> List[Dict[str, Any]]:
ret: List[Dict[str, Any]] = []
for x in kvp:
items = x.split('&')
tmp = {}
for item in items:
kvp = item.split('=')
if len(kvp) == 2:
tmp[kvp[0]] = kvp[1]
ret.append(tmp)
def allnet_req_to_dict(self, data: bytes):
"""
Parses an billing request string into a python dictionary
"""
try:
decomp = zlib.decompressobj(-zlib.MAX_WBITS)
unzipped = decomp.decompress(data)
sections = unzipped.decode('ascii').split('\r\n')
return Utils.kvp_to_dict(sections)
except Exception as e:
print(e)
return None
def billing_req_to_dict(self, data: str) -> Optional[List[Dict[str, Any]]]:
"""
Parses an allnet request string into a python dictionary
"""
try:
zipped = base64.b64decode(data)
unzipped = zlib.decompress(zipped)
sections = unzipped.decode('utf-8').split('\r\n')
return Utils.kvp_to_dict(sections)
except Exception as e:
print(e)
return None
def dict_to_http_form_string(self, data:List[Dict[str, Any]], crlf: bool = False, trailing_newline: bool = True) -> Optional[str]:
"""
Takes a python dictionary and parses it into an allnet response string
"""
try:
urlencode = ""
for item in data:
for k,v in item.items():
urlencode += f"{k}={v}&"
if crlf:
urlencode = urlencode[:-1] + "\r\n"
else:
urlencode = urlencode[:-1] + "\n"
if not trailing_newline:
if crlf:
urlencode = urlencode[:-2]
else:
urlencode = urlencode[:-1]
return urlencode
except Exception as e:
print(e)
return None
class AllnetPowerOnRequest():
def __init__(self, req: Dict) -> None:
if req is None:
raise AllnetRequestException("Request processing failed")
self.game_id: str = req["game_id"] if "game_id" in req else ""
self.ver: str = req["ver"] if "ver" in req else ""
self.serial: str = req["serial"] if "serial" in req else ""
self.ip: str = req["ip"] if "ip" in req else ""
self.firm_ver: str = req["firm_ver"] if "firm_ver" in req else ""
self.boot_ver: str = req["boot_ver"] if "boot_ver" in req else ""
self.encode: str = req["encode"] if "encode" in req else ""
try:
self.hops = int(req["hops"]) if "hops" in req else 0
self.format_ver = int(req["format_ver"]) if "format_ver" in req else 2
self.token = int(req["token"]) if "token" in req else 0
except ValueError as e:
raise AllnetRequestException(f"Failed to parse int: {e}")
class AllnetPowerOnResponse3():
def __init__(self, token) -> None:
self.stat = 1
self.uri = ""
self.host = ""
self.place_id = "123"
self.name = ""
self.nickname = ""
self.region0 = "1"
self.region_name0 = "W"
self.region_name1 = ""
self.region_name2 = ""
self.region_name3 = ""
self.country = "JPN"
self.allnet_id = "123"
self.client_timezone = "+0900"
self.utc_time = datetime.now(tz=pytz.timezone('UTC')).strftime("%Y-%m-%dT%H:%M:%SZ")
self.setting = ""
self.res_ver = "3"
self.token = str(token)
class AllnetPowerOnResponse2():
def __init__(self) -> None:
self.stat = 1
self.uri = ""
self.host = ""
self.place_id = "123"
self.name = "Test"
self.nickname = "Test123"
self.region0 = "1"
self.region_name0 = "W"
self.region_name1 = "X"
self.region_name2 = "Y"
self.region_name3 = "Z"
self.country = "JPN"
self.year = datetime.now().year
self.month = datetime.now().month
self.day = datetime.now().day
self.hour = datetime.now().hour
self.minute = datetime.now().minute
self.second = datetime.now().second
self.setting = "1"
self.timezone = "+0900"
self.res_class = "PowerOnResponseV2"
class AllnetDownloadOrderRequest():
def __init__(self, req: Dict) -> None:
self.game_id = req["game_id"] if "game_id" in req else ""
self.ver = req["ver"] if "ver" in req else ""
self.serial = req["serial"] if "serial" in req else ""
self.encode = req["encode"] if "encode" in req else ""
class AllnetDownloadOrderResponse():
def __init__(self, stat: int = 1, serial: str = "", uri: str = "null") -> None:
self.stat = stat
self.serial = serial
self.uri = uri
class BillingResponse():
def __init__(self, playlimit: str, playlimit_sig: str, nearfull: str, nearfull_sig: str,
playhistory: str = "000000/0:000000/0:000000/0") -> None:
self.result = "0"
self.waitime = "100"
self.linelimit = "1"
self.message = ""
self.playlimit = playlimit
self.playlimitsig = playlimit_sig
self.protocolver = "1.000"
self.nearfull = nearfull
self.nearfullsig = nearfull_sig
self.fixlogincnt = "0"
self.fixinterval = "5"
self.playhistory = playhistory
# playhistory -> YYYYMM/C:...
# YYYY -> 4 digit year, MM -> 2 digit month, C -> Playcount during that period
class AllnetRequestException(Exception):
pass

View File

@ -7,27 +7,27 @@ class ServerConfig:
@property
def listen_address(self) -> str:
return CoreConfig.get_config_field(self.__config, '127.0.0.1', 'core', 'server', 'listen_address')
return CoreConfig.get_config_field(self.__config, 'core', 'server', 'listen_address', default='127.0.0.1')
@property
def allow_user_registration(self) -> bool:
return CoreConfig.get_config_field(self.__config, True, 'core', 'server', 'allow_user_registration')
return CoreConfig.get_config_field(self.__config, 'core', 'server', 'allow_user_registration', default=True)
@property
def allow_unregistered_games(self) -> bool:
return CoreConfig.get_config_field(self.__config, True, 'core', 'server', 'allow_unregistered_games')
return CoreConfig.get_config_field(self.__config, 'core', 'server', 'allow_unregistered_games', default=True)
@property
def name(self) -> str:
return CoreConfig.get_config_field(self.__config, "ARTEMiS", 'core', 'server', 'name')
return CoreConfig.get_config_field(self.__config, 'core', 'server', 'name', default="ARTEMiS")
@property
def is_develop(self) -> bool:
return CoreConfig.get_config_field(self.__config, True, 'core', 'server', 'is_develop')
return CoreConfig.get_config_field(self.__config, 'core', 'server', 'is_develop', default=True)
@property
def log_dir(self) -> str:
return CoreConfig.get_config_field(self.__config, 'logs', 'core', 'server', 'log_dir')
return CoreConfig.get_config_field(self.__config, 'core', 'server', 'log_dir', default='logs')
class TitleConfig:
def __init__(self, parent_config: "CoreConfig") -> None:
@ -35,15 +35,15 @@ class TitleConfig:
@property
def loglevel(self) -> int:
return CoreConfig.str_to_loglevel(CoreConfig.get_config_field(self.__config, "info", 'core', 'title', 'loglevel'))
return CoreConfig.str_to_loglevel(CoreConfig.get_config_field(self.__config, 'core', 'title', 'loglevel', default="info"))
@property
def hostname(self) -> str:
return CoreConfig.get_config_field(self.__config, "localhost", 'core', 'title', 'hostname')
return CoreConfig.get_config_field(self.__config, 'core', 'title', 'hostname', default="localhost")
@property
def port(self) -> int:
return CoreConfig.get_config_field(self.__config, 8080, 'core', 'title', 'port')
return CoreConfig.get_config_field(self.__config, 'core', 'title', 'port', default=8080)
class DatabaseConfig:
def __init__(self, parent_config: "CoreConfig") -> None:
@ -51,43 +51,43 @@ class DatabaseConfig:
@property
def host(self) -> str:
return CoreConfig.get_config_field(self.__config, "localhost", 'core', 'database', 'host')
return CoreConfig.get_config_field(self.__config, 'core', 'database', 'host', default="localhost")
@property
def username(self) -> str:
return CoreConfig.get_config_field(self.__config, 'aime', 'core', 'database', 'username')
return CoreConfig.get_config_field(self.__config, 'core', 'database', 'username', default='aime')
@property
def password(self) -> str:
return CoreConfig.get_config_field(self.__config, 'aime', 'core', 'database', 'password')
return CoreConfig.get_config_field(self.__config, 'core', 'database', 'password', default='aime')
@property
def name(self) -> str:
return CoreConfig.get_config_field(self.__config, 'aime', 'core', 'database', 'name')
return CoreConfig.get_config_field(self.__config, 'core', 'database', 'name', default='aime')
@property
def port(self) -> int:
return CoreConfig.get_config_field(self.__config, 3306, 'core', 'database', 'port')
return CoreConfig.get_config_field(self.__config, 'core', 'database', 'port', default=3306)
@property
def protocol(self) -> str:
return CoreConfig.get_config_field(self.__config, "mysql", 'core', 'database', 'type')
return CoreConfig.get_config_field(self.__config, 'core', 'database', 'type', default="mysql")
@property
def sha2_password(self) -> bool:
return CoreConfig.get_config_field(self.__config, False, 'core', 'database', 'sha2_password')
return CoreConfig.get_config_field(self.__config, 'core', 'database', 'sha2_password', default=False)
@property
def loglevel(self) -> int:
return CoreConfig.str_to_loglevel(CoreConfig.get_config_field(self.__config, "info", 'core', 'database', 'loglevel'))
return CoreConfig.str_to_loglevel(CoreConfig.get_config_field(self.__config, 'core', 'database', 'loglevel', default="info"))
@property
def user_table_autoincrement_start(self) -> int:
return CoreConfig.get_config_field(self.__config, 10000, 'core', 'database', 'user_table_autoincrement_start')
return CoreConfig.get_config_field(self.__config, 'core', 'database', 'user_table_autoincrement_start', default=10000)
@property
def memcached_host(self) -> str:
return CoreConfig.get_config_field(self.__config, "localhost", 'core', 'database', 'memcached_host')
return CoreConfig.get_config_field(self.__config, 'core', 'database', 'memcached_host', default="localhost")
class FrontendConfig:
def __init__(self, parent_config: "CoreConfig") -> None:
@ -95,15 +95,15 @@ class FrontendConfig:
@property
def enable(self) -> int:
return CoreConfig.get_config_field(self.__config, False, 'core', 'frontend', 'enable')
return CoreConfig.get_config_field(self.__config, 'core', 'frontend', 'enable', default=False)
@property
def port(self) -> int:
return CoreConfig.get_config_field(self.__config, 8090, 'core', 'frontend', 'port')
return CoreConfig.get_config_field(self.__config, 'core', 'frontend', 'port', default=8090)
@property
def loglevel(self) -> int:
return CoreConfig.str_to_loglevel(CoreConfig.get_config_field(self.__config, 'core', 'frontend', 'loglevel', "info"))
return CoreConfig.str_to_loglevel(CoreConfig.get_config_field(self.__config, 'core', 'frontend', 'loglevel', default="info"))
class AllnetConfig:
def __init__(self, parent_config: "CoreConfig") -> None:
@ -111,15 +111,15 @@ class AllnetConfig:
@property
def loglevel(self) -> int:
return CoreConfig.str_to_loglevel(CoreConfig.get_config_field(self.__config, "info", 'core', 'allnet', 'loglevel'))
return CoreConfig.str_to_loglevel(CoreConfig.get_config_field(self.__config, 'core', 'allnet', 'loglevel', default="info"))
@property
def port(self) -> int:
return CoreConfig.get_config_field(self.__config, 80, 'core', 'allnet', 'port')
return CoreConfig.get_config_field(self.__config, 'core', 'allnet', 'port', default=80)
@property
def allow_online_updates(self) -> int:
return CoreConfig.get_config_field(self.__config, False, 'core', 'allnet', 'allow_online_updates')
return CoreConfig.get_config_field(self.__config, 'core', 'allnet', 'allow_online_updates', default=False)
class BillingConfig:
def __init__(self, parent_config: "CoreConfig") -> None:
@ -127,19 +127,19 @@ class BillingConfig:
@property
def port(self) -> int:
return CoreConfig.get_config_field(self.__config, 8443, 'core', 'billing', 'port')
return CoreConfig.get_config_field(self.__config, 'core', 'billing', 'port', default=8443)
@property
def ssl_key(self) -> str:
return CoreConfig.get_config_field(self.__config, "cert/server.key", 'core', 'billing', 'ssl_key')
return CoreConfig.get_config_field(self.__config, 'core', 'billing', 'ssl_key', default="cert/server.key")
@property
def ssl_cert(self) -> str:
return CoreConfig.get_config_field(self.__config, "cert/server.pem", 'core', 'billing', 'ssl_cert')
return CoreConfig.get_config_field(self.__config, 'core', 'billing', 'ssl_cert', default="cert/server.pem")
@property
def signing_key(self) -> str:
return CoreConfig.get_config_field(self.__config, "cert/billing.key", 'core', 'billing', 'signing_key')
return CoreConfig.get_config_field(self.__config, 'core', 'billing', 'signing_key', default="cert/billing.key")
class AimedbConfig:
def __init__(self, parent_config: "CoreConfig") -> None:
@ -147,15 +147,15 @@ class AimedbConfig:
@property
def loglevel(self) -> int:
return CoreConfig.str_to_loglevel(CoreConfig.get_config_field(self.__config, "info", 'core', 'aimedb', 'loglevel'))
return CoreConfig.str_to_loglevel(CoreConfig.get_config_field(self.__config, 'core', 'aimedb', 'loglevel', default="info"))
@property
def port(self) -> int:
return CoreConfig.get_config_field(self.__config, 22345, 'core', 'aimedb', 'port')
return CoreConfig.get_config_field(self.__config, 'core', 'aimedb', 'port', default=22345)
@property
def key(self) -> str:
return CoreConfig.get_config_field(self.__config, "", 'core', 'aimedb', 'key')
return CoreConfig.get_config_field(self.__config, 'core', 'aimedb', 'key', default="")
class CoreConfig(dict):
def __init__(self) -> None:
@ -179,8 +179,8 @@ class CoreConfig(dict):
return logging.INFO
@classmethod
def get_config_field(cls, __config: dict, default: Any, *path: str) -> Any:
envKey = 'CFG_'
def get_config_field(cls, __config: dict, module, *path: str, default: Any = "") -> Any:
envKey = f'CFG_{module}_'
for arg in path:
envKey += arg + '_'

View File

@ -36,7 +36,7 @@ class Data:
# Prevent the logger from adding handlers multiple times
if not getattr(self.logger, 'handler_set', None):
fileHandler = TimedRotatingFileHandler("{0}/{1}.log".format(self.config.server.logs, "db"), encoding="utf-8",
fileHandler = TimedRotatingFileHandler("{0}/{1}.log".format(self.config.server.log_dir, "db"), encoding="utf-8",
when="d", backupCount=10)
fileHandler.setFormatter(log_fmt)

View File

@ -1,14 +1,42 @@
from twisted.web import resource
import logging, coloredlogs
from logging.handlers import TimedRotatingFileHandler
from twisted.web.http import Request
from core.config import CoreConfig
from core.data import Data
from core.utils import Utils
class TitleServlet(resource.Resource):
isLeaf = True
class TitleServlet():
def __init__(self, core_cfg: CoreConfig, cfg_folder: str):
super().__init__()
self.config = core_cfg
self.config_folder = cfg_folder
self.data = Data(core_cfg)
self.logger = logging.getLogger("title")
if not hasattr(self.logger, "initialized"):
log_fmt_str = "[%(asctime)s] Title | %(levelname)s | %(message)s"
log_fmt = logging.Formatter(log_fmt_str)
fileHandler = TimedRotatingFileHandler("{0}/{1}.log".format(self.config.server.log_dir, "title"), when="d", backupCount=10)
fileHandler.setFormatter(log_fmt)
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(log_fmt)
self.logger.addHandler(fileHandler)
self.logger.addHandler(consoleHandler)
self.logger.setLevel(core_cfg.title.loglevel)
coloredlogs.install(level=core_cfg.title.loglevel, logger=self.logger, fmt=log_fmt_str)
self.logger.initialized = True
if "game_registry" not in globals():
globals()["game_registry"] = Utils.get_all_titles()
def handle_GET(self, request: Request):
pass
def handle_POST(self, request: Request):
pass

21
core/utils.py Normal file
View File

@ -0,0 +1,21 @@
from typing import Dict, List, Any, Optional
from types import ModuleType
import zlib, base64
import importlib
from os import walk
class Utils:
@classmethod
def get_all_titles(cls) -> Dict[str, ModuleType]:
ret: Dict[str, Any] = {}
for root, dirs, files in walk("titles"):
for dir in dirs:
if not dir.startswith("__"):
try:
mod = importlib.import_module(f"titles.{dir}")
ret[dir] = mod
except ImportError as e:
print(f"{dir} - {e}")
return ret

View File

@ -4,7 +4,7 @@ server:
allow_unregistered_games: True
name: "ARTEMiS"
is_develop: True
log_dir: False
log_dir: "logs"
title:
loglevel: "info"

View File

@ -6,6 +6,7 @@ from core import *
from twisted.web import server
from twisted.internet import reactor, endpoints
from txroutes import Dispatcher
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ARTEMiS main entry point")
@ -26,14 +27,14 @@ if __name__ == "__main__":
print(f"Log directory {cfg.server.log_dir} NOT writable, please check permissions")
exit(1)
if cfg.aimedb.key == "":
if not cfg.aimedb.key:
print("!!AIMEDB KEY BLANK, SET KEY IN CORE.YAML!!")
exit(1)
print(f"ARTEMiS starting in {'develop' if cfg.server.is_develop else 'production'} mode")
allnet_server_str = f"tcp:{cfg.allnet.port}:interface={cfg.server.listen_address}"
title_server_str = f"tcp:{cfg.billing.port}:interface={cfg.server.listen_address}"
title_server_str = f"tcp:{cfg.title.port}:interface={cfg.server.listen_address}"
adb_server_str = f"tcp:{cfg.aimedb.port}:interface={cfg.server.listen_address}"
billing_server_str = f"tcp:{cfg.billing.port}:interface={cfg.server.listen_address}"
@ -41,13 +42,23 @@ if __name__ == "__main__":
billing_server_str = f"ssl:{cfg.billing.port}:interface={cfg.server.listen_address}"\
f":privateKey={cfg.billing.ssl_key}:certKey={cfg.billing.ssl_cert}"
endpoints.serverFromString(reactor, allnet_server_str).listen(server.Site(AllnetServlet(cfg, args.config)))
allnet_cls = AllnetServlet(cfg, args.config)
title_cls = TitleServlet(cfg, args.config)
dispatcher = Dispatcher()
dispatcher.connect('allnet_poweron', '/sys/servlet/PowerOn', allnet_cls, action='handle_poweron', conditions=dict(method=['POST']))
dispatcher.connect('allnet_downloadorder', '/sys/servlet/DownloadOrder', allnet_cls, action='handle_dlorder', conditions=dict(method=['POST']))
dispatcher.connect('allnet_billing', '/request', allnet_cls, action='handle_billing_request', conditions=dict(method=['POST']))
dispatcher.connect("title_get", "/{game}/{version}/{endpoint}", title_cls, action="handle_GET", conditions=dict(method=['GET']))
dispatcher.connect("title_post", "/{game}/{version}/{endpoint}", title_cls, action="handle_POST", conditions=dict(method=['POST']))
endpoints.serverFromString(reactor, allnet_server_str).listen(server.Site(dispatcher))
endpoints.serverFromString(reactor, adb_server_str).listen(AimedbFactory(cfg))
if cfg.billing.port > 0:
endpoints.serverFromString(reactor, billing_server_str).listen(server.Site(BillingServlet(cfg)))
endpoints.serverFromString(reactor, billing_server_str).listen(server.Site(dispatcher))
if cfg.title.port > 0:
endpoints.serverFromString(reactor, title_server_str).listen(server.Site(TitleServlet(cfg, args.config)))
endpoints.serverFromString(reactor, title_server_str).listen(server.Site(dispatcher))
reactor.run() # type: ignore

View File

@ -12,3 +12,4 @@ inflection
coloredlogs
pylibmc
wacky
txroutes

View File

@ -11,3 +11,4 @@ PyCryptodome
inflection
coloredlogs
wacky
txroutes