Running Estimate Multivariable Gaussian Distributions

Estimate the mean and covariance of a multivariable Gaussian distribution from data samples

给定若干个 data point,假设这些样本是从多元高斯分布种采样出来的,估计该高斯分布的参数 (均值向量 $\mu$ 和协方差矩阵 $\sigma$)

有很多情况下,我们无法一次性获取所有的数据,数据是分批次逐渐获得的,比如在神经网络中估计样本特征的均值和协方差。 或者有时候一次性计算所有数据的参数根本无法做到,比如计算100亿个样本的协方差和均值,普通机器根本无法做到。

本 notebook 主要讲述如何 针对分批次的数据来 running estimate 高斯分布的参数

In [1]:
import torch, torchvision
import numpy as np
from torch.distributions.multivariate_normal import MultivariateNormal
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

A 2D example 首先用二元的高斯分布举例,方便可视化

我们先设定一个2元高斯分布的参数,然后生成若干个样本,并通过这些样本估计分布的参数。

In [2]:
torch.manual_seed(4)

mean = torch.tensor([5, 5]).float()

cov = torch.randn(2, 2)
cov = cov.mm(cov.t())

n = 6000
dist = MultivariateNormal(loc=mean, covariance_matrix=cov)
samples = dist.sample(sample_shape=(n,))

print(mean)
print(cov)
tensor([5., 5.])
tensor([[ 2.6310, -3.3986],
        [-3.3986,  5.7349]])

可视化这些样本

In [3]:
colors = np.random.rand(n)
print(samples[:, 0].shape)
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5, c=colors)
plt.grid()
torch.Size([6000])

每个batch 输入50个样本,running estimate 均值和协方差

In [4]:
# batch-nize data
samples = samples.view(-1, 50, 2)

sigma = torch.eye(2)
mean = torch.zeros(1, 2)

t = 0
n = 0
for batch in samples:
    bs = batch.size(0)
    t = t + 1
    n = n + bs

    m = batch.mean(dim=0)

    mean1 = (mean * bs * (t-1) + m * bs) / n
    sigma = sigma + batch.t().mm(batch) + (n + bs -1) * (mean.t().mm(mean) - mean1.t().mm(mean1)) - bs * mean.t().mm(mean)
    mean = mean1
sigma = sigma / (n-1)

打印出估计出来的均值和协方差矩阵,和上面模型的均值和协方差矩阵很接近。

In [5]:
mean
Out[5]:
tensor([[4.9513, 5.0701]])
In [6]:
sigma
Out[6]:
tensor([[ 2.4376, -3.6059],
        [-3.6059,  5.5210]])

A high dimensional distribution 更高维的高斯分布

first generate samples for estimation 生成数据

In [7]:
d = 128
m = torch.randn(d)
cov = torch.randn(d, d)
cov = cov.mm(cov.t())

dist = MultivariateNormal(loc=m, covariance_matrix=cov)
samples = dist.sample(sample_shape=(100000,)).view(-1, 50, d)
In [8]:
# initiate covariance (sigma) and mean
sigma = torch.eye(d)
mean = torch.zeros(1, d)

t = 0
n = 0
for batch in tqdm(samples):
    bs = batch.size(0)
    t = t + 1
    n = n + bs

    m = batch.mean(dim=0)

    mean1 = (mean * bs * (t-1) + m * bs) / n
    sigma = sigma + batch.t().mm(batch) + (n + bs -1) * (mean.t().mm(mean) - mean1.t().mm(mean1)) - bs * mean.t().mm(mean)
    mean = mean1
sigma = sigma / (n-1)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:00<00:00, 7076.28it/s]
$$ \overline{X}_t = \frac{\overline{X}_{t-1} \cdot bs \cdot (t-1) + \overline{X}\cdot bs}{t \cdot bs} $$$$ M^t = M^{t-1} + X_t^TX_t + (n + bs -1)(\overline{X}_{t-1}^T\overline{X}_{t-1} - \overline{X}_t^T\overline{X}_t) - bs\cdot\overline{X}_{t-1}^T\overline{X}_{t-1} $$
  • $t$: time step;
  • $bs$: batch size
  • $M^t$: unnormalized covariance at time step $t$;
  • $\overline{X}_{t}$: batch mean at time step $t$;
  • $X_{t}$: estimated mean at time step $t$;

