"""Contains multi-unit and sorted spiking signals."""
from __future__ import annotations
import warnings
from abc import abstractmethod
from datetime import datetime
from functools import cache
from typing import Callable, Tuple
import brian2.only as b2
import neo
import numpy as np
import quantities as pq
from attrs import define, field, fields
from brian2 import NeuronGroup, Quantity, SpikeMonitor, mm, ms, um
from jaxtyping import Bool, Float, Int
from scipy import signal
from scipy.stats import norm
from cleo.base import NeoExportable
from cleo.ephys.probes import Signal
from cleo.utilities import rng, unit_safe_cat
[docs]
@define(eq=False)
class Spiking(Signal, NeoExportable):
"""Base class for probabilistically detecting spikes.
See ``notebooks/spike_detection.py`` for an interactive explanation of the methods
and parameters involved."""
r_noise_floor: Quantity = 80 * um
"""Radius (with Brian unit) at which the measured spike amplitude
equals the background noise standard deviation.
i.e., ``spike_amplitude(r_noise_floor) = sigma_noise = 1``.
80 μm by default."""
threshold_sigma: int = 4
"""Spike detection threshold, as a multiple of sigma_noise.
Values in real experiments typically range from 3 to 6. 4 by default."""
spike_amplitude_cv: float = 0.05
"""Coefficient of variation of the spike amplitude, i.e., `|sigma_amp/mu_amp|`.
From what we have seen in Allen Cell Types data, this ranges from 0 to 0.2,
but is most often very low. 0.05 by default."""
r0: Quantity = 5 * um
"""A small distance added to r before computing the amplitude to avoid division
by 0 for the power law decay. 5 μm by default.
It also makes some physical sense as the minimum distance from the current source
it is possible to place an electrode, 5 μm being reasonable as the radius of a typical soma."""
recording_recall_cutoff: float = 0.001
"""*Multi-channel* recall, above which neurons will be considered.
I.e., the probability a spike is detected on at least one channel.
You shouldn't need to change this; it's mainly for efficiency, allowing amplitude
sampling and threshold crossing to operate on fewer spikes by ignoring neurons
very unlikely to produce a spike that crosses the threshold."""
eap_decay_fn: Callable[[Quantity], float] = lambda r: r**-2
"""The function describing the decay of the measured extracellular action potential
amplitude. By default 1/r^2.
This inverse square decay is a good approximation in accordance with the detailed
simulations by `Pettersen et al. (2008) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2186261/>`_,
though they find the exponent ranges from 2 to 3 depending on the cell type and distance. """
collision_prob_fn: Callable[[Quantity], float] = field(default=None)
"""The probability of failing to detect the latter of two overlapping threshold
crossings on a given channel, as a function of ``t2 - t1``.
Values for ``t2 - t1 < 0`` are ignored.
For :class:`SortedSpiking`, the default is a decaying exponential.
See `Garcia et al. (2022) <https://www.eneuro.org/content/9/5/ENEURO.0105-22.2022>`_
for what this might look like for different sorters.
By default simply enforces a hard 1 ms refractory period
per channel for :class:`MultiUnitActivity`.
"""
@collision_prob_fn.validator
def _validate_coll_prob_fn(self, attribute, value):
if value is not None:
assert callable(value), "collision_prob_fn must be callable"
assert np.all(0 <= value([0, 1, 10] * ms) <= 1), (
"collision_prob_fn must return a value between 0 and 1"
)
simulate_false_positives: bool = True
"""Whether to simulate false positives from noise. In the case of :class:`SortedSpiking`,
these aren't reported, but still affect collision sampling."""
t: Quantity = field(
init=False, factory=lambda: ms * np.array([], dtype=float), repr=False
)
"""Spike times with Brian units, stored if
:attr:`~cleo.InterfaceDevice.save_history` on :attr:`~Signal.probe`"""
i: Int[np.ndarray, "n_recorded_spikes"] = field(
init=False, factory=lambda: np.array([], dtype=int), repr=False
)
"""Channel (for multi-unit) or neuron (for sorted) indices
of spikes, stored if
:attr:`~cleo.InterfaceDevice.save_history` on :attr:`~Signal.probe`"""
t_samp: Quantity = field(
init=False, factory=lambda: ms * np.array([], dtype=float), repr=False
)
"""Sample times with Brian units when each spike was recorded, stored if
:attr:`~cleo.InterfaceDevice.save_history` on :attr:`~Signal.probe`"""
i_probe_by_ng: dict[NeuronGroup, Int[np.ndarray, "ng_N"]] = field(
init=False, factory=dict, repr=False
)
"""neuron_group keys, i_probe values for every neuron in group."""
i_ng_by_i_probe: list[tuple[NeuronGroup, int]] = field(
init=False, factory=list, repr=False
)
"""n_neurons-length list indexed by i_probe returning a neuron_group, i_ng tuple
to map from i_probe to neuron group and index."""
_monitors: list[SpikeMonitor] = field(init=False, factory=list, repr=False)
_mon_spikes_already_seen: list[int] = field(init=False, factory=list, repr=False)
_mu_eap: Float[np.ndarray, "n_neurons n_channels"] = field(
init=False, default=None, repr=False
)
_prev_t: Quantity = field(init=False, default=None, repr=False)
_prev_zi: Float[np.ndarray, "n_sos 2"] = field(init=False, default=None, repr=False)
def _init_saved_vars(self):
if self.probe.save_history:
self.t = fields(type(self)).t.default.factory()
self.i = fields(type(self)).i.default.factory()
self.t_samp = fields(type(self)).t_samp.default.factory()
def _update_saved_vars(self, t, i, t_samp):
if self.probe.save_history:
self.i = np.concatenate([self.i, i])
self.t = unit_safe_cat([self.t, t])
t_samp_rep = np.full_like(t, t_samp)
self.t_samp = unit_safe_cat([self.t_samp, t_samp_rep])
@property
@abstractmethod
def n(self):
"""Number of spike sources: channels for :class:`MultiUnitActivity` or sorted neurons
for :class:`SortedSpiking`."""
pass
@property
def n_channels(self) -> int:
"""Number of channels on probe"""
return self.probe.n
@property
def n_neurons(self) -> int:
"""Number of neurons recorded by probe"""
return sum([np.sum(i_probe != -2) for i_probe in self.i_probe_by_ng.values()])
@property
def r_threshold(self, resolution: Quantity = um / 10) -> Quantity:
"""The distance from a contact at which the SNR equals the detection threshold.
This also means 50% single-channel recall."""
return self.r_for_snr(self.threshold_sigma, resolution=resolution)
[docs]
def r_for_recall(
self,
recall: float,
resolution: Quantity = um / 10,
upper_limit: Quantity = None,
) -> Quantity:
"""The distance from a contact at which the single-channel recall
(detection probability) equals the specified value."""
if upper_limit is None:
upper_limit = 10 * self.r_noise_floor
rr = np.arange(0, upper_limit / um, resolution / um) * um
recalls = self.recall_by_distance(rr)
try:
return rr[recalls <= recall][0]
except IndexError:
return self.r_for_recall(
recall, resolution=resolution, upper_limit=upper_limit * 2
)
[docs]
def r_for_snr(
self, snr: float, resolution: Quantity = um / 10, upper_limit: Quantity = None
) -> Quantity:
"""The distance from a contact at which the SNR equals the specified value."""
if upper_limit is None:
upper_limit = 10 * self.r_noise_floor
rr = np.arange(0, upper_limit / um, resolution / um) * um
try:
return rr[self.snr_by_distance(rr) <= snr][0]
except IndexError:
return self.r_for_snr(
snr, resolution=resolution, upper_limit=upper_limit * 2
)
[docs]
def recall_by_snr(self, snr: float) -> float:
"""Probability of detecting a spike at distance r from the neuron
as a function of SNR."""
# 1 - P(spike not detected)
mu = snr
sigma = np.sqrt(1 + (mu * self.spike_amplitude_cv) ** 2)
return norm.sf(self.threshold_sigma, loc=mu, scale=sigma)
[docs]
def recall_by_distance(self, r: Quantity) -> float:
"""Probability of detecting a spike at distance r from the neuron
as a function of distance."""
return self.recall_by_snr(self.snr_by_distance(r))
[docs]
def snr_by_distance(self, r: Quantity) -> float:
"""The mean extracellular action potential amplitude as a function of distance
from the neuron, in units of background noise standard deviation."""
return self.eap_decay_fn(r + self.r0) / self.eap_decay_fn(
self.r0 + self.r_noise_floor
)
[docs]
def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams):
"""Configure signal to record from specified neuron group
Parameters
----------
neuron_group : NeuronGroup
Neuron group to record from
"""
super(Spiking, self).connect_to_neuron_group(neuron_group, **kwparams)
if neuron_group in self.i_probe_by_ng:
raise ValueError(
f"Spiking signal {self.name} already connected to NeuronGroup {neuron_group.name}"
)
# could support separate detection probabilities per group using kwparams
# n_neurons X n_channels X 3
dist2 = np.zeros((len(neuron_group), self.n_channels))
for dim in ["x", "y", "z"]:
dim_ng, dim_probe = np.meshgrid(
getattr(neuron_group, dim),
getattr(self.probe, f"{dim}s"),
indexing="ij",
)
# proactively strip units to avoid numpy maybe doing so
dist2 += (dim_ng / mm - dim_probe / mm) ** 2
distances = np.sqrt(dist2) * mm
snr = self.snr_by_distance(distances)
# 1 from baseline noise, mu * cv from spike amp variability
sigma_eap = np.sqrt(1 + (snr * self.spike_amplitude_cv) ** 2)
probs_miss = norm.cdf(self.threshold_sigma, loc=snr, scale=sigma_eap)
assert probs_miss.shape == (len(neuron_group), self.n_channels)
multi_channel_recall = 1 - np.prod(probs_miss, axis=1)
where_to_keep = multi_channel_recall >= self.recording_recall_cutoff
i2keep = np.nonzero(where_to_keep)[0]
# spike sorters don't sort by multi-channel recall; we only do this
# for computational efficiency, to not get super distant neurons
n2keep = np.sum(where_to_keep)
if n2keep == 0:
warnings.warn(
f"NeuronGroup {neuron_group.name} has no neurons with multi-channel recall >= {self.recording_recall_cutoff}."
" Skipping this group."
)
return
# create monitor
mon = SpikeMonitor(neuron_group)
self._monitors.append(mon)
self.brian_objects.add(mon)
self._mon_spikes_already_seen.append(0)
# update mapping from neuron group to probe index
i_probe_start = self.n_neurons
# -2 means not recorded
new_i_probe = np.arange(i_probe_start, i_probe_start + n2keep)
i_probe_for_ng = np.full(neuron_group.N, -2, dtype=int)
i_probe_for_ng[where_to_keep] = new_i_probe
self.i_probe_by_ng[neuron_group] = i_probe_for_ng
# update mapping from probe index to neuron group & index
assert len(self.i_ng_by_i_probe) == i_probe_start
self.i_ng_by_i_probe.extend(list(zip([neuron_group] * n2keep, i2keep)))
if not self._prev_t:
self._prev_t = self.probe.sim.network.t
# store neuron-channel mean and stdev of EAPs to use in subclasses
if self._mu_eap is None:
self._mu_eap = snr[where_to_keep]
else:
self._mu_eap = np.concatenate((self._mu_eap, snr[where_to_keep]), axis=0)
return snr[where_to_keep].max(axis=1), new_i_probe
[docs]
@abstractmethod
def get_state(
self,
) -> tuple[Int[np.ndarray, "n_spikes"], Quantity, Int[np.ndarray, "{self.n}"]]:
"""Return spikes since method was last called (i, t, y)
Returns
-------
tuple[Int[np.ndarray, "n_spikes"], Quantity, Int[np.ndarray, "{self.n}"]]
(i, t, y) where i is channel (for multi-unit) or neuron (for sorted) spike
indices, t is spike times, and y is a spike count vector suitable for control-
theoretic uses---i.e., a 0 for every channel/neuron that hasn't spiked and a 1
for a single spike.
"""
pass
def _get_new_spikes(self) -> Tuple[Int[np.ndarray, "n_spikes"], Quantity]:
i_probe = np.array([], dtype=int)
t = ms * np.array([], dtype=float)
for j in range(len(self._monitors)):
mon = self._monitors[j]
spikes_already_seen = self._mon_spikes_already_seen[j]
i_ng = mon.i[spikes_already_seen:] # can contain spikes we don't care about
# filter out spikes we don't care about
i_probe_unfilt = self.i_probe_by_ng[mon.source][i_ng]
i2keep = i_probe_unfilt != -2
i_probe = np.concatenate((i_probe, i_probe_unfilt[i2keep]))
t = unit_safe_cat([t, mon.t[spikes_already_seen:][i2keep]])
self._mon_spikes_already_seen[j] = mon.num_spikes
return i_probe, t
@staticmethod
@cache
def _prep_noise(dt_s: float) -> np.ndarray:
lowcut_Hz = 300
fs_Hz = 1 / dt_s
highcut_Hz = min(3000, 0.45 * fs_Hz)
sos = signal.butter(
6, [lowcut_Hz, highcut_Hz], fs=fs_Hz, btype="band", output="sos"
)
w, h = signal.sosfreqz(sos, 2**14, fs=fs_Hz)
enbw = np.trapz(np.abs(h) ** 2, w)
assert np.isclose(enbw, highcut_Hz - lowcut_Hz, rtol=0.01), (
f"{enbw} != {highcut_Hz - lowcut_Hz}"
)
nyq = 0.5 * fs_Hz
pre_filter_factor = np.sqrt(nyq / enbw)
# inline test:
white_noise = rng.standard_normal((400, 32))
# scale noise so RMS after filtering is 1
white_noise *= pre_filter_factor
filtered_noise = signal.sosfilt(sos, white_noise, axis=0)
assert np.isclose(np.std(filtered_noise), 1, rtol=0.05), (
f"Filtered noise RMS {np.std(filtered_noise)} != 1"
)
return sos, pre_filter_factor
def _generate_noise(self) -> tuple[Float[np.ndarray, "n_t n_channels"], Quantity]:
"""generate noise in spiking band"""
dt = b2.defaultclock.dt
n_t = int(round((self.probe.sim.network.t - self._prev_t) / dt))
t_window = np.arange(n_t) * dt + self._prev_t
sos, pre_filter_factor = self._prep_noise(dt / b2.second)
if n_t <= 0:
return np.zeros((0, self.n_channels)), [] * ms
# generate white noise
white_noise = rng.standard_normal((n_t, self.n_channels))
# scale noise so RMS after filtering is 1
white_noise *= pre_filter_factor
if self._prev_zi is None:
self._prev_zi = np.zeros((sos.shape[0], 2, self.n_channels))
noise_filt, zi = signal.sosfilt(sos, white_noise, axis=0, zi=self._prev_zi)
# I assume we won't have spikes at the current timestep here
# since Cleo's NetworkOperation is scheduled for start of timestep
self._prev_zi = zi
return noise_filt, t_window
def _noisily_get_true_tcs(
self, i_probe, t
) -> Tuple[
Quantity,
Int[np.ndarray, "n_tcs"],
Int[np.ndarray, "n_tcs"],
Float[np.ndarray, "n_tcs"],
Float[np.ndarray, "n_t_window {self.n_channels}"],
Quantity,
]:
""""""
n_spks = len(i_probe)
noise, t_noise = self._generate_noise()
# mu and sigma arrays: n_nrns x n_channels
mu_eap_for_spikes = self._mu_eap[i_probe]
sigma_spike_amps = mu_eap_for_spikes * self.spike_amplitude_cv
# add noise at right timesteps
noise_at_spks = noise[
((t - self._prev_t) / b2.defaultclock.dt).round().astype(int)
]
assert (
mu_eap_for_spikes.shape
== sigma_spike_amps.shape
== noise_at_spks.shape
== (n_spks, self.n_channels)
)
amps = (
rng.standard_normal((n_spks, 1)) * sigma_spike_amps
+ mu_eap_for_spikes
+ noise_at_spks
)
# ⬇ nonzero gives row, column indices of each nonzero element
i_spk_tcs, i_chan_tcs = (amps > self.threshold_sigma).nonzero()
i_probe_tcs = i_probe[i_spk_tcs]
t_tcs = t[i_spk_tcs]
amp_tcs = amps[i_spk_tcs, i_chan_tcs]
self._prev_t = self.probe.sim.network.t
return t_tcs, i_probe_tcs, i_chan_tcs, amp_tcs, noise, t_noise
@staticmethod
@cache
def _max_collision_interval(dt_ms, collision_prob_fn):
intervals = np.arange(10 / dt_ms) * dt_ms * ms
i = np.searchsorted(-collision_prob_fn(intervals).astype(float), -1e-3)
if i == len(intervals):
warnings.warn(
"collision_prob_fn(10 ms) > 1e-3. "
"Will not look for collisions over 10 ms in the past."
)
else:
return intervals[i]
def _sample_collisions(self, t, i_chan, amps) -> Bool[np.ndarray, "n_spikes"]:
"""Filter out spikes that are too close together in time on the same channel.
For simplicity, the first spike is kept, or the largest if simultaneous.
Note this operates on candidate threshold crossings, not called spikes.
This is mainly for computational efficiency, so we don't have to iterate."""
assert np.all(np.diff(t) >= 0), "should be time-sorted"
# need to combine with previous t, i_chan, amps
try:
where_window_starts = len(self._prev_t_tcs)
t = unit_safe_cat([self._prev_t_tcs, t])
i_chan = np.concatenate([self._prev_i_chan_tcs, i_chan])
amps = np.concatenate([self._prev_amp_tcs, amps])
except AttributeError:
where_window_starts = 0
# rows=spike 2, cols=spike 1
t_diff = t[:, None] - t[None, :]
amp_diff = amps[:, None] - amps[None, :]
same_chan = i_chan[:, None] == i_chan[None, :]
collision_prob = self.collision_prob_fn(t_diff) * same_chan
# remove self-pairs
np.fill_diagonal(collision_prob, 0)
# only consider same-channel pairs, and only consider where t_spk2 >= t_spk1
collision_prob *= t_diff >= 0
# should be roughly lower triangular at this point if spikes are ordered by time
# for simultaneous spikes, make sure biggest amplitude wins
# by removing simultaneous pairs where the second is bigger
collision_prob[(t_diff == 0) & (amp_diff > 0)] = 0
which_collided = np.any(
rng.uniform(size=collision_prob.shape) < collision_prob, axis=1
)
# save t, i_chan, amps for next call
# earliest time needed from current time (more than enough for next sample)
t_needed = self.probe.sim.network.t - self._max_collision_interval(
b2.defaultclock.dt / ms, self.collision_prob_fn
)
# TODO: use searchsorted elsewhere
i_oldest_needed = max(np.searchsorted(t, t_needed) - 1, 0)
self._prev_t_tcs = t[i_oldest_needed:]
self._prev_i_chan_tcs = i_chan[i_oldest_needed:]
self._prev_amp_tcs = amps[i_oldest_needed:]
return which_collided[where_window_starts:]
[docs]
def reset(self, **kwargs) -> None:
# crucial that this be called after network restore
# since that would reset monitors
for j, mon in enumerate(self._monitors):
self._mon_spikes_already_seen[j] = mon.num_spikes
self._init_saved_vars()
[docs]
def to_neo(self) -> neo.Group:
group = neo.Group(allowed_types=[neo.SpikeTrain])
for i in set(self.i):
st = neo.SpikeTrain(
times=self.t[self.i == i] / ms * pq.ms,
t_stop=self.probe.sim.network.t / ms * pq.ms,
)
st.annotate(i=int(i))
group.add(st)
group.annotate(export_datetime=datetime.now())
group.name = f"{self.probe.name}.{self.name}"
group.description = f"Exported from Cleo {self.__class__.__name__} object"
return group
[docs]
@define(eq=False)
class MultiUnitActivity(Spiking):
"""Detects (unsorted) spikes per channel."""
collision_prob_fn: Callable[[Quantity], float] = lambda t: t < 1 * ms
@property
def n(self):
return self.probe.n
[docs]
def get_state(
self,
) -> tuple[
Int[np.ndarray, "n_spikes"], Quantity, Int[np.ndarray, "{self.n_channels}"]
]:
# inherit docstring
t_samp = self.probe.sim.network.t
i_probe, t = self._get_new_spikes()
t_tcs, _, i_chan_tcs, amp_tcs, noise, t_noise = self._noisily_get_true_tcs(
i_probe, t
)
# get false positives from noise
if self.simulate_false_positives:
i_t_fps, i_chan_fps = (noise > self.threshold_sigma).nonzero()
t_fps = t_noise[i_t_fps]
amp_fps = noise[i_t_fps, i_chan_fps]
else:
t_fps, i_chan_fps, amp_fps = [] * ms, [], []
i_chan = np.concatenate([i_chan_tcs, i_chan_fps])
t = unit_safe_cat([t_tcs, t_fps])
amps = np.concatenate([amp_tcs, amp_fps])
# sort by time
sort_idx = np.argsort(t)
t = t[sort_idx]
i_chan = i_chan[sort_idx]
amps = amps[sort_idx]
# sample collisions
which_collided = self._sample_collisions(t, i_chan, amps)
t_detected = t[~which_collided]
i_chan_detected = i_chan[~which_collided]
y = np.bincount(i_chan_detected.astype(int))
# include 0s for upper indices not seen:
y = np.concatenate([y, np.zeros(self.n_channels - len(y))])
self._update_saved_vars(t_detected, i_chan_detected, t_samp)
return i_chan_detected, t_detected, y
[docs]
def to_neo(self) -> neo.Group:
group = super(MultiUnitActivity, self).to_neo()
for st in group.spiketrains:
i = int(st.annotations["i"])
st.annotate(
i_channel=i,
x_contact=self.probe.coords[i, 0] / mm * pq.mm,
y_contact=self.probe.coords[i, 1] / mm * pq.mm,
z_contact=self.probe.coords[i, 2] / mm * pq.mm,
)
return group
[docs]
@define(eq=False)
class SortedSpiking(Spiking):
"""Detect spikes identified by neuron indices.
The indices used by the probe do not correspond to those
coming from neuron groups, since the probe must consider
multiple potential groups and within a group ignores those
neurons that are too far away to be easily detected."""
snr_cutoff: float = 6
"""The signal-to-noise ratio a unit must have for its spikes to be reported.
SNR is defined as the mean spike amplitude divided by the standard deviation of
the background noise for the peak (closest) channel.
Should be higher than :attr:`~Spiking.threshold_sigma`.
Spikes from units with SNR < snr_cutoff still factor into collision sampling and
are reported as unsorted (index -1), essentially "multi-unit activity"."""
collision_prob_fn: Callable[[Quantity], float] = lambda t: 0.2 * np.exp(
-t / (0.3 * ms)
)
@property
def n(self):
return self.n_sorted
@property
def n_sorted(self):
"""Number of sorted neurons"""
return len(self._i_probe_by_i_sorted)
[docs]
def i_sorted_by_ng(self, ng: NeuronGroup) -> Int[np.ndarray, "{ng.N}"]:
"""Get the sorted indices for a given neuron group.
-1 means recorded, but not sorted. -2 means not recorded."""
i_probe = self.i_probe_by_ng[ng][ng.i]
i_sorted = self._i_sorted_by_i_probe[i_probe]
# pass through the -2s so they don't index i_sorted
i_sorted[i_probe == -2] = -2
return i_sorted
@property
def i_ng_by_i_sorted(self) -> list[tuple[NeuronGroup, int]]:
"""Get a list of (ng, i_ng) tuples for all sorted neurons, in order.
That is, this maps from sorted indices back to the original neuron group and
indices."""
i_probe = self._i_probe_by_i_sorted[np.arange(self.n_sorted)]
return [self.i_ng_by_i_probe[i] for i in i_probe]
_i_sorted_by_i_probe: Int[np.ndarray, "{self.n_neurons}"] = field(
init=False, factory=lambda: np.zeros(0, dtype=int), repr=False
)
_i_probe_by_i_sorted: Int[np.ndarray, "{self.n_sorted}"] = field(
init=False, factory=lambda: np.zeros(0, dtype=int), repr=False
)
@property
def r_cutoff(self, resolution: Quantity = um / 10) -> Quantity:
"""The distance from a contact at which the SNR is high enough for a neuron
to be included."""
return self.r_for_snr(self.snr_cutoff, resolution=resolution)
@property
def sorted_units_snr(self) -> Float[np.ndarray, "{self.n_sorted}"]:
"""The SNR for each sorted neuron, in order."""
return self._mu_eap[self._i_probe_by_i_sorted]
[docs]
def connect_to_neuron_group(self, neuron_group, **kwparams):
snr, i_probe = super().connect_to_neuron_group(neuron_group, **kwparams)
n_recorded_ng = len(snr)
# filter by SNR (measured on peak channel)
above_cutoff = snr >= self.snr_cutoff
n_above_cutoff = np.sum(above_cutoff)
n_prev_sorted = self.n_sorted
i_srt_new_range = np.arange(self.n_sorted, n_prev_sorted + n_above_cutoff)
# update map from i_probe to i_sorted
# -1 means not reported
i_srt_by_i_probe_for_ng = np.full(n_recorded_ng, -1, dtype=int)
i_srt_by_i_probe_for_ng[above_cutoff] = i_srt_new_range
self._i_sorted_by_i_probe = np.concatenate(
[self._i_sorted_by_i_probe, i_srt_by_i_probe_for_ng]
)
# update map from i_sorted to i_probe
i_probe_by_i_sorted_for_ng = i_probe[above_cutoff]
assert len(i_probe_by_i_sorted_for_ng) == n_above_cutoff
self._i_probe_by_i_sorted = np.concatenate(
[self._i_probe_by_i_sorted, i_probe_by_i_sorted_for_ng]
)
assert n_prev_sorted + n_above_cutoff == self.n_sorted
[docs]
def get_state(
self,
) -> tuple[Int[np.ndarray, "n_spikes"], Quantity, Int[np.ndarray, "{self.n}"]]:
# inherit docstring
t_samp = self.probe.sim.network.t
i_probe, t = self._get_new_spikes()
t_tcs, i_probe_tcs, i_chan_tcs, amp_tcs, noise, t_noise = (
self._noisily_get_true_tcs(i_probe, t)
)
# get false positives from noise
if self.simulate_false_positives:
i_t_fps, i_chan_fps = (noise > self.threshold_sigma).nonzero()
t_fps = t_noise[i_t_fps]
amp_fps = noise[i_t_fps, i_chan_fps]
else:
t_fps, i_chan_fps, amp_fps = [] * ms, [], []
i_chan = np.concatenate([i_chan_tcs, i_chan_fps])
t = unit_safe_cat([t_tcs, t_fps])
amps = np.concatenate([amp_tcs, amp_fps])
i_probe = np.concatenate([i_probe_tcs, np.full_like(i_chan_fps, -3)]).astype(
int
)
# sort by time
sort_idx = np.argsort(t)
t = t[sort_idx]
i_chan = i_chan[sort_idx]
amps = amps[sort_idx]
i_probe = i_probe[sort_idx]
which_collided = self._sample_collisions(t, i_chan, amps)
# filter out false positives
t_detected = t[~which_collided & (i_probe != -3)]
i_probe_detected = i_probe[~which_collided & (i_probe != -3)]
# remove repeat t, i_nrn spikes (spikes detected on >1 channel)
_, spk_detected_any_channel = np.unique(
np.array([t_detected / ms, i_probe_detected]), axis=1, return_index=True
)
# get spikes detected on any channel
i_probe_detected = i_probe_detected[spk_detected_any_channel]
t_detected = t_detected[spk_detected_any_channel]
# convert to sorted indices, including getting -1s for unsorted
i_sorted_detected = self._i_sorted_by_i_probe[i_probe_detected]
# filter out -1s
to_keep = i_sorted_detected != -1
i_srt_dtct_filt = i_sorted_detected[to_keep]
t_dtct_filt = t_detected[to_keep]
y = np.bincount(i_srt_dtct_filt.astype(int))
# include 0s for upper indices not seen:
y = np.concatenate([y, np.zeros(self.n_sorted - len(y))])
self._update_saved_vars(t_dtct_filt, i_srt_dtct_filt, t_samp)
return i_srt_dtct_filt, t_dtct_filt, y