let black do it's magic

This commit is contained in:
2023-03-09 11:38:58 -05:00
parent fa7206848c
commit a76bb94eb1
150 changed files with 8474 additions and 4843 deletions

View File

@ -1,2 +1,2 @@
from core.data.database import Data
from core.data.cache import cached
from core.data.cache import cached

View File

@ -1,4 +1,3 @@
from typing import Any, Callable
from functools import wraps
import hashlib
@ -6,15 +5,17 @@ import pickle
import logging
from core.config import CoreConfig
cfg:CoreConfig = None # type: ignore
cfg: CoreConfig = None # type: ignore
# Make memcache optional
try:
import pylibmc # type: ignore
has_mc = True
except ModuleNotFoundError:
has_mc = False
def cached(lifetime: int=10, extra_key: Any=None) -> Callable:
def cached(lifetime: int = 10, extra_key: Any = None) -> Callable:
def _cached(func: Callable) -> Callable:
if has_mc:
hostname = "127.0.0.1"
@ -22,11 +23,10 @@ def cached(lifetime: int=10, extra_key: Any=None) -> Callable:
hostname = cfg.database.memcached_host
memcache = pylibmc.Client([hostname], binary=True)
memcache.behaviors = {"tcp_nodelay": True, "ketama": True}
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if lifetime is not None:
# Hash function args
items = kwargs.items()
hashable_args = (args[1:], sorted(list(items)))
@ -41,7 +41,7 @@ def cached(lifetime: int=10, extra_key: Any=None) -> Callable:
except pylibmc.Error as e:
logging.getLogger("database").error(f"Memcache failed: {e}")
result = None
if result is not None:
logging.getLogger("database").debug(f"Cache hit: {result}")
return result
@ -55,7 +55,9 @@ def cached(lifetime: int=10, extra_key: Any=None) -> Callable:
memcache.set(cache_key, result, lifetime)
return result
else:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

View File

