forked from Hay1tsme/artemis
properly support connecting with SSL on aiomysql
This commit is contained in:
@ -1,6 +1,9 @@
|
|||||||
import logging, os
|
import logging
|
||||||
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from typing_extensions import Optional
|
||||||
|
|
||||||
class ServerConfig:
|
class ServerConfig:
|
||||||
def __init__(self, parent_config: "CoreConfig") -> None:
|
def __init__(self, parent_config: "CoreConfig") -> None:
|
||||||
self.__config = parent_config
|
self.__config = parent_config
|
||||||
@ -175,12 +178,36 @@ class DatabaseConfig:
|
|||||||
return CoreConfig.get_config_field(
|
return CoreConfig.get_config_field(
|
||||||
self.__config, "core", "database", "protocol", default="mysql"
|
self.__config, "core", "database", "protocol", default="mysql"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ssl_enabled(self) -> str:
|
def ssl_enabled(self) -> bool:
|
||||||
return CoreConfig.get_config_field(
|
return CoreConfig.get_config_field(
|
||||||
self.__config, "core", "database", "ssl_enabled", default=False
|
self.__config, "core", "database", "ssl_enabled", default=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ssl_ca(self) -> Optional[str]:
|
||||||
|
return CoreConfig.get_config_field(
|
||||||
|
self.__config, "core", "database", "ssl_ca", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ssl_cert(self) -> Optional[str]:
|
||||||
|
return CoreConfig.get_config_field(
|
||||||
|
self.__config, "core", "database", "ssl_cert", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ssl_key(self) -> Optional[str]:
|
||||||
|
return CoreConfig.get_config_field(
|
||||||
|
self.__config, "core", "database", "ssl_key", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ssl_key_password(self) -> Optional[str]:
|
||||||
|
return CoreConfig.get_config_field(
|
||||||
|
self.__config, "core", "database", "ssl_key_password", default=None
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sha2_password(self) -> bool:
|
def sha2_password(self) -> bool:
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
|
import ssl
|
||||||
import string
|
import string
|
||||||
import warnings
|
import warnings
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from logging.handlers import TimedRotatingFileHandler
|
from logging.handlers import TimedRotatingFileHandler
|
||||||
from typing import ClassVar, Optional
|
from typing import Any, ClassVar, Optional
|
||||||
|
|
||||||
import alembic.config
|
import alembic.config
|
||||||
import bcrypt
|
import bcrypt
|
||||||
@ -35,12 +36,36 @@ class Data:
|
|||||||
|
|
||||||
if self.config.database.sha2_password:
|
if self.config.database.sha2_password:
|
||||||
passwd = sha256(self.config.database.password.encode()).digest()
|
passwd = sha256(self.config.database.password.encode()).digest()
|
||||||
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4"
|
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}"
|
||||||
else:
|
else:
|
||||||
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4"
|
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}"
|
||||||
|
|
||||||
if Data.engine is MISSING:
|
if Data.engine is MISSING:
|
||||||
Data.engine = create_async_engine(self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT")
|
connect_args: dict[str, Any] = {
|
||||||
|
"charset": "utf8mb4",
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.config.database.ssl_enabled:
|
||||||
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
|
|
||||||
|
if self.config.database.ssl_ca:
|
||||||
|
ssl_context.load_verify_locations(self.config.database.ssl_ca)
|
||||||
|
|
||||||
|
if self.config.database.ssl_cert:
|
||||||
|
ssl_context.load_cert_chain(
|
||||||
|
self.config.database.ssl_cert,
|
||||||
|
self.config.database.ssl_key,
|
||||||
|
self.config.database.ssl_key_password,
|
||||||
|
)
|
||||||
|
|
||||||
|
connect_args["ssl"] = ssl_context
|
||||||
|
|
||||||
|
Data.engine = create_async_engine(
|
||||||
|
self.__url,
|
||||||
|
pool_recycle=3600,
|
||||||
|
isolation_level="AUTOCOMMIT",
|
||||||
|
connect_args=connect_args,
|
||||||
|
)
|
||||||
self.__engine = Data.engine
|
self.__engine = Data.engine
|
||||||
|
|
||||||
if Data.session is MISSING:
|
if Data.session is MISSING:
|
||||||
|
Reference in New Issue
Block a user