Source code for cleo.opto.opsins

"""Contains opsin models and default parameters"""
from __future__ import annotations
from typing import Callable, Tuple
import warnings

from attrs import define, field, asdict, fields_dict
from brian2 import (
    Synapses,
    Function,
    NeuronGroup,
    Unit,
    BrianObjectException,
    get_unit,
    Equations,
    implementation,
    check_units,
)
from nptyping import NDArray
from brian2.units import (
    mm,
    mm2,
    nmeter,
    Quantity,
    second,
    ms,
    second,
    psiemens,
    nsiemens,
    mV,
    volt,
    amp,
    mM,
)
from brian2.units.allunits import radian
import numpy as np
from scipy.interpolate import CubicSpline

from cleo.base import InterfaceDevice
from cleo.coords import assign_coords
from cleo.opto.registry import lor_for_sim
from cleo.utilities import wavelength_to_rgb


[docs]def linear_interpolator(lambdas_nm, epsilons, lambda_new_nm): return np.interp(lambda_new_nm, lambdas_nm, epsilons)
[docs]def cubic_interpolator(lambdas_nm, epsilons, lambda_new_nm): return CubicSpline(lambdas_nm, epsilons)(lambda_new_nm)
[docs]@define(eq=False) class Opsin(InterfaceDevice): """Base class for opsin model. We approximate dynamics under multiple wavelengths using a weighted sum of photon fluxes, where the :math:`\\varepsilon` factor indicates the activation relative to the peak-sensitivy wavelength for an equivalent number of photons (see Mager et al, 2018). This weighted sum is an approximation of a nonlinear peak-non-peak wavelength relation; see notebooks/multi_wavelength_model.ipynb for details.""" model: str = field(init=False) """Basic Brian model equations string. Should contain a `rho_rel` term reflecting relative expression levels. Will likely also contain special NeuronGroup-dependent symbols such as V_VAR_NAME to be replaced on injection in :meth:`~OpsinModel.modify_model_and_params_for_ng`.""" per_ng_unit_replacements: list[Tuple[str, str]] = field( factory=list, init=False, repr=False ) """List of (UNIT_NAME, neuron_group_specific_unit_name) tuples to be substituted in the model string on injection and before checking required variables.""" required_vars: list[Tuple[str, Unit]] = field(factory=list, init=False, repr=False) """Default names of state variables required in the neuron group, along with units, e.g., [('Iopto', amp)]. It is assumed that non-default values can be passed in on injection as a keyword argument ``[default_name]_var_name=[non_default_name]`` and that these are found in the model string as ``[DEFAULT_NAME]_VAR_NAME`` before replacement.""" light_agg_ngs: dict[str, NeuronGroup] = field(factory=dict, init=False, repr=False) """{target_ng.name: light_agg_ng} dict of light aggregator neuron groups.""" opto_syns: dict[NeuronGroup, Synapses] = field(factory=dict, init=False, repr=False) """Stores the synapse objects implementing the opsin model, with NeuronGroup keys and Synapse values.""" action_spectrum: list[tuple[float, float]] = field( factory=lambda: [(-1e10, 1), (1e10, 1)] ) """List of (wavelength, epsilon) tuples representing the action spectrum.""" action_spectrum_interpolator: Callable = field( default=cubic_interpolator, repr=False ) """Function of signature (lambdas_nm, epsilons, lambda_new_nm) that interpolates the action spectrum data and returns :math:`\\varepsilon \\in [0,1]` for the new wavelength.""" extra_namespace: dict = field(factory=dict, repr=False) """Additional items (beyond parameters) to be added to the opto synapse namespace"""
[docs] def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams) -> None: """Transfect neuron group with opsin. Parameters ---------- neuron_group : NeuronGroup The neuron group to stimulate with the given opsin and light source Keyword args ------------ p_expression : float Probability (0 <= p <= 1) that a given neuron in the group will express the opsin. 1 by default. rho_rel : float The expression level, relative to the standard model fit, of the opsin. 1 by default. For heterogeneous expression, this would have to be modified in the opsin synapse post-injection, e.g., ``opto.opto_syns["neuron_group_name"].rho_rel = ...`` Iopto_var_name : str The name of the variable in the neuron group model representing current from the opsin v_var_name : str The name of the variable in the neuron group model representing membrane potential """ if neuron_group.name in self.light_agg_ngs: assert neuron_group.name in self.opto_syns raise ValueError( f"Opsin {self.name} already connected to neuron group {neuron_group.name}" ) # get modified opsin model string (i.e., with names/units specified) mod_opsin_model, mod_opsin_params = self.modify_model_and_params_for_ng( neuron_group, kwparams ) # handle p_expression p_expression = kwparams.get("p_expression", 1) i_expression_bool = np.random.rand(neuron_group.N) < p_expression i_expression = np.where(i_expression_bool)[0] if len(i_expression) == 0: return # create light aggregator neurons light_agg_ng = NeuronGroup( len(i_expression), model=""" phi : 1/second/meter**2 Irr : watt/meter**2 """, name=f"light_agg_{self.name}_{neuron_group.name}", ) assign_coords( light_agg_ng, neuron_group.x[i_expression] / mm, neuron_group.y[i_expression] / mm, neuron_group.z[i_expression] / mm, unit=mm, ) # create opsin synapses opto_syn = Synapses( light_agg_ng, neuron_group, model=mod_opsin_model, namespace=mod_opsin_params, name=f"opto_syn_{self.name}_{neuron_group.name}", ) opto_syn.namespace.update(self.extra_namespace) opto_syn.connect(i=range(len(i_expression)), j=i_expression) self.init_opto_syn_vars(opto_syn) # relative channel density opto_syn.rho_rel = kwparams.get("rho_rel", 1) # store at the end, after all checks have passed self.light_agg_ngs[neuron_group.name] = light_agg_ng self.brian_objects.add(light_agg_ng) self.opto_syns[neuron_group.name] = opto_syn self.brian_objects.add(opto_syn) lor = lor_for_sim(self.sim) lor.register_opsin(self, neuron_group)
[docs] def modify_model_and_params_for_ng( self, neuron_group: NeuronGroup, injct_params: dict ) -> Tuple[Equations, dict]: """Adapt model for given neuron group on injection This enables the specification of variable names differently for each neuron group, allowing for custom names and avoiding conflicts. Parameters ---------- neuron_group : NeuronGroup NeuronGroup this opsin model is being connected to injct_params : dict kwargs passed in on injection, could contain variable names to plug into the model Keyword Args ------------ model : str, optional Model to start with, by default that defined for the class. This allows for prior string manipulations before it can be parsed as an `Equations` object. Returns ------- Equations, dict A tuple containing an Equations object and a parameter dictionary, constructed from :attr:`~model` and :attr:`~params`, respectively, with modified names for use in :attr:`~cleo.opto.OptogeneticIntervention.opto_syns` """ model = self.model # perform unit substitutions for unit_name, neuron_group_unit_name in self.per_ng_unit_replacements: model = model.replace(unit_name, neuron_group_unit_name) # check required variables/units and replace placeholder names for default_name, unit in self.required_vars: var_name = injct_params.get(f"{default_name}_var_name", default_name) if var_name not in neuron_group.variables or not neuron_group.variables[ var_name ].unit.has_same_dimensions(unit): raise BrianObjectException( ( f"{var_name} : {unit.name} needed in the model of NeuronGroup " f"{neuron_group.name} to connect Opsin {self.name}." ), neuron_group, ) # opsin synapse model needs modified names to_replace = f"{default_name}_var_name".upper() model = model.replace(to_replace, var_name) # Synapse variable and parameter names cannot be the same as any # neuron group variable name return self._fix_name_conflicts(model, neuron_group)
@property def params(self) -> dict: """Returns a dictionary of parameters for the model""" params = asdict(self, recurse=False) # remove generic fields that are not parameters for field in fields_dict(Opsin): params.pop(field) # remove private attributes for key in list(params.keys()): if key.startswith("_"): params.pop(key) return params def _fix_name_conflicts( self, modified_model: str, neuron_group: NeuronGroup ) -> Tuple[Equations, dict]: modified_params = self.params.copy() rename = lambda x: f"{x}_syn" # get variables to rename opsin_eqs = Equations(modified_model) substitutions = {} for var in opsin_eqs.names: if var in neuron_group.variables: substitutions[var] = rename(var) # and parameters for param in self.params.keys(): if param in neuron_group.variables: substitutions[param] = rename(param) modified_params[rename(param)] = modified_params[param] del modified_params[param] mod_opsin_eqs = opsin_eqs.substitute(**substitutions) return mod_opsin_eqs, modified_params
[docs] def reset(self, **kwargs): for opto_syn in self.opto_syns.values(): self.init_opto_syn_vars(opto_syn)
[docs] def init_opto_syn_vars(self, opto_syn: Synapses) -> None: """Initializes appropriate variables in Synapses implementing the model Can also be used to reset the variables. Parameters ---------- opto_syn : Synapses The synapses object implementing this model """ pass
[docs] def epsilon(self, lambda_new) -> float: """Returns the epsilon value for a given lambda (in nm) representing the relative sensitivity of the opsin to that wavelength.""" action_spectrum = np.array(self.action_spectrum) lambdas = action_spectrum[:, 0] epsilons = action_spectrum[:, 1] if lambda_new < min(lambdas) or lambda_new > max(lambdas): warnings.warn( f"λ = {lambda_new} nm is outside the range of the action spectrum data" f" for {self.name}. Assuming ε = 0." ) return 0 return self.action_spectrum_interpolator(lambdas, epsilons, lambda_new)
[docs]def plot_action_spectra(*opsins: Opsin): import matplotlib.pyplot as plt fig, ax = plt.subplots() for opsin in opsins: action_spectrum = np.array(opsin.action_spectrum) lambdas = action_spectrum[:, 0] epsilons = action_spectrum[:, 1] lambdas_new = np.linspace(min(lambdas), max(lambdas), 100) epsilons_new = opsin.action_spectrum_interpolator( lambdas, epsilons, lambdas_new ) c_points = [wavelength_to_rgb(l) for l in lambdas] c_line = wavelength_to_rgb(lambdas_new[np.argmax(epsilons_new)]) ax.plot(lambdas_new, epsilons_new, c=c_line, label=opsin.name) ax.scatter(lambdas, epsilons, marker="o", s=50, color=c_points) title = "Action spectra" if len(opsins) > 1 else f"Action spectrum" ax.set(xlabel="λ (nm)", ylabel="ε", title=title) fig.legend() return fig, ax
@define(eq=False) class MarkovOpsin(Opsin): """Base class for Markov state models à la Evans et al., 2016""" required_vars: list[Tuple[str, Unit]] = field( factory=lambda: [("Iopto", amp), ("v", volt)], init=False, ) @implementation( "cython", """ cdef double f_unless_x0(double f, double x, double f_when_x0): if x == 0: return f_when_x0 else: return f """, ) @check_units(f=1, x=volt, f_when_0=1, result=1) def f_unless_x0(f, x, f_when_x0): f[x == 0] = f_when_x0 return f
[docs]@define(eq=False) class FourStateOpsin(MarkovOpsin): """4-state model from PyRhO (Evans et al. 2016). rho_rel is channel density relative to standard model fit; modifying it post-injection allows for heterogeneous opsin expression. IOPTO_VAR_NAME and V_VAR_NAME are substituted on injection. Defaults are for ChR2. """ g0: Quantity = 114000 * psiemens gamma: Quantity = 0.00742 phim: Quantity = 2.33e17 / mm2 / second # *photon, not in Brian2 k1: Quantity = 4.15 / ms k2: Quantity = 0.868 / ms p: Quantity = 0.833 Gf0: Quantity = 0.0373 / ms kf: Quantity = 0.0581 / ms Gb0: Quantity = 0.0161 / ms kb: Quantity = 0.063 / ms q: Quantity = 1.94 Gd1: Quantity = 0.105 / ms Gd2: Quantity = 0.0138 / ms Gr0: Quantity = 0.00033 / ms E: Quantity = 0 * mV v0: Quantity = 43 * mV v1: Quantity = 17.1 * mV model: str = field( init=False, default=""" dC1/dt = Gd1*O1 + Gr0*C2 - Ga1*C1 : 1 (clock-driven) dO1/dt = Ga1*C1 + Gb*O2 - (Gd1+Gf)*O1 : 1 (clock-driven) dO2/dt = Ga2*C2 + Gf*O1 - (Gd2+Gb)*O2 : 1 (clock-driven) C2 = 1 - C1 - O1 - O2 : 1 Theta = int(phi_pre > 0*phi_pre) : 1 Hp = Theta * phi_pre**p/(phi_pre**p + phim**p) : 1 Ga1 = k1*Hp : hertz Ga2 = k2*Hp : hertz Hq = Theta * phi_pre**q/(phi_pre**q + phim**q) : 1 Gf = kf*Hq + Gf0 : hertz Gb = kb*Hq + Gb0 : hertz fphi = O1 + gamma*O2 : 1 # TODO: get this voltage dependence right # v1/v0 when v-E == 0 via l'Hopital's rule # fv = (1 - exp(-(V_VAR_NAME_post-E)/v0)) / -2 : 1 fv = f_unless_x0( (1 - exp(-(V_VAR_NAME_post-E)/v0)) / ((V_VAR_NAME_post-E)/v1), V_VAR_NAME_post - E, v1/v0 ) : 1 IOPTO_VAR_NAME_post = -g0*fphi*fv*(V_VAR_NAME_post-E)*rho_rel : ampere (summed) rho_rel : 1""", ) extra_namespace: dict[str, Any] = field( init=False, factory=lambda: {"f_unless_x0": f_unless_x0} )
[docs] def init_opto_syn_vars(self, opto_syn: Synapses) -> None: for varname, value in {"C1": 1, "O1": 0, "O2": 0}.items(): setattr(opto_syn, varname, value)
[docs]@define(eq=False) class BansalFourStateOpsin(MarkovOpsin): """4-state model from Bansal et al. 2020. The difference from the PyRhO model is that there is no voltage dependence. rho_rel is channel density relative to standard model fit; modifying it post-injection allows for heterogeneous opsin expression. IOPTO_VAR_NAME and V_VAR_NAME are substituted on injection. """ Gd1: Quantity = 0.066 / ms Gd2: Quantity = 0.01 / ms Gr0: Quantity = 3.33e-4 / ms g0: Quantity = 3.2 * nsiemens phim: Quantity = 1e16 / mm2 / second # *photon, not in Brian2 k1: Quantity = 0.4 / ms k2: Quantity = 0.12 / ms Gf0: Quantity = 0.018 / ms Gb0: Quantity = 0.008 / ms kf: Quantity = 0.01 / ms kb: Quantity = 0.008 / ms gamma: Quantity = 0.05 p: Quantity = 1 q: Quantity = 1 E: Quantity = 0 * mV model: str = field( init=False, default=""" dC1/dt = Gd1*O1 + Gr0*C2 - Ga1*C1 : 1 (clock-driven) dO1/dt = Ga1*C1 + Gb*O2 - (Gd1+Gf)*O1 : 1 (clock-driven) dO2/dt = Ga2*C2 + Gf*O1 - (Gd2+Gb)*O2 : 1 (clock-driven) C2 = 1 - C1 - O1 - O2 : 1 Theta = int(phi_pre > 0*phi_pre) : 1 Hp = Theta * phi_pre**p/(phi_pre**p + phim**p) : 1 Ga1 = k1*Hp : hertz Ga2 = k2*Hp : hertz Hq = Theta * phi_pre**q/(phi_pre**q + phim**q) : 1 Gf = kf*Hq + Gf0 : hertz Gb = kb*Hq + Gb0 : hertz fphi = O1 + gamma*O2 : 1 IOPTO_VAR_NAME_post = -g0*fphi*(V_VAR_NAME_post-E)*rho_rel : ampere (summed) rho_rel : 1""", )
[docs] def init_opto_syn_vars(self, opto_syn: Synapses) -> None: for varname, value in {"C1": 1, "O1": 0, "O2": 0}.items(): setattr(opto_syn, varname, value)
@define(eq=False) class BansalThreeStatePump(MarkovOpsin): """3-state model from `Bansal et al. 2020 <10.1016/j.neuroscience.2020.09.022>`_. rho_rel is channel density relative to standard model fit; modifying it post-injection allows for heterogeneous opsin expression. IOPTO_VAR_NAME and V_VAR_NAME are substituted on injection. """ Gd: Quantity = 0 Gr: Quantity = 0 ka: Quantity = 0 p: Quantity = 0 q: Quantity = 0 phim: Quantity = 0 E: Quantity = 0 a: Quantity = 0 b: Quantity = 0 vartheta_max = 5 * mM kd = 16 * mM model: str = field( init=False, default=""" dP0/dt = Gr*P6 - Ga*P0 : 1 (clock-driven) dP4/dt = Ga*P0 - Gd*P4 : 1 (clock-driven) P6 = 1 - P0 - P4 : 1 Theta = int(phi_pre > 0*phi_pre) : 1 Hp = Theta * phi_pre**p/(phi_pre**p + phim**p) : 1 Ga = ka*Hp : hertz fphi = P4 : 1 dCl_in/dt = a*(I_i + b*I_Cl_leak) : molar Cl_out : molar E_Cl = -26.67*mV * log(Cl_out/Cl_in) : volt I_Cl_leak = g_Cl * (E_Cl0 - E_Cl) Psi = vartheta_max*Cl_out / (kd + Cl_out) / 4.43 : 1 I_i = fphi*(V_VAR_NAME_post-E)*Psi*rho_rel IOPTO_VAR_NAME_post = -(I_i + I_Cl_leak) : ampere (summed) rho_rel : 1""", ) extra_namespace: dict[str, Any] = field( init=False, factory=lambda: {"E_Cl0": -70 * mV, "g_Cl": 2.3 * msiemens / cm2} ) def init_opto_syn_vars(self, opto_syn: Synapses) -> None: raise NotImplementedError("Still need to figure out [Cl-_out]") opto_syn.P0 = 1 opto_syn.P4 = 0 opto_syn.P6 = 0 opto_syn.Cl_out = 124 * mM opto_syn.Cl_in = np.exp(np.log(124) - 70 / 26.67) * mM
[docs]@define(eq=False) class ProportionalCurrentOpsin(Opsin): """A simple model delivering current proportional to light intensity""" I_per_Irr: Quantity = field(kw_only=True) """ How much current (in amps or unitless, depending on neuron model) to deliver per mW/mm2. """ # would be IOPTO_UNIT but that throws off Equation parsing model: str = field( init=False, default=""" IOPTO_VAR_NAME_post = I_per_Irr / (mwatt / mm2) * Irr_pre * rho_rel : IOPTO_UNIT (summed) rho_rel : 1 """, ) required_vars: list[Tuple[str, Unit]] = field(factory=list, init=False) def __attrs_post_init__(self): if isinstance(self.I_per_Irr, Quantity): Iopto_unit = get_unit(self.I_per_Irr.dim) else: Iopto_unit = radian self.per_ng_unit_replacements = [("IOPTO_UNIT", Iopto_unit.name)] self.required_vars = [("Iopto", Iopto_unit)]