@ -13,6 +13,7 @@ from core.config import CoreConfig
from core.data.schema import *
from core.utils import Utils
class Data:
def __init__(self, cfg: CoreConfig) -> None:
self.config = cfg
@ -22,7 +23,7 @@ class Data:
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}/{self.config.database.name}?charset=utf8mb4"
else:
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}/{self.config.database.name}?charset=utf8mb4"
self.__engine = create_engine(self.__url, pool_recycle=3600)
session = sessionmaker(bind=self.__engine, autoflush=True, autocommit=True)
self.session = scoped_session(session)
@ -38,11 +39,15 @@ class Data:
self.logger = logging.getLogger("database")
# Prevent the logger from adding handlers multiple times
if not getattr(self.logger, 'handler_set', None):
fileHandler = TimedRotatingFileHandler("{0}/{1}.log".format(self.config.server.log_dir, "db"), encoding="utf-8",
when="d", backupCount=10)
if not getattr(self.logger, "handler_set", None):
fileHandler = TimedRotatingFileHandler(
"{0}/{1}.log".format(self.config.server.log_dir, "db"),
encoding="utf-8",
when="d",
backupCount=10,
)
fileHandler.setFormatter(log_fmt)
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(log_fmt)
@ -50,8 +55,10 @@ class Data:
self.logger.addHandler(consoleHandler)
self.logger.setLevel(self.config.database.loglevel)
coloredlogs.install(cfg.database.loglevel, logger=self.logger, fmt=log_fmt_str)
self.logger.handler_set = True # type: ignore
coloredlogs.install(
cfg.database.loglevel, logger=self.logger, fmt=log_fmt_str
)
self.logger.handler_set = True # type: ignore
def create_database(self):
self.logger.info("Creating databases...")
@ -60,24 +67,32 @@ class Data:
except SQLAlchemyError as e:
self.logger.error(f"Failed to create databases! {e}")
return
games = Utils.get_all_titles()
for game_dir, game_mod in games.items():
try:
title_db = game_mod.database(self.config)
metadata.create_all(self.__engine.connect())
self.base.set_schema_ver(game_mod.current_schema_version, game_mod.game_codes[0])
self.base.set_schema_ver(
game_mod.current_schema_version, game_mod.game_codes[0]
)
except Exception as e:
self.logger.warning(f"Could not load database schema from {game_dir} - {e}")
self.logger.warning(
f"Could not load database schema from {game_dir} - {e}"
)
self.logger.info(f"Setting base_schema_ver to {self.schema_ver_latest}")
self.base.set_schema_ver(self.schema_ver_latest)
self.logger.info(f"Setting user auto_incrememnt to {self.config.database.user_table_autoincrement_start}")
self.user.reset_autoincrement(self.config.database.user_table_autoincrement_start)
self.logger.info(
f"Setting user auto_incrememnt to {self.config.database.user_table_autoincrement_start}"
)
self.user.reset_autoincrement(
self.config.database.user_table_autoincrement_start
)
def recreate_database(self):
self.logger.info("Dropping all databases...")
self.base.execute("SET FOREIGN_KEY_CHECKS=0")
@ -86,61 +101,79 @@ class Data:
except SQLAlchemyError as e:
self.logger.error(f"Failed to drop databases! {e}")
return
for root, dirs, files in os.walk("./titles"):
for dir in dirs:
for dir in dirs:
if not dir.startswith("__"):
try:
mod = importlib.import_module(f"titles.{dir}")
try:
title_db = mod.database(self.config)
metadata.drop_all(self.__engine.connect())
except Exception as e:
self.logger.warning(f"Could not load database schema from {dir} - {e}")
self.logger.warning(
f"Could not load database schema from {dir} - {e}"
)
except ImportError as e:
self.logger.warning(f"Failed to load database schema dir {dir} - {e}")
self.logger.warning(
f"Failed to load database schema dir {dir} - {e}"
)
break
self.base.execute("SET FOREIGN_KEY_CHECKS=1")
self.create_database()
def migrate_database(self, game: str, version: int, action: str) -> None:
old_ver = self.base.get_schema_ver(game)
sql = ""
if old_ver is None:
self.logger.error(f"Schema for game {game} does not exist, did you run the creation script?")
return
if old_ver == version:
self.logger.info(f"Schema for game {game} is already version {old_ver}, nothing to do")
return
if not os.path.exists(f"core/data/schema/versions/{game.upper()}_{version}_{action}.sql"):
self.logger.error(f"Could not find {action} script {game.upper()}_{version}_{action}.sql in core/data/schema/versions folder")
self.logger.error(
f"Schema for game {game} does not exist, did you run the creation script?"
)
return
with open(f"core/data/schema/versions/{game.upper()}_{version}_{action}.sql", "r", encoding="utf-8") as f:
if old_ver == version:
self.logger.info(
f"Schema for game {game} is already version {old_ver}, nothing to do"
)
return
if not os.path.exists(
f"core/data/schema/versions/{game.upper()}_{version}_{action}.sql"
):
self.logger.error(
f"Could not find {action} script {game.upper()}_{version}_{action}.sql in core/data/schema/versions folder"
)
return
with open(
f"core/data/schema/versions/{game.upper()}_{version}_{action}.sql",
"r",
encoding="utf-8",
) as f:
sql = f.read()
result = self.base.execute(sql)
if result is None:
self.logger.error("Error execuing sql script!")
return None
result = self.base.set_schema_ver(version, game)
if result is None:
self.logger.error("Error setting version in schema_version table!")
return None
self.logger.info(f"Successfully migrated {game} to schema version {version}")
def create_owner(self, email: Optional[str] = None) -> None:
pw = ''.join(secrets.choice(string.ascii_letters + string.digits) for i in range(20))
pw = "".join(
secrets.choice(string.ascii_letters + string.digits) for i in range(20)
)
hash = bcrypt.hashpw(pw.encode(), bcrypt.gensalt())
user_id = self.user.create_user(email=email, permission=255, password=hash)
@ -153,32 +186,38 @@ class Data:
self.logger.error(f"Failed to create card for owner with id {user_id}")
return
self.logger.warn(f"Successfully created owner with email {email}, access code 00000000000000000000, and password {pw} Make sure to change this password and assign a real card ASAP!")
self.logger.warn(
f"Successfully created owner with email {email}, access code 00000000000000000000, and password {pw} Make sure to change this password and assign a real card ASAP!"
)
def migrate_card(self, old_ac: str, new_ac: str, should_force: bool) -> None:
if old_ac == new_ac:
self.logger.error("Both access codes are the same!")
return
new_card = self.card.get_card_by_access_code(new_ac)
if new_card is None:
self.card.update_access_code(old_ac, new_ac)
return
if not should_force:
self.logger.warn(f"Card already exists for access code {new_ac} (id {new_card['id']}). If you wish to continue, rerun with the '--force' flag."\
f" All exiting data on the target card {new_ac} will be perminently erased and replaced with data from card {old_ac}.")
self.logger.warn(
f"Card already exists for access code {new_ac} (id {new_card['id']}). If you wish to continue, rerun with the '--force' flag."
f" All exiting data on the target card {new_ac} will be perminently erased and replaced with data from card {old_ac}."
)
return
self.logger.info(f"All exiting data on the target card {new_ac} will be perminently erased and replaced with data from card {old_ac}.")
self.logger.info(
f"All exiting data on the target card {new_ac} will be perminently erased and replaced with data from card {old_ac}."
)
self.card.delete_card(new_card["id"])
self.card.update_access_code(old_ac, new_ac)
hanging_user = self.user.get_user(new_card["user"])
if hanging_user["password"] is None:
self.logger.info(f"Delete hanging user {hanging_user['id']}")
self.user.delete_user(hanging_user['id'])
self.user.delete_user(hanging_user["id"])
def delete_hanging_users(self) -> None:
"""
Finds and deletes users that have not registered for the webui that have no cards assocated with them.
@ -186,13 +225,13 @@ class Data:
unreg_users = self.user.get_unregistered_users()
if unreg_users is None:
self.logger.error("Error occoured finding unregistered users")
for user in unreg_users:
cards = self.card.get_user_cards(user['id'])
cards = self.card.get_user_cards(user["id"])
if cards is None:
self.logger.error(f"Error getting cards for user {user['id']}")
continue
if not cards:
self.logger.info(f"Delete hanging user {user['id']}")
self.user.delete_user(user['id'])
self.user.delete_user(user["id"])

View File

@ -3,4 +3,4 @@ from core.data.schema.card import CardData
from core.data.schema.base import BaseData, metadata
from core.data.schema.arcade import ArcadeData
__all__ = ["UserData", "CardData", "BaseData", "metadata", "ArcadeData"]
__all__ = ["UserData", "CardData", "BaseData", "metadata", "ArcadeData"]

View File

@ -14,131 +14,186 @@ arcade = Table(
metadata,
Column("id", Integer, primary_key=True, nullable=False),
Column("name", String(255)),
Column("nickname", String(255)),
Column("nickname", String(255)),
Column("country", String(3)),
Column("country_id", Integer),
Column("state", String(255)),
Column("city", String(255)),
Column("region_id", Integer),
Column("timezone", String(255)),
mysql_charset='utf8mb4'
mysql_charset="utf8mb4",
)
machine = Table(
"machine",
metadata,
Column("id", Integer, primary_key=True, nullable=False),
Column("arcade", ForeignKey("arcade.id", ondelete="cascade", onupdate="cascade"), nullable=False),
Column(
"arcade",
ForeignKey("arcade.id", ondelete="cascade", onupdate="cascade"),
nullable=False,
),
Column("serial", String(15), nullable=False),
Column("board", String(15)),
Column("game", String(4)),
Column("country", String(3)), # overwrites if not null
Column("country", String(3)), # overwrites if not null
Column("timezone", String(255)),
Column("ota_enable", Boolean),
Column("is_cab", Boolean),
mysql_charset='utf8mb4'
mysql_charset="utf8mb4",
)
arcade_owner = Table(
'arcade_owner',
"arcade_owner",
metadata,
Column('user', Integer, ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"), nullable=False),
Column('arcade', Integer, ForeignKey("arcade.id", ondelete="cascade", onupdate="cascade"), nullable=False),
Column('permissions', Integer, nullable=False),
PrimaryKeyConstraint('user', 'arcade', name='arcade_owner_pk'),
mysql_charset='utf8mb4'
Column(
"user",
Integer,
ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"),
nullable=False,
),
Column(
"arcade",
Integer,
ForeignKey("arcade.id", ondelete="cascade", onupdate="cascade"),
nullable=False,
),
Column("permissions", Integer, nullable=False),
PrimaryKeyConstraint("user", "arcade", name="arcade_owner_pk"),
mysql_charset="utf8mb4",
)
class ArcadeData(BaseData):
def get_machine(self, serial: str = None, id: int = None) -> Optional[Dict]:
if serial is not None:
serial = serial.replace("-", "")
if len(serial) == 11:
sql = machine.select(machine.c.serial.like(f"{serial}%"))
elif len(serial) == 15:
sql = machine.select(machine.c.serial == serial)
else:
self.logger.error(f"{__name__ }: Malformed serial {serial}")
return None
elif id is not None:
sql = machine.select(machine.c.id == id)
else:
else:
self.logger.error(f"{__name__ }: Need either serial or ID to look up!")
return None
result = self.execute(sql)
if result is None: return None
if result is None:
return None
return result.fetchone()
def put_machine(self, arcade_id: int, serial: str = "", board: str = None, game: str = None, is_cab: bool = False) -> Optional[int]:
def put_machine(
self,
arcade_id: int,
serial: str = "",
board: str = None,
game: str = None,
is_cab: bool = False,
) -> Optional[int]:
if arcade_id:
self.logger.error(f"{__name__ }: Need arcade id!")
return None
sql = machine.insert().values(arcade = arcade_id, keychip = serial, board = board, game = game, is_cab = is_cab)
result = self.execute(sql)
if result is None: return None
return result.lastrowid
def set_machine_serial(self, machine_id: int, serial: str) -> None:
result = self.execute(machine.update(machine.c.id == machine_id).values(keychip = serial))
if result is None:
self.logger.error(f"Failed to update serial for machine {machine_id} -> {serial}")
return result.lastrowid
def set_machine_boardid(self, machine_id: int, boardid: str) -> None:
result = self.execute(machine.update(machine.c.id == machine_id).values(board = boardid))
if result is None:
self.logger.error(f"Failed to update board id for machine {machine_id} -> {boardid}")
def get_arcade(self, id: int) -> Optional[Dict]:
sql = arcade.select(arcade.c.id == id)
result = self.execute(sql)
if result is None: return None
return result.fetchone()
def put_arcade(self, name: str, nickname: str = None, country: str = "JPN", country_id: int = 1,
state: str = "", city: str = "", regional_id: int = 1) -> Optional[int]:
if nickname is None: nickname = name
sql = arcade.insert().values(name = name, nickname = nickname, country = country, country_id = country_id,
state = state, city = city, regional_id = regional_id)
result = self.execute(sql)
if result is None: return None
return result.lastrowid
def get_arcade_owners(self, arcade_id: int) -> Optional[Dict]:
sql = select(arcade_owner).where(arcade_owner.c.arcade==arcade_id)
result = self.execute(sql)
if result is None: return None
return result.fetchall()
def add_arcade_owner(self, arcade_id: int, user_id: int) -> None:
sql = insert(arcade_owner).values(
arcade=arcade_id,
user=user_id
sql = machine.insert().values(
arcade=arcade_id, keychip=serial, board=board, game=game, is_cab=is_cab
)
result = self.execute(sql)
if result is None: return None
if result is None:
return None
return result.lastrowid
def format_serial(self, platform_code: str, platform_rev: int, serial_num: int, append: int = 4152) -> str:
return f"{platform_code}{platform_rev:02d}A{serial_num:04d}{append:04d}" # 0x41 = A, 0x52 = R
def set_machine_serial(self, machine_id: int, serial: str) -> None:
result = self.execute(
machine.update(machine.c.id == machine_id).values(keychip=serial)
)
if result is None:
self.logger.error(
f"Failed to update serial for machine {machine_id} -> {serial}"
)
return result.lastrowid
def set_machine_boardid(self, machine_id: int, boardid: str) -> None:
result = self.execute(
machine.update(machine.c.id == machine_id).values(board=boardid)
)
if result is None:
self.logger.error(
f"Failed to update board id for machine {machine_id} -> {boardid}"
)
def get_arcade(self, id: int) -> Optional[Dict]:
sql = arcade.select(arcade.c.id == id)
result = self.execute(sql)
if result is None:
return None
return result.fetchone()
def put_arcade(
self,
name: str,
nickname: str = None,
country: str = "JPN",
country_id: int = 1,
state: str = "",
city: str = "",
regional_id: int = 1,
) -> Optional[int]:
if nickname is None:
nickname = name
sql = arcade.insert().values(
name=name,
nickname=nickname,
country=country,
country_id=country_id,
state=state,
city=city,
regional_id=regional_id,
)
result = self.execute(sql)
if result is None:
return None
return result.lastrowid
def get_arcade_owners(self, arcade_id: int) -> Optional[Dict]:
sql = select(arcade_owner).where(arcade_owner.c.arcade == arcade_id)
result = self.execute(sql)
if result is None:
return None
return result.fetchall()
def add_arcade_owner(self, arcade_id: int, user_id: int) -> None:
sql = insert(arcade_owner).values(arcade=arcade_id, user=user_id)
result = self.execute(sql)
if result is None:
return None
return result.lastrowid
def format_serial(
self, platform_code: str, platform_rev: int, serial_num: int, append: int = 4152
) -> str:
return f"{platform_code}{platform_rev:02d}A{serial_num:04d}{append:04d}" # 0x41 = A, 0x52 = R
def validate_keychip_format(self, serial: str) -> bool:
serial = serial.replace("-", "")
if len(serial) != 11 or len(serial) != 15:
self.logger.error(f"Serial validate failed: Incorrect length for {serial} (len {len(serial)})")
self.logger.error(
f"Serial validate failed: Incorrect length for {serial} (len {len(serial)})"
)
return False
platform_code = serial[:4]
platform_rev = serial[4:6]
const_a = serial[6]
@ -150,11 +205,15 @@ class ArcadeData(BaseData):
return False
if len(append) != 0 or len(append) != 4:
self.logger.error(f"Serial validate failed: {serial} had malformed append {append}")
self.logger.error(
f"Serial validate failed: {serial} had malformed append {append}"
)
return False
if len(num) != 4:
self.logger.error(f"Serial validate failed: {serial} had malformed number {num}")
self.logger.error(
f"Serial validate failed: {serial} had malformed number {num}"
)
return False
return True

View File

@ -19,7 +19,7 @@ schema_ver = Table(
metadata,
Column("game", String(4), primary_key=True, nullable=False),
Column("version", Integer, nullable=False, server_default="1"),
mysql_charset='utf8mb4'
mysql_charset="utf8mb4",
)
event_log = Table(
@ -32,16 +32,17 @@ event_log = Table(
Column("message", String(1000), nullable=False),
Column("details", JSON, nullable=False),
Column("when_logged", TIMESTAMP, nullable=False, server_default=func.now()),
mysql_charset='utf8mb4'
mysql_charset="utf8mb4",
)
class BaseData():
class BaseData:
def __init__(self, cfg: CoreConfig, conn: Connection) -> None:
self.config = cfg
self.conn = conn
self.logger = logging.getLogger("database")
def execute(self, sql: str, opts: Dict[str, Any]={}) -> Optional[CursorResult]:
def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]:
res = None
try:
@ -51,7 +52,7 @@ class BaseData():
except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}")
return None
except UnicodeEncodeError as e:
self.logger.error(f"UnicodeEncodeError error {e}")
return None
@ -63,7 +64,7 @@ class BaseData():
except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}")
return None
except UnicodeEncodeError as e:
self.logger.error(f"UnicodeEncodeError error {e}")
return None
@ -73,58 +74,71 @@ class BaseData():
raise
return res
def generate_id(self) -> int:
"""
Generate a random 5-7 digit id
"""
return randrange(10000, 9999999)
def get_schema_ver(self, game: str) -> Optional[int]:
sql = select(schema_ver).where(schema_ver.c.game == game)
result = self.execute(sql)
if result is None:
return None
row = result.fetchone()
if row is None:
return None
return row["version"]
def set_schema_ver(self, ver: int, game: str = "CORE") -> Optional[int]:
sql = insert(schema_ver).values(game = game, version = ver)
conflict = sql.on_duplicate_key_update(version = ver)
sql = insert(schema_ver).values(game=game, version=ver)
conflict = sql.on_duplicate_key_update(version=ver)
result = self.execute(conflict)
if result is None:
self.logger.error(f"Failed to update schema version for game {game} (v{ver})")
self.logger.error(
f"Failed to update schema version for game {game} (v{ver})"
)
return None
return result.lastrowid
def log_event(self, system: str, type: str, severity: int, message: str, details: Dict = {}) -> Optional[int]:
sql = event_log.insert().values(system = system, type = type, severity = severity, message = message, details = json.dumps(details))
def log_event(
self, system: str, type: str, severity: int, message: str, details: Dict = {}
) -> Optional[int]:
sql = event_log.insert().values(
system=system,
type=type,
severity=severity,
message=message,
details=json.dumps(details),
)
result = self.execute(sql)
if result is None:
self.logger.error(f"{__name__}: Failed to insert event into event log! system = {system}, type = {type}, severity = {severity}, message = {message}")
self.logger.error(
f"{__name__}: Failed to insert event into event log! system = {system}, type = {type}, severity = {severity}, message = {message}"
)
return None
return result.lastrowid
def get_event_log(self, entries: int = 100) -> Optional[List[Dict]]:
sql = event_log.select().limit(entries).all()
result = self.execute(sql)
if result is None: return None
if result is None:
return None
return result.fetchall()
def fix_bools(self, data: Dict) -> Dict:
for k,v in data.items():
for k, v in data.items():
if type(v) == str and v.lower() == "true":
data[k] = True
elif type(v) == str and v.lower() == "false":
data[k] = False
return data

View File

@ -8,55 +8,67 @@ from sqlalchemy.engine import Row
from core.data.schema.base import BaseData, metadata
aime_card = Table(
'aime_card',
"aime_card",
metadata,
Column("id", Integer, primary_key=True, nullable=False),
Column("user", ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"), nullable=False),
Column(
"user",
ForeignKey("aime_user.id", ondelete="cascade", onupdate="cascade"),
nullable=False,
),
Column("access_code", String(20)),
Column("created_date", TIMESTAMP, server_default=func.now()),
Column("last_login_date", TIMESTAMP, onupdate=func.now()),
Column("is_locked", Boolean, server_default="0"),
Column("is_banned", Boolean, server_default="0"),
UniqueConstraint("user", "access_code", name="aime_card_uk"),
mysql_charset='utf8mb4'
mysql_charset="utf8mb4",
)
class CardData(BaseData):
def get_card_by_access_code(self, access_code: str) -> Optional[Row]:
sql = aime_card.select(aime_card.c.access_code == access_code)
result = self.execute(sql)
if result is None: return None
if result is None:
return None
return result.fetchone()
def get_card_by_id(self, card_id: int) -> Optional[Row]:
sql = aime_card.select(aime_card.c.id == card_id)
result = self.execute(sql)
if result is None: return None
if result is None:
return None
return result.fetchone()
def update_access_code(self, old_ac: str, new_ac: str) -> None:
sql = aime_card.update(aime_card.c.access_code == old_ac).values(access_code = new_ac)
sql = aime_card.update(aime_card.c.access_code == old_ac).values(
access_code=new_ac
)
result = self.execute(sql)
if result is None:
self.logger.error(f"Failed to change card access code from {old_ac} to {new_ac}")
if result is None:
self.logger.error(
f"Failed to change card access code from {old_ac} to {new_ac}"
)
def get_user_id_from_card(self, access_code: str) -> Optional[int]:
"""
Given a 20 digit access code as a string, get the user id associated with that card
"""
card = self.get_card_by_access_code(access_code)
if card is None: return None
if card is None:
return None
return int(card["user"])
def delete_card(self, card_id: int) -> None:
sql = aime_card.delete(aime_card.c.id == card_id)
result = self.execute(sql)
if result is None:
if result is None:
self.logger.error(f"Failed to delete card with id {card_id}")
def get_user_cards(self, aime_id: int) -> Optional[List[Row]]:
@ -65,17 +77,18 @@ class CardData(BaseData):
"""
sql = aime_card.select(aime_card.c.user == aime_id)
result = self.execute(sql)
if result is None: return None
if result is None:
return None
return result.fetchall()
def create_card(self, user_id: int, access_code: str) -> Optional[int]:
"""
Given a aime_user id and a 20 digit access code as a string, create a card and return the ID if successful
"""
sql = aime_card.insert().values(user=user_id, access_code=access_code)
result = self.execute(sql)
if result is None: return None
if result is None:
return None
return result.lastrowid
def to_access_code(self, luid: str) -> str:
@ -88,4 +101,4 @@ class CardData(BaseData):
"""
Given a 20 digit access code as a string, return the 16 hex character luid
"""
return f'{int(access_code):0{16}x}'
return f"{int(access_code):0{16}x}"

