This commit is contained in:
2024-07-06 06:57:59 +07:00
parent 6f12f9988e
commit 5894831154
4 changed files with 64 additions and 53 deletions

View File

@ -176,7 +176,13 @@ class UserData(BaseData):
await self.execute(sql) await self.execute(sql)
async def extend_lock_for_game(self, user_id: int, game: str): async def extend_lock_for_game(self, user_id: int, game: str, extra: dict | None = None):
sql = game_locks.update().where((game_locks.c.user == user_id) & (game_locks.c.game == game)).values(expires_at=func.date_add(func.now(), text("INTERVAL 15 MINUTE"))) sql = (
insert(game_locks)
.values(user=user_id, game=game, extra=extra or {})
.on_duplicate_key_update(
expires_at=func.date_add(func.now(), text("INTERVAL 15 MINUTE")),
)
)
await self.execute(sql) await self.execute(sql)

View File

@ -45,7 +45,7 @@ def init_root_logger(cfg: "CoreConfig"):
def create_logger( def create_logger(
title: str, title: str,
level: Optional[str] = None, level: Optional[logging._Level] = None,
*, *,
logger_name: Optional[str] = None, logger_name: Optional[str] = None,
): ):

View File

@ -59,7 +59,7 @@ class OngekiCryptoConfig:
self.__config = parent_config self.__config = parent_config
@property @property
def keys(self) -> Dict: def keys(self) -> dict[int, list[str]]:
""" """
in the form of: in the form of:
internal_version: [key, iv] internal_version: [key, iv]

View File

@ -1,6 +1,6 @@
from starlette.requests import Request from starlette.requests import Request
from starlette.routing import Route from starlette.routing import Route
from starlette.responses import Response from starlette.responses import PlainTextResponse, Response
import json import json
import inflection import inflection
import yaml import yaml
@ -27,7 +27,6 @@ from .red import OngekiRed
from .redplus import OngekiRedPlus from .redplus import OngekiRedPlus
from .bright import OngekiBright from .bright import OngekiBright
from .brightmemory import OngekiBrightMemory from .brightmemory import OngekiBrightMemory
from .brightmemory2 import OngekiBrightMemoryAct2
from .brightmemory3 import OngekiBrightMemoryAct3 from .brightmemory3 import OngekiBrightMemoryAct3
@ -35,7 +34,7 @@ class OngekiServlet(BaseServlet):
def __init__(self, core_cfg: CoreConfig, cfg_dir: str) -> None: def __init__(self, core_cfg: CoreConfig, cfg_dir: str) -> None:
super().__init__(core_cfg, cfg_dir) super().__init__(core_cfg, cfg_dir)
self.game_cfg = OngekiConfig() self.game_cfg = OngekiConfig()
self.hash_table: Dict[Dict[str, str]] = {} self.hash_table: dict[int, dict[str, str]] = {}
if path.exists(f"{cfg_dir}/{OngekiConstants.CONFIG_NAME}"): if path.exists(f"{cfg_dir}/{OngekiConstants.CONFIG_NAME}"):
self.game_cfg.update( self.game_cfg.update(
yaml.safe_load(open(f"{cfg_dir}/{OngekiConstants.CONFIG_NAME}")) yaml.safe_load(open(f"{cfg_dir}/{OngekiConstants.CONFIG_NAME}"))
@ -51,8 +50,7 @@ class OngekiServlet(BaseServlet):
OngekiRedPlus(core_cfg, self.game_cfg), OngekiRedPlus(core_cfg, self.game_cfg),
OngekiBright(core_cfg, self.game_cfg), OngekiBright(core_cfg, self.game_cfg),
OngekiBrightMemory(core_cfg, self.game_cfg), OngekiBrightMemory(core_cfg, self.game_cfg),
# OngekiBrightMemoryAct2(core_cfg, self.game_cfg), None, # used to be act2 but apparently it wasn't a version of its own
None,
OngekiBrightMemoryAct3(core_cfg, self.game_cfg), OngekiBrightMemoryAct3(core_cfg, self.game_cfg),
] ]
@ -65,25 +63,24 @@ class OngekiServlet(BaseServlet):
method_list = [ method_list = [
method method
for method in dir(self.versions[version]) for method in dir(self.versions[version])
if not method.startswith("__") if method.startswith("handle_") and method.endswith("_request")
] ]
salt = bytes.fromhex(keys[2])
for method in method_list: for method in method_list:
method_fixed = inflection.camelize(method)[6:-7] method_fixed = inflection.camelize(method)[6:-7]
# number of iterations is 64 on Bright Memory
iter_count = 64
hash = PBKDF2( hash = PBKDF2(
method_fixed, method_fixed,
bytes.fromhex(keys[2]), salt,
128, 16,
count=iter_count, count=64, # bright memory
hmac_hash_module=SHA1, hmac_hash_module=SHA1,
) )
hashed_name = hash.hex()[:32] # truncate unused bytes like the game does hashed_name = hash.hex()
self.hash_table[version][hashed_name] = method_fixed self.hash_table[version][hashed_name] = method_fixed
self.logger.debug( self.logger.debug(
f"Hashed v{version} method {method_fixed} with {bytes.fromhex(keys[2])} to get {hash.hex()[:32]}" f"Hashed v{version} method {method_fixed} with {salt} to get {hashed_name}"
) )
@classmethod @classmethod
@ -121,7 +118,7 @@ class OngekiServlet(BaseServlet):
f"{self.core_cfg.server.hostname}{t_port}/", f"{self.core_cfg.server.hostname}{t_port}/",
) )
async def render_POST(self, request: Request) -> bytes: async def render_POST(self, request: Request) -> Response:
endpoint: str = request.path_params.get('endpoint', '') endpoint: str = request.path_params.get('endpoint', '')
version: int = request.path_params.get('version', 0) version: int = request.path_params.get('version', 0)
if endpoint.lower() == "ping": if endpoint.lower() == "ping":
@ -151,21 +148,18 @@ class OngekiServlet(BaseServlet):
elif version >= 145: elif version >= 145:
internal_ver = OngekiConstants.VER_ONGEKI_BRIGHT_MEMORY_ACT3 internal_ver = OngekiConstants.VER_ONGEKI_BRIGHT_MEMORY_ACT3
if all(c in string.hexdigits for c in endpoint) and len(endpoint) == 32: if request.headers.get("ongeki-encoding") is not None:
# If we get a 32 character long hex string, it's a hash and we're
# doing encrypted. The likelyhood of false positives is low but
# technically not 0
if internal_ver not in self.hash_table: if internal_ver not in self.hash_table:
self.logger.error( self.logger.error(
f"v{version} does not support encryption or no keys entered" f"v{version} does not support encryption or no keys entered"
) )
return Response(zlib.compress(b'{"stat": "0"}')) return Response(zlib.compress(b'{"returnCode": "0"}'))
elif endpoint.lower() not in self.hash_table[internal_ver]: elif endpoint.lower() not in self.hash_table[internal_ver]:
self.logger.error( self.logger.error(
f"No hash found for v{version} endpoint {endpoint}" f"No hash found for v{version} endpoint {endpoint}"
) )
return Response(zlib.compress(b'{"stat": "0"}')) return Response(zlib.compress(b'{"returnCode": "0"}'))
endpoint = self.hash_table[internal_ver][endpoint.lower()] endpoint = self.hash_table[internal_ver][endpoint.lower()]
@ -179,8 +173,9 @@ class OngekiServlet(BaseServlet):
req_raw = crypt.decrypt(req_raw) req_raw = crypt.decrypt(req_raw)
except Exception as e: except Exception as e:
self.logger.error( self.logger.exception(
f"Failed to decrypt v{version} request to {endpoint} -> {e}" f"Failed to decrypt v{version} request to {endpoint}",
exc_info=e,
) )
return Response(zlib.compress(b'{"returnCode": "0"}')) return Response(zlib.compress(b'{"returnCode": "0"}'))
@ -196,51 +191,61 @@ class OngekiServlet(BaseServlet):
) )
return Response(zlib.compress(b'{"returnCode": "0"}')) return Response(zlib.compress(b'{"returnCode": "0"}'))
if version < 105: if request.headers.get("content-encoding") == "deflate":
# O.N.G.E.K.I base don't use zlib
req_data = json.loads(req_raw)
else:
try: try:
unzip = zlib.decompress(req_raw) unzip = zlib.decompress(req_raw)
except zlib.error as e: except zlib.error as e:
self.logger.error( self.logger.exception(
f"Failed to decompress v{version} {endpoint} request -> {e}" f"Failed to decompress v{version} {endpoint} request",
exc_info=e
) )
return Response(zlib.compress(b'{"returnCode": "0"}')) return Response(zlib.compress(b'{"returnCode": "0"}'))
req_data = json.loads(unzip) req_data = json.loads(unzip)
else:
req_data = json.loads(req_raw)
self.logger.info( self.logger.debug(
f"v{version} {endpoint} request from {client_ip}" "Received request v%d %s from %s.",
version, endpoint, client_ip,
extra={
"body": req_data,
},
) )
self.logger.debug(req_data)
func_to_find = "handle_" + inflection.underscore(endpoint) + "_request" func_to_find = "handle_" + inflection.underscore(endpoint) + "_request"
handler = getattr(self.versions[internal_ver], func_to_find, None)
response = None
if not hasattr(self.versions[internal_ver], func_to_find): if handler is not None:
self.logger.warning(f"Unhandled v{version} request {endpoint}") try:
return Response(zlib.compress(b'{"returnCode": 1}')) response = await handler(req_data)
if response is None:
response = {"returnCode": 1}
except Exception as e:
self.logger.exception(f"Error handling v{version} {endpoint}", exc_info=e)
try: if not self.core_cfg.server.is_develop:
handler = getattr(self.versions[internal_ver], func_to_find) return PlainTextResponse("Internal Server Error", status_code=500)
resp = await handler(req_data)
response = {"returnCode": 0}
elif self.core_cfg.server.is_develop:
self.logger.warning(f"Stubbing unhandled v{version} {endpoint} request because server is in develop mode.")
response = {"returnCode": 1}
else:
self.logger.error(f"Unhandled v{version} {endpoint} request.")
return PlainTextResponse("Not Found", status_code=404)
except Exception as e: self.logger.info(f"(v{version} {endpoint}) Returned 200.", extra=response)
self.logger.error(f"Error handling v{version} method {endpoint} - {e}")
resp = {"returnCode": 0}
if resp is None: resp_raw = json.dumps(response, ensure_ascii=False).encode("utf-8")
resp = {"returnCode": 1}
self.logger.debug(f"Response {resp}") if request.headers.get("content-encoding") != "deflate":
return Response(resp_raw)
resp_raw = json.dumps(resp, ensure_ascii=False).encode("utf-8")
zipped = zlib.compress(resp_raw) zipped = zlib.compress(resp_raw)
if not encrtped or version < 120: if not encrtped:
if version < 105:
return Response(resp_raw)
return Response(zipped) return Response(zipped)
padded = pad(zipped, 16) padded = pad(zipped, 16)