-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathfvd.py
More file actions
137 lines (109 loc) · 4.96 KB
/
fvd.py
File metadata and controls
137 lines (109 loc) · 4.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data as data
from tats.fvd.pytorch_i3d import InceptionI3d
import os
from sklearn.metrics.pairwise import polynomial_kernel
MAX_BATCH = 16
FVD_SAMPLE_SIZE = 2048
TARGET_RESOLUTION = (224, 224)
def preprocess(videos, target_resolution):
# videos in {0, ..., 255} as np.uint8 array
b, t, h, w, c = videos.shape
all_frames = torch.FloatTensor(videos).flatten(end_dim=1) # (b * t, h, w, c)
all_frames = all_frames.permute(0, 3, 1, 2).contiguous() # (b * t, c, h, w)
resized_videos = F.interpolate(all_frames, size=target_resolution,
mode='bilinear', align_corners=False)
resized_videos = resized_videos.view(b, t, c, *target_resolution)
output_videos = resized_videos.transpose(1, 2).contiguous() # (b, c, t, *)
scaled_videos = 2. * output_videos / 255. - 1 # [-1, 1]
return scaled_videos
def get_fvd_logits(videos, i3d, device):
videos = preprocess(videos, TARGET_RESOLUTION)
embeddings = get_logits(i3d, videos, device)
return embeddings
def load_fvd_model(device):
i3d = InceptionI3d(400, in_channels=3).to(device)
current_dir = os.path.dirname(os.path.abspath(__file__))
i3d_path = os.path.join(current_dir, 'i3d_pretrained_400.pt')
i3d.load_state_dict(torch.load(i3d_path, map_location=device))
i3d.eval()
return i3d
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161
def _symmetric_matrix_square_root(mat, eps=1e-10):
u, s, v = torch.svd(mat)
si = torch.where(s < eps, s, torch.sqrt(s))
return torch.matmul(torch.matmul(u, torch.diag(si)), v.t())
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400
def trace_sqrt_product(sigma, sigma_v):
sqrt_sigma = _symmetric_matrix_square_root(sigma)
sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma))
return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
def cov(m, rowvar=False):
'''Estimate a covariance matrix given data.
Covariance indicates the level to which two variables vary together.
If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
then the covariance matrix element `C_{ij}` is the covariance of
`x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
Args:
m: A 1-D or 2-D array containing multiple variables and observations.
Each row of `m` represents a variable, and each column a single
observation of all those variables.
rowvar: If `rowvar` is True, then each row represents a
variable, with observations in the columns. Otherwise, the
relationship is transposed: each column represents a variable,
while the rows contain observations.
Returns:
The covariance matrix of the variables.
'''
if m.dim() > 2:
raise ValueError('m has more than 2 dimensions')
if m.dim() < 2:
m = m.view(1, -1)
if not rowvar and m.size(0) != 1:
m = m.t()
fact = 1.0 / (m.size(1) - 1) # unbiased estimate
m_center = m - torch.mean(m, dim=1, keepdim=True)
mt = m_center.t() # if complex: mt = m.t().conj()
return fact * m_center.matmul(mt).squeeze()
def frechet_distance(x1, x2):
x1 = x1.flatten(start_dim=1)
x2 = x2.flatten(start_dim=1)
m, m_w = x1.mean(dim=0), x2.mean(dim=0)
sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False)
sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
mean = torch.sum((m - m_w) ** 2)
fd = trace + mean
return fd
def polynomial_mmd(X, Y):
m = X.shape[0]
n = Y.shape[0]
# compute kernels
K_XX = polynomial_kernel(X)
K_YY = polynomial_kernel(Y)
K_XY = polynomial_kernel(X, Y)
# compute mmd distance
K_XX_sum = (K_XX.sum() - np.diagonal(K_XX).sum()) / (m * (m - 1))
K_YY_sum = (K_YY.sum() - np.diagonal(K_YY).sum()) / (n * (n - 1))
K_XY_sum = K_XY.sum() / (m * n)
mmd = K_XX_sum + K_YY_sum - 2 * K_XY_sum
return mmd
def get_logits(i3d, videos, device):
assert videos.shape[0] % MAX_BATCH == 0
with torch.no_grad():
logits = []
for i in range(0, videos.shape[0], MAX_BATCH):
batch = videos[i:i + MAX_BATCH].to(device)
logits.append(i3d(batch))
logits = torch.cat(logits, dim=0)
return logits
def compute_fvd(real, samples, i3d, device=torch.device('cpu')):
# real, samples are (N, T, H, W, C) numpy arrays in np.uint8
real, samples = preprocess(real, (224, 224)), preprocess(samples, (224, 224))
first_embed = get_logits(i3d, real, device)
second_embed = get_logits(i3d, samples, device)
return frechet_distance(first_embed, second_embed)