View File

@ -17,72 +17,81 @@ aime_user = Table(
Column("username", String(25), unique=True),
Column("email", String(255), unique=True),
Column("password", String(255)),
Column("permissions", Integer),
Column("permissions", Integer),
Column("created_date", TIMESTAMP, server_default=func.now()),
Column("last_login_date", TIMESTAMP, onupdate=func.now()),
Column("suspend_expire_time", TIMESTAMP),
mysql_charset='utf8mb4'
mysql_charset="utf8mb4",
)
class PermissionBits(Enum):
PermUser = 1
PermMod = 2
PermSysAdmin = 4
class UserData(BaseData):
def create_user(self, id: int = None, username: str = None, email: str = None, password: str = None, permission: int = 1) -> Optional[int]:
def create_user(
self,
id: int = None,
username: str = None,
email: str = None,
password: str = None,
permission: int = 1,
) -> Optional[int]:
if id is None:
sql = insert(aime_user).values(
username=username,
email=email,
password=password,
permissions=permission
username=username,
email=email,
password=password,
permissions=permission,
)
else:
sql = insert(aime_user).values(
id=id,
username=username,
email=email,
password=password,
permissions=permission
id=id,
username=username,
email=email,
password=password,
permissions=permission,
)
conflict = sql.on_duplicate_key_update(
username=username,
email=email,
password=password,
permissions=permission
username=username, email=email, password=password, permissions=permission
)
result = self.execute(conflict)
if result is None: return None
if result is None:
return None
return result.lastrowid
def get_user(self, user_id: int) -> Optional[Row]:
sql = select(aime_user).where(aime_user.c.id == user_id)
result = self.execute(sql)
if result is None: return False
if result is None:
return False
return result.fetchone()
def check_password(self, user_id: int, passwd: bytes = None) -> bool:
usr = self.get_user(user_id)
if usr is None: return False
if usr['password'] is None:
if usr is None:
return False
return bcrypt.checkpw(passwd, usr['password'].encode())
if usr["password"] is None:
return False
return bcrypt.checkpw(passwd, usr["password"].encode())
def reset_autoincrement(self, ai_value: int) -> None:
# ALTER TABLE isn't in sqlalchemy so we do this the ugly way
sql = f"ALTER TABLE aime_user AUTO_INCREMENT={ai_value}"
self.execute(sql)
def delete_user(self, user_id: int) -> None:
sql = aime_user.delete(aime_user.c.id == user_id)
result = self.execute(sql)
if result is None:
if result is None:
self.logger.error(f"Failed to delete user with id {user_id}")
def get_unregistered_users(self) -> List[Row]:
@ -92,6 +101,6 @@ class UserData(BaseData):
sql = select(aime_user).where(aime_user.c.password == None)
result = self.execute(sql)
if result is None:
if result is None:
return None
return result.fetchall()
return result.fetchall()