eaapi/eaapi/server/controller.py

211 lines
6.9 KiB
Python

from typing import Callable
from collections import defaultdict
from abc import ABC
from .context import CallContext
from .model import ModelMatcher
from .const import RESERVED_MODULES, RESERVED_SERVICES, TRIVIAL_SERVICES, MODULE_SERVICES, METHOD_SERVICES_GET
import eaapi
Handler = Callable[[CallContext], None]
class IController(ABC):
_name: str
def get_handler(self, ctx: CallContext) -> Handler | None:
raise NotImplementedError
def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]:
raise NotImplementedError
def serviced_prefixes(self) -> list[str]:
raise NotImplementedError
class Controller(IController):
def __init__(self, server, endpoint="", matcher: None | ModelMatcher = None):
from .server import EAMServer
self._server: EAMServer = server
server.controllers.append(self)
import inspect
caller = inspect.getmodule(inspect.stack()[1][0])
assert caller is not None
self._name = caller.__name__
self._handlers: dict[
tuple[str | None, str | None],
list[tuple[ModelMatcher, Handler]]
] = defaultdict(lambda: [])
self._endpoint = endpoint
self._matcher = matcher
self._services: set[tuple[ModelMatcher, str]] = set()
self._pre_handler: list[Callable[[CallContext, Handler], Handler]] = []
@property
def server(self):
return self._server
def on_pre(self, callback):
if callback not in self._pre_handler:
self._pre_handler.append(callback)
def add_dummy_service(
self,
service: str,
matcher: ModelMatcher | None = None,
unsafe_force_bypass_reserved: bool = False
):
if not unsafe_force_bypass_reserved:
if service in RESERVED_SERVICES:
raise KeyError(
f"{service} is a reserved service provided by default.\n"
"Pass unsafe_force_bypass_reserved=True to override this implementation"
)
if matcher is None:
matcher = ModelMatcher()
self._services.add((matcher, service))
def register_handler(
self,
handler: Handler,
module: str,
method: str,
matcher: ModelMatcher | None = None,
service: str | None = None,
unsafe_force_bypass_reserved: bool = False
):
if not unsafe_force_bypass_reserved:
if module in RESERVED_MODULES:
raise KeyError(
f"{module} is a reserved module provided by default.\n"
"Pass unsafe_force_bypass_reserved=True to override this implementation"
)
if service is not None and service in RESERVED_SERVICES:
raise KeyError(
f"{service} is a reserved service provided by default.\n"
"Pass unsafe_force_bypass_reserved=True to override this implementation"
)
if service is None:
if module not in TRIVIAL_SERVICES:
raise ValueError(f"Unable to identify service for {module}")
service = module
handlers = self._handlers[(module, method)]
for i in handlers:
if matcher is None and i[0] is None:
raise ValueError(f"Duplicate default handler for {module}.{method}")
if matcher == i[0]:
raise ValueError(f"Duplicate handler for {module}.{method} ({matcher})")
matcher_ = matcher or ModelMatcher()
handlers.append((matcher_, handler))
handlers.sort(key=lambda x: x[0])
self._services.add((matcher_, service))
def handler(
self,
module: str,
method: str | None = None,
matcher: ModelMatcher | None = None,
service: str | None = None,
unsafe_force_bypass_reserved: bool = False
):
if method is None:
def h2(method):
return self.handler(module, method, matcher, service)
return h2
# Commented out for MachineCallContext bodge
# def decorator(handler: Handler):
def decorator(handler):
self.register_handler(handler, module, method, matcher, service, unsafe_force_bypass_reserved)
return handler
return decorator
def get_handler(self, ctx):
if self._matcher is not None:
if not self._matcher.matches(ctx.model):
return None
handlers = self._handlers[(ctx.module, ctx.method)]
if not handlers:
return None
for matcher, handler in handlers:
if matcher.matches(ctx.model):
for i in self._pre_handler:
handler = i(ctx, handler)
return handler
return None
def get_service_routes(self, ctx: CallContext | None):
endpoint = self._server.expand_url(self._endpoint)
if ctx is None:
return {
service: endpoint
for _, service in self._services
}
if self._matcher is not None and not self._matcher.matches(ctx.model):
return {}
return {
service: endpoint
for matcher, service in self._services
if matcher.matches(ctx.model)
}
def serviced_prefixes(self) -> list[str]:
return [self._endpoint]
class ServicesController(IController):
def __init__(self, server, services_mode: eaapi.const.ServicesMode):
from .server import EAMServer
self._server: EAMServer = server
self.services_mode = services_mode
self._name = __name__ + "." + self.__class__.__name__
def service_routes_for(self, ctx: CallContext):
services = defaultdict(lambda: self._server.public_url)
for service, route in self._server.get_service_routes(ctx):
services[service] = self._server.expand_url(route)
return services
def get_handler(self, ctx: CallContext) -> Handler | None:
if ctx.module == MODULE_SERVICES and ctx.method == METHOD_SERVICES_GET:
return self.services_get
return None
def service_route(self, for_service: str, ctx: CallContext):
routes = self.service_routes_for(ctx)
return routes[for_service]
def services_get(self, ctx: CallContext):
services = ctx.resp.append(
MODULE_SERVICES, expire="600", mode=self.services_mode.value, status="0"
)
routes = self._server.get_service_routes(ctx)
for service in routes:
services.append("item", name=service, url=routes[service])
def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]:
return {}
def serviced_prefixes(self) -> list[str]:
return [""]