163 lines
6.0 KiB
Python
163 lines
6.0 KiB
Python
import base64
|
|
import urllib.parse
|
|
|
|
import requests
|
|
|
|
import eaapi
|
|
|
|
from werkzeug.routing import Rule
|
|
from werkzeug.wrappers import Response
|
|
|
|
from .. import exceptions as exc
|
|
from ..model import ModelMatcher
|
|
from ..context import ResponseContext
|
|
from ..server import (
|
|
METHOD_SERVICES_GET, MODULE_SERVICES, SERVICE_KEEPALIVE, SERVICE_NTP, EAMServer, HEADER_ENCRYPTION,
|
|
HEADER_COMPRESSION, TRIVIAL_SERVICES
|
|
)
|
|
|
|
|
|
class EAMProxyServer(EAMServer):
|
|
def __init__(self, upstream, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._upstream_fallback = upstream
|
|
|
|
self._taps = []
|
|
self._tapped_services = set()
|
|
|
|
if not kwargs.get("disable_routes"):
|
|
self._rules.add(Rule("/__tap/<upstream>//<model>/<module>/<method>", endpoint="tap_request"))
|
|
self._rules.add(Rule("/__tap/<upstream>", endpoint="tap_request"))
|
|
|
|
def tap_url(self, ctx, url):
|
|
url = base64.b64encode(url.encode("latin-1")).decode()
|
|
return urllib.parse.urljoin(self._public_url, f"__tap/{url}/")
|
|
|
|
def _call_upstream(self, ctx, upstream):
|
|
base = upstream or self._upstream_fallback
|
|
if ctx.url_slash:
|
|
url = f"{base}/{ctx.model}/{ctx.module}/{ctx.method}"
|
|
else:
|
|
url = f"{base}?model={ctx.model}&module={ctx.module}&method={ctx.method}"
|
|
|
|
return requests.post(url, headers=ctx.request.headers, data=ctx.request.data)
|
|
|
|
def tap(self, module, method=None, matcher=None, service=None):
|
|
if method is None:
|
|
def h2(method):
|
|
return self.handler(module, method, matcher, service)
|
|
return h2
|
|
|
|
def decorator(handler):
|
|
matcher_ = matcher or ModelMatcher()
|
|
self._taps.append((module, method, matcher_, handler))
|
|
|
|
if service is None:
|
|
if module in TRIVIAL_SERVICES:
|
|
self._tapped_services.add((matcher_, module))
|
|
else:
|
|
raise ValueError(f"Unable to identify service for {module}")
|
|
else:
|
|
self._tapped_services.add((matcher_, service))
|
|
return handler
|
|
return decorator
|
|
|
|
def _decode_us_request(self, resp):
|
|
ea_info = resp.headers.get(HEADER_ENCRYPTION)
|
|
compression = resp.headers.get(HEADER_COMPRESSION)
|
|
|
|
compressed = False
|
|
if compression == eaapi.Compression.Lz77.value:
|
|
compressed = True
|
|
elif compression != eaapi.Compression.None_.value:
|
|
raise exc.UnknownCompression
|
|
|
|
payload = eaapi.unwrap(resp.content, ea_info, compressed)
|
|
decoder = eaapi.Decoder(payload)
|
|
try:
|
|
response = decoder.unpack()
|
|
except eaapi.exception.EAAPIException:
|
|
raise exc.InvalidPacket
|
|
|
|
return ResponseContext(resp, decoder, response, compressed)
|
|
|
|
def call_upstream(self, ctx, upstream):
|
|
resp = self._call_upstream(ctx, upstream)
|
|
return self._decode_us_request(resp)
|
|
|
|
def _services_handler(self, ctx, upstream=None):
|
|
upstream_ctx = self.call_upstream(ctx, upstream)
|
|
|
|
super()._services_handler(ctx)
|
|
services = ctx.resp.xpath("services")
|
|
# Never proxy NTP or keepalive
|
|
services.children = [
|
|
i for i in services.children
|
|
if i.get("name") not in (SERVICE_NTP, SERVICE_KEEPALIVE)
|
|
]
|
|
|
|
added = set(i.get("name") for i in services)
|
|
|
|
for i in upstream_ctx.response.xpath("services"):
|
|
name = i.get("name")
|
|
if name in added:
|
|
continue
|
|
|
|
url = i.get("url")
|
|
# for matcher, service in self._tapped_services:
|
|
# if name == service and matcher.matches(ctx.model):
|
|
# url = self.tap_url(ctx, url)
|
|
if name != "keepalive":
|
|
url = self.tap_url(ctx, url)
|
|
|
|
services.append("item", name=name, url=url)
|
|
|
|
def _handle_request(self, upstream, url_slash, request, model, module, method):
|
|
ctx = self._create_ctx(url_slash, request, model, module, method)
|
|
|
|
if ctx.module == MODULE_SERVICES:
|
|
if ctx.method != METHOD_SERVICES_GET:
|
|
raise exc.NoMethodHandler
|
|
self._services_handler(ctx, upstream)
|
|
return self._encode_response(ctx)
|
|
|
|
try:
|
|
return super()._handle_request(ctx)
|
|
except exc.NoMethodHandler:
|
|
try:
|
|
us_resp = self._call_upstream(ctx, upstream)
|
|
except requests.RequestException:
|
|
raise exc.UpstreamFailed
|
|
|
|
try:
|
|
resp_ctx = self._decode_us_request(us_resp)
|
|
for tap in self._taps:
|
|
if tap[0] == module and tap[1] == method and tap[2].matches(ctx.model):
|
|
tap[3](ctx, resp_ctx)
|
|
except Exception:
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Return the response the upstream sent
|
|
response = Response(us_resp.content, us_resp.status_code)
|
|
for i in us_resp.headers:
|
|
response.headers[i] = us_resp.headers[i]
|
|
return response
|
|
|
|
def on_xrpc_request(self, request, service=None, model=None, module=None, method=None):
|
|
return self.on_tap_request(request, None, service, model, module, method)
|
|
|
|
def on_tap_request(self, request, upstream, service=None, model=None, module=None, method=None):
|
|
url_slash, service, model, module, method = self.parse_request(request, service, model, module, method)
|
|
|
|
if request.method != "POST":
|
|
return self.on_xrpc_other(request, service, model, module, method)
|
|
|
|
if upstream is not None:
|
|
try:
|
|
upstream = base64.b64decode(upstream).decode("latin-1")
|
|
except Exception:
|
|
raise exc.InvalidUpstream
|
|
|
|
return self._handle_request(upstream, url_slash, request, model, module, method)
|