diff --git a/core/aimedb.py b/core/aimedb.py index 6d5bd57..5a84662 100644 --- a/core/aimedb.py +++ b/core/aimedb.py @@ -1,14 +1,18 @@ -import logging, coloredlogs -from Crypto.Cipher import AES -from typing import Dict, Tuple, Callable, Union, Optional import asyncio +import logging from logging.handlers import TimedRotatingFileHandler +from typing import Callable, Dict, Optional, Tuple, Union + +import coloredlogs +from Crypto.Cipher import AES from core.config import CoreConfig -from core.utils import create_sega_auth_key from core.data import Data +from core.utils import create_sega_auth_key + from .adb_handlers import * + class AimedbServlette(): request_list: Dict[int, Tuple[Callable[[bytes, int], Union[ADBBaseResponse, bytes]], int, str]] = {} def __init__(self, core_cfg: CoreConfig) -> None: diff --git a/core/allnet.py b/core/allnet.py index 9eb6595..3113cf9 100644 --- a/core/allnet.py +++ b/core/allnet.py @@ -181,6 +181,11 @@ class AllnetServlet: if machine is not None: arcade = await self.data.arcade.get_arcade(machine["arcade"]) + + if arcade is None: + self.logger.error("The arcade %s belongs to (ID %s) does not exist!", req.serial, machine["arcade"]) + return PlainTextResponse("stat=-3\n", status_code=500) + if self.config.server.check_arcade_ip: if arcade["ip"] and arcade["ip"] is not None and arcade["ip"] != req.ip: msg = f"{req.serial} attempted allnet auth from bad IP {req.ip} (expected {arcade['ip']})." diff --git a/core/data/schema/arcade.py b/core/data/schema/arcade.py index 653fe7c..16eb81e 100644 --- a/core/data/schema/arcade.py +++ b/core/data/schema/arcade.py @@ -3,7 +3,7 @@ from typing import List, Optional from sqlalchemy import Column, Table, and_, or_ from sqlalchemy.dialects.mysql import insert -from sqlalchemy.engine import Row +from sqlalchemy.engine import Row, RowMapping from sqlalchemy.sql import func, select from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint from sqlalchemy.types import JSON, Boolean, Integer, String @@ -69,21 +69,23 @@ arcade_owner: Table = Table( class ArcadeData(BaseData): - async def get_machine(self, serial: Optional[str] = None, id: Optional[int] = None) -> Optional[Row]: + async def get_machine( + self, serial: Optional[str] = None, id: Optional[int] = None + ) -> Optional[RowMapping]: if serial is not None: serial = serial.replace("-", "") if len(serial) == 11: - sql = machine.select(machine.c.serial.like(f"{serial}%")) + sql = machine.select().where(machine.c.serial.like(f"{serial}%")) elif len(serial) == 15: - sql = machine.select(machine.c.serial == serial) + sql = machine.select().where(machine.c.serial == serial) else: self.logger.error(f"{__name__ }: Malformed serial {serial}") return None elif id is not None: - sql = machine.select(machine.c.id == id) + sql = machine.select().where(machine.c.id == id) else: self.logger.error(f"{__name__ }: Need either serial or ID to look up!") @@ -92,7 +94,7 @@ class ArcadeData(BaseData): result = await self.execute(sql) if result is None: return None - return result.fetchone() + return result.mappings().fetchone() async def create_machine( self, @@ -117,7 +119,7 @@ class ArcadeData(BaseData): async def set_machine_serial(self, machine_id: int, serial: str) -> None: result = await self.execute( - machine.update(machine.c.id == machine_id).values(keychip=serial) + machine.update().where(machine.c.id == machine_id).values(keychip=serial) ) if result is None: self.logger.error( @@ -127,26 +129,26 @@ class ArcadeData(BaseData): async def set_machine_boardid(self, machine_id: int, boardid: str) -> None: result = await self.execute( - machine.update(machine.c.id == machine_id).values(board=boardid) + machine.update().where(machine.c.id == machine_id).values(board=boardid) ) if result is None: self.logger.error( f"Failed to update board id for machine {machine_id} -> {boardid}" ) - async def get_arcade(self, id: int) -> Optional[Row]: - sql = arcade.select(arcade.c.id == id) + async def get_arcade(self, id: int) -> Optional[RowMapping]: + sql = arcade.select().where(arcade.c.id == id) result = await self.execute(sql) if result is None: return None - return result.fetchone() - - async def get_arcade_machines(self, id: int) -> Optional[List[Row]]: - sql = machine.select(machine.c.arcade == id) + return result.mappings().fetchone() + + async def get_arcade_machines(self, id: int) -> Optional[List[RowMapping]]: + sql = machine.select().where(machine.c.arcade == id) result = await self.execute(sql) if result is None: return None - return result.fetchall() + return result.mappings().fetchall() async def create_arcade( self, @@ -177,14 +179,22 @@ class ArcadeData(BaseData): return result.lastrowid async def get_arcades_managed_by_user(self, user_id: int) -> Optional[List[Row]]: - sql = select(arcade).join(arcade_owner, arcade_owner.c.arcade == arcade.c.id).where(arcade_owner.c.user == user_id) + sql = ( + select(arcade) + .join(arcade_owner, arcade_owner.c.arcade == arcade.c.id) + .where(arcade_owner.c.user == user_id) + ) result = await self.execute(sql) if result is None: return False return result.fetchall() - - async def get_manager_permissions(self, user_id: int, arcade_id: int) -> Optional[int]: - sql = select(arcade_owner.c.permissions).where(and_(arcade_owner.c.user == user_id, arcade_owner.c.arcade == arcade_id)) + + async def get_manager_permissions( + self, user_id: int, arcade_id: int + ) -> Optional[int]: + sql = select(arcade_owner.c.permissions).where( + and_(arcade_owner.c.user == user_id, arcade_owner.c.arcade == arcade_id) + ) result = await self.execute(sql) if result is None: return False @@ -207,7 +217,9 @@ class ArcadeData(BaseData): return result.lastrowid async def get_arcade_by_name(self, name: str) -> Optional[List[Row]]: - sql = arcade.select(or_(arcade.c.name.like(f"%{name}%"), arcade.c.nickname.like(f"%{name}%"))) + sql = arcade.select().where( + or_(arcade.c.name.like(f"%{name}%"), arcade.c.nickname.like(f"%{name}%")) + ) result = await self.execute(sql) if result is None: return None @@ -219,25 +231,38 @@ class ArcadeData(BaseData): if result is None: return None return result.fetchall() - + async def get_num_generated_keychips(self) -> Optional[int]: - result = await self.execute(select(func.count("serial LIKE 'A69A%'")).select_from(machine)) + result = await self.execute( + select(func.count("serial LIKE 'A69A%'")).select_from(machine) + ) if result: - return result.fetchone()['count_1'] + return result.mappings().fetchone()["count_1"] self.logger.error("Failed to count machine serials that start with A69A!") def format_serial( - self, platform_code: str, platform_rev: int, serial_letter: str, serial_num: int, append: int, dash: bool = False + self, + platform_code: str, + platform_rev: int, + serial_letter: str, + serial_num: int, + append: int, + dash: bool = False, ) -> str: return f"{platform_code}{'-' if dash else ''}{platform_rev:02d}{serial_letter}{serial_num:04d}{append:04d}" def validate_keychip_format(self, serial: str) -> bool: # For the 2nd letter, E and X are the only "real" values that have been observed (A is used for generated keychips) - if re.fullmatch(r"^A[0-9]{2}[A-Z][-]?[0-9]{2}[A-HJ-NP-Z][0-9]{4}([0-9]{4})?$", serial) is None: + if ( + re.fullmatch( + r"^A[0-9]{2}[A-Z][-]?[0-9]{2}[A-HJ-NP-Z][0-9]{4}([0-9]{4})?$", serial + ) + is None + ): return False - + return True - + # Thanks bottersnike! def get_keychip_suffix(self, year: int, month: int) -> str: assert year > 1957 @@ -252,7 +277,6 @@ class ArcadeData(BaseData): month = ((month - 1) + 9) % 12 # Offset so April=0 return f"{year:02}{month // 6:01}{month % 6 + 1:01}" - def parse_keychip_suffix(self, suffix: str) -> tuple[int, int]: year = int(suffix[0:2]) half = int(suffix[2]) diff --git a/core/data/schema/card.py b/core/data/schema/card.py index 254b19e..9815e68 100644 --- a/core/data/schema/card.py +++ b/core/data/schema/card.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional from sqlalchemy import Column, Table, UniqueConstraint -from sqlalchemy.engine import Row +from sqlalchemy.engine import Row, RowMapping from sqlalchemy.sql import func from sqlalchemy.sql.schema import ForeignKey from sqlalchemy.types import BIGINT, TIMESTAMP, VARCHAR, Boolean, Integer, String @@ -12,7 +12,11 @@ aime_card: Table = Table( "aime_card", metadata, Column("id", Integer, primary_key=True, nullable=False), - Column("user", ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"), nullable=False), + Column( + "user", + ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"), + nullable=False, + ), Column("access_code", String(20), nullable=False, unique=True), Column("idm", String(16), unique=True), Column("chip_id", BIGINT, unique=True), @@ -28,27 +32,29 @@ aime_card: Table = Table( class CardData(BaseData): moble_os_codes = set([0x06, 0x07, 0x10, 0x12, 0x13, 0x14, 0x15, 0x17, 0x18]) - card_os_codes = set([0x20, 0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, 0xF7]) + card_os_codes = set([0x20, 0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, 0xF7]) - async def get_card_by_access_code(self, access_code: str) -> Optional[Row]: - sql = aime_card.select(aime_card.c.access_code == access_code) + async def get_card_by_access_code(self, access_code: str) -> Optional[RowMapping]: + sql = aime_card.select().where(aime_card.c.access_code == access_code) result = await self.execute(sql) if result is None: return None - return result.fetchone() + return result.mappings().fetchone() - async def get_card_by_id(self, card_id: int) -> Optional[Row]: - sql = aime_card.select(aime_card.c.id == card_id) + async def get_card_by_id(self, card_id: int) -> Optional[RowMapping]: + sql = aime_card.select().where(aime_card.c.id == card_id) result = await self.execute(sql) if result is None: return None - return result.fetchone() + return result.mappings().fetchone() async def update_access_code(self, old_ac: str, new_ac: str) -> None: - sql = aime_card.update(aime_card.c.access_code == old_ac).values( - access_code=new_ac + sql = ( + aime_card.update() + .where(aime_card.c.access_code == old_ac) + .values(access_code=new_ac) ) result = await self.execute(sql) @@ -65,7 +71,7 @@ class CardData(BaseData): if card is None: return None - return int(card["user"]) + return int(card.user) async def get_card_banned(self, access_code: str) -> Optional[bool]: """ @@ -74,10 +80,10 @@ class CardData(BaseData): card = await self.get_card_by_access_code(access_code) if card is None: return None - if card["is_banned"]: + if card.is_banned: return True return False - + async def get_card_locked(self, access_code: str) -> Optional[bool]: """ Given a 20 digit access code as a string, check if the card is locked @@ -85,26 +91,26 @@ class CardData(BaseData): card = await self.get_card_by_access_code(access_code) if card is None: return None - if card["is_locked"]: + if card.is_locked: return True return False async def delete_card(self, card_id: int) -> None: - sql = aime_card.delete(aime_card.c.id == card_id) + sql = aime_card.delete().where(aime_card.c.id == card_id) result = await self.execute(sql) if result is None: self.logger.error(f"Failed to delete card with id {card_id}") - async def get_user_cards(self, aime_id: int) -> Optional[List[Row]]: + async def get_user_cards(self, aime_id: int) -> Optional[List[RowMapping]]: """ Returns all cards owned by a user """ - sql = aime_card.select(aime_card.c.user == aime_id) + sql = aime_card.select().where(aime_card.c.user == aime_id) result = await self.execute(sql) if result is None: return None - return result.fetchall() + return result.mappings().fetchall() async def create_card(self, user_id: int, access_code: str) -> Optional[int]: """ @@ -117,41 +123,67 @@ class CardData(BaseData): return result.lastrowid async def update_card_last_login(self, access_code: str) -> None: - sql = aime_card.update(aime_card.c.access_code == access_code).values( - last_login_date=func.now() + sql = ( + aime_card.update() + .where(aime_card.c.access_code == access_code) + .values(last_login_date=func.now()) ) - + result = await self.execute(sql) if result is None: self.logger.warn(f"Failed to update last login time for {access_code}") - async def get_card_by_idm(self, idm: str) -> Optional[Row]: - result = await self.execute(aime_card.select(aime_card.c.idm == idm)) + async def get_card_by_idm(self, idm: str) -> Optional[RowMapping]: + result = await self.execute(aime_card.select().where(aime_card.c.idm == idm)) if result: - return result.fetchone() + return result.mappings().fetchone() async def get_card_by_chip_id(self, chip_id: int) -> Optional[Row]: - result = await self.execute(aime_card.select(aime_card.c.chip_id == chip_id)) + result = await self.execute( + aime_card.select().where(aime_card.c.chip_id == chip_id) + ) if result: return result.fetchone() - async def set_chip_id_by_access_code(self, access_code: str, chip_id: int) -> Optional[Row]: - result = await self.execute(aime_card.update(aime_card.c.access_code == access_code).values(chip_id=chip_id)) + async def set_chip_id_by_access_code( + self, access_code: str, chip_id: int + ) -> Optional[Row]: + result = await self.execute( + aime_card.update() + .where(aime_card.c.access_code == access_code) + .values(chip_id=chip_id) + ) if not result: - self.logger.error(f"Failed to update chip ID to {chip_id} for {access_code}") + self.logger.error( + f"Failed to update chip ID to {chip_id} for {access_code}" + ) async def set_idm_by_access_code(self, access_code: str, idm: str) -> Optional[Row]: - result = await self.execute(aime_card.update(aime_card.c.access_code == access_code).values(idm=idm)) + result = await self.execute( + aime_card.update() + .where(aime_card.c.access_code == access_code) + .values(idm=idm) + ) if not result: self.logger.error(f"Failed to update IDm to {idm} for {access_code}") async def set_access_code_by_access_code(self, old_ac: str, new_ac: str) -> None: - result = await self.execute(aime_card.update(aime_card.c.access_code == old_ac).values(access_code=new_ac)) + result = await self.execute( + aime_card.update() + .where(aime_card.c.access_code == old_ac) + .values(access_code=new_ac) + ) if not result: - self.logger.error(f"Failed to change card access code from {old_ac} to {new_ac}") + self.logger.error( + f"Failed to change card access code from {old_ac} to {new_ac}" + ) async def set_memo_by_access_code(self, access_code: str, memo: str) -> None: - result = await self.execute(aime_card.update(aime_card.c.access_code == access_code).values(memo=memo)) + result = await self.execute( + aime_card.update() + .where(aime_card.c.access_code == access_code) + .values(memo=memo) + ) if not result: self.logger.error(f"Failed to add memo to card {access_code}") diff --git a/core/data/schema/user.py b/core/data/schema/user.py index 8686f08..03deb4b 100644 --- a/core/data/schema/user.py +++ b/core/data/schema/user.py @@ -3,7 +3,7 @@ from typing import List, Optional import bcrypt from sqlalchemy import Column, Table from sqlalchemy.dialects.mysql import insert -from sqlalchemy.engine import Row +from sqlalchemy.engine import Row, RowMapping from sqlalchemy.sql import func, select from sqlalchemy.types import TIMESTAMP, Integer, String @@ -23,6 +23,7 @@ aime_user: Table = Table( mysql_charset="utf8mb4", ) + class UserData(BaseData): async def create_user( self, @@ -57,28 +58,28 @@ class UserData(BaseData): return None return result.lastrowid - async def get_user(self, user_id: int) -> Optional[Row]: + async def get_user(self, user_id: int) -> Optional[RowMapping]: sql = select(aime_user).where(aime_user.c.id == user_id) result = await self.execute(sql) if result is None: return False - return result.fetchone() + return result.mappings().fetchone() async def check_password(self, user_id: int, passwd: bytes = None) -> bool: usr = await self.get_user(user_id) if usr is None: return False - if usr["password"] is None: + if usr.password is None: return False - + if passwd is None or not passwd: return False - return bcrypt.checkpw(passwd, usr["password"].encode()) + return bcrypt.checkpw(passwd, usr.password.encode()) async def delete_user(self, user_id: int) -> None: - sql = aime_user.delete(aime_user.c.id == user_id) + sql = aime_user.delete().where(aime_user.c.id == user_id) result = await self.execute(sql) if result is None: @@ -103,24 +104,35 @@ class UserData(BaseData): return result.fetchone() async def find_user_by_username(self, username: str) -> List[Row]: - sql = aime_user.select(aime_user.c.username.like(f"%{username}%")) + sql = aime_user.select().where(aime_user.c.username.like(f"%{username}%")) result = await self.execute(sql) if result is None: return False return result.fetchall() async def change_password(self, user_id: int, new_passwd: str) -> bool: - sql = aime_user.update(aime_user.c.id == user_id).values(password = new_passwd) + sql = ( + aime_user.update() + .where(aime_user.c.id == user_id) + .values(password=new_passwd) + ) result = await self.execute(sql) return result is not None async def change_username(self, user_id: int, new_name: str) -> bool: - sql = aime_user.update(aime_user.c.id == user_id).values(username = new_name) + sql = ( + aime_user.update() + .where(aime_user.c.id == user_id) + .values(username=new_name) + ) result = await self.execute(sql) return result is not None - async def get_user_by_username(self, username: str) -> Optional[Row]: - result = await self.execute(aime_user.select(aime_user.c.username == username)) - if result: return result.fetchone() + async def get_user_by_username(self, username: str) -> Optional[RowMapping]: + result = await self.execute( + aime_user.select().where(aime_user.c.username == username) + ) + if result: + return result.mappings().fetchone() diff --git a/core/frontend.py b/core/frontend.py index c593828..7025a14 100644 --- a/core/frontend.py +++ b/core/frontend.py @@ -111,7 +111,7 @@ class FrontendServlet(): self.arcade = FE_Arcade(cfg, self.environment) self.machine = FE_Machine(cfg, self.environment) - def get_routes(self) -> List[Route]: + def get_routes(self) -> List[Union[Route, Mount]]: g_routes = [] for nav_name, g_data in self.environment.globals["game_list"].items(): g_routes.append(Mount(g_data['url'], routes=g_data['class'].get_routes()))