From f07d60f8cef7daa91e3d5392a211ed367409c7ac Mon Sep 17 00:00:00 2001 From: beerpiss Date: Thu, 17 Aug 2023 17:16:05 +0700 Subject: [PATCH] db: Migrate to SQLite, improve performance --- .gitignore | 7 +- core/allnet.py | 66 +- core/config.py | 2 +- core/data/__init__.py | 4 +- core/data/cache.py | 134 +- core/data/database.py | 727 ++++---- core/data/schema/__init__.py | 12 +- core/data/schema/arcade.py | 438 ++--- core/data/schema/base.py | 335 ++-- core/data/schema/card.py | 208 +-- core/data/schema/user.py | 225 +-- core/frontend.py | 21 +- core/frontend/user/index.jinja | 6 +- core/mucha.py | 27 +- core/title.py | 4 +- dbutils.py | 2 +- example_config/core.yaml | 9 +- index.py | 19 +- readme.md | 4 +- titles/chuni/__init__.py | 2 + titles/chuni/base.py | 284 +-- titles/chuni/const.py | 2 +- titles/chuni/index.py | 4 +- titles/chuni/new.py | 148 +- titles/chuni/read.py | 13 +- titles/chuni/schema/__init__.py | 12 +- titles/chuni/schema/item.py | 1187 +++++++------ titles/chuni/schema/profile.py | 1286 +++++++------- titles/chuni/schema/score.py | 404 ++--- titles/chuni/schema/static.py | 1190 ++++++------- titles/chuni/sun.py | 8 +- titles/cm/index.py | 2 +- titles/cm/read.py | 12 +- titles/cm/schema/__init__.py | 2 +- titles/cxb/base.py | 356 ++-- titles/cxb/read.py | 4 +- titles/cxb/schema/__init__.py | 12 +- titles/cxb/schema/item.py | 84 +- titles/cxb/schema/profile.py | 156 +- titles/cxb/schema/score.py | 374 ++-- titles/cxb/schema/static.py | 190 +- titles/diva/base.py | 184 +- titles/diva/read.py | 16 +- titles/diva/schema/__init__.py | 34 +- titles/diva/schema/customize.py | 132 +- titles/diva/schema/item.py | 138 +- titles/diva/schema/module.py | 124 +- titles/diva/schema/profile.py | 236 +-- titles/diva/schema/pv_customize.py | 174 +- titles/diva/schema/score.py | 484 ++--- titles/diva/schema/static.py | 624 +++---- titles/idz/userdb.py | 10 +- titles/mai2/base.py | 452 ++--- titles/mai2/config.py | 6 +- titles/mai2/const.py | 16 +- titles/mai2/dx.py | 247 +-- titles/mai2/finale.py | 4 +- titles/mai2/index.py | 42 +- titles/mai2/read.py | 129 +- titles/mai2/schema/__init__.py | 12 +- titles/mai2/schema/item.py | 1106 ++++++------ titles/mai2/schema/profile.py | 1655 +++++++++--------- titles/mai2/schema/score.py | 741 ++++---- titles/mai2/schema/static.py | 502 +++--- titles/mai2/universe.py | 68 +- titles/ongeki/base.py | 217 +-- titles/ongeki/bright.py | 197 ++- titles/ongeki/index.py | 4 +- titles/ongeki/read.py | 8 +- titles/ongeki/schema/__init__.py | 26 +- titles/ongeki/schema/item.py | 1434 +++++++-------- titles/ongeki/schema/log.py | 138 +- titles/ongeki/schema/profile.py | 1022 +++++------ titles/ongeki/schema/score.py | 360 ++-- titles/ongeki/schema/static.py | 676 +++---- titles/pokken/base.py | 189 +- titles/pokken/config.py | 13 +- titles/pokken/frontend.py | 2 +- titles/pokken/index.py | 5 +- titles/pokken/schema/__init__.py | 8 +- titles/pokken/schema/item.py | 96 +- titles/pokken/schema/match.py | 104 +- titles/pokken/schema/profile.py | 826 +++++---- titles/pokken/schema/static.py | 26 +- titles/pokken/services.py | 40 +- titles/sao/base.py | 2048 +++++++++++++++------- titles/sao/const.py | 2 +- titles/sao/database.py | 2 +- titles/sao/handlers/__init__.py | 2 +- titles/sao/handlers/base.py | 2526 +++++++++++++++------------ titles/sao/index.py | 12 +- titles/sao/read.py | 61 +- titles/sao/schema/__init__.py | 6 +- titles/sao/schema/item.py | 1066 +++++------ titles/sao/schema/profile.py | 172 +- titles/sao/schema/static.py | 781 +++++---- titles/wacca/base.py | 568 +++--- titles/wacca/frontend.py | 4 +- titles/wacca/handlers/user_music.py | 4 +- titles/wacca/lilyr.py | 50 +- titles/wacca/read.py | 4 +- titles/wacca/reverse.py | 14 +- titles/wacca/schema/__init__.py | 12 +- titles/wacca/schema/item.py | 416 ++--- titles/wacca/schema/profile.py | 1048 +++++------ titles/wacca/schema/score.py | 658 +++---- titles/wacca/schema/static.py | 170 +- 107 files changed, 15925 insertions(+), 14210 deletions(-) diff --git a/.gitignore b/.gitignore index b5a0e6e..5c1952f 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,9 @@ config/* deliver/* *.gz -dbdump-*.json \ No newline at end of file +dbdump-*.json + +*.sqlite3 +*.sqlite3-journal +*.sqlite3-shm +*.sqlite3-wal diff --git a/core/allnet.py b/core/allnet.py index 9ad5949..3c7f64b 100644 --- a/core/allnet.py +++ b/core/allnet.py @@ -80,7 +80,14 @@ class AllnetServlet: req = AllnetPowerOnRequest(req_dict[0]) # Validate the request. Currently we only validate the fields we plan on using - if not req.game_id or not req.ver or not req.serial or not req.ip or not req.firm_ver or not req.boot_ver: + if ( + not req.game_id + or not req.ver + or not req.serial + or not req.ip + or not req.firm_ver + or not req.boot_ver + ): raise AllnetRequestException( f"Bad auth request params from {request_ip} - {vars(req)}" ) @@ -97,7 +104,7 @@ class AllnetServlet: else: resp = AllnetPowerOnResponse() - self.logger.debug(f"Allnet request: {vars(req)}") + self.logger.debug(f"Allnet request: {vars(req)}") if req.game_id not in self.uri_registry: if not self.config.server.is_develop: msg = f"Unrecognised game {req.game_id} attempted allnet auth from {request_ip}." @@ -108,7 +115,9 @@ class AllnetServlet: resp.stat = -1 resp_dict = {k: v for k, v in vars(resp).items() if v is not None} - return (urllib.parse.unquote(urllib.parse.urlencode(resp_dict)) + "\n").encode("utf-8") + return ( + urllib.parse.unquote(urllib.parse.urlencode(resp_dict)) + "\n" + ).encode("utf-8") else: self.logger.info( @@ -116,16 +125,16 @@ class AllnetServlet: ) resp.uri = f"http://{self.config.title.hostname}:{self.config.title.port}/{req.game_id}/{req.ver.replace('.', '')}/" resp.host = f"{self.config.title.hostname}:{self.config.title.port}" - + resp_dict = {k: v for k, v in vars(resp).items() if v is not None} resp_str = urllib.parse.unquote(urllib.parse.urlencode(resp_dict)) - + self.logger.debug(f"Allnet response: {resp_str}") return (resp_str + "\n").encode("utf-8") resp.uri, resp.host = self.uri_registry[req.game_id] - machine = self.data.arcade.get_machine(req.serial) + machine = self.data.arcade.get_machine(req.serial) if machine is None and not self.config.server.allow_unregistered_serials: msg = f"Unrecognised serial {req.serial} attempted allnet auth from {request_ip}." self.data.base.log_event( @@ -135,7 +144,9 @@ class AllnetServlet: resp.stat = -2 resp_dict = {k: v for k, v in vars(resp).items() if v is not None} - return (urllib.parse.unquote(urllib.parse.urlencode(resp_dict)) + "\n").encode("utf-8") + return ( + urllib.parse.unquote(urllib.parse.urlencode(resp_dict)) + "\n" + ).encode("utf-8") if machine is not None: arcade = self.data.arcade.get_arcade(machine["arcade"]) @@ -180,7 +191,7 @@ class AllnetServlet: resp_dict = {k: v for k, v in vars(resp).items() if v is not None} resp_str = urllib.parse.unquote(urllib.parse.urlencode(resp_dict)) - self.logger.debug(f"Allnet response: {resp_dict}") + self.logger.debug(f"Allnet response: {resp_dict}") resp_str += "\n" return resp_str.encode("utf-8") @@ -228,7 +239,12 @@ class AllnetServlet: resp.uri += f"|http://{self.config.title.hostname}:{self.config.title.port}/dl/ini/{req.game_id}-{req.ver.replace('.', '')}-opt.ini" self.logger.debug(f"Sending download uri {resp.uri}") - self.data.base.log_event("allnet", "DLORDER_REQ_SUCCESS", logging.INFO, f"{Utils.get_ip_addr(request)} requested DL Order for {req.serial} {req.game_id} v{req.ver}") + self.data.base.log_event( + "allnet", + "DLORDER_REQ_SUCCESS", + logging.INFO, + f"{Utils.get_ip_addr(request)} requested DL Order for {req.serial} {req.game_id} v{req.ver}", + ) return urllib.parse.unquote(urllib.parse.urlencode(vars(resp))) + "\n" @@ -239,8 +255,15 @@ class AllnetServlet: req_file = match["file"].replace("%0A", "") if path.exists(f"{self.config.allnet.update_cfg_folder}/{req_file}"): - self.logger.info(f"Request for DL INI file {req_file} from {Utils.get_ip_addr(request)} successful") - self.data.base.log_event("allnet", "DLORDER_INI_SENT", logging.INFO, f"{Utils.get_ip_addr(request)} successfully recieved {req_file}") + self.logger.info( + f"Request for DL INI file {req_file} from {Utils.get_ip_addr(request)} successful" + ) + self.data.base.log_event( + "allnet", + "DLORDER_INI_SENT", + logging.INFO, + f"{Utils.get_ip_addr(request)} successfully recieved {req_file}", + ) return open( f"{self.config.allnet.update_cfg_folder}/{req_file}", "rb" ).read() @@ -257,7 +280,7 @@ class AllnetServlet: def handle_loaderstaterecorder(self, request: Request, match: Dict) -> bytes: req_data = request.content.getvalue() sections = req_data.decode("utf-8").split("\r\n") - + req_dict = dict(urllib.parse.parse_qsl(sections[0])) serial: Union[str, None] = req_dict.get("serial", None) @@ -266,12 +289,19 @@ class AllnetServlet: dl_state: Union[str, None] = req_dict.get("dld_st", None) ip = Utils.get_ip_addr(request) - if serial is None or num_files_dld is None or num_files_to_dl is None or dl_state is None: + if ( + serial is None + or num_files_dld is None + or num_files_to_dl is None + or dl_state is None + ): return "NG".encode() - self.logger.info(f"LoaderStateRecorder Request from {ip} {serial}: {num_files_dld}/{num_files_to_dl} Files download (State: {dl_state})") + self.logger.info( + f"LoaderStateRecorder Request from {ip} {serial}: {num_files_dld}/{num_files_to_dl} Files download (State: {dl_state})" + ) return "OK".encode() - + def handle_alive(self, request: Request, match: Dict) -> bytes: return "OK".encode() @@ -297,7 +327,7 @@ class AllnetServlet: kc_game: str = req_dict[0]["gameid"] kc_date = strptime(req_dict[0]["date"], "%Y%m%d%H%M%S") kc_serial_bytes = kc_serial.encode() - + except KeyError as e: return f"result=5&linelimit=&message={e} field is missing".encode() @@ -431,6 +461,7 @@ class AllnetPowerOnRequest: self.format_ver = float(req.get("format_ver", "1.00")) self.token: str = req.get("token", "0") + class AllnetPowerOnResponse: def __init__(self) -> None: self.stat = 1 @@ -443,7 +474,7 @@ class AllnetPowerOnResponse: self.region_name0 = "W" self.region_name1 = "" self.region_name2 = "" - self.region_name3 = "" + self.region_name3 = "" self.setting = "1" self.year = datetime.now().year self.month = datetime.now().month @@ -452,6 +483,7 @@ class AllnetPowerOnResponse: self.minute = datetime.now().minute self.second = datetime.now().second + class AllnetPowerOnResponse3(AllnetPowerOnResponse): def __init__(self, token) -> None: super().__init__() diff --git a/core/config.py b/core/config.py index 8b85353..3427014 100644 --- a/core/config.py +++ b/core/config.py @@ -111,7 +111,7 @@ class DatabaseConfig: @property def protocol(self) -> str: return CoreConfig.get_config_field( - self.__config, "core", "database", "type", default="mysql" + self.__config, "core", "database", "protocol", default="mysql" ) @property diff --git a/core/data/__init__.py b/core/data/__init__.py index eb30d05..cfd31c3 100644 --- a/core/data/__init__.py +++ b/core/data/__init__.py @@ -1,2 +1,2 @@ -from core.data.database import Data -from core.data.cache import cached +from core.data.database import Data +from core.data.cache import cached diff --git a/core/data/cache.py b/core/data/cache.py index cabf597..df48ca7 100644 --- a/core/data/cache.py +++ b/core/data/cache.py @@ -1,67 +1,67 @@ -from typing import Any, Callable -from functools import wraps -import hashlib -import pickle -import logging -from core.config import CoreConfig - -cfg: CoreConfig = None # type: ignore -# Make memcache optional -try: - import pylibmc # type: ignore - - has_mc = True -except ModuleNotFoundError: - has_mc = False - - -def cached(lifetime: int = 10, extra_key: Any = None) -> Callable: - def _cached(func: Callable) -> Callable: - if has_mc: - hostname = "127.0.0.1" - if cfg: - hostname = cfg.database.memcached_host - memcache = pylibmc.Client([hostname], binary=True) - memcache.behaviors = {"tcp_nodelay": True, "ketama": True} - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - if lifetime is not None: - # Hash function args - items = kwargs.items() - hashable_args = (args[1:], sorted(list(items))) - args_key = hashlib.md5(pickle.dumps(hashable_args)).hexdigest() - - # Generate unique cache key - cache_key = f'{func.__module__}-{func.__name__}-{args_key}-{extra_key() if hasattr(extra_key, "__call__") else extra_key}' - - # Return cached version if allowed and available - try: - result = memcache.get(cache_key) - except pylibmc.Error as e: - logging.getLogger("database").error(f"Memcache failed: {e}") - result = None - - if result is not None: - logging.getLogger("database").debug(f"Cache hit: {result}") - return result - - # Generate output - result = func(*args, **kwargs) - - # Cache output if allowed - if lifetime is not None and result is not None: - logging.getLogger("database").debug(f"Setting cache: {result}") - memcache.set(cache_key, result, lifetime) - - return result - - else: - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - return func(*args, **kwargs) - - return wrapper - - return _cached +from typing import Any, Callable +from functools import wraps +import hashlib +import pickle +import logging +from core.config import CoreConfig + +cfg: CoreConfig = None # type: ignore +# Make memcache optional +try: + import pylibmc # type: ignore + + has_mc = True +except ModuleNotFoundError: + has_mc = False + + +def cached(lifetime: int = 10, extra_key: Any = None) -> Callable: + def _cached(func: Callable) -> Callable: + if has_mc: + hostname = "127.0.0.1" + if cfg: + hostname = cfg.database.memcached_host + memcache = pylibmc.Client([hostname], binary=True) + memcache.behaviors = {"tcp_nodelay": True, "ketama": True} + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if lifetime is not None: + # Hash function args + items = kwargs.items() + hashable_args = (args[1:], sorted(list(items))) + args_key = hashlib.md5(pickle.dumps(hashable_args)).hexdigest() + + # Generate unique cache key + cache_key = f'{func.__module__}-{func.__name__}-{args_key}-{extra_key() if hasattr(extra_key, "__call__") else extra_key}' + + # Return cached version if allowed and available + try: + result = memcache.get(cache_key) + except pylibmc.Error as e: + logging.getLogger("database").error(f"Memcache failed: {e}") + result = None + + if result is not None: + logging.getLogger("database").debug(f"Cache hit: {result}") + return result + + # Generate output + result = func(*args, **kwargs) + + # Cache output if allowed + if lifetime is not None and result is not None: + logging.getLogger("database").debug(f"Setting cache: {result}") + memcache.set(cache_key, result, lifetime) + + return result + + else: + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) + + return wrapper + + return _cached diff --git a/core/data/database.py b/core/data/database.py index 9fb2606..88c8837 100644 --- a/core/data/database.py +++ b/core/data/database.py @@ -1,357 +1,370 @@ -import logging, coloredlogs -from typing import Optional, Dict, List -from sqlalchemy.orm import scoped_session, sessionmaker -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy import create_engine -from logging.handlers import TimedRotatingFileHandler -import importlib, os -import secrets, string -import bcrypt -from hashlib import sha256 - -from core.config import CoreConfig -from core.data.schema import * -from core.utils import Utils - - -class Data: - current_schema_version = 4 - engine = None - session = None - user = None - arcade = None - card = None - base = None - def __init__(self, cfg: CoreConfig) -> None: - self.config = cfg - - if self.config.database.sha2_password: - passwd = sha256(self.config.database.password.encode()).digest() - self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4" - else: - self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4" - - if Data.engine is None: - Data.engine = create_engine(self.__url, pool_recycle=3600) - self.__engine = Data.engine - - if Data.session is None: - s = sessionmaker(bind=Data.engine, autoflush=True, autocommit=True) - Data.session = scoped_session(s) - - if Data.user is None: - Data.user = UserData(self.config, self.session) - - if Data.arcade is None: - Data.arcade = ArcadeData(self.config, self.session) - - if Data.card is None: - Data.card = CardData(self.config, self.session) - - if Data.base is None: - Data.base = BaseData(self.config, self.session) - - self.logger = logging.getLogger("database") - - # Prevent the logger from adding handlers multiple times - if not getattr(self.logger, "handler_set", None): - log_fmt_str = "[%(asctime)s] %(levelname)s | Database | %(message)s" - log_fmt = logging.Formatter(log_fmt_str) - fileHandler = TimedRotatingFileHandler( - "{0}/{1}.log".format(self.config.server.log_dir, "db"), - encoding="utf-8", - when="d", - backupCount=10, - ) - fileHandler.setFormatter(log_fmt) - - consoleHandler = logging.StreamHandler() - consoleHandler.setFormatter(log_fmt) - - self.logger.addHandler(fileHandler) - self.logger.addHandler(consoleHandler) - - self.logger.setLevel(self.config.database.loglevel) - coloredlogs.install( - cfg.database.loglevel, logger=self.logger, fmt=log_fmt_str - ) - self.logger.handler_set = True # type: ignore - - def create_database(self): - self.logger.info("Creating databases...") - try: - metadata.create_all(self.__engine.connect()) - except SQLAlchemyError as e: - self.logger.error(f"Failed to create databases! {e}") - return - - games = Utils.get_all_titles() - for game_dir, game_mod in games.items(): - try: - if hasattr(game_mod, "database") and hasattr( - game_mod, "current_schema_version" - ): - game_mod.database(self.config) - metadata.create_all(self.__engine.connect()) - - self.base.touch_schema_ver( - game_mod.current_schema_version, game_mod.game_codes[0] - ) - - except Exception as e: - self.logger.warning( - f"Could not load database schema from {game_dir} - {e}" - ) - - self.logger.info(f"Setting base_schema_ver to {self.current_schema_version}") - self.base.set_schema_ver(self.current_schema_version) - - self.logger.info( - f"Setting user auto_incrememnt to {self.config.database.user_table_autoincrement_start}" - ) - self.user.reset_autoincrement( - self.config.database.user_table_autoincrement_start - ) - - def recreate_database(self): - self.logger.info("Dropping all databases...") - self.base.execute("SET FOREIGN_KEY_CHECKS=0") - try: - metadata.drop_all(self.__engine.connect()) - except SQLAlchemyError as e: - self.logger.error(f"Failed to drop databases! {e}") - return - - for root, dirs, files in os.walk("./titles"): - for dir in dirs: - if not dir.startswith("__"): - try: - mod = importlib.import_module(f"titles.{dir}") - - try: - if hasattr(mod, "database"): - mod.database(self.config) - metadata.drop_all(self.__engine.connect()) - - except Exception as e: - self.logger.warning( - f"Could not load database schema from {dir} - {e}" - ) - - except ImportError as e: - self.logger.warning( - f"Failed to load database schema dir {dir} - {e}" - ) - break - - self.base.execute("SET FOREIGN_KEY_CHECKS=1") - - self.create_database() - - def migrate_database(self, game: str, version: Optional[int], action: str) -> None: - old_ver = self.base.get_schema_ver(game) - sql = "" - if version is None: - if not game == "CORE": - titles = Utils.get_all_titles() - - for folder, mod in titles.items(): - if not mod.game_codes[0] == game: - continue - - if hasattr(mod, "current_schema_version"): - version = mod.current_schema_version - - else: - self.logger.warn( - f"current_schema_version not found for {folder}" - ) - - else: - version = self.current_schema_version - - if version is None: - self.logger.warn( - f"Could not determine latest version for {game}, please specify --version" - ) - - if old_ver is None: - self.logger.error( - f"Schema for game {game} does not exist, did you run the creation script?" - ) - return - - if old_ver == version: - self.logger.info( - f"Schema for game {game} is already version {old_ver}, nothing to do" - ) - return - - if action == "upgrade": - for x in range(old_ver, version): - if not os.path.exists( - f"core/data/schema/versions/{game.upper()}_{x + 1}_{action}.sql" - ): - self.logger.error( - f"Could not find {action} script {game.upper()}_{x + 1}_{action}.sql in core/data/schema/versions folder" - ) - return - - with open( - f"core/data/schema/versions/{game.upper()}_{x + 1}_{action}.sql", - "r", - encoding="utf-8", - ) as f: - sql = f.read() - - result = self.base.execute(sql) - if result is None: - self.logger.error("Error execuing sql script!") - return None - - else: - for x in range(old_ver, version, -1): - if not os.path.exists( - f"core/data/schema/versions/{game.upper()}_{x - 1}_{action}.sql" - ): - self.logger.error( - f"Could not find {action} script {game.upper()}_{x - 1}_{action}.sql in core/data/schema/versions folder" - ) - return - - with open( - f"core/data/schema/versions/{game.upper()}_{x - 1}_{action}.sql", - "r", - encoding="utf-8", - ) as f: - sql = f.read() - - result = self.base.execute(sql) - if result is None: - self.logger.error("Error execuing sql script!") - return None - - result = self.base.set_schema_ver(version, game) - if result is None: - self.logger.error("Error setting version in schema_version table!") - return None - - self.logger.info(f"Successfully migrated {game} to schema version {version}") - - def create_owner(self, email: Optional[str] = None) -> None: - pw = "".join( - secrets.choice(string.ascii_letters + string.digits) for i in range(20) - ) - hash = bcrypt.hashpw(pw.encode(), bcrypt.gensalt()) - - user_id = self.user.create_user(email=email, permission=255, password=hash) - if user_id is None: - self.logger.error(f"Failed to create owner with email {email}") - return - - card_id = self.card.create_card(user_id, "00000000000000000000") - if card_id is None: - self.logger.error(f"Failed to create card for owner with id {user_id}") - return - - self.logger.warn( - f"Successfully created owner with email {email}, access code 00000000000000000000, and password {pw} Make sure to change this password and assign a real card ASAP!" - ) - - def migrate_card(self, old_ac: str, new_ac: str, should_force: bool) -> None: - if old_ac == new_ac: - self.logger.error("Both access codes are the same!") - return - - new_card = self.card.get_card_by_access_code(new_ac) - if new_card is None: - self.card.update_access_code(old_ac, new_ac) - return - - if not should_force: - self.logger.warn( - f"Card already exists for access code {new_ac} (id {new_card['id']}). If you wish to continue, rerun with the '--force' flag." - f" All exiting data on the target card {new_ac} will be perminently erased and replaced with data from card {old_ac}." - ) - return - - self.logger.info( - f"All exiting data on the target card {new_ac} will be perminently erased and replaced with data from card {old_ac}." - ) - self.card.delete_card(new_card["id"]) - self.card.update_access_code(old_ac, new_ac) - - hanging_user = self.user.get_user(new_card["user"]) - if hanging_user["password"] is None: - self.logger.info(f"Delete hanging user {hanging_user['id']}") - self.user.delete_user(hanging_user["id"]) - - def delete_hanging_users(self) -> None: - """ - Finds and deletes users that have not registered for the webui that have no cards assocated with them. - """ - unreg_users = self.user.get_unregistered_users() - if unreg_users is None: - self.logger.error("Error occoured finding unregistered users") - - for user in unreg_users: - cards = self.card.get_user_cards(user["id"]) - if cards is None: - self.logger.error(f"Error getting cards for user {user['id']}") - continue - - if not cards: - self.logger.info(f"Delete hanging user {user['id']}") - self.user.delete_user(user["id"]) - - def autoupgrade(self) -> None: - all_game_versions = self.base.get_all_schema_vers() - if all_game_versions is None: - self.logger.warn("Failed to get schema versions") - return - - all_games = Utils.get_all_titles() - all_games_list: Dict[str, int] = {} - for _, mod in all_games.items(): - if hasattr(mod, "current_schema_version"): - all_games_list[mod.game_codes[0]] = mod.current_schema_version - - for x in all_game_versions: - failed = False - game = x["game"].upper() - update_ver = int(x["version"]) - latest_ver = all_games_list.get(game, 1) - if game == "CORE": - latest_ver = self.current_schema_version - - if update_ver == latest_ver: - self.logger.info(f"{game} is already latest version") - continue - - for y in range(update_ver + 1, latest_ver + 1): - if os.path.exists(f"core/data/schema/versions/{game}_{y}_upgrade.sql"): - with open( - f"core/data/schema/versions/{game}_{y}_upgrade.sql", - "r", - encoding="utf-8", - ) as f: - sql = f.read() - - result = self.base.execute(sql) - if result is None: - self.logger.error( - f"Error execuing sql script for game {game} v{y}!" - ) - failed = True - break - else: - self.logger.warning(f"Could not find script {game}_{y}_upgrade.sql") - failed = True - - if not failed: - self.base.set_schema_ver(latest_ver, game) - - def show_versions(self) -> None: - all_game_versions = self.base.get_all_schema_vers() - for ver in all_game_versions: - self.logger.info(f"{ver['game']} -> v{ver['version']}") +import logging, coloredlogs +from typing import Optional, Dict, List +from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy import create_engine +from logging.handlers import TimedRotatingFileHandler +import importlib, os +import secrets, string +import bcrypt +from hashlib import sha256 + +from core.config import CoreConfig +from core.data.schema import * +from core.utils import Utils + + +class Data: + current_schema_version = 4 + engine = None + session = None + user = None + arcade = None + card = None + base = None + + def __init__(self, cfg: CoreConfig) -> None: + self.config = cfg + + if self.config.database.protocol == "sqlite": + self.__url = ( + f"{self.config.database.protocol}:///{self.config.database.name}" + ) + else: + if self.config.database.sha2_password: + passwd = sha256(self.config.database.password.encode()).digest() + self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4" + else: + self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4" + + if Data.engine is None: + Data.engine = create_engine(self.__url, pool_recycle=3600) + self.__engine = Data.engine + + if Data.session is None: + s = sessionmaker( + bind=Data.engine, + # There are a billion transactions, autoflushing will take a lot of time + # and potentially locks up the database. + autoflush=self.config.database.protocol != "sqlite", + autocommit=True, + ) + Data.session = scoped_session(s) + + if Data.user is None: + Data.user = UserData(self.config, self.session) + + if Data.arcade is None: + Data.arcade = ArcadeData(self.config, self.session) + + if Data.card is None: + Data.card = CardData(self.config, self.session) + + if Data.base is None: + Data.base = BaseData(self.config, self.session) + + self.logger = logging.getLogger("database") + + # Prevent the logger from adding handlers multiple times + if not getattr(self.logger, "handler_set", None): + log_fmt_str = "[%(asctime)s] %(levelname)s | Database | %(message)s" + log_fmt = logging.Formatter(log_fmt_str) + fileHandler = TimedRotatingFileHandler( + "{0}/{1}.log".format(self.config.server.log_dir, "db"), + encoding="utf-8", + when="d", + backupCount=10, + ) + fileHandler.setFormatter(log_fmt) + + consoleHandler = logging.StreamHandler() + consoleHandler.setFormatter(log_fmt) + + self.logger.addHandler(fileHandler) + self.logger.addHandler(consoleHandler) + + self.logger.setLevel(self.config.database.loglevel) + coloredlogs.install( + cfg.database.loglevel, logger=self.logger, fmt=log_fmt_str + ) + self.logger.handler_set = True # type: ignore + + def create_database(self): + self.logger.info("Creating databases...") + self.base.setup_sqlite() + try: + metadata.create_all(self.__engine.connect()) + except SQLAlchemyError as e: + self.logger.error(f"Failed to create databases! {e}") + return + + games = Utils.get_all_titles() + for game_dir, game_mod in games.items(): + try: + if hasattr(game_mod, "database") and hasattr( + game_mod, "current_schema_version" + ): + game_mod.database(self.config) + metadata.create_all(self.__engine.connect()) + + self.base.touch_schema_ver( + game_mod.current_schema_version, game_mod.game_codes[0] + ) + + except Exception as e: + self.logger.warning( + f"Could not load database schema from {game_dir} - {e}" + ) + + self.logger.info(f"Setting base_schema_ver to {self.current_schema_version}") + self.base.set_schema_ver(self.current_schema_version) + + self.logger.info( + f"Setting user auto_incrememnt to {self.config.database.user_table_autoincrement_start}" + ) + self.user.reset_autoincrement( + self.config.database.user_table_autoincrement_start + ) + + def recreate_database(self): + self.logger.info("Dropping all databases...") + self.base.execute("SET FOREIGN_KEY_CHECKS=0") + try: + metadata.drop_all(self.__engine.connect()) + except SQLAlchemyError as e: + self.logger.error(f"Failed to drop databases! {e}") + return + + for root, dirs, files in os.walk("./titles"): + for dir in dirs: + if not dir.startswith("__"): + try: + mod = importlib.import_module(f"titles.{dir}") + + try: + if hasattr(mod, "database"): + mod.database(self.config) + metadata.drop_all(self.__engine.connect()) + + except Exception as e: + self.logger.warning( + f"Could not load database schema from {dir} - {e}" + ) + + except ImportError as e: + self.logger.warning( + f"Failed to load database schema dir {dir} - {e}" + ) + break + + self.base.execute("SET FOREIGN_KEY_CHECKS=1") + + self.create_database() + + def migrate_database(self, game: str, version: Optional[int], action: str) -> None: + old_ver = self.base.get_schema_ver(game) + sql = "" + if version is None: + if not game == "CORE": + titles = Utils.get_all_titles() + + for folder, mod in titles.items(): + if not mod.game_codes[0] == game: + continue + + if hasattr(mod, "current_schema_version"): + version = mod.current_schema_version + + else: + self.logger.warn( + f"current_schema_version not found for {folder}" + ) + + else: + version = self.current_schema_version + + if version is None: + self.logger.warn( + f"Could not determine latest version for {game}, please specify --version" + ) + + if old_ver is None: + self.logger.error( + f"Schema for game {game} does not exist, did you run the creation script?" + ) + return + + if old_ver == version: + self.logger.info( + f"Schema for game {game} is already version {old_ver}, nothing to do" + ) + return + + if action == "upgrade": + for x in range(old_ver, version): + if not os.path.exists( + f"core/data/schema/versions/{game.upper()}_{x + 1}_{action}.sql" + ): + self.logger.error( + f"Could not find {action} script {game.upper()}_{x + 1}_{action}.sql in core/data/schema/versions folder" + ) + return + + with open( + f"core/data/schema/versions/{game.upper()}_{x + 1}_{action}.sql", + "r", + encoding="utf-8", + ) as f: + sql = f.read() + + result = self.base.execute(sql) + if result is None: + self.logger.error("Error execuing sql script!") + return None + + else: + for x in range(old_ver, version, -1): + if not os.path.exists( + f"core/data/schema/versions/{game.upper()}_{x - 1}_{action}.sql" + ): + self.logger.error( + f"Could not find {action} script {game.upper()}_{x - 1}_{action}.sql in core/data/schema/versions folder" + ) + return + + with open( + f"core/data/schema/versions/{game.upper()}_{x - 1}_{action}.sql", + "r", + encoding="utf-8", + ) as f: + sql = f.read() + + result = self.base.execute(sql) + if result is None: + self.logger.error("Error execuing sql script!") + return None + + result = self.base.set_schema_ver(version, game) + if result is None: + self.logger.error("Error setting version in schema_version table!") + return None + + self.logger.info(f"Successfully migrated {game} to schema version {version}") + + def create_owner(self, email: Optional[str] = None) -> None: + pw = "".join( + secrets.choice(string.ascii_letters + string.digits) for i in range(20) + ) + hash = bcrypt.hashpw(pw.encode(), bcrypt.gensalt()) + + user_id = self.user.create_user(email=email, permission=255, password=hash) + if user_id is None: + self.logger.error(f"Failed to create owner with email {email}") + return + + card_id = self.card.create_card(user_id, "00000000000000000000") + if card_id is None: + self.logger.error(f"Failed to create card for owner with id {user_id}") + return + + self.logger.warn( + f"Successfully created owner with email {email}, access code 00000000000000000000, and password {pw} Make sure to change this password and assign a real card ASAP!" + ) + + def migrate_card(self, old_ac: str, new_ac: str, should_force: bool) -> None: + if old_ac == new_ac: + self.logger.error("Both access codes are the same!") + return + + new_card = self.card.get_card_by_access_code(new_ac) + if new_card is None: + self.card.update_access_code(old_ac, new_ac) + return + + if not should_force: + self.logger.warn( + f"Card already exists for access code {new_ac} (id {new_card['id']}). If you wish to continue, rerun with the '--force' flag." + f" All exiting data on the target card {new_ac} will be perminently erased and replaced with data from card {old_ac}." + ) + return + + self.logger.info( + f"All exiting data on the target card {new_ac} will be perminently erased and replaced with data from card {old_ac}." + ) + self.card.delete_card(new_card["id"]) + self.card.update_access_code(old_ac, new_ac) + + hanging_user = self.user.get_user(new_card["user"]) + if hanging_user["password"] is None: + self.logger.info(f"Delete hanging user {hanging_user['id']}") + self.user.delete_user(hanging_user["id"]) + + def delete_hanging_users(self) -> None: + """ + Finds and deletes users that have not registered for the webui that have no cards assocated with them. + """ + unreg_users = self.user.get_unregistered_users() + if unreg_users is None: + self.logger.error("Error occoured finding unregistered users") + + for user in unreg_users: + cards = self.card.get_user_cards(user["id"]) + if cards is None: + self.logger.error(f"Error getting cards for user {user['id']}") + continue + + if not cards: + self.logger.info(f"Delete hanging user {user['id']}") + self.user.delete_user(user["id"]) + + def autoupgrade(self) -> None: + all_game_versions = self.base.get_all_schema_vers() + if all_game_versions is None: + self.logger.warn("Failed to get schema versions") + return + + all_games = Utils.get_all_titles() + all_games_list: Dict[str, int] = {} + for _, mod in all_games.items(): + if hasattr(mod, "current_schema_version"): + all_games_list[mod.game_codes[0]] = mod.current_schema_version + + for x in all_game_versions: + failed = False + game = x["game"].upper() + update_ver = int(x["version"]) + latest_ver = all_games_list.get(game, 1) + if game == "CORE": + latest_ver = self.current_schema_version + + if update_ver == latest_ver: + self.logger.info(f"{game} is already latest version") + continue + + for y in range(update_ver + 1, latest_ver + 1): + if os.path.exists(f"core/data/schema/versions/{game}_{y}_upgrade.sql"): + with open( + f"core/data/schema/versions/{game}_{y}_upgrade.sql", + "r", + encoding="utf-8", + ) as f: + sql = f.read() + + result = self.base.execute(sql) + if result is None: + self.logger.error( + f"Error execuing sql script for game {game} v{y}!" + ) + failed = True + break + else: + self.logger.warning(f"Could not find script {game}_{y}_upgrade.sql") + failed = True + + if not failed: + self.base.set_schema_ver(latest_ver, game) + + def show_versions(self) -> None: + all_game_versions = self.base.get_all_schema_vers() + for ver in all_game_versions: + self.logger.info(f"{ver['game']} -> v{ver['version']}") diff --git a/core/data/schema/__init__.py b/core/data/schema/__init__.py index 45931d7..810e24a 100644 --- a/core/data/schema/__init__.py +++ b/core/data/schema/__init__.py @@ -1,6 +1,6 @@ -from core.data.schema.user import UserData -from core.data.schema.card import CardData -from core.data.schema.base import BaseData, metadata -from core.data.schema.arcade import ArcadeData - -__all__ = ["UserData", "CardData", "BaseData", "metadata", "ArcadeData"] +from core.data.schema.user import UserData +from core.data.schema.card import CardData +from core.data.schema.base import BaseData, metadata +from core.data.schema.arcade import ArcadeData + +__all__ = ["UserData", "CardData", "BaseData", "metadata", "ArcadeData"] diff --git a/core/data/schema/arcade.py b/core/data/schema/arcade.py index e1d9b1f..c6753f2 100644 --- a/core/data/schema/arcade.py +++ b/core/data/schema/arcade.py @@ -1,219 +1,219 @@ -from typing import Optional, Dict -from sqlalchemy import Table, Column -from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint -from sqlalchemy.types import Integer, String, Boolean -from sqlalchemy.sql import func, select -from sqlalchemy.dialects.mysql import insert -import re - -from core.data.schema.base import BaseData, metadata -from core.const import * - -arcade = Table( - "arcade", - metadata, - Column("id", Integer, primary_key=True, nullable=False), - Column("name", String(255)), - Column("nickname", String(255)), - Column("country", String(3)), - Column("country_id", Integer), - Column("state", String(255)), - Column("city", String(255)), - Column("region_id", Integer), - Column("timezone", String(255)), - mysql_charset="utf8mb4", -) - -machine = Table( - "machine", - metadata, - Column("id", Integer, primary_key=True, nullable=False), - Column( - "arcade", - ForeignKey("arcade.id", ondelete="cascade", onupdate="cascade"), - nullable=False, - ), - Column("serial", String(15), nullable=False), - Column("board", String(15)), - Column("game", String(4)), - Column("country", String(3)), # overwrites if not null - Column("timezone", String(255)), - Column("ota_enable", Boolean), - Column("is_cab", Boolean), - mysql_charset="utf8mb4", -) - -arcade_owner = Table( - "arcade_owner", - metadata, - Column( - "user", - Integer, - ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"), - nullable=False, - ), - Column( - "arcade", - Integer, - ForeignKey("arcade.id", ondelete="cascade", onupdate="cascade"), - nullable=False, - ), - Column("permissions", Integer, nullable=False), - PrimaryKeyConstraint("user", "arcade", name="arcade_owner_pk"), - mysql_charset="utf8mb4", -) - - -class ArcadeData(BaseData): - def get_machine(self, serial: str = None, id: int = None) -> Optional[Dict]: - if serial is not None: - serial = serial.replace("-", "") - if len(serial) == 11: - sql = machine.select(machine.c.serial.like(f"{serial}%")) - - elif len(serial) == 15: - sql = machine.select(machine.c.serial == serial) - - else: - self.logger.error(f"{__name__ }: Malformed serial {serial}") - return None - - elif id is not None: - sql = machine.select(machine.c.id == id) - - else: - self.logger.error(f"{__name__ }: Need either serial or ID to look up!") - return None - - result = self.execute(sql) - if result is None: - return None - return result.fetchone() - - def put_machine( - self, - arcade_id: int, - serial: str = "", - board: str = None, - game: str = None, - is_cab: bool = False, - ) -> Optional[int]: - if arcade_id: - self.logger.error(f"{__name__ }: Need arcade id!") - return None - - sql = machine.insert().values( - arcade=arcade_id, keychip=serial, board=board, game=game, is_cab=is_cab - ) - - result = self.execute(sql) - if result is None: - return None - return result.lastrowid - - def set_machine_serial(self, machine_id: int, serial: str) -> None: - result = self.execute( - machine.update(machine.c.id == machine_id).values(keychip=serial) - ) - if result is None: - self.logger.error( - f"Failed to update serial for machine {machine_id} -> {serial}" - ) - return result.lastrowid - - def set_machine_boardid(self, machine_id: int, boardid: str) -> None: - result = self.execute( - machine.update(machine.c.id == machine_id).values(board=boardid) - ) - if result is None: - self.logger.error( - f"Failed to update board id for machine {machine_id} -> {boardid}" - ) - - def get_arcade(self, id: int) -> Optional[Dict]: - sql = arcade.select(arcade.c.id == id) - result = self.execute(sql) - if result is None: - return None - return result.fetchone() - - def put_arcade( - self, - name: str, - nickname: str = None, - country: str = "JPN", - country_id: int = 1, - state: str = "", - city: str = "", - regional_id: int = 1, - ) -> Optional[int]: - if nickname is None: - nickname = name - - sql = arcade.insert().values( - name=name, - nickname=nickname, - country=country, - country_id=country_id, - state=state, - city=city, - regional_id=regional_id, - ) - - result = self.execute(sql) - if result is None: - return None - return result.lastrowid - - def get_arcade_owners(self, arcade_id: int) -> Optional[Dict]: - sql = select(arcade_owner).where(arcade_owner.c.arcade == arcade_id) - - result = self.execute(sql) - if result is None: - return None - return result.fetchall() - - def add_arcade_owner(self, arcade_id: int, user_id: int) -> None: - sql = insert(arcade_owner).values(arcade=arcade_id, user=user_id) - - result = self.execute(sql) - if result is None: - return None - return result.lastrowid - - def format_serial( - self, platform_code: str, platform_rev: int, serial_num: int, append: int = 4152 - ) -> str: - return f"{platform_code}{platform_rev:02d}A{serial_num:04d}{append:04d}" # 0x41 = A, 0x52 = R - - def validate_keychip_format(self, serial: str) -> bool: - serial = serial.replace("-", "") - if len(serial) != 11 or len(serial) != 15: - self.logger.error( - f"Serial validate failed: Incorrect length for {serial} (len {len(serial)})" - ) - return False - - platform_code = serial[:4] - platform_rev = serial[4:6] - const_a = serial[6] - num = serial[7:11] - append = serial[11:15] - - if re.match("A[7|6]\d[E|X][0|1][0|1|2]A\d{4,8}", serial) is None: - self.logger.error(f"Serial validate failed: {serial} failed regex") - return False - - if len(append) != 0 or len(append) != 4: - self.logger.error( - f"Serial validate failed: {serial} had malformed append {append}" - ) - return False - - if len(num) != 4: - self.logger.error( - f"Serial validate failed: {serial} had malformed number {num}" - ) - return False - - return True +from typing import Optional, Dict +from sqlalchemy import Table, Column +from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint +from sqlalchemy.types import Integer, String, Boolean +from sqlalchemy.sql import func, select +from sqlalchemy.dialects.sqlite import insert +import re + +from core.data.schema.base import BaseData, metadata +from core.const import * + +arcade = Table( + "arcade", + metadata, + Column("id", Integer, primary_key=True, nullable=False), + Column("name", String(255)), + Column("nickname", String(255)), + Column("country", String(3)), + Column("country_id", Integer), + Column("state", String(255)), + Column("city", String(255)), + Column("region_id", Integer), + Column("timezone", String(255)), + mysql_charset="utf8mb4", +) + +machine = Table( + "machine", + metadata, + Column("id", Integer, primary_key=True, nullable=False), + Column( + "arcade", + ForeignKey("arcade.id", ondelete="cascade", onupdate="cascade"), + nullable=False, + ), + Column("serial", String(15), nullable=False), + Column("board", String(15)), + Column("game", String(4)), + Column("country", String(3)), # overwrites if not null + Column("timezone", String(255)), + Column("ota_enable", Boolean), + Column("is_cab", Boolean), + mysql_charset="utf8mb4", +) + +arcade_owner = Table( + "arcade_owner", + metadata, + Column( + "user", + Integer, + ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"), + nullable=False, + ), + Column( + "arcade", + Integer, + ForeignKey("arcade.id", ondelete="cascade", onupdate="cascade"), + nullable=False, + ), + Column("permissions", Integer, nullable=False), + PrimaryKeyConstraint("user", "arcade", name="arcade_owner_pk"), + mysql_charset="utf8mb4", +) + + +class ArcadeData(BaseData): + def get_machine(self, serial: str = None, id: int = None) -> Optional[Dict]: + if serial is not None: + serial = serial.replace("-", "") + if len(serial) == 11: + sql = machine.select(machine.c.serial.like(f"{serial}%")) + + elif len(serial) == 15: + sql = machine.select(machine.c.serial == serial) + + else: + self.logger.error(f"{__name__ }: Malformed serial {serial}") + return None + + elif id is not None: + sql = machine.select(machine.c.id == id) + + else: + self.logger.error(f"{__name__ }: Need either serial or ID to look up!") + return None + + result = self.execute(sql) + if result is None: + return None + return result.fetchone() + + def put_machine( + self, + arcade_id: int, + serial: str = "", + board: str = None, + game: str = None, + is_cab: bool = False, + ) -> Optional[int]: + if arcade_id: + self.logger.error(f"{__name__ }: Need arcade id!") + return None + + sql = machine.insert().values( + arcade=arcade_id, keychip=serial, board=board, game=game, is_cab=is_cab + ) + + result = self.execute(sql) + if result is None: + return None + return result.lastrowid + + def set_machine_serial(self, machine_id: int, serial: str) -> None: + result = self.execute( + machine.update(machine.c.id == machine_id).values(keychip=serial) + ) + if result is None: + self.logger.error( + f"Failed to update serial for machine {machine_id} -> {serial}" + ) + return result.lastrowid + + def set_machine_boardid(self, machine_id: int, boardid: str) -> None: + result = self.execute( + machine.update(machine.c.id == machine_id).values(board=boardid) + ) + if result is None: + self.logger.error( + f"Failed to update board id for machine {machine_id} -> {boardid}" + ) + + def get_arcade(self, id: int) -> Optional[Dict]: + sql = arcade.select(arcade.c.id == id) + result = self.execute(sql) + if result is None: + return None + return result.fetchone() + + def put_arcade( + self, + name: str, + nickname: str = None, + country: str = "JPN", + country_id: int = 1, + state: str = "", + city: str = "", + regional_id: int = 1, + ) -> Optional[int]: + if nickname is None: + nickname = name + + sql = arcade.insert().values( + name=name, + nickname=nickname, + country=country, + country_id=country_id, + state=state, + city=city, + regional_id=regional_id, + ) + + result = self.execute(sql) + if result is None: + return None + return result.lastrowid + + def get_arcade_owners(self, arcade_id: int) -> Optional[Dict]: + sql = select(arcade_owner).where(arcade_owner.c.arcade == arcade_id) + + result = self.execute(sql) + if result is None: + return None + return result.fetchall() + + def add_arcade_owner(self, arcade_id: int, user_id: int) -> None: + sql = insert(arcade_owner).values(arcade=arcade_id, user=user_id) + + result = self.execute(sql) + if result is None: + return None + return result.lastrowid + + def format_serial( + self, platform_code: str, platform_rev: int, serial_num: int, append: int = 4152 + ) -> str: + return f"{platform_code}{platform_rev:02d}A{serial_num:04d}{append:04d}" # 0x41 = A, 0x52 = R + + def validate_keychip_format(self, serial: str) -> bool: + serial = serial.replace("-", "") + if len(serial) != 11 or len(serial) != 15: + self.logger.error( + f"Serial validate failed: Incorrect length for {serial} (len {len(serial)})" + ) + return False + + platform_code = serial[:4] + platform_rev = serial[4:6] + const_a = serial[6] + num = serial[7:11] + append = serial[11:15] + + if re.match("A[7|6]\d[E|X][0|1][0|1|2]A\d{4,8}", serial) is None: + self.logger.error(f"Serial validate failed: {serial} failed regex") + return False + + if len(append) != 0 or len(append) != 4: + self.logger.error( + f"Serial validate failed: {serial} had malformed append {append}" + ) + return False + + if len(num) != 4: + self.logger.error( + f"Serial validate failed: {serial} had malformed number {num}" + ) + return False + + return True diff --git a/core/data/schema/base.py b/core/data/schema/base.py index a53392f..e26640e 100644 --- a/core/data/schema/base.py +++ b/core/data/schema/base.py @@ -1,165 +1,170 @@ -import json -import logging -from random import randrange -from typing import Any, Optional, Dict, List -from sqlalchemy.engine import Row -from sqlalchemy.engine.cursor import CursorResult -from sqlalchemy.engine.base import Connection -from sqlalchemy.sql import text, func, select -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy import MetaData, Table, Column -from sqlalchemy.types import Integer, String, TIMESTAMP, JSON -from sqlalchemy.dialects.mysql import insert - -from core.config import CoreConfig - -metadata = MetaData() - -schema_ver = Table( - "schema_versions", - metadata, - Column("game", String(4), primary_key=True, nullable=False), - Column("version", Integer, nullable=False, server_default="1"), - mysql_charset="utf8mb4", -) - -event_log = Table( - "event_log", - metadata, - Column("id", Integer, primary_key=True, nullable=False), - Column("system", String(255), nullable=False), - Column("type", String(255), nullable=False), - Column("severity", Integer, nullable=False), - Column("message", String(1000), nullable=False), - Column("details", JSON, nullable=False), - Column("when_logged", TIMESTAMP, nullable=False, server_default=func.now()), - mysql_charset="utf8mb4", -) - - -class BaseData: - def __init__(self, cfg: CoreConfig, conn: Connection) -> None: - self.config = cfg - self.conn = conn - self.logger = logging.getLogger("database") - - def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]: - res = None - - try: - self.logger.info(f"SQL Execute: {''.join(str(sql).splitlines())}") - res = self.conn.execute(text(sql), opts) - - except SQLAlchemyError as e: - self.logger.error(f"SQLAlchemy error {e}") - return None - - except UnicodeEncodeError as e: - self.logger.error(f"UnicodeEncodeError error {e}") - return None - - except Exception: - try: - res = self.conn.execute(sql, opts) - - except SQLAlchemyError as e: - self.logger.error(f"SQLAlchemy error {e}") - return None - - except UnicodeEncodeError as e: - self.logger.error(f"UnicodeEncodeError error {e}") - return None - - except Exception: - self.logger.error(f"Unknown error") - raise - - return res - - def generate_id(self) -> int: - """ - Generate a random 5-7 digit id - """ - return randrange(10000, 9999999) - - def get_all_schema_vers(self) -> Optional[List[Row]]: - sql = select(schema_ver) - - result = self.execute(sql) - if result is None: - return None - return result.fetchall() - - def get_schema_ver(self, game: str) -> Optional[int]: - sql = select(schema_ver).where(schema_ver.c.game == game) - - result = self.execute(sql) - if result is None: - return None - - row = result.fetchone() - if row is None: - return None - - return row["version"] - - def touch_schema_ver(self, ver: int, game: str = "CORE") -> Optional[int]: - sql = insert(schema_ver).values(game=game, version=ver) - conflict = sql.on_duplicate_key_update(version=schema_ver.c.version) - - result = self.execute(conflict) - if result is None: - self.logger.error( - f"Failed to update schema version for game {game} (v{ver})" - ) - return None - return result.lastrowid - - def set_schema_ver(self, ver: int, game: str = "CORE") -> Optional[int]: - sql = insert(schema_ver).values(game=game, version=ver) - conflict = sql.on_duplicate_key_update(version=ver) - - result = self.execute(conflict) - if result is None: - self.logger.error( - f"Failed to update schema version for game {game} (v{ver})" - ) - return None - return result.lastrowid - - def log_event( - self, system: str, type: str, severity: int, message: str, details: Dict = {} - ) -> Optional[int]: - sql = event_log.insert().values( - system=system, - type=type, - severity=severity, - message=message, - details=json.dumps(details), - ) - result = self.execute(sql) - - if result is None: - self.logger.error( - f"{__name__}: Failed to insert event into event log! system = {system}, type = {type}, severity = {severity}, message = {message}" - ) - return None - - return result.lastrowid - - def get_event_log(self, entries: int = 100) -> Optional[List[Dict]]: - sql = event_log.select().limit(entries).all() - result = self.execute(sql) - - if result is None: - return None - return result.fetchall() - - def fix_bools(self, data: Dict) -> Dict: - for k, v in data.items(): - if type(v) == str and v.lower() == "true": - data[k] = True - elif type(v) == str and v.lower() == "false": - data[k] = False - - return data +import json +import logging +from random import randrange +from typing import Any, Optional, Dict, List +from sqlalchemy.engine import Row +from sqlalchemy.engine.cursor import CursorResult +from sqlalchemy.engine.base import Connection +from sqlalchemy.sql import text, func, select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy import MetaData, Table, Column +from sqlalchemy.types import Integer, String, TIMESTAMP, JSON +from sqlalchemy.dialects.sqlite import insert + +from core.config import CoreConfig + +metadata = MetaData() + +schema_ver = Table( + "schema_versions", + metadata, + Column("game", String(4), primary_key=True, nullable=False), + Column("version", Integer, nullable=False, server_default="1"), + mysql_charset="utf8mb4", +) + +event_log = Table( + "event_log", + metadata, + Column("id", Integer, primary_key=True, nullable=False), + Column("system", String(255), nullable=False), + Column("type", String(255), nullable=False), + Column("severity", Integer, nullable=False), + Column("message", String(1000), nullable=False), + Column("details", JSON, nullable=False), + Column("when_logged", TIMESTAMP, nullable=False, server_default=func.now()), + mysql_charset="utf8mb4", +) + + +class BaseData: + def __init__(self, cfg: CoreConfig, conn: Connection) -> None: + self.config = cfg + self.conn = conn + self.logger = logging.getLogger("database") + + def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]: + res = None + + try: + self.logger.info(f"SQL Execute: {''.join(str(sql).splitlines())}") + res = self.conn.execute(text(sql), opts) + + except SQLAlchemyError as e: + self.logger.error(f"SQLAlchemy error {e}") + return None + + except UnicodeEncodeError as e: + self.logger.error(f"UnicodeEncodeError error {e}") + return None + + except Exception: + try: + res = self.conn.execute(sql, opts) + + except SQLAlchemyError as e: + self.logger.error(f"SQLAlchemy error {e}") + return None + + except UnicodeEncodeError as e: + self.logger.error(f"UnicodeEncodeError error {e}") + return None + + except Exception: + self.logger.error(f"Unknown error") + raise + + return res + + def setup_sqlite(self) -> None: + if self.config.database.protocol == "sqlite": + self.execute("PRAGMA journal_mode=WAL") + self.execute("PRAGMA synchronous=NORMAL") + + def generate_id(self) -> int: + """ + Generate a random 5-7 digit id + """ + return randrange(10000, 9999999) + + def get_all_schema_vers(self) -> Optional[List[Row]]: + sql = select(schema_ver) + + result = self.execute(sql) + if result is None: + return None + return result.fetchall() + + def get_schema_ver(self, game: str) -> Optional[int]: + sql = select(schema_ver).where(schema_ver.c.game == game) + + result = self.execute(sql) + if result is None: + return None + + row = result.fetchone() + if row is None: + return None + + return row["version"] + + def touch_schema_ver(self, ver: int, game: str = "CORE") -> Optional[int]: + sql = insert(schema_ver).values(game=game, version=ver) + conflict = sql.on_conflict_do_update(set_=dict(version=schema_ver.c.version)) + + result = self.execute(conflict) + if result is None: + self.logger.error( + f"Failed to update schema version for game {game} (v{ver})" + ) + return None + return result.lastrowid + + def set_schema_ver(self, ver: int, game: str = "CORE") -> Optional[int]: + sql = insert(schema_ver).values(game=game, version=ver) + conflict = sql.on_conflict_do_update(set_=dict(version=ver)) + + result = self.execute(conflict) + if result is None: + self.logger.error( + f"Failed to update schema version for game {game} (v{ver})" + ) + return None + return result.lastrowid + + def log_event( + self, system: str, type: str, severity: int, message: str, details: Dict = {} + ) -> Optional[int]: + sql = event_log.insert().values( + system=system, + type=type, + severity=severity, + message=message, + details=json.dumps(details), + ) + result = self.execute(sql) + + if result is None: + self.logger.error( + f"{__name__}: Failed to insert event into event log! system = {system}, type = {type}, severity = {severity}, message = {message}" + ) + return None + + return result.lastrowid + + def get_event_log(self, entries: int = 100) -> Optional[List[Dict]]: + sql = event_log.select().limit(entries).all() + result = self.execute(sql) + + if result is None: + return None + return result.fetchall() + + def fix_bools(self, data: Dict) -> Dict: + for k, v in data.items(): + if type(v) == str and v.lower() == "true": + data[k] = True + elif type(v) == str and v.lower() == "false": + data[k] = False + + return data diff --git a/core/data/schema/card.py b/core/data/schema/card.py index d8f5fc0..f5800f2 100644 --- a/core/data/schema/card.py +++ b/core/data/schema/card.py @@ -1,104 +1,104 @@ -from typing import Dict, List, Optional -from sqlalchemy import Table, Column, UniqueConstraint -from sqlalchemy.types import Integer, String, Boolean, TIMESTAMP -from sqlalchemy.sql.schema import ForeignKey -from sqlalchemy.sql import func -from sqlalchemy.engine import Row - -from core.data.schema.base import BaseData, metadata - -aime_card = Table( - "aime_card", - metadata, - Column("id", Integer, primary_key=True, nullable=False), - Column( - "user", - ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"), - nullable=False, - ), - Column("access_code", String(20)), - Column("created_date", TIMESTAMP, server_default=func.now()), - Column("last_login_date", TIMESTAMP, onupdate=func.now()), - Column("is_locked", Boolean, server_default="0"), - Column("is_banned", Boolean, server_default="0"), - UniqueConstraint("user", "access_code", name="aime_card_uk"), - mysql_charset="utf8mb4", -) - - -class CardData(BaseData): - def get_card_by_access_code(self, access_code: str) -> Optional[Row]: - sql = aime_card.select(aime_card.c.access_code == access_code) - - result = self.execute(sql) - if result is None: - return None - return result.fetchone() - - def get_card_by_id(self, card_id: int) -> Optional[Row]: - sql = aime_card.select(aime_card.c.id == card_id) - - result = self.execute(sql) - if result is None: - return None - return result.fetchone() - - def update_access_code(self, old_ac: str, new_ac: str) -> None: - sql = aime_card.update(aime_card.c.access_code == old_ac).values( - access_code=new_ac - ) - - result = self.execute(sql) - if result is None: - self.logger.error( - f"Failed to change card access code from {old_ac} to {new_ac}" - ) - - def get_user_id_from_card(self, access_code: str) -> Optional[int]: - """ - Given a 20 digit access code as a string, get the user id associated with that card - """ - card = self.get_card_by_access_code(access_code) - if card is None: - return None - - return int(card["user"]) - - def delete_card(self, card_id: int) -> None: - sql = aime_card.delete(aime_card.c.id == card_id) - - result = self.execute(sql) - if result is None: - self.logger.error(f"Failed to delete card with id {card_id}") - - def get_user_cards(self, aime_id: int) -> Optional[List[Row]]: - """ - Returns all cards owned by a user - """ - sql = aime_card.select(aime_card.c.user == aime_id) - result = self.execute(sql) - if result is None: - return None - return result.fetchall() - - def create_card(self, user_id: int, access_code: str) -> Optional[int]: - """ - Given a aime_user id and a 20 digit access code as a string, create a card and return the ID if successful - """ - sql = aime_card.insert().values(user=user_id, access_code=access_code) - result = self.execute(sql) - if result is None: - return None - return result.lastrowid - - def to_access_code(self, luid: str) -> str: - """ - Given a felica cards internal 16 hex character luid, convert it to a 0-padded 20 digit access code as a string - """ - return f"{int(luid, base=16):0{20}}" - - def to_idm(self, access_code: str) -> str: - """ - Given a 20 digit access code as a string, return the 16 hex character luid - """ - return f"{int(access_code):0{16}x}" +from typing import Dict, List, Optional +from sqlalchemy import Table, Column, UniqueConstraint +from sqlalchemy.types import Integer, String, Boolean, TIMESTAMP +from sqlalchemy.sql.schema import ForeignKey +from sqlalchemy.sql import func +from sqlalchemy.engine import Row + +from core.data.schema.base import BaseData, metadata + +aime_card = Table( + "aime_card", + metadata, + Column("id", Integer, primary_key=True, nullable=False), + Column( + "user", + ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"), + nullable=False, + ), + Column("access_code", String(20)), + Column("created_date", TIMESTAMP, server_default=func.now()), + Column("last_login_date", TIMESTAMP, onupdate=func.now()), + Column("is_locked", Boolean, server_default="0"), + Column("is_banned", Boolean, server_default="0"), + UniqueConstraint("user", "access_code", name="aime_card_uk"), + mysql_charset="utf8mb4", +) + + +class CardData(BaseData): + def get_card_by_access_code(self, access_code: str) -> Optional[Row]: + sql = aime_card.select(aime_card.c.access_code == access_code) + + result = self.execute(sql) + if result is None: + return None + return result.fetchone() + + def get_card_by_id(self, card_id: int) -> Optional[Row]: + sql = aime_card.select(aime_card.c.id == card_id) + + result = self.execute(sql) + if result is None: + return None + return result.fetchone() + + def update_access_code(self, old_ac: str, new_ac: str) -> None: + sql = aime_card.update(aime_card.c.access_code == old_ac).values( + access_code=new_ac + ) + + result = self.execute(sql) + if result is None: + self.logger.error( + f"Failed to change card access code from {old_ac} to {new_ac}" + ) + + def get_user_id_from_card(self, access_code: str) -> Optional[int]: + """ + Given a 20 digit access code as a string, get the user id associated with that card + """ + card = self.get_card_by_access_code(access_code) + if card is None: + return None + + return int(card["user"]) + + def delete_card(self, card_id: int) -> None: + sql = aime_card.delete(aime_card.c.id == card_id) + + result = self.execute(sql) + if result is None: + self.logger.error(f"Failed to delete card with id {card_id}") + + def get_user_cards(self, aime_id: int) -> Optional[List[Row]]: + """ + Returns all cards owned by a user + """ + sql = aime_card.select(aime_card.c.user == aime_id) + result = self.execute(sql) + if result is None: + return None + return result.fetchall() + + def create_card(self, user_id: int, access_code: str) -> Optional[int]: + """ + Given a aime_user id and a 20 digit access code as a string, create a card and return the ID if successful + """ + sql = aime_card.insert().values(user=user_id, access_code=access_code) + result = self.execute(sql) + if result is None: + return None + return result.lastrowid + + def to_access_code(self, luid: str) -> str: + """ + Given a felica cards internal 16 hex character luid, convert it to a 0-padded 20 digit access code as a string + """ + return f"{int(luid, base=16):0{20}}" + + def to_idm(self, access_code: str) -> str: + """ + Given a 20 digit access code as a string, return the 16 hex character luid + """ + return f"{int(access_code):0{16}x}" diff --git a/core/data/schema/user.py b/core/data/schema/user.py index 6a95005..3d18436 100644 --- a/core/data/schema/user.py +++ b/core/data/schema/user.py @@ -1,109 +1,116 @@ -from enum import Enum -from typing import Optional, List -from sqlalchemy import Table, Column -from sqlalchemy.types import Integer, String, TIMESTAMP -from sqlalchemy.sql import func -from sqlalchemy.dialects.mysql import insert -from sqlalchemy.sql import func, select -from sqlalchemy.engine import Row -import bcrypt - -from core.data.schema.base import BaseData, metadata - -aime_user = Table( - "aime_user", - metadata, - Column("id", Integer, nullable=False, primary_key=True, autoincrement=True), - Column("username", String(25), unique=True), - Column("email", String(255), unique=True), - Column("password", String(255)), - Column("permissions", Integer), - Column("created_date", TIMESTAMP, server_default=func.now()), - Column("last_login_date", TIMESTAMP, onupdate=func.now()), - Column("suspend_expire_time", TIMESTAMP), - mysql_charset="utf8mb4", -) - - -class PermissionBits(Enum): - PermUser = 1 - PermMod = 2 - PermSysAdmin = 4 - - -class UserData(BaseData): - def create_user( - self, - id: int = None, - username: str = None, - email: str = None, - password: str = None, - permission: int = 1, - ) -> Optional[int]: - if id is None: - sql = insert(aime_user).values( - username=username, - email=email, - password=password, - permissions=permission, - ) - else: - sql = insert(aime_user).values( - id=id, - username=username, - email=email, - password=password, - permissions=permission, - ) - - conflict = sql.on_duplicate_key_update( - username=username, email=email, password=password, permissions=permission - ) - - result = self.execute(conflict) - if result is None: - return None - return result.lastrowid - - def get_user(self, user_id: int) -> Optional[Row]: - sql = select(aime_user).where(aime_user.c.id == user_id) - result = self.execute(sql) - if result is None: - return False - return result.fetchone() - - def check_password(self, user_id: int, passwd: bytes = None) -> bool: - usr = self.get_user(user_id) - if usr is None: - return False - - if usr["password"] is None: - return False - - if passwd is None or not passwd: - return False - - return bcrypt.checkpw(passwd, usr["password"].encode()) - - def reset_autoincrement(self, ai_value: int) -> None: - # ALTER TABLE isn't in sqlalchemy so we do this the ugly way - sql = f"ALTER TABLE aime_user AUTO_INCREMENT={ai_value}" - self.execute(sql) - - def delete_user(self, user_id: int) -> None: - sql = aime_user.delete(aime_user.c.id == user_id) - - result = self.execute(sql) - if result is None: - self.logger.error(f"Failed to delete user with id {user_id}") - - def get_unregistered_users(self) -> List[Row]: - """ - Returns a list of users who have not registered with the webui. They may or may not have cards. - """ - sql = select(aime_user).where(aime_user.c.password == None) - - result = self.execute(sql) - if result is None: - return None - return result.fetchall() +from enum import Enum +from typing import Optional, List +from sqlalchemy import Table, Column +from sqlalchemy.types import Integer, String, TIMESTAMP +from sqlalchemy.sql import func +from sqlalchemy.dialects.sqlite import insert +from sqlalchemy.sql import func, select +from sqlalchemy.engine import Row +import bcrypt + +from core.data.schema.base import BaseData, metadata + +aime_user = Table( + "aime_user", + metadata, + Column("id", Integer, nullable=False, primary_key=True, autoincrement=True), + Column("username", String(25), unique=True), + Column("email", String(255), unique=True), + Column("password", String(255)), + Column("permissions", Integer), + Column("created_date", TIMESTAMP, server_default=func.now()), + Column("last_login_date", TIMESTAMP, onupdate=func.now()), + Column("suspend_expire_time", TIMESTAMP), + mysql_charset="utf8mb4", + sqlite_autoincrement=True, +) + + +class PermissionBits(Enum): + PermUser = 1 + PermMod = 2 + PermSysAdmin = 4 + + +class UserData(BaseData): + def create_user( + self, + id: int = None, + username: str = None, + email: str = None, + password: str = None, + permission: int = 1, + ) -> Optional[int]: + if id is None: + sql = insert(aime_user).values( + username=username, + email=email, + password=password, + permissions=permission, + ) + else: + sql = insert(aime_user).values( + id=id, + username=username, + email=email, + password=password, + permissions=permission, + ) + + conflict = sql.on_conflict_do_update( + set_=dict( + username=username, + email=email, + password=password, + permissions=permission, + ) + ) + + result = self.execute(conflict) + if result is None: + return None + return result.lastrowid + + def get_user(self, user_id: int) -> Optional[Row]: + sql = select(aime_user).where(aime_user.c.id == user_id) + result = self.execute(sql) + if result is None: + return False + return result.fetchone() + + def check_password(self, user_id: int, passwd: bytes = None) -> bool: + usr = self.get_user(user_id) + if usr is None: + return False + + if usr["password"] is None: + return False + + if passwd is None or not passwd: + return False + + return bcrypt.checkpw(passwd, usr["password"].encode()) + + def reset_autoincrement(self, ai_value: int) -> None: + # ALTER TABLE isn't in sqlalchemy so we do this the ugly way + # sql = f"ALTER TABLE aime_user AUTO_INCREMENT={ai_value}" + sql = f"INSERT OR REPLACE INTO sqlite_sequence VALUES ('aime_user', {ai_value})" + self.execute(sql) + + def delete_user(self, user_id: int) -> None: + sql = aime_user.delete(aime_user.c.id == user_id) + + result = self.execute(sql) + if result is None: + self.logger.error(f"Failed to delete user with id {user_id}") + + def get_unregistered_users(self) -> List[Row]: + """ + Returns a list of users who have not registered with the webui. They may or may not have cards. + """ + sql = select(aime_user).where(aime_user.c.password == None) + + result = self.execute(sql) + if result is None: + return None + return result.fetchall() diff --git a/core/frontend.py b/core/frontend.py index f01be50..128032b 100644 --- a/core/frontend.py +++ b/core/frontend.py @@ -227,22 +227,25 @@ class FE_User(FE_Base): usr_sesh = IUserSession(sesh) if usr_sesh.userId == 0: return redirectTo(b"/gate", request) - + cards = self.data.card.get_user_cards(usr_sesh.userId) user = self.data.user.get_user(usr_sesh.userId) card_data = [] for c in cards: - if c['is_locked']: - status = 'Locked' - elif c['is_banned']: - status = 'Banned' + if c["is_locked"]: + status = "Locked" + elif c["is_banned"]: + status = "Banned" else: - status = 'Active' - - card_data.append({'access_code': c['access_code'], 'status': status}) + status = "Active" + + card_data.append({"access_code": c["access_code"], "status": status}) return template.render( - title=f"{self.core_config.server.name} | Account", sesh=vars(usr_sesh), cards=card_data, username=user['username'] + title=f"{self.core_config.server.name} | Account", + sesh=vars(usr_sesh), + cards=card_data, + username=user["username"], ).encode("utf-16") diff --git a/core/frontend/user/index.jinja b/core/frontend/user/index.jinja index 2911e67..1c7c9c9 100644 --- a/core/frontend/user/index.jinja +++ b/core/frontend/user/index.jinja @@ -1,6 +1,6 @@ {% extends "core/frontend/index.jinja" %} {% block content %} -

Management for {{ username }}

+

Management for {{ username.decode("utf-8") }}

Cards