diff --git a/core/config.py b/core/config.py index 2b96157..e05323b 100644 --- a/core/config.py +++ b/core/config.py @@ -1,6 +1,7 @@ import logging import os -from typing import Any +import ssl +from typing import Any, Union from typing_extensions import Optional @@ -222,7 +223,7 @@ class DatabaseConfig: ) @property - def ssl_verify_cert(self) -> Optional[bool]: + def ssl_verify_cert(self) -> Optional[Union[str, bool]]: return CoreConfig.get_config_field( self.__config, "core", "database", "ssl_verify_cert", default=None ) @@ -259,6 +260,53 @@ class DatabaseConfig: self.__config, "core", "database", "memcached_host", default="localhost" ) + def create_ssl_context_if_enabled(self): + if not self.ssl_enabled: + return + + no_ca = ( + self.ssl_cafile is None + and self.ssl_capath is None + ) + + ctx = ssl.create_default_context( + cafile=self.ssl_cafile, + capath=self.ssl_capath, + ) + ctx.check_hostname = not no_ca and self.ssl_verify_identity + + if self.ssl_verify_cert is None: + ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED + elif isinstance(self.ssl_verify_cert, bool): + ctx.verify_mode = ( + ssl.CERT_REQUIRED + if self.ssl_verify_cert + else ssl.CERT_NONE + ) + elif isinstance(self.ssl_verify_cert, str): + value = self.ssl_verify_cert.lower() + + if value in ("none", "0", "false", "no"): + ctx.verify_mode = ssl.CERT_NONE + elif value == "optional": + ctx.verify_mode = ssl.CERT_OPTIONAL + elif value in ("required", "1", "true", "yes"): + ctx.verify_mode = ssl.CERT_REQUIRED + else: + ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED + + if self.ssl_cert: + ctx.load_cert_chain( + self.ssl_cert, + self.ssl_key, + self.ssl_key_password, + ) + + if self.ssl_ciphers: + ctx.set_ciphers(self.ssl_ciphers) + + return ctx + class FrontendConfig: def __init__(self, parent_config: "CoreConfig") -> None: self.__config = parent_config diff --git a/core/data/alembic/env.py b/core/data/alembic/env.py index f2a8182..b175ee6 100644 --- a/core/data/alembic/env.py +++ b/core/data/alembic/env.py @@ -1,14 +1,18 @@ from __future__ import with_statement import asyncio +import os +from pathlib import Path import threading from logging.config import fileConfig +import yaml from alembic import context from sqlalchemy import pool from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import async_engine_from_config +from core.config import CoreConfig from core.data.schema.base import metadata # this is the Alembic Config object, which provides @@ -74,8 +78,18 @@ async def run_async_migrations() -> None: for override in overrides: ini_section[override] = overrides[override] + core_config = CoreConfig() + + with (Path("../../..") / os.environ["ARTEMIS_CFG_DIR"] / "core.yaml").open(encoding="utf-8") as f: + core_config.update(yaml.safe_load(f)) + connectable = async_engine_from_config( - ini_section, prefix="sqlalchemy.", poolclass=pool.NullPool + ini_section, + poolclass=pool.NullPool, + connect_args={ + "charset": "utf8mb4", + "ssl": core_config.database.create_ssl_context_if_enabled(), + } ) async with connectable.connect() as connection: diff --git a/core/data/database.py b/core/data/database.py index b36d65f..170665e 100644 --- a/core/data/database.py +++ b/core/data/database.py @@ -41,48 +41,14 @@ class Data: self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}" if Data.engine is MISSING: - connect_args: dict[str, Any] = { - "charset": "utf8mb4", - } - - if self.config.database.ssl_enabled: - no_ca = ( - self.config.database.ssl_cafile is None - and self.config.database.ssl_capath is None - ) - - ctx = ssl.create_default_context( - cafile=self.config.database.ssl_cafile, - capath=self.config.database.ssl_capath, - ) - ctx.check_hostname = self.config.database.ssl_verify_identity - - if self.config.database.ssl_verify_cert is None: - ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED - else: - ctx.verify_mode = ( - ssl.CERT_REQUIRED - if self.config.database.ssl_verify_cert - else ssl.CERT_NONE - ) - - if self.config.database.ssl_cert: - ctx.load_cert_chain( - self.config.database.ssl_cert, - self.config.database.ssl_key, - self.config.database.ssl_key_password, - ) - - if self.config.database.ssl_ciphers: - ctx.set_ciphers(self.config.database.ssl_ciphers) - - connect_args["ssl"] = ctx - Data.engine = create_async_engine( self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT", - connect_args=connect_args, + connect_args={ + "charset": "utf8mb4", + "ssl": self.config.database.create_ssl_context_if_enabled(), + }, ) self.__engine = Data.engine