PI control#

In this tutorial we’ll introduce

  1. PI control, a commonly used model-free control method,

  2. the concept of decomposing the IOProcessor’s computation into ProcessingBlocks, and

  3. modeling computation delays on those blocks to reflect hardware and algorithmic speed limitations present in a real experiment.

Preamble:

from brian2 import *
import matplotlib.pyplot as plt
from cleo import *

utilities.style_plots_for_docs()

np.random.seed(7000)

# the default cython compilation target isn't worth it for 
# this trivial example
prefs.codegen.target = "numpy"

Create the Brian network#

We’ll create a population of 10 LIF neurons mainly driven by feedforward input but with some recurrent connections as well.

n = 10
population = NeuronGroup(n, '''
            dv/dt = (-v - 70*mV + Rm*I) / tau : volt
            tau: second
            Rm: ohm
            I: amp''',
        threshold='v>-50*mV',
        reset='v=-70*mV'
)
population.tau = 10*ms
population.Rm = 100*Mohm
population.I = 0*mA
population.v = -70*mV

input_group = PoissonGroup(n, np.linspace(20, 200, n)*Hz)

S = Synapses(input_group, population, on_pre='v+=5*mV')
S.connect(condition=f'abs(i-j)<={3}')
S2 = Synapses(population, population, on_pre='v+=2*mV')
S2.connect(p=0.2)

pop_mon = SpikeMonitor(population)

net = Network(population, input_group, S, S2, pop_mon)
population.equations
\[\begin{split}\begin{align*}\frac{\mathrm{d}v}{\mathrm{d}t} &= \frac{I Rm - 70 mV - v}{\tau} && \text{(unit of $v$: $\mathrm{V}$)}\\ I &&& \text{(unit: $\mathrm{A}$)}\\ Rm &&& \text{(unit: $\mathrm{ohm}$)}\\ \tau &&& \text{(unit: $\mathrm{s}$)}\end{align*}\end{split}\]

Run simulation without control:#

net.run(100*ms)
INFO       No numerical integration method specified for group 'neurongroup', using method 'exact' (took 0.06s). [brian2.stateupdaters.base.method_choice]
fig, ax = plt.subplots()
ax.scatter(pop_mon.t / ms, pop_mon.i, marker='|', s=200); 
ax.set(title='population spiking', ylabel='neuron index', xlabel='time (ms)');
../_images/PI_ctrl_6_0.png

Constructing a closed-loop simulation#

We will use the popular model-free PI control to control a single neuron’s firing rate. PI stands for proportional-integral, referring to a feedback gain proportional to the instantaneous error as well as the integrated error over time.

First we construct a CLSimulator from the network:

from cleo import CLSimulator
sim = CLSimulator(net)

Then, to control neuron \(i\), we need to:

  1. capture spiking using a GroundTruthSpikeRecorder

from cleo.recorders import GroundTruthSpikeRecorder
i = 0  # neuron to control
rec = GroundTruthSpikeRecorder(name='spike_rec')
sim.inject(rec, population[i])
CLSimulator(io_processor=None, devices={GroundTruthSpikeRecorder(brian_objects={<SpikeMonitor, recording from 'spikemonitor_1'>}, sim=..., name='spike_rec', _mon=<SpikeMonitor, recording from 'spikemonitor_1'>, _num_spikes_seen=0, neuron_group=<Subgroup 'neurongroup_subgroup' of 'neurongroup' from 0 to 1>)})
  1. define the firing rate trajectory we want our target neuron to follow

# the target firing rate trajectory, as a function of time
def target_Hz(t_ms):
    if t_ms < 250:  # constant target at first
        return 400
    else:  # sinusoidal afterwards
        a = 200
        t_s = t_ms / 1000
        return a + a * np.sin(2 * np.pi * 20 * t_s)
  1. estimating its firing rate from incoming spikes using a FiringRateEstimator

  2. compute the stimulus intensity with a PIController

  3. output that value for a StateVariableSetter stimulator to use

Here we initialize blocks when the IOProcessor is created and define how to process network output and set the control signal in the process function.

from cleo.ioproc import (
    LatencyIOProcessor,
    FiringRateEstimator,
    ConstantDelay,
    PIController,
)

