forked from Hay1tsme/artemis
236 lines
8.1 KiB
Python
236 lines
8.1 KiB
Python
import json
|
|
import logging
|
|
import sqlite3
|
|
from random import randrange
|
|
from typing import Any, Optional, Dict, List
|
|
|
|
import sqlalchemy.dialects.mysql
|
|
from sqlalchemy.engine import Row, CursorResult
|
|
from sqlalchemy.engine.cursor import CursorResult
|
|
from sqlalchemy.engine.base import Connection
|
|
from sqlalchemy.sql import text, func, select, insert
|
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
|
|
from sqlalchemy import MetaData, Table, Column
|
|
from sqlalchemy.types import Integer, String, TIMESTAMP, JSON
|
|
|
|
from core.config import CoreConfig
|
|
|
|
metadata = MetaData()
|
|
|
|
schema_ver = Table(
|
|
"schema_versions",
|
|
metadata,
|
|
Column("game", String(4), primary_key=True, nullable=False),
|
|
Column("version", Integer, nullable=False, server_default="1"),
|
|
mysql_charset="utf8mb4",
|
|
)
|
|
|
|
event_log = Table(
|
|
"event_log",
|
|
metadata,
|
|
Column("id", Integer, primary_key=True, nullable=False),
|
|
Column("system", String(255), nullable=False),
|
|
Column("type", String(255), nullable=False),
|
|
Column("severity", Integer, nullable=False),
|
|
Column("message", String(1000), nullable=False),
|
|
Column("details", JSON, nullable=False),
|
|
Column("when_logged", TIMESTAMP, nullable=False, server_default=func.now()),
|
|
mysql_charset="utf8mb4",
|
|
)
|
|
|
|
|
|
class BaseData:
|
|
def __init__(self, cfg: CoreConfig, conn: Connection) -> None:
|
|
self.config = cfg
|
|
self.conn = conn
|
|
self.logger = logging.getLogger("database")
|
|
|
|
def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]:
|
|
res = None
|
|
|
|
try:
|
|
self.logger.info(f"SQL Execute: {''.join(str(sql).splitlines())}")
|
|
res = self.conn.execute(text(sql), opts)
|
|
|
|
except SQLAlchemyError as e:
|
|
self.logger.error(f"SQLAlchemy error {e}")
|
|
return None
|
|
|
|
except UnicodeEncodeError as e:
|
|
self.logger.error(f"UnicodeEncodeError error {e}")
|
|
return None
|
|
|
|
except Exception:
|
|
try:
|
|
res = self.conn.execute(sql, opts)
|
|
|
|
except SQLAlchemyError as e:
|
|
self.logger.error(f"SQLAlchemy error {e}")
|
|
return None
|
|
|
|
except UnicodeEncodeError as e:
|
|
self.logger.error(f"UnicodeEncodeError error {e}")
|
|
return None
|
|
|
|
except Exception:
|
|
self.logger.error(f"Unknown error")
|
|
raise
|
|
|
|
return res
|
|
|
|
def generate_id(self) -> int:
|
|
"""
|
|
Generate a random 5-7 digit id
|
|
"""
|
|
return randrange(10000, 9999999)
|
|
|
|
def get_all_schema_vers(self) -> Optional[List[Row]]:
|
|
sql = select(schema_ver)
|
|
|
|
result = self.execute(sql)
|
|
if result is None:
|
|
return None
|
|
return result.fetchall()
|
|
|
|
def get_schema_ver(self, game: str) -> Optional[int]:
|
|
sql = select(schema_ver).where(schema_ver.c.game == game)
|
|
|
|
result = self.execute(sql)
|
|
if result is None:
|
|
return None
|
|
|
|
row = result.fetchone()
|
|
if row is None:
|
|
return None
|
|
|
|
return row["version"]
|
|
|
|
def upsert(self, table: Table, data: Dict) -> CursorResult | None:
|
|
# fall back to the old behavior for mysql, and use current behavior for postgres
|
|
|
|
sql = self.dialect_insert(table).values(data)
|
|
unique_columns = get_unique_columns(table)
|
|
if len(unique_columns) > 0:
|
|
if self.conn.bind.name == "mysql":
|
|
conflict = sql.on_duplicate_key_update(**data)
|
|
elif self.conn.bind.name == "postgresql":
|
|
conflict = sql.on_conflict_do_update(
|
|
index_elements=unique_columns,
|
|
set_=data,
|
|
)
|
|
|
|
if self.conn.bind.name == "sqlite":
|
|
from sqlalchemy import exc
|
|
try:
|
|
result = self.conn.execute(sql)
|
|
except (exc.IntegrityError, sqlite3.IntegrityError) as e:
|
|
# run update query using the values from unique_columns
|
|
from operator import and_
|
|
conditions = [getattr(table.c, col) == data[col] for col in unique_columns]
|
|
|
|
if len(conditions) > 1:
|
|
from functools import reduce
|
|
where_clause = reduce(and_, conditions)
|
|
elif conditions:
|
|
where_clause = conditions[0]
|
|
|
|
upd = table.update().where(where_clause).values(data)
|
|
self.logger.error(f"SQL post-conflict update ({', '.join(unique_columns)}): {str(upd)}")
|
|
result = self.execute(upd)
|
|
else:
|
|
result = self.execute(conflict)
|
|
else:
|
|
result = self.execute(sql)
|
|
if result is None:
|
|
self.logger.error(f"Failed to upsert data into {table.name}")
|
|
return None
|
|
return result
|
|
|
|
def dialect_insert(self, table: Table):
|
|
if self.conn.bind.name == "mysql":
|
|
return sqlalchemy.dialects.mysql.insert(table)
|
|
elif self.conn.bind.name == "postgresql":
|
|
return sqlalchemy.dialects.postgresql.insert(table)
|
|
elif self.conn.bind.name == "sqlite":
|
|
return sqlalchemy.dialects.sqlite.insert(table)
|
|
else:
|
|
raise Exception("Unknown dialect")
|
|
|
|
def touch_schema_ver(self, ver: int, game: str = "CORE") -> Optional[int]:
|
|
result = self.upsert(schema_ver, dict(game=game, version=ver))
|
|
if result is None:
|
|
self.logger.error(
|
|
f"Failed to update schema version for game {game} (v{ver})"
|
|
)
|
|
return None
|
|
return result.lastrowid
|
|
|
|
def set_schema_ver(self, ver: int, game: str = "CORE") -> Optional[int]:
|
|
result = self.upsert(schema_ver, dict(game=game, version=ver))
|
|
if result is None:
|
|
self.logger.error(
|
|
f"Failed to update schema version for game {game} (v{ver})"
|
|
)
|
|
return None
|
|
return result.lastrowid
|
|
|
|
def log_event(
|
|
self, system: str, type: str, severity: int, message: str, details: Dict = {}
|
|
) -> Optional[int]:
|
|
sql = event_log.insert().values(
|
|
system=system,
|
|
type=type,
|
|
severity=severity,
|
|
message=message,
|
|
details=json.dumps(details),
|
|
)
|
|
result = self.execute(sql)
|
|
|
|
if result is None:
|
|
self.logger.error(
|
|
f"{__name__}: Failed to insert event into event log! system = {system}, type = {type}, severity = {severity}, message = {message}"
|
|
)
|
|
return None
|
|
|
|
return result.lastrowid
|
|
|
|
def get_event_log(self, entries: int = 100) -> Optional[List[Dict]]:
|
|
sql = event_log.select().limit(entries).all()
|
|
result = self.execute(sql)
|
|
|
|
if result is None:
|
|
return None
|
|
return result.fetchall()
|
|
|
|
def fix_bools(self, data: Dict) -> Dict:
|
|
for k, v in data.items():
|
|
if k == "userName" or k == "teamName":
|
|
continue
|
|
if type(v) == str and v.lower() == "true":
|
|
data[k] = True
|
|
elif type(v) == str and v.lower() == "false":
|
|
data[k] = False
|
|
|
|
return data
|
|
|
|
|
|
unique_columns_cache = {}
|
|
|
|
def get_unique_columns(table: Table):
|
|
global unique_columns_cache
|
|
|
|
# Check if the unique columns for this table are already cached
|
|
if table.name in unique_columns_cache:
|
|
return unique_columns_cache[table.name]
|
|
|
|
# Calculate the unique columns for this table
|
|
unique_columns = []
|
|
from sqlalchemy import UniqueConstraint, PrimaryKeyConstraint
|
|
has_unique_constraint = any(isinstance(constraint, UniqueConstraint) for constraint in table.constraints)
|
|
for constraint in table.constraints:
|
|
# if we do not have a unique constraint, use the primary key
|
|
if (not has_unique_constraint and isinstance(constraint, PrimaryKeyConstraint)) or isinstance(constraint, UniqueConstraint):
|
|
unique_columns.extend([column.name for column in constraint.columns])
|
|
unique_columns_cache[table.name] = unique_columns
|
|
print(unique_columns)
|
|
return unique_columns |