eaapi/eaapi/server/demo/proxy.py

163 lines
6.0 KiB
Python

import base64
import urllib.parse
import requests
import eaapi
from werkzeug.routing import Rule
from werkzeug.wrappers import Response
from .. import exceptions as exc
from ..model import ModelMatcher
from ..context import ResponseContext
from ..server import (
METHOD_SERVICES_GET, MODULE_SERVICES, SERVICE_KEEPALIVE, SERVICE_NTP, EAMServer, HEADER_ENCRYPTION,
HEADER_COMPRESSION, TRIVIAL_SERVICES
)
class EAMProxyServer(EAMServer):
def __init__(self, upstream, *args, **kwargs):
super().__init__(*args, **kwargs)
self._upstream_fallback = upstream
self._taps = []
self._tapped_services = set()
if not kwargs.get("disable_routes"):
self._rules.add(Rule("/__tap/<upstream>//<model>/<module>/<method>", endpoint="tap_request"))
self._rules.add(Rule("/__tap/<upstream>", endpoint="tap_request"))
def tap_url(self, ctx, url):
url = base64.b64encode(url.encode("latin-1")).decode()
return urllib.parse.urljoin(self._public_url, f"__tap/{url}/")
def _call_upstream(self, ctx, upstream):
base = upstream or self._upstream_fallback
if ctx.url_slash:
url = f"{base}/{ctx.model}/{ctx.module}/{ctx.method}"
else:
url = f"{base}?model={ctx.model}&module={ctx.module}&method={ctx.method}"
return requests.post(url, headers=ctx.request.headers, data=ctx.request.data)
def tap(self, module, method=None, matcher=None, service=None):
if method is None:
def h2(method):
return self.handler(module, method, matcher, service)
return h2
def decorator(handler):
matcher_ = matcher or ModelMatcher()
self._taps.append((module, method, matcher_, handler))
if service is None:
if module in TRIVIAL_SERVICES:
self._tapped_services.add((matcher_, module))
else:
raise ValueError(f"Unable to identify service for {module}")
else:
self._tapped_services.add((matcher_, service))
return handler
return decorator
def _decode_us_request(self, resp):
ea_info = resp.headers.get(HEADER_ENCRYPTION)
compression = resp.headers.get(HEADER_COMPRESSION)
compressed = False
if compression == eaapi.Compression.Lz77.value:
compressed = True
elif compression != eaapi.Compression.None_.value:
raise exc.UnknownCompression
payload = eaapi.unwrap(resp.content, ea_info, compressed)
decoder = eaapi.Decoder(payload)
try:
response = decoder.unpack()
except eaapi.exception.EAAPIException:
raise exc.InvalidPacket
return ResponseContext(resp, decoder, response, compressed)
def call_upstream(self, ctx, upstream):
resp = self._call_upstream(ctx, upstream)
return self._decode_us_request(resp)
def _services_handler(self, ctx, upstream=None):
upstream_ctx = self.call_upstream(ctx, upstream)
super()._services_handler(ctx)
services = ctx.resp.xpath("services")
# Never proxy NTP or keepalive
services.children = [
i for i in services.children
if i.get("name") not in (SERVICE_NTP, SERVICE_KEEPALIVE)
]
added = set(i.get("name") for i in services)
for i in upstream_ctx.response.xpath("services"):
name = i.get("name")
if name in added:
continue
url = i.get("url")
# for matcher, service in self._tapped_services:
# if name == service and matcher.matches(ctx.model):
# url = self.tap_url(ctx, url)
if name != "keepalive":
url = self.tap_url(ctx, url)
services.append("item", name=name, url=url)
def _handle_request(self, upstream, url_slash, request, model, module, method):
ctx = self._create_ctx(url_slash, request, model, module, method)
if ctx.module == MODULE_SERVICES:
if ctx.method != METHOD_SERVICES_GET:
raise exc.NoMethodHandler
self._services_handler(ctx, upstream)
return self._encode_response(ctx)
try:
return super()._handle_request(ctx)
except exc.NoMethodHandler:
try:
us_resp = self._call_upstream(ctx, upstream)
except requests.RequestException:
raise exc.UpstreamFailed
try:
resp_ctx = self._decode_us_request(us_resp)
for tap in self._taps:
if tap[0] == module and tap[1] == method and tap[2].matches(ctx.model):
tap[3](ctx, resp_ctx)
except Exception:
import traceback
traceback.print_exc()
finally:
# Return the response the upstream sent
response = Response(us_resp.content, us_resp.status_code)
for i in us_resp.headers:
response.headers[i] = us_resp.headers[i]
return response
def on_xrpc_request(self, request, service=None, model=None, module=None, method=None):
return self.on_tap_request(request, None, service, model, module, method)
def on_tap_request(self, request, upstream, service=None, model=None, module=None, method=None):
url_slash, service, model, module, method = self.parse_request(request, service, model, module, method)
if request.method != "POST":
return self.on_xrpc_other(request, service, model, module, method)
if upstream is not None:
try:
upstream = base64.b64decode(upstream).decode("latin-1")
except Exception:
raise exc.InvalidUpstream
return self._handle_request(upstream, url_slash, request, model, module, method)