.. currentmodule:: brian2

.. modelfitting_sbi:

Example: modelfitting_sbi
=========================


        .. only:: html

            .. |launchbinder| image:: file:///usr/share/doc/python-brian-doc/docs/badge.svg
            .. _launchbinder: https://mybinder.org/v2/gh/brian-team/brian2-binder/master?filepath=examples/advanced/modelfitting_sbi.ipynb

            .. note::
               You can launch an interactive, editable version of this
               example without installing any local files
               using the Binder service (although note that at some times this
               may be slow or fail to open): |launchbinder|_

        

Model fitting with simulation-based inference
---------------------------------------------

In this example, a HH-type model is used to demonstrate simulation-based
inference with the sbi toolbox (https://www.mackelab.org/sbi/). It is based
on a fake current-clamp recording generated from the same model that we use
in the inference process. Two of the parameters (the maximum sodium and
potassium conductances) are considered parameters of the model.

For more details about this approach, see the references below.

To run this example, you need to install the sbi package, e.g. with::

    pip install sbi

References:

* https://www.mackelab.org/sbi
* Tejero-Cantero et al., (2020). sbi: A toolkit for simulation-based inference.
  Journal of Open Source Software, 5(52), 2505, https://doi.org/10.21105/joss.02505

::

    import matplotlib.pyplot as plt
    
    from brian2 import *
    import sbi.utils
    import sbi.analysis
    import sbi.inference
    import torch  # PyTorch
    
    defaultclock.dt = 0.05*ms
    
    def simulate(params, I=1*nA, t_on=50*ms, t_total=350*ms):
        """
        Simulates the HH-model with Brian2 for parameter sets in params and the
        given input current (injection of I between t_on and t_total-t_on).
    
        Returns a dictionary {'t': time steps, 'v': voltage,
                              'I_inj': current, 'spike_count': spike count}.
        """
        assert t_total > 2*t_on
        t_off = t_total - t_on
        
        params = np.atleast_2d(params)
        # fixed parameters
        gleak = 10*nS
        Eleak = -70*mV
        VT = -60.0*mV
        C = 200*pF
        ENa = 53*mV
        EK = -107*mV
    
        # The conductance-based model
        eqs = '''
             dVm/dt = -(gNa*m**3*h*(Vm - ENa) + gK*n**4*(Vm - EK) + gleak*(Vm - Eleak) - I_inj) / C : volt
             I_inj = int(t >= t_on and t < t_off)*I : amp (shared)
             dm/dt = alpham*(1-m) - betam*m : 1
             dn/dt = alphan*(1-n) - betan*n : 1
             dh/dt = alphah*(1-h) - betah*h : 1
    
             alpham = (-0.32/mV) * (Vm - VT - 13.*mV) / (exp((-(Vm - VT - 13.*mV))/(4.*mV)) - 1)/ms : Hz
             betam = (0.28/mV) * (Vm - VT - 40.*mV) / (exp((Vm - VT - 40.*mV)/(5.*mV)) - 1)/ms : Hz
    
             alphah = 0.128 * exp(-(Vm - VT - 17.*mV) / (18.*mV))/ms : Hz
             betah = 4/(1 + exp((-(Vm - VT - 40.*mV)) / (5.*mV)))/ms : Hz
    
             alphan = (-0.032/mV) * (Vm - VT - 15.*mV) / (exp((-(Vm - VT - 15.*mV)) / (5.*mV)) - 1)/ms : Hz
             betan = 0.5*exp(-(Vm - VT - 10.*mV) / (40.*mV))/ms : Hz
             # The parameters to fit
             gNa : siemens (constant)
             gK : siemens (constant)
             '''
        neurons = NeuronGroup(params.shape[0], eqs, threshold='m>0.5', refractory='m>0.5',
                              method='exponential_euler', name='neurons')
        Vm_mon = StateMonitor(neurons, 'Vm', record=True, name='Vm_mon')
        spike_mon = SpikeMonitor(neurons, record=False, name='spike_mon')  #record=False → do not record times
        neurons.gNa_ = params[:, 0]*uS
        neurons.gK = params[:, 1]*uS
    
        neurons.Vm = 'Eleak'
        neurons.m = '1/(1 + betam/alpham)'         # Would be the solution when dm/dt = 0
        neurons.h = '1/(1 + betah/alphah)'         # Would be the solution when dh/dt = 0
        neurons.n = '1/(1 + betan/alphan)'         # Would be the solution when dn/dt = 0
    
        run(t_total)
        # For convenient plotting, reconstruct the current
        I_inj = ((Vm_mon.t >= t_on) & (Vm_mon.t < t_off))*I
        return dict(v=Vm_mon.Vm,
                    t=Vm_mon.t,
                    I_inj=I_inj,
                    spike_count=spike_mon.count)
    
    
    def calculate_summary_statistics(x):
        """Calculate summary statistics for results in x"""
        I_inj = x["I_inj"]
        v = x["v"]/mV
        
        spike_count = x["spike_count"]
        # Mean and standard deviation during stimulation
        v_active = v[:, I_inj > 0*nA]
        mean_active = np.mean(v_active, axis=1)
        std_active = np.std(v_active, axis=1)
        # Height of action potential peaks
        max_v = np.max(v_active, axis=1)
        
        # concatenation of summary statistics
        sum_stats = np.vstack((spike_count, mean_active, std_active, max_v))
    
        return sum_stats.T
    
    
    def simulation_wrapper(params):
        """
        Returns summary statistics from conductance values in `params`.
        Summarizes the output of the simulation and converts it to `torch.Tensor`.
        """
        obs = simulate(params)
        summstats = torch.as_tensor(calculate_summary_statistics(obs))
        return summstats.to(torch.float32)
    
    
    if __name__ == '__main__':
        # Define prior distribution over parameters
        prior_min = [.5, 1e-4]  # (gNa, gK) in µS
        prior_max = [80.,15.]
        prior = sbi.utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min),
                                                high=torch.as_tensor(prior_max))
    
        # Simulate samples from the prior distribution
        theta = prior.sample((10_000,))
        print('Simulating samples from prior simulation... ', end='')
        stats = simulation_wrapper(theta.numpy())
        print('done.')
    
        # Train inference network
        density_estimator_build_fun = sbi.utils.posterior_nn(model='mdn')
        inference = sbi.inference.SNPE(prior,
                                       density_estimator=density_estimator_build_fun)
        print('Training inference network... ')
        inference.append_simulations(theta, stats).train()
        posterior = inference.build_posterior()
    
        # true parameters for real ground truth data
        true_params = np.array([[32., 1.]])
        true_data = simulate(true_params)
        t = true_data['t']
        I_inj = true_data['I_inj']
        v = true_data['v']
        xo = calculate_summary_statistics(true_data)
        print("The true summary statistics are:  ", xo)
    
        # Plot estimated posterior distribution
        samples = posterior.sample((1000,), x=xo, show_progress_bars=False)
        labels_params = [r'$\overline{g}_{Na}$', r'$\overline{g}_{K}$']
        sbi.analysis.pairplot(samples,
                              limits=[[.5, 80], [1e-4, 15.]],
                              ticks=[[.5, 80], [1e-4, 15.]],
                              figsize=(4, 4),
                              points=true_params, labels=labels_params,
                              points_offdiag={'markersize': 6},
                              points_colors=['r'])
        plt.tight_layout()
    
        # Draw a single sample from the posterior and convert to numpy for plotting.
        posterior_sample = posterior.sample((1,), x=xo,
                                            show_progress_bars=False).numpy()
        x = simulate(posterior_sample)
    
        # plot observation and sample
        fig, ax = plt.subplots(figsize=(8, 4))
        ax.plot(t/ms, v[0, :]/mV, lw=2, label='observation')
        ax.plot(t/ms, x['v'][0, :]/mV, '--', lw=2, label='posterior sample')
        ax.legend()
        ax.set(xlabel='time (ms)', ylabel='voltage (mV)')
        plt.show()
    

