Source code for cleo.recorders
"""Contains basic recorders."""
from typing import Any
from brian2 import PopulationRateMonitor, StateMonitor, SpikeMonitor, Quantity
import numpy as np
from nptyping import NDArray
from cleo.base import Recorder
[docs]class RateRecorder(Recorder):
"""Records firing rate from a single neuron.
Firing rate comes from Brian's :class:`~brian2.monitors.ratemonitor.PopulationRateMonitor`"""
def __init__(self, name: str, index: int):
"""
Parameters
----------
name : str
Unique device name
index : int
index of neuron to record
"""
super().__init__(name)
self.i = index
self.mon = None
[docs] def connect_to_neuron_group(self, neuron_group):
self.mon = PopulationRateMonitor(neuron_group[self.i])
self.brian_objects.add(self.mon)
[docs]class VoltageRecorder(Recorder):
"""Records the voltage of a single neuron group."""
def __init__(self, name: str, voltage_var_name: str = "v"):
"""
Parameters
----------
name : str
Unique device name
voltage_var_name : str, optional
Name of variable representing membrane voltage, by default "v"
"""
super().__init__(name)
self.voltage_var_name = voltage_var_name
self.mon = None
[docs] def connect_to_neuron_group(self, neuron_group):
if self.mon is not None:
raise UserWarning(
"Recorder was already connected to a neuron group. "
"Can only record from one at a time."
)
self.mon = StateMonitor(neuron_group, self.voltage_var_name, record=True)
self.brian_objects.add(self.mon)
[docs] def get_state(self) -> Quantity:
"""
Returns
-------
Quantity
Current voltage of target neuron group
"""
try:
return self.mon.v[:, -1]
except IndexError:
return None
[docs]class GroundTruthSpikeRecorder(Recorder):
"""Reports the number of spikes seen since last queried for each neuron.
This amounts effectively to the number of spikes per control period.
Note: this will only work for one neuron group at the moment.
"""
def __init__(self, name):
super().__init__(name)
self.mon = None
self.num_spikes_seen = 0
[docs] def connect_to_neuron_group(self, neuron_group):
if self.mon is not None:
raise UserWarning(
"Recorder was already connected to a neuron group. "
"Can only record from one at a time."
)
self.mon = SpikeMonitor(neuron_group)
self.brian_objects.add(self.mon)
self.out_template = np.zeros(len(neuron_group))
[docs] def get_state(self) -> NDArray[(Any,), np.uint]:
"""
Returns
-------
NDArray[(n_neurons,), np.uint]
n_neurons-length array with spike counts over the latest
control period.
"""
num_new_spikes = len(self.mon.t) - self.num_spikes_seen
self.num_spikes_seen += num_new_spikes
if len(self.out_template) == 1:
out = np.array([num_new_spikes])
else:
out = self.out_template.copy()
for spike_i in self.mon.i[-num_new_spikes:]:
out[spike_i] += 1
return out