[core] Address SQLAlchemy deprecations

This commit is contained in:
2024-11-18 10:49:41 +07:00
parent 58a5177a30
commit c11aae58a6
6 changed files with 156 additions and 79 deletions

View File

@ -1,14 +1,18 @@
import logging, coloredlogs
from Crypto.Cipher import AES
from typing import Dict, Tuple, Callable, Union, Optional
import asyncio import asyncio
import logging
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
from typing import Callable, Dict, Optional, Tuple, Union
import coloredlogs
from Crypto.Cipher import AES
from core.config import CoreConfig from core.config import CoreConfig
from core.utils import create_sega_auth_key
from core.data import Data from core.data import Data
from core.utils import create_sega_auth_key
from .adb_handlers import * from .adb_handlers import *
class AimedbServlette(): class AimedbServlette():
request_list: Dict[int, Tuple[Callable[[bytes, int], Union[ADBBaseResponse, bytes]], int, str]] = {} request_list: Dict[int, Tuple[Callable[[bytes, int], Union[ADBBaseResponse, bytes]], int, str]] = {}
def __init__(self, core_cfg: CoreConfig) -> None: def __init__(self, core_cfg: CoreConfig) -> None:

View File

@ -181,6 +181,11 @@ class AllnetServlet:
if machine is not None: if machine is not None:
arcade = await self.data.arcade.get_arcade(machine["arcade"]) arcade = await self.data.arcade.get_arcade(machine["arcade"])
if arcade is None:
self.logger.error("The arcade %s belongs to (ID %s) does not exist!", req.serial, machine["arcade"])
return PlainTextResponse("stat=-3\n", status_code=500)
if self.config.server.check_arcade_ip: if self.config.server.check_arcade_ip:
if arcade["ip"] and arcade["ip"] is not None and arcade["ip"] != req.ip: if arcade["ip"] and arcade["ip"] is not None and arcade["ip"] != req.ip:
msg = f"{req.serial} attempted allnet auth from bad IP {req.ip} (expected {arcade['ip']})." msg = f"{req.serial} attempted allnet auth from bad IP {req.ip} (expected {arcade['ip']})."

View File

