forked from Hay1tsme/artemis
allow alembic to also connect with tls
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
import ssl
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
from typing_extensions import Optional
|
from typing_extensions import Optional
|
||||||
|
|
||||||
@ -222,7 +223,7 @@ class DatabaseConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ssl_verify_cert(self) -> Optional[bool]:
|
def ssl_verify_cert(self) -> Optional[Union[str, bool]]:
|
||||||
return CoreConfig.get_config_field(
|
return CoreConfig.get_config_field(
|
||||||
self.__config, "core", "database", "ssl_verify_cert", default=None
|
self.__config, "core", "database", "ssl_verify_cert", default=None
|
||||||
)
|
)
|
||||||
@ -259,6 +260,53 @@ class DatabaseConfig:
|
|||||||
self.__config, "core", "database", "memcached_host", default="localhost"
|
self.__config, "core", "database", "memcached_host", default="localhost"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_ssl_context_if_enabled(self):
|
||||||
|
if not self.ssl_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
no_ca = (
|
||||||
|
self.ssl_cafile is None
|
||||||
|
and self.ssl_capath is None
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = ssl.create_default_context(
|
||||||
|
cafile=self.ssl_cafile,
|
||||||
|
capath=self.ssl_capath,
|
||||||
|
)
|
||||||
|
ctx.check_hostname = not no_ca and self.ssl_verify_identity
|
||||||
|
|
||||||
|
if self.ssl_verify_cert is None:
|
||||||
|
ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED
|
||||||
|
elif isinstance(self.ssl_verify_cert, bool):
|
||||||
|
ctx.verify_mode = (
|
||||||
|
ssl.CERT_REQUIRED
|
||||||
|
if self.ssl_verify_cert
|
||||||
|
else ssl.CERT_NONE
|
||||||
|
)
|
||||||
|
elif isinstance(self.ssl_verify_cert, str):
|
||||||
|
value = self.ssl_verify_cert.lower()
|
||||||
|
|
||||||
|
if value in ("none", "0", "false", "no"):
|
||||||
|
ctx.verify_mode = ssl.CERT_NONE
|
||||||
|
elif value == "optional":
|
||||||
|
ctx.verify_mode = ssl.CERT_OPTIONAL
|
||||||
|
elif value in ("required", "1", "true", "yes"):
|
||||||
|
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||||
|
else:
|
||||||
|
ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED
|
||||||
|
|
||||||
|
if self.ssl_cert:
|
||||||
|
ctx.load_cert_chain(
|
||||||
|
self.ssl_cert,
|
||||||
|
self.ssl_key,
|
||||||
|
self.ssl_key_password,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.ssl_ciphers:
|
||||||
|
ctx.set_ciphers(self.ssl_ciphers)
|
||||||
|
|
||||||
|
return ctx
|
||||||
|
|
||||||
class FrontendConfig:
|
class FrontendConfig:
|
||||||
def __init__(self, parent_config: "CoreConfig") -> None:
|
def __init__(self, parent_config: "CoreConfig") -> None:
|
||||||
self.__config = parent_config
|
self.__config = parent_config
|
||||||
|
@ -1,14 +1,18 @@
|
|||||||
from __future__ import with_statement
|
from __future__ import with_statement
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
import threading
|
import threading
|
||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
import yaml
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sqlalchemy import pool
|
from sqlalchemy import pool
|
||||||
from sqlalchemy.engine import Connection
|
from sqlalchemy.engine import Connection
|
||||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
||||||
|
from core.config import CoreConfig
|
||||||
from core.data.schema.base import metadata
|
from core.data.schema.base import metadata
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
@ -74,8 +78,18 @@ async def run_async_migrations() -> None:
|
|||||||
for override in overrides:
|
for override in overrides:
|
||||||
ini_section[override] = overrides[override]
|
ini_section[override] = overrides[override]
|
||||||
|
|
||||||
|
core_config = CoreConfig()
|
||||||
|
|
||||||
|
with (Path("../../..") / os.environ["ARTEMIS_CFG_DIR"] / "core.yaml").open(encoding="utf-8") as f:
|
||||||
|
core_config.update(yaml.safe_load(f))
|
||||||
|
|
||||||
connectable = async_engine_from_config(
|
connectable = async_engine_from_config(
|
||||||
ini_section, prefix="sqlalchemy.", poolclass=pool.NullPool
|
ini_section,
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
connect_args={
|
||||||
|
"charset": "utf8mb4",
|
||||||
|
"ssl": core_config.database.create_ssl_context_if_enabled(),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async with connectable.connect() as connection:
|
async with connectable.connect() as connection:
|
||||||
|
@ -41,48 +41,14 @@ class Data:
|
|||||||
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}"
|
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}"
|
||||||
|
|
||||||
if Data.engine is MISSING:
|
if Data.engine is MISSING:
|
||||||
connect_args: dict[str, Any] = {
|
|
||||||
"charset": "utf8mb4",
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.config.database.ssl_enabled:
|
|
||||||
no_ca = (
|
|
||||||
self.config.database.ssl_cafile is None
|
|
||||||
and self.config.database.ssl_capath is None
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx = ssl.create_default_context(
|
|
||||||
cafile=self.config.database.ssl_cafile,
|
|
||||||
capath=self.config.database.ssl_capath,
|
|
||||||
)
|
|
||||||
ctx.check_hostname = self.config.database.ssl_verify_identity
|
|
||||||
|
|
||||||
if self.config.database.ssl_verify_cert is None:
|
|
||||||
ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED
|
|
||||||
else:
|
|
||||||
ctx.verify_mode = (
|
|
||||||
ssl.CERT_REQUIRED
|
|
||||||
if self.config.database.ssl_verify_cert
|
|
||||||
else ssl.CERT_NONE
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.database.ssl_cert:
|
|
||||||
ctx.load_cert_chain(
|
|
||||||
self.config.database.ssl_cert,
|
|
||||||
self.config.database.ssl_key,
|
|
||||||
self.config.database.ssl_key_password,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.database.ssl_ciphers:
|
|
||||||
ctx.set_ciphers(self.config.database.ssl_ciphers)
|
|
||||||
|
|
||||||
connect_args["ssl"] = ctx
|
|
||||||
|
|
||||||
Data.engine = create_async_engine(
|
Data.engine = create_async_engine(
|
||||||
self.__url,
|
self.__url,
|
||||||
pool_recycle=3600,
|
pool_recycle=3600,
|
||||||
isolation_level="AUTOCOMMIT",
|
isolation_level="AUTOCOMMIT",
|
||||||
connect_args=connect_args,
|
connect_args={
|
||||||
|
"charset": "utf8mb4",
|
||||||
|
"ssl": self.config.database.create_ssl_context_if_enabled(),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self.__engine = Data.engine
|
self.__engine = Data.engine
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user