eaapi/eaapi/server/server.py

398 lines
13 KiB
Python

import traceback
import urllib.parse
import urllib
import sys
import os
from typing import Callable
from collections import defaultdict
from werkzeug.exceptions import HTTPException, MethodNotAllowed
from werkzeug.wrappers import Request, Response
from werkzeug.routing import Map, Rule
import eaapi
from . import exceptions as exc
from .context import CallContext
from .model import Model
from .controller import IController, ServicesController
from .const import SERVICE_NTP, SERVICE_KEEPALIVE
Handler = Callable[[CallContext], None]
HEADER_ENCRYPTION = "X-Eamuse-Info"
HEADER_COMPRESSION = "X-Compress"
PINGABLE_IP = "127.0.0.1"
class NetworkState:
def __init__(self):
self._pa = PINGABLE_IP # TODO: what does this one mean?
self.router_ip = PINGABLE_IP
self.gateway_ip = PINGABLE_IP
self.center_ip = PINGABLE_IP
def format_ka(self, base):
return base + "?" + urllib.parse.urlencode({
"pa": self.pa,
"ia": self.ia,
"ga": self.ga,
"ma": self.ma,
"t1": self.t1,
"t2": self.t2,
})
@property
def pa(self) -> str:
return self._pa
@property
def ia(self) -> str:
return self.router_ip
@property
def ga(self) -> str:
return self.gateway_ip
@property
def ma(self) -> str:
return self.center_ip
# TODO: Identify what these values are. Ping intervals?
@property
def t1(self):
return 2
@property
def t2(self):
return 10
class EAMServer:
def __init__(
self,
public_url: str,
prioritise_params: bool = False,
verbose_errors: bool = False,
services_mode: eaapi.const.ServicesMode = eaapi.const.ServicesMode.Operation,
ntp_server: str = "ntp://pool.ntp.org/",
keepalive_server: str | None = None,
no_keepalive_route: bool = False,
disable_routes: bool = False,
no_services_handler: bool = False,
):
self.network = NetworkState()
self.verbose_errors = verbose_errors
self._prioritise_params = prioritise_params
self._public_url = public_url
self.disable_routes = disable_routes
self._no_keepalive_route = no_keepalive_route
self.ntp = ntp_server
self.keepalive = keepalive_server or f"{public_url}/keepalive"
self._prng = eaapi.crypt.new_prng()
self._setup = []
self._pre_handlers_check = []
self._teardown = []
self._einfo_ctx: CallContext | None = None
self._einfo_controller: str | None = None
self.controllers: list[IController] = []
if not no_services_handler:
self.controllers.append(ServicesController(self, services_mode))
def on_setup(self, callback):
if callback not in self._setup:
self._setup.append(callback)
def on_pre_handlers_check(self, callback):
if callback not in self._pre_handlers_check:
self._pre_handlers_check.append(callback)
def on_teardown(self, callback):
if callback not in self._teardown:
self._teardown.append(callback)
def build_rules_map(self) -> Map:
if self.disable_routes:
return Map([])
rules = Map([], strict_slashes=False, merge_slashes=False)
prefixes = {"/"}
for i in self.controllers:
for prefix in i.serviced_prefixes():
prefix = self.expand_url(prefix)
if not prefix.startswith(self._public_url):
continue
prefix = prefix[len(self._public_url):]
if prefix == "":
prefix = "/"
prefixes.add(prefix)
for i in prefixes:
rules.add(Rule(f"{i}/<model>/<module>/<method>", endpoint="xrpc_request"))
# WSGI flattens the // at the start
if i == "/":
rules.add(Rule("/<model>/<module>/<method>", endpoint="xrpc_request"))
rules.add(Rule(f"{i}", endpoint="xrpc_request"))
if not self._no_keepalive_route:
rules.add(Rule("/keepalive", endpoint="keepalive_request"))
return rules
def expand_url(self, url: str) -> str:
return urllib.parse.urljoin(self._public_url, url)
@property
def public_url(self) -> str:
return self._public_url
def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]:
services: dict[str, str] = defaultdict(lambda: self.public_url)
services[SERVICE_NTP] = self.ntp
services[SERVICE_KEEPALIVE] = self.network.format_ka(self.keepalive)
for i in self.controllers:
services.update(i.get_service_routes(ctx))
return services
def _decode_request(self, request: Request) -> CallContext:
ea_info = request.headers.get(HEADER_ENCRYPTION)
compression = request.headers.get(HEADER_COMPRESSION)
compressed = False
if compression == eaapi.Compression.Lz77.value:
compressed = True
elif compression != eaapi.Compression.None_.value:
raise exc.UnknownCompression
payload = eaapi.unwrap(request.data, ea_info, compressed)
decoder = eaapi.Decoder(payload)
try:
call = decoder.unpack()
except eaapi.EAAPIException:
raise exc.InvalidPacket
return CallContext(request, decoder, call, ea_info, compressed)
def _encode_response(self, ctx: CallContext) -> Response:
if ctx._eainfo is None:
ea_info = None
else:
ea_info = eaapi.crypt.get_key(self._prng)
encoded = eaapi.Encoder.encode(ctx.resp, ctx.was_xml_string)
wrapped = eaapi.wrap(encoded, ea_info, ctx.was_compressed)
response = Response(wrapped, 200)
if ea_info:
response.headers[HEADER_ENCRYPTION] = ea_info
response.headers[HEADER_COMPRESSION] = (
eaapi.Compression.Lz77 if ctx.was_compressed
else eaapi.Compression.None_
).value
return response
def _create_ctx(
self,
url_slash: bool,
request: Request,
model: Model | None,
module: str,
method: str
) -> CallContext:
ctx = self._decode_request(request)
ctx._module = module
ctx._method = method
ctx._url_slash = url_slash
self._einfo_ctx = ctx
if ctx.model != model:
raise exc.ModelMissmatch
return ctx
def _handle_request(self, ctx: CallContext) -> Response:
for controller in self.controllers:
if (handler := controller.get_handler(ctx)) is not None:
self._einfo_controller = (
f"{controller._name}"
)
handler(ctx)
break
else:
raise exc.NoMethodHandler
return self._encode_response(ctx)
def on_xrpc_other(
self,
request: Request,
service: str | None = None,
model: str | None = None,
module: str | None = None,
method: str | None = None
):
if request.method != "GET" or not self.verbose_errors:
raise MethodNotAllowed
return Response(
f"XRPC running. model {model}, call {module}.{method} ({service})"
)
def keepalive_request(self) -> Response:
return Response(None)
def parse_request(
self,
request: Request,
service: str | None = None,
model: str | None = None,
module: str | None = None,
method: str | None = None
):
url_slash = bool(module and module and method)
model_param = request.args.get("model", None)
module_param = request.args.get("module", None)
method_param = request.args.get("method", None)
if "f" in request.args:
module_param, _, method_param = request.args.get("f", "").partition(".")
if self._prioritise_params:
model = model_param or model
module = module_param or module
method = method_param or method
else:
model = model or model_param
module = module or module_param
method = method or method_param
if module is None or method is None:
raise exc.ModuleMethodMissing
if model is None:
model_obj = None
else:
try:
model_obj = Model.from_model_str(model)
except eaapi.exception.InvalidModel:
raise exc.InvalidModel
return url_slash, service, model_obj, module, method
def on_xrpc_request(
self,
request: Request,
service: str | None = None,
model: str | None = None,
module: str | None = None,
method: str | None = None
):
url_slash, service, model_obj, module, method = self.parse_request(request, service, model, module, method)
if request.method != "POST":
return self.on_xrpc_other(request, service, model, module, method)
ctx = self._create_ctx(url_slash, request, model_obj, module, method)
for i in self._pre_handlers_check:
i(ctx)
return self._handle_request(ctx)
def _make_error(self, status: int | None = None, message: str | None = None) -> Response:
response = eaapi.XMLNode.void("response")
if status is not None:
response["status"] = str(status)
if self.verbose_errors:
if message:
response.append("details", eaapi.Type.Str, message)
context = response.append("context")
if self._einfo_ctx is not None:
context.append("module", eaapi.Type.Str, self._einfo_ctx.module)
context.append("method", eaapi.Type.Str, self._einfo_ctx.method)
context.append("game", eaapi.Type.Str, str(self._einfo_ctx.model))
if self._einfo_controller is not None:
context.append("controller", eaapi.Type.Str, self._einfo_controller)
encoded = eaapi.Encoder.encode(response, False)
wrapped = eaapi.wrap(encoded, None, False)
response = Response(wrapped, status or 500)
response.headers[HEADER_COMPRESSION] = eaapi.Compression.None_.value
return response
def _eamhttp_error(self, exc: exc.EAMHTTPException) -> Response:
return self._make_error(exc.code, exc.eam_description)
def _structure_error(self, e: eaapi.exception.XMLStrutureError) -> Response:
summary = traceback.extract_tb(e.__traceback__)
for frame_summary in summary:
filename = frame_summary.filename
frame_summary.filename = os.path.relpath(filename)
# The first three entries are within the controller, and the last one is us
summary = summary[3:-1]
tb = "".join(traceback.format_list(traceback.StackSummary.from_list(summary)))
tb += f"{e.__module__}.{e.__class__.__name__}"
return self._make_error(400, tb)
def _generic_error(self, exc: Exception) -> Response:
return self._make_error(500, str(exc))
def dispatch_request(self, request):
self._einfo_ctx = None
self._einfo_controller = None
adapter = self.build_rules_map().bind_to_environ(request.environ)
try:
endpoint, values = adapter.match()
return getattr(self, f"on_{endpoint}")(request, **values)
except exc.EAMHTTPException as e:
return self._eamhttp_error(e)
except HTTPException as e:
return e
except eaapi.exception.XMLStrutureError as e:
traceback.print_exc(file=sys.stderr)
return self._structure_error(e)
except Exception as e:
traceback.print_exc(file=sys.stderr)
return self._generic_error(e)
def wsgi_app(self, environ, start_response):
request = Request(environ)
response = self.dispatch_request(request)
return response(environ, start_response)
def __call__(self, environ, start_response):
for i in self._setup:
i()
try:
response = self.wsgi_app(environ, start_response)
for i in self._teardown:
i(None)
return response
except Exception as e:
for i in self._teardown:
i(e)
raise e
def run(self, host="127.0.0.1", port=5000, debug=False):
from werkzeug.serving import run_simple
run_simple(host, port, self, use_debugger=debug, use_reloader=debug)