Optimal control

This tutorial will be more comprehensive than the others, bringing together all of cleo’s main capabilities—electrode recording, optogenetics, and latency modeling—as well as introducing more sophisticated model-based optimal feedback control. To achieve the latter, we will use the Python bindings to the ldsCtrlEst C++ library.

First some boilerplate:

from brian2 import ms, mm, um, second, mV, Hz, np
import brian2.only as b2
import matplotlib.pyplot as plt
import cleo

cleo.utilities.style_plots_for_docs()

# numpy faster than cython for lightweight example
b2.prefs.codegen.target = "numpy"
np.random.seed(17320222)
b2.seed(17870917)
cleo.utilities.set_seed(17991214)

cy1 = "#C500CC"
cy2 = "#df87e1"
cu1 = cleo.utilities.wavelength_to_rgb(463 * b2.nmeter)
cu2 = cleo.utilities.wavelength_to_rgb(483 * b2.nmeter)

Network setup

For the network model we adapt an E/I network implementation from the Neuronal Dynamics textbook. Let’s run it and plot the spiking output:

Hide code cell source
b2.defaultclock.dt = 0.5 * b2.ms

n_exc = 80
n_inh = int(n_exc / 4)
connection_probability = 0.1
n_ext = int(n_exc * connection_probability)
w0 = 0.1 * b2.mV  # 0.2
g = 4
w_inh = -g * w0
synaptic_delay = 1.5 * b2.ms
w_external = w0
v_rest = -70 * b2.mV
v_reset = -60 * b2.mV
firing_threshold = -50 * b2.mV
tau_m = 20 * b2.ms
Rm = 100 * b2.Mohm
abs_refractory_period = 2 * b2.ms
thresh_rate = (firing_threshold - v_rest) / (w_external * n_ext * tau_m)
poisson_input_rate = 1 * thresh_rate
bias_current = 11 * b2.pA

bias_ng = b2.NeuronGroup(
    1, "dI_bias/dt = -I_bias / second + bias_current*xi/sqrt(tau_m) : amp"
)

lif_dynamics = """
dv/dt = (-(v-v_rest) + Rm * (I + I_bias)) / tau_m : volt (unless refractory)
I : amp
I_bias : amp (linked)
"""

ng = b2.NeuronGroup(
    n_exc + n_inh,
    model=lif_dynamics,
    threshold="v>firing_threshold",
    reset="v=v_reset",
    refractory=abs_refractory_period,
    method="linear",
)
ng.I_bias = b2.linked_var(bias_ng, "I_bias")

# random v init
ng.v = (
    np.random.uniform(
        v_rest / b2.mV, high=firing_threshold / b2.mV, size=(n_exc + n_inh)
    )
    * b2.mV
)
ng_exc = ng[:n_exc]
ng_inh = ng[n_exc:]

syn = b2.Synapses(
    ng,
    model="w : volt",
    on_pre="v += w",
    delay=synaptic_delay,
)
syn.connect(p=connection_probability)
syn[f"i < {n_exc}"].w = w0
syn[f"i >= {n_exc}"].w = w_inh

external_poisson_input = b2.PoissonInput(
    target=ng,
    target_var="v",
    N=n_ext,
    rate=poisson_input_rate,
    weight=w_external,
)

rate_monitor = b2.PopulationRateMonitor(ng)
spike_monitor = b2.SpikeMonitor(ng, record=True)

net = b2.Network(
    ng,
    bias_ng,
    syn,
    external_poisson_input,
    rate_monitor,
    spike_monitor,
)

runtime = 1 * b2.second
net.store()
net.run(runtime)

plt.scatter(spike_monitor.t / b2.ms, spike_monitor.i, s=1)
net.restore()
INFO       No numerical integration method specified for group 'neurongroup', using method 'euler' (took 0.01s, trying other methods took 0.00s). [brian2.stateupdaters.base.method_choice]
../_images/57c43f09f88aaba40c3413dac37cf87a7bf7c3fd264f18678ed0f06e04eddd27.png

The spiking looks like we want: relatively stable with random fluctuations in global activity from the bias current we added. This will make things more interesting for the controller.

Coordinates, stimulation, and recording

Here we assign coordinates to the neurons and configure the optogenetic intervention and recording setup:

hor_lim = 0.1
cleo.coords.assign_coords_rand_rect_prism(
    ng, xlim=(-hor_lim, hor_lim), ylim=(-hor_lim, hor_lim), zlim=(0.4, 0.6)
)

fibers = cleo.light.Light(
    name="fiber",
    light_model=cleo.light.fiber473nm(R0=50 * um),
    coords=[(-0.1, 0, 0.3), (0.1, 0, 0.3)] * mm,
)

opsin = cleo.opto.chr2_4s()

spikes = cleo.ephys.SortedSpiking(name="spikes")
probe = cleo.ephys.Probe(
    coords=[[0, 0, 0.4], [0, 0, 0.6]] * mm,
    signals=[spikes],
    save_history=True,
)

cleo.viz.plot(
    ng,
    colors=["xkcd:fuchsia"],
    xlim=(-0.2, 0.2),
    ylim=(-0.2, 0.2),
    zlim=(0.3, 0.8),
    devices=[probe, fibers],
    scatterargs={"alpha": 1},
    axis_scale_unit=mm,
)
(<Figure size 640x480 with 1 Axes>,
 <Axes3D: xlabel='x (mm)', ylabel='y (mm)', zlabel='z (mm)'>)
../_images/e2a37db887df4763264250629685debc06700e3de13276add713b9a9be1673c9.png

Looks right. Let’s set up the simulation and inject the devices:

sim = cleo.CLSimulator(net)
sim.inject(fibers, ng)
sim.inject(opsin, ng, Iopto_var_name="I")
sim.inject(probe, ng)
assert spikes.n_sorted >= 2

Collect training data

Our goal will be to control two neuron’s firing rates simultaneously. To do this, we will use the LQR technique explained in Bolus et al., 2021 (“State-space optimal feedback control of optogenetically driven neural activity)”. LQR is a model-based technique, and to fit a model of the system’s dynamics, we first need training data. We will generate training data using Gaussian random walk inputs, modulo’ed to a normal stimulation range (0-75 mW/mm²). Here we structure our data into trials, as ldsCtrlEst is designed for. \(u\) represents the input and \(z\) the spike output.

