forked from Hay1tsme/artemis
things
This commit is contained in:
@ -176,7 +176,13 @@ class UserData(BaseData):
|
|||||||
|
|
||||||
await self.execute(sql)
|
await self.execute(sql)
|
||||||
|
|
||||||
async def extend_lock_for_game(self, user_id: int, game: str):
|
async def extend_lock_for_game(self, user_id: int, game: str, extra: dict | None = None):
|
||||||
sql = game_locks.update().where((game_locks.c.user == user_id) & (game_locks.c.game == game)).values(expires_at=func.date_add(func.now(), text("INTERVAL 15 MINUTE")))
|
sql = (
|
||||||
|
insert(game_locks)
|
||||||
|
.values(user=user_id, game=game, extra=extra or {})
|
||||||
|
.on_duplicate_key_update(
|
||||||
|
expires_at=func.date_add(func.now(), text("INTERVAL 15 MINUTE")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
await self.execute(sql)
|
await self.execute(sql)
|
||||||
|
@ -45,7 +45,7 @@ def init_root_logger(cfg: "CoreConfig"):
|
|||||||
|
|
||||||
def create_logger(
|
def create_logger(
|
||||||
title: str,
|
title: str,
|
||||||
level: Optional[str] = None,
|
level: Optional[logging._Level] = None,
|
||||||
*,
|
*,
|
||||||
logger_name: Optional[str] = None,
|
logger_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
@ -59,7 +59,7 @@ class OngekiCryptoConfig:
|
|||||||
self.__config = parent_config
|
self.__config = parent_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self) -> Dict:
|
def keys(self) -> dict[int, list[str]]:
|
||||||
"""
|
"""
|
||||||
in the form of:
|
in the form of:
|
||||||
internal_version: [key, iv]
|
internal_version: [key, iv]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
from starlette.responses import Response
|
from starlette.responses import PlainTextResponse, Response
|
||||||
import json
|
import json
|
||||||
import inflection
|
import inflection
|
||||||
import yaml
|
import yaml
|
||||||
@ -27,7 +27,6 @@ from .red import OngekiRed
|
|||||||
from .redplus import OngekiRedPlus
|
from .redplus import OngekiRedPlus
|
||||||
from .bright import OngekiBright
|
from .bright import OngekiBright
|
||||||
from .brightmemory import OngekiBrightMemory
|
from .brightmemory import OngekiBrightMemory
|
||||||
from .brightmemory2 import OngekiBrightMemoryAct2
|
|
||||||
from .brightmemory3 import OngekiBrightMemoryAct3
|
from .brightmemory3 import OngekiBrightMemoryAct3
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +34,7 @@ class OngekiServlet(BaseServlet):
|
|||||||
def __init__(self, core_cfg: CoreConfig, cfg_dir: str) -> None:
|
def __init__(self, core_cfg: CoreConfig, cfg_dir: str) -> None:
|
||||||
super().__init__(core_cfg, cfg_dir)
|
super().__init__(core_cfg, cfg_dir)
|
||||||
self.game_cfg = OngekiConfig()
|
self.game_cfg = OngekiConfig()
|
||||||
self.hash_table: Dict[Dict[str, str]] = {}
|
self.hash_table: dict[int, dict[str, str]] = {}
|
||||||
if path.exists(f"{cfg_dir}/{OngekiConstants.CONFIG_NAME}"):
|
if path.exists(f"{cfg_dir}/{OngekiConstants.CONFIG_NAME}"):
|
||||||
self.game_cfg.update(
|
self.game_cfg.update(
|
||||||
yaml.safe_load(open(f"{cfg_dir}/{OngekiConstants.CONFIG_NAME}"))
|
yaml.safe_load(open(f"{cfg_dir}/{OngekiConstants.CONFIG_NAME}"))
|
||||||
@ -51,8 +50,7 @@ class OngekiServlet(BaseServlet):
|
|||||||
OngekiRedPlus(core_cfg, self.game_cfg),
|
OngekiRedPlus(core_cfg, self.game_cfg),
|
||||||
OngekiBright(core_cfg, self.game_cfg),
|
OngekiBright(core_cfg, self.game_cfg),
|
||||||
OngekiBrightMemory(core_cfg, self.game_cfg),
|
OngekiBrightMemory(core_cfg, self.game_cfg),
|
||||||
# OngekiBrightMemoryAct2(core_cfg, self.game_cfg),
|
None, # used to be act2 but apparently it wasn't a version of its own
|
||||||
None,
|
|
||||||
OngekiBrightMemoryAct3(core_cfg, self.game_cfg),
|
OngekiBrightMemoryAct3(core_cfg, self.game_cfg),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -65,25 +63,24 @@ class OngekiServlet(BaseServlet):
|
|||||||
method_list = [
|
method_list = [
|
||||||
method
|
method
|
||||||
for method in dir(self.versions[version])
|
for method in dir(self.versions[version])
|
||||||
if not method.startswith("__")
|
if method.startswith("handle_") and method.endswith("_request")
|
||||||
]
|
]
|
||||||
|
salt = bytes.fromhex(keys[2])
|
||||||
for method in method_list:
|
for method in method_list:
|
||||||
method_fixed = inflection.camelize(method)[6:-7]
|
method_fixed = inflection.camelize(method)[6:-7]
|
||||||
# number of iterations is 64 on Bright Memory
|
|
||||||
iter_count = 64
|
|
||||||
hash = PBKDF2(
|
hash = PBKDF2(
|
||||||
method_fixed,
|
method_fixed,
|
||||||
bytes.fromhex(keys[2]),
|
salt,
|
||||||
128,
|
16,
|
||||||
count=iter_count,
|
count=64, # bright memory
|
||||||
hmac_hash_module=SHA1,
|
hmac_hash_module=SHA1,
|
||||||
)
|
)
|
||||||
|
|
||||||
hashed_name = hash.hex()[:32] # truncate unused bytes like the game does
|
hashed_name = hash.hex()
|
||||||
self.hash_table[version][hashed_name] = method_fixed
|
self.hash_table[version][hashed_name] = method_fixed
|
||||||
|
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Hashed v{version} method {method_fixed} with {bytes.fromhex(keys[2])} to get {hash.hex()[:32]}"
|
f"Hashed v{version} method {method_fixed} with {salt} to get {hashed_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -121,7 +118,7 @@ class OngekiServlet(BaseServlet):
|
|||||||
f"{self.core_cfg.server.hostname}{t_port}/",
|
f"{self.core_cfg.server.hostname}{t_port}/",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def render_POST(self, request: Request) -> bytes:
|
async def render_POST(self, request: Request) -> Response:
|
||||||
endpoint: str = request.path_params.get('endpoint', '')
|
endpoint: str = request.path_params.get('endpoint', '')
|
||||||
version: int = request.path_params.get('version', 0)
|
version: int = request.path_params.get('version', 0)
|
||||||
if endpoint.lower() == "ping":
|
if endpoint.lower() == "ping":
|
||||||
@ -151,21 +148,18 @@ class OngekiServlet(BaseServlet):
|
|||||||
elif version >= 145:
|
elif version >= 145:
|
||||||
internal_ver = OngekiConstants.VER_ONGEKI_BRIGHT_MEMORY_ACT3
|
internal_ver = OngekiConstants.VER_ONGEKI_BRIGHT_MEMORY_ACT3
|
||||||
|
|
||||||
if all(c in string.hexdigits for c in endpoint) and len(endpoint) == 32:
|
if request.headers.get("ongeki-encoding") is not None:
|
||||||
# If we get a 32 character long hex string, it's a hash and we're
|
|
||||||
# doing encrypted. The likelyhood of false positives is low but
|
|
||||||
# technically not 0
|
|
||||||
if internal_ver not in self.hash_table:
|
if internal_ver not in self.hash_table:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
f"v{version} does not support encryption or no keys entered"
|
f"v{version} does not support encryption or no keys entered"
|
||||||
)
|
)
|
||||||
return Response(zlib.compress(b'{"stat": "0"}'))
|
return Response(zlib.compress(b'{"returnCode": "0"}'))
|
||||||
|
|
||||||
elif endpoint.lower() not in self.hash_table[internal_ver]:
|
elif endpoint.lower() not in self.hash_table[internal_ver]:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
f"No hash found for v{version} endpoint {endpoint}"
|
f"No hash found for v{version} endpoint {endpoint}"
|
||||||
)
|
)
|
||||||
return Response(zlib.compress(b'{"stat": "0"}'))
|
return Response(zlib.compress(b'{"returnCode": "0"}'))
|
||||||
|
|
||||||
endpoint = self.hash_table[internal_ver][endpoint.lower()]
|
endpoint = self.hash_table[internal_ver][endpoint.lower()]
|
||||||
|
|
||||||
@ -179,8 +173,9 @@ class OngekiServlet(BaseServlet):
|
|||||||
req_raw = crypt.decrypt(req_raw)
|
req_raw = crypt.decrypt(req_raw)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(
|
self.logger.exception(
|
||||||
f"Failed to decrypt v{version} request to {endpoint} -> {e}"
|
f"Failed to decrypt v{version} request to {endpoint}",
|
||||||
|
exc_info=e,
|
||||||
)
|
)
|
||||||
return Response(zlib.compress(b'{"returnCode": "0"}'))
|
return Response(zlib.compress(b'{"returnCode": "0"}'))
|
||||||
|
|
||||||
@ -196,51 +191,61 @@ class OngekiServlet(BaseServlet):
|
|||||||
)
|
)
|
||||||
return Response(zlib.compress(b'{"returnCode": "0"}'))
|
return Response(zlib.compress(b'{"returnCode": "0"}'))
|
||||||
|
|
||||||
if version < 105:
|
if request.headers.get("content-encoding") == "deflate":
|
||||||
# O.N.G.E.K.I base don't use zlib
|
|
||||||
req_data = json.loads(req_raw)
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
unzip = zlib.decompress(req_raw)
|
unzip = zlib.decompress(req_raw)
|
||||||
|
|
||||||
except zlib.error as e:
|
except zlib.error as e:
|
||||||
self.logger.error(
|
self.logger.exception(
|
||||||
f"Failed to decompress v{version} {endpoint} request -> {e}"
|
f"Failed to decompress v{version} {endpoint} request",
|
||||||
|
exc_info=e
|
||||||
)
|
)
|
||||||
return Response(zlib.compress(b'{"returnCode": "0"}'))
|
return Response(zlib.compress(b'{"returnCode": "0"}'))
|
||||||
|
|
||||||
req_data = json.loads(unzip)
|
req_data = json.loads(unzip)
|
||||||
|
else:
|
||||||
|
req_data = json.loads(req_raw)
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.debug(
|
||||||
f"v{version} {endpoint} request from {client_ip}"
|
"Received request v%d %s from %s.",
|
||||||
|
version, endpoint, client_ip,
|
||||||
|
extra={
|
||||||
|
"body": req_data,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self.logger.debug(req_data)
|
|
||||||
|
|
||||||
func_to_find = "handle_" + inflection.underscore(endpoint) + "_request"
|
func_to_find = "handle_" + inflection.underscore(endpoint) + "_request"
|
||||||
|
handler = getattr(self.versions[internal_ver], func_to_find, None)
|
||||||
|
response = None
|
||||||
|
|
||||||
if not hasattr(self.versions[internal_ver], func_to_find):
|
if handler is not None:
|
||||||
self.logger.warning(f"Unhandled v{version} request {endpoint}")
|
|
||||||
return Response(zlib.compress(b'{"returnCode": 1}'))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
handler = getattr(self.versions[internal_ver], func_to_find)
|
response = await handler(req_data)
|
||||||
resp = await handler(req_data)
|
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
response = {"returnCode": 1}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error handling v{version} method {endpoint} - {e}")
|
self.logger.exception(f"Error handling v{version} {endpoint}", exc_info=e)
|
||||||
resp = {"returnCode": 0}
|
|
||||||
|
|
||||||
if resp is None:
|
if not self.core_cfg.server.is_develop:
|
||||||
resp = {"returnCode": 1}
|
return PlainTextResponse("Internal Server Error", status_code=500)
|
||||||
|
|
||||||
self.logger.debug(f"Response {resp}")
|
response = {"returnCode": 0}
|
||||||
|
elif self.core_cfg.server.is_develop:
|
||||||
|
self.logger.warning(f"Stubbing unhandled v{version} {endpoint} request because server is in develop mode.")
|
||||||
|
response = {"returnCode": 1}
|
||||||
|
else:
|
||||||
|
self.logger.error(f"Unhandled v{version} {endpoint} request.")
|
||||||
|
return PlainTextResponse("Not Found", status_code=404)
|
||||||
|
|
||||||
|
self.logger.info(f"(v{version} {endpoint}) Returned 200.", extra=response)
|
||||||
|
|
||||||
|
resp_raw = json.dumps(response, ensure_ascii=False).encode("utf-8")
|
||||||
|
|
||||||
|
if request.headers.get("content-encoding") != "deflate":
|
||||||
|
return Response(resp_raw)
|
||||||
|
|
||||||
resp_raw = json.dumps(resp, ensure_ascii=False).encode("utf-8")
|
|
||||||
zipped = zlib.compress(resp_raw)
|
zipped = zlib.compress(resp_raw)
|
||||||
|
|
||||||
if not encrtped or version < 120:
|
if not encrtped:
|
||||||
if version < 105:
|
|
||||||
return Response(resp_raw)
|
|
||||||
return Response(zipped)
|
return Response(zipped)
|
||||||
|
|
||||||
padded = pad(zipped, 16)
|
padded = pad(zipped, 16)
|
||||||
|
Reference in New Issue
Block a user