diff --git a/core/config.py b/core/config.py index dda43aa..8c1caee 100644 --- a/core/config.py +++ b/core/config.py @@ -1,6 +1,9 @@ -import logging, os +import logging +import os from typing import Any +from typing_extensions import Optional + class ServerConfig: def __init__(self, parent_config: "CoreConfig") -> None: self.__config = parent_config @@ -175,12 +178,36 @@ class DatabaseConfig: return CoreConfig.get_config_field( self.__config, "core", "database", "protocol", default="mysql" ) - + @property - def ssl_enabled(self) -> str: + def ssl_enabled(self) -> bool: return CoreConfig.get_config_field( self.__config, "core", "database", "ssl_enabled", default=False ) + + @property + def ssl_ca(self) -> Optional[str]: + return CoreConfig.get_config_field( + self.__config, "core", "database", "ssl_ca", default=None + ) + + @property + def ssl_cert(self) -> Optional[str]: + return CoreConfig.get_config_field( + self.__config, "core", "database", "ssl_cert", default=None + ) + + @property + def ssl_key(self) -> Optional[str]: + return CoreConfig.get_config_field( + self.__config, "core", "database", "ssl_key", default=None + ) + + @property + def ssl_key_password(self) -> Optional[str]: + return CoreConfig.get_config_field( + self.__config, "core", "database", "ssl_key_password", default=None + ) @property def sha2_password(self) -> bool: diff --git a/core/data/database.py b/core/data/database.py index 312c87c..1b01661 100644 --- a/core/data/database.py +++ b/core/data/database.py @@ -1,11 +1,12 @@ import logging import os import secrets +import ssl import string import warnings from hashlib import sha256 from logging.handlers import TimedRotatingFileHandler -from typing import ClassVar, Optional +from typing import Any, ClassVar, Optional import alembic.config import bcrypt @@ -35,12 +36,36 @@ class Data: if self.config.database.sha2_password: passwd = sha256(self.config.database.password.encode()).digest() - self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4" + self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}" else: - self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4" + self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}" if Data.engine is MISSING: - Data.engine = create_async_engine(self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT") + connect_args: dict[str, Any] = { + "charset": "utf8mb4", + } + + if self.config.database.ssl_enabled: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + if self.config.database.ssl_ca: + ssl_context.load_verify_locations(self.config.database.ssl_ca) + + if self.config.database.ssl_cert: + ssl_context.load_cert_chain( + self.config.database.ssl_cert, + self.config.database.ssl_key, + self.config.database.ssl_key_password, + ) + + connect_args["ssl"] = ssl_context + + Data.engine = create_async_engine( + self.__url, + pool_recycle=3600, + isolation_level="AUTOCOMMIT", + connect_args=connect_args, + ) self.__engine = Data.engine if Data.session is MISSING: