diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..d3f2e7e969ff7f05a8f2b7e840ba56cf6078a9e3 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.txt filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
+*.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..6001e403779d845b4ffbf7b40030fe73dddd0116
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+.idea
+
+.pth
\ No newline at end of file
diff --git a/README.md b/README.md
index d69abd77ea7f147f1e074bf9d70354486e123baa..09987e4997af467217294891e791f16e8ebb7962 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,47 @@
----
-title: Counterfactual World Models
-emoji: 📊
-colorFrom: purple
-colorTo: yellow
-sdk: gradio
-sdk_version: 4.44.0
-app_file: app.py
-pinned: false
-license: mit
-short_description: Vision foundation model that unifies vision structures
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
Understanding Physical Dynamics with Counterfactual World Modeling
+
+[**Rahul Venkatesh***](https://rahulvenkk.github.io/)
1 · [**Honglin Chen***](https://web.stanford.edu/~honglinc/)
1* · [**Kevin Feigelis***](https://neuroscience.stanford.edu/people/kevin-t-feigelis)
1 · [**Daniel M. Bear**](https://twitter.com/recursus?lang=en)
1 · [**Khaled Jedoui**](https://web.stanford.edu/~thekej/)
1 · [**Klemen Kotar**](https://klemenkotar.github.io/)
1 · [**Felix Binder**](https://ac.felixbinder.net/)
2 · [**Wanhee Lee**](https://www.linkedin.com/in/wanhee-lee-31102820b/)
1 · [**Sherry Liu**](https://neuroailab.github.io/cwm-physics/)
1 · [**Kevin A. Smith**](https://www.mit.edu/~k2smith/)
3 · [**Judith E. Fan**](https://cogtoolslab.github.io/)
1 · [**Daniel L. K. Yamins**](https://stanford.edu/~yamins/)
1
+
+(* equal contribution)
+
+
1Stanford
2UCSD
3MIT
+
+
+
+
+

+

+

+

+
+
+This work presents the Counterfactual World Modeling (CWM) framework. CWM is capable of counterfactual prediction and extraction of vision structures useful for understanding physical dynamics.
+
+
+
+## 📣 News
+
+- 2024-06-01: Release [project page](https://neuroailab.github.io) and [codes](https://github.com/rahulvenkk/cwm_release.git)
+
+## 🔨 Installation
+
+```
+git clone https://github.com/rahulvenkk/cwm_release.git
+pip install -e .
+```
+
+## ✨ Usage
+To download and use a pre-trianed model run the following
+```
+from cwm.model.model_factory import model_factory
+model = model_factory.load_model('vitbase_8x8patch_3frames_1tube')
+```
+This will automatically initialize the appropriate model class and download the specified weights to your `$CACHE` directory.
+
+## 🔄 Pre-training
+To train the model run the following script
+
+```
+./scripts/pretrain/3frame_patch8x8_mr0.90_gpu.sh
+```
diff --git a/assets/color_wheel.png b/assets/color_wheel.png
new file mode 100644
index 0000000000000000000000000000000000000000..84cdae7bb397f27457ea8b65457847e79c1f716e
Binary files /dev/null and b/assets/color_wheel.png differ
diff --git a/assets/cwm_teaser.gif b/assets/cwm_teaser.gif
new file mode 100644
index 0000000000000000000000000000000000000000..ff0a7ff2592aebf693544f447058a3ff96e17c86
--- /dev/null
+++ b/assets/cwm_teaser.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4fac6e545660c695f81f360a87f1060c44eea95f3ae7dbcf1fecbe2b097e3b6a
+size 12896275
diff --git a/assets/desk_1.jpg b/assets/desk_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e7e1c5fdea5cd0dcc0125e9125b3d4afe0defb51
--- /dev/null
+++ b/assets/desk_1.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84a3bfdd40841e8d291b4eb638ecc29b99f054ae6d3ea51b4cdc3090741987c8
+size 4731808
diff --git a/assets/flow_test_videos/libby.mp4 b/assets/flow_test_videos/libby.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..d5086b20a9a14f6c3f14e7df3eec2a947ea7688b
--- /dev/null
+++ b/assets/flow_test_videos/libby.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1d1887bc796d1883e8a63405398125476257572c0c8d7f1862bf309e422b4828
+size 671950
diff --git a/assets/flow_test_videos/weight_lifter.mp4 b/assets/flow_test_videos/weight_lifter.mp4
new file mode 100755
index 0000000000000000000000000000000000000000..12a2d354ee2e002c858d346e2731ad6c9168c02f
--- /dev/null
+++ b/assets/flow_test_videos/weight_lifter.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:525988b314936079f904236c614e2f987ba64fd5c69f4f12dd5b6b9076311854
+size 1176790
diff --git a/cwm/__init__.py b/cwm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cwm/data/__init__.py b/cwm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cwm/data/dataset.py b/cwm/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b4d0aafd109ce17f8a220f67fce80994f1145f8
--- /dev/null
+++ b/cwm/data/dataset.py
@@ -0,0 +1,453 @@
+import os
+import decord
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+
+class VideoMAE(torch.utils.data.Dataset):
+ """Load your own video classification dataset.
+ Parameters
+ ----------
+ root : str, required.
+ Path to the root folder storing the dataset.
+ setting : str, required.
+ A text file describing the dataset, each line per video sample.
+ There are three items in each line: (1) video path; (2) video length and (3) video label.
+ train : bool, default True.
+ Whether to load the training or validation set.
+ test_mode : bool, default False.
+ Whether to perform evaluation on the test set.
+ Usually there is three-crop or ten-crop evaluation strategy involved.
+ name_pattern : str, default None.
+ The naming pattern of the decoded video frames.
+ For example, img_00012.jpg.
+ video_ext : str, default 'mp4'.
+ If video_loader is set to True, please specify the video format accordinly.
+ is_color : bool, default True.
+ Whether the loaded image is color or grayscale.
+ modality : str, default 'rgb'.
+ Input modalities, we support only rgb video frames for now.
+ Will add support for rgb difference image and optical flow image later.
+ num_segments : int, default 1.
+ Number of segments to evenly divide the video into clips.
+ A useful technique to obtain global video-level information.
+ Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
+ num_crop : int, default 1.
+ Number of crops for each image. default is 1.
+ Common choices are three crops and ten crops during evaluation.
+ new_length : int, default 1.
+ The length of input video clip. Default is a single image, but it can be multiple video frames.
+ For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
+ new_step : int, default 1.
+ Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
+ new_step=2 means we will extract a video clip of every other frame.
+ temporal_jitter : bool, default False.
+ Whether to temporally jitter if new_step > 1.
+ video_loader : bool, default False.
+ Whether to use video loader to load data.
+ use_decord : bool, default True.
+ Whether to use Decord video loader to load data. Otherwise use mmcv video loader.
+ transform : function, default None.
+ A function that takes data and label and transforms them.
+ data_aug : str, default 'v1'.
+ Different types of data augmentation auto. Supports v1, v2, v3 and v4.
+ lazy_init : bool, default False.
+ If set to True, build a dataset instance without loading any dataset.
+ """
+
+ def __init__(self,
+ root,
+ setting,
+ train=True,
+ test_mode=False,
+ name_pattern='img_%05d.jpg',
+ video_ext='mp4',
+ is_color=True,
+ modality='rgb',
+ num_segments=1,
+ num_crop=1,
+ new_length=1,
+ new_step=1,
+ randomize_interframes=False,
+ transform=None,
+ temporal_jitter=False,
+ video_loader=False,
+ use_decord=False,
+ lazy_init=False,
+ is_video_dataset=True):
+
+ super(VideoMAE, self).__init__()
+ self.root = root
+ self.setting = setting
+ self.train = train
+ self.test_mode = test_mode
+ self.is_color = is_color
+ self.modality = modality
+ self.num_segments = num_segments
+ self.num_crop = num_crop
+ self.new_length = new_length
+
+ self.randomize_interframes = randomize_interframes
+ self._new_step = new_step # If randomize_interframes is True, then this is the max, otherwise it's just the skip
+ # self._skip_length = self.new_length * self.new_step # If randomize_interframes is True, then this isn't used, otherwise it's used as calculated
+ self.temporal_jitter = temporal_jitter
+ self.name_pattern = name_pattern
+ self.video_loader = video_loader
+ self.video_ext = video_ext
+ self.use_decord = use_decord
+ self.transform = transform
+ self.lazy_init = lazy_init
+
+ if (not self.lazy_init) and is_video_dataset:
+ self.clips = self._make_dataset(root, setting)
+ if len(self.clips) == 0:
+ raise (RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
+ "Check your data directory (opt.data-dir)."))
+
+ def __getitem__(self, index):
+
+ directory, target = self.clips[index]
+
+ if self.video_loader:
+ if '.' in directory.split('/')[-1]:
+ # data in the "setting" file already have extension, e.g., demo.mp4
+ video_name = directory
+ else:
+ # data in the "setting" file do not have extension, e.g., demo
+ # So we need to provide extension (i.e., .mp4) to complete the file name.
+ video_name = '{}.{}'.format(directory, self.video_ext)
+
+ try:
+ decord_vr = decord.VideoReader(video_name, num_threads=1)
+ except:
+ # return video_name
+ return (self.__getitem__(index + 1))
+ duration = len(decord_vr)
+
+ segment_indices, skip_offsets, new_step, skip_length = self._sample_train_indices(duration)
+
+ images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets,
+ new_step, skip_length)
+
+ process_data, mask = self.transform((images, None)) # T*C,H,W
+ process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0,
+ 1) # T*C,H,W -> T,C,H,W -> C,T,H,W
+
+ return (process_data, mask)
+
+ def __len__(self):
+ return len(self.clips)
+
+ def _make_dataset(self, directory, setting):
+ if not os.path.exists(setting):
+ raise (RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
+ clips = []
+ with open(setting) as split_f:
+ data = split_f.readlines()
+ for line in data:
+ line_info = line.split(' ')
+ # line format: video_path, video_duration, video_label
+ if len(line_info) < 2:
+ raise (RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
+ elif len(line_info) > 2:
+ line_info = (' '.join(line_info[:-1]), line_info[-1]) # filename has spaces
+ clip_path = os.path.join(line_info[0])
+ target = int(line_info[1])
+ item = (clip_path, target)
+ clips.append(item)
+ # import torch_xla.core.xla_model as xm
+ # print = xm.master_print
+ # print("Dataset created. Number of clips: ", len(clips))
+ return clips
+
+ def _sample_train_indices(self, num_frames):
+ if self.randomize_interframes is False:
+ new_step = self._new_step
+ else:
+ new_step = np.random.randint(1, self._new_step + 1)
+
+ skip_length = self.new_length * new_step
+
+ average_duration = (num_frames - skip_length + 1) // self.num_segments
+ if average_duration > 0:
+ offsets = np.multiply(list(range(self.num_segments)),
+ average_duration)
+ offsets = offsets + np.random.randint(average_duration,
+ size=self.num_segments)
+ elif num_frames > max(self.num_segments, skip_length):
+ offsets = np.sort(np.random.randint(
+ num_frames - skip_length + 1,
+ size=self.num_segments))
+ else:
+ offsets = np.zeros((self.num_segments,))
+
+ if self.temporal_jitter:
+ skip_offsets = np.random.randint(
+ new_step, size=skip_length // new_step)
+ else:
+ skip_offsets = np.zeros(
+ skip_length // new_step, dtype=int)
+ return offsets + 1, skip_offsets, new_step, skip_length
+
+ def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets, new_step,
+ skip_length):
+ sampled_list = []
+ frame_id_list = []
+ for seg_ind in indices:
+ offset = int(seg_ind)
+ for i, _ in enumerate(range(0, skip_length, new_step)):
+ if offset + skip_offsets[i] <= duration:
+ frame_id = offset + skip_offsets[i] - 1
+ else:
+ frame_id = offset - 1
+ frame_id_list.append(frame_id)
+ if offset + new_step < duration:
+ offset += new_step
+ try:
+ video_data = video_reader.get_batch(frame_id_list).asnumpy()
+ sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in
+ enumerate(frame_id_list)]
+ except:
+ raise RuntimeError(
+ 'Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory,
+ duration))
+ return sampled_list
+
+
+class ContextAndTargetVideoDataset(VideoMAE):
+ """
+ A video dataset whose provided videos consist of (1) a "context" sequence of length Tc
+ and (2) a "target" sequence Tt.
+
+ These two sequences have the same frame rate (specificiable in real units) but are
+ separated by a specified gap (which may vary for different examples.)
+
+ The main use case is for training models to predict ahead by some variable amount,
+ given the context.
+ """
+
+ standard_fps = [12, 24, 30, 48, 60, 100]
+
+ def __init__(self,
+ root,
+ setting,
+ train=True,
+ test_mode=False,
+ transform=None,
+ step_units='ms',
+ new_step=150,
+ start_frame=0,
+ context_length=2,
+ target_length=1,
+ channels_first=True,
+ generate_masks=True,
+ mask_generator=None,
+ context_target_gap=[400, 600],
+ normalize_timestamps=True,
+ default_fps=30,
+ min_fps=0.1,
+ seed=0,
+ *args,
+ **kwargs):
+ super(ContextAndTargetVideoDataset, self).__init__(
+ root=root,
+ setting=setting,
+ train=train,
+ test_mode=test_mode,
+ transform=transform,
+ new_length=context_length,
+ use_decord=True,
+ lazy_init=False,
+ video_loader=True,
+ *args, **kwargs)
+
+ # breakpoint()
+
+ self.context_length = self.new_length
+ self.target_length = target_length
+
+ ## convert from fps and step size to frames
+ self._fps = None
+ self._min_fps = min_fps
+ self._default_fps = default_fps
+ self._step_units = step_units
+ self.new_step = new_step
+
+ ## sampling for train and test
+ self._start_frame = start_frame
+ self.gap = context_target_gap
+ self.seed = seed
+ self.rng = np.random.RandomState(seed=seed)
+
+ # breakpoint()
+
+ ## output formatting
+ self._channels_first = channels_first
+ self._normalize_timestamps = normalize_timestamps
+ self._generate_masks = generate_masks
+ self.mask_generator = mask_generator
+
+
+ def _get_frames_per_t(self, t):
+ if self._step_units == 'frames' or (self._step_units is None):
+ return int(t)
+
+ assert self._fps is not None
+ t_per_frame = 1 / self._fps
+ if self._step_units in ['ms', 'milliseconds']:
+ t_per_frame *= 1000.0
+
+ return max(int(np.round(t / t_per_frame)), 1)
+
+ @property
+ def new_step(self):
+ if self._fps is None:
+ return None
+ else:
+ return self._get_frames_per_t(self._new_step)
+
+ @new_step.setter
+ def new_step(self, v):
+ self._new_step = v
+
+ @property
+ def gap(self):
+ if self._fps is None:
+ return [1, 2]
+ else:
+ gap = [self._get_frames_per_t(self._gap[0]),
+ self._get_frames_per_t(self._gap[1])]
+ gap[1] = max(gap[1], gap[0] + 1)
+ return gap
+
+ @gap.setter
+ def gap(self, v):
+ if v is None:
+ v = self._new_step
+ if not isinstance(v, (list, tuple)):
+ v = [v, v]
+ self._gap = v
+
+ def _get_video_name(self, directory):
+ if ''.join(['.', self.video_ext]) in directory.split('/')[-1]:
+ # data in the "setting" file has extension, e.g. demo.mpr
+ video_name = directory
+ else:
+ # data doesn't have an extension
+ video_name = '{}.{}'.format(directory, self.video_ext)
+ return video_name
+
+ def _set_fps(self, reader):
+ """click fps to a standard"""
+ if self._step_units == 'frames' or self._step_units is None:
+ self._fps = None
+ else:
+ self._fps = None
+ fps = reader.get_avg_fps()
+ for st in self.standard_fps:
+ if (int(np.floor(fps)) == st) or (int(np.ceil(fps)) == st):
+ self._fps = st
+ if self._fps is None:
+ self._fps = int(np.round(fps))
+
+ if self._fps < self._min_fps:
+ self._fps = self._default_fps
+
+ def _get_step_and_gap(self):
+ step = self.new_step
+ if self.randomize_interframes and self.train:
+ step = self.rng.randint(1, step + 1)
+
+ if self.train:
+ gap = self.rng.randint(*self.gap)
+ else:
+ gap = sum(self.gap) // 2
+ return (step, gap)
+
+ def _sample_frames(self):
+ step, gap = self._get_step_and_gap()
+
+ ## compute total length of sample
+ ## e.g. if context_length = 2, step = 1, gap = 10, target_length = 2:
+ ## total_length = 2 * 1 + 10 + (2 - 1) * 1 = 13
+ ## so len(video) must be >= 13
+ self._total_length = self.context_length * step + gap + (self.target_length - 1) * step
+ if self._total_length > (self._num_frames - self._start_frame):
+ if self.train:
+ return None
+ else:
+ raise ValueError(
+ "movie of length %d starting at fr=%d is too long for video of %d frames" % \
+ (self._total_length, self._start_frame, self._num_frames))
+
+ ## sample the frames randomly (if training) or from the start frame (if test)
+ if self.train:
+ self.start_frame_now = self.rng.randint(
+ min(self._start_frame, self._num_frames - self._total_length),
+ self._num_frames - self._total_length + 1)
+ else:
+ self.start_frame_now = min(self._start_frame, self._num_frames - self._total_length)
+
+ frames = [self.start_frame_now + i * step for i in range(self.context_length)]
+ frames += [frames[-1] + gap + i * step for i in range(self.target_length)]
+
+ # breakpoint()
+
+ return frames
+
+ def _decode_frame_images(self, reader, frames):
+ try:
+ video_data = reader.get_batch(frames).asnumpy()
+ video_data = [Image.fromarray(video_data[t, :, :, :]).convert('RGB')
+ for t, _ in enumerate(frames)]
+ except:
+ raise RuntimeError(
+ "Error occurred in reading frames {} from video {} of duration {}".format(
+ frames, self.index, self._num_frames))
+ return video_data
+
+ def __getitem__(self, index):
+
+ self.index = index
+ self.directory, target = self.clips[index]
+
+ self.video_name = self._get_video_name(self.directory)
+
+ ## build decord loader
+ try:
+ decord_vr = decord.VideoReader(self.video_name, num_threads=1)
+ self._set_fps(decord_vr)
+ except:
+ # return self.video_name
+ return (self.__getitem__(index + 1))
+
+ ## sample the video
+ self._num_frames = len(decord_vr)
+ self.frames = self._sample_frames()
+ if self.frames is None:
+ print("no movie of length %d for video idx=%d" % (self._total_length, self.index))
+ return self.__getitem__(index + 1)
+
+ ## decode to PIL.Image
+ image_list = self._decode_frame_images(decord_vr, self.frames)
+
+ ## postproc to torch.Tensor and mask generation
+ if self.transform is None:
+ image_tensor = torch.stack([transforms.ToTensor()(img) for img in image_list], 0)
+ else:
+ image_tensor = self.transform((image_list, None))
+
+ image_tensor = image_tensor.view(self.context_length + self.target_length, 3, *image_tensor.shape[-2:])
+
+ ## VMAE expects [B,C,T,H,W] rather than [B,T,C,H,W]
+ if self._channels_first:
+ image_tensor = image_tensor.transpose(0, 1)
+
+ if self._generate_masks and self.mask_generator is not None:
+ mask = self.mask_generator()
+ return image_tensor, mask.bool()
+ else:
+ return image_tensor
diff --git a/cwm/data/dataset_utils.py b/cwm/data/dataset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..db9057b48eee5df1c91ed00f86405f1b1a619f30
--- /dev/null
+++ b/cwm/data/dataset_utils.py
@@ -0,0 +1,73 @@
+from torchvision import transforms
+from cwm.data.transforms import *
+from cwm.data.dataset import ContextAndTargetVideoDataset
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from cwm.data.masking_generator import RotatedTableMaskingGenerator
+
+class DataAugmentationForVideoMAE(object):
+ def __init__(self, augmentation_type, input_size, augmentation_scales):
+
+ transform_list = []
+
+ self.scale = GroupScale(input_size)
+ transform_list.append(self.scale)
+
+ if augmentation_type == 'multiscale':
+ self.train_augmentation = GroupMultiScaleCrop(input_size, list(augmentation_scales))
+ elif augmentation_type == 'center':
+ self.train_augmentation = GroupCenterCrop(input_size)
+
+ transform_list.extend([self.train_augmentation, Stack(roll=False), ToTorchFormatTensor(div=True)])
+
+ # Normalize input images
+ normalize = GroupNormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
+ transform_list.append(normalize)
+
+ self.transform = transforms.Compose(transform_list)
+
+ def __call__(self, images):
+ process_data, _ = self.transform(images)
+ return process_data
+
+ def __repr__(self):
+ repr = "(DataAugmentationForVideoMAE,\n"
+ repr += " transform = %s,\n" % str(self.transform)
+ repr += ")"
+ return repr
+
+
+def build_pretraining_dataset(args):
+
+ dataset_list = []
+ data_transform = DataAugmentationForVideoMAE(args.augmentation_type, args.input_size, args.augmentation_scales)
+
+ mask_generator = RotatedTableMaskingGenerator(
+ input_size=args.mask_input_size,
+ mask_ratio=args.mask_ratio,
+ tube_length=args.tubelet_size,
+ batch_size=args.batch_size,
+ mask_type=args.mask_type
+ )
+
+ for data_path in [args.data_path] if args.data_path_list is None else args.data_path_list:
+ dataset = ContextAndTargetVideoDataset(
+ root=None,
+ setting=data_path,
+ video_ext='mp4',
+ is_color=True,
+ modality='rgb',
+ context_length=args.context_frames,
+ target_length=args.target_frames,
+ step_units=args.temporal_units,
+ new_step=args.sampling_rate,
+ context_target_gap=args.context_target_gap,
+ transform=data_transform,
+ randomize_interframes=False,
+ channels_first=True,
+ temporal_jitter=False,
+ train=True,
+ mask_generator=mask_generator,
+ )
+ dataset_list.append(dataset)
+ dataset = torch.utils.data.ConcatDataset(dataset_list)
+ return dataset
diff --git a/cwm/data/masking_generator.py b/cwm/data/masking_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..19fdf4b2444966a0b9517024f1bcf6cb18d3f125
--- /dev/null
+++ b/cwm/data/masking_generator.py
@@ -0,0 +1,86 @@
+import numpy as np
+import torch
+
+def get_tubes(masks_per_frame, tube_length):
+ rp = torch.randperm(len(masks_per_frame))
+ masks_per_frame = masks_per_frame[rp]
+
+ tubes = [masks_per_frame]
+ for x in range(tube_length - 1):
+ masks_per_frame = masks_per_frame.clone()
+ rp = torch.randperm(len(masks_per_frame))
+ masks_per_frame = masks_per_frame[rp]
+ tubes.append(masks_per_frame)
+
+ tubes = torch.vstack(tubes)
+
+ return tubes
+
+class RotatedTableMaskingGenerator:
+ def __init__(self,
+ input_size,
+ mask_ratio,
+ tube_length,
+ batch_size,
+ mask_type='rotated_table',
+ seed=None,
+ randomize_num_visible=False):
+
+ self.batch_size = batch_size
+
+ self.mask_ratio = mask_ratio
+ self.tube_length = tube_length
+
+ self.frames, self.height, self.width = input_size
+ self.num_patches_per_frame = self.height * self.width
+ self.total_patches = self.frames * self.num_patches_per_frame
+
+ self.seed = seed
+ self.randomize_num_visible = randomize_num_visible
+
+ self.mask_type = mask_type
+
+ def __repr__(self):
+ repr_str = "Inverted Table Mask: total patches {}, tube length {}, randomize num visible? {}, seed {}".format(
+ self.total_patches, self.tube_length, self.randomize_num_visible, self.seed
+ )
+ return repr_str
+
+ def __call__(self, m=None):
+
+ if self.mask_type == 'rotated_table_magvit':
+ self.mask_ratio = np.random.uniform(low=0.0, high=1)
+ self.mask_ratio = np.cos(self.mask_ratio * np.pi / 2)
+ elif self.mask_type == 'rotated_table_maskvit':
+ self.mask_ratio = np.random.uniform(low=0.5, high=1)
+
+ all_masks = []
+ for b in range(self.batch_size):
+
+ self.num_masks_per_frame = max(0, int(self.mask_ratio * self.num_patches_per_frame))
+ self.total_masks = self.tube_length * self.num_masks_per_frame
+
+ num_masks = self.num_masks_per_frame
+
+ if self.randomize_num_visible:
+ assert "Randomize num visible Not implemented"
+ num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
+
+ if self.mask_ratio == 0:
+ mask_per_frame = torch.hstack([
+ torch.zeros(self.num_patches_per_frame - num_masks),
+ ])
+ else:
+ mask_per_frame = torch.hstack([
+ torch.zeros(self.num_patches_per_frame - num_masks),
+ torch.ones(num_masks),
+ ])
+
+ tubes = get_tubes(mask_per_frame, self.tube_length)
+ top = torch.zeros(self.height * self.width).to(tubes.dtype)
+
+ top = torch.tile(top, (self.frames - self.tube_length, 1))
+ mask = torch.cat([top, tubes])
+ mask = mask.flatten()
+ all_masks.append(mask)
+ return torch.stack(all_masks)
diff --git a/cwm/data/transforms.py b/cwm/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..f16a7373a9e1d7abb1c91fd041827b380b5392ed
--- /dev/null
+++ b/cwm/data/transforms.py
@@ -0,0 +1,206 @@
+import torch
+import torchvision.transforms.functional as F
+import warnings
+import random
+import numpy as np
+import torchvision
+from PIL import Image, ImageOps
+import numbers
+
+
+class GroupRandomCrop(object):
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+
+ w, h = img_group[0].size
+ th, tw = self.size
+
+ out_images = list()
+
+ x1 = random.randint(0, w - tw)
+ y1 = random.randint(0, h - th)
+
+ for img in img_group:
+ assert(img.size[0] == w and img.size[1] == h)
+ if w == tw and h == th:
+ out_images.append(img)
+ else:
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
+
+ return (out_images, label)
+
+
+class GroupCenterCrop(object):
+ def __init__(self, size):
+ self.worker = torchvision.transforms.CenterCrop(size)
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+ return ([self.worker(img) for img in img_group], label)
+
+
+class GroupNormalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, tensor_tuple):
+ tensor, label = tensor_tuple
+ rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
+ rep_std = self.std * (tensor.size()[0]//len(self.std))
+
+ # TODO: make efficient
+ for t, m, s in zip(tensor, rep_mean, rep_std):
+ t.sub_(m).div_(s)
+
+ return (tensor,label)
+
+
+class GroupGrayScale(object):
+ def __init__(self, size):
+ self.worker = torchvision.transforms.Grayscale(size)
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+ return ([self.worker(img) for img in img_group], label)
+
+
+class GroupScale(object):
+ """ Rescales the input PIL.Image to the given 'size'.
+ 'size' will be the size of the smaller edge.
+ For example, if height > width, then image will be
+ rescaled to (size * height / width, size)
+ size: size of the smaller edge
+ interpolation: Default: PIL.Image.BILINEAR
+ """
+
+ def __init__(self, size, interpolation=Image.BILINEAR):
+ self.worker = torchvision.transforms.Resize(size, interpolation)
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+ return ([self.worker(img) for img in img_group], label)
+
+
+class GroupMultiScaleCrop(object):
+
+ def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
+ self.scales = scales if scales is not None else [1, 875, .75, .66]
+ self.max_distort = max_distort
+ self.fix_crop = fix_crop
+ self.more_fix_crop = more_fix_crop
+ self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
+ self.interpolation = Image.BILINEAR
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+
+ im_size = img_group[0].size
+
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
+ crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
+ return (ret_img_group, label)
+
+ def _sample_crop_size(self, im_size):
+ image_w, image_h = im_size[0], im_size[1]
+
+ # find a crop size
+ base_size = min(image_w, image_h)
+ crop_sizes = [int(base_size * x) for x in self.scales]
+ crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
+ crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
+
+ pairs = []
+ for i, h in enumerate(crop_h):
+ for j, w in enumerate(crop_w):
+ if abs(i - j) <= self.max_distort:
+ pairs.append((w, h))
+
+ crop_pair = random.choice(pairs)
+ if not self.fix_crop:
+ w_offset = random.randint(0, image_w - crop_pair[0])
+ h_offset = random.randint(0, image_h - crop_pair[1])
+ else:
+ w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
+
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
+
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
+ offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
+ return random.choice(offsets)
+
+ @staticmethod
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
+ w_step = (image_w - crop_w) // 4
+ h_step = (image_h - crop_h) // 4
+
+ ret = list()
+ ret.append((0, 0)) # upper left
+ ret.append((4 * w_step, 0)) # upper right
+ ret.append((0, 4 * h_step)) # lower left
+ ret.append((4 * w_step, 4 * h_step)) # lower right
+ ret.append((2 * w_step, 2 * h_step)) # center
+
+ if more_fix_crop:
+ ret.append((0, 2 * h_step)) # center left
+ ret.append((4 * w_step, 2 * h_step)) # center right
+ ret.append((2 * w_step, 4 * h_step)) # lower center
+ ret.append((2 * w_step, 0 * h_step)) # upper center
+
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
+ return ret
+
+
+class Stack(object):
+
+ def __init__(self, roll=False):
+ self.roll = roll
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+
+ if img_group[0].mode == 'L':
+ return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
+ elif img_group[0].mode == 'RGB':
+ if self.roll:
+ return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
+ else:
+ return (np.concatenate(img_group, axis=2), label)
+
+
+class ToTorchFormatTensor(object):
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
+ def __init__(self, div=True):
+ self.div = div
+
+ def __call__(self, pic_tuple):
+ pic, label = pic_tuple
+
+ if isinstance(pic, np.ndarray):
+ # handle numpy array
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
+ else:
+ # handle PIL Image
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
+ # put it from HWC to CHW format
+ # yikes, this transpose takes 80% of the loading time/CPU
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
+ return (img.float().div(255.) if self.div else img.float(), label)
+
+
+class IdentityTransform(object):
+
+ def __call__(self, data):
+ return data
diff --git a/cwm/data/video_file_lists/kinetics_400_train_list.txt b/cwm/data/video_file_lists/kinetics_400_train_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..35c95966b7313044367975e20aeb24c444e324e6
--- /dev/null
+++ b/cwm/data/video_file_lists/kinetics_400_train_list.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65e14c0735b4c90c57022add2407a8524d246cc09b3d5a7e83b963ac3b231032
+size 19539143
diff --git a/cwm/data/video_file_lists/kinetics_400_train_list_sing.txt b/cwm/data/video_file_lists/kinetics_400_train_list_sing.txt
new file mode 100644
index 0000000000000000000000000000000000000000..af8496578845cf64678d43290bec5fb4bf7f1b86
--- /dev/null
+++ b/cwm/data/video_file_lists/kinetics_400_train_list_sing.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b18ccdce4616fb32a0aababc2342a640a11f7d73439f49358a16cc99e7eaed3
+size 1943
diff --git a/cwm/engine_for_pretraining.py b/cwm/engine_for_pretraining.py
new file mode 100644
index 0000000000000000000000000000000000000000..854755c3eab5980335ff7d4abf4f8694b2e87400
--- /dev/null
+++ b/cwm/engine_for_pretraining.py
@@ -0,0 +1,92 @@
+import json
+import os
+import time
+from typing import Iterable
+
+import torch
+import torch.nn as nn
+from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
+
+import utils
+
+from datetime import datetime
+
+
+def train_one_epoch(model: torch.nn.Module,
+ data_loader: Iterable,
+ optimizer: torch.optim.Optimizer,
+ device: torch.device,
+ epoch: int,
+ loss_scaler,
+ start_steps=None,
+ lr_schedule_values=None,
+ wd_schedule_values=None,
+ global_rank=None,
+ args=None,
+ loss_func = nn.MSELoss(),
+ ):
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+
+ if args.eval:
+ model.eval()
+ else:
+ model.train()
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+
+ header = f'Epoch [{epoch}]'
+ patch_size = model.module.encoder.patch_size[-2:]
+ tubelet_size = model.module.encoder.patch_size[0]
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
+
+ for step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
+
+ # assign learning rate & weight decay for each iteration
+ it = start_steps + step # global training iteration
+ if (lr_schedule_values is not None or wd_schedule_values is not None) and (step % args.accum_iter == 0):
+ for i, param_group in enumerate(optimizer.param_groups):
+ if lr_schedule_values is not None:
+ param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
+ if wd_schedule_values is not None and param_group["weight_decay"] > 0:
+ param_group["weight_decay"] = wd_schedule_values[it]
+
+ # prepare input
+ videos, bool_masked_pos = batch
+ videos = videos.to(device, non_blocking=True)
+ bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1)
+
+ # prepare target
+ with torch.no_grad():
+ unnorm_videos = videos * std + mean # in [0, 1]
+ videos_patch = utils.patchify(unnorm_videos, tubelet_size, patch_size)
+ B, _, C = videos_patch.shape
+ labels = videos_patch[bool_masked_pos].reshape(B, -1, C)
+
+ # feedforward
+ with torch.cuda.amp.autocast(enabled=True):
+ outputs = model(videos, bool_masked_pos)
+ loss = loss_func(input=outputs, target=labels)
+
+ loss_value = loss.item()
+
+ # backward
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
+ loss /= args.accum_iter
+ loss_scaler(loss, optimizer, clip_grad=None,
+ parameters=model.parameters(), create_graph=is_second_order,
+ update_grad=(step + 1) % args.accum_iter == 0)
+
+ torch.cuda.synchronize()
+ metric_logger.update(loss=loss_value)
+
+ if (step + 1) % args.accum_iter == 0:
+ optimizer.zero_grad()
+
+ lr = optimizer.param_groups[0]["lr"]
+ metric_logger.update(lr=lr)
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
diff --git a/cwm/eval/Action_recognition/__init__.py b/cwm/eval/Action_recognition/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cwm/eval/Flow/__init__.py b/cwm/eval/Flow/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cwm/eval/Flow/create_spring_submission_parallel.sh b/cwm/eval/Flow/create_spring_submission_parallel.sh
new file mode 100644
index 0000000000000000000000000000000000000000..91a1f8257f8af8acd01687fe019d745673c37884
--- /dev/null
+++ b/cwm/eval/Flow/create_spring_submission_parallel.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+# Define the path to the dataset and the Python script
+DATASET_PATH="/ccn2/dataset/Flows_Kinetics/SPRING/spring/test/"
+SCRIPT_PATH="./create_spring_submission_unified.py"
+SAVE_DATA_PATH=${1}
+MODEL=${2}
+# Counter for GPUs
+GPU_COUNTER=0
+
+# Number of GPUs available
+NUM_GPUS=8
+
+#kill session
+tmux kill-session -t extraction
+
+tmux new-session -d -s "extraction"
+
+# Iterate through each folder in the dataset
+for FOLDER in $(find $DATASET_PATH -mindepth 1 -maxdepth 1 -type d | sort); do
+ # Extract the folder name for the tmux session name
+ FOLDER_NAME=$(basename $FOLDER)
+
+ # Create a new detached tmux session for each folder
+ tmux new-window -t extraction -n "$FOLDER" "ulimit -n 65535; CUDA_VISIBLE_DEVICES=$GPU_COUNTER python $SCRIPT_PATH --folder $FOLDER --gpu $GPU_COUNTER --save_data_path $SAVE_DATA_PATH --model $MODEL; echo 'Press Enter to continue...'; read -p ''"
+ # Increment the GPU counter and reset if it exceeds the number of GPUs
+ GPU_COUNTER=$((GPU_COUNTER + 1))
+ if [ $GPU_COUNTER -ge $NUM_GPUS ]; then
+ GPU_COUNTER=0
+ fi
+
+ sleep 1
+done
+
+tmux attach-session -t extraction
+
diff --git a/cwm/eval/Flow/create_spring_submission_unified.py b/cwm/eval/Flow/create_spring_submission_unified.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6fb526cc175dd8d2f9a66f624ed2bd8ad46fec6
--- /dev/null
+++ b/cwm/eval/Flow/create_spring_submission_unified.py
@@ -0,0 +1,111 @@
+
+import argparse
+
+# Parse command-line arguments
+import importlib
+import time
+
+parser = argparse.ArgumentParser(description='Process a folder with RAFT')
+parser.add_argument('--folder', type=str, required=True, help='Folder to process')
+parser.add_argument('--model', type=str, required=True, help='Model used to extract flow')
+parser.add_argument('--save_data_path', type=str, required=True, help='where to save the data')
+parser.add_argument('--gpu', type=int, default=0, help='GPU index to use')
+args = parser.parse_args()
+import os
+
+os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
+import torch
+torch.cuda.set_device(0)
+
+import h5py
+
+def writeFlo5File(flow, filename):
+ with h5py.File(filename, "w") as f:
+ f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
+
+if __name__ == '__main__':
+ module_name, class_name = args.model.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+
+ model = getattr(module, class_name)
+ model = model().cuda().eval()
+
+ folder = args.folder.split('/')[-1]
+
+ import os
+ import matplotlib.pyplot as plt
+
+ import torch
+ import torchvision.transforms as transforms
+
+ # import smurf # Assuming this is your custom inference module
+
+ # Path for the dataset
+ dataset_path = '/ccn2/dataset/Flows_Kinetics/SPRING/spring/test/'
+
+ save_data_path = args.save_data_path
+
+ if not os.path.exists(save_data_path):
+ os.makedirs(save_data_path)
+
+ resize_crop = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ import numpy as np
+
+ def l2norm(x):
+ return np.sqrt((x ** 2).sum(-1))
+
+ all_epe = []
+ # Create a new HDF5 file
+
+ TAG_FLOAT = 202021.25
+
+ # Iterate over each folder in the dataset directory
+ for dir in ['FW', 'BW']:
+ for stereo in ['left', 'right']:
+ files = sorted(os.listdir(os.path.join(dataset_path, folder, f'frame_{stereo}')))
+ output_folder = os.path.join(save_data_path, folder)
+ output_folder = os.path.join(output_folder, f'flow_{dir}_{stereo}')
+
+ if not os.path.exists(output_folder):
+ os.makedirs(output_folder)
+
+ for ct_f in range(len(files) - 1):
+ # Read images
+ if dir == 'FW':
+ f1 = files[ct_f]
+ f2 = files[ct_f + 1]
+ else:
+ f2 = files[ct_f]
+ f1 = files[ct_f + 1]
+ t = time.time()
+ image1_path = os.path.join(dataset_path, folder, f'frame_{stereo}', f1)
+ image2_path = os.path.join(dataset_path, folder, f'frame_{stereo}', f2)
+
+ idx = image1_path.split('/')[-1].split('.')[0].split('_')[-1]
+ flow_save_path = os.path.join(output_folder, f'flow_{dir}_{stereo}_' + idx + '.flo5')
+
+ # if os.path.exists(flow_save_path):
+ # try:
+ # with h5py.File(flow_save_path, 'r+') as f:
+ # if f['flow'][:].shape[0] == 2:
+ # flow = f['flow'][:].transpose([1, 2, 0])
+ # del f['flow']
+ # f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
+ # continue
+ # else:
+ # continue
+ # except:
+ # pass
+
+ image1_ = plt.imread(image1_path)
+ image2_ = plt.imread(image2_path)
+
+ image1 = resize_crop(image1_)
+ image2 = resize_crop(image2_)
+
+ forward_flow = model.forward(image1, image2)
+
+ writeFlo5File(forward_flow, flow_save_path)
diff --git a/cwm/eval/Flow/flow_extraction_classes.py b/cwm/eval/Flow/flow_extraction_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1923d72745b998cb144408f0eee763d9f6b360b
--- /dev/null
+++ b/cwm/eval/Flow/flow_extraction_classes.py
@@ -0,0 +1,122 @@
+import torch
+import torch.nn as nn
+
+import cwm.model.model_pretrain as vmae_tranformers
+from . import flow_utils
+from . import losses as bblosses
+
+
+# Normal Resolution
+def l2_norm(x):
+ return x.square().sum(-3, True).sqrt()
+
+
+
+
+# x.shape
+def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5):
+ fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd)
+ flow_diff_fwd = flow_fwd + fwd_bck_cycle
+
+ bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck)
+ flow_diff_bck = flow_bck + bck_fwd_cycle
+
+ norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2
+ norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2
+
+ occ_thresh_fwd = occ_thresh * norm_fwd + 0.5
+ occ_thresh_bck = occ_thresh * norm_bck + 0.5
+
+ occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float()
+ occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float()
+
+ return occ_mask_fwd, occ_mask_bck
+
+
+class ExtractFlow(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ return
+
+ def forward(self, img1, img2):
+ '''
+ img1: first frame
+ img2: second frame
+ returns: flow map (h, w, 2)
+ '''
+
+from cwm.data.masking_generator import RotatedTableMaskingGenerator
+
+class CWM(ExtractFlow):
+ def __init__(self, model_name, patch_size, weights_path):
+ super().__init__()
+
+ self.patch_size = patch_size
+ model = getattr(vmae_tranformers, model_name)
+ vmae_8x8_full = model().cuda().eval().requires_grad_(False)
+
+ VMAE_LOAD_PATH = weights_path
+ did_load = vmae_8x8_full.load_state_dict(torch.load(VMAE_LOAD_PATH)['model'], strict=False)
+ print(did_load, VMAE_LOAD_PATH)
+
+ self.predictor = vmae_8x8_full
+
+ self.mask_generator = RotatedTableMaskingGenerator(
+ input_size=(vmae_8x8_full.num_frames, 28, 28),
+ mask_ratio=0.0,
+ tube_length=1,
+ batch_size=1,
+ mask_type='rotated_table'
+ )
+
+ def forward(self, img1, img2):
+ '''
+ img1: [3, 1024, 1024]
+ img1: [3, 1024, 1024]
+ both images are imagenet normalized
+ '''
+
+ with torch.no_grad():
+ FF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor,
+ self.mask_generator, img1[None],
+ img2[None],
+ num_scales=2,
+ min_scale=224,
+ N_mask_samples=1)
+
+ BF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor,
+ self.mask_generator,
+ img2[None],
+ img1[None],
+ num_scales=2,
+ min_scale=224,
+ N_mask_samples=1)
+
+ # FF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor,
+ # self.mask_generator, img1[None],
+ # img2[None], img2[None],
+ # neg_back_flow=True, num_scales=1,
+ # min_scale=224, N_mask_samples=1,
+ # mask_ratio=0.0)
+ #
+ # BF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor,
+ # self.mask_generator, img2[None],
+ # img1[None], img1[None],
+ # neg_back_flow=True, num_scales=1,
+ # min_scale=224, N_mask_samples=1,
+ # mask_ratio=0.0)
+
+ occ_mask = get_occ_masks(FF, BF)[0]
+
+ FF = FF * occ_mask
+
+ FF = FF[0]
+
+ return FF#.cpu().numpy().transpose([1, 2, 0])
+
+
+class CWM_8x8(CWM):
+ def __init__(self):
+ super().__init__('vitb_8x8patch_3frames', 8,
+ '/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth')
diff --git a/cwm/eval/Flow/flow_utils.py b/cwm/eval/Flow/flow_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad37600b33e096c07939d9b797dd6a6db927abb7
--- /dev/null
+++ b/cwm/eval/Flow/flow_utils.py
@@ -0,0 +1,569 @@
+import random
+import math
+import numpy as np
+import torch
+import torch.nn.functional as F
+from . import losses as bblosses
+import kornia
+
+IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
+
+def compute_optical_flow(embedding_tensor, mask_tensor, frame_size):
+ # Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame
+ mask_unrolled = mask_tensor.view(-1)
+
+ second_frame_unmask_indices = torch.where(mask_unrolled[frame_size ** 2:] == False)[0]
+
+ # Divide the embedding tensor into two parts: corresponding to the first and the second frame
+ first_frame_embeddings = embedding_tensor[0, :frame_size ** 2, :]
+ second_frame_embeddings = embedding_tensor[0, frame_size ** 2:, :]
+
+ # print(first_frame_embeddings.shape, second_frame_embeddings.shape, embedding_tensor.shape)
+
+ # Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame
+ dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T)
+ norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :]
+ cos_sim_matrix = dot_product / norms
+
+ # Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame
+ first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1)
+
+ # Convert the 1D pixel indices into 2D coordinates
+ second_frame_y = second_frame_unmask_indices // frame_size
+ second_frame_x = second_frame_unmask_indices % frame_size
+ first_frame_y = first_frame_most_similar_indices // frame_size
+ first_frame_x = first_frame_most_similar_indices % frame_size
+
+ # Compute the x and y displacements and convert them to float
+ displacements_x = (second_frame_x - first_frame_x).float()
+ displacements_y = (second_frame_y - first_frame_y).float()
+
+ # Initialize optical flow tensor
+ optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device)
+
+ # Assign the computed displacements to the corresponding pixels in the optical flow tensor
+ optical_flow[0, second_frame_y, second_frame_x] = displacements_x
+ optical_flow[1, second_frame_y, second_frame_x] = displacements_y
+
+ return optical_flow
+
+
+def get_minimal_224_crops_new_batched(video_tensor, N):
+ B, T, C, H, W = video_tensor.shape
+
+ # Calculate the number of crops needed in both the height and width dimensions
+ num_crops_h = math.ceil(H / 224) if H > 224 else 1
+ num_crops_w = math.ceil(W / 224) if W > 224 else 1
+
+ # Calculate the step size for the height and width dimensions
+ step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1))
+ step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1))
+
+ # Create a list to store the cropped tensors and their start positions
+ cropped_tensors = []
+ crop_positions = []
+
+ # Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list
+ for i in range(num_crops_h):
+ for j in range(num_crops_w):
+ start_h = i * step_size_h
+ start_w = j * step_size_w
+ end_h = min(start_h + 224, H)
+ end_w = min(start_w + 224, W)
+ crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w]
+ cropped_tensors.append(crop)
+ crop_positions.append((start_h, start_w))
+
+ D = len(cropped_tensors)
+
+ # If N is greater than D, generate additional random crops
+ if N > D and H > 224 and W > 224: # check if H and W are greater than 224
+ for _ in range(N - D):
+ start_h = random.randint(0, H - 224)
+ start_w = random.randint(0, W - 224)
+ crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)]
+ cropped_tensors.append(crop)
+ crop_positions.append((start_h, start_w))
+
+ # Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224)
+ cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors]
+
+ return cropped_tensors, crop_positions
+
+
+def create_weighted_mask_batched(h, w):
+ y_mask = np.linspace(0, 1, h)
+ y_mask = np.minimum(y_mask, 1 - y_mask)
+ x_mask = np.linspace(0, 1, w)
+ x_mask = np.minimum(x_mask, 1 - x_mask)
+ weighted_mask = np.outer(y_mask, x_mask)
+ return torch.from_numpy(weighted_mask).float()
+
+
+def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape):
+ B, T, C, H, W = original_shape
+
+ # Initialize an empty tensor to store the reconstructed video
+ reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
+
+ # Create a tensor to store the sum of weighted masks
+ weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
+
+ # Create a weighted mask for the crops
+ weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device)
+ weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor.
+
+ for idx, crop in enumerate(cropped_tensors):
+ start_h, start_w = crop_positions[idx]
+
+ # Multiply the crop with the weighted mask
+ weighted_crop = crop * weighted_mask
+
+ # Add the weighted crop to the corresponding location in the reconstructed_video tensor
+ reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop
+
+ # Update the weighted_masks_sum tensor
+ weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask
+
+ # Add a small epsilon value to avoid division by zero
+ epsilon = 1e-8
+
+ # Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon
+ reconstructed_video /= (weighted_masks_sum + epsilon)
+
+ return reconstructed_video
+
+
+def l2_norm(x):
+ return x.square().sum(-3, True).sqrt()
+
+
+resize = lambda x, a: F.interpolate(x, [int(a * x.shape[-2]), int(a * x.shape[-1])], mode='bilinear',
+ align_corners=False)
+
+upsample = lambda x, H, W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False)
+
+
+def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5):
+ fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd)
+ flow_diff_fwd = flow_fwd + fwd_bck_cycle
+
+ bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck)
+ flow_diff_bck = flow_bck + bck_fwd_cycle
+
+ norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2
+ norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2
+
+ occ_thresh_fwd = occ_thresh * norm_fwd + 0.5
+ occ_thresh_bck = occ_thresh * norm_bck + 0.5
+
+ occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float()
+ occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float()
+
+ return occ_mask_fwd, occ_mask_bck
+
+def forward_backward_cycle_consistency(flow_fwd, flow_bck, niters=10):
+ # Make sure to be using axes-swapped, upsampled flows!
+ bck_flow_clone = flow_bck.clone().detach()
+ fwd_flow_clone = flow_fwd.clone().detach()
+
+ for i in range(niters):
+
+ fwd_bck_cycle_orig, _ = bblosses.backward_warp(img2=bck_flow_clone, flow=fwd_flow_clone)
+ flow_diff_fwd_orig = fwd_flow_clone + fwd_bck_cycle_orig
+
+ fwd_flow_clone = fwd_flow_clone - flow_diff_fwd_orig/2
+
+ bck_fwd_cycle_orig, _ = bblosses.backward_warp(img2=fwd_flow_clone, flow=bck_flow_clone)
+ flow_diff_bck_orig = bck_flow_clone + bck_fwd_cycle_orig
+
+
+ bck_flow_clone = bck_flow_clone - flow_diff_bck_orig/2
+
+ return fwd_flow_clone, bck_flow_clone
+
+from PIL import Image
+def resize_flow_map(flow_map, target_size):
+ """
+ Resize a flow map to a target size while adjusting the flow vectors.
+
+ Parameters:
+ flow_map (numpy.ndarray): Input flow map of shape (H, W, 2) where each pixel contains a (dx, dy) flow vector.
+ target_size (tuple): Target size (height, width) for the resized flow map.
+
+ Returns:
+ numpy.ndarray: Resized and scaled flow map of shape (target_size[0], target_size[1], 2).
+ """
+ # Get the original size
+ flow_map = flow_map[0].detach().cpu().numpy()
+ flow_map = flow_map.transpose(1, 2, 0)
+ original_size = flow_map.shape[:2]
+
+ # Separate the flow map into two channels: dx and dy
+ flow_map_x = flow_map[:, :, 0]
+ flow_map_y = flow_map[:, :, 1]
+
+ # Convert each flow channel to a PIL image for resizing
+ flow_map_x_img = Image.fromarray(flow_map_x)
+ flow_map_y_img = Image.fromarray(flow_map_y)
+
+ # Resize both channels to the target size using bilinear interpolation
+ flow_map_x_resized = flow_map_x_img.resize(target_size, Image.BILINEAR)
+ flow_map_y_resized = flow_map_y_img.resize(target_size, Image.BILINEAR)
+
+ # Convert resized PIL images back to NumPy arrays
+ flow_map_x_resized = np.array(flow_map_x_resized)
+ flow_map_y_resized = np.array(flow_map_y_resized)
+
+ # Compute the scaling factor based on the size change
+ scale_factor = target_size[0] / original_size[0] # Scaling factor for both dx and dy
+
+ # Scale the flow vectors (dx and dy) accordingly
+ flow_map_x_resized *= scale_factor
+ flow_map_y_resized *= scale_factor
+
+ # Recombine the two channels into a resized flow map
+ flow_map_resized = np.stack([flow_map_x_resized, flow_map_y_resized], axis=-1)
+
+ flow_map_resized = torch.from_numpy(flow_map_resized)[None].permute(0, 3, 1, 2)
+
+ return flow_map_resized
+
+def get_vmae_optical_flow_crop_batched_smoothed(generator,
+ mask_generator,
+ img1,
+ img2,
+ neg_back_flow=True,
+ num_scales=1,
+ min_scale=400,
+ N_mask_samples=100,
+ mask_ratio=0.8,
+ smoothing_factor=1):
+
+ ##### DEPRECATED
+ print('Deprecated. Please use scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed')
+
+ return scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator,
+ mask_generator,
+ img1,
+ img2,
+ neg_back_flow=neg_back_flow,
+ num_scales=num_scales,
+ min_scale=min_scale,
+ N_mask_samples=N_mask_samples,
+ mask_ratio=mask_ratio,
+ smoothing_factor=smoothing_factor)
+
+
+
+def average_crops(tensor, D):
+ C, H, W = tensor.shape
+
+ # Create zero-filled tensors for the shifted crops
+ down_shifted = torch.zeros_like(tensor)
+ up_shifted = torch.zeros_like(tensor)
+ right_shifted = torch.zeros_like(tensor)
+ left_shifted = torch.zeros_like(tensor)
+
+ # Shift the tensor and store the results in the zero-filled tensors
+ down_shifted[:, :H-D, :] = tensor[:, D:, :]
+ up_shifted[:, D:, :] = tensor[:, :H-D, :]
+ right_shifted[:, :, :W-D] = tensor[:, :, D:]
+ left_shifted[:, :, D:] = tensor[:, :, :W-D]
+
+ # Average the tensor with its four crops
+ result = (tensor + down_shifted + up_shifted + right_shifted + left_shifted) / 5.0
+
+ return result
+
+
+def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(predictor,
+ mask_generator,
+ img1,
+ img2,
+ conditioning_img=None,
+ num_scales=1,
+ min_scale=400,
+ N_mask_samples=100,
+ smoothing_factor=1):
+ B = img1.shape[0]
+ assert len(img1.shape) == 4
+ assert num_scales >= 1
+
+ # For scaling
+ h1 = img2.shape[-2]
+ w1 = img2.shape[-1]
+
+
+ alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1
+
+ frame_size = 224 // predictor.patch_size[-1]
+
+ patch_size = predictor.patch_size[-1]
+
+ num_frames = predictor.num_frames
+
+ all_fwd_flows_e2d = []
+
+ s_hs = []
+ s_ws = []
+
+ for aidx in range(num_scales):
+ # print(aidx)
+
+ # print('aidx: ', aidx)
+
+ img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
+ mode='bicubic', align_corners=True)
+ img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
+ mode='bicubic', align_corners=True)
+
+ if conditioning_img is not None:
+ conditioning_img_scaled = F.interpolate(conditioning_img.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
+ mode='bilinear', align_corners=False)
+
+ # print("img1_scaled", img1_scaled.shape, alpha, min_scale, num_scales)
+
+ h2 = img2_scaled.shape[-2]
+ w2 = img2_scaled.shape[-1]
+
+ s_h = h1 / h2
+ s_w = w1 / w2
+
+ s_hs.append(s_h)
+ s_ws.append(s_w)
+
+ if conditioning_img is not None:
+ video = torch.cat([conditioning_img_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
+ else:
+ video = torch.cat([img2_scaled.unsqueeze(1)]*(num_frames-1) + [img1_scaled.unsqueeze(1)], 1)
+
+ # Should work, even if the incoming video is already 224x224
+ crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
+
+ num_crops = len(crops1)
+
+ crop_flows_enc = []
+ crop_flows_enc2dec = []
+ N_samples = N_mask_samples
+
+ crop = torch.cat(crops1, 0).cuda()
+
+ optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
+ mask_counts = torch.zeros(frame_size, frame_size).cuda()
+
+ i = 0
+ while i < N_samples or (mask_counts == 0).any().item():
+ if i % 100 == 0:
+ pass # print(i)
+
+ # This would be that every sample has the same mask. For now that's okay I think
+ mask = mask_generator().bool().cuda()
+ mask_2f = ~mask[0, (frame_size * frame_size)*(num_frames-1):]
+ mask_counts += mask_2f.reshape(frame_size, frame_size)
+
+ with torch.cuda.amp.autocast(enabled=True):
+
+ processed_x = crop.transpose(1, 2)
+
+ encoder_out = predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
+ encoder_to_decoder = predictor.encoder_to_decoder(encoder_out)
+
+ encoder_to_decoder = encoder_to_decoder[:, (frame_size * frame_size)*(num_frames-2):, :]
+ flow_mask = mask[:, (frame_size * frame_size)*(num_frames-2):]
+
+ optical_flow_e2d = []
+ # one per batch element for now
+ for b in range(B * num_crops):
+ batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size)
+ # optical_flow_e2d.append(batch_flow.unsqueeze(0))
+
+ optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0))
+
+ optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
+ optical_flows_enc2dec += optical_flow_e2d
+ i += 1
+
+ optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
+
+ #other fucntion
+ # scale_factor_y = video.shape[-2] / 224
+ # scale_factor_x = video.shape[-1] / 224
+ #
+ # scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec)
+ # scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w
+ # scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h
+ #
+ # # split the crops back up
+ # crop_flows_enc2dec = scaled_optical_flow.split(B, 0)
+
+ ###
+ #Kevin's fn
+ crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0)
+
+ ###
+
+ #Changed by Kevin
+ T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in
+ crop_flows_enc2dec]
+ optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, (
+ B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
+
+ #other function
+ # optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(
+ # [_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in
+ # crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
+ #
+ all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
+
+ #other function
+ # all_fwd_flows_e2d_new = []
+ #
+ # for r in all_fwd_flows_e2d:
+ # new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1])
+ # all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1))
+ # return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
+ #
+ #
+ # return_flow = -return_flow
+ # all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
+ #
+ # return return_flow, all_fwd_flows_e2d_new
+
+ #Kevin's method
+ all_fwd_flows_e2d_new = []
+
+ for ridx, r in enumerate(all_fwd_flows_e2d):
+ # print('ridx', ridx)
+ # print('sh', s_hs[ridx])
+ # print('sw', s_ws[ridx])
+ # print('scale_fac y', scale_ys[ridx])
+ # print('scale_fac x', scale_xs[ridx])
+
+ _sh = s_hs[ridx]
+ _sw = s_ws[ridx]
+ _sfy = predictor.patch_size[-1]
+ _sfx = predictor.patch_size[-1]
+
+ # plt.figure(figsize=(20, 20))
+
+ # plt.subplot(1,3,1)
+ # plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0))
+
+ # plt.subplot(1,3,2)
+ new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])], mode='bicubic', align_corners=True)
+ # plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0))
+
+ scaled_new_r = torch.zeros_like(new_r)
+ scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw
+ scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh
+
+ # plt.subplot(1,3,3)
+ # plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0))
+
+ # plt.show()
+
+ all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1))
+ return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
+
+ return_flow = -return_flow
+ all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
+
+ return return_flow , all_fwd_flows_e2d_new
+
+def extract_jacobians_and_flows(img1, img2,
+ flow_generator,
+ mask,
+ target_mask=None):
+
+ IMAGE_SIZE = img1.shape[-2:]
+
+ y = torch.cat([img2.unsqueeze(1), img1.unsqueeze(1)], 1)
+
+ jacobians, flows, _ = flow_generator(y, mask, target_mask)
+
+ # swap x,y flow dims
+ flows = torch.cat([flows[0, 1].unsqueeze(0), flows[0, 0].unsqueeze(0)])
+
+ # upsample to 224
+ flows = flows.unsqueeze(0).repeat_interleave(IMAGE_SIZE[0] // flows.shape[-1], -1).repeat_interleave(
+ IMAGE_SIZE[0] // flows.shape[-1], -2)
+
+ return jacobians, flows
+
+import matplotlib.pyplot as plt
+
+class FlowToRgb(object):
+
+ def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False):
+ self.max_speed = max_speed
+ self.from_image_coordinates = from_image_coordinates
+ self.from_sampling_grid = from_sampling_grid
+
+ def __call__(self, flow):
+ assert flow.size(-3) == 2, flow.shape
+ if self.from_sampling_grid:
+ flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
+ flow_y = -flow_y
+ elif not self.from_image_coordinates:
+ flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
+ else:
+ flow_h, flow_w = torch.split(flow, [1,1], dim=-3)
+ flow_x, flow_y = [flow_w, -flow_h]
+
+
+ # print("flow_x", flow_x[0, :, 0, 0], flow_y[0, :, 0, 0])
+ angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi
+ speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed
+
+ # print("angle", angle[0, :, 0, 0] * 180 / np.pi)
+
+ hue = torch.fmod(angle, torch.tensor(2 * np.pi))
+ sat = torch.ones_like(hue)
+ val = speed
+
+ hsv = torch.cat([hue, sat, val], -3)
+ rgb = kornia.color.hsv_to_rgb(hsv)
+ return rgb
+
+ def make_colorwheel(self):
+ """
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ """
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = np.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
+ col += RY
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
+ colorwheel[col:col + YG, 1] = 255
+ col += YG
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
+ col += GC
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(0, CB) / CB)
+ colorwheel[col:col + CB, 2] = 255
+ col += CB
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
+ col += BM
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(0, MR) / MR)
+ colorwheel[col:col + MR, 0] = 255
+ return colorwheel
\ No newline at end of file
diff --git a/cwm/eval/Flow/flow_utils_legacy.py b/cwm/eval/Flow/flow_utils_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..355590223e769fa77ba46ca22837662e8ea0304a
--- /dev/null
+++ b/cwm/eval/Flow/flow_utils_legacy.py
@@ -0,0 +1,152 @@
+def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator,
+ mask_generator,
+ img1,
+ img2,
+ neg_back_flow=True,
+ num_scales=1,
+ min_scale=400,
+ N_mask_samples=100,
+ mask_ratio=0.8,
+ smoothing_factor=1):
+ B = img1.shape[0]
+ assert len(img1.shape) == 4
+ assert num_scales >= 1
+
+ # For scaling
+ h1 = img2.shape[-2]
+ w1 = img2.shape[-1]
+ assert min_scale < h1 and min_scale >= 360 # Below 360p, the flows look terrible
+
+ if neg_back_flow is False:
+ print('WARNING: Not calculating negative backward flow')
+
+ alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1
+
+ frame_size = 224 // generator.patch_size[-1]
+
+ all_fwd_flows_e2d = []
+
+ s_hs = []
+ s_ws = []
+
+ for aidx in range(num_scales):
+ print(aidx)
+
+ # print('aidx: ', aidx)
+
+ img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
+ mode='bicubic', align_corners=True)
+ img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
+ mode='bicubic', align_corners=True)
+
+ h2 = img2_scaled.shape[-2]
+ w2 = img2_scaled.shape[-1]
+
+ s_h = h1 / h2
+ s_w = w1 / w2
+
+ s_hs.append(s_h)
+ s_ws.append(s_w)
+
+ # Because technically the compute_optical_flow function returns neg back flow
+ if neg_back_flow is True:
+ video = torch.cat([img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
+ else:
+ video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1)], 1)
+
+ # Should work, even if the incoming video is already 224x224
+ crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
+
+ num_crops = len(crops1)
+
+ crop_flows_enc = []
+ crop_flows_enc2dec = []
+ N_samples = N_mask_samples
+
+ crop = torch.cat(crops1, 0).cuda()
+
+ optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
+ mask_counts = torch.zeros(frame_size, frame_size).cuda()
+
+ i = 0
+ while i < N_samples or (mask_counts == 0).any().item():
+ if i % 100 == 0:
+ pass # print(i)
+ mask_generator.mask_ratio = mask_ratio
+
+ # This would be that every sample has the same mask. For now that's okay I think
+ mask = mask_generator()[None]
+ mask_2f = ~mask[0, frame_size * frame_size:]
+ mask_counts += mask_2f.reshape(frame_size, frame_size)
+
+ with torch.cuda.amp.autocast(enabled=True):
+
+ processed_x = generator._preprocess(crop)
+
+ encoder_out = generator.predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
+ encoder_to_decoder = generator.predictor.encoder_to_decoder(encoder_out)
+
+ optical_flow_e2d = []
+ # one per batch element for now
+ for b in range(B * num_crops):
+ batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), mask, frame_size)
+ optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0))
+
+ optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
+ optical_flows_enc2dec += optical_flow_e2d
+ i += 1
+
+ optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
+
+ # split the crops back up
+ crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0)
+
+ T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in
+ crop_flows_enc2dec]
+
+ optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, (
+ B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
+
+ all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
+
+ all_fwd_flows_e2d_new = []
+
+ for ridx, r in enumerate(all_fwd_flows_e2d):
+ # print('ridx', ridx)
+ # print('sh', s_hs[ridx])
+ # print('sw', s_ws[ridx])
+ # print('scale_fac y', scale_ys[ridx])
+ # print('scale_fac x', scale_xs[ridx])
+
+ _sh = s_hs[ridx]
+ _sw = s_ws[ridx]
+ _sfy = generator.patch_size[-1]
+ _sfx = generator.patch_size[-1]
+
+ # plt.figure(figsize=(20, 20))
+
+ # plt.subplot(1,3,1)
+ # plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0))
+
+ # plt.subplot(1,3,2)
+ new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])],
+ mode='bicubic', align_corners=True)
+ # plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0))
+
+ scaled_new_r = torch.zeros_like(new_r)
+ scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw
+ scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh
+
+ # plt.subplot(1,3,3)
+ # plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0))
+
+ # plt.show()
+
+ all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1))
+ return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
+
+ if neg_back_flow is True:
+ return_flow = -return_flow
+ all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
+
+ return return_flow, all_fwd_flows_e2d_new
\ No newline at end of file
diff --git a/cwm/eval/Flow/generator.py b/cwm/eval/Flow/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..954ca3ad5ba3269bf4a58e75150155528b79a2f1
--- /dev/null
+++ b/cwm/eval/Flow/generator.py
@@ -0,0 +1,579 @@
+import kornia
+import numpy as np
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import nn
+
+import cwm.eval.Flow.masking_flow as masking
+
+
+def boltzmann(x, beta=1, eps=1e-9):
+ if beta is None:
+ return x
+ x = torch.exp(x * beta)
+ return x / x.amax((-1,-2), keepdim=True).clamp(min=eps)
+
+IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
+
+def imagenet_normalize(x, temporal_dim=1):
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(x.device)[None,None,:,None,None].to(x)
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(x.device)[None,None,:,None,None].to(x)
+ if temporal_dim == 2:
+ mean = mean.transpose(1,2)
+ std = std.transpose(1,2)
+ return (x - mean) / std
+
+def imagenet_unnormalize(x, temporal_dim=2):
+ device = x.device
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None].to(x)
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None].to(x)
+ if temporal_dim == 2:
+ mean = mean.transpose(1,2)
+ std = std.transpose(1,2)
+ x = x*std + mean
+ return x
+
+
+
+def coordinate_ims(batch_size, seq_length, imsize, normalize=True, dtype_out=torch.float32):
+ static = False
+ if seq_length == 0:
+ static = True
+ seq_length = 1
+ B = batch_size
+ T = seq_length
+ H,W = imsize
+ ones = torch.ones([B,H,W,1], dtype=dtype_out)
+ if normalize:
+ h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=dtype_out))
+ h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5)
+ w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=dtype_out))
+ w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5)
+ else:
+ h = torch.arange(H).to(ones).view(1,H,1,1) * ones
+ w = torch.arange(W).to(ones).view(1,1,W,1) * ones
+ h = torch.stack([h]*T, 1)
+ w = torch.stack([w]*T, 1)
+ hw_ims = torch.cat([h,w], -1)
+ if static:
+ hw_ims = hw_ims[:,0]
+ return hw_ims
+
+
+def get_distribution_centroid(dist, eps=1e-9, normalize=False):
+
+ B,T,C,H,W = dist.shape
+ assert C == 1
+ dist_sum = dist.sum((-2, -1), keepdim=True).clamp(min=eps)
+ dist = dist / dist_sum
+
+ grid = coordinate_ims(B, T, [H,W], normalize=normalize).to(dist.device)
+ grid = grid.permute(0,1,4,2,3)
+ centroid = (grid * dist).sum((-2,-1))
+ return centroid
+
+
+
+class FlowToRgb(object):
+
+ def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False):
+ self.max_speed = max_speed
+ self.from_image_coordinates = from_image_coordinates
+ self.from_sampling_grid = from_sampling_grid
+
+ def __call__(self, flow):
+ assert flow.size(-3) == 2, flow.shape
+ if self.from_sampling_grid:
+ flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
+ flow_y = -flow_y
+ elif not self.from_image_coordinates:
+ flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
+ else:
+ flow_h, flow_w = torch.split(flow, [1,1], dim=-3)
+ flow_x, flow_y = [flow_w, -flow_h]
+
+ angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi
+ speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed
+
+ hue = torch.fmod(angle, torch.tensor(2 * np.pi))
+ sat = torch.ones_like(hue)
+ val = speed
+
+ hsv = torch.cat([hue, sat, val], -3)
+ rgb = kornia.color.hsv_to_rgb(hsv)
+ return rgb
+
+class Patchify(nn.Module):
+ """Convert a set of images or a movie into patch vectors"""
+
+ def __init__(self,
+ patch_size=(16, 16),
+ temporal_dim=1,
+ squeeze_channel_dim=True
+ ):
+ super().__init__()
+ self.set_patch_size(patch_size)
+ self.temporal_dim = temporal_dim
+ assert self.temporal_dim in [1, 2], self.temporal_dim
+ self._squeeze_channel_dim = squeeze_channel_dim
+
+ @property
+ def num_patches(self):
+ if (self.T is None) or (self.H is None) or (self.W is None):
+ return None
+ else:
+ return (self.T // self.pt) * (self.H // self.ph) * (self.W // self.pw)
+
+ def set_patch_size(self, patch_size):
+ self.patch_size = patch_size
+ if len(self.patch_size) == 2:
+ self.ph, self.pw = self.patch_size
+ self.pt = 1
+ self._patches_are_3d = False
+ elif len(self.patch_size) == 3:
+ self.pt, self.ph, self.pw = self.patch_size
+ self._patches_are_3d = True
+ else:
+ raise ValueError("patch_size must be a 2- or 3-tuple, but is %s" % self.patch_size)
+
+ self.shape_inp = self.rank_inp = self.H = self.W = self.T = None
+ self.D = self.C = self.E = self.embed_dim = None
+
+ def _check_shape(self, x):
+ self.shape_inp = x.shape
+ self.rank_inp = len(self.shape_inp)
+ self.H, self.W = self.shape_inp[-2:]
+ assert (self.H % self.ph) == 0 and (self.W % self.pw) == 0, (self.shape_inp, self.patch_size)
+ if (self.rank_inp == 5) and self._patches_are_3d:
+ self.T = self.shape_inp[self.temporal_dim]
+ assert (self.T % self.pt) == 0, (self.T, self.pt)
+ elif self.rank_inp == 5:
+ self.T = self.shape_inp[self.temporal_dim]
+ else:
+ self.T = 1
+
+ def split_by_time(self, x):
+ shape = x.shape
+ assert shape[1] % self.T == 0, (shape, self.T)
+ return x.view(shape[0], self.T, shape[1] // self.T, *shape[2:])
+
+ def merge_by_time(self, x):
+ shape = x.shape
+ return x.view(shape[0], shape[1] * shape[2], *shape[3:])
+
+ def video_to_patches(self, x):
+ if self.rank_inp == 4:
+ assert self.pt == 1, (self.pt, x.shape)
+ x = rearrange(x, 'b c (h ph) (w pw) -> b (h w) (ph pw) c', ph=self.ph, pw=self.pw)
+ else:
+ assert self.rank_inp == 5, (x.shape, self.rank_inp, self.shape_inp)
+ dim_order = 'b (t pt) c (h ph) (w pw)' if self.temporal_dim == 1 else 'b c (t pt) (h ph) (w pw)'
+ x = rearrange(x, dim_order + ' -> b (t h w) (pt ph pw) c', pt=self.pt, ph=self.ph, pw=self.pw)
+
+ self.N, self.D, self.C = x.shape[-3:]
+ self.embed_dim = self.E = self.D * self.C
+ return x
+
+ def patches_to_video(self, x):
+ shape = x.shape
+ rank = len(shape)
+ if rank == 4:
+ B, _N, _D, _C = shape
+ else:
+ assert rank == 3, rank
+ B, _N, _E = shape
+ assert (_E % self.D == 0), (_E, self.D)
+ x = x.view(B, _N, self.D, -1)
+
+ if _N < self.num_patches:
+ masked_patches = self.get_masked_patches(
+ x,
+ num_patches=(self.num_patches - _N),
+ mask_mode=self.mask_mode)
+ x = torch.cat([x, masked_patches], 1)
+
+ x = rearrange(
+ x,
+ 'b (t h w) (pt ph pw) c -> b c (t pt) (h ph) (w pw)',
+ pt=self.pt, ph=self.ph, pw=self.pw,
+ t=(self.T // self.pt), h=(self.H // self.ph), w=(self.W // self.pw))
+
+ if self.rank_inp == 5 and (self.temporal_dim == 1):
+ x = x.transpose(1, 2)
+ elif self.rank_inp == 4:
+ assert x.shape[2] == 1, x.shape
+ x = x[:, :, 0]
+ return x
+
+ @staticmethod
+ def get_masked_patches(x, num_patches, mask_mode='zeros'):
+ shape = x.shape
+ patches_shape = (shape[0], num_patches, *shape[2:])
+ if mask_mode == 'zeros':
+ return torch.zeros(patches_shape).to(x.device).to(x.dtype).detach()
+ elif mask_mode == 'gray':
+ return 0.5 * torch.ones(patches_shape).to(x.device).to(x.dtype).detach()
+ else:
+ raise NotImplementedError("Haven't implemented mask_mode == %s" % mask_mode)
+
+ def average_within_patches(self, z):
+ if len(z.shape) == 3:
+ z = rearrange(z, 'b n (d c) -> b n d c', c=self.C)
+ return z.mean(-2, True).repeat(1, 1, z.shape[-2], 1)
+
+ def forward(self, x, to_video=False, mask_mode='zeros'):
+ if not to_video:
+ self._check_shape(x)
+ x = self.video_to_patches(x)
+ return x if not self._squeeze_channel_dim else x.view(x.size(0), self.N, -1)
+
+ else: # x are patches
+ assert (self.shape_inp is not None) and (self.num_patches is not None)
+ self.mask_mode = mask_mode
+ x = self.patches_to_video(x)
+ return x
+
+
+class DerivativeFlowGenerator(nn.Module):
+ """Estimate flow of a two-frame predictor using torch autograd"""
+
+ def __init__(self,
+ predictor,
+ perturbation_patch_size=None,
+ aggregation_patch_size=None,
+ agg_power=None,
+ agg_channel_func=None,
+ num_samples=1,
+ leave_one_out_sampling=False,
+ average_jacobian=True,
+ confidence_thresh=None,
+ temporal_dim=2,
+ imagenet_normalize_inputs=True):
+
+ super(DerivativeFlowGenerator, self).__init__()
+
+ self.predictor = predictor
+
+ self.patchify = Patchify(self.patch_size, temporal_dim=1, squeeze_channel_dim=True)
+
+ self.set_temporal_dim(temporal_dim)
+
+ self.imagenet_normalize_inputs = imagenet_normalize_inputs
+
+ self.perturbation_patch_size = self._get_patch_size(perturbation_patch_size) or self.patch_size
+ self.aggregation_patch_size = self._get_patch_size(aggregation_patch_size) or self.patch_size
+ self.agg_patchify = Patchify(self.aggregation_patch_size,
+ temporal_dim=1,
+ squeeze_channel_dim=False)
+ self.agg_channel_func = agg_channel_func or (lambda x: F.relu(x).sum(-3, True))
+ self.average_jacobian = average_jacobian
+ self.confidence_thresh = confidence_thresh
+
+ self.num_samples = num_samples
+ self.leave_one_out_sampling = leave_one_out_sampling
+ self.agg_power = agg_power
+ self.t_dim = temporal_dim
+
+ def _get_patch_size(self, p):
+ if p is None:
+ return None
+ elif isinstance(p, int):
+ return (1, p, p)
+ elif len(p) == 2:
+ return (1, p[0], p[1])
+ else:
+ assert len(p) == 3, p
+ return (p[0], p[1], p[2])
+
+ def set_temporal_dim(self, t_dim):
+ if t_dim == 1:
+ self.predictor.t_dim = 1
+ self.predictor.c_dim = 2
+ elif t_dim == 2:
+ self.predictor.c_dim = 1
+ self.predictor.t_dim = 2
+ else:
+ raise ValueError("temporal_dim must be 1 or 2")
+
+ @property
+ def c_dim(self):
+ if self.predictor is None:
+ return None
+ return self.predictor.c_dim
+
+ @property
+ def patch_size(self):
+ if self.predictor is None:
+ return None
+ elif hasattr(self.predictor, 'patch_size'):
+ return self.predictor.patch_size
+ elif hasattr(self.predictor.encoder.patch_embed, 'proj'):
+ return self.predictor.encoder.patch_embed.proj.kernel_size
+ else:
+ return None
+ @property
+ def S(self):
+ return self.num_samples
+
+ @property
+ def sequence_length(self):
+ if self.predictor is None:
+ return None
+ elif hasattr(self.predictor, 'sequence_length'):
+ return self.predictor.sequence_length
+ elif hasattr(self.predictor, 'num_frames'):
+ return self.predictor.num_frames
+ else:
+ return 2
+ @property
+ def mask_shape(self):
+ if self.predictor is None:
+ return None
+ elif hasattr(self.predictor, 'mask_shape'):
+ return self.predictor.mask_shape
+
+ assert self.patch_size is not None
+ pt, ph, pw = self.patch_size
+ return (self.sequence_length // pt,
+ self.inp_shape[-2] // ph,
+ self.inp_shape[-1] // pw)
+
+ @property
+ def perturbation_mask_shape(self):
+ return (
+ self.mask_shape[0],
+ self.inp_shape[-2] // self.perturbation_patch_size[-2],
+ self.inp_shape[-1] // self.perturbation_patch_size[-1]
+ )
+
+
+
+ @property
+ def p_mask_shape(self):
+ return self.perturbation_mask_shape
+
+ @property
+ def aggregation_mask_shape(self):
+ return (
+ 1,
+ self.inp_shape[-2] // self.aggregation_patch_size[-2],
+ self.inp_shape[-1] // self.aggregation_patch_size[-1]
+ )
+
+ @property
+ def a_mask_shape(self):
+ return self.aggregation_mask_shape
+
+ def get_perturbation_input(self, x):
+ self.set_input(x)
+ y = torch.zeros((self.B, *self.p_mask_shape), dtype=x.dtype, device=x.device, requires_grad=True)
+ y = y.unsqueeze(2).repeat(1, 1, x.shape[2], 1, 1)
+ return y
+
+ def pred_patches_to_video(self, y, x, mask):
+ """input at visible positions, preds at masked positions"""
+ B, C = y.shape[0], y.shape[-1]
+ self.patchify._check_shape(x)
+ self.patchify.D = np.prod(self.patch_size)
+ x = self.patchify(x)
+ y_out = torch.zeros_like(x)
+ x_vis = x[~mask]
+
+ y_out[~mask] = x_vis.view(-1, C)
+ try:
+ y_out[mask] = y.view(-1, C)
+ except:
+ y_out[mask] = y.reshape(-1, C)
+
+ return self.patchify(y_out, to_video=True)
+
+ def set_image_size(self, *args, **kwargs):
+ assert self.predictor is not None, "Can't set the image size without a predictor"
+ if hasattr(self.predictor, 'set_image_size'):
+ self.predictor.set_image_size(*args, **kwargs)
+ else:
+ self.predictor.image_size = args[0]
+
+ def predict(self, x=None, mask=None, forward_full=False):
+ if x is None:
+ x = self.x
+ if mask is None:
+ mask = self.generate_mask(x)
+
+ self.set_image_size(x.shape[-2:])
+ y = self.predictor(
+ self._preprocess(x),
+ mask if (x.size(0) == 1) else self.mask_rectangularizer(mask), forward_full=forward_full)
+
+ y = self.pred_patches_to_video(y, x, mask=mask)
+
+ frame = -1 % y.size(1)
+ y = y[:, frame:frame + 1]
+
+ return y
+
+ def _get_perturbation_func(self, x=None, mask=None):
+
+ if (x is not None):
+ self.set_input(x, mask)
+
+ def forward_mini_image(y):
+ y = y.repeat_interleave(self.perturbation_patch_size[-2], -2)
+ y = y.repeat_interleave(self.perturbation_patch_size[-1], -1)
+ x_pred = self.predict(self.x + y, self.mask)
+ x_pred = self.agg_patchify(x_pred).mean(-2).sum(-1).view(self.B, *self.a_mask_shape)
+ return x_pred[self.targets]
+
+ return forward_mini_image
+
+ def _postprocess_jacobian(self, jac):
+ _jac = torch.zeros((self.B, *self.a_mask_shape, *jac.shape[1:])).to(jac.device).to(jac.dtype)
+ _jac[self.targets] = jac
+ jac = self.agg_channel_func(_jac)
+ assert jac.size(-3) == 1, jac.shape
+ jac = jac.squeeze(-3)[..., 0, :, :] # derivative w.r.t. first frame and agg channels
+ jac = jac.view(self.B, self.a_mask_shape[-2], self.a_mask_shape[-1],
+ self.B, self.p_mask_shape[-2], self.p_mask_shape[-1])
+ bs = torch.arange(0, self.B).long().to(jac.device)
+ jac = jac[bs, :, :, bs, :, :] # take diagonal
+ return jac
+
+ def _confident_jacobian(self, jac):
+ if self.confidence_thresh is None:
+ return torch.ones_like(jac[:, None, ..., 0, 0])
+ conf = (jac.amax((-2, -1)) > self.confidence_thresh).float()[:, None]
+ return conf
+
+ def set_input(self, x, mask=None, timestamps=None):
+ shape = x.shape
+ if len(shape) == 4:
+ x = x.unsqueeze(1)
+ else:
+ assert len(shape) == 5, \
+ "Input must be a movie of shape [B,T,C,H,W]" + \
+ "or a single frame of shape [B,C,H,W]"
+
+ self.inp_shape = x.shape
+ self.x = x
+ self.B = self.inp_shape[0]
+ self.T = self.inp_shape[1]
+ self.C = self.inp_shape[2]
+ if mask is not None:
+ self.mask = mask
+
+ if timestamps is not None:
+ self.timestamps = timestamps
+
+ def _preprocess(self, x):
+ if self.imagenet_normalize_inputs:
+ x = imagenet_normalize(x)
+ if self.t_dim != 1:
+ x = x.transpose(self.t_dim, self.c_dim)
+ return x
+
+ def _jacobian_to_flows(self, jac):
+ if self.agg_power is None:
+ jac = (jac == jac.amax((-2, -1), True)).float()
+ else:
+ jac = torch.pow(jac, self.agg_power)
+
+ jac = jac.view(self.B * np.prod(self.a_mask_shape[-2:]), 1, 1, *self.p_mask_shape[-2:])
+ centroids = get_distribution_centroid(jac, normalize=False).view(
+ self.B, self.a_mask_shape[-2], self.a_mask_shape[-1], 2)
+ rescale = [self.a_mask_shape[-2] / self.p_mask_shape[-2],
+ self.a_mask_shape[-1] / self.p_mask_shape[-1]]
+ centroids = centroids * torch.tensor(rescale, device=centroids.device).view(1, 1, 1, 2)
+
+ flows = centroids - \
+ coordinate_ims(1, 0, self.a_mask_shape[-2:], normalize=False).to(jac.device)
+ flows = flows.permute(0, 3, 1, 2)
+ px_scale = torch.tensor(self.aggregation_patch_size[-2:]).float().to(flows.device).view(1, 2, 1, 1)
+ flows *= px_scale
+
+ return flows
+
+ def set_targets(self, targets=None, frame=-1):
+ frame = frame % self.mask_shape[0]
+ if targets is None:
+ targets = self.get_mask_image(self.mask)[:, frame:frame + 1]
+ else:
+ assert len(targets.shape) == 4, targets.shape
+ targets = targets[:, frame:frame + 1]
+ self.targets = ~masking.upsample_masks(~targets, self.a_mask_shape[-2:])
+
+ def _get_mask_partition(self, mask):
+ mask = self.get_mask_image(mask)
+ mask_list = masking.partition_masks(
+ mask[:, 1:], num_samples=self.S, leave_one_out=self.leave_one_out_sampling)
+ return [torch.cat([mask[:, 0:1].view(m.size(0), -1), m], -1)
+ for m in mask_list]
+
+ def _compute_jacobian(self, y):
+ perturbation_func = self._get_perturbation_func()
+ jac = torch.autograd.functional.jacobian(
+ perturbation_func,
+ y,
+ vectorize=False)
+ jac = self._postprocess_jacobian(jac)
+ return jac
+
+ def _upsample_mask(self, mask):
+ return masking.upsample_masks(
+ mask.view(mask.size(0), -1, *self.mask_shape[-2:]).float(), self.inp_shape[-2:])
+
+ def get_mask_image(self, mask, upsample=False, invert=False, shape=None):
+ if shape is None:
+ shape = self.mask_shape
+ mask = mask.view(-1, *shape)
+ if upsample:
+ mask = self._upsample_mask(mask)
+ if invert:
+ mask = 1 - mask
+ return mask
+
+ def forward(self, x, mask, targets=None):
+ self.set_input(x, mask)
+ y = self.get_perturbation_input(x)
+ mask_list = self._get_mask_partition(mask)
+
+ jacobian, flows, confident = [], [], []
+ for s, mask_sample in enumerate(mask_list):
+ self.set_input(x, mask_sample)
+ self.set_targets(targets)
+
+ import time
+ t1 = time.time()
+ jac = self._compute_jacobian(y)
+ conf_jac = masking.upsample_masks(self._confident_jacobian(jac), self.a_mask_shape[-2:])
+ jacobian.append(jac)
+ confident.append(conf_jac)
+ if not self.average_jacobian:
+ flow = self._jacobian_to_flows(jac) * self.targets * conf_jac * \
+ masking.upsample_masks(self.get_mask_image(self.mask)[:, 1:], self.a_mask_shape[-2:])
+ flows.append(flow)
+ t2 = time.time()
+ print(t2 - t1)
+
+ jacobian = torch.stack(jacobian, -1)
+ confident = torch.stack(confident, -1)
+ valid = torch.stack([masking.upsample_masks(
+ self.get_mask_image(m)[:, 1:], self.a_mask_shape[-2:]) for m in mask_list], -1)
+ valid = valid * confident
+
+ if self.average_jacobian:
+ _valid = valid[:, 0].unsqueeze(-2).unsqueeze(-2)
+ jac = (jacobian * _valid.float()).sum(-1) / _valid.float().sum(-1).clamp(min=1)
+ flows = self._jacobian_to_flows(jac) * \
+ masking.upsample_masks(_valid[:, None, ..., 0, 0, :].amax(-1).bool(), self.a_mask_shape[-2:])
+ if targets is not None:
+ self.set_targets(targets)
+ flows *= self.targets
+ else:
+ flows = torch.stack(flows, -1)
+ flows = flows.sum(-1) / valid.float().sum(-1).clamp(min=1)
+
+ valid = valid * (targets[:, -1:].unsqueeze(-1) if targets is not None else 1)
+
+ return (jacobian, flows, valid)
\ No newline at end of file
diff --git a/cwm/eval/Flow/losses.py b/cwm/eval/Flow/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..25ee0ea28deb584aba156b9d91264aab29dac8dd
--- /dev/null
+++ b/cwm/eval/Flow/losses.py
@@ -0,0 +1,60 @@
+import torch
+import torch.nn.functional as F
+from torchvision import transforms
+
+
+def sampling_grid(height, width):
+ H, W = height, width
+ grid = torch.stack([
+ torch.arange(W).view(1, -1).repeat(H, 1),
+ torch.arange(H).view(-1, 1).repeat(1, W)
+ ], -1)
+ grid = grid.view(1, H, W, 2)
+ return grid
+
+
+def normalize_sampling_grid(coords):
+ assert len(coords.shape) == 4, coords.shape
+ assert coords.size(-1) == 2, coords.shape
+ H, W = coords.shape[-3:-1]
+ xs, ys = coords.split([1, 1], -1)
+ xs = 2 * xs / (W - 1) - 1
+ ys = 2 * ys / (H - 1) - 1
+ return torch.cat([xs, ys], -1)
+
+
+def backward_warp(img2, flow, do_mask=False):
+ """
+ Grid sample from img2 using the flow from img1->img2 to get a prediction of img1.
+
+ flow: [B,2,H',W'] in units of pixels at its current resolution. The two channels
+ should be (x,y) where larger y values correspond to lower parts of the image.
+ """
+
+ ## resize the flow to the image size.
+ ## since flow has units of pixels, its values need to be rescaled accordingly.
+ if list(img2.shape[-2:]) != list(flow.shape[-2:]):
+ scale = [img2.size(-1) / flow.size(-1), # x
+ img2.size(-2) / flow.size(-2)] # y
+ scale = torch.tensor(scale).view(1, 2, 1, 1).to(flow.device)
+ flow = scale * transforms.Resize(img2.shape[-2:])(flow) # defaults to bilinear
+
+ B, C, H, W = img2.shape
+
+ ## use flow to warp sampling grid
+ grid = sampling_grid(H, W).to(flow.device) + flow.permute(0, 2, 3, 1)
+
+ ## put grid in normalized image coordinates
+ grid = normalize_sampling_grid(grid)
+
+ ## backward warp, i.e. sample pixel (x,y) from (x+flow_x, y+flow_y)
+ img1_pred = F.grid_sample(img2, grid, align_corners=True)
+
+ if do_mask:
+ mask = (grid[..., 0] > -1) & (grid[..., 0] < 1) & \
+ (grid[..., 1] > -1) & (grid[..., 1] < 1)
+ mask = mask[:, None].to(img2.dtype)
+ return (img1_pred, mask)
+
+ else:
+ return (img1_pred, torch.ones_like(grid[..., 0][:, None]).float())
diff --git a/cwm/eval/Flow/masking_flow.py b/cwm/eval/Flow/masking_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..650250d7b8e741154435ff675a3b703bf6990847
--- /dev/null
+++ b/cwm/eval/Flow/masking_flow.py
@@ -0,0 +1,375 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+
+def upsample_masks(masks, size, thresh=0.5):
+ shape = masks.shape
+ dtype = masks.dtype
+ h, w = shape[-2:]
+ H, W = size
+ if (H == h) and (W == w):
+ return masks
+ elif (H < h) and (W < w):
+ s = (h // H, w // W)
+ return masks[..., ::s[0], ::s[1]]
+
+ masks = masks.unsqueeze(-2).unsqueeze(-1)
+ masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w)
+ if ((H % h) == 0) and ((W % w) == 0):
+ masks = masks.view(*shape[:-2], H, W)
+ else:
+ _H = np.prod(masks.shape[-4:-2])
+ _W = np.prod(masks.shape[-2:])
+ masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh
+ masks = masks.view(*shape[:2], H, W).to(masks.dtype)
+ return masks
+
+
+
+
+def partition_masks(masks, num_samples=2, leave_one_out=False):
+ B = masks.shape[0]
+ S = num_samples
+ masks = masks.view(B, -1)
+ partitioned = [torch.ones_like(masks) for _ in range(S)]
+ for b in range(B):
+ vis_inds = torch.where(~masks[b])[0]
+ vis_inds = vis_inds[torch.randperm(vis_inds.size(0))]
+ if leave_one_out:
+ for s in range(S):
+ partitioned[s][b][vis_inds] = 0
+ partitioned[s][b][vis_inds[s::S]] = 1
+ else:
+ for s in range(S):
+ partitioned[s][b][vis_inds[s::S]] = 0
+ return partitioned
+
+
+class RectangularizeMasks(nn.Module):
+ """Make sure all masks in a batch have same number of 1s and 0s"""
+
+ def __init__(self, truncation_mode='min'):
+ super().__init__()
+ self._mode = truncation_mode
+ assert self._mode in ['min', 'max', 'mean', 'full', 'none', None], (self._mode)
+
+ def set_mode(self, mode):
+ self._mode = mode
+
+ def __call__(self, masks):
+
+ if self._mode in ['none', None]:
+ return masks
+
+ assert isinstance(masks, torch.Tensor), type(masks)
+ if self._mode == 'full':
+ return torch.ones_like(masks)
+
+ shape = masks.shape
+ masks = masks.flatten(1)
+ B, N = masks.shape
+ num_masked = masks.float().sum(-1)
+ M = {
+ 'min': torch.amin, 'max': torch.amax, 'mean': torch.mean
+ }[self._mode](num_masked).long()
+
+ num_changes = num_masked.long() - M
+
+ for b in range(B):
+ nc = num_changes[b]
+ if nc > 0:
+ inds = torch.where(masks[b])[0]
+ inds = inds[torch.randperm(inds.size(0))[:nc].to(inds.device)]
+ masks[b, inds] = 0
+ elif nc < 0:
+ inds = torch.where(~masks[b])[0]
+ inds = inds[torch.randperm(inds.size(0))[:-nc].to(inds.device)]
+ masks[b, inds] = 1
+ if list(masks.shape) != list(shape):
+ masks = masks.view(*shape)
+
+ return masks
+
+
+class UniformMaskingGenerator(object):
+ def __init__(self, input_size, mask_ratio, seed=None, clumping_factor=1, randomize_num_visible=False):
+ self.frames = None
+ if len(input_size) == 3:
+ self.frames, self.height, self.width = input_size
+ elif len(input_size) == 2:
+ self.height, self.width = input_size
+ elif len(input_size) == 1 or isinstance(input_size, int):
+ self.height = self.width = input_size
+
+ self.clumping_factor = clumping_factor
+ self.pad_h = self.height % self.c[0]
+ self.pad_w = self.width % self.c[1]
+ self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1])
+ self.mask_ratio = mask_ratio
+
+ self.rng = np.random.RandomState(seed=seed)
+ self.randomize_num_visible = randomize_num_visible
+
+ @property
+ def num_masks_per_frame(self):
+ if not hasattr(self, '_num_masks_per_frame'):
+ self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame)
+ return self._num_masks_per_frame
+
+ @num_masks_per_frame.setter
+ def num_masks_per_frame(self, val):
+ self._num_masks_per_frame = val
+ self._mask_ratio = (val / self.num_patches_per_frame)
+
+ @property
+ def c(self):
+ if isinstance(self.clumping_factor, int):
+ return (self.clumping_factor, self.clumping_factor)
+ else:
+ return self.clumping_factor[:2]
+
+ @property
+ def mask_ratio(self):
+ return self._mask_ratio
+
+ @mask_ratio.setter
+ def mask_ratio(self, val):
+ self._mask_ratio = val
+ self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame)
+
+ @property
+ def num_visible(self):
+ return self.num_patches_per_frame - self.num_masks_per_frame
+
+ @num_visible.setter
+ def num_visible(self, val):
+ self.num_masks_per_frame = self.num_patches_per_frame - val
+
+ def __repr__(self):
+ repr_str = "Mask: total patches per frame {}, mask patches per frame {}, mask ratio {}, random num num visible? {}".format(
+ self.num_patches_per_frame, self.num_masks_per_frame, self.mask_ratio, self.randomize_num_visible
+ )
+ return repr_str
+
+ def sample_mask_per_frame(self):
+ num_masks = self.num_masks_per_frame
+ if self.randomize_num_visible:
+ num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
+ mask = np.hstack([
+ np.zeros(self.num_patches_per_frame - num_masks),
+ np.ones(num_masks)])
+ self.rng.shuffle(mask)
+ if max(*self.c) > 1:
+ mask = mask.reshape(self.height // self.c[0],
+ 1,
+ self.width // self.c[1],
+ 1)
+ mask = np.tile(mask, (1, self.c[0], 1, self.c[1]))
+ mask = mask.reshape((self.height - self.pad_h, self.width - self.pad_w))
+ _pad_h = self.rng.choice(range(self.pad_h + 1))
+ pad_h = (self.pad_h - _pad_h, _pad_h)
+ _pad_w = self.rng.choice(range(self.pad_w + 1))
+ pad_w = (self.pad_w - _pad_w, _pad_w)
+ mask = np.pad(mask,
+ (pad_h, pad_w),
+ constant_values=1
+ ).reshape((self.height, self.width))
+ return mask
+
+ def __call__(self, num_frames=None):
+ num_frames = (num_frames or self.frames) or 1
+ masks = np.stack([self.sample_mask_per_frame() for _ in range(num_frames)]).flatten()
+ return masks
+
+
+class TubeMaskingGenerator(UniformMaskingGenerator):
+
+ def __call__(self, num_frames=None):
+ num_frames = (num_frames or self.frames) or 1
+ masks = np.tile(self.sample_mask_per_frame(), (num_frames, 1)).flatten()
+ return masks
+
+
+class RotatedTableMaskingGenerator(TubeMaskingGenerator):
+
+ def __init__(self, tube_length=None, *args, **kwargs):
+ super(RotatedTableMaskingGenerator, self).__init__(*args, **kwargs)
+ self.tube_length = tube_length
+
+ def __call__(self, num_frames=None):
+ num_frames = (num_frames or self.frames) or 2
+ tube_length = self.tube_length or (num_frames - 1)
+ table_thickness = num_frames - tube_length
+ assert tube_length < num_frames, (tube_length, num_frames)
+
+ tubes = super().__call__(num_frames=tube_length)
+ top = np.zeros(table_thickness * self.height * self.width).astype(tubes.dtype).flatten()
+ masks = np.concatenate([top, tubes], 0)
+ return masks
+
+
+class PytorchMaskGeneratorWrapper(nn.Module):
+ """Pytorch wrapper for numpy masking generators"""
+
+ def __init__(self,
+ mask_generator=TubeMaskingGenerator,
+ *args, **kwargs):
+ super().__init__()
+ self.mask_generator = mask_generator(*args, **kwargs)
+
+ @property
+ def mask_ratio(self):
+ return self.mask_generator.mask_ratio
+
+ @mask_ratio.setter
+ def mask_ratio(self, value):
+ self.mask_generator.mask_ratio = value
+
+ def forward(self, device='cuda', dtype_out=torch.bool, **kwargs):
+ masks = self.mask_generator(**kwargs)
+ masks = torch.tensor(masks).to(device).to(dtype_out)
+ return masks
+
+
+class MaskingGenerator(nn.Module):
+ """Pytorch base class for masking generators"""
+
+ def __init__(self,
+ input_size,
+ mask_ratio,
+ seed=0,
+ visible_frames=0,
+ clumping_factor=1,
+ randomize_num_visible=False,
+ create_on_cpu=True,
+ always_batch=False):
+ super().__init__()
+ self.frames = None
+
+ if len(input_size) == 3:
+ self.frames, self.height, self.width = input_size
+ elif len(input_size) == 2:
+ self.height, self.width = input_size
+ elif len(input_size) == 1 or isinstance(input_size, int):
+ self.height = self.width = input_size
+
+ self.clumping_factor = clumping_factor
+ self.pad_h = self.height % self.c[0]
+ self.pad_w = self.width % self.c[1]
+ self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1])
+
+ self.mask_ratio = mask_ratio
+ self.visible_frames = visible_frames
+ self.always_batch = always_batch
+ self.create_on_cpu = create_on_cpu
+
+ self.rng = np.random.RandomState(seed=seed)
+ self._set_torch_seed(seed)
+
+ self.randomize_num_visible = randomize_num_visible
+
+ @property
+ def num_masks_per_frame(self):
+ if not hasattr(self, '_num_masks_per_frame'):
+ self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame)
+ return self._num_masks_per_frame
+
+ @num_masks_per_frame.setter
+ def num_masks_per_frame(self, val):
+ self._num_masks_per_frame = val
+ self._mask_ratio = (val / self.num_patches_per_frame)
+
+ @property
+ def c(self):
+ if isinstance(self.clumping_factor, int):
+ return (self.clumping_factor,) * 2
+ else:
+ return self.clumping_factor[:2]
+
+ @property
+ def mask_ratio(self):
+ return self._mask_ratio
+
+ @mask_ratio.setter
+ def mask_ratio(self, val):
+ self._mask_ratio = val
+ self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame)
+
+ @property
+ def num_visible(self):
+ return self.num_patches_per_frame - self.num_masks_per_frame
+
+ @num_visible.setter
+ def num_visible(self, val):
+ self.num_masks_per_frame = self.num_patches_per_frame - val
+
+ def _set_torch_seed(self, seed):
+ self.seed = seed
+ torch.manual_seed(self.seed)
+
+ def __repr__(self):
+ repr_str = ("Class: {}\nMask: total patches per mask {},\n" + \
+ "mask patches per mask {}, visible patches per mask {}, mask ratio {:0.3f}\n" + \
+ "randomize num visible? {}").format(
+ type(self).__name__, self.num_patches_per_frame,
+ self.num_masks_per_frame, self.num_visible, self.mask_ratio,
+ self.randomize_num_visible
+ )
+ return repr_str
+
+ def sample_mask_per_frame(self, *args, **kwargs):
+ num_masks = self.num_masks_per_frame
+ if self.randomize_num_visible:
+ num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
+
+ mask = torch.cat([
+ torch.zeros([self.num_patches_per_frame - num_masks]),
+ torch.ones([num_masks])], 0).bool()
+ inds = torch.randperm(mask.size(0)).long()
+ mask = mask[inds]
+
+ if max(*self.c) > 1:
+ mask = mask.view(self.height // self.c[0],
+ 1,
+ self.width // self.c[1],
+ 1)
+ mask = torch.tile(mask, (1, self.c[0], 1, self.c[1]))
+ mask = mask.reshape(self.height - self.pad_h, self.width - self.pad_w)
+ _pad_h = self.rng.choice(range(self.pad_h + 1))
+ pad_h = (self.pad_h - _pad_h, _pad_h)
+ _pad_w = self.rng.choice(range(self.pad_w + 1))
+ pad_w = (self.pad_w - _pad_w, _pad_w)
+ mask = F.pad(mask,
+ pad_w + pad_h,
+ mode='constant',
+ value=1)
+ mask = mask.reshape(self.height, self.width)
+
+ return mask
+
+ def forward(self, x=None, num_frames=None):
+
+ num_frames = (num_frames or self.frames) or 1
+ if isinstance(x, torch.Tensor):
+ batch_size = x.size(0)
+ masks = torch.stack([
+ torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten()
+ for b in range(batch_size)], 0)
+ if not self.create_on_cpu:
+ masks = masks.to(x.device)
+ if batch_size == 1 and not self.always_batch:
+ masks = masks.squeeze(0)
+ else:
+ batch_size = 1
+ masks = torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten()
+ if self.always_batch:
+ masks = masks[None]
+
+ if self.visible_frames > 0:
+ vis = torch.zeros((batch_size, 1, self.height, self.width), dtype=torch.bool)
+ vis = vis.view(masks.shape).to(masks.device)
+ masks = torch.cat(([vis] * self.visible_frames) + [masks], -1)
+
+ return masks
diff --git a/cwm/eval/Flow/vis_utils.py b/cwm/eval/Flow/vis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2767b1c3d84b9f4ed1382463106f9de96003f8c
--- /dev/null
+++ b/cwm/eval/Flow/vis_utils.py
@@ -0,0 +1,150 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+
+def imshow(ims, ax=None, t=0, vmin=None, vmax=None, title=None, cmap=None, fontsize=20):
+ if ax is None:
+ fig, ax = plt.subplots(1,1)
+ with torch.no_grad():
+ im = ims[t].float().cpu().numpy().transpose((1,2,0))
+ if (vmin is not None) and (vmax is not None):
+ im =ax.imshow(im, vmin=vmin, vmax=vmax, cmap=(cmap or 'viridis'))
+ else:
+ im =ax.imshow(im)
+
+ if title is not None:
+ ax.set_title(title, fontsize=fontsize)
+
+ return (im, ax)
+
+def make_colorwheel():
+ """
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
+
+ Code follows the original C++ source code of Daniel Scharstein.
+ Code follows the the Matlab source code of Deqing Sun.
+
+ Returns:
+ np.ndarray: Color wheel
+ """
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = np.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
+ col = col+RY
+ # YG
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
+ colorwheel[col:col+YG, 1] = 255
+ col = col+YG
+ # GC
+ colorwheel[col:col+GC, 1] = 255
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
+ col = col+GC
+ # CB
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
+ colorwheel[col:col+CB, 2] = 255
+ col = col+CB
+ # BM
+ colorwheel[col:col+BM, 2] = 255
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
+ col = col+BM
+ # MR
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
+ colorwheel[col:col+MR, 0] = 255
+ return colorwheel
+
+
+def flow_uv_to_colors(u, v, convert_to_bgr=False):
+ """
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
+
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+
+ Args:
+ u (np.ndarray): Input horizontal flow of shape [H,W]
+ v (np.ndarray): Input vertical flow of shape [H,W]
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ np.ndarray: Flow visualization image of shape [H,W,3]
+ """
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
+ colorwheel = make_colorwheel() # shape [55x3]
+ ncols = colorwheel.shape[0]
+ rad = np.sqrt(np.square(u) + np.square(v))
+ a = np.arctan2(-v, -u)/np.pi
+ fk = (a+1) / 2*(ncols-1)
+ k0 = np.floor(fk).astype(np.int32)
+ k1 = k0 + 1
+ k1[k1 == ncols] = 0
+ f = fk - k0
+ for i in range(colorwheel.shape[1]):
+ tmp = colorwheel[:,i]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1-f)*col0 + f*col1
+ idx = (rad <= 1)
+ col[idx] = 1 - rad[idx] * (1-col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range
+ # Note the 2-i => BGR instead of RGB
+ ch_idx = 2-i if convert_to_bgr else i
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
+ return flow_image
+
+
+def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
+ """
+ Expects a two dimensional flow image of shape.
+
+ Args:
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ np.ndarray: Flow visualization image of shape [H,W,3]
+ """
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+ if clip_flow is not None:
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
+ u = flow_uv[:,:,0]
+ v = flow_uv[:,:,1]
+ rad = np.sqrt(np.square(u) + np.square(v))
+ rad_max = np.max(rad)
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+ return flow_uv_to_colors(u, v, convert_to_bgr)
+
+from decord import VideoReader, cpu
+from PIL import Image
+from torchvision import transforms
+def get_video(video_name, num_frames=2, delta_time=4, frame=None):
+ decord_vr = VideoReader(video_name, num_threads=1, ctx=cpu(0))
+ max_end_ind = len(decord_vr) - num_frames*delta_time - 1
+ start_frame = frame if frame is not None else rng.randint(1, max_end_ind)
+ print("fps", decord_vr.get_avg_fps())
+ print("start frame = %d" % start_frame)
+ frame_id_list = list(range(start_frame, start_frame + num_frames*delta_time, delta_time))
+ video_data = decord_vr.get_batch(frame_id_list).asnumpy()
+ video_data = [Image.fromarray(video_data[t]).convert('RGB') for t, _ in enumerate(frame_id_list)]
+ return (torch.stack([transforms.ToTensor()(im) for im in video_data], 0), start_frame)
+
+
+
diff --git a/cwm/eval/IntPhys/__init__.py b/cwm/eval/IntPhys/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cwm/eval/Physion/__init__.py b/cwm/eval/Physion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cwm/eval/Physion/feature_extractor.py b/cwm/eval/Physion/feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e050ff61cf27102fb34093da708f4ad40a359b5e
--- /dev/null
+++ b/cwm/eval/Physion/feature_extractor.py
@@ -0,0 +1,317 @@
+import numpy as np
+from physion_evaluator.feature_extract_interface import PhysionFeatureExtractor
+from physion_evaluator.utils import DataAugmentationForVideoMAE
+
+from torch.functional import F
+
+from cwm.eval.Flow.flow_utils import get_occ_masks
+
+from cwm.model.model_factory import model_factory
+import torch
+
+def load_predictor(
+ model_func_,
+ load_path_,
+ **kwargs):
+ predictor = model_func_(**kwargs).eval().requires_grad_(False)
+
+ did_load = predictor.load_state_dict(
+ torch.load(load_path_, map_location=torch.device("cpu"))['model'])
+ predictor._predictor_load_path = load_path_
+ print(did_load, load_path_)
+ return predictor
+
+
+class CWM(PhysionFeatureExtractor):
+ def __init__(self, model_name, aggregate_embeddings=False):
+ super().__init__()
+
+ self.model = model_factory.load_model(model_name).cuda().half()
+
+ self.num_frames = self.model.num_frames
+
+ self.timestamps = np.arange(self.num_frames)
+
+ ps = (224 // self.model.patch_size[1]) ** 2
+
+ self.bool_masked_pos = np.zeros([ps * self.num_frames])
+ self.bool_masked_pos[ps * (self.num_frames - 1):] = 1
+
+ self.ps = ps
+
+ self.aggregate_embeddings = aggregate_embeddings
+
+ def transform(self):
+
+ return DataAugmentationForVideoMAE(
+ imagenet_normalize=True,
+ rescale_size=224,
+ ), 150, 4
+
+ def fwd(self, videos):
+ bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
+ bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
+ x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
+ return_features=True)
+ return x_encoded
+
+ def extract_features(self, videos, for_flow=False):
+ '''
+ videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
+ returns: [B, T, D] extracted features
+ '''
+
+ videos = videos.transpose(1, 2)
+
+ all_features = []
+
+ # repeat the last frame of the video
+ videos = torch.cat([videos, videos[:, :, -1:]], dim=2)
+
+ for x in range(0, 4, self.num_frames - 1):
+ vid = videos[:, :, x:x + self.num_frames, :, :]
+ all_features.append(self.fwd(vid))
+ if self.aggregate_embeddings:
+ feats = all_features[-1].mean(dim=1, keepdim=True)
+ all_features[-1] = feats
+ # feats = feats.view(feats.shape[0], -1, self.model.num_patches_per_frame, feats.shape[-1])
+ # feats = feats.mean(dim=2)
+ # all_features[-1] = feats
+
+ x_encoded = torch.cat(all_features, dim=1)
+
+ return x_encoded
+
+
+class CWM_Keypoints(PhysionFeatureExtractor):
+ def __init__(self, model_name):
+ super().__init__()
+
+ self.model = model_factory.load_model(model_name).cuda().half()
+
+ self.frames = [[0, 1, 2], [1, 2, 3]]
+
+ self.num_frames = self.model.num_frames
+
+ self.ps = (224 // self.model.patch_size[1]) ** 2
+
+ self.bool_masked_pos = np.zeros([self.ps * self.num_frames])
+ self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1
+
+ self.frame_gap = 150
+
+ self.num_frames_dataset = 4
+
+ self.res = 224
+
+
+ def transform(self):
+
+ return DataAugmentationForVideoMAE(
+ imagenet_normalize=True,
+ rescale_size=self.res,
+ ), self.frame_gap, self.num_frames_dataset
+
+ def fwd(self, videos):
+ bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
+ bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
+ _, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
+ return_features=True)
+ return x_encoded
+
+ def extract_features(self, videos, segments=None):
+ '''
+ videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
+ returns: [B, T, D] extracted features
+ '''
+
+ videos = videos.transpose(1, 2)
+
+ all_features = []
+
+ for x, arr in enumerate(self.frames):
+
+ #use the downsampled videos for keypoints
+ vid = videos[:, :, arr, :, :].half()
+ frame0 = vid[:, :, 0]
+ frame1 = vid[:, :, 1]
+ frame2 = vid[:, :, 2]
+
+ #extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2
+ mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1)
+
+ #reshape the features to [batch size, num_features]
+ k_feat = k_feat.view(k_feat.shape[0], -1)
+
+ all_features.append(k_feat)
+
+ x_encoded = torch.cat(all_features, dim=1)
+
+ return x_encoded
+
+
+class CWM_KeypointsFlow(PhysionFeatureExtractor):
+ def __init__(self, model_name):
+ super().__init__()
+
+ self.model = model_factory.load_model(model_name).cuda().half()
+
+ self.frames = [[0, 3, 6], [3, 6, 9], [6, 9, 9]]
+
+ self.num_frames = self.model.num_frames
+
+ self.timestamps = np.arange(self.num_frames)
+
+ self.ps = (224 // self.model.patch_size[1]) ** 2
+
+ self.bool_masked_pos = np.zeros([self.ps * self.num_frames])
+ self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1
+
+ self.frame_gap = 50
+
+ self.num_frames_dataset = 9
+
+ self.res = 512
+
+ def transform(self):
+
+ return DataAugmentationForVideoMAE(
+ imagenet_normalize=True,
+ rescale_size=self.res,
+ ), self.frame_gap, self.num_frames_dataset
+
+ def fwd(self, videos):
+ bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
+ bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
+ _, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
+ return_features=True)
+ return x_encoded
+
+ def get_forward_flow(self, videos):
+
+ fid = 6
+
+ forward_flow = self.model.get_flow(videos[:, :, fid], videos[:, :, fid + 1], conditioning_img=videos[:, :, fid + 2], mode='cosine')
+
+ backward_flow = self.model.get_flow(videos[:, :, fid + 1], videos[:, :, fid], conditioning_img=videos[:, :, fid - 1], mode='cosine')
+
+ occlusion_mask = get_occ_masks(forward_flow, backward_flow)[0]
+
+ forward_flow = forward_flow * occlusion_mask
+
+ forward_flow = torch.stack([forward_flow, forward_flow, forward_flow], dim=1)
+
+ forward_flow = forward_flow.to(videos.device)
+
+ forward_flow = F.interpolate(forward_flow, size=(2, 224, 224), mode='nearest')
+
+ return forward_flow
+
+ def extract_features(self, videos, segments=None):
+ '''
+ videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
+ returns: [B, T, D] extracted features
+ Note:
+ For efficiency, the optical flow is computed and added for a single frame (300ms) as we found this to be sufficient
+ for capturing temporal dynamics in our experiments. This approach can be extended to multiple frames if needed,
+ depending on the complexity of the task.
+ '''
+
+
+ #resize to 224 to get keypoints and features
+ videos_downsampled = F.interpolate(videos.flatten(0, 1), size=(224, 224), mode='bilinear', align_corners=False)
+ videos_downsampled = videos_downsampled.view(videos.shape[0], videos.shape[1], videos.shape[2], 224, 224)
+
+ #for computing flow at higher resolution
+ videos_ = F.interpolate(videos.flatten(0, 1), size=(1024, 1024), mode='bilinear', align_corners=False)
+ videos = videos_.view(videos.shape[0], videos.shape[1], videos.shape[2], 1024, 1024)
+
+ videos = videos.transpose(1, 2).half()
+ videos_downsampled = videos_downsampled.transpose(1, 2).half()
+
+ # Get the forward flow for the frame at 300ms
+ forward_flow = self.get_forward_flow(videos)
+
+ # Verify that there are no nans forward flow
+ assert not torch.isnan(forward_flow).any(), "Forward flow is nan"
+
+ all_features = []
+
+ for x, arr in enumerate(self.frames):
+
+ #use the downsampled videos for keypoints
+ vid = videos_downsampled[:, :, arr, :, :]
+ frame0 = vid[:, :, 0]
+ frame1 = vid[:, :, 1]
+ frame2 = vid[:, :, 2]
+
+ #extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2
+ mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1)
+
+ #for the last set of frames only use features at keypoint regions of frame2
+ if (x == 2):
+ k_feat = k_feat[:, -10:, :]
+
+ #reshape the features to [batch size, num_features]
+ k_feat = k_feat.view(k_feat.shape[0], -1)
+
+ choices_image_resolution = choices * self.model.patch_size[1]
+
+ # At 300ms, add optical flow patches at the detected keypoint locations
+ # For the first frame (x == 0)
+ if x == 0:
+ # Extract the optical flow information from the forward flow matrix for the second channel (index 2)
+ flow_keyp = forward_flow[:, 2]
+
+ # Initialize a result tensor to store the flow patches
+ # Tensor shape: [batch_size, 8x8 patch (flattened to 64) * 2 channels, 10 keypoints]
+ flow = torch.zeros(vid.shape[0], 8 * 8 * 2, 10).to(videos.device)
+
+ # Patch size shift (since 8x8 patches are being extracted)
+ shift = 8
+
+ # Loop over each element in the batch to process individual video frames
+ for b in range(flow_keyp.size(0)):
+ # Extract the x and y coordinates of the keypoint locations for this batch element
+ x_indices = choices_image_resolution[b, :, 0]
+ y_indices = choices_image_resolution[b, :, 1]
+
+ # For each keypoint (10 total keypoints in this case)
+ for ind in range(10):
+ # Extract the 8x8 patch of optical flow at each keypoint's (x, y) location
+ # Flatten the patch and assign it to the corresponding slice in the result tensor
+ flow[b, :, ind] = flow_keyp[b, :, y_indices[ind]:y_indices[ind] + shift,
+ x_indices[ind]:x_indices[ind] + shift].flatten()
+
+ # Reshape the flow tensor for easier concatenation (flatten across all patches)
+ flow = flow.view(flow.shape[0], -1)
+
+ # Concatenate the extracted optical flow features with the existing feature tensor (k_feat)
+ k_feat = torch.cat([k_feat, flow], dim=1)
+
+ all_features.append(k_feat)
+
+ x_encoded = torch.cat(all_features, dim=1)
+
+ return x_encoded
+
+
+class CWM_base_8x8_3frame(CWM):
+ def __init__(self,):
+ super().__init__('vitb_8x8patch_3frames')
+
+class CWM_base_8x8_3frame_mean_embed(CWM):
+ def __init__(self,):
+ super().__init__('vitb_8x8patch_3frames', aggregate_embeddings=True)
+
+# CWM* (keypoints only) 74.7
+class CWM_base_8x8_3frame_keypoints(CWM_Keypoints):
+ def __init__(self,):
+ super().__init__('vitb_8x8patch_3frames')
+
+
+# CWM* (keypoints + Flow) 75.4
+class CWM_base_8x8_3frame_keypoints_flow(CWM_KeypointsFlow):
+ def __init__(self,):
+ super().__init__('vitb_8x8patch_3frames')
+
diff --git a/cwm/eval/Physion/flow_utils.py b/cwm/eval/Physion/flow_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7490805e601091e8c185dcc95fa047ce04ed331
--- /dev/null
+++ b/cwm/eval/Physion/flow_utils.py
@@ -0,0 +1,279 @@
+
+import torch
+import numpy as np
+import random
+import math
+
+def create_weighted_mask_batched(h, w):
+ y_mask = np.linspace(0, 1, h)
+ y_mask = np.minimum(y_mask, 1 - y_mask)
+ x_mask = np.linspace(0, 1, w)
+ x_mask = np.minimum(x_mask, 1 - x_mask)
+ weighted_mask = np.outer(y_mask, x_mask)
+ return torch.from_numpy(weighted_mask).float()
+
+def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape):
+ B, T, C, H, W = original_shape
+
+ # Initialize an empty tensor to store the reconstructed video
+ reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
+
+ # Create a tensor to store the sum of weighted masks
+ weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
+
+ # Create a weighted mask for the crops
+ weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device)
+ weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor.
+
+ for idx, crop in enumerate(cropped_tensors):
+ start_h, start_w = crop_positions[idx]
+
+ # Multiply the crop with the weighted mask
+ weighted_crop = crop * weighted_mask
+
+ # Add the weighted crop to the corresponding location in the reconstructed_video tensor
+ reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop
+
+ # Update the weighted_masks_sum tensor
+ weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask
+
+ # Add a small epsilon value to avoid division by zero
+ epsilon = 1e-8
+
+ # Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon
+ reconstructed_video /= (weighted_masks_sum + epsilon)
+
+ return reconstructed_video
+
+import torch.nn.functional as F
+
+resize = lambda x,a: F.interpolate(x, [int(a*x.shape[-2]), int(a*x.shape[-1])], mode='bilinear', align_corners=False)
+
+upsample = lambda x,H,W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False)
+
+
+
+#
+def compute_optical_flow(embedding_tensor, mask_tensor, frame_size):
+ # Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame
+ mask_unrolled = mask_tensor.view(-1)
+
+ second_frame_unmask_indices = torch.where(mask_unrolled[frame_size**2:] == False)[0]
+
+ # Divide the embedding tensor into two parts: corresponding to the first and the second frame
+ first_frame_embeddings = embedding_tensor[0, :frame_size**2, :]
+ second_frame_embeddings = embedding_tensor[0, frame_size**2:, :]
+
+ # Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame
+ dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T)
+ norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :]
+ cos_sim_matrix = dot_product / norms
+
+ # Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame
+ first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1)
+
+ # Convert the 1D pixel indices into 2D coordinates
+ second_frame_y = second_frame_unmask_indices // frame_size
+ second_frame_x = second_frame_unmask_indices % frame_size
+ first_frame_y = first_frame_most_similar_indices // frame_size
+ first_frame_x = first_frame_most_similar_indices % frame_size
+
+ # Compute the x and y displacements and convert them to float
+ displacements_x = (second_frame_x - first_frame_x).float()
+ displacements_y = (second_frame_y - first_frame_y).float()
+
+ # Initialize optical flow tensor
+ optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device)
+
+ # Assign the computed displacements to the corresponding pixels in the optical flow tensor
+ optical_flow[0, second_frame_y, second_frame_x] = displacements_x
+ optical_flow[1, second_frame_y, second_frame_x] = displacements_y
+
+ return optical_flow
+
+def get_minimal_224_crops_new_batched(video_tensor, N):
+ B, T, C, H, W = video_tensor.shape
+
+ # Calculate the number of crops needed in both the height and width dimensions
+ num_crops_h = math.ceil(H / 224) if H > 224 else 1
+ num_crops_w = math.ceil(W / 224) if W > 224 else 1
+
+ # Calculate the step size for the height and width dimensions
+ step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1))
+ step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1))
+
+ # Create a list to store the cropped tensors and their start positions
+ cropped_tensors = []
+ crop_positions = []
+
+ # Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list
+ for i in range(num_crops_h):
+ for j in range(num_crops_w):
+ start_h = i * step_size_h
+ start_w = j * step_size_w
+ end_h = min(start_h + 224, H)
+ end_w = min(start_w + 224, W)
+ crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w]
+ cropped_tensors.append(crop)
+ crop_positions.append((start_h, start_w))
+
+ D = len(cropped_tensors)
+
+ # If N is greater than D, generate additional random crops
+ if N > D and H > 224 and W > 224: # check if H and W are greater than 224
+ for _ in range(N - D):
+ start_h = random.randint(0, H - 224)
+ start_w = random.randint(0, W - 224)
+ crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)]
+ cropped_tensors.append(crop)
+ crop_positions.append((start_h, start_w))
+
+ # Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224)
+ cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors]
+
+ return cropped_tensors, crop_positions
+
+def get_honglin_3frame_vmae_optical_flow_crop_batched(generator,
+ mask_generator,
+ img1,
+ img2,
+ img3,
+ neg_back_flow=True,
+ num_scales=1,
+ min_scale=400,
+ N_mask_samples=100,
+ mask_ratio=0.8,
+ flow_frames='23'):
+ B = img1.shape[0]
+ assert len(img1.shape) == 4
+ assert num_scales >= 1
+
+ # For scaling
+ h1 = img2.shape[-2]
+ w1 = img2.shape[-1]
+ assert min_scale < h1
+
+ if neg_back_flow is False:
+ print('WARNING: Not calculating negative backward flow')
+
+ alpha = (min_scale / img1.shape[-2]) ** (1 / 4)
+
+ frame_size = 224 // generator.patch_size[-1]
+
+ patch_size = generator.patch_size[-1]
+
+ all_fwd_flows_e2d = []
+
+ for aidx in range(num_scales):
+
+ # print('aidx: ', aidx)
+
+ img1_scaled = resize(img1.clone(), alpha ** aidx)
+ img2_scaled = resize(img2.clone(), alpha ** aidx)
+ img3_scaled = resize(img3.clone(), alpha ** aidx)
+
+ h2 = img2_scaled.shape[-2]
+ w2 = img2_scaled.shape[-1]
+
+ s_h = h1 / h2
+ s_w = w1 / w2
+
+ # Because technically the compute_optical_flow function returns neg back flow
+ if neg_back_flow is True:
+ video = torch.cat([img3_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
+ else:
+ video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img3_scaled.unsqueeze(1)], 1)
+
+ # Should work, even if the incoming video is already 224x224
+ crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
+
+ # print(len(crops1), crops1[0].shape)
+
+ num_crops = len(crops1)
+
+ crop_flows_enc = []
+ crop_flows_enc2dec = []
+ N_samples = N_mask_samples
+
+ crop = torch.cat(crops1, 0).cuda()
+ # print(crop.shape)
+
+ optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
+ mask_counts = torch.zeros(frame_size, frame_size).cuda()
+
+ i = 0
+ while i < N_samples or (mask_counts == 0).any().item():
+ if i % 100 == 0:
+ pass # print(i)
+ mask_generator.mask_ratio = mask_ratio
+
+ # breakpoint()
+ # This would be that every sample has the same mask. For now that's okay I think
+ mask = mask_generator(num_frames=3)[None]
+ mask_2f = ~mask[0, frame_size * frame_size * 2:]
+ mask_counts += mask_2f.reshape(frame_size, frame_size)
+
+ with torch.cuda.amp.autocast(enabled=True):
+
+ processed_x = crop.transpose(1, 2)
+
+ # print("crop", processed_x.max())
+
+ encoder_out = generator.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
+ encoder_to_decoder = generator.encoder_to_decoder(encoder_out)
+ # print(encoder_to_decoder.shape)
+
+ if flow_frames == '23':
+ encoder_to_decoder = encoder_to_decoder[:, frame_size * frame_size:, :]
+ flow_mask = mask[:, frame_size * frame_size:]
+ # print(encoder_to_decoder.shape)
+ elif flow_frames == '12':
+ encoder_to_decoder = encoder_to_decoder[:, :frame_size * frame_size * 2, :]
+ # print(encoder_to_decoder.shape)
+ flow_mask = mask[:, :frame_size * frame_size * 2]
+ # print(mask.shape)
+ # print(flow_mask.shape)
+ # print()
+
+ optical_flow_e2d = []
+ # one per batch element for now
+ for b in range(B * num_crops):
+ batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size)
+ optical_flow_e2d.append(batch_flow.unsqueeze(0))
+
+ optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
+ optical_flows_enc2dec += optical_flow_e2d
+ i += 1
+
+ optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
+
+ scale_factor_y = video.shape[-2] / 224
+ scale_factor_x = video.shape[-1] / 224
+
+ scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec)
+ scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w
+ scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h
+
+ # split the crops back up
+ crop_flows_enc2dec = scaled_optical_flow.split(B, 0)
+ # print(len(crop_flows_enc2dec))
+
+ optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(
+ [_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in
+ crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
+
+ all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
+
+ all_fwd_flows_e2d_new = []
+
+ for r in all_fwd_flows_e2d:
+ new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1])
+ all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1))
+ return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
+
+ if neg_back_flow is True:
+ return_flow = -return_flow
+ all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
+
+ return return_flow, all_fwd_flows_e2d_new
+
diff --git a/cwm/eval/Physion/run_eval.sh b/cwm/eval/Physion/run_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3650eaf5ff7fbe035ab6dec72a266ef58c9dccfe
--- /dev/null
+++ b/cwm/eval/Physion/run_eval.sh
@@ -0,0 +1,17 @@
+#physion_feature_extract \
+#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
+#--data_root_path /ccn2/u/rmvenkat/data/testing_physion/regenerate_from_old_commit/ \
+#--model_class feature_extractor.CWM_base_8x8_3frame \
+#--gpu 1 \
+#--batch_size 8 \
+#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release/ \
+#--mode ocp
+
+physion_train_readout \
+--train-path /ccn2/u/rmvenkat/data/physion_release/ocp/train_features.hdf5 \
+--test-path /ccn2/u/rmvenkat/data/physion_release/ocp/test_features.hdf5 \
+--model-name CWM_base_8x8_3frame \
+--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release/ocp/train_json.json \
+--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release/ocp/test_json.json \
+--test-scenario-map /ccn2/u/rmvenkat/data/physion_release/ocp/test_scenario_map.json \
+--save_path /ccn2/u/rmvenkat/data/physion_release/
\ No newline at end of file
diff --git a/cwm/eval/Physion/run_eval_kfflow.sh b/cwm/eval/Physion/run_eval_kfflow.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9a1a547d962f2f9d8104008fc63c55c6e52866a8
--- /dev/null
+++ b/cwm/eval/Physion/run_eval_kfflow.sh
@@ -0,0 +1,18 @@
+#physion_feature_extract \
+#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
+#--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \
+#--model_class feature_extractor.CWM_base_8x8_3frame_Keypoints_KFFlowPatched_noF1_cwm_50_occ_mask \
+#--gpu 1 \
+#--batch_size 8 \
+#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_kfflow/ \
+#--mode ocp
+
+
+physion_train_readout \
+--train-path /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/train_features.hdf5 \
+--test-path /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_features.hdf5 \
+--model-name CWM_base_8x8_3frame_Keypoints_KFFlowPatched_noF1_cwm_50_occ_mask \
+--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/train_json.json \
+--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_json.json \
+--test-scenario-map /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_scenario_map.json \
+--save_path /ccn2/u/rmvenkat/data/physion_release_kfflow/
\ No newline at end of file
diff --git a/cwm/eval/Physion/run_eval_mp4s.sh b/cwm/eval/Physion/run_eval_mp4s.sh
new file mode 100644
index 0000000000000000000000000000000000000000..156f3c801b2d9eed159585732819c3206a42b2ac
--- /dev/null
+++ b/cwm/eval/Physion/run_eval_mp4s.sh
@@ -0,0 +1,19 @@
+dir_for_saving=/ccn2/u/rmvenkat/data/physion_release/
+model_name=CWM_base_8x8_3frame
+
+physion_feature_extract \
+--data_root_path /ccn2/u/rmvenkat/data/download_test/physion_mp4s/ \
+--model_class feature_extractor.$model_name \
+--gpu 1 \
+--batch_size 8 \
+--dir_for_saving $dir_for_saving \
+--mode ocd
+
+physion_train_readout \
+--train-path ${dir_for_saving}ocd/train_features.hdf5 \
+--test-path ${dir_for_saving}ocd/test_features.hdf5 \
+--model-name $model_name \
+--train-scenario-indices ${dir_for_saving}ocd/train_json.json \
+--test-scenario-indices ${dir_for_saving}ocd/test_json.json \
+--test-scenario-map ${dir_for_saving}ocd/test_scenario_map.json \
+--save_path $dir_for_saving
diff --git a/cwm/eval/Physion/run_eval_mp4s_keyp.sh b/cwm/eval/Physion/run_eval_mp4s_keyp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..58fea4eedbf9983a0f7ef29d077d2f1d03ce76cd
--- /dev/null
+++ b/cwm/eval/Physion/run_eval_mp4s_keyp.sh
@@ -0,0 +1,17 @@
+#physion_feature_extract \
+#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
+#--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \
+#--model_class feature_extractor_cleaned.CWM_base_8x8_3frame_keypoints \
+#--gpu 3 \
+#--batch_size 8 \
+#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_keyp/ \
+#--mode ocp
+
+physion_train_readout \
+--train-path /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/train_features.hdf5 \
+--test-path /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_features.hdf5 \
+--model-name CWM_base_8x8_3frame \
+--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/train_json.json \
+--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_json.json \
+--test-scenario-map /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_scenario_map.json \
+--save_path /ccn2/u/rmvenkat/data/physion_release_keyp/
\ No newline at end of file
diff --git a/cwm/eval/Physion/run_eval_mp4s_keyp_flow.sh b/cwm/eval/Physion/run_eval_mp4s_keyp_flow.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7c0016e7942e5dde6d9ab9b6f5accf50369c07c3
--- /dev/null
+++ b/cwm/eval/Physion/run_eval_mp4s_keyp_flow.sh
@@ -0,0 +1,17 @@
+#physion_feature_extract \
+#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
+#--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \
+#--model_class feature_extractor_cleaned.CWM_base_8x8_3frame_keypoints_flow \
+#--gpu 7 \
+#--batch_size 8 \
+#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ \
+#--mode ocp
+
+physion_train_readout \
+--train-path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/train_features.hdf5 \
+--test-path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_features.hdf5 \
+--model-name CWM_base_8x8_3frame \
+--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/train_json.json \
+--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_json.json \
+--test-scenario-map /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_scenario_map.json \
+--save_path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/
\ No newline at end of file
diff --git a/cwm/eval/Segmentation/__init__.py b/cwm/eval/Segmentation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cwm/eval/Segmentation/archive/__init__.py b/cwm/eval/Segmentation/archive/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cwm/eval/Segmentation/archive/common/__init__.py b/cwm/eval/Segmentation/archive/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cwm/eval/Segmentation/archive/common/coco_loader_lsj.py b/cwm/eval/Segmentation/archive/common/coco_loader_lsj.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f3cfeec95cf641cd4a3b0671ed8abac063ebbe2
--- /dev/null
+++ b/cwm/eval/Segmentation/archive/common/coco_loader_lsj.py
@@ -0,0 +1,222 @@
+import detectron2.data.transforms as T
+from detectron2 import model_zoo
+from detectron2.config import LazyCall as L
+
+# Data using LSJ
+image_size = 512
+dataloader = model_zoo.get_config("common/data/coco.py").dataloader
+dataloader.train.mapper.augmentations = [
+ L(T.RandomFlip)(horizontal=True), # flip first
+ L(T.ResizeScale)(
+ min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size
+ ),
+ L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False),
+]
+dataloader.train.mapper.image_format = "RGB"
+dataloader.train.total_batch_size = 64
+dataloader.train.num_workers = 0
+# recompute boxes due to cropping
+dataloader.train.mapper.recompute_boxes = True
+
+dataloader.test.mapper.augmentations = [
+ L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size),
+]
+
+
+
+
+import copy
+import logging
+import numpy as np
+from typing import List, Optional, Union
+import torch
+
+from detectron2.config import configurable
+
+from detectron2.data import detection_utils as utils
+from detectron2.data import transforms as T
+
+"""
+This file contains the default mapping that's applied to "dataset dicts".
+"""
+
+__all__ = ["DatasetMapper"]
+
+
+class DatasetMapper:
+ """
+ A callable which takes a dataset dict in Detectron2 Dataset format,
+ and map it into a format used by the model.
+
+ This is the default callable to be used to map your dataset dict into training data.
+ You may need to follow it to implement your own one for customized logic,
+ such as a different way to read or transform images.
+ See :doc:`/tutorials/data_loading` for details.
+
+ The callable currently does the following:
+
+ 1. Read the image from "file_name"
+ 2. Applies cropping/geometric transforms to the image and annotations
+ 3. Prepare data and annotations to Tensor and :class:`Instances`
+ """
+
+ @configurable
+ def __init__(
+ self,
+ is_train: bool,
+ *,
+ augmentations: List[Union[T.Augmentation, T.Transform]],
+ image_format: str,
+ use_instance_mask: bool = False,
+ use_keypoint: bool = False,
+ instance_mask_format: str = "polygon",
+ keypoint_hflip_indices: Optional[np.ndarray] = None,
+ precomputed_proposal_topk: Optional[int] = None,
+ recompute_boxes: bool = False,
+ ):
+ """
+ NOTE: this interface is experimental.
+
+ Args:
+ is_train: whether it's used in training or inference
+ augmentations: a list of augmentations or deterministic transforms to apply
+ image_format: an image format supported by :func:`detection_utils.read_image`.
+ use_instance_mask: whether to process instance segmentation annotations, if available
+ use_keypoint: whether to process keypoint annotations if available
+ instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
+ masks into this format.
+ keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
+ precomputed_proposal_topk: if given, will load pre-computed
+ proposals from dataset_dict and keep the top k proposals for each image.
+ recompute_boxes: whether to overwrite bounding box annotations
+ by computing tight bounding boxes from instance mask annotations.
+ """
+ if recompute_boxes:
+ assert use_instance_mask, "recompute_boxes requires instance masks"
+ # fmt: off
+ self.is_train = is_train
+ self.augmentations = T.AugmentationList(augmentations)
+ self.image_format = image_format
+ self.use_instance_mask = use_instance_mask
+ self.instance_mask_format = instance_mask_format
+ self.use_keypoint = use_keypoint
+ self.keypoint_hflip_indices = keypoint_hflip_indices
+ self.proposal_topk = precomputed_proposal_topk
+ self.recompute_boxes = recompute_boxes
+ # fmt: on
+ logger = logging.getLogger(__name__)
+ mode = "training" if is_train else "inference"
+ logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
+
+ @classmethod
+ def from_config(cls, cfg, is_train: bool = True):
+ augs = utils.build_augmentation(cfg, is_train)
+ if cfg.INPUT.CROP.ENABLED and is_train:
+ augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
+ recompute_boxes = cfg.MODEL.MASK_ON
+ else:
+ recompute_boxes = False
+
+ ret = {
+ "is_train": is_train,
+ "augmentations": augs,
+ "image_format": cfg.INPUT.FORMAT,
+ "use_instance_mask": cfg.MODEL.MASK_ON,
+ "instance_mask_format": cfg.INPUT.MASK_FORMAT,
+ "use_keypoint": cfg.MODEL.KEYPOINT_ON,
+ "recompute_boxes": recompute_boxes,
+ }
+
+ if cfg.MODEL.KEYPOINT_ON:
+ ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
+
+ if cfg.MODEL.LOAD_PROPOSALS:
+ ret["precomputed_proposal_topk"] = (
+ cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
+ if is_train
+ else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
+ )
+ return ret
+
+ def _transform_annotations(self, dataset_dict, transforms, image_shape):
+ # USER: Modify this if you want to keep them for some reason.
+ for anno in dataset_dict["annotations"]:
+ if not self.use_instance_mask:
+ anno.pop("segmentation", None)
+ if not self.use_keypoint:
+ anno.pop("keypoints", None)
+
+ # USER: Implement additional transformations if you have other types of data
+ annos = [
+ utils.transform_instance_annotations(
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
+ )
+ for obj in dataset_dict.pop("annotations")
+ if obj.get("iscrowd", 0) == 0
+ ]
+ instances = utils.annotations_to_instances(
+ annos, image_shape, mask_format=self.instance_mask_format
+ )
+
+ # After transforms such as cropping are applied, the bounding box may no longer
+ # tightly bound the object. As an example, imagine a triangle object
+ # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
+ # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
+ # the intersection of original bounding box and the cropping box.
+ if self.recompute_boxes:
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
+
+ def __call__(self, dataset_dict):
+ """
+ Args:
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
+
+ Returns:
+ dict: a format that builtin models in detectron2 accept
+ """
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
+ # USER: Write your own image loading if it's not from a file
+ image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
+ utils.check_image_size(dataset_dict, image)
+
+ # USER: Remove if you don't do semantic/panoptic segmentation.
+ if "sem_seg_file_name" in dataset_dict:
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
+ else:
+ sem_seg_gt = None
+
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
+ transforms = self.augmentations(aug_input)
+ image, sem_seg_gt = aug_input.image, aug_input.sem_seg
+
+ image_shape = image.shape[:2] # h, w
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
+ # Therefore it's important to use torch.Tensor.
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
+ if sem_seg_gt is not None:
+ dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
+
+ # USER: Remove if you don't use pre-computed proposals.
+ # Most users would not need this feature.
+ if self.proposal_topk is not None:
+ utils.transform_proposals(
+ dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
+ )
+
+ if not self.is_train:
+ # USER: Modify this if you want to keep them for some reason.
+ dataset_dict.pop("annotations", None)
+ dataset_dict.pop("sem_seg_file_name", None)
+ return dataset_dict
+
+ if "annotations" in dataset_dict:
+ self._transform_annotations(dataset_dict, transforms, image_shape)
+
+ # Modified by Honglin Chen: change it to class-agnostic instance labels
+ dataset_dict['instances'].gt_classes *= 0
+ return dataset_dict
+
+
+dataloader.train.mapper._target_ = DatasetMapper
diff --git a/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format.py b/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format.py
new file mode 100644
index 0000000000000000000000000000000000000000..757d3201e9f241dd752e7cadb92c6bc1ef8d2a8d
--- /dev/null
+++ b/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format.py
@@ -0,0 +1,19 @@
+import torch
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--input', type=str, help='The path to the checkpoint.')
+parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
+args = parser.parse_args()
+
+state_dict = torch.load(args.input, map_location='cpu')['model']
+
+new_state_dict = {}
+for k, v in state_dict.items():
+ if 'encoder' in k and not 'decoder' in k:
+ new_k = 'backbone.net.model.' + k
+ new_state_dict[new_k] = v
+
+output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
+torch.save(new_state_dict, output_path)
+print('Save model to', output_path)
\ No newline at end of file
diff --git a/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format_v2.py b/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b58caf4024bea05862a34aa3c81331ed40099e2
--- /dev/null
+++ b/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format_v2.py
@@ -0,0 +1,54 @@
+import torch
+import argparse
+import sys
+sys.path.append('../../../')
+from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--input', type=str, help='The path to the checkpoint.')
+parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
+args = parser.parse_args()
+
+state_dict = torch.load(args.input, map_location='cpu')['model']
+mae = True
+# C = state_dict['encoder.patch_embed.proj.weight'].shape[0]
+C = 768
+pos_embed = get_sinusoid_encoding_table(14*14, C)
+cls_token = torch.zeros(1, 1, C)
+pos_embed = torch.cat([cls_token, pos_embed], dim=1)
+
+
+new_state_dict = {'backbone.net.pos_embed': pos_embed}
+for k, v in state_dict.items():
+
+ if mae or ('encoder' in k and not 'decoder' in k or 'patch_embed' in k):
+
+ if 'patch_embed.proj.weight' in k:
+
+ if len(v.shape) == 5:
+ if v.shape[2] == 1:
+ v = v.squeeze(2) # (768, 3, 1, 16, 16) -> (768, 3, 16, 16)
+ else:
+ v = v[:, :, 0]
+
+ old_k = k
+ k = k.replace('encoder.', 'backbone.net.') if not mae else 'backbone.net.'+k
+
+ if 'attn' in k and '_bias' in k:
+ old_attn = '.'.join(old_k.split('.')[:-1])
+ attn = '.'.join(k.split('.')[:-1])
+ k = attn + '.qkv.bias'
+ if k in new_state_dict:
+ continue
+
+ v = torch.cat([
+ state_dict[old_attn + '.q_bias'],
+ state_dict[old_attn + '.k_bias'] if (old_attn + '.k_bias') in state_dict else torch.zeros_like(state_dict[old_attn + '.q_bias']),
+ state_dict[old_attn + '.v_bias'],
+ ], dim=0)
+ print(k, v.shape)
+ new_state_dict[k] = v
+
+output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
+torch.save(new_state_dict, output_path)
+print('Save model to', output_path)
\ No newline at end of file
diff --git a/cwm/eval/Segmentation/archive/common/convert_dino_checkpoint_detectron_format.py b/cwm/eval/Segmentation/archive/common/convert_dino_checkpoint_detectron_format.py
new file mode 100644
index 0000000000000000000000000000000000000000..477efc0c06657518571cf9ee5775ef2189592299
--- /dev/null
+++ b/cwm/eval/Segmentation/archive/common/convert_dino_checkpoint_detectron_format.py
@@ -0,0 +1,26 @@
+import torch
+import argparse
+import sys
+sys.path.append('../../../')
+from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--input', type=str, help='The path to the checkpoint.')
+parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
+args = parser.parse_args()
+breakpoint()
+state_dict = torch.load(args.input, map_location='cpu')
+
+new_state_dict = {}
+
+for k, v in state_dict.items():
+ if 'pos_embed' in k:
+ breakpoint()
+ else:
+ pass
+ k = 'backbone.net.' + k
+ new_state_dict[k] = v
+
+output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
+torch.save(new_state_dict, output_path)
+print('Save model to', output_path)
\ No newline at end of file
diff --git a/cwm/eval/Segmentation/archive/competition.py b/cwm/eval/Segmentation/archive/competition.py
new file mode 100644
index 0000000000000000000000000000000000000000..2143ee00f0f90f9e23c777dae4e1effaf805ef1c
--- /dev/null
+++ b/cwm/eval/Segmentation/archive/competition.py
@@ -0,0 +1,673 @@
+import numpy as np
+import torch
+from torch import nn
+from torchvision import transforms
+import torch.nn.functional as F
+from torch.distributions.categorical import Categorical
+
+from kornia.filters.kernels import (get_spatial_gradient_kernel2d,
+ normalize_kernel2d)
+
+def l2_normalize(x):
+ return F.normalize(x, p=2.0, dim=-1, eps=1e-6)
+
+def reduce_max(x, dim, keepdim=True):
+ return torch.max(x, dim=dim, keepdim=keepdim)[0]
+
+def coordinate_ims(batch_size, seq_length, imsize):
+ static = False
+ if seq_length == 0:
+ static = True
+ seq_length = 1
+ B = batch_size
+ T = seq_length
+ H,W = imsize
+ ones = torch.ones([B,H,W,1], dtype=torch.float32)
+ h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=torch.float32))
+ h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5)
+ w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=torch.float32))
+ w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5)
+ h = torch.stack([h]*T, 1)
+ w = torch.stack([w]*T, 1)
+ hw_ims = torch.cat([h,w], -1)
+ if static:
+ hw_ims = hw_ims[:,0]
+ return hw_ims
+
+def dot_product_attention(queries, keys, normalize=True, eps=1e-8):
+ """
+ Compute the normalized dot product between two PyTorch tensors
+ """
+ B,N,D_q = queries.size()
+ _B,N_k,D_k = keys.size()
+ assert D_q == D_k, (queries.shape, keys.shape)
+ if normalize:
+ queries = F.normalize(queries, p=2.0, dim=-1, eps=eps)
+ keys = F.normalize(keys, p=2.0, dim=-1, eps=eps)
+
+ outputs = torch.matmul(queries, torch.transpose(keys, 1, 2)) # [B, N, N_k]
+ attention = torch.transpose(outputs, 1, 2) # [B, N_k, N]
+
+ return outputs
+
+def sample_image_inds_from_probs(probs, num_points, eps=1e-9):
+
+ B,H,W = probs.shape
+ P = num_points
+ N = H*W
+
+ probs = probs.reshape(B,N)
+ probs = torch.maximum(probs + eps, torch.tensor(0., device=probs.device)) / (probs.sum(dim=-1, keepdim=True) + eps)
+ dist = Categorical(probs=probs, validate_args=False)
+
+ indices = dist.sample([P]).permute(1,0).to(torch.int32) # [B,P]
+
+ indices_h = torch.minimum(torch.maximum(torch.div(indices, W, rounding_mode='floor'), torch.tensor(0)), torch.tensor(H-1))
+ indices_w = torch.minimum(torch.maximum(torch.fmod(indices, W), torch.tensor(0)), torch.tensor(W-1))
+ indices = torch.stack([indices_h, indices_w], dim=-1) # [B,P,2]
+ return indices
+
+def get_gradient_image(image, mode='sobel', order=1, normalize_kernel=True):
+
+ B,C,H,W = list(image.size())
+
+ # prepare kernel
+ kernel = get_spatial_gradient_kernel2d(mode, order)
+ if normalize_kernel:
+ kernel = normalize_kernel2d(kernel)
+ tmp_kernel = kernel.to(image).detach()
+ tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1)
+ kernel_flip = tmp_kernel.flip(-3)
+
+ # pad spatial dims of image
+ padding = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
+ out_channels = 3 if (order == 2) else 2
+ padded_image = F.pad(image.reshape(B*C, 1, H, W), padding, 'replicate')[:, :, None] # [B*C,1,1,H+p,W+p]
+ gradient_image = F.conv3d(padded_image, kernel_flip, padding=0).view(B, C, out_channels, H, W)
+ return gradient_image
+
+def sample_coordinates_at_borders(image, num_points=16, mask=None, sum_edges=True, normalized_coordinates=True):
+ """
+ Sample num_points in normalized (h,w) coordinates from the borders of the input image
+ """
+ B,C,H,W = list(image.size())
+ if mask is not None:
+ assert mask.shape[2:] == image.shape[2:], (mask.size(), image.size())
+ else:
+ mask = torch.ones(size=(B,1,H,W)).to(image)
+
+ gradient_image = get_gradient_image(image * mask, mode='sobel', order=1) # [B,C,2,H,W]
+ gradient_magnitude = torch.sqrt(torch.square(gradient_image).sum(dim=2))
+ if sum_edges:
+ edges = gradient_magnitude.sum(1) # [B,H,W]
+ else:
+ edges = gradient_magnitude.max(1)[0]
+
+ if mask is not None:
+ edges = edges * mask[:,0]
+
+ coordinates = sample_image_inds_from_probs(edges, num_points=num_points)
+ if normalized_coordinates:
+ coordinates = coordinates.to(torch.float32)
+ coordinates /= torch.tensor([H-1,W-1], dtype=torch.float32)[None,None].to(coordinates.device)
+ coordinates = 2.0 * coordinates - 1.0
+ return coordinates
+
+def index_into_images(images, indices, channels_last=False):
+ """
+ index into an image at P points to get its values
+
+ images: [B,C,H,W]
+ indices: [B,P,2]
+ """
+ assert indices.size(-1) == 2, indices.size()
+ if channels_last:
+ images = images.permute(0,3,1,2) # [B,C,H,W]
+ B,C,H,W = images.shape
+ _,P,_ = indices.shape
+ inds_h, inds_w = list(indices.to(torch.long).permute(2,0,1)) # [B,P] each
+ inds_b = torch.arange(B, dtype=torch.long).unsqueeze(-1).expand(-1,P).to(indices)
+ inds = torch.stack([inds_b, inds_h, inds_w], 0).to(torch.long)
+ values = images.permute(0,2,3,1)[list(inds)] # [B,P,C]
+ return values
+
+def soft_index(images, indices, scale_by_imsize=True):
+ assert indices.shape[-1] == 2, indices.shape
+ B,C,H,W = images.shape
+ _,P,_ = indices.shape
+
+ # h_inds, w_inds = indices.split([1,1], dim=-1)
+ h_inds, w_inds = list(indices.permute(2,0,1))
+ if scale_by_imsize:
+ h_inds = (h_inds + 1.0) * torch.tensor(H).to(h_inds) * 0.5
+ w_inds = (w_inds + 1.0) * torch.tensor(W).to(w_inds) * 0.5
+
+ h_inds = torch.maximum(torch.minimum(h_inds, torch.tensor(H-1).to(h_inds)), torch.tensor(0.).to(h_inds))
+ w_inds = torch.maximum(torch.minimum(w_inds, torch.tensor(W-1).to(w_inds)), torch.tensor(0.).to(w_inds))
+
+ h_floor = torch.floor(h_inds)
+ w_floor = torch.floor(w_inds)
+ h_ceil = torch.ceil(h_inds)
+ w_ceil = torch.ceil(w_inds)
+
+ bot_right_weight = (h_inds - h_floor) * (w_inds - w_floor)
+ bot_left_weight = (h_inds - h_floor) * (w_ceil - w_inds)
+ top_right_weight = (h_ceil - h_inds) * (w_inds - w_floor)
+ top_left_weight = (h_ceil - h_inds) * (w_ceil - w_inds)
+
+ in_bounds = (bot_right_weight + bot_left_weight + top_right_weight + top_left_weight) > 0.95
+ in_bounds = in_bounds.to(torch.float32)
+
+ top_left_vals = index_into_images(images, torch.stack([h_floor, w_floor], -1))
+ top_right_vals = index_into_images(images, torch.stack([h_floor, w_ceil], -1))
+ bot_left_vals = index_into_images(images, torch.stack([h_ceil, w_floor], -1))
+ bot_right_vals = index_into_images(images, torch.stack([h_ceil, w_ceil], -1))
+
+ im_vals = top_left_vals * top_left_weight[...,None]
+ im_vals += top_right_vals * top_right_weight[...,None]
+ im_vals += bot_left_vals * bot_left_weight[...,None]
+ im_vals += bot_right_vals * bot_right_weight[...,None]
+
+ im_vals = im_vals.view(B,P,C)
+
+ return im_vals
+
+def compute_compatibility(positions, plateau, phenotypes=None, availability=None, noise=0.1):
+ """
+ Compute how well "fit" each agent is for the position it's at on the plateau,
+ according to its "phenotype"
+
+ positions: [B,P,2]
+ plateau: [B,H,W,Q]
+ phenotypes: [B,P,D] or None
+ availability: [B,H,W,A]
+ """
+ B,H,W,Q = plateau.shape
+ P = positions.shape[1]
+ if phenotypes is None:
+ phenotypes = soft_index(plateau, positions)
+
+ if availability is not None:
+ assert list(availability.shape)[:-1] == list(plateau.shape)[:-1], (availability.shape, plateau.shape)
+ A = availability.size(-1)
+ assert P % A == 0, (P, A)
+ S = P // A # population size
+ print("computing availability -- needlessly?", [B,H,W,A,Q])
+ plateau = availability[...,None] * plateau[...,None,:] # [B,H,W,A,Q]
+ plateau = plateau.view(B,H,W,A*Q)
+
+ plateau_values = soft_index(plateau.permute(0,3,1,2), positions, scale_by_imsize=True)
+ if noise > 0:
+ plateau_values += noise * torch.rand(size=plateau_values.size(), dtype=torch.float32).to(plateau_values.device)
+
+ if availability is not None:
+ plateau_values = l2_normalize(plateau_values.view(B, P, A, Q))
+ inds = torch.tile(torch.eye(A)[None].expand(B,-1,-1), (1,S,1))[...,None] # [B,P,A,1]
+ plateau_values = torch.sum(plateau_values * inds.to(plateau_values), dim=-2) # [B,P,Q]
+ else:
+ plateau_values = l2_normalize(plateau_values)
+
+ compatibility = torch.sum(
+ l2_normalize(phenotypes) * plateau_values, dim=-1, keepdim=True) # [B,P,1]
+
+ return compatibility
+
+def compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=None, eps=1e-6):
+ """Find overlaps between masks"""
+ B,N,P = masks.shape
+ if masks_target is None:
+ masks_target = masks
+ if mask_thresh is not None:
+ masks = (masks > mask_thresh).to(torch.float32)
+ masks_target = (masks_target > mask_thresh).to(torch.float32)
+
+ ## union and intersection
+ overlaps = masks[...,None] * masks_target[...,None,:] # [B,N,P,P]
+ I = overlaps.sum(dim=1)
+ U = torch.maximum(masks[...,None], masks_target[...,None,:]).sum(dim=1)
+ iou = I / torch.maximum(U, torch.tensor(eps, dtype=torch.float32)) # [B,P,P]
+
+ return iou
+
+def compete_agents(masks, fitnesses, alive,
+ mask_thresh=0.5, compete_thresh=0.2,
+ sticky_winners=True):
+ """
+ Kill off agents (which mask dimensions are "alive") based on mask overlap and fitnesses of each
+
+ args:
+ masks: [B,N,P]
+ fitnesses: [B,P,1]
+ alive: [B,P,1]
+
+ returns:
+ still_alive: [B,P,1]
+
+ """
+ B,N,P = masks.shape
+ assert list(alive.shape) == [B,P,1], alive.shape
+ assert list(fitnesses.shape) == [B,P,1], fitnesses.shape
+
+ ## find territorial disputes
+ overlaps = compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=mask_thresh)
+ disputes = overlaps > compete_thresh # [B,P,P]