properly support connecting with SSL on aiomysql

This commit is contained in:
2024-11-16 09:14:44 +07:00
parent eba03e6b9b
commit d7d3dbac59
2 changed files with 59 additions and 7 deletions

View File

@ -1,6 +1,9 @@
import logging, os import logging
import os
from typing import Any from typing import Any
from typing_extensions import Optional
class ServerConfig: class ServerConfig:
def __init__(self, parent_config: "CoreConfig") -> None: def __init__(self, parent_config: "CoreConfig") -> None:
self.__config = parent_config self.__config = parent_config
@ -175,12 +178,36 @@ class DatabaseConfig:
return CoreConfig.get_config_field( return CoreConfig.get_config_field(
self.__config, "core", "database", "protocol", default="mysql" self.__config, "core", "database", "protocol", default="mysql"
) )
@property @property
def ssl_enabled(self) -> str: def ssl_enabled(self) -> bool:
return CoreConfig.get_config_field( return CoreConfig.get_config_field(
self.__config, "core", "database", "ssl_enabled", default=False self.__config, "core", "database", "ssl_enabled", default=False
) )
@property
def ssl_ca(self) -> Optional[str]:
return CoreConfig.get_config_field(
self.__config, "core", "database", "ssl_ca", default=None
)
@property
def ssl_cert(self) -> Optional[str]:
return CoreConfig.get_config_field(
self.__config, "core", "database", "ssl_cert", default=None
)
@property
def ssl_key(self) -> Optional[str]:
return CoreConfig.get_config_field(
self.__config, "core", "database", "ssl_key", default=None
)
@property
def ssl_key_password(self) -> Optional[str]:
return CoreConfig.get_config_field(
self.__config, "core", "database", "ssl_key_password", default=None
)
@property @property
def sha2_password(self) -> bool: def sha2_password(self) -> bool:

View File

@ -1,11 +1,12 @@
import logging import logging
import os import os
import secrets import secrets
import ssl
import string import string
import warnings import warnings
from hashlib import sha256 from hashlib import sha256
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
from typing import ClassVar, Optional from typing import Any, ClassVar, Optional
import alembic.config import alembic.config
import bcrypt import bcrypt
@ -35,12 +36,36 @@ class Data:
if self.config.database.sha2_password: if self.config.database.sha2_password:
passwd = sha256(self.config.database.password.encode()).digest() passwd = sha256(self.config.database.password.encode()).digest()
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4" self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}"
else: else:
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4" self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}"
if Data.engine is MISSING: if Data.engine is MISSING:
Data.engine = create_async_engine(self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT") connect_args: dict[str, Any] = {
"charset": "utf8mb4",
}
if self.config.database.ssl_enabled:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
if self.config.database.ssl_ca:
ssl_context.load_verify_locations(self.config.database.ssl_ca)
if self.config.database.ssl_cert:
ssl_context.load_cert_chain(
self.config.database.ssl_cert,
self.config.database.ssl_key,
self.config.database.ssl_key_password,
)
connect_args["ssl"] = ssl_context
Data.engine = create_async_engine(
self.__url,
pool_recycle=3600,
isolation_level="AUTOCOMMIT",
connect_args=connect_args,
)
self.__engine = Data.engine self.__engine = Data.engine
if Data.session is MISSING: if Data.session is MISSING: