Spaces:
Sleeping
Sleeping
| import functools | |
| from abc import ABC, abstractmethod | |
| from collections import deque | |
| from typing import Callable, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from jaxtyping import Float | |
| from ibydmt.payoff import HSIC, cMMD, xMMD | |
| from ibydmt.wealth import get_wealth | |
| Array = Union[np.ndarray, torch.Tensor] | |
| class Tester(ABC): | |
| def __init__(self): | |
| pass | |
| def test(self, *args, **kwargs) -> Tuple[bool, int]: | |
| pass | |
| class SequentialTester(Tester): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.wealth = get_wealth(config.wealth)(config) | |
| self.tau_max = config.tau_max | |
| class SKIT(SequentialTester): | |
| """Global Independence Tester""" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.payoff = HSIC(config) | |
| def test(self, Y: Float[Array, "N"], Z: Float[Array, "N"]) -> Tuple[bool, int]: | |
| D = np.stack([Y, Z], axis=1) | |
| for t in range(1, self.tau_max): | |
| d = D[2 * t : 2 * (t + 1)] | |
| prev_d = D[: 2 * t] | |
| null_d = np.stack([d[:, 0], np.flip(d[:, 1])], axis=1) | |
| payoff = self.payoff.compute(d, null_d, prev_d) | |
| self.wealth.update(payoff) | |
| if self.wealth.rejected: | |
| return (True, t) | |
| return (False, t) | |
| class cSKIT(SequentialTester): | |
| """Global Conditional Independence Tester""" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.payoff = cMMD(config) | |
| def _sample( | |
| self, | |
| z: Float[Array, "N D"], | |
| j: int = None, | |
| cond_p: Callable[[Float[Array, "N D"], list[int]], Float[Array, "N D"]] = None, | |
| ) -> Tuple[Float[Array, "N"], Float[Array, "N"], Float[Array, "N D-1"]]: | |
| C = list(set(range(z.shape[1])) - {j}) | |
| zj, cond_z = z[:, [j]], z[:, C] | |
| samples = cond_p(z, C) | |
| null_zj = samples[:, [j]] | |
| return zj, null_zj, cond_z | |
| def test( | |
| self, | |
| Y: Float[Array, "N"], | |
| Z: Float[Array, "N D"], | |
| j: int, | |
| cond_p: Callable[[Float[Array, "N D"], list[int]], Float[Array, "N D"]], | |
| ) -> Tuple[bool, int]: | |
| sample = functools.partial(self._sample, j=j, cond_p=cond_p) | |
| prev_y, prev_z = Y[:1][:, None], Z[:1] | |
| prev_zj, prev_null_zj, prev_cond_z = sample(prev_z) | |
| prev_d = np.concatenate([prev_y, prev_zj, prev_null_zj, prev_cond_z], axis=-1) | |
| for t in range(1, self.tau_max): | |
| y, z = Y[[t]][:, None], Z[[t]] | |
| zj, null_zj, cond_z = sample(z) | |
| u = np.concatenate([y, zj, cond_z], axis=-1) | |
| null_u = np.concatenate([y, null_zj, cond_z], axis=-1) | |
| payoff = self.payoff.compute(u, null_u, prev_d) | |
| self.wealth.update(payoff) | |
| d = np.concatenate([y, zj, null_zj, cond_z], axis=-1) | |
| prev_d = np.vstack([prev_d, d]) | |
| if self.wealth.rejected: | |
| return (True, t) | |
| return (False, t) | |
| class xSKIT(SequentialTester): | |
| """Local Conditional Independence Tester""" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.payoff = xMMD(config) | |
| self._queue = deque() | |
| def _sample( | |
| self, | |
| z: Float[Array, "D"], | |
| j: int, | |
| C: list[int], | |
| cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]], | |
| model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]], | |
| ) -> Tuple[Float[Array, "1"], Float[Array, "1"]]: | |
| if len(self._queue) == 0: | |
| Cuj = C + [j] | |
| h = cond_p(z, Cuj, self.tau_max) | |
| null_h = cond_p(z, C, self.tau_max) | |
| y = model(h)[:, None] | |
| null_y = model(null_h)[:, None] | |
| self._queue.extend(zip(y, null_y)) | |
| return self._queue.pop() | |
| def test( | |
| self, | |
| z: Float[Array, "D"], | |
| j: int, | |
| C: list[int], | |
| cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]], | |
| model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]], | |
| interrupt_on_rejection: bool = True, | |
| ) -> Tuple[bool, int]: | |
| sample = functools.partial(self._sample, z, j, C, cond_p, model) | |
| tau = self.tau_max - 1 | |
| prev_d = np.stack(sample(), axis=1) | |
| for t in range(1, self.tau_max): | |
| y, null_y = sample() | |
| payoff = self.payoff.compute(y, null_y, prev_d) | |
| self.wealth.update(payoff) | |
| d = np.stack([y, null_y], axis=1) | |
| prev_d = np.vstack([prev_d, d]) | |
| if self.wealth.rejected: | |
| tau = min(tau, t) | |
| if interrupt_on_rejection: | |
| break | |
| return (self.wealth.rejected, tau) | |