import traceback import urllib.parse import urllib import sys import os from typing import Callable from collections import defaultdict from werkzeug.exceptions import HTTPException, MethodNotAllowed from werkzeug.wrappers import Request, Response from werkzeug.routing import Map, Rule import eaapi from . import exceptions as exc from .context import CallContext from .model import Model from .controller import IController, ServicesController from .const import SERVICE_NTP, SERVICE_KEEPALIVE Handler = Callable[[CallContext], None] HEADER_ENCRYPTION = "X-Eamuse-Info" HEADER_COMPRESSION = "X-Compress" PINGABLE_IP = "127.0.0.1" class NetworkState: def __init__(self): self._pa = PINGABLE_IP # TODO: what does this one mean? self.router_ip = PINGABLE_IP self.gateway_ip = PINGABLE_IP self.center_ip = PINGABLE_IP def format_ka(self, base): return base + "?" + urllib.parse.urlencode({ "pa": self.pa, "ia": self.ia, "ga": self.ga, "ma": self.ma, "t1": self.t1, "t2": self.t2, }) @property def pa(self) -> str: return self._pa @property def ia(self) -> str: return self.router_ip @property def ga(self) -> str: return self.gateway_ip @property def ma(self) -> str: return self.center_ip # TODO: Identify what these values are. Ping intervals? @property def t1(self): return 2 @property def t2(self): return 10 class EAMServer: def __init__( self, public_url: str, prioritise_params: bool = False, verbose_errors: bool = False, services_mode: eaapi.const.ServicesMode = eaapi.const.ServicesMode.Operation, ntp_server: str = "ntp://pool.ntp.org/", keepalive_server: str | None = None, no_keepalive_route: bool = False, disable_routes: bool = False, no_services_handler: bool = False, ): self.network = NetworkState() self.verbose_errors = verbose_errors self._prioritise_params = prioritise_params self._public_url = public_url self.disable_routes = disable_routes self._no_keepalive_route = no_keepalive_route self.ntp = ntp_server self.keepalive = keepalive_server or f"{public_url}/keepalive" self._prng = eaapi.crypt.new_prng() self._setup = [] self._pre_handlers_check = [] self._teardown = [] self._einfo_ctx: CallContext | None = None self._einfo_controller: str | None = None self.controllers: list[IController] = [] if not no_services_handler: self.controllers.append(ServicesController(self, services_mode)) def on_setup(self, callback): if callback not in self._setup: self._setup.append(callback) def on_pre_handlers_check(self, callback): if callback not in self._pre_handlers_check: self._pre_handlers_check.append(callback) def on_teardown(self, callback): if callback not in self._teardown: self._teardown.append(callback) def build_rules_map(self) -> Map: if self.disable_routes: return Map([]) rules = Map([], strict_slashes=False, merge_slashes=False) prefixes = {"/"} for i in self.controllers: for prefix in i.serviced_prefixes(): prefix = self.expand_url(prefix) if not prefix.startswith(self._public_url): continue prefix = prefix[len(self._public_url):] if prefix == "": prefix = "/" prefixes.add(prefix) for i in prefixes: rules.add(Rule(f"{i}///", endpoint="xrpc_request")) # WSGI flattens the // at the start if i == "/": rules.add(Rule("///", endpoint="xrpc_request")) rules.add(Rule(f"{i}", endpoint="xrpc_request")) if not self._no_keepalive_route: rules.add(Rule("/keepalive", endpoint="keepalive_request")) return rules def expand_url(self, url: str) -> str: return urllib.parse.urljoin(self._public_url, url) @property def public_url(self) -> str: return self._public_url def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]: services: dict[str, str] = defaultdict(lambda: self.public_url) services[SERVICE_NTP] = self.ntp services[SERVICE_KEEPALIVE] = self.network.format_ka(self.keepalive) for i in self.controllers: services.update(i.get_service_routes(ctx)) return services def _decode_request(self, request: Request) -> CallContext: ea_info = request.headers.get(HEADER_ENCRYPTION) compression = request.headers.get(HEADER_COMPRESSION) compressed = False if compression == eaapi.Compression.Lz77.value: compressed = True elif compression != eaapi.Compression.None_.value: raise exc.UnknownCompression payload = eaapi.unwrap(request.data, ea_info, compressed) decoder = eaapi.Decoder(payload) try: call = decoder.unpack() except eaapi.EAAPIException: raise exc.InvalidPacket return CallContext(request, decoder, call, ea_info, compressed) def _encode_response(self, ctx: CallContext) -> Response: if ctx._eainfo is None: ea_info = None else: ea_info = eaapi.crypt.get_key(self._prng) encoded = eaapi.Encoder.encode(ctx.resp, ctx.was_xml_string) wrapped = eaapi.wrap(encoded, ea_info, ctx.was_compressed) response = Response(wrapped, 200) if ea_info: response.headers[HEADER_ENCRYPTION] = ea_info response.headers[HEADER_COMPRESSION] = ( eaapi.Compression.Lz77 if ctx.was_compressed else eaapi.Compression.None_ ).value return response def _create_ctx( self, url_slash: bool, request: Request, model: Model | None, module: str, method: str ) -> CallContext: ctx = self._decode_request(request) ctx._module = module ctx._method = method ctx._url_slash = url_slash self._einfo_ctx = ctx if ctx.model != model: raise exc.ModelMissmatch return ctx def _handle_request(self, ctx: CallContext) -> Response: for controller in self.controllers: if (handler := controller.get_handler(ctx)) is not None: self._einfo_controller = ( f"{controller._name}" ) handler(ctx) break else: raise exc.NoMethodHandler return self._encode_response(ctx) def on_xrpc_other( self, request: Request, service: str | None = None, model: str | None = None, module: str | None = None, method: str | None = None ): if request.method != "GET" or not self.verbose_errors: raise MethodNotAllowed return Response( f"XRPC running. model {model}, call {module}.{method} ({service})" ) def keepalive_request(self) -> Response: return Response(None) def parse_request( self, request: Request, service: str | None = None, model: str | None = None, module: str | None = None, method: str | None = None ): url_slash = bool(module and module and method) model_param = request.args.get("model", None) module_param = request.args.get("module", None) method_param = request.args.get("method", None) if "f" in request.args: module_param, _, method_param = request.args.get("f", "").partition(".") if self._prioritise_params: model = model_param or model module = module_param or module method = method_param or method else: model = model or model_param module = module or module_param method = method or method_param if module is None or method is None: raise exc.ModuleMethodMissing if model is None: model_obj = None else: try: model_obj = Model.from_model_str(model) except eaapi.exception.InvalidModel: raise exc.InvalidModel return url_slash, service, model_obj, module, method def on_xrpc_request( self, request: Request, service: str | None = None, model: str | None = None, module: str | None = None, method: str | None = None ): url_slash, service, model_obj, module, method = self.parse_request(request, service, model, module, method) if request.method != "POST": return self.on_xrpc_other(request, service, model, module, method) ctx = self._create_ctx(url_slash, request, model_obj, module, method) for i in self._pre_handlers_check: i(ctx) return self._handle_request(ctx) def _make_error(self, status: int | None = None, message: str | None = None) -> Response: response = eaapi.XMLNode.void("response") if status is not None: response["status"] = str(status) if self.verbose_errors: if message: response.append("details", eaapi.Type.Str, message) context = response.append("context") if self._einfo_ctx is not None: context.append("module", eaapi.Type.Str, self._einfo_ctx.module) context.append("method", eaapi.Type.Str, self._einfo_ctx.method) context.append("game", eaapi.Type.Str, str(self._einfo_ctx.model)) if self._einfo_controller is not None: context.append("controller", eaapi.Type.Str, self._einfo_controller) encoded = eaapi.Encoder.encode(response, False) wrapped = eaapi.wrap(encoded, None, False) response = Response(wrapped, status or 500) response.headers[HEADER_COMPRESSION] = eaapi.Compression.None_.value return response def _eamhttp_error(self, exc: exc.EAMHTTPException) -> Response: return self._make_error(exc.code, exc.eam_description) def _structure_error(self, e: eaapi.exception.XMLStrutureError) -> Response: summary = traceback.extract_tb(e.__traceback__) for frame_summary in summary: filename = frame_summary.filename frame_summary.filename = os.path.relpath(filename) # The first three entries are within the controller, and the last one is us summary = summary[3:-1] tb = "".join(traceback.format_list(traceback.StackSummary.from_list(summary))) tb += f"{e.__module__}.{e.__class__.__name__}" return self._make_error(400, tb) def _generic_error(self, exc: Exception) -> Response: return self._make_error(500, str(exc)) def dispatch_request(self, request): self._einfo_ctx = None self._einfo_controller = None adapter = self.build_rules_map().bind_to_environ(request.environ) try: endpoint, values = adapter.match() return getattr(self, f"on_{endpoint}")(request, **values) except exc.EAMHTTPException as e: return self._eamhttp_error(e) except HTTPException as e: return e except eaapi.exception.XMLStrutureError as e: traceback.print_exc(file=sys.stderr) return self._structure_error(e) except Exception as e: traceback.print_exc(file=sys.stderr) return self._generic_error(e) def wsgi_app(self, environ, start_response): request = Request(environ) response = self.dispatch_request(request) return response(environ, start_response) def __call__(self, environ, start_response): for i in self._setup: i() try: response = self.wsgi_app(environ, start_response) for i in self._teardown: i(None) return response except Exception as e: for i in self._teardown: i(e) raise e def run(self, host="127.0.0.1", port=5000, debug=False): from werkzeug.serving import run_simple run_simple(host, port, self, use_debugger=debug, use_reloader=debug)