from typing import Callable from collections import defaultdict from abc import ABC from .context import CallContext from .model import ModelMatcher from .const import RESERVED_MODULES, RESERVED_SERVICES, TRIVIAL_SERVICES, MODULE_SERVICES, METHOD_SERVICES_GET import eaapi Handler = Callable[[CallContext], None] class IController(ABC): _name: str def get_handler(self, ctx: CallContext) -> Handler | None: raise NotImplementedError def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]: raise NotImplementedError def serviced_prefixes(self) -> list[str]: raise NotImplementedError class Controller(IController): def __init__(self, server, endpoint="", matcher: None | ModelMatcher = None): from .server import EAMServer self._server: EAMServer = server server.controllers.append(self) import inspect caller = inspect.getmodule(inspect.stack()[1][0]) assert caller is not None self._name = caller.__name__ self._handlers: dict[ tuple[str | None, str | None], list[tuple[ModelMatcher, Handler]] ] = defaultdict(lambda: []) self._endpoint = endpoint self._matcher = matcher self._services: set[tuple[ModelMatcher, str]] = set() self._pre_handler: list[Callable[[CallContext, Handler], Handler]] = [] @property def server(self): return self._server def on_pre(self, callback): if callback not in self._pre_handler: self._pre_handler.append(callback) def add_dummy_service( self, service: str, matcher: ModelMatcher | None = None, unsafe_force_bypass_reserved: bool = False ): if not unsafe_force_bypass_reserved: if service in RESERVED_SERVICES: raise KeyError( f"{service} is a reserved service provided by default.\n" "Pass unsafe_force_bypass_reserved=True to override this implementation" ) if matcher is None: matcher = ModelMatcher() self._services.add((matcher, service)) def register_handler( self, handler: Handler, module: str, method: str, matcher: ModelMatcher | None = None, service: str | None = None, unsafe_force_bypass_reserved: bool = False ): if not unsafe_force_bypass_reserved: if module in RESERVED_MODULES: raise KeyError( f"{module} is a reserved module provided by default.\n" "Pass unsafe_force_bypass_reserved=True to override this implementation" ) if service is not None and service in RESERVED_SERVICES: raise KeyError( f"{service} is a reserved service provided by default.\n" "Pass unsafe_force_bypass_reserved=True to override this implementation" ) if service is None: if module not in TRIVIAL_SERVICES: raise ValueError(f"Unable to identify service for {module}") service = module handlers = self._handlers[(module, method)] for i in handlers: if matcher is None and i[0] is None: raise ValueError(f"Duplicate default handler for {module}.{method}") if matcher == i[0]: raise ValueError(f"Duplicate handler for {module}.{method} ({matcher})") matcher_ = matcher or ModelMatcher() handlers.append((matcher_, handler)) handlers.sort(key=lambda x: x[0]) self._services.add((matcher_, service)) def handler( self, module: str, method: str | None = None, matcher: ModelMatcher | None = None, service: str | None = None, unsafe_force_bypass_reserved: bool = False ): if method is None: def h2(method): return self.handler(module, method, matcher, service) return h2 # Commented out for MachineCallContext bodge # def decorator(handler: Handler): def decorator(handler): self.register_handler(handler, module, method, matcher, service, unsafe_force_bypass_reserved) return handler return decorator def get_handler(self, ctx): if self._matcher is not None: if not self._matcher.matches(ctx.model): return None handlers = self._handlers[(ctx.module, ctx.method)] if not handlers: return None for matcher, handler in handlers: if matcher.matches(ctx.model): for i in self._pre_handler: handler = i(ctx, handler) return handler return None def get_service_routes(self, ctx: CallContext | None): endpoint = self._server.expand_url(self._endpoint) if ctx is None: return { service: endpoint for _, service in self._services } if self._matcher is not None and not self._matcher.matches(ctx.model): return {} return { service: endpoint for matcher, service in self._services if matcher.matches(ctx.model) } def serviced_prefixes(self) -> list[str]: return [self._endpoint] class ServicesController(IController): def __init__(self, server, services_mode: eaapi.const.ServicesMode): from .server import EAMServer self._server: EAMServer = server self.services_mode = services_mode self._name = __name__ + "." + self.__class__.__name__ def service_routes_for(self, ctx: CallContext): services = defaultdict(lambda: self._server.public_url) for service, route in self._server.get_service_routes(ctx): services[service] = self._server.expand_url(route) return services def get_handler(self, ctx: CallContext) -> Handler | None: if ctx.module == MODULE_SERVICES and ctx.method == METHOD_SERVICES_GET: return self.services_get return None def service_route(self, for_service: str, ctx: CallContext): routes = self.service_routes_for(ctx) return routes[for_service] def services_get(self, ctx: CallContext): services = ctx.resp.append( MODULE_SERVICES, expire="600", mode=self.services_mode.value, status="0" ) routes = self._server.get_service_routes(ctx) for service in routes: services.append("item", name=service, url=routes[service]) def get_service_routes(self, ctx: CallContext | None) -> dict[str, str]: return {} def serviced_prefixes(self) -> list[str]: return [""]