diff --git a/titles/idz/echo.py b/titles/idz/echo.py index 979fd19..a8ff3a2 100644 --- a/titles/idz/echo.py +++ b/titles/idz/echo.py @@ -4,16 +4,11 @@ import logging from core.config import CoreConfig from .config import IDZConfig +class IDZEcho: + def connection_made(self, transport): + self.transport = transport -class IDZEcho(DatagramProtocol): - def __init__(self, cfg: CoreConfig, game_cfg: IDZConfig) -> None: - super().__init__() - self.core_config = cfg - self.game_config = game_cfg - self.logger = logging.getLogger("idz") - - def datagramReceived(self, data, addr): - self.logger.debug( - f"Echo from from {addr[0]}:{addr[1]} -> {self.transport.getHost().port} - {data.hex()}" - ) - self.transport.write(data, addr) + def datagram_received(self, data, addr): + message = data.decode() + self.logger.debug(f'Received echo from {addr}') + self.transport.sendto(data, addr) diff --git a/titles/idz/handlers/load_server_info.py b/titles/idz/handlers/load_server_info.py index 9eb63ab..f0ace03 100644 --- a/titles/idz/handlers/load_server_info.py +++ b/titles/idz/handlers/load_server_info.py @@ -91,7 +91,7 @@ class IDZHandlerLoadServerInfo(IDZHandlerBase): ) struct.pack_into(" bytes: diff --git a/titles/idz/userdb.py b/titles/idz/userdb.py index b585bd2..cd6ea9c 100644 --- a/titles/idz/userdb.py +++ b/titles/idz/userdb.py @@ -1,15 +1,9 @@ -from twisted.internet.protocol import Factory, Protocol -import logging, coloredlogs +import logging from Crypto.Cipher import AES import struct from typing import Dict, Optional, List, Type -from twisted.web import server, resource -from twisted.internet import reactor, endpoints -from starlette.requests import Request -from routes import Mapper import random -from os import walk -import importlib +import asyncio from core.config import CoreConfig from .database import IDZData @@ -28,7 +22,7 @@ class IDZKey: self.hashN = hashN -class IDZUserDBProtocol(Protocol): +class IDZUserDB: def __init__( self, core_cfg: CoreConfig, @@ -45,6 +39,10 @@ class IDZUserDBProtocol(Protocol): self.version = None self.version_internal = None self.skip_next = False + + def start(self) -> None: + self.logger.info(f"Start on port {self.config.aimedb.port}") + asyncio.create_task(asyncio.start_server(self.connection_cb, self.config.server.listen_address, self.config.aimedb.port)) def append_padding(self, data: bytes): """Appends 0s to the end of the data until it's at the correct size""" @@ -52,43 +50,54 @@ class IDZUserDBProtocol(Protocol): padding_size = length[0] - len(data) data += bytes(padding_size) return data + + async def connection_cb(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + self.logger.debug(f"Connection made from {writer.get_extra_info('peername')[0]}") + while True: + try: + base = 0 - def connectionMade(self) -> None: - self.logger.debug(f"{self.transport.getPeer().host} Connected") - base = 0 + for i in range(len(self.static_key) - 1): + shift = 8 * i + byte = self.static_key[i] - for i in range(len(self.static_key) - 1): - shift = 8 * i - byte = self.static_key[i] + base |= byte << shift - base |= byte << shift + rsa_key = random.choice(self.rsa_keys) + key_enc: int = pow(base, rsa_key.e, rsa_key.N) + result = ( + key_enc.to_bytes(0x40, "little") + + struct.pack(" None: - self.logger.debug( - f"{self.transport.getPeer().host} Disconnected - {reason.value}" - ) - - def dataReceived(self, data: bytes) -> None: + def dataReceived(self, data: bytes, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: self.logger.debug(f"Receive data {data.hex()}") + client_ip = writer.get_extra_info('peername')[0] crypt = AES.new(self.static_key, AES.MODE_ECB) try: data_dec = crypt.decrypt(data) except Exception as e: - self.logger.error(f"Failed to decrypt UserDB request from {self.transport.getPeer().host} because {e} - {data.hex()}") + self.logger.error(f"Failed to decrypt UserDB request from {client_ip} because {e} - {data.hex()}") self.logger.debug(f"Decrypt data {data_dec.hex()}") @@ -99,7 +108,7 @@ class IDZUserDBProtocol(Protocol): self.logger.info(f"Userdb serverbox request {data_dec.hex()}") self.skip_next = True - self.transport.write(b"\x00") + writer.write(b"\x00") return elif magic == 0x01020304: @@ -119,21 +128,21 @@ class IDZUserDBProtocol(Protocol): self.version_internal = None self.logger.debug( - f"Userdb v{self.version} handshake response from {self.transport.getPeer().host}" + f"Userdb v{self.version} handshake response from {client_ip}" ) return elif self.skip_next: self.skip_next = False - self.transport.write(b"\x00") + writer.write(b"\x00") return elif self.version is None: # We didn't get a handshake before, and this isn't one now, so we're up the creek self.logger.info( - f"Bad UserDB request from from {self.transport.getPeer().host}" + f"Bad UserDB request from from {client_ip}" ) - self.transport.write(b"\x00") + writer.write(b"\x00") return cmd = struct.unpack_from(" None: - self.core_config = cfg - self.game_config = game_cfg - self.keys = keys - self.handlers = handlers - - def buildProtocol(self, addr): - return IDZUserDBProtocol( - self.core_config, self.game_config, self.keys, self.handlers - ) - - -class IDZUserDBWeb(resource.Resource): - def __init__(self, core_cfg: CoreConfig, game_cfg: IDZConfig): - super().__init__() - self.isLeaf = True - self.core_config = core_cfg - self.game_config = game_cfg - self.logger = logging.getLogger("idz") - - def render_POST(self, request: Request) -> bytes: - self.logger.info( - f"IDZUserDBWeb POST from {request.getClientAddress().host} to {request.uri} with data {request.content.getvalue()}" - ) - return b"" - - def render_GET(self, request: Request) -> bytes: - self.logger.info( - f"IDZUserDBWeb GET from {request.getClientAddress().host} to {request.uri}" - ) - return b"" + writer.write(response_enc)