Variational Inference: Speed vs Accuracy

Fast approximate Bayesian inference for real-time applications

Estimated time: 1.5-2 hours | Advanced session | Requires: MCMC only (can skip stochastic sessions)

Introduction

Do this session if you need faster inference than MCMC (real-time forecasting, many datasets) or want to understand modern ML-influenced approaches to Bayesian inference.

So far in this course, we’ve used Markov Chain Monte Carlo (MCMC) for Bayesian inference. MCMC is asymptotically exact and provides reliable uncertainty quantification. But it has a key limitation: it’s slow.

For a single dataset with a simple model, MCMC takes minutes to hours. But what if we need:

  • Real-time forecasting: Update predictions as new data arrives every hour
  • Many datasets: Analyse hundreds of outbreaks from different locations
  • High-dimensional models: Age-stratified models with 50+ parameters
  • Rapid prototyping: Quickly explore many model variants

This is where variational inference (VI) comes in. VI trades some accuracy for speed improvements - the gains depend on problem complexity, from modest speedups for simple models to orders of magnitude for large neural networks or models with many latent variables.

Objectives

By the end of this session, you will be able to:

  1. Understand the variational inference framework and ELBO optimisation
  2. Implement mean-field and full-rank Gaussian approximations
  3. Compare VI and MCMC trade-offs (speed vs accuracy)
  4. Choose appropriate inference methods for different scenarios
  5. Apply VI to real-time forecasting workflows

Setup

using DifferentialEquations
using Distributions
using DataFrames
using CSV
using Plots
using StatsPlots
using Turing
using Turing.Variational: vi, q_meanfield_gaussian, q_fullrank_gaussian
using ADTypes: AutoForwardDiff
using MCMCChains
using Random
using LinearAlgebra
using DrWatson
using Statistics: mean, std, cor

Random.seed!(1234)
TaskLocalRNG()
NoteComing from Stan?

The closest equivalent is Stan’s ADVI (e.g. vb() in rstan, $variational() in cmdstanr, or CmdStanModel.variational() in cmdstanpy). The interface is similar: define a model and call the VI method instead of the sampling method. The same trade-offs apply — faster than full MCMC but approximate.

Load Tristan da Cunha data

flu_tdc = CSV.read(datadir("flu_tdc_1971.csv"), DataFrame)

scatter(
    flu_tdc.time, flu_tdc.obs,
    xlabel = "Time (days)", ylabel = "Daily incidence",
    label = "Observed cases", markersize = 4, color = :red,
    title = "Tristan da Cunha 1971 Influenza Outbreak",
    legend = :topright
)

Simple SIR model

For clarity, we’ll use a simple SIR model:

"""
SIR model for variational inference demonstrations.
"""
function simulate_sir(θ, init_state, times)
    R_0, D_inf = θ[:R_0], θ[:D_inf]

    function sir_ode!(du, u, p, t)
        S, I, R = u
        β, γ = p
        N = S + I + R

        du[1] = -β * S * I / N
        du[2] = β * S * I / N - γ * I
        du[3] = γ * I
    end

    β = R_0 / D_inf
    γ = 1.0 / D_inf
    u0 = [init_state[:S], init_state[:I], init_state[:R]]
    prob = ODEProblem(sir_ode!, u0, (times[1], times[end]), [β, γ])
    sol = solve(prob, Tsit5(), saveat=times)

    # Extract incidence
    S_values = sol[1, :]
    incidence = [0.0; -diff(S_values)]

    return DataFrame(time=times, S=S_values,
                     I=sol[2, :],
                     R=sol[3, :],
                     Inc=incidence)
end

# Test
times = flu_tdc.time
init_state = Dict(:S => 270.0, :I => 3.0, :R => 0.0)
Dict{Symbol, Float64} with 3 entries:
  :I => 3.0
  :R => 0.0
  :S => 270.0

Turing model for inference

@model function sir_model(times, init_state, n_obs)
    # Priors
    R_0 ~ Uniform(1.0, 20.0)
    D_inf ~ Uniform(1.0, 10.0)
    ρ ~ Uniform(0.1, 1.0)

    # Simulate
    θ = Dict(:R_0 => R_0, :D_inf => D_inf)
    traj = simulate_sir(θ, init_state, times)

    # Likelihood
    lambdas = [max* traj.Inc[i], 1e-10) for i in 1:n_obs]
    obs ~ arraydist(Poisson.(lambdas))

    return traj
end

model = sir_model(times, init_state, length(flu_tdc.obs)) | (; obs = flu_tdc.obs)
DynamicPPL.Model{typeof(sir_model), (:times, :init_state, :n_obs), (), (), Tuple{Vector{Int64}, Dict{Symbol, Float64}, Int64}, Tuple{}, DynamicPPL.ConditionContext{@NamedTuple{obs::Vector{Int64}}, DynamicPPL.DefaultContext}, false}(sir_model, (times = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  50, 51, 52, 53, 54, 55, 56, 57, 58, 59], init_state = Dict(:I => 3.0, :R => 0.0, :S => 270.0), n_obs = 59), NamedTuple(), ConditionContext((obs = [0, 1, 0, 10, 6, 32, 47, 37, 29, 11, 13, 8, 2, 2, 2, 0, 2, 1, 2, 2, 5, 2, 3, 2, 3, 5, 7, 8, 5, 7, 4, 12, 7, 5, 3, 4, 1, 0, 5, 2, 2, 0, 6, 1, 1, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1],), DynamicPPL.DefaultContext()))

obs ~ arraydist(Poisson.(lambdas)) declares a vector of independent Poisson observations, one per time point. The | operator, applied when creating the model instance below, conditions the model on the actual data — see the MCMC session for a full explanation of this pattern.

Baseline: MCMC with NUTS

Let’s start with full MCMC to establish a baseline:

println("Running MCMC with NUTS...")
t_mcmc = @elapsed chain_mcmc = sample(model, NUTS(0.65), 500, progress=false)
println("MCMC time: $(round(t_mcmc, digits=1)) seconds")
Running MCMC with NUTS...
Info: Found initial step size
  ϵ = 0.2
MCMC time: 25.3 seconds
  • @elapsed - returns the time in seconds to execute the expression.
  • NUTS(0.65) - sets the target acceptance rate during warmup. 0.65 is a good default; use 0.8-0.9 for difficult posteriors.
# Plot MCMC posteriors
plot(chain_mcmc[[:R_0, :D_inf, :ρ]], size=(900, 600))
# Summary
summarystats(chain_mcmc)
Summary Statistics

  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

         R_0    1.7456    0.1141    0.0070   300.1356   314.8640    0.9993     ⋯
       D_inf    3.2594    0.3184    0.0190   299.1105   298.2930    0.9989     ⋯
           ρ    0.9920    0.0082    0.0006   114.3943    91.5676    1.0037     ⋯

                                                                1 column omitted

Note the time: MCMC takes several minutes (varies by hardware).

Approach 1: Mean-Field Variational Inference

What is variational inference?

MCMC works by drawing samples from the posterior distribution. This gives us exact results (in the limit), but requires many sequential steps as the sampler explores parameter space.

Variational inference takes a completely different approach: instead of sampling from the true posterior \(p(\theta | \text{data})\), VI finds a simpler approximation \(q(\theta)\) that’s “close” to the true posterior.

The key idea is to turn inference into an optimisation problem:

  1. Choose a family of simple distributions \(q(\theta)\) (e.g., Gaussians)
  2. Optimise parameters of \(q\) to make it as close as possible to the true posterior
  3. Use gradient descent instead of MCMC sampling

Why is this faster? Optimisation can use parallel gradient computations and converges in a fixed number of steps, whereas MCMC must run sequentially until convergence (which can take many thousands of iterations).

The ELBO objective

How do we measure “closeness” between \(q(\theta)\) and the true posterior? VI uses the Kullback-Leibler (KL) divergence, which measures how much information is lost when we use \(q\) to approximate the true posterior.

Directly minimising KL divergence is intractable (we can’t compute the true posterior!), but we can instead maximise an equivalent quantity called the Evidence Lower Bound (ELBO):

\[\text{ELBO}(q) = \mathbb{E}_{q}[\log p(\text{data}, \theta)] - \mathbb{E}_{q}[\log q(\theta)]\]

The first term rewards \(q\) for placing probability mass where the joint distribution \(p(\text{data}, \theta)\) is high. The second term (entropy of \(q\)) prevents \(q\) from collapsing to a point estimate. Together, they push \(q\) to approximate the posterior well.

Maximising ELBO is equivalent to minimising \(\text{KL}(q || p)\), and importantly, we can estimate the ELBO using samples from \(q\) - no need to evaluate the intractable posterior directly.

NoteMathematical detail

The derivation of the ELBO from KL divergence involves Jensen’s inequality and properties of logarithms. If you are not familiar with these concepts, you can safely accept the ELBO formula above on faith and skip ahead to the practical section below — the key intuition is simply that a higher ELBO means a better approximation.

Mean-field approximation

The simplest VI uses a mean-field approximation, which assumes all parameters are independent:

\[q(\theta) = q(\theta_1) \times q(\theta_2) \times \cdots \times q(\theta_n)\]

For our SIR model with three parameters: \[q(R_0, D_{\text{inf}}, \rho) = \mathcal{N}(\mu_1, \sigma_1^2) \times \mathcal{N}(\mu_2, \sigma_2^2) \times \mathcal{N}(\mu_3, \sigma_3^2)\]

This factorisation means we only need to learn 6 numbers (3 means + 3 variances) instead of the full joint distribution. The optimisation becomes tractable, but we lose the ability to capture correlations between parameters.

Note: The actual optimisation happens in unconstrained space (using log transforms for positive parameters, logit for bounded parameters), then samples are transformed back.

The function q_meanfield_gaussian(model) inspects the Turing model and creates an initial variational distribution: one independent Gaussian per model parameter, in unconstrained space. It sets all means to zero and all standard deviations to one. This is the starting point — the vi function then adjusts these means and standard deviations to maximise the ELBO.

println("\nRunning Mean-Field VI...")
q_init_mf = q_meanfield_gaussian(model)
t_vi_mf = @elapsed (q_mf, stats_mf) = vi(model, q_init_mf, 500; adtype=AutoForwardDiff(), show_progress=false)
println("Mean-field VI time: $(round(t_vi_mf, digits=1)) seconds")
println("Final ELBO: ", round(stats_mf[end].elbo, digits=2))
Running Mean-Field VI...
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
Mean-field VI time: 9.8 seconds
Final ELBO: -381.98
  • q_meanfield_gaussian(model) - creates an initial mean-field Gaussian variational distribution for the model, with independent Normal distributions for each parameter (all initialised to standard normals in unconstrained space).
  • vi(model, q_init, n_iters; ...) - runs variational inference, optimising the ELBO for n_iters steps. Returns the optimised distribution q and training statistics.
  • adtype=AutoForwardDiff() - specifies ForwardDiff.jl for computing gradients. Forward-mode AD is efficient when the number of parameters is small (like here).
  • stats_mf[end].elbo - the final ELBO value. Higher (less negative) is better, indicating a closer approximation to the true posterior.

Monitoring ELBO convergence

A critical diagnostic for VI is checking that the ELBO has converged — i.e., that the optimisation has run for long enough. We can plot the ELBO over iterations and look for a plateau:

elbo_vals_mf = [s.elbo for s in stats_mf]
plot(elbo_vals_mf, xlabel="Iteration", ylabel="ELBO",
     title="ELBO convergence (mean-field)", label="ELBO",
     linewidth=2, color=:blue)

The ELBO should increase (become less negative) and eventually plateau. If the curve is still climbing at the end, you need more iterations. Typical counts range from a few hundred to a few thousand, depending on model complexity — for our simple 3-parameter model, 500 iterations is usually sufficient, but more complex models may need 1000–5000 or more.

Much faster! But is it accurate?

# Sample from variational approximation
samples_vi_mf = rand(q_mf, 10_000)

# Convert to DataFrame for comparison
# Note: samples are in unconstrained space, need to transform back
vi_mf_df = DataFrame(
    R_0 = samples_vi_mf[1, :],
    D_inf = samples_vi_mf[2, :],
    ρ = samples_vi_mf[3, :]
)

# Plot comparison
p1 = histogram(vec(chain_mcmc[:R_0]), bins=30, alpha=0.5, normalize=:pdf,
               label="MCMC", xlabel="R_0", ylabel="Density")
histogram!(p1, vi_mf_df.R_0, bins=30, alpha=0.5, normalize=:pdf,
           label="VI (mean-field)")

p2 = histogram(vec(chain_mcmc[:D_inf]), bins=30, alpha=0.5, normalize=:pdf,
               label="MCMC", xlabel="D_inf", ylabel="Density")
histogram!(p2, vi_mf_df.D_inf, bins=30, alpha=0.5, normalize=:pdf,
           label="VI (mean-field)")

plot(p1, p2, layout=(1, 2), size=(900, 400),
     title=["R_0 Comparison" "D_inf Comparison"])

The independence assumption problem

Mean-field assumes parameters are independent. But in epidemic models, \(R_0\) and \(D_{\text{inf}}\) are often negatively correlated in the posterior!

Why does this correlation arise? Consider the SIR model dynamics:

  • The transmission rate \(\beta = R_0 / D_{\text{inf}}\) determines how fast susceptibles become infected
  • If we increase \(R_0\) while keeping \(\beta\) roughly constant (to match observed epidemic speed), we must also increase \(D_{\text{inf}}\)
  • Conversely, a shorter infectious period requires a higher \(R_0\) to produce the same transmission rate

In other words, there are multiple combinations of (\(R_0\), \(D_{\text{inf}}\)) that produce similar epidemic trajectories. The data constrains their ratio more tightly than their individual values. This creates a ridge of high posterior probability running diagonally through parameter space - exactly the kind of correlation mean-field VI cannot capture.

# Check correlation in MCMC samples
println("MCMC correlation (R_0, D_inf): ",
        round(cor(vec(chain_mcmc[:R_0]), vec(chain_mcmc[:D_inf])), digits=3))

# Mean-field VI forces correlation to 0
println("VI mean-field correlation (R_0, D_inf): ",
        round(cor(vi_mf_df.R_0, vi_mf_df.D_inf), digits=3))
MCMC correlation (R_0, D_inf): 0.901
VI mean-field correlation (R_0, D_inf): -0.006

What we see is that the mean-field approximation forces this correlation to zero, which can lead to:

  • Overconfident marginals: Each parameter appears more certain than it really is
  • Implausible joint samples: Combinations that the true posterior considers unlikely
  • Poor predictive performance: Forward simulations may not match observed data well

Approach 2: Full-Rank Gaussian VI

Capturing correlations

The solution to the independence problem is to use a full covariance matrix instead of a diagonal one:

\[q(\theta) = \mathcal{N}(\mu, \Sigma)\]

where \(\Sigma\) is a full \(n \times n\) covariance matrix (not just diagonal elements).

For our 3-parameter model, this means:

  • Mean-field: 6 parameters (3 means + 3 variances)
  • Full-rank: 9 parameters (3 means + 6 unique covariance entries, since \(\Sigma\) is symmetric)

The additional parameters allow the approximation to capture correlations. The elliptical contours of a full Gaussian can be tilted and stretched to match the ridge-like structure of the true posterior.

The trade-off: more parameters means more to optimise, so full-rank VI is slower than mean-field (though still much faster than MCMC). For high-dimensional problems, the covariance matrix grows as \(O(n^2)\), which can become prohibitive - in those cases, low-rank approximations or structured covariances are used.

println("\nRunning Full-Rank VI...")
q_init_full = q_fullrank_gaussian(model)
t_vi_full = @elapsed (q_full, stats_full) = vi(model, q_init_full, 500; adtype=AutoForwardDiff(), show_progress=false)
println("Full-rank VI time: $(round(t_vi_full, digits=1)) seconds")
println("Final ELBO: ", round(stats_full[end].elbo, digits=2))
Running Full-Rank VI...
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
Full-rank VI time: 8.4 seconds
Final ELBO: -377.74
  • q_fullrank_gaussian(model) - creates a full-rank Gaussian variational distribution that can capture correlations between parameters. Uses a lower-triangular Cholesky factor for the covariance matrix.
# Sample from full-rank approximation
samples_vi_full = rand(q_full, 10_000)

vi_full_df = DataFrame(
    R_0 = samples_vi_full[1, :],
    D_inf = samples_vi_full[2, :],
    ρ = samples_vi_full[3, :]
)

# Check correlation - should now be closer to MCMC
println("VI full-rank correlation (R_0, D_inf): ",
        round(cor(vi_full_df.R_0, vi_full_df.D_inf), digits=3))
VI full-rank correlation (R_0, D_inf): 0.926

We can see that the full-rank approximation now captures the negative correlation between \(R_0\) and \(D_{\text{inf}}\), giving a much better approximation to the true posterior.

# Scatter plot comparison
p1 = scatter(vec(chain_mcmc[:R_0]), vec(chain_mcmc[:D_inf]),
             alpha=0.3, label="MCMC", markersize=2,
             xlabel="R_0", ylabel="D_inf",
             title="MCMC: Captures Correlation")

p2 = scatter(vi_mf_df.R_0, vi_mf_df.D_inf,
             alpha=0.3, label="VI (mean-field)", markersize=2,
             xlabel="R_0", ylabel="D_inf",
             title="Mean-Field VI: No Correlation")

p3 = scatter(vi_full_df.R_0, vi_full_df.D_inf,
             alpha=0.3, label="VI (full-rank)", markersize=2,
             xlabel="R_0", ylabel="D_inf",
             title="Full-Rank VI: Captures Correlation")

plot(p1, p2, p3, layout=(1, 3), size=(1200, 400))

Comparison: MCMC vs VI

Quantitative comparison

Let’s compare posterior means and standard deviations:

# Extract summaries
mcmc_summary = summarystats(chain_mcmc)
println("MCMC Summary:")
println(mcmc_summary)

println("\n" * "="^60)
println("Mean-Field VI Summary:")
println("  R_0: mean = ", round(mean(vi_mf_df.R_0), digits=2),
        ", std = ", round(std(vi_mf_df.R_0), digits=2))
println("  D_inf: mean = ", round(mean(vi_mf_df.D_inf), digits=2),
        ", std = ", round(std(vi_mf_df.D_inf), digits=2))
println("  ρ: mean = ", round(mean(vi_mf_df.ρ), digits=2),
        ", std = ", round(std(vi_mf_df.ρ), digits=2))

println("\n" * "="^60)
println("Full-Rank VI Summary:")
println("  R_0: mean = ", round(mean(vi_full_df.R_0), digits=2),
        ", std = ", round(std(vi_full_df.R_0), digits=2))
println("  D_inf: mean = ", round(mean(vi_full_df.D_inf), digits=2),
        ", std = ", round(std(vi_full_df.D_inf), digits=2))
println("  ρ: mean = ", round(mean(vi_full_df.ρ), digits=2),
        ", std = ", round(std(vi_full_df.ρ), digits=2))
MCMC Summary:
Summary Statistics (3 x 8)

============================================================
Mean-Field VI Summary:
  R_0: mean = 1.74, std = 0.05
  D_inf: mean = 3.24, std = 0.14
  ρ: mean = 0.99, std = 0.01

============================================================
Full-Rank VI Summary:
  R_0: mean = 1.76, std = 0.13
  D_inf: mean = 3.29, std = 0.37
  ρ: mean = 0.99, std = 0.01

Speed comparison

println("Speed comparison (from runs above):")
println("="^50)
println("MCMC (500 samples):      $(round(t_mcmc, digits=1)) seconds")
println("Mean-field VI (500 iter): $(round(t_vi_mf, digits=1)) seconds")
println("Full-rank VI (500 iter):  $(round(t_vi_full, digits=1)) seconds")
println("="^50)
println("Mean-field speedup: $(round(t_mcmc / t_vi_mf, digits=1))×")
println("Full-rank speedup:  $(round(t_mcmc / t_vi_full, digits=1))×")
Speed comparison (from runs above):
==================================================
MCMC (500 samples):      25.3 seconds
Mean-field VI (500 iter): 9.8 seconds
Full-rank VI (500 iter):  8.4 seconds
==================================================
Mean-field speedup: 2.6×
Full-rank speedup:  3.0×

The speedup becomes even more dramatic with larger models or more samples.

Trade-offs summary

Method Speed Accuracy Correlations Use Case
MCMC Slow (minutes-hours) Exact (asymptotically) Yes Final analysis, publications
Mean-field VI Very fast (seconds) Approximate No Quick exploration, many datasets
Full-rank VI Fast (seconds) Better Yes Real-time forecasting

When to use each method

Use MCMC when:

  • You need exact inference for publication
  • You have time (single dataset, offline analysis)
  • Posterior may be multimodal or complex
  • You need reliable uncertainty quantification
  • Problem is low-to-moderate dimensional (<50 parameters)

Use Mean-Field VI when:

  • You need extremely fast approximate inference
  • Parameters are roughly independent
  • You’re doing exploratory analysis or prototyping
  • You need to analyse many datasets quickly
  • Rough uncertainty estimates are sufficient

Use Full-Rank VI when:

  • Speed is essential and MCMC is too slow (real-time systems)
  • Parameters are correlated (otherwise mean-field is fine)
  • You’ve validated that VI and MCMC give similar results for your problem
WarningWhen VI fails

Gaussian VI (both mean-field and full-rank) assumes the posterior is well-approximated by a single Gaussian. This assumption breaks down in several important cases:

  • Multimodal posteriors: If the posterior has multiple peaks (e.g., a model with label-switching symmetry), a single Gaussian will either collapse onto one mode or spread across modes, giving a poor approximation in either case.
  • Heavy-tailed posteriors: Gaussian tails decay exponentially, so VI will systematically underestimate the probability of extreme parameter values. This leads to overconfident uncertainty intervals.
  • Strong nonlinear correlations: Full-rank VI captures linear correlations via the covariance matrix, but banana-shaped or other nonlinear posterior structures cannot be represented by any Gaussian.

When you suspect these issues, MCMC remains the more reliable choice. A useful diagnostic: if VI and MCMC give substantially different results, trust MCMC — VI is likely struggling with the posterior geometry.

Practical workflow

Why speed matters in practice

The speed comparisons above might seem academic for a single analysis - who cares if inference takes 30 seconds vs 5 minutes? But the difference becomes critical in operational settings:

Real-time epidemic forecasting: During an outbreak, you might need to update estimates as new data arrives - daily or even hourly. If each MCMC run takes 5 minutes, you can afford to wait. But if you’re fitting models for 100 different regions, that’s over 8 hours of compute time per update cycle. With VI, the same analysis might take under an hour.

Model development and debugging: When developing a new model, you’ll run inference hundreds of times as you iterate on the specification, fix bugs, and tune priors. Fast iteration cycles (seconds vs minutes) dramatically accelerate development.

Sensitivity analyses: Checking how results change with different priors, data subsets, or model variants requires many inference runs. VI makes comprehensive sensitivity analysis practical.

Ensemble forecasting: Combining predictions from multiple models improves forecast accuracy. VI’s speed makes it practical to run many model variants.

The practical recommendation: validate VI against MCMC once (to ensure they give similar results for your problem), then use VI for operational workflows where speed matters.

Exercises

Exercise 1: ELBO convergence and iteration count

Run mean-field VI with different numbers of iterations (50, 200, 500, 2000) and observe how the ELBO converges and how the posterior approximation changes.

Use a loop over iteration counts, storing the final ELBO and posterior mean of \(R_0\) for each run.

iter_counts = [50, 200, 500, 2000]
results = []

for n_iter in iter_counts
    q_init = q_meanfield_gaussian(model)
    (q_result, stats_result) = vi(model, q_init, n_iter; adtype=AutoForwardDiff(), show_progress=false)
    samples = rand(q_result, 5_000)
    push!(results, (
        n_iter = n_iter,
        final_elbo = stats_result[end].elbo,
        R0_mean = mean(samples[1, :]),
        R0_std = std(samples[1, :]),
        elbo_trace = [s.elbo for s in stats_result]
    ))
end

# Plot ELBO traces
p = plot(xlabel="Iteration", ylabel="ELBO",
         title="ELBO convergence by iteration count", legend=:bottomright)
for r in results
    plot!(p, r.elbo_trace, label="$(r.n_iter) iters", linewidth=1.5)
end
p
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
# Summarise results
println("Iterations | Final ELBO | R₀ mean ± std")
println("-"^50)
for r in results
    println("$(lpad(r.n_iter, 9)) | $(lpad(round(r.final_elbo, digits=1), 10)) | $(round(r.R0_mean, digits=2)) ± $(round(r.R0_std, digits=2))")
end
Iterations | Final ELBO | R₀ mean ± std
--------------------------------------------------
       50 |     -392.0 | 2.65 ± 0.15
      200 |     -386.3 | 1.77 ± 0.06
      500 |     -379.0 | 1.74 ± 0.05
     2000 |     -378.5 | 1.73 ± 0.05

You should see that with too few iterations (e.g., 50), the ELBO has not converged and the posterior means may be inaccurate. By around 500 iterations the ELBO typically plateaus for this model, and increasing to 2000 gives negligible improvement. This tells us 500 iterations is sufficient here.

Exercise 2: Comparing VI and MCMC posteriors

Run both mean-field VI and full-rank VI, and compare their posteriors against the MCMC baseline. Produce overlay histograms for all three parameters and assess where VI agrees with MCMC and where it diverges.

Use the MCMC chain chain_mcmc and VI distributions q_mf and q_full that we already fitted above. Sample 10,000 draws from each VI distribution and overlay the histograms.

