import base64 import urllib.parse import requests import eaapi from werkzeug.routing import Rule from werkzeug.wrappers import Response from .. import exceptions as exc from ..model import ModelMatcher from ..context import ResponseContext from ..server import ( METHOD_SERVICES_GET, MODULE_SERVICES, SERVICE_KEEPALIVE, SERVICE_NTP, EAMServer, HEADER_ENCRYPTION, HEADER_COMPRESSION, TRIVIAL_SERVICES ) class EAMProxyServer(EAMServer): def __init__(self, upstream, *args, **kwargs): super().__init__(*args, **kwargs) self._upstream_fallback = upstream self._taps = [] self._tapped_services = set() if not kwargs.get("disable_routes"): self._rules.add(Rule("/__tap/////", endpoint="tap_request")) self._rules.add(Rule("/__tap/", endpoint="tap_request")) def tap_url(self, ctx, url): url = base64.b64encode(url.encode("latin-1")).decode() return urllib.parse.urljoin(self._public_url, f"__tap/{url}/") def _call_upstream(self, ctx, upstream): base = upstream or self._upstream_fallback if ctx.url_slash: url = f"{base}/{ctx.model}/{ctx.module}/{ctx.method}" else: url = f"{base}?model={ctx.model}&module={ctx.module}&method={ctx.method}" return requests.post(url, headers=ctx.request.headers, data=ctx.request.data) def tap(self, module, method=None, matcher=None, service=None): if method is None: def h2(method): return self.handler(module, method, matcher, service) return h2 def decorator(handler): matcher_ = matcher or ModelMatcher() self._taps.append((module, method, matcher_, handler)) if service is None: if module in TRIVIAL_SERVICES: self._tapped_services.add((matcher_, module)) else: raise ValueError(f"Unable to identify service for {module}") else: self._tapped_services.add((matcher_, service)) return handler return decorator def _decode_us_request(self, resp): ea_info = resp.headers.get(HEADER_ENCRYPTION) compression = resp.headers.get(HEADER_COMPRESSION) compressed = False if compression == eaapi.Compression.Lz77.value: compressed = True elif compression != eaapi.Compression.None_.value: raise exc.UnknownCompression payload = eaapi.unwrap(resp.content, ea_info, compressed) decoder = eaapi.Decoder(payload) try: response = decoder.unpack() except eaapi.exception.EAAPIException: raise exc.InvalidPacket return ResponseContext(resp, decoder, response, compressed) def call_upstream(self, ctx, upstream): resp = self._call_upstream(ctx, upstream) return self._decode_us_request(resp) def _services_handler(self, ctx, upstream=None): upstream_ctx = self.call_upstream(ctx, upstream) super()._services_handler(ctx) services = ctx.resp.xpath("services") # Never proxy NTP or keepalive services.children = [ i for i in services.children if i.get("name") not in (SERVICE_NTP, SERVICE_KEEPALIVE) ] added = set(i.get("name") for i in services) for i in upstream_ctx.response.xpath("services"): name = i.get("name") if name in added: continue url = i.get("url") # for matcher, service in self._tapped_services: # if name == service and matcher.matches(ctx.model): # url = self.tap_url(ctx, url) if name != "keepalive": url = self.tap_url(ctx, url) services.append("item", name=name, url=url) def _handle_request(self, upstream, url_slash, request, model, module, method): ctx = self._create_ctx(url_slash, request, model, module, method) if ctx.module == MODULE_SERVICES: if ctx.method != METHOD_SERVICES_GET: raise exc.NoMethodHandler self._services_handler(ctx, upstream) return self._encode_response(ctx) try: return super()._handle_request(ctx) except exc.NoMethodHandler: try: us_resp = self._call_upstream(ctx, upstream) except requests.RequestException: raise exc.UpstreamFailed try: resp_ctx = self._decode_us_request(us_resp) for tap in self._taps: if tap[0] == module and tap[1] == method and tap[2].matches(ctx.model): tap[3](ctx, resp_ctx) except Exception: import traceback traceback.print_exc() finally: # Return the response the upstream sent response = Response(us_resp.content, us_resp.status_code) for i in us_resp.headers: response.headers[i] = us_resp.headers[i] return response def on_xrpc_request(self, request, service=None, model=None, module=None, method=None): return self.on_tap_request(request, None, service, model, module, method) def on_tap_request(self, request, upstream, service=None, model=None, module=None, method=None): url_slash, service, model, module, method = self.parse_request(request, service, model, module, method) if request.method != "POST": return self.on_xrpc_other(request, service, model, module, method) if upstream is not None: try: upstream = base64.b64decode(upstream).decode("latin-1") except Exception: raise exc.InvalidUpstream return self._handle_request(upstream, url_slash, request, model, module, method)