"""Assorted utilities for developers."""
import warnings
from collections.abc import MutableMapping
import brian2 as b2
import neo
import quantities as pq
from brian2 import Quantity, np, second
from brian2.equations.equations import (
DIFFERENTIAL_EQUATION,
PARAMETER,
SUBEXPRESSION,
Equations,
)
from brian2.groups.group import get_dtype
from matplotlib import pyplot as plt
rng = np.random.default_rng()
"""supposed to be the central random number generator, but not yet used everywhere"""
[docs]
def times_are_regular(times):
if len(times) < 2:
return False
return np.allclose(np.diff(times), times[1] - times[0])
[docs]
def analog_signal(t_ms, values_no_unit, units="") -> neo.core.basesignal.BaseSignal:
if times_are_regular(t_ms):
return neo.AnalogSignal(
values_no_unit,
t_start=t_ms[0] * pq.ms,
units=units,
sampling_period=(t_ms[1] - t_ms[0]) * pq.ms,
)
else:
return neo.IrregularlySampledSignal(
t_ms * pq.ms,
values_no_unit,
units=units,
)
[docs]
def add_to_neo_segment(
segment: neo.core.Segment, *objects: neo.core.dataobject.DataObject
):
"""Conveniently adds multiple objects to a segment.
Taken from :class:`neo.core.group.Group`."""
container_lookup = {
cls_name: getattr(segment, container_name)
for cls_name, container_name in zip(
segment._child_objects, segment._child_containers
)
}
for obj in objects:
cls = obj.__class__
if hasattr(cls, "proxy_for"):
cls = cls.proxy_for
container = container_lookup[cls.__name__]
container.append(obj)
[docs]
def normalize_coords(coords: Quantity) -> Quantity:
"""Normalize coordinates to unit vectors."""
return coords / np.linalg.norm(coords, axis=-1, keepdims=True)
[docs]
def get_orth_vectors_for_V(V):
"""For nx3 block of row vectors V, return nx3 W1, W2 orthogonal
vector blocks"""
V = V.reshape((-1, 3, 1))
n = V.shape[0]
V = np.repeat(V, 3, axis=-1)
assert V.shape == (n, 3, 3)
q, r = np.linalg.qr(V)
# get two vectors orthogonal to v from QR decomp
W1, W2 = q[..., 1], q[..., 2]
assert W1.shape == W2.shape == (n, 3)
return W1.squeeze(), W2.squeeze()
[docs]
def xyz_from_rθz(rs, thetas, zs, xyz_start, xyz_end):
"""Convert from cylindrical to Cartesian coordinates."""
# not using np.linalg.norm because it strips units
m = xyz_start.reshape((-1, 3)).shape[0]
n = len(rs)
cyl_length = np.sqrt(np.sum((xyz_end - xyz_start) ** 2, axis=-1, keepdims=True))
assert cyl_length.shape in [(m, 1), (1,)]
c = (xyz_end - xyz_start) / cyl_length # unit vector in direction of cylinder
# in case cyl_length is 0, producing nans
assert c.shape in [(m, 3), (3,)]
if c.shape == (m, 3):
assert cyl_length.shape == (m, 1)
c[cyl_length.ravel() == 0] = [0, 0, 1]
elif c.shape == (3,):
c[:] = [0, 0, 1]
r1, r2 = get_orth_vectors_for_V(c)
def r_unit_vecs(thetas):
cosines = np.reshape(np.cos(thetas), (len(thetas), 1))
sines = np.reshape(np.sin(thetas), (len(thetas), 1))
# add axis for broadcasting so result is nx3
cosines = np.cos(thetas)[..., np.newaxis]
sines = np.sin(thetas)[..., np.newaxis]
return r1.reshape((m, 1, 3)) * cosines + r2.reshape((m, 1, 3)) * sines
coords = (
xyz_start.reshape((m, 1, 3))
+ c.reshape((m, 1, 3)) * zs.reshape((1, n, 1))
+ rs.reshape((1, n, 1)) * r_unit_vecs(thetas)
)
assert coords.shape == (m, n, 3)
return coords[..., 0], coords[..., 1], coords[..., 2]
# coords = coords.reshape((m * n), 3)
# return coords[:, 0], coords[:, 1], coords[:, 2]
[docs]
def modify_model_with_eqs(neuron_group, eqs_to_add):
"""Adapted from _create_variables() from neurongroup.py from Brian2
source code v2.3.0.2
"""
if type(eqs_to_add) == str:
eqs_to_add = Equations(eqs_to_add)
neuron_group.equations += eqs_to_add
variables = neuron_group.variables
dtype = {}
if isinstance(dtype, MutableMapping):
dtype["lastspike"] = neuron_group._clock.variables["t"].dtype
for eq in eqs_to_add.values():
dtype = get_dtype(eq, dtype) # {} appears to be the default
if eq.type in (DIFFERENTIAL_EQUATION, PARAMETER):
if "linked" in eq.flags:
# 'linked' cannot be combined with other flags
if not len(eq.flags) == 1:
raise SyntaxError(
('The "linked" flag cannot be ' "combined with other flags")
)
neuron_group._linked_variables.add(eq.varname)
else:
constant = "constant" in eq.flags
shared = "shared" in eq.flags
size = 1 if shared else neuron_group._N
variables.add_array(
eq.varname,
size=size,
dimensions=eq.dim,
dtype=dtype,
constant=constant,
scalar=shared,
)
elif eq.type == SUBEXPRESSION:
neuron_group.variables.add_subexpression(
eq.varname,
dimensions=eq.dim,
expr=str(eq.expr),
dtype=dtype,
scalar="shared" in eq.flags,
)
else:
raise AssertionError("Unknown type of equation: " + eq.eq_type)
# Add the conditional-write attribute for variables with the
# "unless refractory" flag
if neuron_group._refractory is not False:
for eq in neuron_group.equations.values():
if eq.type == DIFFERENTIAL_EQUATION and "unless refractory" in eq.flags:
not_refractory_var = neuron_group.variables["not_refractory"]
var = neuron_group.variables[eq.varname]
var.set_conditional_write(not_refractory_var)
# Stochastic variables
for xi in eqs_to_add.stochastic_variables:
try:
neuron_group.variables.add_auxiliary_variable(
xi, dimensions=(second**-0.5).dim
)
except KeyError as ex:
warnings.warn(
"Adding a stochastic variable to a neuron group that already"
" has a variable with the same name."
)
# Check scalar subexpressions
for eq in neuron_group.equations.values():
if eq.type == SUBEXPRESSION and "shared" in eq.flags:
var = neuron_group.variables[eq.varname]
for identifier in var.identifiers:
if identifier in neuron_group.variables:
if not neuron_group.variables[identifier].scalar:
raise SyntaxError(
(
"Shared subexpression %s refers "
"to non-shared variable %s."
)
% (eq.varname, identifier)
)
[docs]
def wavelength_to_rgb(wavelength_nm, gamma=0.8) -> tuple[float, float, float]:
"""taken from http://www.noah.org/wiki/Wavelength_to_RGB_in_Python
This converts a given wavelength of light to an
approximate RGB color value. The wavelength must be given
in nanometers in the range from 380 nm through 750 nm
(789 THz through 400 THz).
Based on code by Dan Bruton
http://www.physics.sfasu.edu/astro/color/spectra.html
"""
wavelength = wavelength_nm
if wavelength < 380:
wavelength = 380.0
if wavelength > 750:
wavelength = 750.0
if wavelength >= 380 and wavelength <= 440:
attenuation = 0.3 + 0.7 * (wavelength - 380) / (440 - 380)
R = ((-(wavelength - 440) / (440 - 380)) * attenuation) ** gamma
G = 0.0
B = (1.0 * attenuation) ** gamma
elif wavelength >= 440 and wavelength <= 490:
R = 0.0
G = ((wavelength - 440) / (490 - 440)) ** gamma
B = 1.0
elif wavelength >= 490 and wavelength <= 510:
R = 0.0
G = 1.0
B = (-(wavelength - 510) / (510 - 490)) ** gamma
elif wavelength >= 510 and wavelength <= 580:
R = ((wavelength - 510) / (580 - 510)) ** gamma
G = 1.0
B = 0.0
elif wavelength >= 580 and wavelength <= 645:
R = 1.0
G = (-(wavelength - 645) / (645 - 580)) ** gamma
B = 0.0
elif wavelength >= 645 and wavelength <= 750:
attenuation = 0.3 + 0.7 * (750 - wavelength) / (750 - 645)
R = (1.0 * attenuation) ** gamma
G = 0.0
B = 0.0
else:
R = 0.0
G = 0.0
B = 0.0
return (R, G, B)
[docs]
def brian_safe_name(name: str) -> str:
return (
name.replace(" ", "_")
.replace("-", "_")
.replace(".", "_")
.replace("(", "_")
.replace(")", "_")
)
[docs]
def style_plots_for_docs(dark=True):
# some hacky workaround for params not being updated until after first plot
f = plt.figure()
plt.plot()
plt.close(f)
if dark:
plt.style.use("dark_background")
for obj in ["figure", "axes", "savefig"]:
plt.rc(obj, facecolor="131416") # color of Furo dark background
else:
plt.style.use("default")
plt.rc("savefig", transparent=False)
plt.rc("axes.spines", top=False, right=False)
plt.rc("font", **{"sans-serif": "Open Sans"})
[docs]
def style_plots_for_paper(fontscale=5 / 6):
"""
fontscale=5/6 goes from default paper font size of 9.6 down to 8
"""
# some hacky workaround for params not being updated until after first plot
f = plt.figure()
plt.plot()
plt.close(f)
try:
import seaborn as sns
sns.set_context("paper", font_scale=fontscale)
except ImportError:
plt.style.use("seaborn-v0_8-paper")
warnings.warn("Seaborn not found, using matplotlib style, ignoring fontscale")
plt.rc("savefig", transparent=True, bbox="tight", dpi=300)
plt.rc("svg", fonttype="none")
plt.rc("axes.spines", top=False, right=False)
plt.rc("font", **{"sans-serif": "Open Sans"})
[docs]
def unit_safe_append(q1: Quantity, q2: Quantity, axis=0):
if not b2.have_same_dimensions(q1, q2):
raise ValueError("Dimensions must match")
if isinstance(q1, Quantity):
assert isinstance(q2, Quantity)
unit = q1.get_best_unit()
return np.append(q1 / unit, q2 / unit, axis=axis) * unit
else:
return np.append(q1, q2, axis=axis)