"""Contains opsin models, parameters, and OptogeneticIntervention device"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Tuple, Any
import warnings
from brian2 import (
Synapses,
NeuronGroup,
Unit,
BrianObjectException,
get_unit,
Equations,
)
from brian2.units import (
mm,
mm2,
nmeter,
meter,
kgram,
Quantity,
second,
ms,
second,
psiemens,
mV,
volt,
amp,
mwatt,
)
from brian2.units.allunits import meter2, radian
import brian2.units.unitsafefunctions as usf
import numpy as np
import matplotlib
from matplotlib import colors
from matplotlib.artist import Artist
from matplotlib.collections import PathCollection
from cleo.utilities import uniform_cylinder_rθz, wavelength_to_rgb, xyz_from_rθz
from cleo.stimulators import Stimulator
ChR2_four_state = {
"g0": 114000 * psiemens,
"gamma": 0.00742,
"phim": 2.33e17 / mm2 / second, # *photon, not in Brian2
"k1": 4.15 / ms,
"k2": 0.868 / ms,
"p": 0.833,
"Gf0": 0.0373 / ms,
"kf": 0.0581 / ms,
"Gb0": 0.0161 / ms,
"kb": 0.063 / ms,
"q": 1.94,
"Gd1": 0.105 / ms,
"Gd2": 0.0138 / ms,
"Gr0": 0.00033 / ms,
"E": 0 * mV,
"v0": 43 * mV,
"v1": 17.1 * mV,
}
"""Parameters for the 4-state ChR2 model.
Taken from try.projectpyrho.org's default 4-state params.
"""
[docs]class OpsinModel(ABC):
"""Base class for opsin model"""
model: str
"""Basic Brian model equations string.
Should contain a `rho_rel` term reflecting relative expression
levels. Will likely also contain special NeuronGroup-dependent
symbols such as V_VAR_NAME to be replaced on injection in
:meth:`~OpsinModel.modify_model_and_params_for_ng`."""
params: dict
"""Parameter values for model, passed in as a namespace dict"""
required_vars: list[Tuple[str, Unit]]
"""Default names of state variables required in the neuron group,
along with units, e.g., [('Iopto', amp)].
It is assumed that non-default values can be passed in on injection
as a keyword argument ``[default_name]_var_name=[non_default_name]``
and that these are found in the model string as
``[DEFAULT_NAME]_VAR_NAME`` before replacement."""
[docs] def modify_model_and_params_for_ng(
self, neuron_group: NeuronGroup, injct_params: dict, model="class-defined"
) -> Tuple[Equations, dict]:
"""Adapt model for given neuron group on injection
This enables the specification of variable names
differently for each neuron group, allowing for custom names
and avoiding conflicts.
Parameters
----------
neuron_group : NeuronGroup
NeuronGroup this opsin model is being connected to
injct_params : dict
kwargs passed in on injection, could contain variable
names to plug into the model
Keyword Args
------------
model : str, optional
Model to start with, by default that defined for the class.
This allows for prior string manipulations before it can be
parsed as an `Equations` object.
Returns
-------
Equations, dict
A tuple containing an Equations object
and a parameter dictionary, constructed from :attr:`~model`
and :attr:`~params`, respectively, with modified names for use
in :attr:`~cleo.opto.OptogeneticIntervention.opto_syns`
"""
if model == "class-defined":
model = self.model
for default_name, unit in self.required_vars:
var_name = injct_params.get(f"{default_name}_var_name", default_name)
if var_name not in neuron_group.variables or not neuron_group.variables[
var_name
].unit.has_same_dimensions(unit):
raise BrianObjectException(
(
f"{var_name} : {unit.name} needed in the model of NeuronGroup "
f"{neuron_group.name} to connect OptogeneticIntervention."
),
neuron_group,
)
# opsin synapse model needs modified names
to_replace = f"{default_name}_var_name".upper()
model = model.replace(to_replace, var_name)
# Synapse variable and parameter names cannot be the same as any
# neuron group variable name
return self._fix_name_conflicts(model, neuron_group)
def _fix_name_conflicts(
self, modified_model: str, neuron_group: NeuronGroup
) -> Tuple[str, dict]:
modified_params = self.params.copy()
rename = lambda x: f"{x}_syn"
# get variables to rename
opsin_eqs = Equations(modified_model)
substitutions = {}
for var in opsin_eqs.names:
if var in neuron_group.variables:
substitutions[var] = rename(var)
# and parameters
for param in self.params.keys():
if param in neuron_group.variables:
substitutions[param] = rename(param)
modified_params[rename(param)] = modified_params[param]
del modified_params[param]
mod_opsin_eqs = opsin_eqs.substitute(**substitutions)
return mod_opsin_eqs, modified_params
[docs] def init_opto_syn_vars(self, opto_syn: Synapses) -> None:
"""Initializes appropriate variables in Synapses implementing the model
Can also be used to reset the variables.
Parameters
----------
opto_syn : Synapses
The synapses object implementing this model
"""
pass
[docs]class MarkovModel(OpsinModel):
"""Base class for Markov state models à la Evans et al., 2016"""
required_vars: list[Tuple[str, Unit]] = [("Iopto", amp), ("v", volt)]
def __init__(self, params: dict) -> None:
"""
Parameters
----------
params : dict
dict defining params in the :attr:`model`
"""
super().__init__()
self.params = params
[docs]class FourStateModel(MarkovModel):
"""4-state model from PyRhO (Evans et al. 2016).
rho_rel is channel density relative to standard model fit;
modifying it post-injection allows for heterogeneous opsin expression.
IOPTO_VAR_NAME and V_VAR_NAME are substituted on injection.
"""
model: str = """
dC1/dt = Gd1*O1 + Gr0*C2 - Ga1*C1 : 1 (clock-driven)
dO1/dt = Ga1*C1 + Gb*O2 - (Gd1+Gf)*O1 : 1 (clock-driven)
dO2/dt = Ga2*C2 + Gf*O1 - (Gd2+Gb)*O2 : 1 (clock-driven)
C2 = 1 - C1 - O1 - O2 : 1
# dC2/dt = Gd2*O2 - (Gr0+Ga2)*C2 : 1 (clock-driven)
Theta = int(phi > 0*phi) : 1
Hp = Theta * phi**p/(phi**p + phim**p) : 1
Ga1 = k1*Hp : hertz
Ga2 = k2*Hp : hertz
Hq = Theta * phi**q/(phi**q + phim**q) : 1
Gf = kf*Hq + Gf0 : hertz
Gb = kb*Hq + Gb0 : hertz
fphi = O1 + gamma*O2 : 1
fv = (1 - exp(-(V_VAR_NAME_post-E)/v0)) / -2 : 1
IOPTO_VAR_NAME_post = -g0*fphi*fv*(V_VAR_NAME_post-E)*rho_rel : ampere (summed)
rho_rel : 1
"""
[docs] def init_opto_syn_vars(self, opto_syn: Synapses) -> None:
for varname, value in {"Irr0": 0, "C1": 1, "O1": 0, "O2": 0}.items():
setattr(opto_syn, varname, value)
[docs]class ProportionalCurrentModel(OpsinModel):
"""A simple model delivering current proportional to light intensity"""
# would be IOPTO_UNIT but that throws off Equation parsing
model: str = """
IOPTO_VAR_NAME_post = gain * Irr * rho_rel : IOPTO_UNIT (summed)
rho_rel : 1
"""
def __init__(self, Iopto_per_mW_per_mm2: Quantity) -> None:
"""
Parameters
----------
Iopto_per_mW_per_mm2 : Quantity
How much current (in amps or unitless, depending on neuron model) to
deliver per mW/mm2
"""
self.params = {"gain": Iopto_per_mW_per_mm2 / (mwatt / mm2)}
if isinstance(Iopto_per_mW_per_mm2, Quantity):
self._Iopto_unit = get_unit(Iopto_per_mW_per_mm2.dim)
else:
self._Iopto_unit = radian
self.required_vars = [("Iopto", self._Iopto_unit)]
[docs] def modify_model_and_params_for_ng(
self, neuron_group: NeuronGroup, injct_params: dict
) -> Tuple[Equations, dict]:
mod_model = self.model.replace("IOPTO_UNIT", self._Iopto_unit.name)
return super().modify_model_and_params_for_ng(
neuron_group, injct_params, model=mod_model
)
default_blue = {
"R0": 0.1 * mm, # optical fiber radius
"NAfib": 0.37, # optical fiber numerical aperture
"wavelength": 473 * nmeter,
# NOTE: the following depend on wavelength and tissue properties and thus would be different for another wavelength
"K": 0.125 / mm, # absorbance coefficient
"S": 7.37 / mm, # scattering coefficient
"ntis": 1.36, # tissue index of refraction
}
"""Light parameters for 473 nm wavelength delivered via an optic fiber.
From Foutz et al., 2012"""
[docs]class OptogeneticIntervention(Stimulator):
"""Enables optogenetic stimulation of the network.
Essentially "transfects" neurons and provides a light source.
Under the hood, it delivers current via a Brian :class:`~brian2.synapses.synapses.Synapses`
object.
Requires neurons to have 3D spatial coordinates already assigned.
Also requires that the neuron model has a current term
(by default Iopto) which is assumed to be positive (unlike the
convention in many opsin modeling papers, where the current is
described as negative).
See :meth:`connect_to_neuron_group` for optional keyword parameters
that can be specified when calling
:meth:`cleo.CLSimulator.inject_stimulator`.
Visualization kwargs
--------------------
n_points : int, optional
The number of points used to represent light intensity in space.
By default 1e4.
T_threshold : float, optional
The transmittance below which no points are plotted. By default
1e-3.
intensity : float, optional
How bright the light appears, should be between 0 and 1. By default 0.5.
rasterized : bool, optional
Whether to render as rasterized in vector output, True by default.
Useful since so many points makes later rendering and editing slow.
"""
opto_syns: dict[str, Synapses]
"""Stores the synapse objects implementing the opsin model,
with NeuronGroup name keys and Synapse values."""
max_Irr0_mW_per_mm2: float
"""The maximum irradiance the light source can emit.
Usually determined by hardware in a real experiment."""
max_Irr0_mW_per_mm2_viz: float
"""Maximum irradiance for visualization purposes.
i.e., the level at or above which the light appears maximally bright.
Only relevant in video visualization.
"""
def __init__(
self,
name: str,
opsin_model: OpsinModel,
light_model_params: dict,
location: Quantity = (0, 0, 0) * mm,
direction: Tuple[float, float, float] = (0, 0, 1),
max_Irr0_mW_per_mm2: float = None,
save_history: bool = False,
):
"""
Parameters
----------
name : str
Unique identifier for stimulator
opsin_model : OpsinModel
OpsinModel object defining how light affects target
neurons. See :class:`FourStateModel` and :class:`ProportionalCurrentModel`
for examples.
light_model_params : dict
Parameters for the light propagation model in Foutz et al., 2012.
See :attr:`default_blue` for an example.
location : Quantity, optional
(x, y, z) coords with Brian unit specifying where to place
the base of the light source, by default (0, 0, 0)*mm
direction : Tuple[float, float, float], optional
(x, y, z) vector specifying direction in which light
source is pointing, by default (0, 0, 1)
max_Irr0_mW_per_mm2 : float, optional
Set :attr:`max_Irr0_mW_per_mm2`.
save_history : bool, optional
Determines whether :attr:`~values` and :attr:`~t_ms` are saved.
"""
super().__init__(name, 0, save_history)
self.opsin_model = opsin_model
self.light_model_params = light_model_params
self.location = location
# direction unit vector
self.dir_uvec = direction / np.linalg.norm(direction)
self.opto_syns = {}
self.max_Irr0_mW_per_mm2 = max_Irr0_mW_per_mm2
self.max_Irr0_mW_per_mm2_viz = None
def _Foutz12_transmittance(self, r, z, scatter=True, spread=True, gaussian=True):
"""Foutz et al. 2012 transmittance model: Gaussian cone with Kubelka-Munk propagation"""
if spread:
# divergence half-angle of cone
theta_div = np.arcsin(
self.light_model_params["NAfib"] / self.light_model_params["ntis"]
)
Rz = self.light_model_params["R0"] + z * np.tan(
theta_div
) # radius as light spreads ("apparent radius" from original code)
C = (self.light_model_params["R0"] / Rz) ** 2
else:
Rz = self.light_model_params["R0"] # "apparent radius"
C = 1
if gaussian:
G = 1 / np.sqrt(2 * np.pi) * np.exp(-2 * (r / Rz) ** 2)
else:
G = 1
if scatter:
S = self.light_model_params["S"]
a = 1 + self.light_model_params["K"] / S
b = np.sqrt(a**2 - 1)
dist = np.sqrt(r**2 + z**2)
M = b / (a * np.sinh(b * S * dist) + b * np.cosh(b * S * dist))
else:
M = 1
T = G * C * M
T[z < 0] = 0
return T
def _get_rz_for_xyz(self, x, y, z):
"""Assumes x, y, z already have units"""
# have to add unit back on since it's stripped by vstack
coords = np.column_stack([x, y, z]) * meter
rel_coords = coords - self.location # relative to fiber location
# must use brian2's dot function for matrix multiply to preserve
# units correctly.
zc = usf.dot(rel_coords, self.dir_uvec) # distance along cylinder axis
# just need length (norm) of radius vectors
# not using np.linalg.norm because it strips units
r = np.sqrt(
np.sum((rel_coords - zc[..., np.newaxis] * self.dir_uvec.T) ** 2, axis=1)
)
return r, zc
[docs] def connect_to_neuron_group(
self, neuron_group: NeuronGroup, **kwparams: Any
) -> None:
"""Configure opsin and light source to stimulate given neuron group.
Parameters
----------
neuron_group : NeuronGroup
The neuron group to stimulate with the given opsin and light source
Keyword args
------------
p_expression : float
Probability (0 <= p <= 1) that a given neuron in the group
will express the opsin. 1 by default.
rho_rel : float
The expression level, relative to the standard model fit,
of the opsin. 1 by default. For heterogeneous expression,
this would have to be modified in the opsin synapse post-injection,
e.g., ``opto.opto_syns["neuron_group_name"].rho_rel = ...``
Iopto_var_name : str
The name of the variable in the neuron group model representing
current from the opsin
v_var_name : str
The name of the variable in the neuron group model representing
membrane potential
"""
# get modified opsin model string (i.e., with names/units specified)
(
mod_opsin_model,
mod_opsin_params,
) = self.opsin_model.modify_model_and_params_for_ng(neuron_group, kwparams)
# fmt: off
# Ephoton = h*c/lambda
E_photon = (
6.63e-34 * meter2 * kgram / second
* 2.998e8 * meter / second
/ self.light_model_params["wavelength"]
)
# fmt: on
light_model = Equations(
"""
Irr = Irr0*T : watt/meter**2
Irr0 : watt/meter**2
T : 1
phi = Irr / Ephoton : 1/second/meter**2
"""
)
opto_syn = Synapses(
neuron_group,
model=mod_opsin_model + light_model,
namespace=mod_opsin_params,
name=f"synapses_{self.name}_{neuron_group.name}",
method="rk2",
)
opto_syn.namespace["Ephoton"] = E_photon
p_expression = kwparams.get("p_expression", 1)
if p_expression == 1:
opto_syn.connect(j="i")
else:
opto_syn.connect(condition="i==j", p=p_expression)
self.opsin_model.init_opto_syn_vars(opto_syn)
# relative channel density
opto_syn.rho_rel = kwparams.get("rho_rel", 1)
# calculate transmittance coefficient for each point
r, z = self._get_rz_for_xyz(neuron_group.x, neuron_group.y, neuron_group.z)
T = self._Foutz12_transmittance(r, z).flatten()
assert len(T) == len(neuron_group)
# reduce to subset expressing opsin before assigning
T = T[opto_syn.i]
opto_syn.T = T
self.opto_syns[neuron_group.name] = opto_syn
self.brian_objects.add(opto_syn)
[docs] def add_self_to_plot(self, ax, axis_scale_unit, **kwargs) -> PathCollection:
# show light with point field, assigning r and z coordinates
# to all points
# filter out points with <0.001 transmittance to make plotting faster
T_threshold = kwargs.get("T_threshold", 0.001)
n_points = kwargs.get("n_points", 1e4)
intensity = kwargs.get("intensity", 0.5)
r_thresh, zc_thresh = self._find_rz_thresholds(T_threshold)
r, theta, zc = uniform_cylinder_rθz(n_points, r_thresh, zc_thresh)
T = self._Foutz12_transmittance(r, zc)
end = self.location + zc_thresh * self.dir_uvec
x, y, z = xyz_from_rθz(r, theta, zc, self.location, end)
idx_to_plot = T >= T_threshold
x = x[idx_to_plot]
y = y[idx_to_plot]
z = z[idx_to_plot]
T = T[idx_to_plot]
point_cloud = ax.scatter(
x / axis_scale_unit,
y / axis_scale_unit,
z / axis_scale_unit,
c=T,
cmap=self._alpha_cmap_for_wavelength(intensity),
marker="o",
edgecolors="none",
label=self.name,
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=".*Rasterization.*will be ignored.*"
)
# to make manageable in SVGs
point_cloud.set_rasterized(kwargs.get("rasterized", True))
handles = ax.get_legend().legendHandles
c = wavelength_to_rgb(self.light_model_params["wavelength"] / nmeter)
opto_patch = matplotlib.patches.Patch(color=c, label=self.name)
handles.append(opto_patch)
ax.legend(handles=handles)
return [point_cloud]
def _find_rz_thresholds(self, thresh):
"""find r and z thresholds for visualization purposes"""
res_mm = 0.01
zc = np.arange(20, 0, -res_mm) * mm # ascending T
T = self._Foutz12_transmittance(0 * mm, zc)
zc_thresh = zc[np.searchsorted(T, thresh)]
# look at half the z threshold for the r threshold
r = np.arange(20, 0, -res_mm) * mm
T = self._Foutz12_transmittance(r, zc_thresh / 2)
r_thresh = r[np.searchsorted(T, thresh)]
# multiply by 1.2 just in case
return r_thresh * 1.2, zc_thresh
[docs] def update_artists(
self, artists: list[Artist], value, *args, **kwargs
) -> list[Artist]:
self._prev_value = getattr(self, "_prev_value", None)
if value == self._prev_value:
return []
assert len(artists) == 1
point_cloud = artists[0]
if self.max_Irr0_mW_per_mm2_viz is not None:
max_Irr0 = self.max_Irr0_mW_per_mm2_viz
elif self.max_Irr0_mW_per_mm2 is not None:
max_Irr0 = self.max_Irr0_mW_per_mm2
else:
raise Exception(
f"OptogeneticIntervention '{self.name}' needs max_Irr0_mW_per_mm2_viz "
"or max_Irr0_mW_per_mm2 "
"set to visualize light intensity."
)
intensity = value / max_Irr0 if value <= max_Irr0 else max_Irr0
point_cloud.set_cmap(self._alpha_cmap_for_wavelength(intensity))
return [point_cloud]
[docs] def update(self, Irr0_mW_per_mm2: float):
"""Set the light intensity, in mW/mm2 (without unit)
Parameters
----------
Irr0_mW_per_mm2 : float
Desired light intensity for light source
"""
if Irr0_mW_per_mm2 < 0:
warnings.warn(f"{self.name}: negative light intensity Irr0 clipped to 0")
Irr0_mW_per_mm2 = 0
if (
self.max_Irr0_mW_per_mm2 is not None
and Irr0_mW_per_mm2 > self.max_Irr0_mW_per_mm2
):
Irr0_mW_per_mm2 = self.max_Irr0_mW_per_mm2
super().update(Irr0_mW_per_mm2)
for opto_syn in self.opto_syns.values():
opto_syn.Irr0 = Irr0_mW_per_mm2 * mwatt / mm2
[docs] def reset(self, **kwargs):
for opto_syn in self.opto_syns.values():
self.opsin_model.init_opto_syn_vars(opto_syn)
def _alpha_cmap_for_wavelength(self, intensity=0.5):
c = wavelength_to_rgb(self.light_model_params["wavelength"] / nmeter)
c_clear = (*c, 0)
c_opaque = (*c, 0.6 * intensity)
return colors.LinearSegmentedColormap.from_list(
"incr_alpha", [(0, c_clear), (1, c_opaque)]
)