Distribution mean and covariance

In [9]:
print(m)
tensor([ 0.6690, -0.7991,  1.2752, -0.8550,  1.8720, -1.1037,  1.1897, -0.7947,
         0.3185,  3.3946,  1.3640,  1.5085, -0.4943, -1.4743, -0.0226,  2.7434,
         3.0446,  0.5451, -2.6687, -5.8625, -0.3359,  2.1169,  1.2494,  0.2833,
         1.7533,  0.1402, -1.1091,  3.6053, -2.3472, -4.3817, -2.7009, -2.0003,
         0.1642, -0.7117, -0.4205,  2.5602,  0.9799,  0.5564,  0.6682, -0.3475,
         0.8833, -0.8802, -0.4283,  1.1144, -0.5187, -0.2829,  1.9375,  2.5031,
         1.5877, -1.3960,  1.6548,  1.2411,  2.2651, -3.5474, -0.1063,  0.5350,
        -1.5288, -1.3198,  0.7773, -1.8456,  2.6918,  2.7315,  1.7986, -1.0009,
        -1.5427, -2.6441, -4.4430,  0.3530,  1.3754, -0.9392, -0.3825, -2.6168,
         2.2516,  2.1361, -0.8463, -3.6970, -0.6834,  0.7744,  4.9241,  1.1855,
         0.4067,  1.7793, -0.5616, -0.8433,  2.8664,  0.3995, -0.6622, -0.7129,
         0.5325,  1.4083,  0.9511, -0.2096,  1.8313, -1.6818, -0.9619, -0.9901,
        -2.3075, -2.0898, -0.5956, -1.6093,  3.3113,  2.8122, -3.1343,  3.1420,
        -1.6974, -4.6162, -0.3918,  0.6847, -2.8809,  2.4475,  2.1285, -0.4980,
         1.6721,  1.6995, -0.2299, -0.0078, -1.5305,  0.8664,  2.9363,  2.0647,
         2.2168,  0.3225, -0.7296,  3.2490, -0.0329,  2.1888,  0.7261, -0.7563])
In [10]:
print(cov)
tensor([[132.3313, -14.9537,  -5.3834,  ...,  -4.4668,  -2.7948,  -2.2385],
        [-14.9537, 116.9795,  18.4168,  ..., -13.1539,  -7.1572,  -7.8003],
        [ -5.3834,  18.4168, 130.3849,  ..., -12.9975,  -0.5502,  -8.7534],
        ...,
        [ -4.4668, -13.1539, -12.9975,  ..., 128.9788,  -3.1381,   3.9312],
        [ -2.7948,  -7.1572,  -0.5502,  ...,  -3.1381, 133.4015,  10.8676],
        [ -2.2385,  -7.8003,  -8.7534,  ...,   3.9312,  10.8676, 120.8049]])

Estimated mean and covariance