n_trials = 5
n_samp = 500
u = []
z = []
n_u = 2  # 1-dimensional input (just one optogenetic actuator)
n_z = 2  # we'll be controlling two neurons
for trial in range(n_trials):
    u_trial = 10 * np.cumsum(np.random.randn(n_u, n_samp), axis=1) % 75
    u.append(u_trial)
    z.append(np.zeros((n_z, n_samp)))
# add some zeros to get right baseline
u[-1][:, (n_samp // 2) :] = 0

We configure the LatencyIOProcessor to deliver our pre-computed stimulus and record the results, resetting after each trial:

class TrainingStimIOP(cleo.LatencyIOProcessor):
    i_samp = 0
    i_trial = 0

    # here we just feed in the training inputs and record the outputs
    def process(self, state_dict, t_samp):
        i, t, z_t = state_dict["Probe"]["spikes"]
        z[self.i_trial][:, self.i_samp] = z_t[:n_z]  # just first two neurons
        out = {"fiber": u[self.i_trial][:, self.i_samp] * b2.mwatt / b2.mm2}
        self.i_samp += 1
        return out, t_samp

    # gets called with sim.reset()
    def reset(self):
        self.i_samp = 0


dt = 5 * ms
training_stim_iop = TrainingStimIOP(sample_period=dt)
sim.set_io_processor(training_stim_iop)

for i_trial in range(n_trials):
    training_stim_iop.i_trial = i_trial
    sim.run(n_samp * dt)
    sim.reset()
INFO       No numerical integration method specified for group 'syn_ChR2_neurongroup_1', using method 'euler' (took 0.03s, trying other methods took 0.04s). [brian2.stateupdaters.base.method_choice]
WARNING    'dt' is an internal variable of group 'syn_ChR2_neurongroup_1', but also exists in the run namespace with the value 5. * msecond. The internal variable will be used. [brian2.groups.group.Group.resolve.resolution_conflict]
WARNING    /home/kyle/GaTech Dropbox/Kyle Johnsen/projects/cleo/cleo/ephys/spiking.py:594: RuntimeWarning: overflow encountered in exp
  collision_prob_fn: Callable[[Quantity], float] = lambda t: 0.2 * np.exp(
 [py.warnings]
WARNING    /home/kyle/GaTech Dropbox/Kyle Johnsen/projects/cleo/cleo/ephys/spiking.py:465: RuntimeWarning: invalid value encountered in multiply
  collision_prob = self.collision_prob_fn(t_diff) * same_chan
 [py.warnings]
WARNING    /home/kyle/GaTech Dropbox/Kyle Johnsen/projects/cleo/cleo/ephys/spiking.py:469: RuntimeWarning: invalid value encountered in multiply
  collision_prob *= t_diff >= 0
 [py.warnings]

Let’s plot our training data:

Hide code cell source
fig, axs = plt.subplots(
    n_trials, 2, figsize=(12, 6), layout="compressed", sharex=True, sharey="col"
)

for i, (utrial, ztrial) in enumerate(zip(u, z)):
    # Plot spiking data
    ax_spike = axs[i, 0]
    ax_spike.imshow(
        ztrial,
        aspect="auto",
        cmap="gray",
        extent=[0, (n_samp * dt) / ms, *np.arange(n_z + 1)],
        interpolation="nearest",
        zorder=0,
    )
    ax_spike.set(ylabel="channel #", yticks=range(n_z))
    if i == n_trials - 1:
        ax_spike.set(xlabel="t [ms]")

    # Plot input data
    ax_input = axs[i, 1]
    lines = ax_input.plot(
        np.arange(n_samp) * dt / ms,
        utrial.T,
        alpha=0.8,
        lw=1,
    )
    for i, ln, c in zip(range(n_u), lines, [cu1, cu2]):
        ln.set(color=c, label=f"fiber {i + 1}")
    ax_input.set(ylabel="u [mW/mm²]")
    if i == n_trials - 1:
        ax_input.set(xlabel="t [ms]")

axs[0, 1].legend()
axs[0, 0].set(title="spikes")
axs[0, 1].set(title="light inputs")
plt.show()
../_images/f8b3aa38156a0f714dad505d04913c5d51460ffe80e4fae4964f48612f270ed6.png

Model fitting

Now we have u and z in the form we need for ldsctrlest’s fitting functions: n_trial-length lists of n by n_samp arrays. We will now fit Gaussian linear dynamical systems using the SSID algorithm. See the documentation for more detailed explanations.

import ldsctrlest as lds
import ldsctrlest.gaussian as glds

n_x_fit = 3  # latent dimensionality of system
n_h = 100  # size of block Hankel data matrix
u_train = lds.UniformMatrixList(u, free_dim=2)
z_train = lds.UniformMatrixList(z, free_dim=2)
ssid = glds.FitSSID(n_x_fit, n_h, dt / second, u_train, z_train)
fit, sing_vals = ssid.Run(lds.SSIDWt.kMOESP)
fit_sys_ssid = glds.System(fit)

Here we plot the singular values of the data matrix—we should see a drop at or before our chosen model order if we have a decent fit. We also visualize impulse responses: we should see increases in firing rate for each of the fibers:

n_samp_imp = 50
y_imp = fit_sys_ssid.simulate_imp(n_samp_imp)
t_imp = np.arange(n_samp_imp) * dt / ms

fig, axs = plt.subplots(1, 3, figsize=(2 + n_u * 3, 4), layout="compressed")
axs[0].semilogy(sing_vals[: n_x_fit * 3], linewidth=2)
axs[0].set(ylabel="singular values", xlabel="singular value index")

for i_u in range(n_u):
    ax = axs[i_u + 1]
    lines = ax.plot(t_imp, y_imp[i_u].T, linewidth=2)
    lines[0].set_color(cy1)
    lines[1].set_color(cy2)
    ax.set(title=f"Impulse response for $u_{i_u + 1}$", xlabel="time (ms)")


ax.legend(lines, ["$y_1$", "$y_2$"]);
../_images/6c37b119bff4a4ef148335ad350465ca0ca84dd9e18c20c5a29dd4a572888a53.png

We see a sharp drop in singular values after the first few, which justifies our model order choice. For the impulse responses, we expect both fibers to cause transitory increases in firing rate for both electrodes. Since we see that isn’t the case for the first, let’s try refining our fit with expectation-maximization (EM):

em = glds.FitEM(fit, u_train, z_train)
fit_em = em.Run(
    calc_dynamics=True,
    calc_Q=True,
    calc_init=True,
    calc_output=True,
    calc_measurement=True,
    max_iter=50,
    tol=1e-2,
)
fit_sys_em = glds.System(fit_em)
Hide code cell output
Iteration 1/50 ...
C_new[0]: -0.105012
d_new[0]: 0.391363
R_new[0]: 0.141144
A_new[0]: 0.982364
B_new[0]: 2.38562e-06
Q_new[0]: 0.0114576
x0_new[0]: -2.53407e-05
P0_new[0]: 1.00001e-06
max dtheta: 2.5084

Iteration 2/50 ...
C_new[0]: -0.0943574
d_new[0]: 0.461353
R_new[0]: 0.140272
A_new[0]: 0.977572
B_new[0]: -0.000414522
Q_new[0]: 0.0121686
x0_new[0]: -4.11669e-05
P0_new[0]: 1.00002e-06
max dtheta: 174.758

Iteration 3/50 ...
C_new[0]: -0.0915526
d_new[0]: 0.487437
R_new[0]: 0.138667
A_new[0]: 0.973509
B_new[0]: -0.000737076
Q_new[0]: 0.0128524
x0_new[0]: -5.35115e-05
P0_new[0]: 1.00003e-06
max dtheta: 11.957

Iteration 4/50 ...
C_new[0]: -0.0932068
d_new[0]: 0.489635
R_new[0]: 0.137209
A_new[0]: 0.96977
B_new[0]: -0.00102345
Q_new[0]: 0.0135567
x0_new[0]: -6.43448e-05
P0_new[0]: 1.00005e-06
max dtheta: 1.00887

Iteration 5/50 ...
C_new[0]: -0.0970164
d_new[0]: 0.480574
R_new[0]: 0.1359
A_new[0]: 0.965911
B_new[0]: -0.00130965
Q_new[0]: 0.0142672
x0_new[0]: -7.45381e-05
P0_new[0]: 1.00008e-06
max dtheta: 2.78001

Iteration 6/50 ...
C_new[0]: -0.101797
d_new[0]: 0.46656
R_new[0]: 0.134703
A_new[0]: 0.961747
B_new[0]: -0.00161025
Q_new[0]: 0.0149682
x0_new[0]: -8.44291e-05
P0_new[0]: 1.0001e-06
max dtheta: 10.8644

Iteration 7/50 ...
C_new[0]: -0.106955
d_new[0]: 0.450582
R_new[0]: 0.133571
A_new[0]: 0.957235
B_new[0]: -0.00192869
Q_new[0]: 0.0156501
x0_new[0]: -9.41219e-05
P0_new[0]: 1.00014e-06
max dtheta: 4.36099

Iteration 8/50 ...
C_new[0]: -0.112155
d_new[0]: 0.434213
R_new[0]: 0.132473
A_new[0]: 0.952402
B_new[0]: -0.00226317
Q_new[0]: 0.0163071
x0_new[0]: -0.000103627
P0_new[0]: 1.00017e-06
max dtheta: 1.75518

Iteration 9/50 ...
C_new[0]: -0.117176
d_new[0]: 0.418421
R_new[0]: 0.131404
A_new[0]: 0.947307
B_new[0]: -0.00260934
Q_new[0]: 0.0169355
x0_new[0]: -0.000112919
P0_new[0]: 1.00021e-06
max dtheta: 0.775979

Iteration 10/50 ...
C_new[0]: -0.121863
d_new[0]: 0.403843
R_new[0]: 0.130366
A_new[0]: 0.942028
B_new[0]: -0.0029617
Q_new[0]: 0.0175338
x0_new[0]: -0.000121961
P0_new[0]: 1.00026e-06
max dtheta: 0.51042

Iteration 11/50 ...
C_new[0]: -0.126112
d_new[0]: 0.390877
R_new[0]: 0.129363
A_new[0]: 0.936648
B_new[0]: -0.00331439
Q_new[0]: 0.018102
x0_new[0]: -0.000130711
P0_new[0]: 1.00031e-06
max dtheta: 0.373061

Iteration 12/50 ...
C_new[0]: -0.129868
d_new[0]: 0.37972
R_new[0]: 0.1284
A_new[0]: 0.93125
B_new[0]: -0.00366167
Q_new[0]: 0.0186418
x0_new[0]: -0.000139133
P0_new[0]: 1.00036e-06
max dtheta: 0.277814

Iteration 13/50 ...
C_new[0]: -0.133115
d_new[0]: 0.370404
R_new[0]: 0.127476
A_new[0]: 0.925911
B_new[0]: -0.00399819
Q_new[0]: 0.0191559
x0_new[0]: -0.000147195
P0_new[0]: 1.00041e-06
max dtheta: 0.337983

Iteration 14/50 ...
C_new[0]: -0.135871
d_new[0]: 0.362844
R_new[0]: 0.126592
A_new[0]: 0.920704
B_new[0]: -0.00431912
Q_new[0]: 0.0196477
x0_new[0]: -0.000154872
P0_new[0]: 1.00047e-06
max dtheta: 0.560341

Iteration 15/50 ...
C_new[0]: -0.138171
d_new[0]: 0.356879
R_new[0]: 0.125748
A_new[0]: 0.915695
B_new[0]: -0.00462026
Q_new[0]: 0.0201215
x0_new[0]: -0.00016215
P0_new[0]: 1.00054e-06
max dtheta: 1.35138

Iteration 16/50 ...
C_new[0]: -0.140066
d_new[0]: 0.35231
R_new[0]: 0.124944
A_new[0]: 0.910944
B_new[0]: -0.00489815
Q_new[0]: 0.0205815
x0_new[0]: -0.000169019
P0_new[0]: 1.0006e-06
max dtheta: 3.93717

Iteration 17/50 ...
C_new[0]: -0.141608
d_new[0]: 0.348928
R_new[0]: 0.124183
A_new[0]: 0.906501
B_new[0]: -0.00515016
Q_new[0]: 0.0210321
x0_new[0]: -0.000175482
P0_new[0]: 1.00067e-06
max dtheta: 0.788479

Iteration 18/50 ...
C_new[0]: -0.142849
d_new[0]: 0.346533
R_new[0]: 0.123468
A_new[0]: 0.902407
B_new[0]: -0.0053746
Q_new[0]: 0.0214775
x0_new[0]: -0.000181547
P0_new[0]: 1.00075e-06
max dtheta: 0.421416

Iteration 19/50 ...
C_new[0]: -0.143838
d_new[0]: 0.344946
R_new[0]: 0.122802
A_new[0]: 0.89869
B_new[0]: -0.00557078
Q_new[0]: 0.021921
x0_new[0]: -0.000187229
P0_new[0]: 1.00082e-06
max dtheta: 0.516032

Iteration 20/50 ...
C_new[0]: -0.144618
d_new[0]: 0.344006
R_new[0]: 0.122186
A_new[0]: 0.895363
B_new[0]: -0.005739
Q_new[0]: 0.0223651
x0_new[0]: -0.000192546
P0_new[0]: 1.0009e-06
max dtheta: 1.39452

Iteration 21/50 ...
C_new[0]: -0.145228
d_new[0]: 0.343577
R_new[0]: 0.121621
A_new[0]: 0.892428
B_new[0]: -0.00588045
Q_new[0]: 0.0228111
x0_new[0]: -0.000197523
P0_new[0]: 1.00098e-06
max dtheta: 7.81432

Iteration 22/50 ...
C_new[0]: -0.145698
d_new[0]: 0.343545
R_new[0]: 0.121107
A_new[0]: 0.889872
B_new[0]: -0.00599704
Q_new[0]: 0.0232597
x0_new[0]: -0.000202186
P0_new[0]: 1.00106e-06
max dtheta: 1.00709

Iteration 23/50 ...
C_new[0]: -0.146058
d_new[0]: 0.343815
R_new[0]: 0.120641
A_new[0]: 0.887674
B_new[0]: -0.00609114
Q_new[0]: 0.0237102
x0_new[0]: -0.000206562
P0_new[0]: 1.00114e-06
max dtheta: 0.557072

Iteration 24/50 ...
C_new[0]: -0.14633
d_new[0]: 0.34431
R_new[0]: 0.12022
A_new[0]: 0.885805
B_new[0]: -0.00616545
Q_new[0]: 0.0241616
x0_new[0]: -0.000210678
P0_new[0]: 1.00123e-06
max dtheta: 0.390198

Iteration 25/50 ...
C_new[0]: -0.146531
d_new[0]: 0.344969
R_new[0]: 0.119841
A_new[0]: 0.884232
B_new[0]: -0.00622268
Q_new[0]: 0.0246122
x0_new[0]: -0.000214561
P0_new[0]: 1.00131e-06
max dtheta: 0.535436

Iteration 26/50 ...
C_new[0]: -0.146678
d_new[0]: 0.345743
R_new[0]: 0.1195
A_new[0]: 0.882919
B_new[0]: -0.00626549
Q_new[0]: 0.0250604
x0_new[0]: -0.000218235
P0_new[0]: 1.0014e-06
max dtheta: 1.00021

Iteration 27/50 ...
C_new[0]: -0.146781
d_new[0]: 0.346594
R_new[0]: 0.119192
A_new[0]: 0.881833
B_new[0]: -0.00629635
Q_new[0]: 0.025504
x0_new[0]: -0.000221725
P0_new[0]: 1.00149e-06
max dtheta: 4230.77

Iteration 28/50 ...
C_new[0]: -0.146852
d_new[0]: 0.347492
R_new[0]: 0.118914
A_new[0]: 0.880941
B_new[0]: -0.00631743
Q_new[0]: 0.0259416
x0_new[0]: -0.000225049
P0_new[0]: 1.00158e-06
max dtheta: 0.873001

Iteration 29/50 ...
C_new[0]: -0.146897
d_new[0]: 0.348415
R_new[0]: 0.118662
A_new[0]: 0.880213
B_new[0]: -0.00633065
Q_new[0]: 0.0263713
x0_new[0]: -0.000228229
P0_new[0]: 1.00167e-06
max dtheta: 0.409528

Iteration 30/50 ...
C_new[0]: -0.146923
d_new[0]: 0.349345
R_new[0]: 0.118432
A_new[0]: 0.879623
B_new[0]: -0.00633764
Q_new[0]: 0.026792
x0_new[0]: -0.00023128
P0_new[0]: 1.00177e-06
max dtheta: 0.257248

Iteration 31/50 ...
C_new[0]: -0.146934
d_new[0]: 0.350268
R_new[0]: 0.118222
A_new[0]: 0.879147
B_new[0]: -0.00633975
Q_new[0]: 0.0272026
x0_new[0]: -0.000234218
P0_new[0]: 1.00186e-06
max dtheta: 0.182744

Iteration 32/50 ...
C_new[0]: -0.146935
d_new[0]: 0.351176
R_new[0]: 0.118028
A_new[0]: 0.878766
B_new[0]: -0.00633811
Q_new[0]: 0.0276021
x0_new[0]: -0.000237055
P0_new[0]: 1.00196e-06
max dtheta: 0.139274

Iteration 33/50 ...
C_new[0]: -0.146928
d_new[0]: 0.352061
R_new[0]: 0.117849
A_new[0]: 0.878462
B_new[0]: -0.00633364
Q_new[0]: 0.0279902
x0_new[0]: -0.000239804
P0_new[0]: 1.00205e-06
max dtheta: 0.124395

Iteration 34/50 ...
C_new[0]: -0.146917
d_new[0]: 0.352917
R_new[0]: 0.117683
A_new[0]: 0.878221
B_new[0]: -0.00632708
Q_new[0]: 0.0283665
x0_new[0]: -0.000242474
P0_new[0]: 1.00215e-06
max dtheta: 0.129348

Iteration 35/50 ...
C_new[0]: -0.146902
d_new[0]: 0.353742
R_new[0]: 0.117528
A_new[0]: 0.878032
B_new[0]: -0.00631903
Q_new[0]: 0.0287308
x0_new[0]: -0.000245074
P0_new[0]: 1.00225e-06
max dtheta: 0.134967

Iteration 36/50 ...
C_new[0]: -0.146886
d_new[0]: 0.354531
R_new[0]: 0.117382
A_new[0]: 0.877884
B_new[0]: -0.00630995
Q_new[0]: 0.0290831
x0_new[0]: -0.000247612
P0_new[0]: 1.00235e-06
max dtheta: 0.141478

Iteration 37/50 ...
C_new[0]: -0.146871
d_new[0]: 0.355285
R_new[0]: 0.117245
A_new[0]: 0.877768
B_new[0]: -0.00630022
Q_new[0]: 0.0294237
x0_new[0]: -0.000250096
P0_new[0]: 1.00246e-06
max dtheta: 0.149187

Iteration 38/50 ...
C_new[0]: -0.146856
d_new[0]: 0.356002
R_new[0]: 0.117115
A_new[0]: 0.877679
B_new[0]: -0.00629013
Q_new[0]: 0.0297528
x0_new[0]: -0.00025253
P0_new[0]: 1.00256e-06
max dtheta: 0.15852

Iteration 39/50 ...
C_new[0]: -0.146843
d_new[0]: 0.356681
R_new[0]: 0.116993
A_new[0]: 0.87761
B_new[0]: -0.00627991
Q_new[0]: 0.0300707
x0_new[0]: -0.000254921
P0_new[0]: 1.00267e-06
max dtheta: 0.170103

Iteration 40/50 ...
C_new[0]: -0.146832
d_new[0]: 0.357325
R_new[0]: 0.116876
A_new[0]: 0.877557
B_new[0]: -0.00626973
Q_new[0]: 0.030378
x0_new[0]: -0.000257274
P0_new[0]: 1.00278e-06
max dtheta: 0.184892

Iteration 41/50 ...
C_new[0]: -0.146825
d_new[0]: 0.357932
R_new[0]: 0.116765
A_new[0]: 0.877516
B_new[0]: -0.00625972
Q_new[0]: 0.0306751
x0_new[0]: -0.000259592
P0_new[0]: 1.00289e-06
max dtheta: 0.204436

Iteration 42/50 ...
C_new[0]: -0.14682
d_new[0]: 0.358504
R_new[0]: 0.116658
A_new[0]: 0.877484
B_new[0]: -0.00624998
Q_new[0]: 0.0309625
x0_new[0]: -0.000261879
P0_new[0]: 1.003e-06
max dtheta: 0.231434

Iteration 43/50 ...
C_new[0]: -0.146819
d_new[0]: 0.359043
R_new[0]: 0.116556
A_new[0]: 0.877459
B_new[0]: -0.00624056
Q_new[0]: 0.0312407
x0_new[0]: -0.000264138
P0_new[0]: 1.00312e-06
max dtheta: 0.271044

Iteration 44/50 ...
C_new[0]: -0.146821
d_new[0]: 0.35955
R_new[0]: 0.116459
A_new[0]: 0.877439
B_new[0]: -0.00623152
Q_new[0]: 0.0315103
x0_new[0]: -0.000266373
P0_new[0]: 1.00323e-06
max dtheta: 0.334533

Iteration 45/50 ...
C_new[0]: -0.146827
d_new[0]: 0.360026
R_new[0]: 0.116365
A_new[0]: 0.877422
B_new[0]: -0.00622289
Q_new[0]: 0.0317718
x0_new[0]: -0.000268586
P0_new[0]: 1.00335e-06
max dtheta: 0.452143

Iteration 46/50 ...
C_new[0]: -0.146837
d_new[0]: 0.360474
R_new[0]: 0.116274
A_new[0]: 0.877408
B_new[0]: -0.00621469
Q_new[0]: 0.0320256
x0_new[0]: -0.00027078
P0_new[0]: 1.00347e-06
max dtheta: 0.742147

Iteration 47/50 ...
C_new[0]: -0.146849
d_new[0]: 0.360894
R_new[0]: 0.116186
A_new[0]: 0.877394
B_new[0]: -0.00620691
Q_new[0]: 0.0322723
x0_new[0]: -0.000272956
P0_new[0]: 1.00359e-06
max dtheta: 2.58805

Iteration 48/50 ...
C_new[0]: -0.146866
d_new[0]: 0.361289
R_new[0]: 0.116102
A_new[0]: 0.87738
B_new[0]: -0.00619957
Q_new[0]: 0.0325123
x0_new[0]: -0.000275116
P0_new[0]: 1.00371e-06
max dtheta: 1.46554

Iteration 49/50 ...
C_new[0]: -0.146885
d_new[0]: 0.361659
R_new[0]: 0.11602
A_new[0]: 0.877367
B_new[0]: -0.00619266
Q_new[0]: 0.0327461
x0_new[0]: -0.000277262
P0_new[0]: 1.00384e-06
max dtheta: 0.53465

Iteration 50/50 ...
C_new[0]: -0.146908
d_new[0]: 0.362008
R_new[0]: 0.11594
A_new[0]: 0.877352
B_new[0]: -0.00618617
Q_new[0]: 0.0329742
x0_new[0]: -0.000279396
P0_new[0]: 1.00396e-06
max dtheta: 0.313476

We’ll plot impulse responses again:

y_imp = fit_sys_em.simulate_imp(n_samp_imp)

fig, axs = plt.subplots(1, 2, figsize=(3 * n_u, 4), layout="compressed")
for i_u in range(n_u):
    ax = axs[i_u]
    lines = ax.plot(t_imp, y_imp[i_u].T, linewidth=2)
    lines[0].set_color(cy1)
    lines[1].set_color(cy2)
    ax.set(title=f"Impulse response for $u_{i_u + 1}$", xlabel="time (ms)")

ax.legend(lines, ["$y_1$", "$y_2$"]);
../_images/cec43fab46378044ad92228ba403304d7297100697d67a70c72f070b8eac723e.png

Looking better. Since we can’t visualize singular values, let’s compare the two fits with percent explained variance:

z_arr = np.array(z)


def pct_exp_var(sys, nstep=1):
    x_filt, x_pred, y_pred = sys.nstep_pred_block(u_train, z_train, nstep)
    y_pred = np.array(y_pred)
    assert y_pred.shape == z_arr[:, :, nstep:].shape, y_pred.shape
    corrcoef_mat = np.corrcoef(z_arr[:, :, nstep:].flatten(), y_pred[:, :, :].flatten())
    return corrcoef_mat[0, 1] ** 2


print("SSID r²:", pct_exp_var(fit_sys_ssid))
print("EM r²:", pct_exp_var(fit_sys_em))
SSID r²: 0.07508708150847317
EM r²: 0.26515411771446507

As we suspected from the impulse responses, the EM-refined fit is much better.

LQR optimal controller design

We now use the fit parameters to create the controller system and set additional parameters. The feedback gain, \( K_c \), is especially important, determining how the controller responds to the current “error”—the difference between where the system is (estimated to be) now and where we want it to be. The field of optimal control deals with how to design the controller so as to minimize a cost function reflecting what we care about.

With a linear system (obtained from the fitting procedure above) and quadratic per-timestep cost function \(J\) penalizing distance from the reference \(x^*\) and the input \(u\)

\[ L = \frac{1}{2} (x - x^*)^T Q (x - x^*) + \frac{1}{2} u^T R u \]

we can use the closed-form optimal solution called the Linear Quadratic Regulator (LQR).

\[ K = (R + B^T P B)^{-1}(B^T P A) \quad\quad u = -Kx\]

The \(P\) matrix is obtained by numerically solving the discrete algebraic Riccati equation:

\[ P=A^{T} P A-\left(A^{T} P B\right)\left(R+B^{T} P B\right)^{-1}\left(B^{T} P A\right)+Q \]
# upper and lower bounds on control signal (optic fiber light intensity)
u_lb = 0  # mW/mm2
u_ub = 75  # mW/mm2
controller = glds.Controller(fit_sys_em, u_lb, u_ub)
# careful not to use this anymore since controller made a copy
del fit_sys_em
from scipy.linalg import solve_discrete_are

A, B, C = controller.sys.A, controller.sys.B, controller.sys.C
# cost matrices
# Q reflects how much we care about state error
# we use C'C since we really care about output error, not latent state
Q_cost = C.T @ C
R_cost = 1e-4 * np.eye(n_u)  # reflects how much we care about minimizing the stimulus
P = solve_discrete_are(A, B, Q_cost, R_cost)
controller.Kc = np.linalg.inv(R_cost + B.T @ P @ B) @ (B.T @ P @ A)
print(controller)
print("For controlled system dynamics A - BK:")
print("eigvals:", np.linalg.eigvals(A - B @ controller.Kc))
print("magnitude of eigvals:", np.abs(np.linalg.eigvals(A - B @ controller.Kc)))
 ********** SYSTEM ********** 
x: 
   0.2416
   0.2736
  -0.4363

P: 
   0.1522  -0.0049   0.0092
  -0.0049   0.2306  -0.0099
   0.0092  -0.0099   0.2138

A: 
   0.8774  -0.1411   0.1746
   0.2014   0.0053   0.4933
   0.0106  -0.8962  -0.3273

B: 
  -0.0062  -0.0065
   0.0161   0.0073
   0.0017   0.0012

g: 
   1.0000
   1.0000

m: 
        0
        0
        0

Q: 
   0.0330  -0.0765  -0.0013
  -0.0765   0.2050   0.0040
  -0.0013   0.0040   0.0198

Q_m: 
   1.0000e-06            0            0
            0   1.0000e-06            0
            0            0   1.0000e-06

d: 
   0.3620
   0.3859

C: 
  -0.1469   0.1281  -0.0339
  -0.1480   0.1816  -0.0983

y: 
   0.3764
   0.4427

R: 
   0.1159  -0.0039
  -0.0039   0.1031

g_design :    1.0000
   1.0000

u_lb : 0
u_ub : 75

For controlled system dynamics A - BK:
eigvals: [ 0.72559149+0.j         -0.15320079+0.55360005j -0.15320079-0.55360005j]
magnitude of eigvals: [0.72559149 0.57440709 0.57440709]

We now configure a LatencyIOProcessor to use our controller:

y_ref = np.mean(z) * 0.75  # target rate per timebin


class CtrlLoop(cleo.LatencyIOProcessor):
    def __init__(self, sample_period, controller):
        super().__init__(sample_period)
        self.controller = controller
        self.sys = controller.sys
        self.do_control = False  # allows us to turn on and off control

        # for post hoc visualization/analysis:
        self.x_hat = np.empty((n_x_fit, 0))
        self.y_hat = np.empty((n_z, 0))
        self.z = np.empty((n_z, 0))

    def process(self, state_dict, t_samp):
        i, t, z_t = state_dict["Probe"]["spikes"]
        z_t = z_t[:n_z].reshape((-1, 1))  # just first n_z neurons
        self.controller.y_ref = np.ones((n_z, 1)) * y_ref

        u_t = self.controller.ControlOutputReference(z_t, do_control=self.do_control)
        out = {fibers.name: u_t.squeeze() * b2.mwatt / b2.mm2}

        # record variables from this timestep
        self.y_hat = np.hstack([self.y_hat, self.sys.y])
        self.x_hat = np.hstack([self.x_hat, self.sys.x])
        self.z = np.hstack((self.z, z_t))

        return out, t_samp + 3 * ms  # 3 ms delay


ctrl_loop = CtrlLoop(sample_period=dt, controller=controller)

Run the experiment

We’ll now run the simulation with and without control to compare.

sim.reset()  # only needed when rerunning
sim.set_io_processor(ctrl_loop)
T0 = 200 * ms
sim.run(T0)

ctrl_loop.do_control = True
T1 = 1000 * ms
sim.run(T1)

Now we plot the results to see how well the controller was able to match the desired firing rate:

def y_ref_fn(t, in_hertz=True):
    y_baseline = np.mean(controller.sys.d)
    if in_hertz:
        yr = y_ref / dt / Hz
        y_baseline /= dt * Hz
    else:
        yr = y_ref
    return y_baseline * (t < T0) + yr * (t >= T0)  # target rate per timebin


def plot_ctrl(loop):
    fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
    ax1.set(ylabel="spikes/s", title="real-time estimated firing rates")
    ax1.plot(
        loop.t_samp / ms,
        loop.y_hat[0] / dt,
        c=cy1,
        alpha=0.7,
        label="$\\hat{y}_1$",
    )
    ax1.plot(
        loop.t_samp / ms,
        loop.y_hat[1] / dt,
        c=cy2,
        alpha=0.7,
        label="$\\hat{y}_2$",
    )
    ax1.plot(loop.t_samp / ms, y_ref_fn(loop.t_samp), c="green", label="reference")
    ax1.axvline(T0 / ms, c="xkcd:red", linestyle=":", label="ctrl on")
    ax1.legend(loc="right")
    ax2.plot(fibers.t / b2.ms, fibers.irradiance_[:, 0], c=cu1, label="$u_1$")
    ax2.plot(fibers.t / b2.ms, fibers.irradiance_[:, 1], c=cu2, label="$u_2$")
    ax2.legend()
    ax2.set(xlabel="t (ms)", ylabel="u (mW/mm$^2$)", title="input signal")


plot_ctrl(ctrl_loop)
../_images/b6b19d661ff79896a6aacd05709623e240053a031eb789eb5939816d995a138e.png

What gives? Why does the controller keep the input on when it’s clearly over the target? Let’s check a Gaussian-smoothed post hoc estimated firing rate to confirm the firing rates are too high:

from scipy.ndimage import gaussian_filter1d


def plot_post_hoc(loop):
    i_baseline = loop.t_samp < T0
    i_static = (loop.t_samp >= T0) & (loop.t_samp < T0 + T1)
    print("Results (spikes/second):")
    print("baseline =", np.sum(loop.z[:, i_baseline], axis=1) / T0)
    print("target =", y_ref / dt)
    print(
        "achieved =",
        (np.sum(loop.z[:, i_static], axis=1) / T1).round(1),
    )

    win_len = 25 * ms / dt
    smoothed = gaussian_filter1d(loop.z, sigma=win_len, axis=1) / dt
    plt.plot(loop.t_samp / ms, y_ref_fn(loop.t_samp), c="green", label="reference")
    plt.axvline(T0 / ms, c="r", ls=":")
    plt.xlabel("t (ms)")
    plt.ylabel("spikes/s")
    plt.title("Gaussian-smoothed firing rate")
    plt.plot(loop.t_samp / ms, smoothed[0], c=cy1)
    plt.plot(loop.t_samp / ms, smoothed[1], c=cy2)


plot_post_hoc(ctrl_loop)
Results (spikes/second):
baseline = [5. 0.] Hz
target = 132.6 Hz
achieved = [198. 181.] Hz
../_images/9c22885f930d58870f63cf43538f13867298923cdf659daf39d067603c1c0bc0.png

And so they are. Another possibility is that the steady-state set point calculated by the controller is unattainable. Let’s see:

controller.u_ref
array([[-31.29477828],
       [ 84.5755037 ]])

It looks to be the case. The controller kept \(u_1\) high since what it was trying to use negative \(u_2\) at the same time—unfortunately, we can’t exactly use negative light. Even if the controller didn’t fail in this way, its ignorance of the upper limit (75 mW/mm²) poses another problem. This is a fundamental limitation of LQR; it can’t account for constraints. What’s the solution?

Accounting for input constraints

Let’s help out LQR by computing a set point that respects our constraints. We need to break out convex optimization tools to do so though.

import cvxpy as cp


def opt_u_x():
    """
    Solve the optimal control problem using CVXPY.
    """
    # Define the optimization variables
    u = cp.Variable(n_u)
    x = cp.Variable(n_x_fit)
    y_r = np.full(n_z, y_ref)

    # Define the cost function
    cost = cp.sum_squares(C @ x + controller.sys.d - y_r) + cp.quad_form(u, R_cost)

    # Define the constraints
    constraints = [
        x == A @ x + B @ u,
        u >= u_lb,
        u <= u_ub,
    ]

    # Define the optimization problem
    prob = cp.Problem(cp.Minimize(cost), constraints)
    prob.solve()

    return u.value, x.value


u_ref, x_ref = opt_u_x()
u_ref, x_ref
WARNING    /home/kyle/miniforge3/envs/cleo/lib/python3.12/site-packages/cvxpy/reductions/solvers/solving_chain_utils.py:30: UserWarning: The problem includes expressions that don't support CPP backend. Defaulting to the SCIPY backend for canonicalization.
  warnings.warn(UserWarning(
 [py.warnings]
(array([17.06548526, 12.72169718]),
 array([-1.59446947,  0.04281098, -0.007822  ]))

That u_ref looks much better! Let’s try again, this time using the Control method of the controller, that lets us set the x and u reference values ourselves, rather than computing the unhelpful values we saw.

class CtrlLoop2(cleo.LatencyIOProcessor):
    def __init__(self, sample_period, controller):
        super().__init__(sample_period)
        self.controller = controller
        self.sys = controller.sys
        self.do_control = False  # allows us to turn on and off control

        # for post hoc visualization/analysis:
        self.x_hat = np.empty((n_x_fit, 0))
        self.y_hat = np.empty((n_z, 0))
        self.z = np.empty((n_z, 0))

    def process(self, state_dict, t_samp):
        i, t, z_t = state_dict["Probe"]["spikes"]
        z_t = z_t[:n_z].reshape((-1, 1))  # just first n_z neurons

        self.controller.u_ref = u_ref
        self.controller.x_ref = x_ref
        u_t = self.controller.Control(z_t, do_control=self.do_control)
        if not self.do_control:
            u_t = np.zeros((n_u, 1))  # no control
        out = {fibers.name: u_t.squeeze() * b2.mwatt / b2.mm2}

        # record variables from this timestep
        self.y_hat = np.hstack([self.y_hat, self.sys.y])
        self.x_hat = np.hstack([self.x_hat, self.sys.x])
        self.z = np.hstack((self.z, z_t))

        return out, t_samp + 3 * ms  # 3 ms delay


ctrl_loop2 = CtrlLoop2(sample_period=dt, controller=controller)
sim.reset()
sim.set_io_processor(ctrl_loop2)
sim.run(T0)

ctrl_loop2.do_control = True
sim.run(T1)
plot_ctrl(ctrl_loop2)
../_images/8439b81f8965af861ab275c58ebb8d893adef011873015326bc121a5a4898b13.png
plot_post_hoc(ctrl_loop2)
Results (spikes/second):
baseline = [15. 20.] Hz
target = 132.6 Hz
achieved = [151. 152.] Hz
../_images/4f28c96a41c895df3e1f10b135f5304c8fdc72afb28e9f096663abea70ae749f.png

There we go! That looks much better.

Conclusion

As a recap, in this tutorial we’ve seen how to:

  • inject optogenetic stimulation into an existing Brian network

  • inject an electrode into an existing Brian network to record spikes

  • generate training data and fit a Gaussian linear dynamical system to the spiking output using ldsctrlest

  • configure an ldsctrlest LQR controller based on that linear system and design optimal gains

  • use that controller in running a complete simulated feedback control experiment

  • work around the limitations of LQR by computing a feasible setpoint using convex optimization

Bonus: model predictive control

The stage is now set for model predictive control (MPC), which, by solving an optimization problem at every timestep, has two important capabilities LQR does not:

  1. It can look ahead, optimizing a time-varying control trajectory over a horizon of specified length.

  2. It can account for linear constraints on states and inputs.

The main downside is the increased computational cost compared to LQR.

Let’s try this experiment again with MPC instead, using the lqmpc package.

MPC setup

import lqmpc

H = 20
mpc = lqmpc.LQMPC(dt / second, A, B, C)
# N, M are prediction and control horizons
mpc.set_control(Q_cost, R_cost, N=H, M=H - 1)
mpc.set_constraints(umin=np.zeros(n_u), umax=np.full(n_u, u_ub))

Since MPC runs a quadratic program under the hood every timestep, one advantage we should have is not needing to think about the optimal setpoints like we just did. However, lqmpc only support a state, not output reference, so we’ll just use the same reference value for now. Adding MPC to ldsctrlest is planned, including control by output reference; when implemented, this tutorial should be updated.

# lqmpc doesn't support directly optimizing y_ref, so we use the same steady-state x_ref
def xref_fn(t):
    return (t >= T0) * x_ref[:, None]

We mentioned a downside of MPC was the increased computational cost. Let’s measure how long the controller takes:

mpc.simulate(
    t_sim=dt / second,
    x0=np.zeros(n_x_fit),
    u0=u_ref,
    xr=xref_fn(np.arange(200) * dt),
    L=150,
);
Simulation Time: 0.30167532 s 
Mean Step Time: 0.00201117 s 

Roughly 3 ms. Good to know! Let’s increase the latency of our control loop from 3 to 6 ms then. Unlike ldsCtrlEst, lqmpc does not have built-in Kalman filtering, so we’ll need to add that too:

class MPCLoop(cleo.LatencyIOProcessor):
    def __init__(self, sample_period):
        super().__init__(sample_period)
        # for post hoc visualization/analysis:
        self.x_hat = np.empty((n_x_fit, 0))
        self.y_hat = np.empty((n_z, 0))
        self.z = np.empty((n_z, 0))

        mpc.xi = np.zeros(n_x_fit)
        mpc.ui = u_ref
        self.P_cov = controller.sys.P  # initialize state covariance matrix
        self.Q = controller.sys.Q
        self.R = controller.sys.R
        self.d = controller.sys.d.squeeze()

    def process(self, state_dict, t_samp):
        i, t, z_t = state_dict["Probe"]["spikes"]
        z_t = z_t[:n_z]  # just first n_z neurons

        # update state covariance for Kalman filtering
        # predict step
        P_pred = A @ self.P_cov @ A.T + self.Q
        # Kalman gain
        S = C @ P_pred @ C.T + self.R
        K = P_pred @ C.T @ np.linalg.inv(S)
        # Update step
        z_err = z_t - (C @ mpc.xi + self.d)  # Measurement residual
        mpc.xi = mpc.xi + K @ z_err
        self.P_cov = (np.eye(len(P_pred)) - K @ C) @ P_pred
        x_hat = mpc.xi.copy()

        mpc.step(
            dt / second, mpc.xi, mpc.ui, xref_fn(t_samp + np.arange(H) * dt), out=False
        )
        u_t = mpc.ui

        out = {fibers.name: u_t.squeeze() * b2.mwatt / b2.mm2}

        # record variables from this timestep
        y_hat = C @ x_hat + self.d
        self.y_hat = np.column_stack([self.y_hat, y_hat])
        self.x_hat = np.column_stack([self.x_hat, x_hat])
        self.z = np.column_stack((self.z, z_t))

        return out, t_samp + 6 * ms


mpc_loop = MPCLoop(sample_period=dt)

Rerunning the experiment

sim.reset()
sim.set_io_processor(mpc_loop)
sim.run(T0 + T1)
WARNING    /home/kyle/GaTech Dropbox/Kyle Johnsen/projects/cleo/cleo/light/light.py:557: UserWarning: fiber: negative light value clipped to 0
  warnings.warn(f"{self.name}: negative light value clipped to 0")
 [py.warnings]
plot_ctrl(mpc_loop)
../_images/22d0abe7a33107e9aa8059cc1c10b5d9e0316d1f1a5875df88faf3e5fb608ed5.png
plot_post_hoc(mpc_loop)
Results (spikes/second):
baseline = [50. 60.] Hz
target = 132.6 Hz
achieved = [135. 135.] Hz
../_images/01391d194e21c460b90e423864797e9dbb38195d34972acdccd0d393374ff0d4.png

Also looking good. We don’t see much of an advantage over LQR with a constraint-informed setpoint here since the optimization we performed was most of what was needed for a static reference. This essentially did part of MPC’s job for it; the difference being that MPC optimizes an entire trajectory (over a receding horizon), ensuring constraints are met at every step. We do see less overshoot though, which makes sense since our controller is looking 100 ms ahead.

This should be revisited with a time-varying reference once ldsctrlest’s MPC with output reference control is available.