from typing import Optional, List from sqlalchemy import Table, Column, text, UniqueConstraint from sqlalchemy.dialects import mysql from sqlalchemy.types import Integer, String, TIMESTAMP, JSON from sqlalchemy.sql import func from sqlalchemy.sql.schema import ForeignKey from sqlalchemy.dialects.mysql import insert from sqlalchemy.sql import func, select from sqlalchemy.engine import Row import bcrypt from core.data.schema.base import BaseData, metadata aime_user = Table( "aime_user", metadata, Column("id", Integer, nullable=False, primary_key=True, autoincrement=True), Column("username", String(25), unique=True), Column("email", String(255), unique=True), Column("password", String(255)), Column("permissions", Integer), Column("created_date", TIMESTAMP, server_default=func.now()), Column("last_login_date", TIMESTAMP, onupdate=func.now()), Column("suspend_expire_time", TIMESTAMP), mysql_charset="utf8mb4", ) game_locks = Table( "aime_user_game_locks", metadata, Column("id", Integer, nullable=False, primary_key=True, autoincrement=True), Column("user", ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"), nullable=False), Column("game", String(4), nullable=False), Column("expires_at", TIMESTAMP, server_default=func.date_add(func.now(), text("INTERVAL 15 MINUTE"))), Column("extra", JSON), UniqueConstraint("user", "game", name="aime_user_title_locks"), mysql_charset="utf8mb4", ) class UserData(BaseData): async def create_user( self, id: int = None, username: str = None, email: str = None, password: str = None, permission: int = 1, ) -> Optional[int]: if id is None: sql = insert(aime_user).values( username=username, email=email, password=password, permissions=permission, ) else: sql = insert(aime_user).values( id=id, username=username, email=email, password=password, permissions=permission, ) conflict = sql.on_duplicate_key_update( username=username, email=email, password=password, permissions=permission ) result = await self.execute(conflict) if result is None: return None return result.lastrowid async def get_user(self, user_id: int) -> Optional[Row]: sql = select(aime_user).where(aime_user.c.id == user_id) result = await self.execute(sql) if result is None: return False return result.fetchone() async def check_password(self, user_id: int, passwd: bytes = None) -> bool: usr = await self.get_user(user_id) if usr is None: return False if usr["password"] is None: return False if passwd is None or not passwd: return False return bcrypt.checkpw(passwd, usr["password"].encode()) async def delete_user(self, user_id: int) -> None: sql = aime_user.delete(aime_user.c.id == user_id) result = await self.execute(sql) if result is None: self.logger.error(f"Failed to delete user with id {user_id}") async def get_unregistered_users(self) -> List[Row]: """ Returns a list of users who have not registered with the webui. They may or may not have cards. """ sql = select(aime_user).where(aime_user.c.password == None) result = await self.execute(sql) if result is None: return None return result.fetchall() async def find_user_by_email(self, email: str) -> Row: sql = select(aime_user).where(aime_user.c.email == email) result = await self.execute(sql) if result is None: return False return result.fetchone() async def find_user_by_username(self, username: str) -> List[Row]: sql = aime_user.select(aime_user.c.username.like(f"%{username}%")) result = await self.execute(sql) if result is None: return False return result.fetchall() async def change_password(self, user_id: int, new_passwd: str) -> bool: sql = aime_user.update(aime_user.c.id == user_id).values(password = new_passwd) result = await self.execute(sql) return result is not None async def change_username(self, user_id: int, new_name: str) -> bool: sql = aime_user.update(aime_user.c.id == user_id).values(username = new_name) result = await self.execute(sql) return result is not None async def get_user_by_username(self, username: str) -> Optional[Row]: result = await self.execute(aime_user.select(aime_user.c.username == username)) if result: return result.fetchone() async def check_lock_for_game(self, user_id: int, game: str): sql = game_locks.select( (game_locks.c.user == user_id) & (game_locks.c.game == game) & (func.timestampdiff(text("SECOND"), func.now(), game_locks.c.expires_at) > 0)) result = await self.execute(sql) if result: return result.fetchone() return None async def acquire_lock_for_game(self, user_id: int, game: str, extra: dict | None = None): result = await self.check_lock_for_game(user_id, game) if result is not None: return result sql = ( insert(game_locks) .values(user=user_id, game=game, extra=extra) .on_duplicate_key_update( expires_at=func.date_add(func.now(), text("INTERVAL 15 MINUTE")), extra=extra, ) ) await self.execute(sql) return None async def release_lock_for_game(self, user_id: int, game: str): sql = game_locks.delete((game_locks.c.user == user_id) & (game_locks.c.game == game)) await self.execute(sql) async def extend_lock_for_game(self, user_id: int, game: str): sql = game_locks.update().where(user=user_id, game=game).values(expires_at=func.date_add(func.now(), text("INTERVAL 15 MINUTE"))) await self.execute(sql)