forked from Dniel97/artemis
		
	
		
			
				
	
	
		
			161 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			161 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging, coloredlogs
 | |
| from typing import Optional
 | |
| from sqlalchemy.orm import scoped_session, sessionmaker
 | |
| from sqlalchemy import create_engine
 | |
| from logging.handlers import TimedRotatingFileHandler
 | |
| import os
 | |
| import secrets, string
 | |
| import bcrypt
 | |
| from hashlib import sha256
 | |
| import alembic.config
 | |
| 
 | |
| from core.config import CoreConfig
 | |
| from core.data.schema import *
 | |
| from core.utils import Utils
 | |
| 
 | |
| 
 | |
| class Data:
 | |
|     engine = None
 | |
|     session = None
 | |
|     user = None
 | |
|     arcade = None
 | |
|     card = None
 | |
|     base = None
 | |
|     def __init__(self, cfg: CoreConfig) -> None:
 | |
|         self.config = cfg
 | |
| 
 | |
|         if self.config.database.sha2_password:
 | |
|             passwd = sha256(self.config.database.password.encode()).digest()
 | |
|             self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4"
 | |
|         else:
 | |
|             self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4"
 | |
| 
 | |
|         if Data.engine is None:
 | |
|             Data.engine = create_engine(self.__url, pool_recycle=3600)
 | |
|             self.__engine = Data.engine
 | |
| 
 | |
|         if Data.session is None:
 | |
|             s = sessionmaker(bind=Data.engine, autoflush=True, autocommit=True)
 | |
|             Data.session = scoped_session(s)
 | |
| 
 | |
|         if Data.user is None:
 | |
|             Data.user = UserData(self.config, self.session)
 | |
|         
 | |
|         if Data.arcade is None:
 | |
|             Data.arcade = ArcadeData(self.config, self.session)
 | |
|         
 | |
|         if Data.card is None:
 | |
|             Data.card = CardData(self.config, self.session)
 | |
|         
 | |
|         if Data.base is None:
 | |
|             Data.base = BaseData(self.config, self.session)
 | |
| 
 | |
|         self.logger = logging.getLogger("database")
 | |
| 
 | |
|         # Prevent the logger from adding handlers multiple times
 | |
|         if not getattr(self.logger, "handler_set", None):            
 | |
|             log_fmt_str = "[%(asctime)s] %(levelname)s | Database | %(message)s"
 | |
|             log_fmt = logging.Formatter(log_fmt_str)
 | |
|             fileHandler = TimedRotatingFileHandler(
 | |
|                 "{0}/{1}.log".format(self.config.server.log_dir, "db"),
 | |
|                 encoding="utf-8",
 | |
|                 when="d",
 | |
|                 backupCount=10,
 | |
|             )
 | |
|             fileHandler.setFormatter(log_fmt)
 | |
| 
 | |
|             consoleHandler = logging.StreamHandler()
 | |
|             consoleHandler.setFormatter(log_fmt)
 | |
| 
 | |
|             self.logger.addHandler(fileHandler)
 | |
|             self.logger.addHandler(consoleHandler)
 | |
| 
 | |
|             self.logger.setLevel(self.config.database.loglevel)
 | |
|             coloredlogs.install(
 | |
|                 cfg.database.loglevel, logger=self.logger, fmt=log_fmt_str
 | |
|             )
 | |
|             self.logger.handler_set = True  # type: ignore
 | |
| 
 | |
|     def __alembic_cmd(self, command: str, *args: str) -> None:
 | |
|         old_dir = os.path.abspath(os.path.curdir)
 | |
|         base_dir = os.path.join(os.path.abspath(os.path.curdir), 'core', 'data', 'alembic')
 | |
|         alembicArgs = [
 | |
|             "-c",
 | |
|             os.path.join(base_dir, "alembic.ini"),
 | |
|             "-x",
 | |
|             f"script_location={base_dir}",
 | |
|             "-x",
 | |
|             f"sqlalchemy.url={self.__url}",
 | |
|             command,
 | |
|         ]
 | |
|         alembicArgs.extend(args)
 | |
|         os.chdir(base_dir)
 | |
|         alembic.config.main(argv=alembicArgs)
 | |
|         os.chdir(old_dir)
 | |
| 
 | |
|     def create_database(self):
 | |
|         self.logger.info("Creating databases...")
 | |
|         metadata.create_all(
 | |
|             self.engine,
 | |
|             checkfirst=True,
 | |
|         )
 | |
| 
 | |
|         for _, mod in Utils.get_all_titles().items():
 | |
|             if hasattr(mod, "database"):
 | |
|                 mod.database(self.config)
 | |
|                 metadata.create_all(
 | |
|                     self.engine,
 | |
|                     checkfirst=True,
 | |
|                 )
 | |
| 
 | |
|         # Stamp the end revision as if alembic had created it, so it can take off after this.
 | |
|         self.__alembic_cmd(
 | |
|             "stamp",
 | |
|             "head",
 | |
|         )
 | |
| 
 | |
|     def schema_upgrade(self, ver: str = None):
 | |
|         self.__alembic_cmd(
 | |
|             "upgrade",
 | |
|             "head",
 | |
|         )
 | |
| 
 | |
|     async def create_owner(self, email: Optional[str] = None, code: Optional[str] = "00000000000000000000") -> None:
 | |
|         pw = "".join(
 | |
|             secrets.choice(string.ascii_letters + string.digits) for i in range(20)
 | |
|         )
 | |
|         hash = bcrypt.hashpw(pw.encode(), bcrypt.gensalt())
 | |
| 
 | |
|         user_id = await self.user.create_user("sysowner", email, hash.decode(), 255)
 | |
|         if user_id is None:
 | |
|             self.logger.error(f"Failed to create owner with email {email}")
 | |
|             return
 | |
| 
 | |
|         card_id = await self.card.create_card(user_id, code)
 | |
|         if card_id is None:
 | |
|             self.logger.error(f"Failed to create card for owner with id {user_id}")
 | |
|             return
 | |
| 
 | |
|         self.logger.warning(
 | |
|             f"Successfully created owner with email {email}, access code {code}, and password {pw} Make sure to change this password and assign a real card ASAP!"
 | |
|         )
 | |
|     
 | |
|     async def migrate(self) -> None:
 | |
|         exist = await self.base.execute("SELECT * FROM alembic_version")
 | |
|         if exist is not None:
 | |
|             self.logger.warn("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!")
 | |
|             return
 | |
|         
 | |
|         self.logger.info("Stamp with initial revision")
 | |
|         self.__alembic_cmd(
 | |
|             "stamp",
 | |
|             "835b862f9bf0",
 | |
|         )
 | |
| 
 | |
|         self.logger.info("Upgrade")
 | |
|         self.__alembic_cmd(
 | |
|             "upgrade",
 | |
|             "head",
 | |
|         )
 | |
| 
 |