Spaces:
Running
Running
File size: 4,990 Bytes
377127e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import torch
def flexible_kernel(X, Y, X_org, Y_org, sigma, sigma0=0.1, epsilon=1e-08):
"""Flexible kernel calculation as in MMDu."""
Dxy = Pdist2(X, Y)
Dxy_org = Pdist2(X_org, Y_org)
L = 1
Kxy = (1 - epsilon) * torch.exp(
-((Dxy / sigma0) ** L) - Dxy_org / sigma
) + epsilon * torch.exp(-Dxy_org / sigma)
return Kxy
def MMD_Diff_Var(Kyy, Kzz, Kxy, Kxz, epsilon=1e-08):
"""Compute the variance of the difference statistic MMDXY - MMDXZ."""
"""Referenced from: https://github.com/eugenium/MMD/blob/master/mmd.py"""
m = Kxy.shape[0]
n = Kyy.shape[0]
r = Kzz.shape[0]
# Remove diagonal elements
Kyynd = Kyy - torch.diag(torch.diag(Kyy))
Kzznd = Kzz - torch.diag(torch.diag(Kzz))
u_yy = torch.sum(Kyynd) * (1.0 / (n * (n - 1)))
u_zz = torch.sum(Kzznd) * (1.0 / (r * (r - 1)))
u_xy = torch.sum(Kxy) / (m * n)
u_xz = torch.sum(Kxz) / (m * r)
t1 = (1.0 / n**3) * torch.sum(Kyynd.T @ Kyynd) - u_yy**2
t2 = (1.0 / (n**2 * m)) * torch.sum(Kxy.T @ Kxy) - u_xy**2
t3 = (1.0 / (n * m**2)) * torch.sum(Kxy @ Kxy.T) - u_xy**2
t4 = (1.0 / r**3) * torch.sum(Kzznd.T @ Kzznd) - u_zz**2
t5 = (1.0 / (r * m**2)) * torch.sum(Kxz @ Kxz.T) - u_xz**2
t6 = (1.0 / (r**2 * m)) * torch.sum(Kxz.T @ Kxz) - u_xz**2
t7 = (1.0 / (n**2 * m)) * torch.sum(Kyynd @ Kxy.T) - u_yy * u_xy
t8 = (1.0 / (n * m * r)) * torch.sum(Kxy.T @ Kxz) - u_xz * u_xy
t9 = (1.0 / (r**2 * m)) * torch.sum(Kzznd @ Kxz.T) - u_zz * u_xz
if type(epsilon) == torch.Tensor:
epsilon_tensor = epsilon.clone().detach()
else:
epsilon_tensor = torch.tensor(epsilon, device=Kyy.device)
zeta1 = torch.max(t1 + t2 + t3 + t4 + t5 + t6 - 2 * (t7 + t8 + t9), epsilon_tensor)
zeta2 = torch.max(
(1 / m / (m - 1)) * torch.sum((Kyynd - Kzznd - Kxy.T - Kxy + Kxz + Kxz.T) ** 2)
- (u_yy - 2 * u_xy - (u_zz - 2 * u_xz)) ** 2,
epsilon_tensor,
)
data = {
"t1": t1.item(),
"t2": t2.item(),
"t3": t3.item(),
"t4": t4.item(),
"t5": t5.item(),
"t6": t6.item(),
"t7": t7.item(),
"t8": t8.item(),
"t9": t9.item(),
"zeta1": zeta1.item(),
"zeta2": zeta2.item(),
}
Var = (4 * (m - 2) / (m * (m - 1))) * zeta1
Var_z2 = Var + (2.0 / (m * (m - 1))) * zeta2
return Var, Var_z2, data
def Pdist2(x, y):
"""compute the paired distance between x and y."""
x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_norm = (y**2).sum(1).view(1, -1)
else:
y = x
y_norm = x_norm.view(1, -1)
Pdist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
Pdist[Pdist < 0] = 0
return Pdist
def MMD_batch2(
Fea,
len_s,
Fea_org,
sigma,
sigma0=0.1,
epsilon=10 ** (-10),
is_var_computed=True,
use_1sample_U=True,
coeff_xy=2,
):
X = Fea[0:len_s, :]
Y = Fea[len_s:, :]
L = 1 # generalized Gaussian (if L>1)
nx = X.shape[0]
ny = Y.shape[0]
Dxx = Pdist2(X, X)
Dyy = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device)
# Dyy = Pdist2(Y, Y)
Dxy = Pdist2(X, Y).transpose(0, 1)
Kx = torch.exp(-Dxx / sigma0)
Ky = torch.exp(-Dyy / sigma0)
Kxy = torch.exp(-Dxy / sigma0)
nx = Kx.shape[0]
is_unbiased = False
xx = torch.div((torch.sum(Kx)), (nx * nx))
yy = Ky.reshape(-1)
xy = torch.div(torch.sum(Kxy, dim=1), (nx))
mmd2 = xx - 2 * xy + yy
return mmd2
# MMD for three samples
def MMD_3_Sample_Test(
ref_fea,
fea_y,
fea_z,
ref_fea_org,
fea_y_org,
fea_z_org,
sigma,
sigma0,
epsilon,
alpha,
):
"""Run three-sample test (TST) using deep kernel kernel."""
X = ref_fea.clone().detach()
Y = fea_y.clone().detach()
Z = fea_z.clone().detach()
X_org = ref_fea_org.clone().detach()
Y_org = fea_y_org.clone().detach()
Z_org = fea_z_org.clone().detach()
Kyy = flexible_kernel(Y, Y, Y_org, Y_org, sigma, sigma0, epsilon)
Kzz = flexible_kernel(Z, Z, Z_org, Z_org, sigma, sigma0, epsilon)
Kxy = flexible_kernel(X, Y, X_org, Y_org, sigma, sigma0, epsilon)
Kxz = flexible_kernel(X, Z, X_org, Z_org, sigma, sigma0, epsilon)
Kyynd = Kyy - torch.diag(torch.diag(Kyy))
Kzznd = Kzz - torch.diag(torch.diag(Kzz))
Diff_Var, _, _ = MMD_Diff_Var(Kyy, Kzz, Kxy, Kxz, epsilon)
u_yy = torch.sum(Kyynd) / (Y.shape[0] * (Y.shape[0] - 1))
u_zz = torch.sum(Kzznd) / (Z.shape[0] * (Z.shape[0] - 1))
u_xy = torch.sum(Kxy) / (X.shape[0] * Y.shape[0])
u_xz = torch.sum(Kxz) / (X.shape[0] * Z.shape[0])
t = u_yy - 2 * u_xy - (u_zz - 2 * u_xz)
if Diff_Var.item() <= 0:
Diff_Var = torch.max(epsilon, torch.tensor(1e-08))
p_value = torch.distributions.Normal(0, 1).cdf(-t / torch.sqrt((Diff_Var)))
t = t / torch.sqrt(Diff_Var)
if p_value > alpha:
h = 0
else:
h = 1
return h, p_value.item(), t.item()
|