In [11]:
mean
Out[11]:
tensor([[ 6.4535e-01, -5.2039e-01,  9.3006e-02, -2.5598e+00,  1.6549e-01,
         -7.9256e-02, -4.6649e-01, -1.2485e+00,  2.6971e-01,  2.2544e-01,
          1.9445e-01,  5.8159e-01, -8.0804e-01, -6.3226e-01, -4.3310e-01,
          4.4581e-02,  8.5828e-01,  8.2413e-02,  7.5887e-01, -1.3219e+00,
         -1.6376e+00,  9.0990e-01,  3.7048e-01,  2.9472e-01,  1.9567e+00,
          2.3271e-01, -8.9447e-01,  2.0350e+00, -6.4032e-01, -7.0798e-01,
         -3.9520e-01, -3.3322e+00,  7.2658e-01, -1.3341e+00, -2.0584e-01,
          1.6834e+00,  2.8762e-01, -8.4816e-01, -9.5242e-01, -9.5150e-01,
          3.5367e-02, -1.1643e+00, -7.2251e-01,  6.4093e-01, -8.0365e-02,
         -4.7422e-01,  1.5139e-01,  1.6515e-01,  6.3490e-01,  3.0903e-01,
         -1.9743e-01,  4.0403e-01,  5.0163e-01, -5.2807e-01, -6.4054e-01,
          2.5552e-02, -2.9152e-01,  1.8208e-02,  8.7059e-02, -6.4096e-01,
          1.1698e+00,  1.0166e+00,  9.6652e-01,  6.9987e-01, -6.3494e-01,
         -2.0679e+00,  9.1055e-01,  8.7085e-01,  6.1761e-01, -9.3352e-01,
         -7.7460e-01, -9.8226e-01,  1.5804e+00,  1.0149e+00, -6.7836e-01,
         -1.8506e+00,  2.2699e-01,  9.6232e-01,  1.3539e+00,  2.3260e-01,
          1.0604e+00,  2.6834e-01, -1.1632e+00, -1.4171e-01,  4.8008e-01,
          2.9846e-01, -4.8141e-01,  2.6709e-01,  2.6692e-01, -2.9069e-01,
         -7.3659e-03, -1.4171e-01,  1.6185e-01, -6.9675e-01, -1.3564e+00,
         -1.5361e+00, -1.4834e+00, -1.5550e+00, -5.2030e-01, -1.4789e+00,
         -1.4643e-01,  1.1346e+00,  3.0055e-02,  1.5238e+00, -1.1367e+00,
         -1.3612e-03, -5.9794e-01, -1.4952e+00, -2.6915e+00,  1.7095e+00,
          3.4237e-01, -7.1614e-01,  1.1129e+00,  5.3432e-01, -2.7895e-01,
          2.1422e-01, -5.4506e-02,  5.1532e-01, -8.7199e-01,  6.8189e-01,
          9.1447e-01, -3.3461e-02, -1.2643e+00, -1.4579e-01, -1.2808e+00,
          1.0968e-01,  3.5911e-01,  2.5529e-01]])
In [12]:
sigma
Out[12]:
tensor([[132.4417, -15.1356,  -5.9752,  ...,  -4.1392,  -2.6158,  -2.1068],
        [-15.1356, 116.3259,  18.9425,  ..., -13.2735,  -6.9948,  -8.1337],
        [ -5.9752,  18.9425, 130.5436,  ..., -13.5383,  -0.9726,  -9.6370],
        ...,
        [ -4.1392, -13.2735, -13.5383,  ..., 129.2036,  -3.1350,   4.0759],
        [ -2.6158,  -6.9948,  -0.9726,  ...,  -3.1350, 134.3331,  10.6512],
        [ -2.1068,  -8.1337,  -9.6370,  ...,   4.0759,  10.6512, 120.4067]])

Estimate covariance with a pretrained model

估计一个预训练模型的特征

In [13]:
model = torchvision.models.mobilenet_v3_small(pretrained=True).cuda()
dataset = torch.randn(100, 32, 3, 224, 224).cuda()

# ONLY extract features, remove classifier
model = model.features
# dimension according to your own model
dim = 576

# initial mean and covariance
sigma = torch.eye(dim).cuda()
mean = torch.zeros(1, dim).cuda()

t = 0
n = 0

with torch.no_grad():
    for batch in tqdm(dataset):
        batch = model(batch).mean(dim=(2,3))
        bs = batch.size(0)
        t = t + 1
        n = n + bs

        m = batch.mean(dim=0)

        mean1 = (mean * bs * (t-1) + m * bs) / n
        sigma = sigma + batch.t().mm(batch) + (n + bs -1) * (mean.t().mm(mean) - mean1.t().mm(mean1)) - bs * mean.t().mm(mean)
        mean = mean1

sigma = sigma / (n-1)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 167.58it/s]
In [14]:
sigma
Out[14]:
tensor([[ 0.0588,  0.0108, -0.0010,  ...,  0.0120,  0.0069,  0.0032],
        [ 0.0108,  0.0572, -0.0009,  ...,  0.0104,  0.0040,  0.0048],
        [-0.0010, -0.0009,  0.0620,  ...,  0.0074,  0.0098, -0.0023],
        ...,
        [ 0.0120,  0.0104,  0.0074,  ...,  0.0646,  0.0064,  0.0048],
        [ 0.0069,  0.0040,  0.0098,  ...,  0.0064,  0.0567,  0.0069],
        [ 0.0032,  0.0048, -0.0023,  ...,  0.0048,  0.0069,  0.0398]],
       device='cuda:0')
For more notebooks, visit https://kaizhao.net/misc.