allow alembic to also connect with tls

This commit is contained in:
2024-11-16 10:58:13 +07:00
parent 5f3d62d84a
commit 6f4e5b0fa3
3 changed files with 69 additions and 41 deletions

View File

@ -1,6 +1,7 @@
import logging
import os
from typing import Any
import ssl
from typing import Any, Union
from typing_extensions import Optional
@ -222,7 +223,7 @@ class DatabaseConfig:
)
@property
def ssl_verify_cert(self) -> Optional[bool]:
def ssl_verify_cert(self) -> Optional[Union[str, bool]]:
return CoreConfig.get_config_field(
self.__config, "core", "database", "ssl_verify_cert", default=None
)
@ -259,6 +260,53 @@ class DatabaseConfig:
self.__config, "core", "database", "memcached_host", default="localhost"
)
def create_ssl_context_if_enabled(self):
if not self.ssl_enabled:
return
no_ca = (
self.ssl_cafile is None
and self.ssl_capath is None
)
ctx = ssl.create_default_context(
cafile=self.ssl_cafile,
capath=self.ssl_capath,
)
ctx.check_hostname = not no_ca and self.ssl_verify_identity
if self.ssl_verify_cert is None:
ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED
elif isinstance(self.ssl_verify_cert, bool):
ctx.verify_mode = (
ssl.CERT_REQUIRED
if self.ssl_verify_cert
else ssl.CERT_NONE
)
elif isinstance(self.ssl_verify_cert, str):
value = self.ssl_verify_cert.lower()
if value in ("none", "0", "false", "no"):
ctx.verify_mode = ssl.CERT_NONE
elif value == "optional":
ctx.verify_mode = ssl.CERT_OPTIONAL
elif value in ("required", "1", "true", "yes"):
ctx.verify_mode = ssl.CERT_REQUIRED
else:
ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED
if self.ssl_cert:
ctx.load_cert_chain(
self.ssl_cert,
self.ssl_key,
self.ssl_key_password,
)
if self.ssl_ciphers:
ctx.set_ciphers(self.ssl_ciphers)
return ctx
class FrontendConfig:
def __init__(self, parent_config: "CoreConfig") -> None:
self.__config = parent_config

View File

@ -1,14 +1,18 @@
from __future__ import with_statement
import asyncio
import os
from pathlib import Path
import threading
from logging.config import fileConfig
import yaml
from alembic import context
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from core.config import CoreConfig
from core.data.schema.base import metadata
# this is the Alembic Config object, which provides
@ -74,8 +78,18 @@ async def run_async_migrations() -> None:
for override in overrides:
ini_section[override] = overrides[override]
core_config = CoreConfig()
with (Path("../../..") / os.environ["ARTEMIS_CFG_DIR"] / "core.yaml").open(encoding="utf-8") as f:
core_config.update(yaml.safe_load(f))
connectable = async_engine_from_config(
ini_section, prefix="sqlalchemy.", poolclass=pool.NullPool
ini_section,
poolclass=pool.NullPool,
connect_args={
"charset": "utf8mb4",
"ssl": core_config.database.create_ssl_context_if_enabled(),
}
)
async with connectable.connect() as connection:

View File

@ -41,48 +41,14 @@ class Data:
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}"
if Data.engine is MISSING:
connect_args: dict[str, Any] = {
"charset": "utf8mb4",
}
if self.config.database.ssl_enabled:
no_ca = (
self.config.database.ssl_cafile is None
and self.config.database.ssl_capath is None
)
ctx = ssl.create_default_context(
cafile=self.config.database.ssl_cafile,
capath=self.config.database.ssl_capath,
)
ctx.check_hostname = self.config.database.ssl_verify_identity
if self.config.database.ssl_verify_cert is None:
ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED
else:
ctx.verify_mode = (
ssl.CERT_REQUIRED
if self.config.database.ssl_verify_cert
else ssl.CERT_NONE
)
if self.config.database.ssl_cert:
ctx.load_cert_chain(
self.config.database.ssl_cert,
self.config.database.ssl_key,
self.config.database.ssl_key_password,
)
if self.config.database.ssl_ciphers:
ctx.set_ciphers(self.config.database.ssl_ciphers)
connect_args["ssl"] = ctx
Data.engine = create_async_engine(
self.__url,
pool_recycle=3600,
isolation_level="AUTOCOMMIT",
connect_args=connect_args,
connect_args={
"charset": "utf8mb4",
"ssl": self.config.database.create_ssl_context_if_enabled(),
},
)
self.__engine = Data.engine