Source code for cleo.base

"""Contains definitions for essential, base classes."""

from __future__ import annotations

import datetime
from abc import ABC, abstractmethod
from typing import Any, Tuple

import neo
from attrs import asdict, define, field, fields_dict
from brian2 import (
    BrianObjectException,
    Equations,
    Network,
    NetworkOperation,
    NeuronGroup,
    Quantity,
    Subgroup,
    Synapses,
    Unit,
    defaultclock,
    ms,
    np,
)
from matplotlib.artist import Artist
from mpl_toolkits.mplot3d import Axes3D

from cleo.registry import registry_for_sim
from cleo.utilities import add_to_neo_segment, analog_signal, brian_safe_name


[docs] class NeoExportable(ABC): """Mixin class for classes that can be exported to Neo objects"""
[docs] @abstractmethod def to_neo(self) -> neo.core.BaseNeo: """Return a Neo signal object with the device's data Returns ------- neo.core.BaseNeo Neo object representing exported data """ pass
[docs] @define(eq=False) class InterfaceDevice(ABC): """Base class for devices to be injected into the network""" brian_objects: set = field(factory=set, init=False, repr=False) """All the Brian objects added to the network by this device. Must be kept up-to-date in :meth:`connect_to_neuron_group` and other functions so that those objects can be automatically added to the network when the device is injected. """ sim: CLSimulator = field(init=False, default=None, repr=False) """The simulator the device is injected into """ name: str = field(kw_only=True) """Identifier for device, used in sampling, plotting, etc. Name of the class by default. Must be unique among recorders and stimulators""" save_history: bool = field(default=True, kw_only=True) """Determines whether times and inputs/outputs are recorded. True by default. For stimulators, this is when :meth:`~Stimulator.update` is called. For recorders, it is when :meth:`~Recorder.get_state` is called.""" @name.default def _default_name(self) -> str: return self.__class__.__name__
[docs] def init_for_simulator(self, simulator: CLSimulator) -> None: """Initialize device for simulator on initial injection. This function is called only the first time a device is injected into a simulator and performs any operations that are independent of the individual neuron groups it is connected to. Parameters ---------- simulator : CLSimulator simulator being injected into """ pass
[docs] def reset(self, **kwargs) -> None: """Reset the device to a neutral state""" pass
[docs] @abstractmethod def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams) -> None: """Connect device to given `neuron_group`. If your device introduces any objects which Brian must keep track of, such as a NeuronGroup, Synapses, or Monitor, make sure to add these to :attr:`~cleo.InterfaceDevice.brian_objects`. Parameters ---------- neuron_group : NeuronGroup **kwparams : optional Passed from `inject` """ pass
[docs] def add_self_to_plot( self, ax: Axes3D, axis_scale_unit: Unit, **kwargs ) -> list[Artist]: """Add device to an existing plot Should only be called by :func:`~cleo.viz.plot`. Parameters ---------- ax : Axes3D The existing matplotlib Axes object axis_scale_unit : Unit The unit used to label axes and define chart limits **kwargs : optional Returns ------- list[Artist] A list of artists used to render the device. Needed for use in conjunction with :class:`~cleo.viz.VideoVisualizer`. """ return []
[docs] def update_artists(self, artists: list[Artist], *args, **kwargs) -> list[Artist]: """Update the artists used to render the device Used to set the artists' state at every frame of a video visualization. The current state would be passed in `*args` or `**kwargs` Parameters ---------- artists : list[Artist] the artists used to render the device originally, i.e., which were returned from the first :meth:`add_self_to_plot` call. Returns ------- list[Artist] The artists that were actually updated. Needed for efficient blit rendering, where only updated artists are re-rendered. """ return []
[docs] @define class IOProcessor(ABC): """Abstract class for implementing sampling, signal processing and control This must be implemented by the user with their desired closed-loop use case, though most users will find the :func:`~processing.LatencyIOProcessor` class more useful, since delay handling is already defined. """ sample_period_ms: float = 1 """Determines how frequently the processor takes samples""" latest_ctrl_signal: dict = field(factory=dict, init=False, repr=False) """The most recent control signal returned by :meth:`get_ctrl_signals`"""
[docs] @abstractmethod def is_sampling_now(self, time) -> bool: """Determines whether the processor will take a sample at this timestep. Parameters ---------- time : Brian 2 temporal Unit Current timestep. Returns ------- bool """ pass
[docs] @abstractmethod def put_state(self, state_dict: dict, sample_time_ms: float) -> None: """Deliver network state to the :class:`IOProcessor`. Parameters ---------- state_dict : dict A dictionary of recorder measurements, as returned by :func:`~cleo.CLSimulator.get_state()` sample_time_ms: float The current simulation timestep. Essential for simulating control latency and for time-varying control. """ pass
[docs] @abstractmethod def get_ctrl_signals(self, query_time_ms: float) -> dict: """Get per-stimulator control signal from the :class:`~cleo.IOProcessor`. Parameters ---------- query_time_ms : float Current simulation time. Returns ------- dict A {'stimulator_name': ctrl_signal} dictionary for updating stimulators. """ pass
[docs] def get_stim_values(self, query_time_ms: float) -> dict: ctrl_signals = self.get_ctrl_signals(query_time_ms) self.latest_ctrl_signal.update(ctrl_signals) stim_value_conversions = self.preprocess_ctrl_signals( self.latest_ctrl_signal, query_time_ms ) return ctrl_signals | stim_value_conversions
[docs] def preprocess_ctrl_signals( self, latest_ctrl_signals: dict, query_time_ms: float ) -> dict: """Preprocess control signals as needed to control stimulator waveforms between samples. I.e., if a control signal defines the frequency of a periodic light stimulus, this function computes the current intensity given the latest frequency and the current time. This is called immediately after :meth:`get_ctrl_signals` and on every timestep to update the stimulator waveform between samples. This only needs to be implemented when a stimulus that varies between samples is desired. Otherwise, the control signal returned by :meth:`get_ctrl_signals` is used directly. If not all stimulators need this functionality, only return a dict for those that do. The original, unprocessed control signal is used for the others. Parameters ---------- query_time_ms : float Current simulation time. Returns ------- dict A {'stimulator_name': value} dictionary for updating stimulators. """ return {}
[docs] def get_intersample_ctrl_signal(self, query_time_ms: float) -> dict: """Get per-stimulator control signal between samples. I.e., for implementing a time-varying waveform based on parameters from the last sample. Such parameters will need to be stored in the :class:`~cleo.IOProcessor`.""" return {}
[docs] def reset(self, **kwargs) -> None: pass
[docs] @define(eq=False) class Recorder(InterfaceDevice): """Device for taking measurements of the network."""
[docs] @abstractmethod def get_state(self) -> Any: """Return current measurement.""" pass
[docs] @define(eq=False) class Stimulator(InterfaceDevice, NeoExportable): """Device for manipulating the network""" value: Any = field(init=False, default=None) """The current value of the stimulator device""" default_value: Any = 0 """The default value of the device---used on initialization and on :meth:`~reset`""" t_ms: list[float] = field(factory=list, init=False, repr=False) """Times stimulator was updated, stored if :attr:`~cleo.InterfaceDevice.save_history`""" values: list[Any] = field(factory=list, init=False, repr=False) """Values taken by the stimulator at each :meth:`~update` call, stored if :attr:`~cleo.InterfaceDevice.save_history`""" def __attrs_post_init__(self): self.value = self.default_value self._init_saved_vars() def _init_saved_vars(self): if self.save_history: if self.sim: t0 = self.sim.network.t / ms else: t0 = 0 self.t_ms = [t0] self.values = [self.value]
[docs] def update(self, ctrl_signal) -> None: """Set the stimulator value. By default this sets :attr:`value` to ``ctrl_signal`` and updates saved times and values. You will want to implement this method if your stimulator requires additional logic. Use ``super.update(self, value)`` to preserve the ``self.value`` and ``save_history`` logic Parameters ---------- ctrl_signal : any The value the stimulator is to take. """ self.value = ctrl_signal if self.save_history: self.t_ms.append(self.sim.network.t / ms) self.values.append(self.value)
[docs] def reset(self, **kwargs) -> None: """Reset the stimulator device to a neutral state""" self.value = self.default_value self._init_saved_vars()
[docs] def to_neo(self): signal = analog_signal(self.t_ms, self.values, "dimensionless") signal.name = self.name signal.description = "Exported from Cleo stimulator device" signal.annotate(export_datetime=datetime.datetime.now()) return signal
[docs] @define(eq=False) class CLSimulator(NeoExportable): """The centerpiece of cleo. Integrates simulation components and runs.""" network: Network = field(repr=False) """The Brian network forming the core model""" io_processor: IOProcessor = field(default=None, init=False) recorders: dict[str, Recorder] = field(factory=dict, init=False, repr=False) stimulators: dict[str, Stimulator] = field(factory=dict, init=False, repr=False) devices: set[InterfaceDevice] = field(factory=set, init=False) _processing_net_op: NetworkOperation = field(default=None, init=False, repr=False) _net_store_name: str = field(default="cleo default", init=False, repr=False)
[docs] def inject( self, device: InterfaceDevice, *neuron_groups: NeuronGroup, **kwparams: Any ) -> CLSimulator: """Inject InterfaceDevice into the network, connecting to specified neurons. Calls :meth:`~InterfaceDevice.connect_to_neuron_group` for each group with kwparams and adds the device's :attr:`~InterfaceDevice.brian_objects` to the simulator's :attr:`network`. Parameters ---------- device : InterfaceDevice Device to inject Returns ------- CLSimulator self """ if len(neuron_groups) == 0: raise Exception("Injecting stimulator for no neuron groups is meaningless.") for ng in neuron_groups: if type(ng) == NeuronGroup: if ng not in self.network.objects: raise Exception( f"Attempted to connect device {device.name} to neuron group " f"{ng.name}, which is not part of the simulator's network." ) elif type(ng) == Subgroup: # must look at sorted_objects because ng.source is unhashable if ng.source not in self.network.sorted_objects: raise Exception( f"Attempted to connect device {device.name} to neuron group " f"{ng.source.name}, which is not part of the simulator's network." ) if device.sim not in [None, self]: raise Exception( f"Attempted to inject device {device.name} into {self}, " f"but it was previously injected into {device.sim}. " "Each device can only be injected into one CLSimulator." ) if device.sim is None: device.sim = self device.init_for_simulator(self) device.connect_to_neuron_group(ng, **kwparams) for brian_object in device.brian_objects: if brian_object not in self.network.objects: self.network.add(brian_object) self.network.store(self._net_store_name) if isinstance(device, Recorder): if ( device.name in self.recorders and device is not self.recorders[device.name] ): raise ValueError( f"Another Recorder with name {device.name} has already been injected" ) self.recorders[device.name] = device if isinstance(device, Stimulator): if ( device.name in self.stimulators and device is not self.stimulators[device.name] ): raise ValueError( f"Another Stimulator with name {device.name} has already been injected" ) self.stimulators[device.name] = device self.devices.add(device) return self
[docs] def get_state(self) -> dict: """Return current recorder measurements. Returns ------- dict A dictionary of `name`: `state` pairs for all recorders in the simulator. """ state = {} for name, recorder in self.recorders.items(): state[name] = recorder.get_state() return state
[docs] def update_stimulators(self, stim_values: dict[str, Any]) -> None: """Update stimulators with output from the :class:`IOProcessor` Parameters ---------- stim_values : dict {`stimulator_name`: `stim_value`} dictionary with values to update each stimulator. """ for name, value in stim_values.items(): self.stimulators[name].update(value)
[docs] def set_io_processor( self, io_processor: IOProcessor, communication_period=None ) -> CLSimulator: """Set simulator IO processor Will replace any previous IOProcessor so there is only one at a time. A Brian NetworkOperation is created to govern communication between the Network and the IOProcessor. Parameters ---------- io_processor : IOProcessor Returns ------- CLSimulator self """ self.io_processor = io_processor # remove previous NetworkOperation if self._processing_net_op is not None: self.network.remove(self._processing_net_op) self._processing_net_op = None if io_processor is None: return def communicate_with_io_proc(t): # assuming no one will have timesteps shorter than nanoseconds... now_ms = round(t / ms, 6) if io_processor.is_sampling_now(now_ms): io_processor.put_state(self.get_state(), now_ms) stim_values = io_processor.get_stim_values(now_ms) self.update_stimulators(stim_values) # communication should be at every timestep. The IOProcessor # decides when to sample and deliver results. if communication_period is None: communication_period = defaultclock.dt self._processing_net_op = NetworkOperation( communicate_with_io_proc, dt=communication_period ) self.network.add(self._processing_net_op) self.network.store(self._net_store_name) return self
[docs] def run(self, duration: Quantity, **kwparams) -> None: """Run simulation. Parameters ---------- duration : brian2 temporal Quantity Length of simulation **kwparams : additional arguments passed to brian2.run() level has a default value of 1 """ level = kwparams.get("level", 1) kwparams["level"] = level self.network.run(duration, **kwparams)
[docs] def reset(self, **kwargs): """Reset the simulator to a neutral state Restores the Brian Network to where it was when the CLSimulator was last modified (last injection, IOProcessor change). Calls reset() on :attr:`devices` and :attr:`io_processor`. """ # kwargs passed to stimulators, recorders, and io_processor reset self.network.restore(self._net_store_name) for device in self.devices: device.reset(**kwargs) if self.io_processor is not None: self.io_processor.reset(**kwargs)
[docs] def to_neo(self) -> neo.core.Block: """Exports simulator data to a Neo Block Returns ------- neo.core.Block Neo Block containing signals representing each device's data """ block = neo.Block( description="Exported from Cleo simulation", ) block.annotate(export_datetime=datetime.datetime.now()) seg = neo.Segment() block.segments.append(seg) for device in self.devices: if not isinstance(device, NeoExportable): continue dev_neo = device.to_neo() if isinstance(dev_neo, neo.core.Group): data_objects = dev_neo.data_children_recur block.groups.append(dev_neo) elif isinstance(dev_neo, neo.core.dataobject.DataObject): data_objects = [dev_neo] add_to_neo_segment(seg, *data_objects) return block
[docs] @define(eq=False) class SynapseDevice(InterfaceDevice): """Base class for devices that record from/stimulate neurons via a Synapses object with device-specific model. Used for opsin and indicator classes""" model: str = field(init=False) """Basic Brian model equations string. Should contain a `rho_rel` term reflecting relative expression levels. Will likely also contain special NeuronGroup-dependent symbols such as V_VAR_NAME to be replaced on injection in :meth:`modify_model_and_params_for_ng`.""" on_pre: str = field(init=False, default="") """Model string for :class:`brian2.synapses.synapses.Synapses` reacting to spikes.""" synapses: dict[str, Synapses] = field(factory=dict, init=False, repr=False) """Stores the synapse objects implementing the model, connecting from source (light aggregator neurons or the target group itself) to target neuron groups, of form ``{target_ng.name: synapses}``.""" source_ngs: dict[str, NeuronGroup] = field(factory=dict, init=False, repr=False) """``{target_ng.name: source_ng}`` dict of source neuron groups. The source is the target itself by default or light aggregator neurons for :class:`~cleo.light.LightDependent`.""" per_ng_unit_replacements: list[Tuple[str, str]] = field( factory=list, init=False, repr=False ) """List of (UNIT_NAME, neuron_group_specific_unit_name) tuples to be substituted in the model string on injection and before checking required variables.""" required_vars: list[Tuple[str, Unit]] = field(factory=list, init=False, repr=False) """Default names of state variables required in the neuron group, along with units, e.g., [('Iopto', amp)]. It is assumed that non-default values can be passed in on injection as a keyword argument ``[default_name]_var_name=[non_default_name]`` and that these are found in the model string as ``[DEFAULT_NAME]_VAR_NAME`` before replacement.""" extra_namespace: dict = field(factory=dict, repr=False) """Additional items (beyond parameters) to be added to the opto synapse namespace""" def _get_source_for_synapse( self, target_ng: NeuronGroup, i_targets: list[int] ) -> Tuple[NeuronGroup, list[int]]: """Get the source neuron group and indices of source neurons. Parameters ---------- ng : NeuronGroup The target neuron group. i_targets : list[int] The indices of the target neurons in the target neuron group. Returns ------- Tuple[NeuronGroup, list[int]] A tuple containing the source neuron group and indices to use in Synapses """ # by default the source is the target group itself return target_ng, i_targets
[docs] def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams) -> None: """Transfect neuron group with device. Parameters ---------- neuron_group : NeuronGroup The neuron group to transform Keyword args ------------ p_expression : float Probability (0 <= p <= 1) that a given neuron in the group will express the protein. 1 by default. i_targets : array-like Indices of neurons in the group to transfect. recommended for efficiency when stimulating or imaging a small subset of the group. Incompatible with ``p_expression``. rho_rel : float The expression level, relative to the standard model fit, of the protein. 1 by default. For heterogeneous expression, this would have to be modified in the light-dependent synapse post-injection, e.g., ``opsin.syns["neuron_group_name"].rho_rel = ...`` [default_name]_var_name : str See :attr:`~required_vars`. Allows for custom variable names. """ if neuron_group.name in self.source_ngs: assert neuron_group.name in self.synapses raise ValueError( f"{self.__class__.__name__} {self.name} already connected to neuron group" f" {neuron_group.name}" ) # get modified synapse model string (i.e., with names/units specified) mod_syn_model, mod_syn_params = self.modify_model_and_params_for_ng( neuron_group, kwparams ) # handle p_expression if "p_expression" in kwparams: if "i_targets" in kwparams: raise ValueError("p_expression and i_targets are incompatible") p_expression = kwparams.get("p_expression", 1) expr_bool = np.random.rand(neuron_group.N) < p_expression i_targets = np.where(expr_bool)[0] elif "i_targets" in kwparams: i_targets = kwparams["i_targets"] else: i_targets = list(range(neuron_group.N)) if len(i_targets) == 0: return source_ng, i_sources = self._get_source_for_synapse(neuron_group, i_targets) syn = Synapses( source_ng, neuron_group, model=mod_syn_model, on_pre=self.on_pre, namespace=mod_syn_params, name=f"syn_{brian_safe_name(self.name)}_{neuron_group.name}", ) syn.namespace.update(self.extra_namespace) syn.connect(i=i_sources, j=i_targets) self.init_syn_vars(syn) # relative protein density syn.rho_rel = kwparams.get("rho_rel", 1) # store at the end, after all checks have passed self.source_ngs[neuron_group.name] = source_ng self.brian_objects.add(source_ng) self.synapses[neuron_group.name] = syn self.brian_objects.add(syn) registry = registry_for_sim(self.sim) registry.register(self, neuron_group)
[docs] def modify_model_and_params_for_ng( self, neuron_group: NeuronGroup, injct_params: dict ) -> Tuple[Equations, dict]: """Adapt model for given neuron group on injection This enables the specification of variable names differently for each neuron group, allowing for custom names and avoiding conflicts. Parameters ---------- neuron_group : NeuronGroup NeuronGroup this opsin model is being connected to injct_params : dict kwargs passed in on injection, could contain variable names to plug into the model Keyword Args ------------ model : str, optional Model to start with, by default that defined for the class. This allows for prior string manipulations before it can be parsed as an `Equations` object. Returns ------- Equations, dict A tuple containing an Equations object and a parameter dictionary, constructed from :attr:`~model` and :attr:`~params`, respectively, with modified names for use in :attr:`~cleo.opto.OptogeneticIntervention.synapses` """ model = self.model # perform unit substitutions for unit_name, neuron_group_unit_name in self.per_ng_unit_replacements: model = model.replace(unit_name, neuron_group_unit_name) # check required variables/units and replace placeholder names for default_name, unit in self.required_vars: var_name = injct_params.get(f"{default_name}_var_name", default_name) if var_name not in neuron_group.variables or not neuron_group.variables[ var_name ].unit.has_same_dimensions(unit): raise BrianObjectException( ( f"{var_name} : {unit.name} needed in the model of NeuronGroup " f"{neuron_group.name} to connect Opsin {self.name}." ), neuron_group, ) # opsin synapse model needs modified names to_replace = f"{default_name}_var_name".upper() model = model.replace(to_replace, var_name) # Synapse variable and parameter names cannot be the same as any # neuron group variable name return self._fix_name_conflicts(model, neuron_group)
@property def params(self) -> dict: """Returns a dictionary of parameters for the model""" params = asdict(self, recurse=False) # remove generic fields that are not parameters # assume we only want fields in the last class in the # hierarchy parent_class = type(self).__mro__[1] for field in fields_dict(parent_class): params.pop(field) # remove private attributes for key in list(params.keys()): if key.startswith("_"): params.pop(key) return params def _fix_name_conflicts( self, modified_model: str, neuron_group: NeuronGroup ) -> Tuple[Equations, dict]: modified_params = self.params.copy() rename = lambda x: f"{x}_syn" # get variables to rename opsin_eqs = Equations(modified_model) substitutions = {} for var in opsin_eqs.names: if var in neuron_group.variables: substitutions[var] = rename(var) # and parameters for param in self.params.keys(): if param in neuron_group.variables: substitutions[param] = rename(param) modified_params[rename(param)] = modified_params[param] del modified_params[param] mod_opsin_eqs = opsin_eqs.substitute(**substitutions) return mod_opsin_eqs, modified_params
[docs] def reset(self, **kwargs): for opto_syn in self.synapses.values(): self.init_syn_vars(opto_syn)
[docs] def init_syn_vars(self, syn: Synapses) -> None: """Initializes appropriate variables in Synapses implementing the model Also called on :meth:`reset`. Parameters ---------- syn : Synapses The synapses object implementing this model """ pass