@ -3,7 +3,7 @@ from typing import List, Optional
from sqlalchemy import Column, Table, and_, or_ from sqlalchemy import Column, Table, and_, or_
from sqlalchemy.dialects.mysql import insert from sqlalchemy.dialects.mysql import insert
from sqlalchemy.engine import Row from sqlalchemy.engine import Row, RowMapping
from sqlalchemy.sql import func, select from sqlalchemy.sql import func, select
from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint
from sqlalchemy.types import JSON, Boolean, Integer, String from sqlalchemy.types import JSON, Boolean, Integer, String
@ -69,21 +69,23 @@ arcade_owner: Table = Table(
class ArcadeData(BaseData): class ArcadeData(BaseData):
async def get_machine(self, serial: Optional[str] = None, id: Optional[int] = None) -> Optional[Row]: async def get_machine(
self, serial: Optional[str] = None, id: Optional[int] = None
) -> Optional[RowMapping]:
if serial is not None: if serial is not None:
serial = serial.replace("-", "") serial = serial.replace("-", "")
if len(serial) == 11: if len(serial) == 11:
sql = machine.select(machine.c.serial.like(f"{serial}%")) sql = machine.select().where(machine.c.serial.like(f"{serial}%"))
elif len(serial) == 15: elif len(serial) == 15:
sql = machine.select(machine.c.serial == serial) sql = machine.select().where(machine.c.serial == serial)
else: else:
self.logger.error(f"{__name__ }: Malformed serial {serial}") self.logger.error(f"{__name__ }: Malformed serial {serial}")
return None return None
elif id is not None: elif id is not None:
sql = machine.select(machine.c.id == id) sql = machine.select().where(machine.c.id == id)
else: else:
self.logger.error(f"{__name__ }: Need either serial or ID to look up!") self.logger.error(f"{__name__ }: Need either serial or ID to look up!")
@ -92,7 +94,7 @@ class ArcadeData(BaseData):
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return None return None
return result.fetchone() return result.mappings().fetchone()
async def create_machine( async def create_machine(
self, self,
@ -117,7 +119,7 @@ class ArcadeData(BaseData):
async def set_machine_serial(self, machine_id: int, serial: str) -> None: async def set_machine_serial(self, machine_id: int, serial: str) -> None:
result = await self.execute( result = await self.execute(
machine.update(machine.c.id == machine_id).values(keychip=serial) machine.update().where(machine.c.id == machine_id).values(keychip=serial)
) )
if result is None: if result is None:
self.logger.error( self.logger.error(
@ -127,26 +129,26 @@ class ArcadeData(BaseData):
async def set_machine_boardid(self, machine_id: int, boardid: str) -> None: async def set_machine_boardid(self, machine_id: int, boardid: str) -> None:
result = await self.execute( result = await self.execute(
machine.update(machine.c.id == machine_id).values(board=boardid) machine.update().where(machine.c.id == machine_id).values(board=boardid)
) )
if result is None: if result is None:
self.logger.error( self.logger.error(
f"Failed to update board id for machine {machine_id} -> {boardid}" f"Failed to update board id for machine {machine_id} -> {boardid}"
) )
async def get_arcade(self, id: int) -> Optional[Row]: async def get_arcade(self, id: int) -> Optional[RowMapping]:
sql = arcade.select(arcade.c.id == id) sql = arcade.select().where(arcade.c.id == id)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return None return None
return result.fetchone() return result.mappings().fetchone()
async def get_arcade_machines(self, id: int) -> Optional[List[Row]]: async def get_arcade_machines(self, id: int) -> Optional[List[RowMapping]]:
sql = machine.select(machine.c.arcade == id) sql = machine.select().where(machine.c.arcade == id)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return None return None
return result.fetchall() return result.mappings().fetchall()
async def create_arcade( async def create_arcade(
self, self,
@ -177,14 +179,22 @@ class ArcadeData(BaseData):
return result.lastrowid return result.lastrowid
async def get_arcades_managed_by_user(self, user_id: int) -> Optional[List[Row]]: async def get_arcades_managed_by_user(self, user_id: int) -> Optional[List[Row]]:
sql = select(arcade).join(arcade_owner, arcade_owner.c.arcade == arcade.c.id).where(arcade_owner.c.user == user_id) sql = (
select(arcade)
.join(arcade_owner, arcade_owner.c.arcade == arcade.c.id)
.where(arcade_owner.c.user == user_id)
)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return False return False
return result.fetchall() return result.fetchall()
async def get_manager_permissions(self, user_id: int, arcade_id: int) -> Optional[int]: async def get_manager_permissions(
sql = select(arcade_owner.c.permissions).where(and_(arcade_owner.c.user == user_id, arcade_owner.c.arcade == arcade_id)) self, user_id: int, arcade_id: int
) -> Optional[int]:
sql = select(arcade_owner.c.permissions).where(
and_(arcade_owner.c.user == user_id, arcade_owner.c.arcade == arcade_id)
)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return False return False
@ -207,7 +217,9 @@ class ArcadeData(BaseData):
return result.lastrowid return result.lastrowid
async def get_arcade_by_name(self, name: str) -> Optional[List[Row]]: async def get_arcade_by_name(self, name: str) -> Optional[List[Row]]:
sql = arcade.select(or_(arcade.c.name.like(f"%{name}%"), arcade.c.nickname.like(f"%{name}%"))) sql = arcade.select().where(
or_(arcade.c.name.like(f"%{name}%"), arcade.c.nickname.like(f"%{name}%"))
)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return None return None
@ -219,25 +231,38 @@ class ArcadeData(BaseData):
if result is None: if result is None:
return None return None
return result.fetchall() return result.fetchall()
async def get_num_generated_keychips(self) -> Optional[int]: async def get_num_generated_keychips(self) -> Optional[int]:
result = await self.execute(select(func.count("serial LIKE 'A69A%'")).select_from(machine)) result = await self.execute(
select(func.count("serial LIKE 'A69A%'")).select_from(machine)
)
if result: if result:
return result.fetchone()['count_1'] return result.mappings().fetchone()["count_1"]
self.logger.error("Failed to count machine serials that start with A69A!") self.logger.error("Failed to count machine serials that start with A69A!")
def format_serial( def format_serial(
self, platform_code: str, platform_rev: int, serial_letter: str, serial_num: int, append: int, dash: bool = False self,
platform_code: str,
platform_rev: int,
serial_letter: str,
serial_num: int,
append: int,
dash: bool = False,
) -> str: ) -> str:
return f"{platform_code}{'-' if dash else ''}{platform_rev:02d}{serial_letter}{serial_num:04d}{append:04d}" return f"{platform_code}{'-' if dash else ''}{platform_rev:02d}{serial_letter}{serial_num:04d}{append:04d}"
def validate_keychip_format(self, serial: str) -> bool: def validate_keychip_format(self, serial: str) -> bool:
# For the 2nd letter, E and X are the only "real" values that have been observed (A is used for generated keychips) # For the 2nd letter, E and X are the only "real" values that have been observed (A is used for generated keychips)
if re.fullmatch(r"^A[0-9]{2}[A-Z][-]?[0-9]{2}[A-HJ-NP-Z][0-9]{4}([0-9]{4})?$", serial) is None: if (
re.fullmatch(
r"^A[0-9]{2}[A-Z][-]?[0-9]{2}[A-HJ-NP-Z][0-9]{4}([0-9]{4})?$", serial
)
is None
):
return False return False
return True return True
# Thanks bottersnike! # Thanks bottersnike!
def get_keychip_suffix(self, year: int, month: int) -> str: def get_keychip_suffix(self, year: int, month: int) -> str:
assert year > 1957 assert year > 1957
@ -252,7 +277,6 @@ class ArcadeData(BaseData):
month = ((month - 1) + 9) % 12 # Offset so April=0 month = ((month - 1) + 9) % 12 # Offset so April=0
return f"{year:02}{month // 6:01}{month % 6 + 1:01}" return f"{year:02}{month // 6:01}{month % 6 + 1:01}"
def parse_keychip_suffix(self, suffix: str) -> tuple[int, int]: def parse_keychip_suffix(self, suffix: str) -> tuple[int, int]:
year = int(suffix[0:2]) year = int(suffix[0:2])
half = int(suffix[2]) half = int(suffix[2])

View File

@ -1,7 +1,7 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
from sqlalchemy import Column, Table, UniqueConstraint from sqlalchemy import Column, Table, UniqueConstraint
from sqlalchemy.engine import Row from sqlalchemy.engine import Row, RowMapping
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.sql.schema import ForeignKey from sqlalchemy.sql.schema import ForeignKey
from sqlalchemy.types import BIGINT, TIMESTAMP, VARCHAR, Boolean, Integer, String from sqlalchemy.types import BIGINT, TIMESTAMP, VARCHAR, Boolean, Integer, String
@ -12,7 +12,11 @@ aime_card: Table = Table(
"aime_card", "aime_card",
metadata, metadata,
Column("id", Integer, primary_key=True, nullable=False), Column("id", Integer, primary_key=True, nullable=False),
Column("user", ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"), nullable=False), Column(
"user",
ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"),
nullable=False,
),
Column("access_code", String(20), nullable=False, unique=True), Column("access_code", String(20), nullable=False, unique=True),
Column("idm", String(16), unique=True), Column("idm", String(16), unique=True),
Column("chip_id", BIGINT, unique=True), Column("chip_id", BIGINT, unique=True),
@ -28,27 +32,29 @@ aime_card: Table = Table(
class CardData(BaseData): class CardData(BaseData):
moble_os_codes = set([0x06, 0x07, 0x10, 0x12, 0x13, 0x14, 0x15, 0x17, 0x18]) moble_os_codes = set([0x06, 0x07, 0x10, 0x12, 0x13, 0x14, 0x15, 0x17, 0x18])
card_os_codes = set([0x20, 0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, 0xF7]) card_os_codes = set([0x20, 0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, 0xF7])
async def get_card_by_access_code(self, access_code: str) -> Optional[Row]: async def get_card_by_access_code(self, access_code: str) -> Optional[RowMapping]:
sql = aime_card.select(aime_card.c.access_code == access_code) sql = aime_card.select().where(aime_card.c.access_code == access_code)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return None return None
return result.fetchone() return result.mappings().fetchone()
async def get_card_by_id(self, card_id: int) -> Optional[Row]: async def get_card_by_id(self, card_id: int) -> Optional[RowMapping]:
sql = aime_card.select(aime_card.c.id == card_id) sql = aime_card.select().where(aime_card.c.id == card_id)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return None return None
return result.fetchone() return result.mappings().fetchone()
async def update_access_code(self, old_ac: str, new_ac: str) -> None: async def update_access_code(self, old_ac: str, new_ac: str) -> None:
sql = aime_card.update(aime_card.c.access_code == old_ac).values( sql = (
access_code=new_ac aime_card.update()
.where(aime_card.c.access_code == old_ac)
.values(access_code=new_ac)
) )
result = await self.execute(sql) result = await self.execute(sql)
@ -65,7 +71,7 @@ class CardData(BaseData):
if card is None: if card is None:
return None return None
return int(card["user"]) return int(card.user)
async def get_card_banned(self, access_code: str) -> Optional[bool]: async def get_card_banned(self, access_code: str) -> Optional[bool]:
""" """
@ -74,10 +80,10 @@ class CardData(BaseData):
card = await self.get_card_by_access_code(access_code) card = await self.get_card_by_access_code(access_code)
if card is None: if card is None:
return None return None
if card["is_banned"]: if card.is_banned:
return True return True
return False return False
async def get_card_locked(self, access_code: str) -> Optional[bool]: async def get_card_locked(self, access_code: str) -> Optional[bool]:
""" """
Given a 20 digit access code as a string, check if the card is locked Given a 20 digit access code as a string, check if the card is locked
@ -85,26 +91,26 @@ class CardData(BaseData):
card = await self.get_card_by_access_code(access_code) card = await self.get_card_by_access_code(access_code)
if card is None: if card is None:
return None return None
if card["is_locked"]: if card.is_locked:
return True return True
return False return False
async def delete_card(self, card_id: int) -> None: async def delete_card(self, card_id: int) -> None:
sql = aime_card.delete(aime_card.c.id == card_id) sql = aime_card.delete().where(aime_card.c.id == card_id)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
self.logger.error(f"Failed to delete card with id {card_id}") self.logger.error(f"Failed to delete card with id {card_id}")
async def get_user_cards(self, aime_id: int) -> Optional[List[Row]]: async def get_user_cards(self, aime_id: int) -> Optional[List[RowMapping]]:
""" """
Returns all cards owned by a user Returns all cards owned by a user
""" """
sql = aime_card.select(aime_card.c.user == aime_id) sql = aime_card.select().where(aime_card.c.user == aime_id)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return None return None
return result.fetchall() return result.mappings().fetchall()
async def create_card(self, user_id: int, access_code: str) -> Optional[int]: async def create_card(self, user_id: int, access_code: str) -> Optional[int]:
""" """
@ -117,41 +123,67 @@ class CardData(BaseData):
return result.lastrowid return result.lastrowid
async def update_card_last_login(self, access_code: str) -> None: async def update_card_last_login(self, access_code: str) -> None:
sql = aime_card.update(aime_card.c.access_code == access_code).values( sql = (
last_login_date=func.now() aime_card.update()
.where(aime_card.c.access_code == access_code)
.values(last_login_date=func.now())
) )
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
self.logger.warn(f"Failed to update last login time for {access_code}") self.logger.warn(f"Failed to update last login time for {access_code}")
async def get_card_by_idm(self, idm: str) -> Optional[Row]: async def get_card_by_idm(self, idm: str) -> Optional[RowMapping]:
result = await self.execute(aime_card.select(aime_card.c.idm == idm)) result = await self.execute(aime_card.select().where(aime_card.c.idm == idm))
if result: if result:
return result.fetchone() return result.mappings().fetchone()
async def get_card_by_chip_id(self, chip_id: int) -> Optional[Row]: async def get_card_by_chip_id(self, chip_id: int) -> Optional[Row]:
result = await self.execute(aime_card.select(aime_card.c.chip_id == chip_id)) result = await self.execute(
aime_card.select().where(aime_card.c.chip_id == chip_id)
)
if result: if result:
return result.fetchone() return result.fetchone()
async def set_chip_id_by_access_code(self, access_code: str, chip_id: int) -> Optional[Row]: async def set_chip_id_by_access_code(
result = await self.execute(aime_card.update(aime_card.c.access_code == access_code).values(chip_id=chip_id)) self, access_code: str, chip_id: int
) -> Optional[Row]:
result = await self.execute(
aime_card.update()
.where(aime_card.c.access_code == access_code)
.values(chip_id=chip_id)
)
if not result: if not result:
self.logger.error(f"Failed to update chip ID to {chip_id} for {access_code}") self.logger.error(
f"Failed to update chip ID to {chip_id} for {access_code}"
)
async def set_idm_by_access_code(self, access_code: str, idm: str) -> Optional[Row]: async def set_idm_by_access_code(self, access_code: str, idm: str) -> Optional[Row]:
result = await self.execute(aime_card.update(aime_card.c.access_code == access_code).values(idm=idm)) result = await self.execute(
aime_card.update()
.where(aime_card.c.access_code == access_code)
.values(idm=idm)
)
if not result: if not result:
self.logger.error(f"Failed to update IDm to {idm} for {access_code}") self.logger.error(f"Failed to update IDm to {idm} for {access_code}")
async def set_access_code_by_access_code(self, old_ac: str, new_ac: str) -> None: async def set_access_code_by_access_code(self, old_ac: str, new_ac: str) -> None:
result = await self.execute(aime_card.update(aime_card.c.access_code == old_ac).values(access_code=new_ac)) result = await self.execute(
aime_card.update()
.where(aime_card.c.access_code == old_ac)
.values(access_code=new_ac)
)
if not result: if not result:
self.logger.error(f"Failed to change card access code from {old_ac} to {new_ac}") self.logger.error(
f"Failed to change card access code from {old_ac} to {new_ac}"
)
async def set_memo_by_access_code(self, access_code: str, memo: str) -> None: async def set_memo_by_access_code(self, access_code: str, memo: str) -> None:
result = await self.execute(aime_card.update(aime_card.c.access_code == access_code).values(memo=memo)) result = await self.execute(
aime_card.update()
.where(aime_card.c.access_code == access_code)
.values(memo=memo)
)
if not result: if not result:
self.logger.error(f"Failed to add memo to card {access_code}") self.logger.error(f"Failed to add memo to card {access_code}")

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import bcrypt import bcrypt
from sqlalchemy import Column, Table from sqlalchemy import Column, Table
from sqlalchemy.dialects.mysql import insert from sqlalchemy.dialects.mysql import insert
from sqlalchemy.engine import Row from sqlalchemy.engine import Row, RowMapping
from sqlalchemy.sql import func, select from sqlalchemy.sql import func, select
from sqlalchemy.types import TIMESTAMP, Integer, String from sqlalchemy.types import TIMESTAMP, Integer, String
@ -23,6 +23,7 @@ aime_user: Table = Table(
mysql_charset="utf8mb4", mysql_charset="utf8mb4",
) )
class UserData(BaseData): class UserData(BaseData):
async def create_user( async def create_user(
self, self,
@ -57,28 +58,28 @@ class UserData(BaseData):
return None return None
return result.lastrowid return result.lastrowid
async def get_user(self, user_id: int) -> Optional[Row]: async def get_user(self, user_id: int) -> Optional[RowMapping]:
sql = select(aime_user).where(aime_user.c.id == user_id) sql = select(aime_user).where(aime_user.c.id == user_id)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return False return False
return result.fetchone() return result.mappings().fetchone()
async def check_password(self, user_id: int, passwd: bytes = None) -> bool: async def check_password(self, user_id: int, passwd: bytes = None) -> bool:
usr = await self.get_user(user_id) usr = await self.get_user(user_id)
if usr is None: if usr is None:
return False return False
if usr["password"] is None: if usr.password is None:
return False return False
if passwd is None or not passwd: if passwd is None or not passwd:
return False return False
return bcrypt.checkpw(passwd, usr["password"].encode()) return bcrypt.checkpw(passwd, usr.password.encode())
async def delete_user(self, user_id: int) -> None: async def delete_user(self, user_id: int) -> None:
sql = aime_user.delete(aime_user.c.id == user_id) sql = aime_user.delete().where(aime_user.c.id == user_id)
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
@ -103,24 +104,35 @@ class UserData(BaseData):
return result.fetchone() return result.fetchone()
async def find_user_by_username(self, username: str) -> List[Row]: async def find_user_by_username(self, username: str) -> List[Row]:
sql = aime_user.select(aime_user.c.username.like(f"%{username}%")) sql = aime_user.select().where(aime_user.c.username.like(f"%{username}%"))
result = await self.execute(sql) result = await self.execute(sql)
if result is None: if result is None:
return False return False
return result.fetchall() return result.fetchall()
async def change_password(self, user_id: int, new_passwd: str) -> bool: async def change_password(self, user_id: int, new_passwd: str) -> bool:
sql = aime_user.update(aime_user.c.id == user_id).values(password = new_passwd) sql = (
aime_user.update()
.where(aime_user.c.id == user_id)
.values(password=new_passwd)
)
result = await self.execute(sql) result = await self.execute(sql)
return result is not None return result is not None
async def change_username(self, user_id: int, new_name: str) -> bool: async def change_username(self, user_id: int, new_name: str) -> bool:
sql = aime_user.update(aime_user.c.id == user_id).values(username = new_name) sql = (
aime_user.update()
.where(aime_user.c.id == user_id)
.values(username=new_name)
)
result = await self.execute(sql) result = await self.execute(sql)
return result is not None return result is not None
async def get_user_by_username(self, username: str) -> Optional[Row]: async def get_user_by_username(self, username: str) -> Optional[RowMapping]:
result = await self.execute(aime_user.select(aime_user.c.username == username)) result = await self.execute(
if result: return result.fetchone() aime_user.select().where(aime_user.c.username == username)
)
if result:
return result.mappings().fetchone()

View File

@ -111,7 +111,7 @@ class FrontendServlet():
self.arcade = FE_Arcade(cfg, self.environment) self.arcade = FE_Arcade(cfg, self.environment)
self.machine = FE_Machine(cfg, self.environment) self.machine = FE_Machine(cfg, self.environment)
def get_routes(self) -> List[Route]: def get_routes(self) -> List[Union[Route, Mount]]:
g_routes = [] g_routes = []
for nav_name, g_data in self.environment.globals["game_list"].items(): for nav_name, g_data in self.environment.globals["game_list"].items():
g_routes.append(Mount(g_data['url'], routes=g_data['class'].get_routes())) g_routes.append(Mount(g_data['url'], routes=g_data['class'].get_routes()))