Source code for cleo.base

"""Contains definitions for essential, base classes."""

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Tuple, Iterable
import datetime

from attrs import define, field
from brian2 import (
    NeuronGroup,
    Subgroup,
    Network,
    NetworkOperation,
    defaultclock,
    ms,
    Unit,
    Quantity,
)
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.artist import Artist
import neo

import cleo.utilities


class NeoExportable(ABC):
    """Mixin class for classes that can be exported to Neo objects"""

    @abstractmethod
    def to_neo(self) -> neo.core.BaseNeo:
        """Return a neo.core.AnalogSignal object with the device's data

        Returns
        -------
        neo.core.BaseNeo
            Neo object representing exported data
        """
        pass


[docs]@define(eq=False) class InterfaceDevice(ABC): """Base class for devices to be injected into the network""" brian_objects: set = field(factory=set, init=False) """All the Brian objects added to the network by this device. Must be kept up-to-date in :meth:`connect_to_neuron_group` and other functions so that those objects can be automatically added to the network when the device is injected. """ sim: CLSimulator = field(init=False, default=None) """The simulator the device is injected into """ name: str = field(kw_only=True) """Unique identifier for device, used in sampling, plotting, etc. Name of the class by default.""" @name.default def _default_name(self) -> str: return self.__class__.__name__
[docs] def init_for_simulator(self, simulator: CLSimulator) -> None: """Initialize device for simulator on initial injection This function is called only the first time a device is injected into a simulator and performs any operations that are independent of the individual neuron groups it is connected to. Parameters ---------- simulator : CLSimulator simulator being injected into """ pass
[docs] def reset(self, **kwargs) -> None: """Reset the device to a neutral state""" pass
[docs] @abstractmethod def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams) -> None: """Connect device to given `neuron_group`. If your device introduces any objects which Brian must keep track of, such as a NeuronGroup, Synapses, or Monitor, make sure to add these to `self.brian_objects`. Parameters ---------- neuron_group : NeuronGroup **kwparams : optional, passed from `inject` or `inject` """ pass
[docs] def add_self_to_plot( self, ax: Axes3D, axis_scale_unit: Unit, **kwargs ) -> list[Artist]: """Add device to an existing plot Should only be called by :func:`~cleo.viz.plot`. Parameters ---------- ax : Axes3D The existing matplotlib Axes object axis_scale_unit : Unit The unit used to label axes and define chart limits **kwargs : optional Returns ------- list[Artist] A list of artists used to render the device. Needed for use in conjunction with :class:`~cleo.viz.VideoVisualizer`. """ return []
[docs] def update_artists(self, artists: list[Artist], *args, **kwargs) -> list[Artist]: """Update the artists used to render the device Used to set the artists' state at every frame of a video visualization. The current state would be passed in `*args` or `**kwargs` Parameters ---------- artists : list[Artist] the artists used to render the device originally, i.e., which were returned from the first :meth:`add_self_to_plot` call. Returns ------- list[Artist] The artists that were actually updated. Needed for efficient blit rendering, where only updated artists are re-rendered. """ return []
[docs]class IOProcessor(ABC): """Abstract class for implementing sampling, signal processing and control This must be implemented by the user with their desired closed-loop use case, though most users will find the :func:`~processing.LatencyIOProcessor` class more useful, since delay handling is already defined. """ sample_period_ms: float """Determines how frequently the processor takes samples"""
[docs] @abstractmethod def is_sampling_now(self, time) -> bool: """Determines whether the processor will take a sample at this timestep. Parameters ---------- time : Brian 2 temporal Unit Current timestep. Returns ------- bool """ pass
[docs] @abstractmethod def put_state(self, state_dict: dict, time) -> None: """Deliver network state to the :class:`IOProcessor`. Parameters ---------- state_dict : dict A dictionary of recorder measurements, as returned by :func:`~cleo.CLSimulator.get_state()` time : brian2 temporal Unit The current simulation timestep. Essential for simulating control latency and for time-varying control. """ pass
[docs] @abstractmethod def get_ctrl_signal(self, time) -> dict: """Get per-stimulator control signal from the :class:`~cleo.IOProcessor`. Parameters ---------- time : Brian 2 temporal Unit Current timestep Returns ------- dict A {'stimulator_name': value} dictionary for updating stimulators. """ pass
[docs] def reset(self, **kwargs) -> None: pass
[docs]@define(eq=False) class Recorder(InterfaceDevice): """Device for taking measurements of the network."""
[docs] @abstractmethod def get_state(self) -> Any: """Return current measurement.""" pass
[docs]@define(eq=False) class Stimulator(InterfaceDevice, NeoExportable): """Device for manipulating the network""" value: Any = field(init=False, default=None) """The current value of the stimulator device""" default_value: Any = 0 """The default value of the device---used on initialization and on :meth:`~reset`""" t_ms: list[float] = field(factory=list, init=False, repr=False) """Times stimulator was updated, stored if :attr:`save_history`""" values: list[Any] = field(factory=list, init=False, repr=False) """Values taken by the stimulator at each :meth:`~update` call, stored if :attr:`save_history`""" save_history: bool = True """Determines whether :attr:`t_ms` and :attr:`values` are recorded""" def __attrs_post_init__(self): self.value = self.default_value def _init_saved_vars(self): if self.save_history: self.t_ms = [] self.values = []
[docs] def update(self, ctrl_signal) -> None: """Set the stimulator value. By default this simply sets `value` to `ctrl_signal`. You will want to implement this method if your stimulator requires additional logic. Use super.update(self, value) to preserve the self.value attribute logic Parameters ---------- ctrl_signal : any The value the stimulator is to take. """ self.value = ctrl_signal if self.save_history: self.t_ms.append(self.sim.network.t / ms) self.values.append(self.value)
[docs] def reset(self, **kwargs) -> None: """Reset the stimulator device to a neutral state""" self.value = self.default_value self._init_saved_vars()
[docs] def to_neo(self): signal = cleo.utilities.analog_signal(self.t_ms, self.values, "dimensionless") signal.name = self.name signal.description = "Exported from Cleo stimulator device" signal.annotate(export_datetime=datetime.datetime.now()) return signal
[docs]@define(eq=False) class CLSimulator(NeoExportable): """The centerpiece of cleo. Integrates simulation components and runs.""" network: Network = field(repr=False) """The Brian network forming the core model""" io_processor: IOProcessor = field(default=None, init=False) recorders: dict[str, Recorder] = field(factory=dict, init=False, repr=False) stimulators: dict[str, Stimulator] = field(factory=dict, init=False, repr=False) devices: set[InterfaceDevice] = field(factory=set, init=False) _processing_net_op: NetworkOperation = field(default=None, init=False, repr=False) _net_store_name: str = field(default="cleo default", init=False, repr=False)
[docs] def inject( self, device: InterfaceDevice, *neuron_groups: NeuronGroup, **kwparams: Any ) -> CLSimulator: """Inject InterfaceDevice into the network, connecting to specified neurons. Calls :meth:`~InterfaceDevice.connect_to_neuron_group` for each group with kwparams and adds the device's :attr:`~InterfaceDevice.brian_objects` to the simulator's :attr:`network`. Parameters ---------- device : InterfaceDevice Device to inject Returns ------- CLSimulator self """ if len(neuron_groups) == 0: raise Exception("Injecting stimulator for no neuron groups is meaningless.") for ng in neuron_groups: if type(ng) == NeuronGroup: if ng not in self.network.objects: raise Exception( f"Attempted to connect device {device.name} to neuron group " f"{ng.name}, which is not part of the simulator's network." ) elif type(ng) == Subgroup: # must look at sorted_objects because ng.source is unhashable if ng.source not in self.network.sorted_objects: raise Exception( f"Attempted to connect device {device.name} to neuron group " f"{ng.source.name}, which is not part of the simulator's network." ) if device.sim not in [None, self]: raise Exception( f"Attempted to inject device {device.name} into {self}, " f"but it was previously injected into {device.sim}. " "Each device can only be injected into one CLSimulator." ) if device.sim is None: device.sim = self device.init_for_simulator(self) device.connect_to_neuron_group(ng, **kwparams) for brian_object in device.brian_objects: if brian_object not in self.network.objects: self.network.add(brian_object) self.network.store(self._net_store_name) if isinstance(device, Recorder): self.recorders[device.name] = device if isinstance(device, Stimulator): self.stimulators[device.name] = device self.devices.add(device) return self
[docs] def get_state(self) -> dict: """Return current recorder measurements. Returns ------- dict A dictionary of `name`: `state` pairs for all recorders in the simulator. """ state = {} for name, recorder in self.recorders.items(): state[name] = recorder.get_state() return state
[docs] def update_stimulators(self, ctrl_signals) -> None: """Update stimulators with output from the :class:`IOProcessor` Parameters ---------- ctrl_signals : dict {`stimulator_name`: `ctrl_signal`} dictionary with values to update each stimulator. """ if ctrl_signals is None: return for name, signal in ctrl_signals.items(): self.stimulators[name].update(signal)
[docs] def set_io_processor(self, io_processor, communication_period=None) -> CLSimulator: """Set simulator IO processor Will replace any previous IOProcessor so there is only one at a time. A Brian NetworkOperation is created to govern communication between the Network and the IOProcessor. Parameters ---------- io_processor : IOProcessor Returns ------- CLSimulator self """ self.io_processor = io_processor # remove previous NetworkOperation if self._processing_net_op is not None: self.network.remove(self._processing_net_op) self._processing_net_op = None if io_processor is None: return def communicate_with_io_proc(t): if io_processor.is_sampling_now(t / ms): io_processor.put_state(self.get_state(), t / ms) ctrl_signal = io_processor.get_ctrl_signal(t / ms) self.update_stimulators(ctrl_signal) # communication should be at every timestep. The IOProcessor # decides when to sample and deliver results. if communication_period is None: communication_period = defaultclock.dt self._processing_net_op = NetworkOperation( communicate_with_io_proc, dt=communication_period ) self.network.add(self._processing_net_op) self.network.store(self._net_store_name) return self
[docs] def run(self, duration: Quantity, **kwparams) -> None: """Run simulation. Parameters ---------- duration : brian2 temporal Quantity Length of simulation **kwparams : additional arguments passed to brian2.run() level has a default value of 1 """ level = kwparams.get("level", 1) kwparams["level"] = level self.network.run(duration, **kwparams)
[docs] def reset(self, **kwargs): """Reset the simulator to a neutral state Restores the Brian Network to where it was when the CLSimulator was last modified (last injection, IOProcessor change). Calls reset() on devices and IOProcessor. """ # kwargs passed to stimulators, recorders, and io_processor reset self.network.restore(self._net_store_name) for device in self.devices: device.reset(**kwargs) if self.io_processor is not None: self.io_processor.reset(**kwargs)
[docs] def to_neo(self) -> neo.core.Block: block = neo.Block( description="Exported from Cleo simulation", ) block.annotate(export_datetime=datetime.datetime.now()) seg = neo.Segment() block.segments.append(seg) for device in self.devices: if not isinstance(device, NeoExportable): continue dev_neo = device.to_neo() if isinstance(dev_neo, neo.core.Group): data_objects = dev_neo.data_children_recur block.groups.append(dev_neo) elif isinstance(dev_neo, neo.core.dataobject.DataObject): data_objects = [dev_neo] cleo.utilities.add_to_neo_segment(seg, *data_objects) return block