# Draw samples (reuse existing VI fits from above)
samp_mf = rand(q_mf, 10_000)
samp_full = rand(q_full, 10_000)

param_names = ["R_0", "D_inf", "ρ"]
plots_list = []

for (i, pname) in enumerate(param_names)
    p = histogram(vec(chain_mcmc[Symbol(pname)]), bins=40, alpha=0.4,
                  normalize=:pdf, label="MCMC", xlabel=pname, ylabel="Density")
    histogram!(p, samp_mf[i, :], bins=40, alpha=0.4,
               normalize=:pdf, label="Mean-field VI")
    histogram!(p, samp_full[i, :], bins=40, alpha=0.4,
               normalize=:pdf, label="Full-rank VI")
    push!(plots_list, p)
end

plot(plots_list..., layout=(1, 3), size=(1200, 400),
     plot_title="Posterior comparison: MCMC vs VI")
# Quantitative comparison
println("Parameter | MCMC mean (std) | MF-VI mean (std) | FR-VI mean (std)")
println("-"^75)
for (i, pname) in enumerate(param_names)
    mcmc_vals = vec(chain_mcmc[Symbol(pname)])
    println("$(rpad(pname, 9)) | $(round(mean(mcmc_vals), digits=2)) ($(round(std(mcmc_vals), digits=2))) | " *
            "$(round(mean(samp_mf[i, :]), digits=2)) ($(round(std(samp_mf[i, :]), digits=2))) | " *
            "$(round(mean(samp_full[i, :]), digits=2)) ($(round(std(samp_full[i, :]), digits=2)))")
end
Parameter | MCMC mean (std) | MF-VI mean (std) | FR-VI mean (std)
---------------------------------------------------------------------------
R_0       | 1.75 (0.11) | 1.74 (0.05) | 1.76 (0.13)
D_inf     | 3.26 (0.32) | 3.24 (0.14) | 3.29 (0.37)
ρ         | 0.99 (0.01) | 0.99 (0.01) | 0.99 (0.01)

Look at where the three methods agree and disagree. Typical findings for this model:

  • Posterior means are usually similar across all three methods — VI tends to find the right location.
  • Posterior widths may differ: mean-field VI can be overconfident (narrower) because it ignores correlations.
  • Full-rank VI generally gives a closer match to MCMC than mean-field, particularly for correlated parameters like \(R_0\) and \(D_\text{inf}\).

If you see substantial disagreement between full-rank VI and MCMC, it may indicate that 500 VI iterations was not enough, or that the posterior has features (e.g., skewness) that even full-rank Gaussian VI cannot capture.

Summary

Variational inference reframes Bayesian inference as optimisation: find a simple distribution \(q(\theta)\) that approximates the true posterior. This trades some accuracy for speed - useful when MCMC is too slow.

Mean-field VI assumes parameter independence. Fast, but it missed the R₀-D_inf correlation in our SIR model, producing overconfident intervals. Full-rank VI uses a full covariance matrix, capturing correlations at modest additional cost. For epidemic models with correlated parameters, full-rank is usually worth it.

VI is not a replacement for MCMC - it’s a complementary tool. Use VI for rapid exploration and real-time systems; validate against MCMC when accuracy matters.

TipLearning points
  • Variational inference turns Bayesian inference into optimisation: find a simple distribution \(q(\theta)\) that approximates the true posterior
  • Mean-field VI assumes parameter independence - very fast, but misses correlations that are common in epidemic models
  • Full-rank VI captures correlations using a full covariance matrix - good balance between speed and accuracy
  • ELBO (Evidence Lower Bound) is the objective function - higher values mean better approximations
  • Speed vs accuracy trade-off: VI is much faster than MCMC but provides approximations; use VI for exploration and real-time systems, MCMC for final publication-quality analysis

VI is not a replacement for MCMC - it’s a complementary tool that enables applications where MCMC is too slow!

Going further

The Gaussian VI methods covered in this session are the most commonly used in practice, but more flexible approaches exist. Normalising flows use sequences of invertible transformations to create non-Gaussian approximations, allowing them to handle multimodal and skewed posteriors — see Bijectors.jl for Julia implementations. Amortised inference trains a neural network to predict posterior parameters directly from data, eliminating per-dataset optimisation — useful when analysing many similar datasets. These are active research areas; the references below provide entry points.

Next session

In Universal Differential Equations, we’ll explore how to combine mechanistic models with neural networks to discover unknown transmission patterns directly from data.

References