forked from Hay1tsme/artemis
allow alembic to also connect with tls
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
import ssl
|
||||
from typing import Any, Union
|
||||
|
||||
from typing_extensions import Optional
|
||||
|
||||
@ -222,7 +223,7 @@ class DatabaseConfig:
|
||||
)
|
||||
|
||||
@property
|
||||
def ssl_verify_cert(self) -> Optional[bool]:
|
||||
def ssl_verify_cert(self) -> Optional[Union[str, bool]]:
|
||||
return CoreConfig.get_config_field(
|
||||
self.__config, "core", "database", "ssl_verify_cert", default=None
|
||||
)
|
||||
@ -259,6 +260,53 @@ class DatabaseConfig:
|
||||
self.__config, "core", "database", "memcached_host", default="localhost"
|
||||
)
|
||||
|
||||
def create_ssl_context_if_enabled(self):
|
||||
if not self.ssl_enabled:
|
||||
return
|
||||
|
||||
no_ca = (
|
||||
self.ssl_cafile is None
|
||||
and self.ssl_capath is None
|
||||
)
|
||||
|
||||
ctx = ssl.create_default_context(
|
||||
cafile=self.ssl_cafile,
|
||||
capath=self.ssl_capath,
|
||||
)
|
||||
ctx.check_hostname = not no_ca and self.ssl_verify_identity
|
||||
|
||||
if self.ssl_verify_cert is None:
|
||||
ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED
|
||||
elif isinstance(self.ssl_verify_cert, bool):
|
||||
ctx.verify_mode = (
|
||||
ssl.CERT_REQUIRED
|
||||
if self.ssl_verify_cert
|
||||
else ssl.CERT_NONE
|
||||
)
|
||||
elif isinstance(self.ssl_verify_cert, str):
|
||||
value = self.ssl_verify_cert.lower()
|
||||
|
||||
if value in ("none", "0", "false", "no"):
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
elif value == "optional":
|
||||
ctx.verify_mode = ssl.CERT_OPTIONAL
|
||||
elif value in ("required", "1", "true", "yes"):
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
else:
|
||||
ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED
|
||||
|
||||
if self.ssl_cert:
|
||||
ctx.load_cert_chain(
|
||||
self.ssl_cert,
|
||||
self.ssl_key,
|
||||
self.ssl_key_password,
|
||||
)
|
||||
|
||||
if self.ssl_ciphers:
|
||||
ctx.set_ciphers(self.ssl_ciphers)
|
||||
|
||||
return ctx
|
||||
|
||||
class FrontendConfig:
|
||||
def __init__(self, parent_config: "CoreConfig") -> None:
|
||||
self.__config = parent_config
|
||||
|
@ -1,14 +1,18 @@
|
||||
from __future__ import with_statement
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from logging.config import fileConfig
|
||||
|
||||
import yaml
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from core.config import CoreConfig
|
||||
from core.data.schema.base import metadata
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
@ -74,8 +78,18 @@ async def run_async_migrations() -> None:
|
||||
for override in overrides:
|
||||
ini_section[override] = overrides[override]
|
||||
|
||||
core_config = CoreConfig()
|
||||
|
||||
with (Path("../../..") / os.environ["ARTEMIS_CFG_DIR"] / "core.yaml").open(encoding="utf-8") as f:
|
||||
core_config.update(yaml.safe_load(f))
|
||||
|
||||
connectable = async_engine_from_config(
|
||||
ini_section, prefix="sqlalchemy.", poolclass=pool.NullPool
|
||||
ini_section,
|
||||
poolclass=pool.NullPool,
|
||||
connect_args={
|
||||
"charset": "utf8mb4",
|
||||
"ssl": core_config.database.create_ssl_context_if_enabled(),
|
||||
}
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
|
@ -41,48 +41,14 @@ class Data:
|
||||
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}"
|
||||
|
||||
if Data.engine is MISSING:
|
||||
connect_args: dict[str, Any] = {
|
||||
"charset": "utf8mb4",
|
||||
}
|
||||
|
||||
if self.config.database.ssl_enabled:
|
||||
no_ca = (
|
||||
self.config.database.ssl_cafile is None
|
||||
and self.config.database.ssl_capath is None
|
||||
)
|
||||
|
||||
ctx = ssl.create_default_context(
|
||||
cafile=self.config.database.ssl_cafile,
|
||||
capath=self.config.database.ssl_capath,
|
||||
)
|
||||
ctx.check_hostname = self.config.database.ssl_verify_identity
|
||||
|
||||
if self.config.database.ssl_verify_cert is None:
|
||||
ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED
|
||||
else:
|
||||
ctx.verify_mode = (
|
||||
ssl.CERT_REQUIRED
|
||||
if self.config.database.ssl_verify_cert
|
||||
else ssl.CERT_NONE
|
||||
)
|
||||
|
||||
if self.config.database.ssl_cert:
|
||||
ctx.load_cert_chain(
|
||||
self.config.database.ssl_cert,
|
||||
self.config.database.ssl_key,
|
||||
self.config.database.ssl_key_password,
|
||||
)
|
||||
|
||||
if self.config.database.ssl_ciphers:
|
||||
ctx.set_ciphers(self.config.database.ssl_ciphers)
|
||||
|
||||
connect_args["ssl"] = ctx
|
||||
|
||||
Data.engine = create_async_engine(
|
||||
self.__url,
|
||||
pool_recycle=3600,
|
||||
isolation_level="AUTOCOMMIT",
|
||||
connect_args=connect_args,
|
||||
connect_args={
|
||||
"charset": "utf8mb4",
|
||||
"ssl": self.config.database.create_ssl_context_if_enabled(),
|
||||
},
|
||||
)
|
||||
self.__engine = Data.engine
|
||||
|
||||
|
Reference in New Issue
Block a user