From 6c155a5e489791858f298f94c48929ce93db0c8c Mon Sep 17 00:00:00 2001 From: Kevin Trocolli Date: Sat, 8 Jul 2023 00:01:52 -0400 Subject: [PATCH] database: add static variables to prevent having multiple sessions --- core/data/database.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/core/data/database.py b/core/data/database.py index ffbefc0..1688812 100644 --- a/core/data/database.py +++ b/core/data/database.py @@ -15,6 +15,13 @@ from core.utils import Utils class Data: + current_schema_version = 4 + engine = None + session = None + user = None + arcade = None + card = None + base = None def __init__(self, cfg: CoreConfig) -> None: self.config = cfg @@ -24,22 +31,32 @@ class Data: else: self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}/{self.config.database.name}?charset=utf8mb4" - self.__engine = create_engine(self.__url, pool_recycle=3600) - session = sessionmaker(bind=self.__engine, autoflush=True, autocommit=True) - self.session = scoped_session(session) + if Data.engine is None: + Data.engine = create_engine(self.__url, pool_recycle=3600) + self.__engine = Data.engine - self.user = UserData(self.config, self.session) - self.arcade = ArcadeData(self.config, self.session) - self.card = CardData(self.config, self.session) - self.base = BaseData(self.config, self.session) - self.current_schema_version = 4 + if Data.session is None: + s = sessionmaker(bind=Data.engine, autoflush=True, autocommit=True) + Data.session = scoped_session(s) + + if Data.user is None: + Data.user = UserData(self.config, self.session) + + if Data.arcade is None: + Data.arcade = ArcadeData(self.config, self.session) + + if Data.card is None: + Data.card = CardData(self.config, self.session) + + if Data.base is None: + Data.base = BaseData(self.config, self.session) - log_fmt_str = "[%(asctime)s] %(levelname)s | Database | %(message)s" - log_fmt = logging.Formatter(log_fmt_str) self.logger = logging.getLogger("database") # Prevent the logger from adding handlers multiple times - if not getattr(self.logger, "handler_set", None): + if not getattr(self.logger, "handler_set", None): + log_fmt_str = "[%(asctime)s] %(levelname)s | Database | %(message)s" + log_fmt = logging.Formatter(log_fmt_str) fileHandler = TimedRotatingFileHandler( "{0}/{1}.log".format(self.config.server.log_dir, "db"), encoding="utf-8",