Automatically determine the database used.

This commit is contained in:
Floatin 2023-10-22 09:28:32 +00:00
parent 7af88d7baa
commit ebb5ee8dd0
1 changed files with 43 additions and 24 deletions

View File

@ -10,7 +10,6 @@ from sqlalchemy.exc import SQLAlchemyError
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import yaml import yaml
import yaml
import argparse import argparse
import logging import logging
import coloredlogs import coloredlogs
@ -27,13 +26,13 @@ from titles.ongeki.const import OngekiConstants
class AquaData: class AquaData:
def __init__(self, aqua_db_path: str) -> None: def __init__(self, aqua_db_path: str) -> None:
if use_mysql: if '@' in aqua_db_path:
self.__engine = create_engine("mysql+pymysql://" + aqua_db_path, echo=False) self.__url = f"mysql+pymysql://{aqua_db_path}"
else: else:
self.__url = f"sqlite:///{aqua_db_path}" self.__url = f"sqlite:///{aqua_db_path}"
self.__engine = create_engine(self.__url, pool_recycle=3600) self.__engine = create_engine(self.__url, pool_recycle=3600, echo=False)
# self.inspector = reflection.Inspector.from_engine(self.__engine) # self.inspector = reflection.Inspector.from_engine(self.__engine)
session = sessionmaker(bind=self.__engine) session = sessionmaker(bind=self.__engine)
self.inspect = inspect(self.__engine) self.inspect = inspect(self.__engine)
@ -117,14 +116,31 @@ class Importer:
coloredlogs.install(level="INFO", logger=self.logger, fmt=log_fmt_str) coloredlogs.install(level="INFO", logger=self.logger, fmt=log_fmt_str)
self.logger.initialized = True self.logger.initialized = True
if use_mysql: if not os.path.isfile(f'{aqua_folder}/application.properties'):
self.aqua = AquaData(aqua_folder) self.logger.error("Could not locate AQUA application.properties file!")
else: exit(1)
with open(f'{aqua_folder}/application.properties') as file:
lines = file.readlines()
properties = {}
for line in lines:
line = line.strip()
if not line or line.startswith('#'):
continue
parts = line.split('=')
if len(parts) >= 2:
key = parts[0].strip()
value = '='.join(parts[1:]).strip()
properties[key] = value
db_driver = properties.get('spring.datasource.driver-class-name')
if 'sqlite' in db_driver:
aqua_db_path = None aqua_db_path = None
if os.path.exists(aqua_folder): db_url = properties.get('spring.datasource.url').split('sqlite:')[1]
temp = os.path.join(aqua_folder, "db.sqlite") temp = os.path.join(f'{aqua_folder}/{db_url}')
if os.path.isfile(temp): if os.path.isfile(temp):
aqua_db_path = temp aqua_db_path = temp
if not aqua_db_path: if not aqua_db_path:
self.logger.error("Could not locate AQUA db.sqlite file!") self.logger.error("Could not locate AQUA db.sqlite file!")
@ -132,6 +148,16 @@ class Importer:
self.aqua = AquaData(aqua_db_path) self.aqua = AquaData(aqua_db_path)
elif 'mysql' in db_driver or 'mariadb' in db_driver:
self.use_mysql = True
db_username = properties.get('spring.datasource.username')
db_password = properties.get('spring.datasource.password')
db_url = properties.get('spring.datasource.url').split('?')[0].split('//')[1]
self.aqua = AquaData(f'{db_username}:{db_password}@{db_url}')
else:
self.logger.error("Unknown database type!")
def get_user_id(self, luid: str): def get_user_id(self, luid: str):
user_id = self.data.card.get_user_id_from_card(access_code=luid) user_id = self.data.card.get_user_id_from_card(access_code=luid)
if user_id is not None: if user_id is not None:
@ -170,8 +196,7 @@ class Importer:
card_id: int, card_id: int,
) -> Dict: ) -> Dict:
row = row._asdict() row = row._asdict()
if not self.use_mysql:
if use_mysql is False:
for column in datetime_columns: for column in datetime_columns:
ts = row[column["name"]] ts = row[column["name"]]
if ts is None: if ts is None:
@ -712,25 +737,19 @@ def main():
"--config", "-c", type=str, help="Config directory to use", default="config" "--config", "-c", type=str, help="Config directory to use", default="config"
) )
parser.add_argument( parser.add_argument(
"aqua_data_path", "aqua_folder_path",
type=str, type=str,
help="Absolute folder path to AQUA /data folder, where db.sqlite is located in. You can also enter the Mysql database connection address (<user>:<password>@<host>:<port>/<database>)", help="The absolute folder path to the folder where AQUA is located, where the data folder and the application.properties file should be located.",
) )
args = parser.parse_args() args = parser.parse_args()
core_cfg = CoreConfig() core_cfg = CoreConfig()
core_cfg.update(yaml.safe_load(open(f"{args.config}/core.yaml"))) core_cfg.update(yaml.safe_load(open(f"{args.config}/core.yaml")))
global use_mysql importer = Importer(core_cfg, args.config, args.aqua_folder_path)
use_mysql = False
if '@' in args.aqua_data_path:
use_mysql = True
importer = Importer(core_cfg, args.config, args.aqua_data_path)
importer.import_chuni() importer.import_chuni()
importer.import_ongeki() importer.import_ongeki()
if __name__ == "__main__": if __name__ == "__main__":
main() main()