class PIRateIOProcessor(LatencyIOProcessor):
    delta = 1  # ms

    def __init__(self):
        super().__init__(sample_period_ms=self.delta, processing="parallel")
        self.rate_estimator = FiringRateEstimator(
            tau_ms=15,
            sample_period_ms=self.delta,
            delay=ConstantDelay(4.1),  # latency in ms
            save_history=True,  # lets us plot later
        )

        # using hand-tuned gains that seem reasonable
        self.pi_controller = PIController(
            target_Hz,
            Kp=0.005,
            Ki=0.04,
            sample_period_ms=self.delta,
            delay=ConstantDelay(2.87),  # latency in ms
            save_history=True,  # lets us plot later
        )

    def process(self, state_dict, sample_time_ms):
        spikes = state_dict["spike_rec"]
        # feed output and out_time through each block
        out, time_ms = self.rate_estimator.process(
            spikes, sample_time_ms, sample_time_ms=sample_time_ms
        )
        out, time_ms = self.pi_controller.process(
            out, time_ms, sample_time_ms=sample_time_ms
        )
        # this dictionary output format allows for the flexibility
        # of controlling multiple stimulators
        if out < 0:  # limit to positive current
            out = 0
        out_dict = {"I_stim": out}
        # time_ms at the end reflects the delays added by each block
        return out_dict, time_ms

io_processor = PIRateIOProcessor()
sim.set_io_processor(io_processor)
CLSimulator(io_processor=<__main__.PIRateIOProcessor object at 0x7f6be05fdb70>, devices={GroundTruthSpikeRecorder(brian_objects={<SpikeMonitor, recording from 'spikemonitor_1'>}, sim=..., name='spike_rec', _mon=<SpikeMonitor, recording from 'spikemonitor_1'>, _num_spikes_seen=0, neuron_group=<Subgroup 'neurongroup_subgroup' of 'neurongroup' from 0 to 1>)})

Note that we can set delays for individual ProcessingBlocks in the IO processor to better approximate the experiment. We use simple constant delays here, but a GaussianDelay class is also available and others could be easily implemented.

Now we inject the stimulator:

from cleo.stimulators import StateVariableSetter
sim.inject(
        StateVariableSetter(
            name='I_stim', variable_to_ctrl='I', unit=nA),
        population[i]
)
CLSimulator(io_processor=<__main__.PIRateIOProcessor object at 0x7f6be05fdb70>, devices={StateVariableSetter(brian_objects=set(), sim=..., name='I_stim', value=0, default_value=0, save_history=True, variable_to_ctrl='I', unit=namp, neuron_groups=[<Subgroup 'neurongroup_subgroup_1' of 'neurongroup' from 0 to 1>]), GroundTruthSpikeRecorder(brian_objects={<SpikeMonitor, recording from 'spikemonitor_1'>}, sim=..., name='spike_rec', _mon=<SpikeMonitor, recording from 'spikemonitor_1'>, _num_spikes_seen=0, neuron_group=<Subgroup 'neurongroup_subgroup' of 'neurongroup' from 0 to 1>)})

Run the simulation#

sim.run(300*ms)
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True);
ax1.plot(pop_mon.t / ms, pop_mon.i[:], '|');
ax1.plot(pop_mon.t[pop_mon.i == i]/ms, pop_mon.i[pop_mon.i==i], '|', c='xkcd:hot pink') 
ax1.set(title='population spiking', ylabel='neuron index')

ax2.plot(io_processor.rate_estimator.t_in_ms, io_processor.rate_estimator.values, c='xkcd:hot pink');
ax2.plot(io_processor.rate_estimator.t_in_ms, [target_Hz(t) for t in io_processor.rate_estimator.t_in_ms],\
         c='xkcd:green');
ax2.set(ylabel='firing rate (Hz)', title=f'neuron {i} activity');
ax2.legend(['estimated firing rate', 'target firing rate']);

ax3.plot(io_processor.pi_controller.t_out_ms, io_processor.pi_controller.values, c='xkcd:cerulean')
ax3.set(title='control input', ylabel='$I_{stim}$ (nA)', xlabel='time (ms)')

fig.tight_layout()
fig.show()
WARNING    /tmp/ipykernel_28526/1609733203.py:16: UserWarning: Matplotlib is currently using module://matplotlib_inline.backend_inline, which is a non-GUI backend, so cannot show the figure.
  fig.show()
 [py.warnings]
../_images/PI_ctrl_19_1.png

Note the lag in keeping up with the target firing rate, which can be directly attributed to the ~7 ms delay we coded in to the IO processor.

Conclusion#

In this tutorial, we’ve learned how to

  • use PI control to interact with a Brian simulation,

  • decompose processing steps into blocks, and

  • assign delays to processing blocks to model real-life latency.