Optimal control¶
This tutorial will be more comprehensive than the others, bringing together all of cleo’s main capabilities—electrode recording, optogenetics, and latency modeling—as well as introducing more sophisticated model-based optimal feedback control.
To achieve the latter, we will use the Python bindings to the ldsCtrlEst C++ library.
First some boilerplate:
from brian2 import ms, mm, um, second, mV, Hz, np
import brian2.only as b2
import matplotlib.pyplot as plt
import cleo
cleo.utilities.style_plots_for_docs()
# numpy faster than cython for lightweight example
b2.prefs.codegen.target = "numpy"
np.random.seed(17320222)
b2.seed(17870917)
cleo.utilities.set_seed(17991214)
cy1 = "#C500CC"
cy2 = "#df87e1"
cu1 = cleo.utilities.wavelength_to_rgb(463 * b2.nmeter)
cu2 = cleo.utilities.wavelength_to_rgb(483 * b2.nmeter)
Network setup¶
For the network model we adapt an E/I network implementation from the Neuronal Dynamics textbook. Let’s run it and plot the spiking output:
Show code cell source
b2.defaultclock.dt = 0.5 * b2.ms
n_exc = 80
n_inh = int(n_exc / 4)
connection_probability = 0.1
n_ext = int(n_exc * connection_probability)
w0 = 0.1 * b2.mV # 0.2
g = 4
w_inh = -g * w0
synaptic_delay = 1.5 * b2.ms
w_external = w0
v_rest = -70 * b2.mV
v_reset = -60 * b2.mV
firing_threshold = -50 * b2.mV
tau_m = 20 * b2.ms
Rm = 100 * b2.Mohm
abs_refractory_period = 2 * b2.ms
thresh_rate = (firing_threshold - v_rest) / (w_external * n_ext * tau_m)
poisson_input_rate = 1 * thresh_rate
bias_current = 11 * b2.pA
bias_ng = b2.NeuronGroup(
1, "dI_bias/dt = -I_bias / second + bias_current*xi/sqrt(tau_m) : amp"
)
lif_dynamics = """
dv/dt = (-(v-v_rest) + Rm * (I + I_bias)) / tau_m : volt (unless refractory)
I : amp
I_bias : amp (linked)
"""
ng = b2.NeuronGroup(
n_exc + n_inh,
model=lif_dynamics,
threshold="v>firing_threshold",
reset="v=v_reset",
refractory=abs_refractory_period,
method="linear",
)
ng.I_bias = b2.linked_var(bias_ng, "I_bias")
# random v init
ng.v = (
np.random.uniform(
v_rest / b2.mV, high=firing_threshold / b2.mV, size=(n_exc + n_inh)
)
* b2.mV
)
ng_exc = ng[:n_exc]
ng_inh = ng[n_exc:]
syn = b2.Synapses(
ng,
model="w : volt",
on_pre="v += w",
delay=synaptic_delay,
)
syn.connect(p=connection_probability)
syn[f"i < {n_exc}"].w = w0
syn[f"i >= {n_exc}"].w = w_inh
external_poisson_input = b2.PoissonInput(
target=ng,
target_var="v",
N=n_ext,
rate=poisson_input_rate,
weight=w_external,
)
rate_monitor = b2.PopulationRateMonitor(ng)
spike_monitor = b2.SpikeMonitor(ng, record=True)
net = b2.Network(
ng,
bias_ng,
syn,
external_poisson_input,
rate_monitor,
spike_monitor,
)
runtime = 1 * b2.second
net.store()
net.run(runtime)
plt.scatter(spike_monitor.t / b2.ms, spike_monitor.i, s=1)
net.restore()
INFO No numerical integration method specified for group 'neurongroup', using method 'euler' (took 0.01s, trying other methods took 0.00s). [brian2.stateupdaters.base.method_choice]
The spiking looks like we want: relatively stable with random fluctuations in global activity from the bias current we added. This will make things more interesting for the controller.
Coordinates, stimulation, and recording¶
Here we assign coordinates to the neurons and configure the optogenetic intervention and recording setup:
hor_lim = 0.1
cleo.coords.assign_coords_rand_rect_prism(
ng, xlim=(-hor_lim, hor_lim), ylim=(-hor_lim, hor_lim), zlim=(0.4, 0.6)
)
fibers = cleo.light.Light(
name="fiber",
light_model=cleo.light.fiber473nm(R0=50 * um),
coords=[(-0.1, 0, 0.3), (0.1, 0, 0.3)] * mm,
)
opsin = cleo.opto.chr2_4s()
spikes = cleo.ephys.SortedSpiking(name="spikes")
probe = cleo.ephys.Probe(
coords=[[0, 0, 0.4], [0, 0, 0.6]] * mm,
signals=[spikes],
save_history=True,
)
cleo.viz.plot(
ng,
colors=["xkcd:fuchsia"],
xlim=(-0.2, 0.2),
ylim=(-0.2, 0.2),
zlim=(0.3, 0.8),
devices=[probe, fibers],
scatterargs={"alpha": 1},
axis_scale_unit=mm,
)
(<Figure size 640x480 with 1 Axes>,
<Axes3D: xlabel='x (mm)', ylabel='y (mm)', zlabel='z (mm)'>)
Looks right. Let’s set up the simulation and inject the devices:
sim = cleo.CLSimulator(net)
sim.inject(fibers, ng)
sim.inject(opsin, ng, Iopto_var_name="I")
sim.inject(probe, ng)
assert spikes.n_sorted >= 2
Collect training data¶
Our goal will be to control two neuron’s firing rates simultaneously.
To do this, we will use the LQR technique explained in Bolus et al., 2021 (“State-space optimal feedback control of optogenetically driven neural activity)”.
LQR is a model-based technique, and to fit a model of the system’s dynamics, we first need training data.
We will generate training data using Gaussian random walk inputs, modulo’ed to a normal stimulation range (0-75 mW/mm²).
Here we structure our data into trials, as ldsCtrlEst is designed for.
\(u\) represents the input and \(z\) the spike output.
n_trials = 5
n_samp = 500
u = []
z = []
n_u = 2 # 1-dimensional input (just one optogenetic actuator)
n_z = 2 # we'll be controlling two neurons
for trial in range(n_trials):
u_trial = 10 * np.cumsum(np.random.randn(n_u, n_samp), axis=1) % 75
u.append(u_trial)
z.append(np.zeros((n_z, n_samp)))
# add some zeros to get right baseline
u[-1][:, (n_samp // 2) :] = 0
We configure the LatencyIOProcessor to deliver our pre-computed stimulus and record the results, resetting after each trial:
class TrainingStimIOP(cleo.LatencyIOProcessor):
i_samp = 0
i_trial = 0
# here we just feed in the training inputs and record the outputs
def process(self, state_dict, t_samp):
i, t, z_t = state_dict["Probe"]["spikes"]
z[self.i_trial][:, self.i_samp] = z_t[:n_z] # just first two neurons
out = {"fiber": u[self.i_trial][:, self.i_samp] * b2.mwatt / b2.mm2}
self.i_samp += 1
return out, t_samp
# gets called with sim.reset()
def reset(self):
self.i_samp = 0
dt = 5 * ms
training_stim_iop = TrainingStimIOP(sample_period=dt)
sim.set_io_processor(training_stim_iop)
for i_trial in range(n_trials):
training_stim_iop.i_trial = i_trial
sim.run(n_samp * dt)
sim.reset()
INFO No numerical integration method specified for group 'syn_ChR2_neurongroup_1', using method 'euler' (took 0.03s, trying other methods took 0.04s). [brian2.stateupdaters.base.method_choice]
WARNING 'dt' is an internal variable of group 'syn_ChR2_neurongroup_1', but also exists in the run namespace with the value 5. * msecond. The internal variable will be used. [brian2.groups.group.Group.resolve.resolution_conflict]
WARNING /home/kyle/GaTech Dropbox/Kyle Johnsen/projects/cleo/cleo/ephys/spiking.py:594: RuntimeWarning: overflow encountered in exp
collision_prob_fn: Callable[[Quantity], float] = lambda t: 0.2 * np.exp(
[py.warnings]
WARNING /home/kyle/GaTech Dropbox/Kyle Johnsen/projects/cleo/cleo/ephys/spiking.py:465: RuntimeWarning: invalid value encountered in multiply
collision_prob = self.collision_prob_fn(t_diff) * same_chan
[py.warnings]
WARNING /home/kyle/GaTech Dropbox/Kyle Johnsen/projects/cleo/cleo/ephys/spiking.py:469: RuntimeWarning: invalid value encountered in multiply
collision_prob *= t_diff >= 0
[py.warnings]
Let’s plot our training data:
Show code cell source
fig, axs = plt.subplots(
n_trials, 2, figsize=(12, 6), layout="compressed", sharex=True, sharey="col"
)
for i, (utrial, ztrial) in enumerate(zip(u, z)):
# Plot spiking data
ax_spike = axs[i, 0]
ax_spike.imshow(
ztrial,
aspect="auto",
cmap="gray",
extent=[0, (n_samp * dt) / ms, *np.arange(n_z + 1)],
interpolation="nearest",
zorder=0,
)
ax_spike.set(ylabel="channel #", yticks=range(n_z))
if i == n_trials - 1:
ax_spike.set(xlabel="t [ms]")
# Plot input data
ax_input = axs[i, 1]
lines = ax_input.plot(
np.arange(n_samp) * dt / ms,
utrial.T,
alpha=0.8,
lw=1,
)
for i, ln, c in zip(range(n_u), lines, [cu1, cu2]):
ln.set(color=c, label=f"fiber {i + 1}")
ax_input.set(ylabel="u [mW/mm²]")
if i == n_trials - 1:
ax_input.set(xlabel="t [ms]")
axs[0, 1].legend()
axs[0, 0].set(title="spikes")
axs[0, 1].set(title="light inputs")
plt.show()
Model fitting¶
Now we have u and z in the form we need for ldsctrlest’s fitting functions: n_trial-length lists of n by n_samp arrays. We will now fit Gaussian linear dynamical systems using the SSID algorithm. See the documentation for more detailed explanations.
import ldsctrlest as lds
import ldsctrlest.gaussian as glds
n_x_fit = 3 # latent dimensionality of system
n_h = 100 # size of block Hankel data matrix
u_train = lds.UniformMatrixList(u, free_dim=2)
z_train = lds.UniformMatrixList(z, free_dim=2)
ssid = glds.FitSSID(n_x_fit, n_h, dt / second, u_train, z_train)
fit, sing_vals = ssid.Run(lds.SSIDWt.kMOESP)
fit_sys_ssid = glds.System(fit)
Here we plot the singular values of the data matrix—we should see a drop at or before our chosen model order if we have a decent fit. We also visualize impulse responses: we should see increases in firing rate for each of the fibers:
n_samp_imp = 50
y_imp = fit_sys_ssid.simulate_imp(n_samp_imp)
t_imp = np.arange(n_samp_imp) * dt / ms
fig, axs = plt.subplots(1, 3, figsize=(2 + n_u * 3, 4), layout="compressed")
axs[0].semilogy(sing_vals[: n_x_fit * 3], linewidth=2)
axs[0].set(ylabel="singular values", xlabel="singular value index")
for i_u in range(n_u):
ax = axs[i_u + 1]
lines = ax.plot(t_imp, y_imp[i_u].T, linewidth=2)
lines[0].set_color(cy1)
lines[1].set_color(cy2)
ax.set(title=f"Impulse response for $u_{i_u + 1}$", xlabel="time (ms)")
ax.legend(lines, ["$y_1$", "$y_2$"]);
We see a sharp drop in singular values after the first few, which justifies our model order choice. For the impulse responses, we expect both fibers to cause transitory increases in firing rate for both electrodes. Since we see that isn’t the case for the first, let’s try refining our fit with expectation-maximization (EM):
em = glds.FitEM(fit, u_train, z_train)
fit_em = em.Run(
calc_dynamics=True,
calc_Q=True,
calc_init=True,
calc_output=True,
calc_measurement=True,
max_iter=50,
tol=1e-2,
)
fit_sys_em = glds.System(fit_em)
Show code cell output
Iteration 1/50 ...
C_new[0]: -0.105012
d_new[0]: 0.391363
R_new[0]: 0.141144
A_new[0]: 0.982364
B_new[0]: 2.38562e-06
Q_new[0]: 0.0114576
x0_new[0]: -2.53407e-05
P0_new[0]: 1.00001e-06
max dtheta: 2.5084
Iteration 2/50 ...
C_new[0]: -0.0943574
d_new[0]: 0.461353
R_new[0]: 0.140272
A_new[0]: 0.977572
B_new[0]: -0.000414522
Q_new[0]: 0.0121686
x0_new[0]: -4.11669e-05
P0_new[0]: 1.00002e-06
max dtheta: 174.758
Iteration 3/50 ...
C_new[0]: -0.0915526
d_new[0]: 0.487437
R_new[0]: 0.138667
A_new[0]: 0.973509
B_new[0]: -0.000737076
Q_new[0]: 0.0128524
x0_new[0]: -5.35115e-05
P0_new[0]: 1.00003e-06
max dtheta: 11.957
Iteration 4/50 ...
C_new[0]: -0.0932068
d_new[0]: 0.489635
R_new[0]: 0.137209
A_new[0]: 0.96977
B_new[0]: -0.00102345
Q_new[0]: 0.0135567
x0_new[0]: -6.43448e-05
P0_new[0]: 1.00005e-06
max dtheta: 1.00887
Iteration 5/50 ...
C_new[0]: -0.0970164
d_new[0]: 0.480574
R_new[0]: 0.1359
A_new[0]: 0.965911
B_new[0]: -0.00130965
Q_new[0]: 0.0142672
x0_new[0]: -7.45381e-05
P0_new[0]: 1.00008e-06
max dtheta: 2.78001
Iteration 6/50 ...
C_new[0]: -0.101797
d_new[0]: 0.46656
R_new[0]: 0.134703
A_new[0]: 0.961747
B_new[0]: -0.00161025
Q_new[0]: 0.0149682
x0_new[0]: -8.44291e-05
P0_new[0]: 1.0001e-06
max dtheta: 10.8644
Iteration 7/50 ...
C_new[0]: -0.106955
d_new[0]: 0.450582
R_new[0]: 0.133571
A_new[0]: 0.957235
B_new[0]: -0.00192869
Q_new[0]: 0.0156501
x0_new[0]: -9.41219e-05
P0_new[0]: 1.00014e-06
max dtheta: 4.36099
Iteration 8/50 ...
C_new[0]: -0.112155
d_new[0]: 0.434213
R_new[0]: 0.132473
A_new[0]: 0.952402
B_new[0]: -0.00226317
Q_new[0]: 0.0163071
x0_new[0]: -0.000103627
P0_new[0]: 1.00017e-06
max dtheta: 1.75518
Iteration 9/50 ...
C_new[0]: -0.117176
d_new[0]: 0.418421
R_new[0]: 0.131404
A_new[0]: 0.947307
B_new[0]: -0.00260934
Q_new[0]: 0.0169355
x0_new[0]: -0.000112919
P0_new[0]: 1.00021e-06
max dtheta: 0.775979
Iteration 10/50 ...
C_new[0]: -0.121863
d_new[0]: 0.403843
R_new[0]: 0.130366
A_new[0]: 0.942028
B_new[0]: -0.0029617
Q_new[0]: 0.0175338
x0_new[0]: -0.000121961
P0_new[0]: 1.00026e-06
max dtheta: 0.51042
Iteration 11/50 ...
C_new[0]: -0.126112
d_new[0]: 0.390877
R_new[0]: 0.129363
A_new[0]: 0.936648
B_new[0]: -0.00331439
Q_new[0]: 0.018102
x0_new[0]: -0.000130711
P0_new[0]: 1.00031e-06
max dtheta: 0.373061
Iteration 12/50 ...
C_new[0]: -0.129868
d_new[0]: 0.37972
R_new[0]: 0.1284
A_new[0]: 0.93125
B_new[0]: -0.00366167
Q_new[0]: 0.0186418
x0_new[0]: -0.000139133
P0_new[0]: 1.00036e-06
max dtheta: 0.277814
Iteration 13/50 ...
C_new[0]: -0.133115
d_new[0]: 0.370404
R_new[0]: 0.127476
A_new[0]: 0.925911
B_new[0]: -0.00399819
Q_new[0]: 0.0191559
x0_new[0]: -0.000147195
P0_new[0]: 1.00041e-06
max dtheta: 0.337983
Iteration 14/50 ...
C_new[0]: -0.135871
d_new[0]: 0.362844
R_new[0]: 0.126592
A_new[0]: 0.920704
B_new[0]: -0.00431912
Q_new[0]: 0.0196477
x0_new[0]: -0.000154872
P0_new[0]: 1.00047e-06
max dtheta: 0.560341
Iteration 15/50 ...
C_new[0]: -0.138171
d_new[0]: 0.356879
R_new[0]: 0.125748
A_new[0]: 0.915695
B_new[0]: -0.00462026
Q_new[0]: 0.0201215
x0_new[0]: -0.00016215
P0_new[0]: 1.00054e-06
max dtheta: 1.35138
Iteration 16/50 ...
C_new[0]: -0.140066
d_new[0]: 0.35231
R_new[0]: 0.124944
A_new[0]: 0.910944
B_new[0]: -0.00489815
Q_new[0]: 0.0205815
x0_new[0]: -0.000169019
P0_new[0]: 1.0006e-06
max dtheta: 3.93717
Iteration 17/50 ...
C_new[0]: -0.141608
d_new[0]: 0.348928
R_new[0]: 0.124183
A_new[0]: 0.906501
B_new[0]: -0.00515016
Q_new[0]: 0.0210321
x0_new[0]: -0.000175482
P0_new[0]: 1.00067e-06
max dtheta: 0.788479
Iteration 18/50 ...
C_new[0]: -0.142849
d_new[0]: 0.346533
R_new[0]: 0.123468
A_new[0]: 0.902407
B_new[0]: -0.0053746
Q_new[0]: 0.0214775
x0_new[0]: -0.000181547
P0_new[0]: 1.00075e-06
max dtheta: 0.421416
Iteration 19/50 ...
C_new[0]: -0.143838
d_new[0]: 0.344946
R_new[0]: 0.122802
A_new[0]: 0.89869
B_new[0]: -0.00557078
Q_new[0]: 0.021921
x0_new[0]: -0.000187229
P0_new[0]: 1.00082e-06
max dtheta: 0.516032
Iteration 20/50 ...
C_new[0]: -0.144618
d_new[0]: 0.344006
R_new[0]: 0.122186
A_new[0]: 0.895363
B_new[0]: -0.005739
Q_new[0]: 0.0223651
x0_new[0]: -0.000192546
P0_new[0]: 1.0009e-06
max dtheta: 1.39452
Iteration 21/50 ...
C_new[0]: -0.145228
d_new[0]: 0.343577
R_new[0]: 0.121621
A_new[0]: 0.892428
B_new[0]: -0.00588045
Q_new[0]: 0.0228111
x0_new[0]: -0.000197523
P0_new[0]: 1.00098e-06
max dtheta: 7.81432
Iteration 22/50 ...
C_new[0]: -0.145698
d_new[0]: 0.343545
R_new[0]: 0.121107
A_new[0]: 0.889872
B_new[0]: -0.00599704
Q_new[0]: 0.0232597
x0_new[0]: -0.000202186
P0_new[0]: 1.00106e-06
max dtheta: 1.00709
Iteration 23/50 ...
C_new[0]: -0.146058
d_new[0]: 0.343815
R_new[0]: 0.120641
A_new[0]: 0.887674
B_new[0]: -0.00609114
Q_new[0]: 0.0237102
x0_new[0]: -0.000206562
P0_new[0]: 1.00114e-06
max dtheta: 0.557072
Iteration 24/50 ...
C_new[0]: -0.14633
d_new[0]: 0.34431
R_new[0]: 0.12022
A_new[0]: 0.885805
B_new[0]: -0.00616545
Q_new[0]: 0.0241616
x0_new[0]: -0.000210678
P0_new[0]: 1.00123e-06
max dtheta: 0.390198
Iteration 25/50 ...
C_new[0]: -0.146531
d_new[0]: 0.344969
R_new[0]: 0.119841
A_new[0]: 0.884232
B_new[0]: -0.00622268
Q_new[0]: 0.0246122
x0_new[0]: -0.000214561
P0_new[0]: 1.00131e-06
max dtheta: 0.535436
Iteration 26/50 ...
C_new[0]: -0.146678
d_new[0]: 0.345743
R_new[0]: 0.1195
A_new[0]: 0.882919
B_new[0]: -0.00626549
Q_new[0]: 0.0250604
x0_new[0]: -0.000218235
P0_new[0]: 1.0014e-06
max dtheta: 1.00021
Iteration 27/50 ...
C_new[0]: -0.146781
d_new[0]: 0.346594
R_new[0]: 0.119192
A_new[0]: 0.881833
B_new[0]: -0.00629635
Q_new[0]: 0.025504
x0_new[0]: -0.000221725
P0_new[0]: 1.00149e-06
max dtheta: 4230.77
Iteration 28/50 ...
C_new[0]: -0.146852
d_new[0]: 0.347492
R_new[0]: 0.118914
A_new[0]: 0.880941
B_new[0]: -0.00631743
Q_new[0]: 0.0259416
x0_new[0]: -0.000225049
P0_new[0]: 1.00158e-06
max dtheta: 0.873001
Iteration 29/50 ...
C_new[0]: -0.146897
d_new[0]: 0.348415
R_new[0]: 0.118662
A_new[0]: 0.880213
B_new[0]: -0.00633065
Q_new[0]: 0.0263713
x0_new[0]: -0.000228229
P0_new[0]: 1.00167e-06
max dtheta: 0.409528
Iteration 30/50 ...
C_new[0]: -0.146923
d_new[0]: 0.349345
R_new[0]: 0.118432
A_new[0]: 0.879623
B_new[0]: -0.00633764
Q_new[0]: 0.026792
x0_new[0]: -0.00023128
P0_new[0]: 1.00177e-06
max dtheta: 0.257248
Iteration 31/50 ...
C_new[0]: -0.146934
d_new[0]: 0.350268
R_new[0]: 0.118222
A_new[0]: 0.879147
B_new[0]: -0.00633975
Q_new[0]: 0.0272026
x0_new[0]: -0.000234218
P0_new[0]: 1.00186e-06
max dtheta: 0.182744
Iteration 32/50 ...
C_new[0]: -0.146935
d_new[0]: 0.351176
R_new[0]: 0.118028
A_new[0]: 0.878766
B_new[0]: -0.00633811
Q_new[0]: 0.0276021
x0_new[0]: -0.000237055
P0_new[0]: 1.00196e-06
max dtheta: 0.139274
Iteration 33/50 ...
C_new[0]: -0.146928
d_new[0]: 0.352061
R_new[0]: 0.117849
A_new[0]: 0.878462
B_new[0]: -0.00633364
Q_new[0]: 0.0279902
x0_new[0]: -0.000239804
P0_new[0]: 1.00205e-06
max dtheta: 0.124395
Iteration 34/50 ...
C_new[0]: -0.146917
d_new[0]: 0.352917
R_new[0]: 0.117683
A_new[0]: 0.878221
B_new[0]: -0.00632708
Q_new[0]: 0.0283665
x0_new[0]: -0.000242474
P0_new[0]: 1.00215e-06
max dtheta: 0.129348
Iteration 35/50 ...
C_new[0]: -0.146902
d_new[0]: 0.353742
R_new[0]: 0.117528
A_new[0]: 0.878032
B_new[0]: -0.00631903
Q_new[0]: 0.0287308
x0_new[0]: -0.000245074
P0_new[0]: 1.00225e-06
max dtheta: 0.134967
Iteration 36/50 ...
C_new[0]: -0.146886
d_new[0]: 0.354531
R_new[0]: 0.117382
A_new[0]: 0.877884
B_new[0]: -0.00630995
Q_new[0]: 0.0290831
x0_new[0]: -0.000247612
P0_new[0]: 1.00235e-06
max dtheta: 0.141478
Iteration 37/50 ...
C_new[0]: -0.146871
d_new[0]: 0.355285
R_new[0]: 0.117245
A_new[0]: 0.877768
B_new[0]: -0.00630022
Q_new[0]: 0.0294237
x0_new[0]: -0.000250096
P0_new[0]: 1.00246e-06
max dtheta: 0.149187
Iteration 38/50 ...
C_new[0]: -0.146856
d_new[0]: 0.356002
R_new[0]: 0.117115
A_new[0]: 0.877679
B_new[0]: -0.00629013
Q_new[0]: 0.0297528
x0_new[0]: -0.00025253
P0_new[0]: 1.00256e-06
max dtheta: 0.15852
Iteration 39/50 ...
C_new[0]: -0.146843
d_new[0]: 0.356681
R_new[0]: 0.116993
A_new[0]: 0.87761
B_new[0]: -0.00627991
Q_new[0]: 0.0300707
x0_new[0]: -0.000254921
P0_new[0]: 1.00267e-06
max dtheta: 0.170103
Iteration 40/50 ...
C_new[0]: -0.146832
d_new[0]: 0.357325
R_new[0]: 0.116876
A_new[0]: 0.877557
B_new[0]: -0.00626973
Q_new[0]: 0.030378
x0_new[0]: -0.000257274
P0_new[0]: 1.00278e-06
max dtheta: 0.184892
Iteration 41/50 ...
C_new[0]: -0.146825
d_new[0]: 0.357932
R_new[0]: 0.116765
A_new[0]: 0.877516
B_new[0]: -0.00625972
Q_new[0]: 0.0306751
x0_new[0]: -0.000259592
P0_new[0]: 1.00289e-06
max dtheta: 0.204436
Iteration 42/50 ...
C_new[0]: -0.14682
d_new[0]: 0.358504
R_new[0]: 0.116658
A_new[0]: 0.877484
B_new[0]: -0.00624998
Q_new[0]: 0.0309625
x0_new[0]: -0.000261879
P0_new[0]: 1.003e-06
max dtheta: 0.231434
Iteration 43/50 ...
C_new[0]: -0.146819
d_new[0]: 0.359043
R_new[0]: 0.116556
A_new[0]: 0.877459
B_new[0]: -0.00624056
Q_new[0]: 0.0312407
x0_new[0]: -0.000264138
P0_new[0]: 1.00312e-06
max dtheta: 0.271044
Iteration 44/50 ...
C_new[0]: -0.146821
d_new[0]: 0.35955
R_new[0]: 0.116459
A_new[0]: 0.877439
B_new[0]: -0.00623152
Q_new[0]: 0.0315103
x0_new[0]: -0.000266373
P0_new[0]: 1.00323e-06
max dtheta: 0.334533
Iteration 45/50 ...
C_new[0]: -0.146827
d_new[0]: 0.360026
R_new[0]: 0.116365
A_new[0]: 0.877422
B_new[0]: -0.00622289
Q_new[0]: 0.0317718
x0_new[0]: -0.000268586
P0_new[0]: 1.00335e-06
max dtheta: 0.452143
Iteration 46/50 ...
C_new[0]: -0.146837
d_new[0]: 0.360474
R_new[0]: 0.116274
A_new[0]: 0.877408
B_new[0]: -0.00621469
Q_new[0]: 0.0320256
x0_new[0]: -0.00027078
P0_new[0]: 1.00347e-06
max dtheta: 0.742147
Iteration 47/50 ...
C_new[0]: -0.146849
d_new[0]: 0.360894
R_new[0]: 0.116186
A_new[0]: 0.877394
B_new[0]: -0.00620691
Q_new[0]: 0.0322723
x0_new[0]: -0.000272956
P0_new[0]: 1.00359e-06
max dtheta: 2.58805
Iteration 48/50 ...
C_new[0]: -0.146866
d_new[0]: 0.361289
R_new[0]: 0.116102
A_new[0]: 0.87738
B_new[0]: -0.00619957
Q_new[0]: 0.0325123
x0_new[0]: -0.000275116
P0_new[0]: 1.00371e-06
max dtheta: 1.46554
Iteration 49/50 ...
C_new[0]: -0.146885
d_new[0]: 0.361659
R_new[0]: 0.11602
A_new[0]: 0.877367
B_new[0]: -0.00619266
Q_new[0]: 0.0327461
x0_new[0]: -0.000277262
P0_new[0]: 1.00384e-06
max dtheta: 0.53465
Iteration 50/50 ...
C_new[0]: -0.146908
d_new[0]: 0.362008
R_new[0]: 0.11594
A_new[0]: 0.877352
B_new[0]: -0.00618617
Q_new[0]: 0.0329742
x0_new[0]: -0.000279396
P0_new[0]: 1.00396e-06
max dtheta: 0.313476
We’ll plot impulse responses again:
y_imp = fit_sys_em.simulate_imp(n_samp_imp)
fig, axs = plt.subplots(1, 2, figsize=(3 * n_u, 4), layout="compressed")
for i_u in range(n_u):
ax = axs[i_u]
lines = ax.plot(t_imp, y_imp[i_u].T, linewidth=2)
lines[0].set_color(cy1)
lines[1].set_color(cy2)
ax.set(title=f"Impulse response for $u_{i_u + 1}$", xlabel="time (ms)")
ax.legend(lines, ["$y_1$", "$y_2$"]);
Looking better. Since we can’t visualize singular values, let’s compare the two fits with percent explained variance:
z_arr = np.array(z)
def pct_exp_var(sys, nstep=1):
x_filt, x_pred, y_pred = sys.nstep_pred_block(u_train, z_train, nstep)
y_pred = np.array(y_pred)
assert y_pred.shape == z_arr[:, :, nstep:].shape, y_pred.shape
corrcoef_mat = np.corrcoef(z_arr[:, :, nstep:].flatten(), y_pred[:, :, :].flatten())
return corrcoef_mat[0, 1] ** 2
print("SSID r²:", pct_exp_var(fit_sys_ssid))
print("EM r²:", pct_exp_var(fit_sys_em))
SSID r²: 0.07508708150847317
EM r²: 0.26515411771446507
As we suspected from the impulse responses, the EM-refined fit is much better.
LQR optimal controller design¶
We now use the fit parameters to create the controller system and set additional parameters. The feedback gain, \( K_c \), is especially important, determining how the controller responds to the current “error”—the difference between where the system is (estimated to be) now and where we want it to be. The field of optimal control deals with how to design the controller so as to minimize a cost function reflecting what we care about.
With a linear system (obtained from the fitting procedure above) and quadratic per-timestep cost function \(J\) penalizing distance from the reference \(x^*\) and the input \(u\)
we can use the closed-form optimal solution called the Linear Quadratic Regulator (LQR).
The \(P\) matrix is obtained by numerically solving the discrete algebraic Riccati equation:
# upper and lower bounds on control signal (optic fiber light intensity)
u_lb = 0 # mW/mm2
u_ub = 75 # mW/mm2
controller = glds.Controller(fit_sys_em, u_lb, u_ub)
# careful not to use this anymore since controller made a copy
del fit_sys_em
from scipy.linalg import solve_discrete_are
A, B, C = controller.sys.A, controller.sys.B, controller.sys.C
# cost matrices
# Q reflects how much we care about state error
# we use C'C since we really care about output error, not latent state
Q_cost = C.T @ C
R_cost = 1e-4 * np.eye(n_u) # reflects how much we care about minimizing the stimulus
P = solve_discrete_are(A, B, Q_cost, R_cost)
controller.Kc = np.linalg.inv(R_cost + B.T @ P @ B) @ (B.T @ P @ A)
print(controller)
print("For controlled system dynamics A - BK:")
print("eigvals:", np.linalg.eigvals(A - B @ controller.Kc))
print("magnitude of eigvals:", np.abs(np.linalg.eigvals(A - B @ controller.Kc)))
********** SYSTEM **********
x:
0.2416
0.2736
-0.4363
P:
0.1522 -0.0049 0.0092
-0.0049 0.2306 -0.0099
0.0092 -0.0099 0.2138
A:
0.8774 -0.1411 0.1746
0.2014 0.0053 0.4933
0.0106 -0.8962 -0.3273
B:
-0.0062 -0.0065
0.0161 0.0073
0.0017 0.0012
g:
1.0000
1.0000
m:
0
0
0
Q:
0.0330 -0.0765 -0.0013
-0.0765 0.2050 0.0040
-0.0013 0.0040 0.0198
Q_m:
1.0000e-06 0 0
0 1.0000e-06 0
0 0 1.0000e-06
d:
0.3620
0.3859
C:
-0.1469 0.1281 -0.0339
-0.1480 0.1816 -0.0983
y:
0.3764
0.4427
R:
0.1159 -0.0039
-0.0039 0.1031
g_design : 1.0000
1.0000
u_lb : 0
u_ub : 75
For controlled system dynamics A - BK:
eigvals: [ 0.72559149+0.j -0.15320079+0.55360005j -0.15320079-0.55360005j]
magnitude of eigvals: [0.72559149 0.57440709 0.57440709]
We now configure a LatencyIOProcessor to use our controller:
y_ref = np.mean(z) * 0.75 # target rate per timebin
class CtrlLoop(cleo.LatencyIOProcessor):
def __init__(self, sample_period, controller):
super().__init__(sample_period)
self.controller = controller
self.sys = controller.sys
self.do_control = False # allows us to turn on and off control
# for post hoc visualization/analysis:
self.x_hat = np.empty((n_x_fit, 0))
self.y_hat = np.empty((n_z, 0))
self.z = np.empty((n_z, 0))
def process(self, state_dict, t_samp):
i, t, z_t = state_dict["Probe"]["spikes"]
z_t = z_t[:n_z].reshape((-1, 1)) # just first n_z neurons
self.controller.y_ref = np.ones((n_z, 1)) * y_ref
u_t = self.controller.ControlOutputReference(z_t, do_control=self.do_control)
out = {fibers.name: u_t.squeeze() * b2.mwatt / b2.mm2}
# record variables from this timestep
self.y_hat = np.hstack([self.y_hat, self.sys.y])
self.x_hat = np.hstack([self.x_hat, self.sys.x])
self.z = np.hstack((self.z, z_t))
return out, t_samp + 3 * ms # 3 ms delay
ctrl_loop = CtrlLoop(sample_period=dt, controller=controller)
Run the experiment¶
We’ll now run the simulation with and without control to compare.
sim.reset() # only needed when rerunning
sim.set_io_processor(ctrl_loop)
T0 = 200 * ms
sim.run(T0)
ctrl_loop.do_control = True
T1 = 1000 * ms
sim.run(T1)
Now we plot the results to see how well the controller was able to match the desired firing rate:
def y_ref_fn(t, in_hertz=True):
y_baseline = np.mean(controller.sys.d)
if in_hertz:
yr = y_ref / dt / Hz
y_baseline /= dt * Hz
else:
yr = y_ref
return y_baseline * (t < T0) + yr * (t >= T0) # target rate per timebin
def plot_ctrl(loop):
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
ax1.set(ylabel="spikes/s", title="real-time estimated firing rates")
ax1.plot(
loop.t_samp / ms,
loop.y_hat[0] / dt,
c=cy1,
alpha=0.7,
label="$\\hat{y}_1$",
)
ax1.plot(
loop.t_samp / ms,
loop.y_hat[1] / dt,
c=cy2,
alpha=0.7,
label="$\\hat{y}_2$",
)
ax1.plot(loop.t_samp / ms, y_ref_fn(loop.t_samp), c="green", label="reference")
ax1.axvline(T0 / ms, c="xkcd:red", linestyle=":", label="ctrl on")
ax1.legend(loc="right")
ax2.plot(fibers.t / b2.ms, fibers.irradiance_[:, 0], c=cu1, label="$u_1$")
ax2.plot(fibers.t / b2.ms, fibers.irradiance_[:, 1], c=cu2, label="$u_2$")
ax2.legend()
ax2.set(xlabel="t (ms)", ylabel="u (mW/mm$^2$)", title="input signal")
plot_ctrl(ctrl_loop)
What gives? Why does the controller keep the input on when it’s clearly over the target? Let’s check a Gaussian-smoothed post hoc estimated firing rate to confirm the firing rates are too high:
from scipy.ndimage import gaussian_filter1d
def plot_post_hoc(loop):
i_baseline = loop.t_samp < T0
i_static = (loop.t_samp >= T0) & (loop.t_samp < T0 + T1)
print("Results (spikes/second):")
print("baseline =", np.sum(loop.z[:, i_baseline], axis=1) / T0)
print("target =", y_ref / dt)
print(
"achieved =",
(np.sum(loop.z[:, i_static], axis=1) / T1).round(1),
)
win_len = 25 * ms / dt
smoothed = gaussian_filter1d(loop.z, sigma=win_len, axis=1) / dt
plt.plot(loop.t_samp / ms, y_ref_fn(loop.t_samp), c="green", label="reference")
plt.axvline(T0 / ms, c="r", ls=":")
plt.xlabel("t (ms)")
plt.ylabel("spikes/s")
plt.title("Gaussian-smoothed firing rate")
plt.plot(loop.t_samp / ms, smoothed[0], c=cy1)
plt.plot(loop.t_samp / ms, smoothed[1], c=cy2)
plot_post_hoc(ctrl_loop)
Results (spikes/second):
baseline = [5. 0.] Hz
target = 132.6 Hz
achieved = [198. 181.] Hz
And so they are. Another possibility is that the steady-state set point calculated by the controller is unattainable. Let’s see:
controller.u_ref
array([[-31.29477828],
[ 84.5755037 ]])
It looks to be the case. The controller kept \(u_1\) high since what it was trying to use negative \(u_2\) at the same time—unfortunately, we can’t exactly use negative light. Even if the controller didn’t fail in this way, its ignorance of the upper limit (75 mW/mm²) poses another problem. This is a fundamental limitation of LQR; it can’t account for constraints. What’s the solution?
Accounting for input constraints¶
Let’s help out LQR by computing a set point that respects our constraints. We need to break out convex optimization tools to do so though.
import cvxpy as cp
def opt_u_x():
"""
Solve the optimal control problem using CVXPY.
"""
# Define the optimization variables
u = cp.Variable(n_u)
x = cp.Variable(n_x_fit)
y_r = np.full(n_z, y_ref)
# Define the cost function
cost = cp.sum_squares(C @ x + controller.sys.d - y_r) + cp.quad_form(u, R_cost)
# Define the constraints
constraints = [
x == A @ x + B @ u,
u >= u_lb,
u <= u_ub,
]
# Define the optimization problem
prob = cp.Problem(cp.Minimize(cost), constraints)
prob.solve()
return u.value, x.value
u_ref, x_ref = opt_u_x()
u_ref, x_ref
WARNING /home/kyle/miniforge3/envs/cleo/lib/python3.12/site-packages/cvxpy/reductions/solvers/solving_chain_utils.py:30: UserWarning: The problem includes expressions that don't support CPP backend. Defaulting to the SCIPY backend for canonicalization.
warnings.warn(UserWarning(
[py.warnings]
(array([17.06548526, 12.72169718]),
array([-1.59446947, 0.04281098, -0.007822 ]))
That u_ref looks much better!
Let’s try again, this time using the Control method of the controller, that lets us set the x and u reference values ourselves, rather than computing the unhelpful values we saw.
class CtrlLoop2(cleo.LatencyIOProcessor):
def __init__(self, sample_period, controller):
super().__init__(sample_period)
self.controller = controller
self.sys = controller.sys
self.do_control = False # allows us to turn on and off control
# for post hoc visualization/analysis:
self.x_hat = np.empty((n_x_fit, 0))
self.y_hat = np.empty((n_z, 0))
self.z = np.empty((n_z, 0))
def process(self, state_dict, t_samp):
i, t, z_t = state_dict["Probe"]["spikes"]
z_t = z_t[:n_z].reshape((-1, 1)) # just first n_z neurons
self.controller.u_ref = u_ref
self.controller.x_ref = x_ref
u_t = self.controller.Control(z_t, do_control=self.do_control)
if not self.do_control:
u_t = np.zeros((n_u, 1)) # no control
out = {fibers.name: u_t.squeeze() * b2.mwatt / b2.mm2}
# record variables from this timestep
self.y_hat = np.hstack([self.y_hat, self.sys.y])
self.x_hat = np.hstack([self.x_hat, self.sys.x])
self.z = np.hstack((self.z, z_t))
return out, t_samp + 3 * ms # 3 ms delay
ctrl_loop2 = CtrlLoop2(sample_period=dt, controller=controller)
sim.reset()
sim.set_io_processor(ctrl_loop2)
sim.run(T0)
ctrl_loop2.do_control = True
sim.run(T1)
plot_ctrl(ctrl_loop2)
plot_post_hoc(ctrl_loop2)
Results (spikes/second):
baseline = [15. 20.] Hz
target = 132.6 Hz
achieved = [151. 152.] Hz
There we go! That looks much better.
Conclusion¶
As a recap, in this tutorial we’ve seen how to:
inject optogenetic stimulation into an existing Brian network
inject an electrode into an existing Brian network to record spikes
generate training data and fit a Gaussian linear dynamical system to the spiking output using
ldsctrlestconfigure an
ldsctrlestLQR controller based on that linear system and design optimal gainsuse that controller in running a complete simulated feedback control experiment
work around the limitations of LQR by computing a feasible setpoint using convex optimization
Bonus: model predictive control¶
The stage is now set for model predictive control (MPC), which, by solving an optimization problem at every timestep, has two important capabilities LQR does not:
It can look ahead, optimizing a time-varying control trajectory over a horizon of specified length.
It can account for linear constraints on states and inputs.
The main downside is the increased computational cost compared to LQR.
Let’s try this experiment again with MPC instead, using the lqmpc package.
MPC setup¶
import lqmpc
H = 20
mpc = lqmpc.LQMPC(dt / second, A, B, C)
# N, M are prediction and control horizons
mpc.set_control(Q_cost, R_cost, N=H, M=H - 1)
mpc.set_constraints(umin=np.zeros(n_u), umax=np.full(n_u, u_ub))
Since MPC runs a quadratic program under the hood every timestep, one advantage we should have is not needing to think about the optimal setpoints like we just did.
However, lqmpc only support a state, not output reference, so we’ll just use the same reference value for now.
Adding MPC to ldsctrlest is planned, including control by output reference; when implemented, this tutorial should be updated.
# lqmpc doesn't support directly optimizing y_ref, so we use the same steady-state x_ref
def xref_fn(t):
return (t >= T0) * x_ref[:, None]
We mentioned a downside of MPC was the increased computational cost. Let’s measure how long the controller takes:
mpc.simulate(
t_sim=dt / second,
x0=np.zeros(n_x_fit),
u0=u_ref,
xr=xref_fn(np.arange(200) * dt),
L=150,
);
Simulation Time: 0.30167532 s
Mean Step Time: 0.00201117 s
Roughly 3 ms.
Good to know!
Let’s increase the latency of our control loop from 3 to 6 ms then.
Unlike ldsCtrlEst, lqmpc does not have built-in Kalman filtering, so we’ll need to add that too:
class MPCLoop(cleo.LatencyIOProcessor):
def __init__(self, sample_period):
super().__init__(sample_period)
# for post hoc visualization/analysis:
self.x_hat = np.empty((n_x_fit, 0))
self.y_hat = np.empty((n_z, 0))
self.z = np.empty((n_z, 0))
mpc.xi = np.zeros(n_x_fit)
mpc.ui = u_ref
self.P_cov = controller.sys.P # initialize state covariance matrix
self.Q = controller.sys.Q
self.R = controller.sys.R
self.d = controller.sys.d.squeeze()
def process(self, state_dict, t_samp):
i, t, z_t = state_dict["Probe"]["spikes"]
z_t = z_t[:n_z] # just first n_z neurons
# update state covariance for Kalman filtering
# predict step
P_pred = A @ self.P_cov @ A.T + self.Q
# Kalman gain
S = C @ P_pred @ C.T + self.R
K = P_pred @ C.T @ np.linalg.inv(S)
# Update step
z_err = z_t - (C @ mpc.xi + self.d) # Measurement residual
mpc.xi = mpc.xi + K @ z_err
self.P_cov = (np.eye(len(P_pred)) - K @ C) @ P_pred
x_hat = mpc.xi.copy()
mpc.step(
dt / second, mpc.xi, mpc.ui, xref_fn(t_samp + np.arange(H) * dt), out=False
)
u_t = mpc.ui
out = {fibers.name: u_t.squeeze() * b2.mwatt / b2.mm2}
# record variables from this timestep
y_hat = C @ x_hat + self.d
self.y_hat = np.column_stack([self.y_hat, y_hat])
self.x_hat = np.column_stack([self.x_hat, x_hat])
self.z = np.column_stack((self.z, z_t))
return out, t_samp + 6 * ms
mpc_loop = MPCLoop(sample_period=dt)
Rerunning the experiment¶
sim.reset()
sim.set_io_processor(mpc_loop)
sim.run(T0 + T1)
WARNING /home/kyle/GaTech Dropbox/Kyle Johnsen/projects/cleo/cleo/light/light.py:557: UserWarning: fiber: negative light value clipped to 0
warnings.warn(f"{self.name}: negative light value clipped to 0")
[py.warnings]
plot_ctrl(mpc_loop)
plot_post_hoc(mpc_loop)
Results (spikes/second):
baseline = [50. 60.] Hz
target = 132.6 Hz
achieved = [135. 135.] Hz
Also looking good. We don’t see much of an advantage over LQR with a constraint-informed setpoint here since the optimization we performed was most of what was needed for a static reference. This essentially did part of MPC’s job for it; the difference being that MPC optimizes an entire trajectory (over a receding horizon), ensuring constraints are met at every step. We do see less overshoot though, which makes sense since our controller is looking 100 ms ahead.
This should be revisited with a time-varying reference once ldsctrlest’s MPC with output reference control is available.