Source code for domprob.dispatchers.basic

from collections.abc import Collection, Iterator
from typing import TypeVar, Any, ParamSpec

from domprob.dispatchers.dispatcher import (
    DispatcherException,
    DispatcherProtocol,
)
from domprob.announcements.method import AnnouncementMethod
from domprob.observations.observation import ObservationProtocol

_Instrument = TypeVar("_Instrument", bound=Any)


[docs] class InstrumentImpRegistry(Collection[_Instrument]): """Registry for instrument implementations, allowing lookup and caching. This class acts as a collection that stores instruments and supports: - Efficient retrieval of instruments by type. - Caching of previously looked-up instruments for performance optimization. Args: *instruments (`_Instrument`): Variable number of instrument instances to store. Example: >>> class LoggerInstrument: ... @staticmethod ... def add(): ... return "Log message added!" ... >>> class AnalyticsInstrument: ... @staticmethod ... def add(): ... return "Analytics entry added!" ... >>> logger = LoggerInstrument() >>> analytics = AnalyticsInstrument() >>> >>> registry = InstrumentImpRegistry(logger, analytics) >>> logger_ = registry.get(LoggerInstrument) >>> >>> logger == logger_ True >>> print(registry.get(object)) None """ def __init__(self, *instruments: _Instrument) -> None: self._instrums = instruments self._cache: dict[type[_Instrument], _Instrument] = {}
[docs] def __contains__(self, item: object) -> bool: """Check if an instrument exists in the registry. Args: item: The instrument instance or class to check. Returns: bool: True if the instrument is present, otherwise False. Example: >>> class LoggerInstrument: ... @staticmethod ... def add(): ... return "Log message added!" ... >>> class AnalyticsInstrument: ... @staticmethod ... def add(): ... return "Analytics entry added!" ... >>> logger = LoggerInstrument() >>> analytics = AnalyticsInstrument() >>> >>> registry = InstrumentImpRegistry(logger, analytics) >>> logger in registry True >>> object in registry False """ return item in self._instrums
def __hash__(self) -> int: return hash(self._instrums)
[docs] def __iter__(self) -> Iterator[_Instrument]: """Iterate over stored instruments. Returns: Iterator[_Instrument]: An iterator over the instruments. Example: >>> class LoggerInstrument: ... @staticmethod ... def add(): ... return "Log message added!" ... >>> class AnalyticsInstrument: ... @staticmethod ... def add(): ... return "Analytics entry added!" ... >>> logger = LoggerInstrument() >>> analytics = AnalyticsInstrument() >>> >>> registry = InstrumentImpRegistry(logger, analytics) >>> >>> for instrument in registry: ... print(instrument.add()) ... Log message added! Analytics entry added! """ yield from self._instrums
[docs] def __len__(self) -> int: """Return the number of stored instruments. Returns: int: The number of instruments in the registry. Example: >>> class LoggerInstrument: ... @staticmethod ... def add(): ... return "Log message added!" ... >>> class AnalyticsInstrument: ... @staticmethod ... def add(): ... return "Analytics entry added!" ... >>> logger = LoggerInstrument() >>> analytics = AnalyticsInstrument() >>> >>> registry = InstrumentImpRegistry(logger, analytics) >>> >>> len(registry) 2 """ return len(self._instrums)
[docs] @staticmethod def _is_hashable(obj: Any) -> bool: """Check if an object is hashable. Args: obj (`Any`): The object to check. Returns: `bool`: True if the object is hashable, False otherwise. """ try: hash(obj) except TypeError: return False return True
[docs] def get( self, instrument_cls: type[_Instrument], required: bool = False ) -> _Instrument | None: # pylint: disable=line-too-long """Retrieve an instrument instance by its class type. If the instrument class is hashable, results are cached for efficiency. Args: instrument_cls: The class type of the instrument to retrieve. required: If `True`, raises a `KeyError` if the instrument is not found. If `False`, returns `None`. Returns: _Instrument | None: The retrieved instrument instance or `None` if not found. Raises: KeyError: If `required` is `True` and the instrument is not found. Example: >>> class LoggerInstrument: ... @staticmethod ... def add(): ... return "Log message added!" ... >>> class AnalyticsInstrument: ... @staticmethod ... def add(): ... return "Analytics entry added!" ... >>> logger = LoggerInstrument() >>> analytics = AnalyticsInstrument() >>> >>> registry = InstrumentImpRegistry(logger) >>> >>> registry.get(LoggerInstrument).add() 'Log message added!' >>> registry.get(AnalyticsInstrument, required=True) Traceback (most recent call last): ... KeyError: 'Instrument `AnalyticsInstrument` not found in available implementations: `<domprob.dispatchers.basic.LoggerInstrument object at 0x...>`' """ if self._is_hashable(instrument_cls) and instrument_cls in self._cache: return self._cache[instrument_cls] for instrum in self._instrums: # pylint: disable=unidiomatic-typecheck if type(instrum) is instrument_cls: if self._is_hashable(instrument_cls): self._cache[instrument_cls] = instrum return instrum if required: imp_str = ", ".join(f"`{repr(i)}`" for i in self._instrums) or None raise KeyError( f"Instrument `{instrument_cls.__name__}` not found in " f"available implementations: {imp_str}" ) return None
[docs] def __repr__(self) -> str: """Return a string representation of the registry. Returns: `str`: The string representation of the registry. """ return f"{self.__class__.__name__}(num_instruments={len(self)})"
_P = ParamSpec("_P") _R = TypeVar("_R", bound=Any)
[docs] class ReqInstrumException(DispatcherException): """Exception raised when a required instrument is missing an implementation of the same type for an observation announcement. An instrument is marked as required with the `required` flag in the `@announcement` decorator: >>> from domprob import announcement, BaseObservation >>> >>> class SomeObservation(BaseObservation): ... ... @announcement(..., required=True) ... def some_method(self, instrument: ...) -> None: ... ... ... Args: observation (_Obs): The observation instance where the missing instrument was required. announcement (_Ann): The announcement method that failed due to the missing instrument. req_supp_instr (type[_Instrument]): The instrument type that was expected but not found. *instrum_imps (_Instrument): The available instrument instances at the time of the failure. """ def __init__( self, observation: ObservationProtocol, announcement: AnnouncementMethod, req_supp_instrum: type[Any], *instrum_imps: Any, ) -> None: self.observation = observation self.announcement = announcement self.req_supp_instr = req_supp_instrum self.instrum_imps = instrum_imps super().__init__(self.msg) @property def msg(self) -> str: """Constructs a descriptive error message for the exception. Returns: str: A formatted string detailing the missing instrument, the observation method where it was required, and the available instrument implementations. """ req_name = self.req_supp_instr.__name__ meth_name = self.announcement.meth.__name__ obs_meth = f"{self.observation.__class__.__name__}.{meth_name}(...)" imps_str = ", ".join([f"`{repr(i)}`" for i in self.instrum_imps]) return ( f"Required instrument `{req_name}` in `{obs_meth}` is " f"missing from available implementations: {imps_str or None}" )
[docs] class BasicDispatcher(DispatcherProtocol): # pylint: disable=line-too-long """Dispatches observations to registered instruments. This class manages: - Finding the appropriate instrument for a given observation. - Dispatching announcements to the relevant instruments. Args: *instruments (`_Instrument`): Variable number of instrument instances. Example: >>> from abc import ABC, abstractmethod >>> >>> class BaseInstrument(ABC): ... @abstractmethod ... def add(self): ... pass ... >>> class LoggerInstrument: ... @staticmethod ... def add(): ... return "Log message added!" ... >>> class AnalyticsInstrument: ... @staticmethod ... def add(): ... return "Analytics entry added!" ... >>> dispatcher = BasicDispatcher(LoggerInstrument(), AnalyticsInstrument()) >>> dispatcher BasicDispatcher(instruments=('<domprob.dispatchers.basic.LoggerInstrument object at 0x...>', '<domprob.dispatchers.basic.AnalyticsInstrument object at 0x...>')) >>> >>> from domprob import announcement, BaseObservation >>> >>> class SomeObservation(BaseObservation): ... @announcement(LoggerInstrument) ... @announcement(AnalyticsInstrument) ... def foo(self, instrument: BaseInstrument) -> None: ... print(instrument.add()) ... >>> obs = SomeObservation() >>> dispatcher.dispatch(obs) Analytics entry added! Log message added! """ def __init__(self, *instruments: _Instrument) -> None: self.instrums = InstrumentImpRegistry(*instruments)
[docs] def __eq__(self, other: Any) -> bool: if not isinstance(other, type(self)): return False return (type(self) is type(other)) and ( tuple(self.instrums) == tuple(other.instrums) )
def __hash__(self) -> int: return hash(self.instrums)
[docs] @staticmethod def _dispatch_instrum_ann( observation: ObservationProtocol, announcement: AnnouncementMethod, instrument: Any, ) -> None: """Invoke an announcement method on an instrument. This method triggers the specified announcement method on the given instrument instance if it is available. Args: observation (_Obs): The observation being processed. announcement (_Ann): The announcement method to invoke. instrument (_Instrument | None): The target instrument instance, if available. """ if instrument is not None: announcement.meth(observation, instrument)
[docs] def _dispatch_ann( self, observation: ObservationProtocol, announcement: AnnouncementMethod, ) -> None: """Process an announcement by identifying the required instrument. This method retrieves the correct instrument instance based on the announcement and calls `_instrum_announce`. Args: observation (_Obs): The observation being processed. announcement (_Ann): The announcement method to invoke. """ for supp_instrum, required in announcement.supp_instrums: try: instrum_imp = self.instrums.get(supp_instrum, required) except KeyError as e: raise ReqInstrumException( observation, announcement, supp_instrum, *self.instrums ) from e self._dispatch_instrum_ann(observation, announcement, instrum_imp)
[docs] def dispatch(self, observation: ObservationProtocol) -> None: # pylint: disable=line-too-long """Dispatch an observation to all applicable instruments. This method retrieves all announcements from the observation and processes them. Args: observation (_Obs): The observation to process. Example: >>> from abc import ABC, abstractmethod >>> >>> class BaseInstrument(ABC): ... @abstractmethod ... def add(self): ... pass ... >>> class LoggerInstrument: ... @staticmethod ... def add(): ... return "Log message added!" ... >>> class AnalyticsInstrument: ... @staticmethod ... def add(): ... return "Analytics entry added!" ... >>> dispatcher = BasicDispatcher(LoggerInstrument(), AnalyticsInstrument()) >>> dispatcher BasicDispatcher(instruments=('<domprob.dispatchers.basic.LoggerInstrument object at 0x...>', '<domprob.dispatchers.basic.AnalyticsInstrument object at 0x...>')) >>> >>> from domprob import announcement, BaseObservation >>> >>> class SomeObservation(BaseObservation): ... @announcement(LoggerInstrument) ... @announcement(AnalyticsInstrument) ... def foo(self, instrument: BaseInstrument) -> None: ... print(instrument.add()) ... >>> obs = SomeObservation() >>> dispatcher.dispatch(obs) Analytics entry added! Log message added! """ for ann in observation.announcements(): self._dispatch_ann(observation, ann)
[docs] def __repr__(self) -> str: # pylint: disable=line-too-long """Return a string representation of the dispatcher. Returns: str: A string representation of the dispatcher and its instruments. Example: >>> from abc import ABC, abstractmethod >>> >>> class BaseInstrument(ABC): ... @abstractmethod ... def add(self): ... pass ... >>> class LoggerInstrument: ... @staticmethod ... def add(): ... return "Log message added!" ... >>> class AnalyticsInstrument: ... @staticmethod ... def add(): ... return "Analytics entry added!" ... >>> dispatcher = BasicDispatcher(LoggerInstrument(), AnalyticsInstrument()) >>> repr(dispatcher) "BasicDispatcher(instruments=('<domprob.dispatchers.basic.LoggerInstrument object at 0x...>', '<domprob.dispatchers.basic.AnalyticsInstrument object at 0x...>'))" """ instrum_imps_str = tuple(repr(i) for i in self.instrums) return f"{self.__class__.__name__}(instruments={instrum_imps_str})"