Source code for cleo.ioproc

"""Basic processor definitions and control/estimation functions"""
from __future__ import annotations

from abc import abstractmethod
from collections import deque
from typing import Tuple

import numpy as np
from attrs import define, field, fields
from brian2 import Quantity, ms
from jaxtyping import UInt

from cleo.base import IOProcessor
from cleo.utilities import unit_safe_append


[docs] @define class LatencyIOProcessor(IOProcessor): """IOProcessor capable of delivering stimulation some time after measurement. Note ---- It doesn't make much sense to combine parallel computation with "when idle" sampling, because "when idle" sampling only produces one sample at a time to process. """ t_samp: Quantity = field(factory=lambda: [] * ms, init=False, repr=False) """Record of sampling times---each time :meth:`~put_state` is called.""" sampling: str = field(default="fixed") """Sampling scheme: "fixed" or "when idle". "fixed" sampling means samples are taken on a fixed schedule, with no exceptions. "when idle" sampling means no samples are taken before the previous sample's output has been delivered. A sample is taken ASAP after an over-period computation: otherwise remains on schedule. """ @sampling.validator def _validate_sampling(self, attribute, value): if value not in ["fixed", "when idle"]: raise ValueError("Invalid sampling scheme:", value) processing: str = field(default="parallel") """Processing scheme: "serial" or "parallel". "parallel" computes the output time by adding the delay for a sample onto the sample time, so if the delay is 2 ms, for example, while the sample period is only 1 ms, some of the processing is happening in parallel. Output order matches input order even if the computed output time for a sample is sooner than that for a previous sample. "serial" computes the output time by adding the delay for a sample onto the output time of the previous sample, rather than the sampling time. Note this may be of limited utility because it essentially means the *entire* round trip cannot be in parallel at all. More realistic is that simply each block or phase of computation must be serial. If anyone cares enough about this, it will have to be implemented in the future. """ @processing.validator def _validate_processing(self, attribute, value): if value not in ["serial", "parallel"]: raise ValueError("Invalid processing scheme:", value) out_buffer: deque[Tuple[dict, float]] = field(factory=deque, init=False, repr=False) """ "serial" computes the output time by adding the delay for a sample onto the output time of the previous sample, rather than the sampling time. Note this may be of limited utility because it essentially means the *entire* round trip cannot be in parallel at all. More realistic is that simply each block or phase of computation must be serial. If anyone cares enough about this, it will have to be implemented in the future. Note ---- It doesn't make much sense to combine parallel computation with "when idle" sampling, because "when idle" sampling only produces one sample at a time to process. Raises ------ ValueError For invalid `sampling` or `processing` kwargs """
[docs] def put_state(self, state_dict: dict, t_samp: Quantity): self.t_samp = unit_safe_append(self.t_samp, t_samp) out, t_out = self.process(state_dict, t_samp) if self.processing == "serial" and len(self.out_buffer) > 0: prev_t_out = self.out_buffer[-1][1] # add delay onto the output time of the last computation t_out = prev_t_out + t_out - t_samp self.out_buffer.append((out, t_out)) self._needs_off_schedule_sample = False
[docs] def get_ctrl_signals(self, t_query): if len(self.out_buffer) == 0: return {} next_out_signal, next_t_out = self.out_buffer[0] if t_query >= next_t_out: self.out_buffer.popleft() return next_out_signal else: return {}
def _is_currently_idle(self, t_query): return len(self.out_buffer) == 0 or self.out_buffer[0][1] <= t_query
[docs] def is_sampling_now(self, t_query): resid_ms = np.round((t_query % self.sample_period) / ms, 6) if self.sampling == "fixed": if np.isclose(resid_ms, 0) or np.isclose( resid_ms, np.round(self.sample_period / ms, 6) ): return True elif self.sampling == "when idle": if np.isclose(resid_ms, 0): if self._is_currently_idle(t_query): self._needs_off_schedule_sample = False return True else: # if not done computing self._needs_off_schedule_sample = True return False else: # off-schedule, only sample if the last sampling period # was missed (there was an overrun) return self._needs_off_schedule_sample and self._is_currently_idle( t_query ) return False
[docs] @abstractmethod def process(self, state_dict: dict, t_samp: Quantity) -> Tuple[dict, Quantity]: """Process network state to generate output to update stimulators. This is the function the user must implement to define the signal processing pipeline. Parameters ---------- state_dict : dict {`recorder_name`: `state`} dictionary from :func:`~cleo.CLSimulator.get_state()` t_samp : Quantity The time at which the sample was taken. Returns ------- Tuple[dict, Quantity] {'stim_name': `ctrl_signal`} dictionary and output time (including unit). """ pass
def _base_reset(self): self.t_samp = fields(type(self)).t_samp.default.factory() self.out_buffer = fields(type(self)).out_buffer.default.factory() self._needs_off_schedule_sample = False
[docs] class RecordOnlyProcessor(LatencyIOProcessor): """Take samples without performing any control. Use this if all you are doing is recording.""" def __init__(self, sample_period, **kwargs): super().__init__(sample_period, **kwargs)
[docs] def process(self, state_dict: dict, sample_time: float) -> Tuple[dict, float]: return ({}, sample_time)
[docs] def exp_firing_rate_estimate( spike_counts: UInt[np.ndarray, "num_spike_sources"], dt: Quantity, prev_rate: Quantity, tau: Quantity, ) -> Quantity: """Estimate firing rate with a recursive exponential filter. Parameters ---------- spike_counts: np.ndarray n-length vector of spike counts dt: Quantity Time since last measurement (with Brian temporal unit) prev_rate: Quantity n-length vector of previously estimated firing rates tau: Quantity Time constant of exponential filter (with Brian temporal unit) Returns ------- Quantity n-length vector of estimated firing rates (with Brian units) """ alpha = np.exp(-dt / tau) return prev_rate * alpha + (1 - alpha) * spike_counts / dt
[docs] def pi_ctrl( measurement: float, reference: float, integ_error: float, dt: Quantity, Kp: float, Ki: Quantity = 0 / ms, ): error = reference - measurement integ_error += error * dt return Kp * error + Ki * integ_error, integ_error