forked from Hay1tsme/artemis
		
	fix: make database async
This commit is contained in:
		| @ -1,8 +1,14 @@ | |||||||
| from __future__ import with_statement | from __future__ import with_statement | ||||||
| from alembic import context |  | ||||||
| from sqlalchemy import engine_from_config, pool | import asyncio | ||||||
|  | import threading | ||||||
| from logging.config import fileConfig | from logging.config import fileConfig | ||||||
|  |  | ||||||
|  | from alembic import context | ||||||
|  | from sqlalchemy import pool | ||||||
|  | from sqlalchemy.engine import Connection | ||||||
|  | from sqlalchemy.ext.asyncio import async_engine_from_config | ||||||
|  |  | ||||||
| from core.data.schema.base import metadata | from core.data.schema.base import metadata | ||||||
|  |  | ||||||
| # this is the Alembic Config object, which provides | # this is the Alembic Config object, which provides | ||||||
| @ -37,20 +43,29 @@ def run_migrations_offline(): | |||||||
|     script output. |     script output. | ||||||
|  |  | ||||||
|     """ |     """ | ||||||
|     raise Exception('Not implemented or configured!') |     raise Exception("Not implemented or configured!") | ||||||
|  |  | ||||||
|     url = config.get_main_option("sqlalchemy.url") |     url = config.get_main_option("sqlalchemy.url") | ||||||
|     context.configure( |     context.configure(url=url, target_metadata=target_metadata, literal_binds=True) | ||||||
|         url=url, target_metadata=target_metadata, literal_binds=True) |  | ||||||
|  |  | ||||||
|     with context.begin_transaction(): |     with context.begin_transaction(): | ||||||
|         context.run_migrations() |         context.run_migrations() | ||||||
|  |  | ||||||
|  |  | ||||||
| def run_migrations_online(): | def do_run_migrations(connection: Connection) -> None: | ||||||
|     """Run migrations in 'online' mode. |     context.configure( | ||||||
|  |         connection=connection, | ||||||
|  |         target_metadata=target_metadata, | ||||||
|  |         compare_type=True, | ||||||
|  |         compare_server_default=True, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     In this scenario we need to create an Engine |     with context.begin_transaction(): | ||||||
|  |         context.run_migrations() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def run_async_migrations() -> None: | ||||||
|  |     """In this scenario we need to create an Engine | ||||||
|     and associate a connection with the context. |     and associate a connection with the context. | ||||||
|  |  | ||||||
|     """ |     """ | ||||||
| @ -59,21 +74,32 @@ def run_migrations_online(): | |||||||
|     for override in overrides: |     for override in overrides: | ||||||
|         ini_section[override] = overrides[override] |         ini_section[override] = overrides[override] | ||||||
|  |  | ||||||
|     connectable = engine_from_config( |     connectable = async_engine_from_config( | ||||||
|         ini_section, |         ini_section, prefix="sqlalchemy.", poolclass=pool.NullPool | ||||||
|         prefix='sqlalchemy.', |     ) | ||||||
|         poolclass=pool.NullPool) |  | ||||||
|  |  | ||||||
|     with connectable.connect() as connection: |     async with connectable.connect() as connection: | ||||||
|         context.configure( |         await connection.run_sync(do_run_migrations) | ||||||
|             connection=connection, |  | ||||||
|             target_metadata=target_metadata, |     await connectable.dispose() | ||||||
|             compare_type=True, |  | ||||||
|             compare_server_default=True, |  | ||||||
|         ) | def run_migrations_online(): | ||||||
|  |     try: | ||||||
|  |         loop = asyncio.get_running_loop() | ||||||
|  |     except RuntimeError: | ||||||
|  |         # there's no event loop | ||||||
|  |         asyncio.run(run_async_migrations()) | ||||||
|  |     else: | ||||||
|  |         # there's currently an event loop and trying to wait for a coroutine | ||||||
|  |         # to finish without using `await` is pretty wormy. nested event loops | ||||||
|  |         # are explicitly forbidden by asyncio. | ||||||
|  |         # | ||||||
|  |         # take the easy way out, spawn it in another thread. | ||||||
|  |         thread = threading.Thread(target=asyncio.run, args=(run_async_migrations(),)) | ||||||
|  |         thread.start() | ||||||
|  |         thread.join() | ||||||
|  |  | ||||||
|         with context.begin_transaction(): |  | ||||||
|             context.run_migrations() |  | ||||||
|  |  | ||||||
| if context.is_offline_mode(): | if context.is_offline_mode(): | ||||||
|     run_migrations_offline() |     run_migrations_offline() | ||||||
|  | |||||||
| @ -1,54 +1,65 @@ | |||||||
| import logging, coloredlogs | import asyncio | ||||||
| from typing import Optional | import logging | ||||||
| from sqlalchemy.orm import scoped_session, sessionmaker |  | ||||||
| from sqlalchemy import create_engine |  | ||||||
| from logging.handlers import TimedRotatingFileHandler |  | ||||||
| import os | import os | ||||||
| import secrets, string | import secrets | ||||||
| import bcrypt | import string | ||||||
|  | import warnings | ||||||
| from hashlib import sha256 | from hashlib import sha256 | ||||||
|  | from logging.handlers import TimedRotatingFileHandler | ||||||
|  | from typing import ClassVar, Optional | ||||||
|  |  | ||||||
| import alembic.config | import alembic.config | ||||||
| import glob | import bcrypt | ||||||
|  | import coloredlogs | ||||||
|  | import pymysql.err | ||||||
|  | from sqlalchemy.ext.asyncio import ( | ||||||
|  |     AsyncEngine, | ||||||
|  |     AsyncSession, | ||||||
|  |     async_scoped_session, | ||||||
|  |     create_async_engine, | ||||||
|  | ) | ||||||
|  | from sqlalchemy.orm import sessionmaker | ||||||
|  |  | ||||||
| from core.config import CoreConfig | from core.config import CoreConfig | ||||||
| from core.data.schema import * | from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata | ||||||
| from core.utils import Utils | from core.utils import MISSING, Utils | ||||||
|  |  | ||||||
|  |  | ||||||
| class Data: | class Data: | ||||||
|     engine = None |     engine: ClassVar[AsyncEngine] = MISSING | ||||||
|     session = None |     session: ClassVar[AsyncSession] = MISSING | ||||||
|     user = None |     user: ClassVar[UserData] = MISSING | ||||||
|     arcade = None |     arcade: ClassVar[ArcadeData] = MISSING | ||||||
|     card = None |     card: ClassVar[CardData] = MISSING | ||||||
|     base = None |     base: ClassVar[BaseData] = MISSING | ||||||
|  |  | ||||||
|     def __init__(self, cfg: CoreConfig) -> None: |     def __init__(self, cfg: CoreConfig) -> None: | ||||||
|         self.config = cfg |         self.config = cfg | ||||||
|  |  | ||||||
|         if self.config.database.sha2_password: |         if self.config.database.sha2_password: | ||||||
|             passwd = sha256(self.config.database.password.encode()).digest() |             passwd = sha256(self.config.database.password.encode()).digest() | ||||||
|             self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}" |             self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}" | ||||||
|         else: |         else: | ||||||
|             self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}" |             self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}" | ||||||
|  |  | ||||||
|         if Data.engine is None: |         if Data.engine is MISSING: | ||||||
|             Data.engine = create_engine(self.__url, pool_recycle=3600) |             Data.engine = create_async_engine(self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT") | ||||||
|             self.__engine = Data.engine |             self.__engine = Data.engine | ||||||
|  |  | ||||||
|         if Data.session is None: |         if Data.session is MISSING: | ||||||
|             s = sessionmaker(bind=Data.engine, autoflush=True, autocommit=True) |             s = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession) | ||||||
|             Data.session = scoped_session(s) |             Data.session = async_scoped_session(s, asyncio.current_task) | ||||||
|  |  | ||||||
|         if Data.user is None: |         if Data.user is MISSING: | ||||||
|             Data.user = UserData(self.config, self.session) |             Data.user = UserData(self.config, self.session) | ||||||
|          |          | ||||||
|         if Data.arcade is None: |         if Data.arcade is MISSING: | ||||||
|             Data.arcade = ArcadeData(self.config, self.session) |             Data.arcade = ArcadeData(self.config, self.session) | ||||||
|          |          | ||||||
|         if Data.card is None: |         if Data.card is MISSING: | ||||||
|             Data.card = CardData(self.config, self.session) |             Data.card = CardData(self.config, self.session) | ||||||
|          |          | ||||||
|         if Data.base is None: |         if Data.base is MISSING: | ||||||
|             Data.base = BaseData(self.config, self.session) |             Data.base = BaseData(self.config, self.session) | ||||||
|  |  | ||||||
|         self.logger = logging.getLogger("database") |         self.logger = logging.getLogger("database") | ||||||
| @ -94,40 +105,73 @@ class Data: | |||||||
|         alembic.config.main(argv=alembicArgs) |         alembic.config.main(argv=alembicArgs) | ||||||
|         os.chdir(old_dir) |         os.chdir(old_dir) | ||||||
|  |  | ||||||
|     def create_database(self): |     async def create_database(self): | ||||||
|         self.logger.info("Creating databases...") |         self.logger.info("Creating databases...") | ||||||
|         metadata.create_all( |  | ||||||
|             self.engine, |  | ||||||
|             checkfirst=True, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         for _, mod in Utils.get_all_titles().items(): |         with warnings.catch_warnings(): | ||||||
|             if hasattr(mod, "database"): |             # SQLAlchemy will generate a nice primary key constraint name, but in | ||||||
|                 mod.database(self.config) |             # MySQL/MariaDB the constraint name is always PRIMARY. Every time a | ||||||
|                 metadata.create_all( |             # custom primary key name is generated, a warning is emitted from pymysql, | ||||||
|                     self.engine, |             # which we don't care about. Other warnings may be helpful though, don't | ||||||
|                     checkfirst=True, |             # suppress everything.             | ||||||
|                 ) |             warnings.filterwarnings( | ||||||
|  |                 action="ignore", | ||||||
|  |                 message=r"Name '(.+)' ignored for PRIMARY key\.", | ||||||
|  |                 category=pymysql.err.Warning, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         # Stamp the end revision as if alembic had created it, so it can take off after this. |             async with self.engine.begin() as conn: | ||||||
|         self.__alembic_cmd( |                 await conn.run_sync(metadata.create_all, checkfirst=True) | ||||||
|             "stamp", |  | ||||||
|             "head", |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def schema_upgrade(self, ver: str = None): |                 for _, mod in Utils.get_all_titles().items(): | ||||||
|         self.__alembic_cmd( |                     if hasattr(mod, "database"): | ||||||
|             "upgrade", |                         mod.database(self.config) | ||||||
|             "head" if not ver else ver, |  | ||||||
|         ) |                         await conn.run_sync(metadata.create_all, checkfirst=True) | ||||||
|  |  | ||||||
|  |             # Stamp the end revision as if alembic had created it, so it can take off after this. | ||||||
|  |             self.__alembic_cmd( | ||||||
|  |                 "stamp", | ||||||
|  |                 "head", | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def schema_upgrade(self, ver: Optional[str] = None): | ||||||
|  |         with warnings.catch_warnings(): | ||||||
|  |             # SQLAlchemy will generate a nice primary key constraint name, but in | ||||||
|  |             # MySQL/MariaDB the constraint name is always PRIMARY. Every time a | ||||||
|  |             # custom primary key name is generated, a warning is emitted from pymysql, | ||||||
|  |             # which we don't care about. Other warnings may be helpful though, don't | ||||||
|  |             # suppress everything.             | ||||||
|  |             warnings.filterwarnings( | ||||||
|  |                 action="ignore", | ||||||
|  |                 message=r"Name '(.+)' ignored for PRIMARY key\.", | ||||||
|  |                 category=pymysql.err.Warning, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             self.__alembic_cmd( | ||||||
|  |                 "upgrade", | ||||||
|  |                 "head" if not ver else ver, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|     def schema_downgrade(self, ver: str): |     def schema_downgrade(self, ver: str): | ||||||
|         self.__alembic_cmd( |         with warnings.catch_warnings(): | ||||||
|             "downgrade", |             # SQLAlchemy will generate a nice primary key constraint name, but in | ||||||
|             ver, |             # MySQL/MariaDB the constraint name is always PRIMARY. Every time a | ||||||
|         ) |             # custom primary key name is generated, a warning is emitted from pymysql, | ||||||
|  |             # which we don't care about. Other warnings may be helpful though, don't | ||||||
|  |             # suppress everything.             | ||||||
|  |             warnings.filterwarnings( | ||||||
|  |                 action="ignore", | ||||||
|  |                 message=r"Name '(.+)' ignored for PRIMARY key\.", | ||||||
|  |                 category=pymysql.err.Warning, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|     async def create_owner(self, email: Optional[str] = None, code: Optional[str] = "00000000000000000000") -> None: |             self.__alembic_cmd( | ||||||
|  |                 "downgrade", | ||||||
|  |                 ver, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     async def create_owner(self, email: Optional[str] = None, code: str = "00000000000000000000") -> None: | ||||||
|         pw = "".join( |         pw = "".join( | ||||||
|             secrets.choice(string.ascii_letters + string.digits) for i in range(20) |             secrets.choice(string.ascii_letters + string.digits) for i in range(20) | ||||||
|         ) |         ) | ||||||
| @ -150,12 +194,12 @@ class Data: | |||||||
|     async def migrate(self) -> None: |     async def migrate(self) -> None: | ||||||
|         exist = await self.base.execute("SELECT * FROM alembic_version") |         exist = await self.base.execute("SELECT * FROM alembic_version") | ||||||
|         if exist is not None: |         if exist is not None: | ||||||
|             self.logger.warn("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!") |             self.logger.warning("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!") | ||||||
|             return |             return | ||||||
|          |          | ||||||
|         self.logger.info("Upgrading to latest with legacy system") |         self.logger.info("Upgrading to latest with legacy system") | ||||||
|         if not await self.legacy_upgrade(): |         if not await self.legacy_upgrade(): | ||||||
|             self.logger.warn("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!") |             self.logger.warning("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!") | ||||||
|             return |             return | ||||||
|         self.logger.info("Done") |         self.logger.info("Done") | ||||||
|          |          | ||||||
|  | |||||||
| @ -1,16 +1,16 @@ | |||||||
| from typing import Optional, Dict, List | import re | ||||||
| from sqlalchemy import Table, Column, and_, or_ | from typing import List, Optional | ||||||
| from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint |  | ||||||
| from sqlalchemy.types import Integer, String, Boolean, JSON | from sqlalchemy import Column, Table, and_, or_ | ||||||
| from sqlalchemy.sql import func, select |  | ||||||
| from sqlalchemy.dialects.mysql import insert | from sqlalchemy.dialects.mysql import insert | ||||||
| from sqlalchemy.engine import Row | from sqlalchemy.engine import Row | ||||||
| import re | from sqlalchemy.sql import func, select | ||||||
|  | from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint | ||||||
|  | from sqlalchemy.types import JSON, Boolean, Integer, String | ||||||
|  |  | ||||||
| from core.data.schema.base import BaseData, metadata | from core.data.schema.base import BaseData, metadata | ||||||
| from core.const import * |  | ||||||
|  |  | ||||||
| arcade = Table( | arcade: Table = Table( | ||||||
|     "arcade", |     "arcade", | ||||||
|     metadata, |     metadata, | ||||||
|     Column("id", Integer, primary_key=True, nullable=False), |     Column("id", Integer, primary_key=True, nullable=False), | ||||||
| @ -26,7 +26,7 @@ arcade = Table( | |||||||
|     mysql_charset="utf8mb4", |     mysql_charset="utf8mb4", | ||||||
| ) | ) | ||||||
|  |  | ||||||
| machine = Table( | machine: Table = Table( | ||||||
|     "machine", |     "machine", | ||||||
|     metadata, |     metadata, | ||||||
|     Column("id", Integer, primary_key=True, nullable=False), |     Column("id", Integer, primary_key=True, nullable=False), | ||||||
| @ -47,7 +47,7 @@ machine = Table( | |||||||
|     mysql_charset="utf8mb4", |     mysql_charset="utf8mb4", | ||||||
| ) | ) | ||||||
|  |  | ||||||
| arcade_owner = Table( | arcade_owner: Table = Table( | ||||||
|     "arcade_owner", |     "arcade_owner", | ||||||
|     metadata, |     metadata, | ||||||
|     Column( |     Column( | ||||||
| @ -69,7 +69,7 @@ arcade_owner = Table( | |||||||
|  |  | ||||||
|  |  | ||||||
| class ArcadeData(BaseData): | class ArcadeData(BaseData): | ||||||
|     async def get_machine(self, serial: str = None, id: int = None) -> Optional[Row]: |     async def get_machine(self, serial: Optional[str] = None, id: Optional[int] = None) -> Optional[Row]: | ||||||
|         if serial is not None: |         if serial is not None: | ||||||
|             serial = serial.replace("-", "") |             serial = serial.replace("-", "") | ||||||
|             if len(serial) == 11: |             if len(serial) == 11: | ||||||
| @ -98,8 +98,8 @@ class ArcadeData(BaseData): | |||||||
|         self, |         self, | ||||||
|         arcade_id: int, |         arcade_id: int, | ||||||
|         serial: str = "", |         serial: str = "", | ||||||
|         board: str = None, |         board: Optional[str] = None, | ||||||
|         game: str = None, |         game: Optional[str] = None, | ||||||
|         is_cab: bool = False, |         is_cab: bool = False, | ||||||
|     ) -> Optional[int]: |     ) -> Optional[int]: | ||||||
|         if not arcade_id: |         if not arcade_id: | ||||||
| @ -150,8 +150,8 @@ class ArcadeData(BaseData): | |||||||
|  |  | ||||||
|     async def create_arcade( |     async def create_arcade( | ||||||
|         self, |         self, | ||||||
|         name: str = None, |         name: Optional[str] = None, | ||||||
|         nickname: str = None, |         nickname: Optional[str] = None, | ||||||
|         country: str = "JPN", |         country: str = "JPN", | ||||||
|         country_id: int = 1, |         country_id: int = 1, | ||||||
|         state: str = "", |         state: str = "", | ||||||
|  | |||||||
| @ -1,22 +1,23 @@ | |||||||
|  | import asyncio | ||||||
| import json | import json | ||||||
| import logging | import logging | ||||||
| from random import randrange | from random import randrange | ||||||
| from typing import Any, Optional, Dict, List | from typing import Any, Dict, List, Optional | ||||||
|  |  | ||||||
|  | from sqlalchemy import Column, MetaData, Table | ||||||
| from sqlalchemy.engine import Row | from sqlalchemy.engine import Row | ||||||
| from sqlalchemy.engine.cursor import CursorResult | from sqlalchemy.engine.cursor import CursorResult | ||||||
| from sqlalchemy.engine.base import Connection |  | ||||||
| from sqlalchemy.sql import text, func, select |  | ||||||
| from sqlalchemy.exc import SQLAlchemyError | from sqlalchemy.exc import SQLAlchemyError | ||||||
| from sqlalchemy import MetaData, Table, Column | from sqlalchemy.ext.asyncio import AsyncSession | ||||||
| from sqlalchemy.types import Integer, String, TIMESTAMP, JSON, INTEGER, TEXT |  | ||||||
| from sqlalchemy.schema import ForeignKey | from sqlalchemy.schema import ForeignKey | ||||||
| from sqlalchemy.dialects.mysql import insert | from sqlalchemy.sql import func, text | ||||||
|  | from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String | ||||||
|  |  | ||||||
| from core.config import CoreConfig | from core.config import CoreConfig | ||||||
|  |  | ||||||
| metadata = MetaData() | metadata = MetaData() | ||||||
|  |  | ||||||
| event_log = Table( | event_log: Table = Table( | ||||||
|     "event_log", |     "event_log", | ||||||
|     metadata, |     metadata, | ||||||
|     Column("id", Integer, primary_key=True, nullable=False), |     Column("id", Integer, primary_key=True, nullable=False), | ||||||
| @ -37,7 +38,7 @@ event_log = Table( | |||||||
|  |  | ||||||
|  |  | ||||||
| class BaseData: | class BaseData: | ||||||
|     def __init__(self, cfg: CoreConfig, conn: Connection) -> None: |     def __init__(self, cfg: CoreConfig, conn: AsyncSession) -> None: | ||||||
|         self.config = cfg |         self.config = cfg | ||||||
|         self.conn = conn |         self.conn = conn | ||||||
|         self.logger = logging.getLogger("database") |         self.logger = logging.getLogger("database") | ||||||
| @ -47,7 +48,7 @@ class BaseData: | |||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}") |             self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}") | ||||||
|             res = self.conn.execute(text(sql), opts) |             res = await self.conn.execute(text(sql), opts) | ||||||
|  |  | ||||||
|         except SQLAlchemyError as e: |         except SQLAlchemyError as e: | ||||||
|             self.logger.error(f"SQLAlchemy error {e}") |             self.logger.error(f"SQLAlchemy error {e}") | ||||||
| @ -59,7 +60,7 @@ class BaseData: | |||||||
|  |  | ||||||
|         except Exception: |         except Exception: | ||||||
|             try: |             try: | ||||||
|                 res = self.conn.execute(sql, opts) |                 res = await self.conn.execute(sql, opts) | ||||||
|  |  | ||||||
|             except SQLAlchemyError as e: |             except SQLAlchemyError as e: | ||||||
|                 self.logger.error(f"SQLAlchemy error {e}") |                 self.logger.error(f"SQLAlchemy error {e}") | ||||||
| @ -83,7 +84,7 @@ class BaseData: | |||||||
|  |  | ||||||
|     async def log_event( |     async def log_event( | ||||||
|         self, system: str, type: str, severity: int, message: str, details: Dict = {}, user: int = None,  |         self, system: str, type: str, severity: int, message: str, details: Dict = {}, user: int = None,  | ||||||
|         arcade: int = None, machine: int = None, ip: str = None, game: str = None, version: str = None |         arcade: int = None, machine: int = None, ip: Optional[str] = None, game: Optional[str] = None, version: Optional[str] = None | ||||||
|     ) -> Optional[int]: |     ) -> Optional[int]: | ||||||
|         sql = event_log.insert().values( |         sql = event_log.insert().values( | ||||||
|             system=system, |             system=system, | ||||||
|  | |||||||
| @ -1,13 +1,14 @@ | |||||||
| from typing import Dict, List, Optional | from typing import Dict, List, Optional | ||||||
| from sqlalchemy import Table, Column, UniqueConstraint |  | ||||||
| from sqlalchemy.types import Integer, String, Boolean, TIMESTAMP, BIGINT, VARCHAR | from sqlalchemy import Column, Table, UniqueConstraint | ||||||
| from sqlalchemy.sql.schema import ForeignKey |  | ||||||
| from sqlalchemy.sql import func |  | ||||||
| from sqlalchemy.engine import Row | from sqlalchemy.engine import Row | ||||||
|  | from sqlalchemy.sql import func | ||||||
|  | from sqlalchemy.sql.schema import ForeignKey | ||||||
|  | from sqlalchemy.types import BIGINT, TIMESTAMP, VARCHAR, Boolean, Integer, String | ||||||
|  |  | ||||||
| from core.data.schema.base import BaseData, metadata | from core.data.schema.base import BaseData, metadata | ||||||
|  |  | ||||||
| aime_card = Table( | aime_card: Table = Table( | ||||||
|     "aime_card", |     "aime_card", | ||||||
|     metadata, |     metadata, | ||||||
|     Column("id", Integer, primary_key=True, nullable=False), |     Column("id", Integer, primary_key=True, nullable=False), | ||||||
|  | |||||||
| @ -1,15 +1,15 @@ | |||||||
| from typing import Optional, List | from typing import List, Optional | ||||||
| from sqlalchemy import Table, Column |  | ||||||
| from sqlalchemy.types import Integer, String, TIMESTAMP |  | ||||||
| from sqlalchemy.sql import func |  | ||||||
| from sqlalchemy.dialects.mysql import insert |  | ||||||
| from sqlalchemy.sql import func, select |  | ||||||
| from sqlalchemy.engine import Row |  | ||||||
| import bcrypt | import bcrypt | ||||||
|  | from sqlalchemy import Column, Table | ||||||
|  | from sqlalchemy.dialects.mysql import insert | ||||||
|  | from sqlalchemy.engine import Row | ||||||
|  | from sqlalchemy.sql import func, select | ||||||
|  | from sqlalchemy.types import TIMESTAMP, Integer, String | ||||||
|  |  | ||||||
| from core.data.schema.base import BaseData, metadata | from core.data.schema.base import BaseData, metadata | ||||||
|  |  | ||||||
| aime_user = Table( | aime_user: Table = Table( | ||||||
|     "aime_user", |     "aime_user", | ||||||
|     metadata, |     metadata, | ||||||
|     Column("id", Integer, nullable=False, primary_key=True, autoincrement=True), |     Column("id", Integer, nullable=False, primary_key=True, autoincrement=True), | ||||||
| @ -26,10 +26,10 @@ aime_user = Table( | |||||||
| class UserData(BaseData): | class UserData(BaseData): | ||||||
|     async def create_user( |     async def create_user( | ||||||
|         self, |         self, | ||||||
|         id: int = None, |         id: Optional[int] = None, | ||||||
|         username: str = None, |         username: Optional[str] = None, | ||||||
|         email: str = None, |         email: Optional[str] = None, | ||||||
|         password: str = None, |         password: Optional[str] = None, | ||||||
|         permission: int = 1, |         permission: int = 1, | ||||||
|     ) -> Optional[int]: |     ) -> Optional[int]: | ||||||
|         if id is None: |         if id is None: | ||||||
|  | |||||||
| @ -1,18 +1,47 @@ | |||||||
| from typing import Dict, Any, Optional |  | ||||||
| from types import ModuleType |  | ||||||
| from starlette.requests import Request |  | ||||||
| import logging |  | ||||||
| import importlib | import importlib | ||||||
| from os import walk | import logging | ||||||
| import jwt |  | ||||||
| from base64 import b64decode | from base64 import b64decode | ||||||
| from datetime import datetime, timezone | from datetime import datetime, timezone | ||||||
|  | from os import walk | ||||||
|  | from types import ModuleType | ||||||
|  | from typing import Any, Dict, Optional | ||||||
|  |  | ||||||
|  | import jwt | ||||||
|  | from starlette.requests import Request | ||||||
|  |  | ||||||
| from .config import CoreConfig | from .config import CoreConfig | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class _MissingSentinel: | ||||||
|  |     __slots__: tuple[str, ...] = () | ||||||
|  |  | ||||||
|  |     def __eq__(self, other) -> bool: | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     def __bool__(self) -> bool: | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     def __hash__(self) -> int: | ||||||
|  |         return 0 | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "..." | ||||||
|  |  | ||||||
|  |  | ||||||
|  | MISSING: Any = _MissingSentinel() | ||||||
|  | """This is different from `None` in that its type is `Any`, and so it can be used | ||||||
|  | as a placeholder for values that are *definitely* going to be initialized, | ||||||
|  | so they don't have to be typed as `T | None`, which makes type checkers | ||||||
|  | angry when an attribute is accessed. | ||||||
|  |  | ||||||
|  | This can also be used for when `None` has actual meaning as a value, and so a | ||||||
|  | separate value is needed to mean "unset".""" | ||||||
|  |  | ||||||
|  |  | ||||||
| class Utils: | class Utils: | ||||||
|     real_title_port = None |     real_title_port = None | ||||||
|     real_title_port_ssl = None |     real_title_port_ssl = None | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_all_titles(cls) -> Dict[str, ModuleType]: |     def get_all_titles(cls) -> Dict[str, ModuleType]: | ||||||
|         ret: Dict[str, Any] = {} |         ret: Dict[str, Any] = {} | ||||||
| @ -39,24 +68,53 @@ class Utils: | |||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_title_port(cls, cfg: CoreConfig): |     def get_title_port(cls, cfg: CoreConfig): | ||||||
|         if cls.real_title_port is not None: return cls.real_title_port |         if cls.real_title_port is not None: | ||||||
|  |             return cls.real_title_port | ||||||
|  |  | ||||||
|         cls.real_title_port = cfg.server.proxy_port if cfg.server.is_using_proxy and cfg.server.proxy_port else cfg.server.port |         cls.real_title_port = ( | ||||||
|  |             cfg.server.proxy_port | ||||||
|  |             if cfg.server.is_using_proxy and cfg.server.proxy_port | ||||||
|  |             else cfg.server.port | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         return cls.real_title_port |         return cls.real_title_port | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_title_port_ssl(cls, cfg: CoreConfig): |     def get_title_port_ssl(cls, cfg: CoreConfig): | ||||||
|         if cls.real_title_port_ssl is not None: return cls.real_title_port_ssl |         if cls.real_title_port_ssl is not None: | ||||||
|  |             return cls.real_title_port_ssl | ||||||
|  |  | ||||||
|         cls.real_title_port_ssl = cfg.server.proxy_port_ssl if cfg.server.is_using_proxy and cfg.server.proxy_port_ssl else 443 |         cls.real_title_port_ssl = ( | ||||||
|  |             cfg.server.proxy_port_ssl | ||||||
|  |             if cfg.server.is_using_proxy and cfg.server.proxy_port_ssl | ||||||
|  |             else 443 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         return cls.real_title_port_ssl |         return cls.real_title_port_ssl | ||||||
|  |  | ||||||
| def create_sega_auth_key(aime_id: int, game: str, place_id: int, keychip_id: str, b64_secret: str, exp_seconds: int = 86400, err_logger: str = 'aimedb') -> Optional[str]: |  | ||||||
|  | def create_sega_auth_key( | ||||||
|  |     aime_id: int, | ||||||
|  |     game: str, | ||||||
|  |     place_id: int, | ||||||
|  |     keychip_id: str, | ||||||
|  |     b64_secret: str, | ||||||
|  |     exp_seconds: int = 86400, | ||||||
|  |     err_logger: str = "aimedb", | ||||||
|  | ) -> Optional[str]: | ||||||
|     logger = logging.getLogger(err_logger) |     logger = logging.getLogger(err_logger) | ||||||
|     try: |     try: | ||||||
|         return jwt.encode({ "aime_id": aime_id, "game": game, "place_id": place_id, "keychip_id": keychip_id, "exp": int(datetime.now(tz=timezone.utc).timestamp()) + exp_seconds }, b64decode(b64_secret), algorithm="HS256") |         return jwt.encode( | ||||||
|  |             { | ||||||
|  |                 "aime_id": aime_id, | ||||||
|  |                 "game": game, | ||||||
|  |                 "place_id": place_id, | ||||||
|  |                 "keychip_id": keychip_id, | ||||||
|  |                 "exp": int(datetime.now(tz=timezone.utc).timestamp()) + exp_seconds, | ||||||
|  |             }, | ||||||
|  |             b64decode(b64_secret), | ||||||
|  |             algorithm="HS256", | ||||||
|  |         ) | ||||||
|     except jwt.InvalidKeyError: |     except jwt.InvalidKeyError: | ||||||
|         logger.error("Failed to encode Sega Auth Key because the secret is invalid!") |         logger.error("Failed to encode Sega Auth Key because the secret is invalid!") | ||||||
|         return None |         return None | ||||||
| @ -64,10 +122,19 @@ def create_sega_auth_key(aime_id: int, game: str, place_id: int, keychip_id: str | |||||||
|         logger.error(f"Unknown exception occoured when encoding Sega Auth Key! {e}") |         logger.error(f"Unknown exception occoured when encoding Sega Auth Key! {e}") | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
| def decode_sega_auth_key(token: str, b64_secret: str, err_logger: str = 'aimedb') -> Optional[Dict]: |  | ||||||
|  | def decode_sega_auth_key( | ||||||
|  |     token: str, b64_secret: str, err_logger: str = "aimedb" | ||||||
|  | ) -> Optional[Dict]: | ||||||
|     logger = logging.getLogger(err_logger) |     logger = logging.getLogger(err_logger) | ||||||
|     try: |     try: | ||||||
|         return jwt.decode(token, "secret", b64decode(b64_secret), algorithms=["HS256"], options={"verify_signature": True}) |         return jwt.decode( | ||||||
|  |             token, | ||||||
|  |             "secret", | ||||||
|  |             b64decode(b64_secret), | ||||||
|  |             algorithms=["HS256"], | ||||||
|  |             options={"verify_signature": True}, | ||||||
|  |         ) | ||||||
|     except jwt.ExpiredSignatureError: |     except jwt.ExpiredSignatureError: | ||||||
|         logger.error("Sega Auth Key failed to validate due to an expired signature!") |         logger.error("Sega Auth Key failed to validate due to an expired signature!") | ||||||
|         return None |         return None | ||||||
| @ -83,4 +150,3 @@ def decode_sega_auth_key(token: str, b64_secret: str, err_logger: str = 'aimedb' | |||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         logger.error(f"Unknown exception occoured when decoding Sega Auth Key! {e}") |         logger.error(f"Unknown exception occoured when decoding Sega Auth Key! {e}") | ||||||
|         return None |         return None | ||||||
|      |  | ||||||
							
								
								
									
										11
									
								
								dbutils.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								dbutils.py
									
									
									
									
									
								
							| @ -1,12 +1,13 @@ | |||||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||||
| import argparse | import argparse | ||||||
| import logging |  | ||||||
| from os import mkdir, path, access, W_OK, environ |  | ||||||
| import yaml |  | ||||||
| import asyncio | import asyncio | ||||||
|  | import logging | ||||||
|  | from os import W_OK, access, environ, mkdir, path | ||||||
|  |  | ||||||
|  | import yaml | ||||||
|  |  | ||||||
| from core.data import Data |  | ||||||
| from core.config import CoreConfig | from core.config import CoreConfig | ||||||
|  | from core.data import Data | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     parser = argparse.ArgumentParser(description="Database utilities") |     parser = argparse.ArgumentParser(description="Database utilities") | ||||||
| @ -46,7 +47,7 @@ if __name__ == "__main__": | |||||||
|     loop = asyncio.get_event_loop() |     loop = asyncio.get_event_loop() | ||||||
|  |  | ||||||
|     if args.action == "create": |     if args.action == "create": | ||||||
|         data.create_database() |         loop.run_until_complete(data.create_database()) | ||||||
|      |      | ||||||
|     elif args.action == "upgrade": |     elif args.action == "upgrade": | ||||||
|         data.schema_upgrade(args.version) |         data.schema_upgrade(args.version) | ||||||
|  | |||||||
							
								
								
									
										24
									
								
								read.py
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								read.py
									
									
									
									
									
								
							| @ -1,16 +1,16 @@ | |||||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||||
| import argparse | import argparse | ||||||
| import re |  | ||||||
| import os |  | ||||||
| import yaml |  | ||||||
| from os import path |  | ||||||
| import logging |  | ||||||
| import coloredlogs |  | ||||||
| import asyncio | import asyncio | ||||||
|  | import logging | ||||||
|  | import os | ||||||
|  | import re | ||||||
| from logging.handlers import TimedRotatingFileHandler | from logging.handlers import TimedRotatingFileHandler | ||||||
|  | from os import path | ||||||
| from typing import List, Optional | from typing import List, Optional | ||||||
|  |  | ||||||
|  | import coloredlogs | ||||||
|  | import yaml | ||||||
|  |  | ||||||
| from core import CoreConfig, Utils | from core import CoreConfig, Utils | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -44,7 +44,7 @@ class BaseReader: | |||||||
|         pass |         pass | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | async def main(): | ||||||
|     parser = argparse.ArgumentParser(description="Import Game Information") |     parser = argparse.ArgumentParser(description="Import Game Information") | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--game", |         "--game", | ||||||
| @ -140,8 +140,12 @@ if __name__ == "__main__": | |||||||
|     for dir, mod in titles.items(): |     for dir, mod in titles.items(): | ||||||
|         if args.game in mod.game_codes: |         if args.game in mod.game_codes: | ||||||
|             handler = mod.reader(config, args.version, bin_arg, opt_arg, args.extra) |             handler = mod.reader(config, args.version, bin_arg, opt_arg, args.extra) | ||||||
|             loop = asyncio.get_event_loop() |              | ||||||
|             loop.run_until_complete(handler.read()) |             await handler.read() | ||||||
|              |              | ||||||
|  |  | ||||||
|     logger.info("Done") |     logger.info("Done") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     asyncio.run(main()) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user