diff --git a/core/config.py b/core/config.py index 8c1caee..2b96157 100644 --- a/core/config.py +++ b/core/config.py @@ -186,9 +186,15 @@ class DatabaseConfig: ) @property - def ssl_ca(self) -> Optional[str]: + def ssl_cafile(self) -> Optional[str]: return CoreConfig.get_config_field( - self.__config, "core", "database", "ssl_ca", default=None + self.__config, "core", "database", "ssl_cafile", default=None + ) + + @property + def ssl_capath(self) -> Optional[str]: + return CoreConfig.get_config_field( + self.__config, "core", "database", "ssl_capath", default=None ) @property @@ -209,6 +215,24 @@ class DatabaseConfig: self.__config, "core", "database", "ssl_key_password", default=None ) + @property + def ssl_verify_identity(self) -> bool: + return CoreConfig.get_config_field( + self.__config, "core", "database", "ssl_verify_identity", default=True + ) + + @property + def ssl_verify_cert(self) -> Optional[bool]: + return CoreConfig.get_config_field( + self.__config, "core", "database", "ssl_verify_cert", default=None + ) + + @property + def ssl_ciphers(self) -> Optional[str]: + return CoreConfig.get_config_field( + self.__config, "core", "database", "ssl_ciphers", default=None + ) + @property def sha2_password(self) -> bool: return CoreConfig.get_config_field( diff --git a/core/data/database.py b/core/data/database.py index 1b01661..b36d65f 100644 --- a/core/data/database.py +++ b/core/data/database.py @@ -46,19 +46,37 @@ class Data: } if self.config.database.ssl_enabled: - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + no_ca = ( + self.config.database.ssl_cafile is None + and self.config.database.ssl_capath is None + ) - if self.config.database.ssl_ca: - ssl_context.load_verify_locations(self.config.database.ssl_ca) + ctx = ssl.create_default_context( + cafile=self.config.database.ssl_cafile, + capath=self.config.database.ssl_capath, + ) + ctx.check_hostname = self.config.database.ssl_verify_identity + + if self.config.database.ssl_verify_cert is None: + ctx.verify_mode = ssl.CERT_NONE if no_ca else ssl.CERT_REQUIRED + else: + ctx.verify_mode = ( + ssl.CERT_REQUIRED + if self.config.database.ssl_verify_cert + else ssl.CERT_NONE + ) if self.config.database.ssl_cert: - ssl_context.load_cert_chain( + ctx.load_cert_chain( self.config.database.ssl_cert, self.config.database.ssl_key, self.config.database.ssl_key_password, ) + + if self.config.database.ssl_ciphers: + ctx.set_ciphers(self.config.database.ssl_ciphers) - connect_args["ssl"] = ssl_context + connect_args["ssl"] = ctx Data.engine = create_async_engine( self.__url,