211 lines
6.9 KiB
Python
211 lines
6.9 KiB
Python
from typing import Callable
|
|
from collections import defaultdict
|
|
from abc import ABC
|
|
|
|
from .context import CallContext
|
|
from .model import ModelMatcher
|
|
from .const import RESERVED_MODULES, RESERVED_SERVICES, TRIVIAL_SERVICES, MODULE_SERVICES, METHOD_SERVICES_GET
|
|
|
|
import eaapi
|
|
|
|
|
|
Handler = Callable[[CallContext], None]
|
|
|
|
|
|
class IController(ABC):
|
|
_name: str
|
|
|
|
def get_handler(self, ctx: CallContext) -> Handler | None:
|
|
raise NotImplementedError
|
|
|
|
def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]:
|
|
raise NotImplementedError
|
|
|
|
def serviced_prefixes(self) -> list[str]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class Controller(IController):
|
|
def __init__(self, server, endpoint="", matcher: None | ModelMatcher = None):
|
|
from .server import EAMServer
|
|
self._server: EAMServer = server
|
|
server.controllers.append(self)
|
|
|
|
import inspect
|
|
caller = inspect.getmodule(inspect.stack()[1][0])
|
|
assert caller is not None
|
|
self._name = caller.__name__
|
|
|
|
self._handlers: dict[
|
|
tuple[str | None, str | None],
|
|
list[tuple[ModelMatcher, Handler]]
|
|
] = defaultdict(lambda: [])
|
|
self._endpoint = endpoint
|
|
self._matcher = matcher
|
|
|
|
self._services: set[tuple[ModelMatcher, str]] = set()
|
|
|
|
self._pre_handler: list[Callable[[CallContext, Handler], Handler]] = []
|
|
|
|
@property
|
|
def server(self):
|
|
return self._server
|
|
|
|
def on_pre(self, callback):
|
|
if callback not in self._pre_handler:
|
|
self._pre_handler.append(callback)
|
|
|
|
def add_dummy_service(
|
|
self,
|
|
service: str,
|
|
matcher: ModelMatcher | None = None,
|
|
unsafe_force_bypass_reserved: bool = False
|
|
):
|
|
if not unsafe_force_bypass_reserved:
|
|
if service in RESERVED_SERVICES:
|
|
raise KeyError(
|
|
f"{service} is a reserved service provided by default.\n"
|
|
"Pass unsafe_force_bypass_reserved=True to override this implementation"
|
|
)
|
|
if matcher is None:
|
|
matcher = ModelMatcher()
|
|
self._services.add((matcher, service))
|
|
|
|
def register_handler(
|
|
self,
|
|
handler: Handler,
|
|
module: str,
|
|
method: str,
|
|
matcher: ModelMatcher | None = None,
|
|
service: str | None = None,
|
|
unsafe_force_bypass_reserved: bool = False
|
|
):
|
|
if not unsafe_force_bypass_reserved:
|
|
if module in RESERVED_MODULES:
|
|
raise KeyError(
|
|
f"{module} is a reserved module provided by default.\n"
|
|
"Pass unsafe_force_bypass_reserved=True to override this implementation"
|
|
)
|
|
|
|
if service is not None and service in RESERVED_SERVICES:
|
|
raise KeyError(
|
|
f"{service} is a reserved service provided by default.\n"
|
|
"Pass unsafe_force_bypass_reserved=True to override this implementation"
|
|
)
|
|
|
|
if service is None:
|
|
if module not in TRIVIAL_SERVICES:
|
|
raise ValueError(f"Unable to identify service for {module}")
|
|
service = module
|
|
|
|
handlers = self._handlers[(module, method)]
|
|
for i in handlers:
|
|
if matcher is None and i[0] is None:
|
|
raise ValueError(f"Duplicate default handler for {module}.{method}")
|
|
if matcher == i[0]:
|
|
raise ValueError(f"Duplicate handler for {module}.{method} ({matcher})")
|
|
matcher_ = matcher or ModelMatcher()
|
|
handlers.append((matcher_, handler))
|
|
handlers.sort(key=lambda x: x[0])
|
|
|
|
self._services.add((matcher_, service))
|
|
|
|
def handler(
|
|
self,
|
|
module: str,
|
|
method: str | None = None,
|
|
matcher: ModelMatcher | None = None,
|
|
service: str | None = None,
|
|
unsafe_force_bypass_reserved: bool = False
|
|
):
|
|
if method is None:
|
|
def h2(method):
|
|
return self.handler(module, method, matcher, service)
|
|
return h2
|
|
|
|
# Commented out for MachineCallContext bodge
|
|
# def decorator(handler: Handler):
|
|
def decorator(handler):
|
|
self.register_handler(handler, module, method, matcher, service, unsafe_force_bypass_reserved)
|
|
return handler
|
|
return decorator
|
|
|
|
def get_handler(self, ctx):
|
|
if self._matcher is not None:
|
|
if not self._matcher.matches(ctx.model):
|
|
return None
|
|
|
|
handlers = self._handlers[(ctx.module, ctx.method)]
|
|
if not handlers:
|
|
return None
|
|
|
|
for matcher, handler in handlers:
|
|
if matcher.matches(ctx.model):
|
|
for i in self._pre_handler:
|
|
handler = i(ctx, handler)
|
|
return handler
|
|
|
|
return None
|
|
|
|
def get_service_routes(self, ctx: CallContext | None):
|
|
endpoint = self._server.expand_url(self._endpoint)
|
|
|
|
if ctx is None:
|
|
return {
|
|
service: endpoint
|
|
for _, service in self._services
|
|
}
|
|
|
|
if self._matcher is not None and not self._matcher.matches(ctx.model):
|
|
return {}
|
|
|
|
return {
|
|
service: endpoint
|
|
for matcher, service in self._services
|
|
if matcher.matches(ctx.model)
|
|
}
|
|
|
|
def serviced_prefixes(self) -> list[str]:
|
|
return [self._endpoint]
|
|
|
|
|
|
class ServicesController(IController):
|
|
def __init__(self, server, services_mode: eaapi.const.ServicesMode):
|
|
from .server import EAMServer
|
|
self._server: EAMServer = server
|
|
|
|
self.services_mode = services_mode
|
|
|
|
self._name = __name__ + "." + self.__class__.__name__
|
|
|
|
def service_routes_for(self, ctx: CallContext):
|
|
services = defaultdict(lambda: self._server.public_url)
|
|
for service, route in self._server.get_service_routes(ctx):
|
|
services[service] = self._server.expand_url(route)
|
|
|
|
return services
|
|
|
|
def get_handler(self, ctx: CallContext) -> Handler | None:
|
|
if ctx.module == MODULE_SERVICES and ctx.method == METHOD_SERVICES_GET:
|
|
return self.services_get
|
|
return None
|
|
|
|
def service_route(self, for_service: str, ctx: CallContext):
|
|
routes = self.service_routes_for(ctx)
|
|
return routes[for_service]
|
|
|
|
def services_get(self, ctx: CallContext):
|
|
services = ctx.resp.append(
|
|
MODULE_SERVICES, expire="600", mode=self.services_mode.value, status="0"
|
|
)
|
|
|
|
routes = self._server.get_service_routes(ctx)
|
|
for service in routes:
|
|
services.append("item", name=service, url=routes[service])
|
|
|
|
def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]:
|
|
return {}
|
|
|
|
def serviced_prefixes(self) -> list[str]:
|
|
return [""]
|