fix: make database async
This commit is contained in:
@ -1,54 +1,65 @@
|
||||
import logging, coloredlogs
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
from sqlalchemy import create_engine
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import secrets, string
|
||||
import bcrypt
|
||||
import secrets
|
||||
import string
|
||||
import warnings
|
||||
from hashlib import sha256
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import alembic.config
|
||||
import glob
|
||||
import bcrypt
|
||||
import coloredlogs
|
||||
import pymysql.err
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_scoped_session,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.config import CoreConfig
|
||||
from core.data.schema import *
|
||||
from core.utils import Utils
|
||||
from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata
|
||||
from core.utils import MISSING, Utils
|
||||
|
||||
|
||||
class Data:
|
||||
engine = None
|
||||
session = None
|
||||
user = None
|
||||
arcade = None
|
||||
card = None
|
||||
base = None
|
||||
engine: ClassVar[AsyncEngine] = MISSING
|
||||
session: ClassVar[AsyncSession] = MISSING
|
||||
user: ClassVar[UserData] = MISSING
|
||||
arcade: ClassVar[ArcadeData] = MISSING
|
||||
card: ClassVar[CardData] = MISSING
|
||||
base: ClassVar[BaseData] = MISSING
|
||||
|
||||
def __init__(self, cfg: CoreConfig) -> None:
|
||||
self.config = cfg
|
||||
|
||||
if self.config.database.sha2_password:
|
||||
passwd = sha256(self.config.database.password.encode()).digest()
|
||||
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
||||
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
||||
else:
|
||||
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
||||
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
||||
|
||||
if Data.engine is None:
|
||||
Data.engine = create_engine(self.__url, pool_recycle=3600)
|
||||
if Data.engine is MISSING:
|
||||
Data.engine = create_async_engine(self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT")
|
||||
self.__engine = Data.engine
|
||||
|
||||
if Data.session is None:
|
||||
s = sessionmaker(bind=Data.engine, autoflush=True, autocommit=True)
|
||||
Data.session = scoped_session(s)
|
||||
if Data.session is MISSING:
|
||||
s = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession)
|
||||
Data.session = async_scoped_session(s, asyncio.current_task)
|
||||
|
||||
if Data.user is None:
|
||||
if Data.user is MISSING:
|
||||
Data.user = UserData(self.config, self.session)
|
||||
|
||||
if Data.arcade is None:
|
||||
if Data.arcade is MISSING:
|
||||
Data.arcade = ArcadeData(self.config, self.session)
|
||||
|
||||
if Data.card is None:
|
||||
if Data.card is MISSING:
|
||||
Data.card = CardData(self.config, self.session)
|
||||
|
||||
if Data.base is None:
|
||||
if Data.base is MISSING:
|
||||
Data.base = BaseData(self.config, self.session)
|
||||
|
||||
self.logger = logging.getLogger("database")
|
||||
@ -94,40 +105,73 @@ class Data:
|
||||
alembic.config.main(argv=alembicArgs)
|
||||
os.chdir(old_dir)
|
||||
|
||||
def create_database(self):
|
||||
async def create_database(self):
|
||||
self.logger.info("Creating databases...")
|
||||
metadata.create_all(
|
||||
self.engine,
|
||||
checkfirst=True,
|
||||
)
|
||||
|
||||
for _, mod in Utils.get_all_titles().items():
|
||||
if hasattr(mod, "database"):
|
||||
mod.database(self.config)
|
||||
metadata.create_all(
|
||||
self.engine,
|
||||
checkfirst=True,
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
# SQLAlchemy will generate a nice primary key constraint name, but in
|
||||
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
|
||||
# custom primary key name is generated, a warning is emitted from pymysql,
|
||||
# which we don't care about. Other warnings may be helpful though, don't
|
||||
# suppress everything.
|
||||
warnings.filterwarnings(
|
||||
action="ignore",
|
||||
message=r"Name '(.+)' ignored for PRIMARY key\.",
|
||||
category=pymysql.err.Warning,
|
||||
)
|
||||
|
||||
# Stamp the end revision as if alembic had created it, so it can take off after this.
|
||||
self.__alembic_cmd(
|
||||
"stamp",
|
||||
"head",
|
||||
)
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(metadata.create_all, checkfirst=True)
|
||||
|
||||
def schema_upgrade(self, ver: str = None):
|
||||
self.__alembic_cmd(
|
||||
"upgrade",
|
||||
"head" if not ver else ver,
|
||||
)
|
||||
for _, mod in Utils.get_all_titles().items():
|
||||
if hasattr(mod, "database"):
|
||||
mod.database(self.config)
|
||||
|
||||
await conn.run_sync(metadata.create_all, checkfirst=True)
|
||||
|
||||
# Stamp the end revision as if alembic had created it, so it can take off after this.
|
||||
self.__alembic_cmd(
|
||||
"stamp",
|
||||
"head",
|
||||
)
|
||||
|
||||
def schema_upgrade(self, ver: Optional[str] = None):
|
||||
with warnings.catch_warnings():
|
||||
# SQLAlchemy will generate a nice primary key constraint name, but in
|
||||
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
|
||||
# custom primary key name is generated, a warning is emitted from pymysql,
|
||||
# which we don't care about. Other warnings may be helpful though, don't
|
||||
# suppress everything.
|
||||
warnings.filterwarnings(
|
||||
action="ignore",
|
||||
message=r"Name '(.+)' ignored for PRIMARY key\.",
|
||||
category=pymysql.err.Warning,
|
||||
)
|
||||
|
||||
self.__alembic_cmd(
|
||||
"upgrade",
|
||||
"head" if not ver else ver,
|
||||
)
|
||||
|
||||
def schema_downgrade(self, ver: str):
|
||||
self.__alembic_cmd(
|
||||
"downgrade",
|
||||
ver,
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
# SQLAlchemy will generate a nice primary key constraint name, but in
|
||||
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
|
||||
# custom primary key name is generated, a warning is emitted from pymysql,
|
||||
# which we don't care about. Other warnings may be helpful though, don't
|
||||
# suppress everything.
|
||||
warnings.filterwarnings(
|
||||
action="ignore",
|
||||
message=r"Name '(.+)' ignored for PRIMARY key\.",
|
||||
category=pymysql.err.Warning,
|
||||
)
|
||||
|
||||
async def create_owner(self, email: Optional[str] = None, code: Optional[str] = "00000000000000000000") -> None:
|
||||
self.__alembic_cmd(
|
||||
"downgrade",
|
||||
ver,
|
||||
)
|
||||
|
||||
async def create_owner(self, email: Optional[str] = None, code: str = "00000000000000000000") -> None:
|
||||
pw = "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits) for i in range(20)
|
||||
)
|
||||
@ -150,12 +194,12 @@ class Data:
|
||||
async def migrate(self) -> None:
|
||||
exist = await self.base.execute("SELECT * FROM alembic_version")
|
||||
if exist is not None:
|
||||
self.logger.warn("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!")
|
||||
self.logger.warning("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!")
|
||||
return
|
||||
|
||||
self.logger.info("Upgrading to latest with legacy system")
|
||||
if not await self.legacy_upgrade():
|
||||
self.logger.warn("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!")
|
||||
self.logger.warning("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!")
|
||||
return
|
||||
self.logger.info("Done")
|
||||
|
||||
|
Reference in New Issue
Block a user