1
0
Fork 0
artemis/core/data/schema/base.py

236 lines
8.1 KiB
Python

import json
import logging
import sqlite3
from random import randrange
from typing import Any, Optional, Dict, List
import sqlalchemy.dialects.mysql
from sqlalchemy.engine import Row, CursorResult
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.base import Connection
from sqlalchemy.sql import text, func, select, insert
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
from sqlalchemy import MetaData, Table, Column
from sqlalchemy.types import Integer, String, TIMESTAMP, JSON
from core.config import CoreConfig
metadata = MetaData()
schema_ver = Table(
"schema_versions",
metadata,
Column("game", String(4), primary_key=True, nullable=False),
Column("version", Integer, nullable=False, server_default="1"),
mysql_charset="utf8mb4",
)
event_log = Table(
"event_log",
metadata,
Column("id", Integer, primary_key=True, nullable=False),
Column("system", String(255), nullable=False),
Column("type", String(255), nullable=False),
Column("severity", Integer, nullable=False),
Column("message", String(1000), nullable=False),
Column("details", JSON, nullable=False),
Column("when_logged", TIMESTAMP, nullable=False, server_default=func.now()),
mysql_charset="utf8mb4",
)
class BaseData:
def __init__(self, cfg: CoreConfig, conn: Connection) -> None:
self.config = cfg
self.conn = conn
self.logger = logging.getLogger("database")
def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]:
res = None
try:
self.logger.info(f"SQL Execute: {''.join(str(sql).splitlines())}")
res = self.conn.execute(text(sql), opts)
except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}")
return None
except UnicodeEncodeError as e:
self.logger.error(f"UnicodeEncodeError error {e}")
return None
except Exception:
try:
res = self.conn.execute(sql, opts)
except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}")
return None
except UnicodeEncodeError as e:
self.logger.error(f"UnicodeEncodeError error {e}")
return None
except Exception:
self.logger.error(f"Unknown error")
raise
return res
def generate_id(self) -> int:
"""
Generate a random 5-7 digit id
"""
return randrange(10000, 9999999)
def get_all_schema_vers(self) -> Optional[List[Row]]:
sql = select(schema_ver)
result = self.execute(sql)
if result is None:
return None
return result.fetchall()
def get_schema_ver(self, game: str) -> Optional[int]:
sql = select(schema_ver).where(schema_ver.c.game == game)
result = self.execute(sql)
if result is None:
return None
row = result.fetchone()
if row is None:
return None
return row["version"]
def upsert(self, table: Table, data: Dict) -> CursorResult | None:
# fall back to the old behavior for mysql, and use current behavior for postgres
sql = self.dialect_insert(table).values(data)
unique_columns = get_unique_columns(table)
if len(unique_columns) > 0:
if self.conn.bind.name == "mysql":
conflict = sql.on_duplicate_key_update(**data)
elif self.conn.bind.name == "postgresql":
conflict = sql.on_conflict_do_update(
index_elements=unique_columns,
set_=data,
)
if self.conn.bind.name == "sqlite":
from sqlalchemy import exc
try:
result = self.conn.execute(sql)
except (exc.IntegrityError, sqlite3.IntegrityError) as e:
# run update query using the values from unique_columns
from operator import and_
conditions = [getattr(table.c, col) == data[col] for col in unique_columns]
if len(conditions) > 1:
from functools import reduce
where_clause = reduce(and_, conditions)
elif conditions:
where_clause = conditions[0]
upd = table.update().where(where_clause).values(data)
self.logger.error(f"SQL post-conflict update ({', '.join(unique_columns)}): {str(upd)}")
result = self.execute(upd)
else:
result = self.execute(conflict)
else:
result = self.execute(sql)
if result is None:
self.logger.error(f"Failed to upsert data into {table.name}")
return None
return result
def dialect_insert(self, table: Table):
if self.conn.bind.name == "mysql":
return sqlalchemy.dialects.mysql.insert(table)
elif self.conn.bind.name == "postgresql":
return sqlalchemy.dialects.postgresql.insert(table)
elif self.conn.bind.name == "sqlite":
return sqlalchemy.dialects.sqlite.insert(table)
else:
raise Exception("Unknown dialect")
def touch_schema_ver(self, ver: int, game: str = "CORE") -> Optional[int]:
result = self.upsert(schema_ver, dict(game=game, version=ver))
if result is None:
self.logger.error(
f"Failed to update schema version for game {game} (v{ver})"
)
return None
return result.lastrowid
def set_schema_ver(self, ver: int, game: str = "CORE") -> Optional[int]:
result = self.upsert(schema_ver, dict(game=game, version=ver))
if result is None:
self.logger.error(
f"Failed to update schema version for game {game} (v{ver})"
)
return None
return result.lastrowid
def log_event(
self, system: str, type: str, severity: int, message: str, details: Dict = {}
) -> Optional[int]:
sql = event_log.insert().values(
system=system,
type=type,
severity=severity,
message=message,
details=json.dumps(details),
)
result = self.execute(sql)
if result is None:
self.logger.error(
f"{__name__}: Failed to insert event into event log! system = {system}, type = {type}, severity = {severity}, message = {message}"
)
return None
return result.lastrowid
def get_event_log(self, entries: int = 100) -> Optional[List[Dict]]:
sql = event_log.select().limit(entries).all()
result = self.execute(sql)
if result is None:
return None
return result.fetchall()
def fix_bools(self, data: Dict) -> Dict:
for k, v in data.items():
if k == "userName" or k == "teamName":
continue
if type(v) == str and v.lower() == "true":
data[k] = True
elif type(v) == str and v.lower() == "false":
data[k] = False
return data
unique_columns_cache = {}
def get_unique_columns(table: Table):
global unique_columns_cache
# Check if the unique columns for this table are already cached
if table.name in unique_columns_cache:
return unique_columns_cache[table.name]
# Calculate the unique columns for this table
unique_columns = []
from sqlalchemy import UniqueConstraint, PrimaryKeyConstraint
has_unique_constraint = any(isinstance(constraint, UniqueConstraint) for constraint in table.constraints)
for constraint in table.constraints:
# if we do not have a unique constraint, use the primary key
if (not has_unique_constraint and isinstance(constraint, PrimaryKeyConstraint)) or isinstance(constraint, UniqueConstraint):
unique_columns.extend([column.name for column in constraint.columns])
unique_columns_cache[table.name] = unique_columns
print(unique_columns)
return unique_columns