| from abc import ABC, abstractmethod | |
| from typing import Generic, TypeVar | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| T_cfg = TypeVar("T_cfg") | |
| T_encoder = TypeVar("T_encoder") | |
| class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]): | |
| cfg: T_cfg | |
| encoder: T_encoder | |
| def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None: | |
| self.cfg = cfg | |
| self.encoder = encoder | |
| def visualize( | |
| self, | |
| context: dict, | |
| global_step: int, | |
| ) -> dict[str, Float[Tensor, "3 _ _"]]: | |
| pass | |