Source code for cleo.base

"""Contains definitions for essential, base classes."""

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any

from brian2 import (
    NeuronGroup,
    Subgroup,
    Network,
    NetworkOperation,
    defaultclock,
    ms,
    Unit,
    Quantity,
)
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.artist import Artist


[docs]class InterfaceDevice(ABC): """Base class for devices to be injected into the network""" name: str """Unique identifier for device. Used as a key in output/input dicts """ brian_objects: set """All the Brian objects added to the network by this device. Must be kept up-to-date in :meth:`connect_to_neuron_group` and other functions so that those objects can be automatically added to the network when the device is injected. """ sim: CLSimulator """The simulator the device is injected into """ def __init__(self, name: str) -> None: """ Parameters ---------- name : str Unique identifier for the device. """ self.name = name self.brian_objects = set() self.sim = None
[docs] def init_for_simulator(self, simulator: CLSimulator) -> None: """Initialize device for simulator on initial injection This function is called only the first time a device is injected into a simulator and performs any operations that are independent of the individual neuron groups it is connected to. Parameters ---------- simulator : CLSimulator simulator being injected into """ pass
[docs] @abstractmethod def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams) -> None: """Connect device to given `neuron_group`. If your device introduces any objects which Brian must keep track of, such as a NeuronGroup, Synapses, or Monitor, make sure to add these to `self.brian_objects`. Parameters ---------- neuron_group : NeuronGroup **kwparams : optional, passed from `inject_recorder` or `inject_stimulator` """ pass
[docs] def add_self_to_plot( self, ax: Axes3D, axis_scale_unit: Unit, **kwargs ) -> list[Artist]: """Add device to an existing plot Should only be called by :func:`~cleo.viz.plot`. Parameters ---------- ax : Axes3D The existing matplotlib Axes object axis_scale_unit : Unit The unit used to label axes and define chart limits **kwargs : optional Returns ------- list[Artist] A list of artists used to render the device. Needed for use in conjunction with :class:`~cleo.viz.VideoVisualizer`. """ return []
[docs] def update_artists(artists: list[Artist], *args, **kwargs) -> list[Artist]: """Update the artists used to render the device Used to set the artists' state at every frame of a video visualization. The current state would be passed in `*args` or `**kwargs` Parameters ---------- artists : list[Artist] the artists used to render the device originally, i.e., which were returned from the first :meth:`add_self_to_plot` call. Returns ------- list[Artist] The artists that were actually updated. Needed for efficient blit rendering, where only updated artists are re-rendered. """ return []
[docs]class IOProcessor(ABC): """Abstract class for implementing sampling, signal processing and control This must be implemented by the user with their desired closed-loop use case, though most users will find the :func:`~processing.LatencyIOProcessor` class more useful, since delay handling is already defined. """ sample_period_ms: float """Determines how frequently the processor takes samples"""
[docs] @abstractmethod def is_sampling_now(self, time) -> bool: """Determines whether the processor will take a sample at this timestep. Parameters ---------- time : Brian 2 temporal Unit Current timestep. Returns ------- bool """ pass
[docs] @abstractmethod def put_state(self, state_dict: dict, time) -> None: """Deliver network state to the :class:`IOProcessor`. Parameters ---------- state_dict : dict A dictionary of recorder measurements, as returned by :func:`~cleo.CLSimulator.get_state()` time : brian2 temporal Unit The current simulation timestep. Essential for simulating control latency and for time-varying control. """ pass
[docs] @abstractmethod def get_ctrl_signal(self, time) -> dict: """Get per-stimulator control signal from the :class:`~cleo.IOProcessor`. Parameters ---------- time : Brian 2 temporal Unit Current timestep Returns ------- dict A {'stimulator_name': value} dictionary for updating stimulators. """ pass
[docs] def reset(self, **kwargs) -> None: pass
[docs]class Recorder(InterfaceDevice): """Device for taking measurements of the network."""
[docs] @abstractmethod def get_state(self) -> Any: """Return current measurement.""" pass
[docs] def reset(self, **kwargs) -> None: """Reset the recording device to a neutral state""" pass
[docs]class Stimulator(InterfaceDevice): """Device for manipulating the network""" value: Any """The current value of the stimulator device""" default_value: Any """The default value of the device---used on initialization and on :meth:`~reset`""" t_ms: list[float] """Times stimulator was updated, stored if :attr:`save_history`""" values: list[Any] """Values taken by the stimulator at each :meth:`~update` call, stored if :attr:`save_history`""" save_history: bool """Determines whether :attr:`t_ms` and :attr:`values` are recorded""" def __init__(self, name: str, default_value, save_history: bool = False) -> None: """ Parameters ---------- name : str Unique device name used in :meth:`CLSimulator.update_stimulators` default_value : any The stimulator's default value """ super().__init__(name) self.value = default_value self.default_value = default_value self.save_history = save_history
[docs] def init_for_simulator(self, simulator: CLSimulator) -> None: super().init_for_simulator(simulator) self._init_saved_vars()
def _init_saved_vars(self): if self.save_history: self.t_ms = [self.sim.network.t / ms] self.values = [self.default_value]
[docs] def update(self, ctrl_signal) -> None: """Set the stimulator value. By default this simply sets `value` to `ctrl_signal`. You will want to implement this method if your stimulator requires additional logic. Use super.update(self, value) to preserve the self.value attribute logic Parameters ---------- ctrl_signal : any The value the stimulator is to take. """ self.value = ctrl_signal if self.save_history: self.t_ms.append(self.sim.network.t / ms) self.values.append(self.value)
[docs] def reset(self, **kwargs) -> None: """Reset the stimulator device to a neutral state""" self._init_saved_vars()
[docs]class CLSimulator: """The centerpiece of cleo. Integrates simulation components and runs.""" io_processor: IOProcessor network: Network recorders = "set[Recorder]" stimulators = "set[Stimulator]" _processing_net_op: NetworkOperation _net_store_name: str = "cleo default" def __init__(self, brian_network: Network) -> None: """ Parameters ---------- brian_network : Network The Brian network forming the core model """ self.network = brian_network self.stimulators = {} self.recorders = {} self.io_processor = None self._processing_net_op = None
[docs] def inject_device( self, device: InterfaceDevice, *neuron_groups: NeuronGroup, **kwparams: Any ) -> None: """Inject InterfaceDevice into the network, connecting to specified neurons. Calls :meth:`~InterfaceDevice.connect_to_neuron_group` for each group with kwparams and adds the device's :attr:`~InterfaceDevice.brian_objects` to the simulator's :attr:`network`. Automatically called by :meth:`inject_recorder` and :meth:`inject_stimulator`. Parameters ---------- device : InterfaceDevice Device to inject """ if len(neuron_groups) == 0: raise Exception("Injecting stimulator for no neuron groups is meaningless.") for ng in neuron_groups: if type(ng) == NeuronGroup: if ng not in self.network.objects: raise Exception( f"Attempted to connect device {device.name} to neuron group " f"{ng.name}, which is not part of the simulator's network." ) elif type(ng) == Subgroup: # must look at sorted_objects because ng.source is unhashable if ng.source not in self.network.sorted_objects: raise Exception( f"Attempted to connect device {device.name} to neuron group " f"{ng.source.name}, which is not part of the simulator's network." ) if device.sim not in [None, self]: raise Exception( f"Attempted to inject device {device.name} into {self}, " f"but it was previously injected into {device.sim}. " "Each device can only be injected into one CLSimulator." ) if device.sim is None: device.sim = self device.init_for_simulator(self) device.connect_to_neuron_group(ng, **kwparams) for brian_object in device.brian_objects: self.network.add(brian_object) self.network.store(self._net_store_name)
[docs] def inject_stimulator( self, stimulator: Stimulator, *neuron_groups: NeuronGroup, **kwparams ) -> None: """Inject stimulator into given neuron groups. :meth:`Stimulator.connect_to_neuron_group` is called for each `group`. Parameters ---------- stimulator : Stimulator The stimulator to inject *neuron_groups : NeuronGroup The groups to inject the stimulator into **kwparams : any Passed on to :meth:`Stimulator.connect_to_neuron_group` function. Necessary for parameters that can vary by neuron group, such as opsin expression levels. """ self.inject_device(stimulator, *neuron_groups, **kwparams) self.stimulators[stimulator.name] = stimulator
[docs] def inject_recorder( self, recorder: Recorder, *neuron_groups: NeuronGroup, **kwparams ) -> None: """Inject recorder into given neuron groups. :meth:`Recorder.connect_to_neuron_group` is called for each `group`. Parameters ---------- recorder : Recorder The recorder to inject into the simulation *neuron_groups : NeuronGroup The groups to inject the recorder into **kwparams : any Passed on to :meth:`Recorder.connect_to_neuron_group` function. Necessary for parameters that can vary by neuron group, such as inhibitory/excitatory cell type """ self.inject_device(recorder, *neuron_groups, **kwparams) self.recorders[recorder.name] = recorder
[docs] def get_state(self) -> dict: """Return current recorder measurements. Returns ------- dict A dictionary of `name`: `state` pairs for all recorders in the simulator. """ state = {} for name, recorder in self.recorders.items(): state[name] = recorder.get_state() return state
[docs] def update_stimulators(self, ctrl_signals) -> None: """Update stimulators with output from the :class:`IOProcessor` Parameters ---------- ctrl_signals : dict {`stimulator_name`: `ctrl_signal`} dictionary with values to update each stimulator. """ if ctrl_signals is None: return for name, signal in ctrl_signals.items(): self.stimulators[name].update(signal)
[docs] def set_io_processor(self, io_processor, communication_period=None) -> None: """Set simulator IO processor Will replace any previous IOProcessor so there is only one at a time. A Brian NetworkOperation is created to govern communication between the Network and the IOProcessor. Parameters ---------- io_processor : IOProcessor """ self.io_processor = io_processor # remove previous NetworkOperation if self._processing_net_op is not None: self.network.remove(self._processing_net_op) self._processing_net_op = None if io_processor is None: return def communicate_with_io_proc(t): if io_processor.is_sampling_now(t / ms): io_processor.put_state(self.get_state(), t / ms) ctrl_signal = io_processor.get_ctrl_signal(t / ms) self.update_stimulators(ctrl_signal) # communication should be at every timestep. The IOProcessor # decides when to sample and deliver results. if communication_period is None: communication_period = defaultclock.dt self._processing_net_op = NetworkOperation( communicate_with_io_proc, dt=communication_period ) self.network.add(self._processing_net_op) self.network.store(self._net_store_name)
[docs] def run(self, duration: Quantity, **kwparams) -> None: """Run simulation. Parameters ---------- duration : brian2 temporal Quantity Length of simulation **kwparams : additional arguments passed to brian2.run() level has a default value of 1 """ level = kwparams.get("level", 1) kwparams["level"] = level self.network.run(duration, **kwparams)
[docs] def reset(self, **kwargs): """Reset the simulator to a neutral state Restores the Brian Network to where it was when the CLSimulator was last modified (last injection, IOProcessor change). Calls reset() on stimulators, recorders, and IOProcessor. """ # kwargs passed to stimulators, recorders, and io_processor reset self.network.restore(self._net_store_name) for stim in self.stimulators.values(): stim.reset(**kwargs) for rec in self.recorders.values(): rec.reset(**kwargs) if self.io_processor is not None: self.io_processor.reset(**kwargs)