9 Commits

2 changed files with 21 additions and 23 deletions

View File

@ -1,11 +1,12 @@
import logging import logging
import os import os
import secrets import secrets
import ssl
import string import string
import warnings import warnings
from hashlib import sha256 from hashlib import sha256
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
from typing import ClassVar, Optional from typing import Any, ClassVar, Optional
import alembic.config import alembic.config
import bcrypt import bcrypt
@ -16,7 +17,6 @@ from sqlalchemy.ext.asyncio import (
AsyncSession, AsyncSession,
create_async_engine, create_async_engine,
) )
from sqlalchemy.orm import sessionmaker
from core.config import CoreConfig from core.config import CoreConfig
from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata
@ -25,7 +25,7 @@ from core.utils import MISSING, Utils
class Data: class Data:
engine: ClassVar[AsyncEngine] = MISSING engine: ClassVar[AsyncEngine] = MISSING
session: ClassVar["sessionmaker[AsyncSession]"] = MISSING session: ClassVar[AsyncSession] = MISSING
user: ClassVar[UserData] = MISSING user: ClassVar[UserData] = MISSING
arcade: ClassVar[ArcadeData] = MISSING arcade: ClassVar[ArcadeData] = MISSING
card: ClassVar[CardData] = MISSING card: ClassVar[CardData] = MISSING
@ -53,7 +53,7 @@ class Data:
self.__engine = Data.engine self.__engine = Data.engine
if Data.session is MISSING: if Data.session is MISSING:
Data.session = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession) Data.session = AsyncSession(Data.engine, expire_on_commit=False)
if Data.user is MISSING: if Data.user is MISSING:
Data.user = UserData(self.config, self.session) Data.user = UserData(self.config, self.session)

View File

@ -9,7 +9,6 @@ from sqlalchemy.engine import Row
from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import ForeignKey from sqlalchemy.schema import ForeignKey
from sqlalchemy.sql import func, text from sqlalchemy.sql import func, text
from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String
@ -39,7 +38,7 @@ event_log: Table = Table(
class BaseData: class BaseData:
def __init__(self, cfg: CoreConfig, conn: "sessionmaker[AsyncSession]") -> None: def __init__(self, cfg: CoreConfig, conn: AsyncSession) -> None:
self.config = cfg self.config = cfg
self.conn = conn self.conn = conn
self.logger = logging.getLogger("database") self.logger = logging.getLogger("database")
@ -47,10 +46,21 @@ class BaseData:
async def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]: async def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]:
res = None res = None
async with self.conn() as session: try:
self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}")
res = await self.conn.execute(text(sql), opts)
except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}")
return None
except UnicodeEncodeError as e:
self.logger.error(f"UnicodeEncodeError error {e}")
return None
except Exception:
try: try:
self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}") res = await self.conn.execute(sql, opts)
res = await session.execute(text(sql), opts)
except SQLAlchemyError as e: except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}") self.logger.error(f"SQLAlchemy error {e}")
@ -61,20 +71,8 @@ class BaseData:
return None return None
except Exception: except Exception:
try: self.logger.error(f"Unknown error")
res = await session.execute(sql, opts) raise
except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}")
return None
except UnicodeEncodeError as e:
self.logger.error(f"UnicodeEncodeError error {e}")
return None
except Exception:
self.logger.error(f"Unknown error")
raise
return res return res