import kornia import open_clip import torch from torch import nn class CLIPConditioner(nn.Module): mean: torch.Tensor std: torch.Tensor def __init__(self): super().__init__() self.module = open_clip.create_model_and_transforms( "ViT-H-14", pretrained="laion2b_s32b_b79k" )[0] self.module.eval().requires_grad_(False) # type: ignore self.register_buffer( "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False ) self.register_buffer( "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False ) def preprocess(self, x: torch.Tensor) -> torch.Tensor: x = kornia.geometry.resize( x, (224, 224), interpolation="bicubic", align_corners=True, antialias=True, ) x = (x + 1.0) / 2.0 x = kornia.enhance.normalize(x, self.mean, self.std) return x def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.preprocess(x) x = self.module.encode_image(x) return x