398 lines
13 KiB
Python
398 lines
13 KiB
Python
import traceback
|
|
import urllib.parse
|
|
import urllib
|
|
import sys
|
|
import os
|
|
|
|
from typing import Callable
|
|
from collections import defaultdict
|
|
|
|
from werkzeug.exceptions import HTTPException, MethodNotAllowed
|
|
from werkzeug.wrappers import Request, Response
|
|
from werkzeug.routing import Map, Rule
|
|
|
|
import eaapi
|
|
|
|
from . import exceptions as exc
|
|
from .context import CallContext
|
|
from .model import Model
|
|
from .controller import IController, ServicesController
|
|
from .const import SERVICE_NTP, SERVICE_KEEPALIVE
|
|
|
|
|
|
Handler = Callable[[CallContext], None]
|
|
|
|
|
|
HEADER_ENCRYPTION = "X-Eamuse-Info"
|
|
HEADER_COMPRESSION = "X-Compress"
|
|
|
|
PINGABLE_IP = "127.0.0.1"
|
|
|
|
|
|
class NetworkState:
|
|
def __init__(self):
|
|
self._pa = PINGABLE_IP # TODO: what does this one mean?
|
|
self.router_ip = PINGABLE_IP
|
|
self.gateway_ip = PINGABLE_IP
|
|
self.center_ip = PINGABLE_IP
|
|
|
|
def format_ka(self, base):
|
|
return base + "?" + urllib.parse.urlencode({
|
|
"pa": self.pa,
|
|
"ia": self.ia,
|
|
"ga": self.ga,
|
|
"ma": self.ma,
|
|
"t1": self.t1,
|
|
"t2": self.t2,
|
|
})
|
|
|
|
@property
|
|
def pa(self) -> str:
|
|
return self._pa
|
|
|
|
@property
|
|
def ia(self) -> str:
|
|
return self.router_ip
|
|
|
|
@property
|
|
def ga(self) -> str:
|
|
return self.gateway_ip
|
|
|
|
@property
|
|
def ma(self) -> str:
|
|
return self.center_ip
|
|
|
|
# TODO: Identify what these values are. Ping intervals?
|
|
@property
|
|
def t1(self):
|
|
return 2
|
|
|
|
@property
|
|
def t2(self):
|
|
return 10
|
|
|
|
|
|
class EAMServer:
|
|
def __init__(
|
|
self,
|
|
public_url: str,
|
|
prioritise_params: bool = False,
|
|
verbose_errors: bool = False,
|
|
services_mode: eaapi.const.ServicesMode = eaapi.const.ServicesMode.Operation,
|
|
ntp_server: str = "ntp://pool.ntp.org/",
|
|
keepalive_server: str | None = None,
|
|
no_keepalive_route: bool = False,
|
|
disable_routes: bool = False,
|
|
no_services_handler: bool = False,
|
|
):
|
|
self.network = NetworkState()
|
|
|
|
self.verbose_errors = verbose_errors
|
|
|
|
self._prioritise_params = prioritise_params
|
|
self._public_url = public_url
|
|
|
|
self.disable_routes = disable_routes
|
|
self._no_keepalive_route = no_keepalive_route
|
|
|
|
self.ntp = ntp_server
|
|
self.keepalive = keepalive_server or f"{public_url}/keepalive"
|
|
|
|
self._prng = eaapi.crypt.new_prng()
|
|
|
|
self._setup = []
|
|
self._pre_handlers_check = []
|
|
self._teardown = []
|
|
|
|
self._einfo_ctx: CallContext | None = None
|
|
self._einfo_controller: str | None = None
|
|
|
|
self.controllers: list[IController] = []
|
|
if not no_services_handler:
|
|
self.controllers.append(ServicesController(self, services_mode))
|
|
|
|
def on_setup(self, callback):
|
|
if callback not in self._setup:
|
|
self._setup.append(callback)
|
|
|
|
def on_pre_handlers_check(self, callback):
|
|
if callback not in self._pre_handlers_check:
|
|
self._pre_handlers_check.append(callback)
|
|
|
|
def on_teardown(self, callback):
|
|
if callback not in self._teardown:
|
|
self._teardown.append(callback)
|
|
|
|
def build_rules_map(self) -> Map:
|
|
if self.disable_routes:
|
|
return Map([])
|
|
|
|
rules = Map([], strict_slashes=False, merge_slashes=False)
|
|
|
|
prefixes = {"/"}
|
|
for i in self.controllers:
|
|
for prefix in i.serviced_prefixes():
|
|
prefix = self.expand_url(prefix)
|
|
if not prefix.startswith(self._public_url):
|
|
continue
|
|
prefix = prefix[len(self._public_url):]
|
|
if prefix == "":
|
|
prefix = "/"
|
|
|
|
prefixes.add(prefix)
|
|
|
|
for i in prefixes:
|
|
rules.add(Rule(f"{i}/<model>/<module>/<method>", endpoint="xrpc_request"))
|
|
# WSGI flattens the // at the start
|
|
if i == "/":
|
|
rules.add(Rule("/<model>/<module>/<method>", endpoint="xrpc_request"))
|
|
rules.add(Rule(f"{i}", endpoint="xrpc_request"))
|
|
|
|
if not self._no_keepalive_route:
|
|
rules.add(Rule("/keepalive", endpoint="keepalive_request"))
|
|
|
|
return rules
|
|
|
|
def expand_url(self, url: str) -> str:
|
|
return urllib.parse.urljoin(self._public_url, url)
|
|
|
|
@property
|
|
def public_url(self) -> str:
|
|
return self._public_url
|
|
|
|
def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]:
|
|
services: dict[str, str] = defaultdict(lambda: self.public_url)
|
|
services[SERVICE_NTP] = self.ntp
|
|
services[SERVICE_KEEPALIVE] = self.network.format_ka(self.keepalive)
|
|
|
|
for i in self.controllers:
|
|
services.update(i.get_service_routes(ctx))
|
|
return services
|
|
|
|
def _decode_request(self, request: Request) -> CallContext:
|
|
ea_info = request.headers.get(HEADER_ENCRYPTION)
|
|
compression = request.headers.get(HEADER_COMPRESSION)
|
|
|
|
compressed = False
|
|
if compression == eaapi.Compression.Lz77.value:
|
|
compressed = True
|
|
elif compression != eaapi.Compression.None_.value:
|
|
raise exc.UnknownCompression
|
|
|
|
payload = eaapi.unwrap(request.data, ea_info, compressed)
|
|
decoder = eaapi.Decoder(payload)
|
|
try:
|
|
call = decoder.unpack()
|
|
except eaapi.EAAPIException:
|
|
raise exc.InvalidPacket
|
|
|
|
return CallContext(request, decoder, call, ea_info, compressed)
|
|
|
|
def _encode_response(self, ctx: CallContext) -> Response:
|
|
if ctx._eainfo is None:
|
|
ea_info = None
|
|
else:
|
|
ea_info = eaapi.crypt.get_key(self._prng)
|
|
|
|
encoded = eaapi.Encoder.encode(ctx.resp, ctx.was_xml_string)
|
|
wrapped = eaapi.wrap(encoded, ea_info, ctx.was_compressed)
|
|
response = Response(wrapped, 200)
|
|
if ea_info:
|
|
response.headers[HEADER_ENCRYPTION] = ea_info
|
|
response.headers[HEADER_COMPRESSION] = (
|
|
eaapi.Compression.Lz77 if ctx.was_compressed
|
|
else eaapi.Compression.None_
|
|
).value
|
|
|
|
return response
|
|
|
|
def _create_ctx(
|
|
self,
|
|
url_slash: bool,
|
|
request: Request,
|
|
model: Model | None,
|
|
module: str,
|
|
method: str
|
|
) -> CallContext:
|
|
ctx = self._decode_request(request)
|
|
ctx._module = module
|
|
ctx._method = method
|
|
ctx._url_slash = url_slash
|
|
self._einfo_ctx = ctx
|
|
|
|
if ctx.model != model:
|
|
raise exc.ModelMissmatch
|
|
return ctx
|
|
|
|
def _handle_request(self, ctx: CallContext) -> Response:
|
|
for controller in self.controllers:
|
|
if (handler := controller.get_handler(ctx)) is not None:
|
|
self._einfo_controller = (
|
|
f"{controller._name}"
|
|
)
|
|
|
|
handler(ctx)
|
|
break
|
|
else:
|
|
raise exc.NoMethodHandler
|
|
|
|
return self._encode_response(ctx)
|
|
|
|
def on_xrpc_other(
|
|
self,
|
|
request: Request,
|
|
service: str | None = None,
|
|
model: str | None = None,
|
|
module: str | None = None,
|
|
method: str | None = None
|
|
):
|
|
if request.method != "GET" or not self.verbose_errors:
|
|
raise MethodNotAllowed
|
|
|
|
return Response(
|
|
f"XRPC running. model {model}, call {module}.{method} ({service})"
|
|
)
|
|
|
|
def keepalive_request(self) -> Response:
|
|
return Response(None)
|
|
|
|
def parse_request(
|
|
self,
|
|
request: Request,
|
|
service: str | None = None,
|
|
model: str | None = None,
|
|
module: str | None = None,
|
|
method: str | None = None
|
|
):
|
|
url_slash = bool(module and module and method)
|
|
model_param = request.args.get("model", None)
|
|
module_param = request.args.get("module", None)
|
|
method_param = request.args.get("method", None)
|
|
if "f" in request.args:
|
|
module_param, _, method_param = request.args.get("f", "").partition(".")
|
|
|
|
if self._prioritise_params:
|
|
model = model_param or model
|
|
module = module_param or module
|
|
method = method_param or method
|
|
else:
|
|
model = model or model_param
|
|
module = module or module_param
|
|
method = method or method_param
|
|
|
|
if module is None or method is None:
|
|
raise exc.ModuleMethodMissing
|
|
|
|
if model is None:
|
|
model_obj = None
|
|
else:
|
|
try:
|
|
model_obj = Model.from_model_str(model)
|
|
except eaapi.exception.InvalidModel:
|
|
raise exc.InvalidModel
|
|
|
|
return url_slash, service, model_obj, module, method
|
|
|
|
def on_xrpc_request(
|
|
self,
|
|
request: Request,
|
|
service: str | None = None,
|
|
model: str | None = None,
|
|
module: str | None = None,
|
|
method: str | None = None
|
|
):
|
|
url_slash, service, model_obj, module, method = self.parse_request(request, service, model, module, method)
|
|
|
|
if request.method != "POST":
|
|
return self.on_xrpc_other(request, service, model, module, method)
|
|
|
|
ctx = self._create_ctx(url_slash, request, model_obj, module, method)
|
|
for i in self._pre_handlers_check:
|
|
i(ctx)
|
|
return self._handle_request(ctx)
|
|
|
|
def _make_error(self, status: int | None = None, message: str | None = None) -> Response:
|
|
response = eaapi.XMLNode.void("response")
|
|
if status is not None:
|
|
response["status"] = str(status)
|
|
|
|
if self.verbose_errors:
|
|
if message:
|
|
response.append("details", eaapi.Type.Str, message)
|
|
|
|
context = response.append("context")
|
|
if self._einfo_ctx is not None:
|
|
context.append("module", eaapi.Type.Str, self._einfo_ctx.module)
|
|
context.append("method", eaapi.Type.Str, self._einfo_ctx.method)
|
|
context.append("game", eaapi.Type.Str, str(self._einfo_ctx.model))
|
|
if self._einfo_controller is not None:
|
|
context.append("controller", eaapi.Type.Str, self._einfo_controller)
|
|
|
|
encoded = eaapi.Encoder.encode(response, False)
|
|
wrapped = eaapi.wrap(encoded, None, False)
|
|
response = Response(wrapped, status or 500)
|
|
response.headers[HEADER_COMPRESSION] = eaapi.Compression.None_.value
|
|
|
|
return response
|
|
|
|
def _eamhttp_error(self, exc: exc.EAMHTTPException) -> Response:
|
|
return self._make_error(exc.code, exc.eam_description)
|
|
|
|
def _structure_error(self, e: eaapi.exception.XMLStrutureError) -> Response:
|
|
summary = traceback.extract_tb(e.__traceback__)
|
|
for frame_summary in summary:
|
|
filename = frame_summary.filename
|
|
frame_summary.filename = os.path.relpath(filename)
|
|
|
|
# The first three entries are within the controller, and the last one is us
|
|
summary = summary[3:-1]
|
|
tb = "".join(traceback.format_list(traceback.StackSummary.from_list(summary)))
|
|
tb += f"{e.__module__}.{e.__class__.__name__}"
|
|
|
|
return self._make_error(400, tb)
|
|
|
|
def _generic_error(self, exc: Exception) -> Response:
|
|
return self._make_error(500, str(exc))
|
|
|
|
def dispatch_request(self, request):
|
|
self._einfo_ctx = None
|
|
self._einfo_controller = None
|
|
|
|
adapter = self.build_rules_map().bind_to_environ(request.environ)
|
|
try:
|
|
endpoint, values = adapter.match()
|
|
return getattr(self, f"on_{endpoint}")(request, **values)
|
|
except exc.EAMHTTPException as e:
|
|
return self._eamhttp_error(e)
|
|
except HTTPException as e:
|
|
return e
|
|
except eaapi.exception.XMLStrutureError as e:
|
|
traceback.print_exc(file=sys.stderr)
|
|
return self._structure_error(e)
|
|
except Exception as e:
|
|
traceback.print_exc(file=sys.stderr)
|
|
return self._generic_error(e)
|
|
|
|
def wsgi_app(self, environ, start_response):
|
|
request = Request(environ)
|
|
response = self.dispatch_request(request)
|
|
return response(environ, start_response)
|
|
|
|
def __call__(self, environ, start_response):
|
|
for i in self._setup:
|
|
i()
|
|
|
|
try:
|
|
response = self.wsgi_app(environ, start_response)
|
|
for i in self._teardown:
|
|
i(None)
|
|
return response
|
|
except Exception as e:
|
|
for i in self._teardown:
|
|
i(e)
|
|
raise e
|
|
|
|
def run(self, host="127.0.0.1", port=5000, debug=False):
|
|
from werkzeug.serving import run_simple
|
|
run_simple(host, port, self, use_debugger=debug, use_reloader=debug)
|