diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..d3f2e7e969ff7f05a8f2b7e840ba56cf6078a9e3 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.txt filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6001e403779d845b4ffbf7b40030fe73dddd0116 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.idea + +.pth \ No newline at end of file diff --git a/README.md b/README.md index d69abd77ea7f147f1e074bf9d70354486e123baa..09987e4997af467217294891e791f16e8ebb7962 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,47 @@ ---- -title: Counterfactual World Models -emoji: 📊 -colorFrom: purple -colorTo: yellow -sdk: gradio -sdk_version: 4.44.0 -app_file: app.py -pinned: false -license: mit -short_description: Vision foundation model that unifies vision structures ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +
+

Understanding Physical Dynamics with Counterfactual World Modeling

+ +[**Rahul Venkatesh***](https://rahulvenkk.github.io/)1 · [**Honglin Chen***](https://web.stanford.edu/~honglinc/)1* · [**Kevin Feigelis***](https://neuroscience.stanford.edu/people/kevin-t-feigelis)1 · [**Daniel M. Bear**](https://twitter.com/recursus?lang=en)1 · [**Khaled Jedoui**](https://web.stanford.edu/~thekej/)1 · [**Klemen Kotar**](https://klemenkotar.github.io/)1 · [**Felix Binder**](https://ac.felixbinder.net/)2 · [**Wanhee Lee**](https://www.linkedin.com/in/wanhee-lee-31102820b/)1 · [**Sherry Liu**](https://neuroailab.github.io/cwm-physics/)1 · [**Kevin A. Smith**](https://www.mit.edu/~k2smith/)3 · [**Judith E. Fan**](https://cogtoolslab.github.io/)1 · [**Daniel L. K. Yamins**](https://stanford.edu/~yamins/)1 + +(* equal contribution) + +1Stanford    2UCSD    3MIT + + + + +Paper PDF +Project Page + + +
+ +This work presents the Counterfactual World Modeling (CWM) framework. CWM is capable of counterfactual prediction and extraction of vision structures useful for understanding physical dynamics. + +![](assets/cwm_teaser.gif) + +## 📣 News + +- 2024-06-01: Release [project page](https://neuroailab.github.io) and [codes](https://github.com/rahulvenkk/cwm_release.git) + +## 🔨 Installation + +``` +git clone https://github.com/rahulvenkk/cwm_release.git +pip install -e . +``` + +## ✨ Usage +To download and use a pre-trianed model run the following +``` +from cwm.model.model_factory import model_factory +model = model_factory.load_model('vitbase_8x8patch_3frames_1tube') +``` +This will automatically initialize the appropriate model class and download the specified weights to your `$CACHE` directory. + +## 🔄 Pre-training +To train the model run the following script + +``` +./scripts/pretrain/3frame_patch8x8_mr0.90_gpu.sh +``` diff --git a/assets/color_wheel.png b/assets/color_wheel.png new file mode 100644 index 0000000000000000000000000000000000000000..84cdae7bb397f27457ea8b65457847e79c1f716e Binary files /dev/null and b/assets/color_wheel.png differ diff --git a/assets/cwm_teaser.gif b/assets/cwm_teaser.gif new file mode 100644 index 0000000000000000000000000000000000000000..ff0a7ff2592aebf693544f447058a3ff96e17c86 --- /dev/null +++ b/assets/cwm_teaser.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fac6e545660c695f81f360a87f1060c44eea95f3ae7dbcf1fecbe2b097e3b6a +size 12896275 diff --git a/assets/desk_1.jpg b/assets/desk_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e7e1c5fdea5cd0dcc0125e9125b3d4afe0defb51 --- /dev/null +++ b/assets/desk_1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84a3bfdd40841e8d291b4eb638ecc29b99f054ae6d3ea51b4cdc3090741987c8 +size 4731808 diff --git a/assets/flow_test_videos/libby.mp4 b/assets/flow_test_videos/libby.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d5086b20a9a14f6c3f14e7df3eec2a947ea7688b --- /dev/null +++ b/assets/flow_test_videos/libby.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d1887bc796d1883e8a63405398125476257572c0c8d7f1862bf309e422b4828 +size 671950 diff --git a/assets/flow_test_videos/weight_lifter.mp4 b/assets/flow_test_videos/weight_lifter.mp4 new file mode 100755 index 0000000000000000000000000000000000000000..12a2d354ee2e002c858d346e2731ad6c9168c02f --- /dev/null +++ b/assets/flow_test_videos/weight_lifter.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:525988b314936079f904236c614e2f987ba64fd5c69f4f12dd5b6b9076311854 +size 1176790 diff --git a/cwm/__init__.py b/cwm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/data/__init__.py b/cwm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/data/dataset.py b/cwm/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4d0aafd109ce17f8a220f67fce80994f1145f8 --- /dev/null +++ b/cwm/data/dataset.py @@ -0,0 +1,453 @@ +import os +import decord +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class VideoMAE(torch.utils.data.Dataset): + """Load your own video classification dataset. + Parameters + ---------- + root : str, required. + Path to the root folder storing the dataset. + setting : str, required. + A text file describing the dataset, each line per video sample. + There are three items in each line: (1) video path; (2) video length and (3) video label. + train : bool, default True. + Whether to load the training or validation set. + test_mode : bool, default False. + Whether to perform evaluation on the test set. + Usually there is three-crop or ten-crop evaluation strategy involved. + name_pattern : str, default None. + The naming pattern of the decoded video frames. + For example, img_00012.jpg. + video_ext : str, default 'mp4'. + If video_loader is set to True, please specify the video format accordinly. + is_color : bool, default True. + Whether the loaded image is color or grayscale. + modality : str, default 'rgb'. + Input modalities, we support only rgb video frames for now. + Will add support for rgb difference image and optical flow image later. + num_segments : int, default 1. + Number of segments to evenly divide the video into clips. + A useful technique to obtain global video-level information. + Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. + num_crop : int, default 1. + Number of crops for each image. default is 1. + Common choices are three crops and ten crops during evaluation. + new_length : int, default 1. + The length of input video clip. Default is a single image, but it can be multiple video frames. + For example, new_length=16 means we will extract a video clip of consecutive 16 frames. + new_step : int, default 1. + Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. + new_step=2 means we will extract a video clip of every other frame. + temporal_jitter : bool, default False. + Whether to temporally jitter if new_step > 1. + video_loader : bool, default False. + Whether to use video loader to load data. + use_decord : bool, default True. + Whether to use Decord video loader to load data. Otherwise use mmcv video loader. + transform : function, default None. + A function that takes data and label and transforms them. + data_aug : str, default 'v1'. + Different types of data augmentation auto. Supports v1, v2, v3 and v4. + lazy_init : bool, default False. + If set to True, build a dataset instance without loading any dataset. + """ + + def __init__(self, + root, + setting, + train=True, + test_mode=False, + name_pattern='img_%05d.jpg', + video_ext='mp4', + is_color=True, + modality='rgb', + num_segments=1, + num_crop=1, + new_length=1, + new_step=1, + randomize_interframes=False, + transform=None, + temporal_jitter=False, + video_loader=False, + use_decord=False, + lazy_init=False, + is_video_dataset=True): + + super(VideoMAE, self).__init__() + self.root = root + self.setting = setting + self.train = train + self.test_mode = test_mode + self.is_color = is_color + self.modality = modality + self.num_segments = num_segments + self.num_crop = num_crop + self.new_length = new_length + + self.randomize_interframes = randomize_interframes + self._new_step = new_step # If randomize_interframes is True, then this is the max, otherwise it's just the skip + # self._skip_length = self.new_length * self.new_step # If randomize_interframes is True, then this isn't used, otherwise it's used as calculated + self.temporal_jitter = temporal_jitter + self.name_pattern = name_pattern + self.video_loader = video_loader + self.video_ext = video_ext + self.use_decord = use_decord + self.transform = transform + self.lazy_init = lazy_init + + if (not self.lazy_init) and is_video_dataset: + self.clips = self._make_dataset(root, setting) + if len(self.clips) == 0: + raise (RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" + "Check your data directory (opt.data-dir).")) + + def __getitem__(self, index): + + directory, target = self.clips[index] + + if self.video_loader: + if '.' in directory.split('/')[-1]: + # data in the "setting" file already have extension, e.g., demo.mp4 + video_name = directory + else: + # data in the "setting" file do not have extension, e.g., demo + # So we need to provide extension (i.e., .mp4) to complete the file name. + video_name = '{}.{}'.format(directory, self.video_ext) + + try: + decord_vr = decord.VideoReader(video_name, num_threads=1) + except: + # return video_name + return (self.__getitem__(index + 1)) + duration = len(decord_vr) + + segment_indices, skip_offsets, new_step, skip_length = self._sample_train_indices(duration) + + images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets, + new_step, skip_length) + + process_data, mask = self.transform((images, None)) # T*C,H,W + process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, + 1) # T*C,H,W -> T,C,H,W -> C,T,H,W + + return (process_data, mask) + + def __len__(self): + return len(self.clips) + + def _make_dataset(self, directory, setting): + if not os.path.exists(setting): + raise (RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) + clips = [] + with open(setting) as split_f: + data = split_f.readlines() + for line in data: + line_info = line.split(' ') + # line format: video_path, video_duration, video_label + if len(line_info) < 2: + raise (RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) + elif len(line_info) > 2: + line_info = (' '.join(line_info[:-1]), line_info[-1]) # filename has spaces + clip_path = os.path.join(line_info[0]) + target = int(line_info[1]) + item = (clip_path, target) + clips.append(item) + # import torch_xla.core.xla_model as xm + # print = xm.master_print + # print("Dataset created. Number of clips: ", len(clips)) + return clips + + def _sample_train_indices(self, num_frames): + if self.randomize_interframes is False: + new_step = self._new_step + else: + new_step = np.random.randint(1, self._new_step + 1) + + skip_length = self.new_length * new_step + + average_duration = (num_frames - skip_length + 1) // self.num_segments + if average_duration > 0: + offsets = np.multiply(list(range(self.num_segments)), + average_duration) + offsets = offsets + np.random.randint(average_duration, + size=self.num_segments) + elif num_frames > max(self.num_segments, skip_length): + offsets = np.sort(np.random.randint( + num_frames - skip_length + 1, + size=self.num_segments)) + else: + offsets = np.zeros((self.num_segments,)) + + if self.temporal_jitter: + skip_offsets = np.random.randint( + new_step, size=skip_length // new_step) + else: + skip_offsets = np.zeros( + skip_length // new_step, dtype=int) + return offsets + 1, skip_offsets, new_step, skip_length + + def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets, new_step, + skip_length): + sampled_list = [] + frame_id_list = [] + for seg_ind in indices: + offset = int(seg_ind) + for i, _ in enumerate(range(0, skip_length, new_step)): + if offset + skip_offsets[i] <= duration: + frame_id = offset + skip_offsets[i] - 1 + else: + frame_id = offset - 1 + frame_id_list.append(frame_id) + if offset + new_step < duration: + offset += new_step + try: + video_data = video_reader.get_batch(frame_id_list).asnumpy() + sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in + enumerate(frame_id_list)] + except: + raise RuntimeError( + 'Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, + duration)) + return sampled_list + + +class ContextAndTargetVideoDataset(VideoMAE): + """ + A video dataset whose provided videos consist of (1) a "context" sequence of length Tc + and (2) a "target" sequence Tt. + + These two sequences have the same frame rate (specificiable in real units) but are + separated by a specified gap (which may vary for different examples.) + + The main use case is for training models to predict ahead by some variable amount, + given the context. + """ + + standard_fps = [12, 24, 30, 48, 60, 100] + + def __init__(self, + root, + setting, + train=True, + test_mode=False, + transform=None, + step_units='ms', + new_step=150, + start_frame=0, + context_length=2, + target_length=1, + channels_first=True, + generate_masks=True, + mask_generator=None, + context_target_gap=[400, 600], + normalize_timestamps=True, + default_fps=30, + min_fps=0.1, + seed=0, + *args, + **kwargs): + super(ContextAndTargetVideoDataset, self).__init__( + root=root, + setting=setting, + train=train, + test_mode=test_mode, + transform=transform, + new_length=context_length, + use_decord=True, + lazy_init=False, + video_loader=True, + *args, **kwargs) + + # breakpoint() + + self.context_length = self.new_length + self.target_length = target_length + + ## convert from fps and step size to frames + self._fps = None + self._min_fps = min_fps + self._default_fps = default_fps + self._step_units = step_units + self.new_step = new_step + + ## sampling for train and test + self._start_frame = start_frame + self.gap = context_target_gap + self.seed = seed + self.rng = np.random.RandomState(seed=seed) + + # breakpoint() + + ## output formatting + self._channels_first = channels_first + self._normalize_timestamps = normalize_timestamps + self._generate_masks = generate_masks + self.mask_generator = mask_generator + + + def _get_frames_per_t(self, t): + if self._step_units == 'frames' or (self._step_units is None): + return int(t) + + assert self._fps is not None + t_per_frame = 1 / self._fps + if self._step_units in ['ms', 'milliseconds']: + t_per_frame *= 1000.0 + + return max(int(np.round(t / t_per_frame)), 1) + + @property + def new_step(self): + if self._fps is None: + return None + else: + return self._get_frames_per_t(self._new_step) + + @new_step.setter + def new_step(self, v): + self._new_step = v + + @property + def gap(self): + if self._fps is None: + return [1, 2] + else: + gap = [self._get_frames_per_t(self._gap[0]), + self._get_frames_per_t(self._gap[1])] + gap[1] = max(gap[1], gap[0] + 1) + return gap + + @gap.setter + def gap(self, v): + if v is None: + v = self._new_step + if not isinstance(v, (list, tuple)): + v = [v, v] + self._gap = v + + def _get_video_name(self, directory): + if ''.join(['.', self.video_ext]) in directory.split('/')[-1]: + # data in the "setting" file has extension, e.g. demo.mpr + video_name = directory + else: + # data doesn't have an extension + video_name = '{}.{}'.format(directory, self.video_ext) + return video_name + + def _set_fps(self, reader): + """click fps to a standard""" + if self._step_units == 'frames' or self._step_units is None: + self._fps = None + else: + self._fps = None + fps = reader.get_avg_fps() + for st in self.standard_fps: + if (int(np.floor(fps)) == st) or (int(np.ceil(fps)) == st): + self._fps = st + if self._fps is None: + self._fps = int(np.round(fps)) + + if self._fps < self._min_fps: + self._fps = self._default_fps + + def _get_step_and_gap(self): + step = self.new_step + if self.randomize_interframes and self.train: + step = self.rng.randint(1, step + 1) + + if self.train: + gap = self.rng.randint(*self.gap) + else: + gap = sum(self.gap) // 2 + return (step, gap) + + def _sample_frames(self): + step, gap = self._get_step_and_gap() + + ## compute total length of sample + ## e.g. if context_length = 2, step = 1, gap = 10, target_length = 2: + ## total_length = 2 * 1 + 10 + (2 - 1) * 1 = 13 + ## so len(video) must be >= 13 + self._total_length = self.context_length * step + gap + (self.target_length - 1) * step + if self._total_length > (self._num_frames - self._start_frame): + if self.train: + return None + else: + raise ValueError( + "movie of length %d starting at fr=%d is too long for video of %d frames" % \ + (self._total_length, self._start_frame, self._num_frames)) + + ## sample the frames randomly (if training) or from the start frame (if test) + if self.train: + self.start_frame_now = self.rng.randint( + min(self._start_frame, self._num_frames - self._total_length), + self._num_frames - self._total_length + 1) + else: + self.start_frame_now = min(self._start_frame, self._num_frames - self._total_length) + + frames = [self.start_frame_now + i * step for i in range(self.context_length)] + frames += [frames[-1] + gap + i * step for i in range(self.target_length)] + + # breakpoint() + + return frames + + def _decode_frame_images(self, reader, frames): + try: + video_data = reader.get_batch(frames).asnumpy() + video_data = [Image.fromarray(video_data[t, :, :, :]).convert('RGB') + for t, _ in enumerate(frames)] + except: + raise RuntimeError( + "Error occurred in reading frames {} from video {} of duration {}".format( + frames, self.index, self._num_frames)) + return video_data + + def __getitem__(self, index): + + self.index = index + self.directory, target = self.clips[index] + + self.video_name = self._get_video_name(self.directory) + + ## build decord loader + try: + decord_vr = decord.VideoReader(self.video_name, num_threads=1) + self._set_fps(decord_vr) + except: + # return self.video_name + return (self.__getitem__(index + 1)) + + ## sample the video + self._num_frames = len(decord_vr) + self.frames = self._sample_frames() + if self.frames is None: + print("no movie of length %d for video idx=%d" % (self._total_length, self.index)) + return self.__getitem__(index + 1) + + ## decode to PIL.Image + image_list = self._decode_frame_images(decord_vr, self.frames) + + ## postproc to torch.Tensor and mask generation + if self.transform is None: + image_tensor = torch.stack([transforms.ToTensor()(img) for img in image_list], 0) + else: + image_tensor = self.transform((image_list, None)) + + image_tensor = image_tensor.view(self.context_length + self.target_length, 3, *image_tensor.shape[-2:]) + + ## VMAE expects [B,C,T,H,W] rather than [B,T,C,H,W] + if self._channels_first: + image_tensor = image_tensor.transpose(0, 1) + + if self._generate_masks and self.mask_generator is not None: + mask = self.mask_generator() + return image_tensor, mask.bool() + else: + return image_tensor diff --git a/cwm/data/dataset_utils.py b/cwm/data/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db9057b48eee5df1c91ed00f86405f1b1a619f30 --- /dev/null +++ b/cwm/data/dataset_utils.py @@ -0,0 +1,73 @@ +from torchvision import transforms +from cwm.data.transforms import * +from cwm.data.dataset import ContextAndTargetVideoDataset +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from cwm.data.masking_generator import RotatedTableMaskingGenerator + +class DataAugmentationForVideoMAE(object): + def __init__(self, augmentation_type, input_size, augmentation_scales): + + transform_list = [] + + self.scale = GroupScale(input_size) + transform_list.append(self.scale) + + if augmentation_type == 'multiscale': + self.train_augmentation = GroupMultiScaleCrop(input_size, list(augmentation_scales)) + elif augmentation_type == 'center': + self.train_augmentation = GroupCenterCrop(input_size) + + transform_list.extend([self.train_augmentation, Stack(roll=False), ToTorchFormatTensor(div=True)]) + + # Normalize input images + normalize = GroupNormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + transform_list.append(normalize) + + self.transform = transforms.Compose(transform_list) + + def __call__(self, images): + process_data, _ = self.transform(images) + return process_data + + def __repr__(self): + repr = "(DataAugmentationForVideoMAE,\n" + repr += " transform = %s,\n" % str(self.transform) + repr += ")" + return repr + + +def build_pretraining_dataset(args): + + dataset_list = [] + data_transform = DataAugmentationForVideoMAE(args.augmentation_type, args.input_size, args.augmentation_scales) + + mask_generator = RotatedTableMaskingGenerator( + input_size=args.mask_input_size, + mask_ratio=args.mask_ratio, + tube_length=args.tubelet_size, + batch_size=args.batch_size, + mask_type=args.mask_type + ) + + for data_path in [args.data_path] if args.data_path_list is None else args.data_path_list: + dataset = ContextAndTargetVideoDataset( + root=None, + setting=data_path, + video_ext='mp4', + is_color=True, + modality='rgb', + context_length=args.context_frames, + target_length=args.target_frames, + step_units=args.temporal_units, + new_step=args.sampling_rate, + context_target_gap=args.context_target_gap, + transform=data_transform, + randomize_interframes=False, + channels_first=True, + temporal_jitter=False, + train=True, + mask_generator=mask_generator, + ) + dataset_list.append(dataset) + dataset = torch.utils.data.ConcatDataset(dataset_list) + return dataset diff --git a/cwm/data/masking_generator.py b/cwm/data/masking_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..19fdf4b2444966a0b9517024f1bcf6cb18d3f125 --- /dev/null +++ b/cwm/data/masking_generator.py @@ -0,0 +1,86 @@ +import numpy as np +import torch + +def get_tubes(masks_per_frame, tube_length): + rp = torch.randperm(len(masks_per_frame)) + masks_per_frame = masks_per_frame[rp] + + tubes = [masks_per_frame] + for x in range(tube_length - 1): + masks_per_frame = masks_per_frame.clone() + rp = torch.randperm(len(masks_per_frame)) + masks_per_frame = masks_per_frame[rp] + tubes.append(masks_per_frame) + + tubes = torch.vstack(tubes) + + return tubes + +class RotatedTableMaskingGenerator: + def __init__(self, + input_size, + mask_ratio, + tube_length, + batch_size, + mask_type='rotated_table', + seed=None, + randomize_num_visible=False): + + self.batch_size = batch_size + + self.mask_ratio = mask_ratio + self.tube_length = tube_length + + self.frames, self.height, self.width = input_size + self.num_patches_per_frame = self.height * self.width + self.total_patches = self.frames * self.num_patches_per_frame + + self.seed = seed + self.randomize_num_visible = randomize_num_visible + + self.mask_type = mask_type + + def __repr__(self): + repr_str = "Inverted Table Mask: total patches {}, tube length {}, randomize num visible? {}, seed {}".format( + self.total_patches, self.tube_length, self.randomize_num_visible, self.seed + ) + return repr_str + + def __call__(self, m=None): + + if self.mask_type == 'rotated_table_magvit': + self.mask_ratio = np.random.uniform(low=0.0, high=1) + self.mask_ratio = np.cos(self.mask_ratio * np.pi / 2) + elif self.mask_type == 'rotated_table_maskvit': + self.mask_ratio = np.random.uniform(low=0.5, high=1) + + all_masks = [] + for b in range(self.batch_size): + + self.num_masks_per_frame = max(0, int(self.mask_ratio * self.num_patches_per_frame)) + self.total_masks = self.tube_length * self.num_masks_per_frame + + num_masks = self.num_masks_per_frame + + if self.randomize_num_visible: + assert "Randomize num visible Not implemented" + num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1)) + + if self.mask_ratio == 0: + mask_per_frame = torch.hstack([ + torch.zeros(self.num_patches_per_frame - num_masks), + ]) + else: + mask_per_frame = torch.hstack([ + torch.zeros(self.num_patches_per_frame - num_masks), + torch.ones(num_masks), + ]) + + tubes = get_tubes(mask_per_frame, self.tube_length) + top = torch.zeros(self.height * self.width).to(tubes.dtype) + + top = torch.tile(top, (self.frames - self.tube_length, 1)) + mask = torch.cat([top, tubes]) + mask = mask.flatten() + all_masks.append(mask) + return torch.stack(all_masks) diff --git a/cwm/data/transforms.py b/cwm/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..f16a7373a9e1d7abb1c91fd041827b380b5392ed --- /dev/null +++ b/cwm/data/transforms.py @@ -0,0 +1,206 @@ +import torch +import torchvision.transforms.functional as F +import warnings +import random +import numpy as np +import torchvision +from PIL import Image, ImageOps +import numbers + + +class GroupRandomCrop(object): + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, img_tuple): + img_group, label = img_tuple + + w, h = img_group[0].size + th, tw = self.size + + out_images = list() + + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + + for img in img_group: + assert(img.size[0] == w and img.size[1] == h) + if w == tw and h == th: + out_images.append(img) + else: + out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) + + return (out_images, label) + + +class GroupCenterCrop(object): + def __init__(self, size): + self.worker = torchvision.transforms.CenterCrop(size) + + def __call__(self, img_tuple): + img_group, label = img_tuple + return ([self.worker(img) for img in img_group], label) + + +class GroupNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor_tuple): + tensor, label = tensor_tuple + rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) + rep_std = self.std * (tensor.size()[0]//len(self.std)) + + # TODO: make efficient + for t, m, s in zip(tensor, rep_mean, rep_std): + t.sub_(m).div_(s) + + return (tensor,label) + + +class GroupGrayScale(object): + def __init__(self, size): + self.worker = torchvision.transforms.Grayscale(size) + + def __call__(self, img_tuple): + img_group, label = img_tuple + return ([self.worker(img) for img in img_group], label) + + +class GroupScale(object): + """ Rescales the input PIL.Image to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + self.worker = torchvision.transforms.Resize(size, interpolation) + + def __call__(self, img_tuple): + img_group, label = img_tuple + return ([self.worker(img) for img in img_group], label) + + +class GroupMultiScaleCrop(object): + + def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): + self.scales = scales if scales is not None else [1, 875, .75, .66] + self.max_distort = max_distort + self.fix_crop = fix_crop + self.more_fix_crop = more_fix_crop + self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] + self.interpolation = Image.BILINEAR + + def __call__(self, img_tuple): + img_group, label = img_tuple + + im_size = img_group[0].size + + crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) + crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] + ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] + return (ret_img_group, label) + + def _sample_crop_size(self, im_size): + image_w, image_h = im_size[0], im_size[1] + + # find a crop size + base_size = min(image_w, image_h) + crop_sizes = [int(base_size * x) for x in self.scales] + crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] + crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] + + pairs = [] + for i, h in enumerate(crop_h): + for j, w in enumerate(crop_w): + if abs(i - j) <= self.max_distort: + pairs.append((w, h)) + + crop_pair = random.choice(pairs) + if not self.fix_crop: + w_offset = random.randint(0, image_w - crop_pair[0]) + h_offset = random.randint(0, image_h - crop_pair[1]) + else: + w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) + + return crop_pair[0], crop_pair[1], w_offset, h_offset + + def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): + offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) + return random.choice(offsets) + + @staticmethod + def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): + w_step = (image_w - crop_w) // 4 + h_step = (image_h - crop_h) // 4 + + ret = list() + ret.append((0, 0)) # upper left + ret.append((4 * w_step, 0)) # upper right + ret.append((0, 4 * h_step)) # lower left + ret.append((4 * w_step, 4 * h_step)) # lower right + ret.append((2 * w_step, 2 * h_step)) # center + + if more_fix_crop: + ret.append((0, 2 * h_step)) # center left + ret.append((4 * w_step, 2 * h_step)) # center right + ret.append((2 * w_step, 4 * h_step)) # lower center + ret.append((2 * w_step, 0 * h_step)) # upper center + + ret.append((1 * w_step, 1 * h_step)) # upper left quarter + ret.append((3 * w_step, 1 * h_step)) # upper right quarter + ret.append((1 * w_step, 3 * h_step)) # lower left quarter + ret.append((3 * w_step, 3 * h_step)) # lower righ quarter + return ret + + +class Stack(object): + + def __init__(self, roll=False): + self.roll = roll + + def __call__(self, img_tuple): + img_group, label = img_tuple + + if img_group[0].mode == 'L': + return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label) + elif img_group[0].mode == 'RGB': + if self.roll: + return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label) + else: + return (np.concatenate(img_group, axis=2), label) + + +class ToTorchFormatTensor(object): + """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] + to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ + def __init__(self, div=True): + self.div = div + + def __call__(self, pic_tuple): + pic, label = pic_tuple + + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() + else: + # handle PIL Image + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + return (img.float().div(255.) if self.div else img.float(), label) + + +class IdentityTransform(object): + + def __call__(self, data): + return data diff --git a/cwm/data/video_file_lists/kinetics_400_train_list.txt b/cwm/data/video_file_lists/kinetics_400_train_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..35c95966b7313044367975e20aeb24c444e324e6 --- /dev/null +++ b/cwm/data/video_file_lists/kinetics_400_train_list.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65e14c0735b4c90c57022add2407a8524d246cc09b3d5a7e83b963ac3b231032 +size 19539143 diff --git a/cwm/data/video_file_lists/kinetics_400_train_list_sing.txt b/cwm/data/video_file_lists/kinetics_400_train_list_sing.txt new file mode 100644 index 0000000000000000000000000000000000000000..af8496578845cf64678d43290bec5fb4bf7f1b86 --- /dev/null +++ b/cwm/data/video_file_lists/kinetics_400_train_list_sing.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b18ccdce4616fb32a0aababc2342a640a11f7d73439f49358a16cc99e7eaed3 +size 1943 diff --git a/cwm/engine_for_pretraining.py b/cwm/engine_for_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..854755c3eab5980335ff7d4abf4f8694b2e87400 --- /dev/null +++ b/cwm/engine_for_pretraining.py @@ -0,0 +1,92 @@ +import json +import os +import time +from typing import Iterable + +import torch +import torch.nn as nn +from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + +import utils + +from datetime import datetime + + +def train_one_epoch(model: torch.nn.Module, + data_loader: Iterable, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int, + loss_scaler, + start_steps=None, + lr_schedule_values=None, + wd_schedule_values=None, + global_rank=None, + args=None, + loss_func = nn.MSELoss(), + ): + + metric_logger = utils.MetricLogger(delimiter=" ") + + if args.eval: + model.eval() + else: + model.train() + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + + header = f'Epoch [{epoch}]' + patch_size = model.module.encoder.patch_size[-2:] + tubelet_size = model.module.encoder.patch_size[0] + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None] + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None] + + for step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + + # assign learning rate & weight decay for each iteration + it = start_steps + step # global training iteration + if (lr_schedule_values is not None or wd_schedule_values is not None) and (step % args.accum_iter == 0): + for i, param_group in enumerate(optimizer.param_groups): + if lr_schedule_values is not None: + param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] + if wd_schedule_values is not None and param_group["weight_decay"] > 0: + param_group["weight_decay"] = wd_schedule_values[it] + + # prepare input + videos, bool_masked_pos = batch + videos = videos.to(device, non_blocking=True) + bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1) + + # prepare target + with torch.no_grad(): + unnorm_videos = videos * std + mean # in [0, 1] + videos_patch = utils.patchify(unnorm_videos, tubelet_size, patch_size) + B, _, C = videos_patch.shape + labels = videos_patch[bool_masked_pos].reshape(B, -1, C) + + # feedforward + with torch.cuda.amp.autocast(enabled=True): + outputs = model(videos, bool_masked_pos) + loss = loss_func(input=outputs, target=labels) + + loss_value = loss.item() + + # backward + is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order + loss /= args.accum_iter + loss_scaler(loss, optimizer, clip_grad=None, + parameters=model.parameters(), create_graph=is_second_order, + update_grad=(step + 1) % args.accum_iter == 0) + + torch.cuda.synchronize() + metric_logger.update(loss=loss_value) + + if (step + 1) % args.accum_iter == 0: + optimizer.zero_grad() + + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(lr=lr) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/cwm/eval/Action_recognition/__init__.py b/cwm/eval/Action_recognition/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/eval/Flow/__init__.py b/cwm/eval/Flow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/eval/Flow/create_spring_submission_parallel.sh b/cwm/eval/Flow/create_spring_submission_parallel.sh new file mode 100644 index 0000000000000000000000000000000000000000..91a1f8257f8af8acd01687fe019d745673c37884 --- /dev/null +++ b/cwm/eval/Flow/create_spring_submission_parallel.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# Define the path to the dataset and the Python script +DATASET_PATH="/ccn2/dataset/Flows_Kinetics/SPRING/spring/test/" +SCRIPT_PATH="./create_spring_submission_unified.py" +SAVE_DATA_PATH=${1} +MODEL=${2} +# Counter for GPUs +GPU_COUNTER=0 + +# Number of GPUs available +NUM_GPUS=8 + +#kill session +tmux kill-session -t extraction + +tmux new-session -d -s "extraction" + +# Iterate through each folder in the dataset +for FOLDER in $(find $DATASET_PATH -mindepth 1 -maxdepth 1 -type d | sort); do + # Extract the folder name for the tmux session name + FOLDER_NAME=$(basename $FOLDER) + + # Create a new detached tmux session for each folder + tmux new-window -t extraction -n "$FOLDER" "ulimit -n 65535; CUDA_VISIBLE_DEVICES=$GPU_COUNTER python $SCRIPT_PATH --folder $FOLDER --gpu $GPU_COUNTER --save_data_path $SAVE_DATA_PATH --model $MODEL; echo 'Press Enter to continue...'; read -p ''" + # Increment the GPU counter and reset if it exceeds the number of GPUs + GPU_COUNTER=$((GPU_COUNTER + 1)) + if [ $GPU_COUNTER -ge $NUM_GPUS ]; then + GPU_COUNTER=0 + fi + + sleep 1 +done + +tmux attach-session -t extraction + diff --git a/cwm/eval/Flow/create_spring_submission_unified.py b/cwm/eval/Flow/create_spring_submission_unified.py new file mode 100644 index 0000000000000000000000000000000000000000..a6fb526cc175dd8d2f9a66f624ed2bd8ad46fec6 --- /dev/null +++ b/cwm/eval/Flow/create_spring_submission_unified.py @@ -0,0 +1,111 @@ + +import argparse + +# Parse command-line arguments +import importlib +import time + +parser = argparse.ArgumentParser(description='Process a folder with RAFT') +parser.add_argument('--folder', type=str, required=True, help='Folder to process') +parser.add_argument('--model', type=str, required=True, help='Model used to extract flow') +parser.add_argument('--save_data_path', type=str, required=True, help='where to save the data') +parser.add_argument('--gpu', type=int, default=0, help='GPU index to use') +args = parser.parse_args() +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) +import torch +torch.cuda.set_device(0) + +import h5py + +def writeFlo5File(flow, filename): + with h5py.File(filename, "w") as f: + f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5) + +if __name__ == '__main__': + module_name, class_name = args.model.rsplit(".", 1) + module = importlib.import_module(module_name) + + model = getattr(module, class_name) + model = model().cuda().eval() + + folder = args.folder.split('/')[-1] + + import os + import matplotlib.pyplot as plt + + import torch + import torchvision.transforms as transforms + + # import smurf # Assuming this is your custom inference module + + # Path for the dataset + dataset_path = '/ccn2/dataset/Flows_Kinetics/SPRING/spring/test/' + + save_data_path = args.save_data_path + + if not os.path.exists(save_data_path): + os.makedirs(save_data_path) + + resize_crop = transforms.Compose([ + transforms.ToTensor(), + ]) + + import numpy as np + + def l2norm(x): + return np.sqrt((x ** 2).sum(-1)) + + all_epe = [] + # Create a new HDF5 file + + TAG_FLOAT = 202021.25 + + # Iterate over each folder in the dataset directory + for dir in ['FW', 'BW']: + for stereo in ['left', 'right']: + files = sorted(os.listdir(os.path.join(dataset_path, folder, f'frame_{stereo}'))) + output_folder = os.path.join(save_data_path, folder) + output_folder = os.path.join(output_folder, f'flow_{dir}_{stereo}') + + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + for ct_f in range(len(files) - 1): + # Read images + if dir == 'FW': + f1 = files[ct_f] + f2 = files[ct_f + 1] + else: + f2 = files[ct_f] + f1 = files[ct_f + 1] + t = time.time() + image1_path = os.path.join(dataset_path, folder, f'frame_{stereo}', f1) + image2_path = os.path.join(dataset_path, folder, f'frame_{stereo}', f2) + + idx = image1_path.split('/')[-1].split('.')[0].split('_')[-1] + flow_save_path = os.path.join(output_folder, f'flow_{dir}_{stereo}_' + idx + '.flo5') + + # if os.path.exists(flow_save_path): + # try: + # with h5py.File(flow_save_path, 'r+') as f: + # if f['flow'][:].shape[0] == 2: + # flow = f['flow'][:].transpose([1, 2, 0]) + # del f['flow'] + # f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5) + # continue + # else: + # continue + # except: + # pass + + image1_ = plt.imread(image1_path) + image2_ = plt.imread(image2_path) + + image1 = resize_crop(image1_) + image2 = resize_crop(image2_) + + forward_flow = model.forward(image1, image2) + + writeFlo5File(forward_flow, flow_save_path) diff --git a/cwm/eval/Flow/flow_extraction_classes.py b/cwm/eval/Flow/flow_extraction_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..f1923d72745b998cb144408f0eee763d9f6b360b --- /dev/null +++ b/cwm/eval/Flow/flow_extraction_classes.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn + +import cwm.model.model_pretrain as vmae_tranformers +from . import flow_utils +from . import losses as bblosses + + +# Normal Resolution +def l2_norm(x): + return x.square().sum(-3, True).sqrt() + + + + +# x.shape +def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5): + fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd) + flow_diff_fwd = flow_fwd + fwd_bck_cycle + + bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck) + flow_diff_bck = flow_bck + bck_fwd_cycle + + norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2 + norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2 + + occ_thresh_fwd = occ_thresh * norm_fwd + 0.5 + occ_thresh_bck = occ_thresh * norm_bck + 0.5 + + occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float() + occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float() + + return occ_mask_fwd, occ_mask_bck + + +class ExtractFlow(nn.Module): + + def __init__(self): + super().__init__() + return + + def forward(self, img1, img2): + ''' + img1: first frame + img2: second frame + returns: flow map (h, w, 2) + ''' + +from cwm.data.masking_generator import RotatedTableMaskingGenerator + +class CWM(ExtractFlow): + def __init__(self, model_name, patch_size, weights_path): + super().__init__() + + self.patch_size = patch_size + model = getattr(vmae_tranformers, model_name) + vmae_8x8_full = model().cuda().eval().requires_grad_(False) + + VMAE_LOAD_PATH = weights_path + did_load = vmae_8x8_full.load_state_dict(torch.load(VMAE_LOAD_PATH)['model'], strict=False) + print(did_load, VMAE_LOAD_PATH) + + self.predictor = vmae_8x8_full + + self.mask_generator = RotatedTableMaskingGenerator( + input_size=(vmae_8x8_full.num_frames, 28, 28), + mask_ratio=0.0, + tube_length=1, + batch_size=1, + mask_type='rotated_table' + ) + + def forward(self, img1, img2): + ''' + img1: [3, 1024, 1024] + img1: [3, 1024, 1024] + both images are imagenet normalized + ''' + + with torch.no_grad(): + FF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor, + self.mask_generator, img1[None], + img2[None], + num_scales=2, + min_scale=224, + N_mask_samples=1) + + BF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor, + self.mask_generator, + img2[None], + img1[None], + num_scales=2, + min_scale=224, + N_mask_samples=1) + + # FF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor, + # self.mask_generator, img1[None], + # img2[None], img2[None], + # neg_back_flow=True, num_scales=1, + # min_scale=224, N_mask_samples=1, + # mask_ratio=0.0) + # + # BF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor, + # self.mask_generator, img2[None], + # img1[None], img1[None], + # neg_back_flow=True, num_scales=1, + # min_scale=224, N_mask_samples=1, + # mask_ratio=0.0) + + occ_mask = get_occ_masks(FF, BF)[0] + + FF = FF * occ_mask + + FF = FF[0] + + return FF#.cpu().numpy().transpose([1, 2, 0]) + + +class CWM_8x8(CWM): + def __init__(self): + super().__init__('vitb_8x8patch_3frames', 8, + '/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth') diff --git a/cwm/eval/Flow/flow_utils.py b/cwm/eval/Flow/flow_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ad37600b33e096c07939d9b797dd6a6db927abb7 --- /dev/null +++ b/cwm/eval/Flow/flow_utils.py @@ -0,0 +1,569 @@ +import random +import math +import numpy as np +import torch +import torch.nn.functional as F +from . import losses as bblosses +import kornia + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +def compute_optical_flow(embedding_tensor, mask_tensor, frame_size): + # Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame + mask_unrolled = mask_tensor.view(-1) + + second_frame_unmask_indices = torch.where(mask_unrolled[frame_size ** 2:] == False)[0] + + # Divide the embedding tensor into two parts: corresponding to the first and the second frame + first_frame_embeddings = embedding_tensor[0, :frame_size ** 2, :] + second_frame_embeddings = embedding_tensor[0, frame_size ** 2:, :] + + # print(first_frame_embeddings.shape, second_frame_embeddings.shape, embedding_tensor.shape) + + # Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame + dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T) + norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :] + cos_sim_matrix = dot_product / norms + + # Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame + first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1) + + # Convert the 1D pixel indices into 2D coordinates + second_frame_y = second_frame_unmask_indices // frame_size + second_frame_x = second_frame_unmask_indices % frame_size + first_frame_y = first_frame_most_similar_indices // frame_size + first_frame_x = first_frame_most_similar_indices % frame_size + + # Compute the x and y displacements and convert them to float + displacements_x = (second_frame_x - first_frame_x).float() + displacements_y = (second_frame_y - first_frame_y).float() + + # Initialize optical flow tensor + optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device) + + # Assign the computed displacements to the corresponding pixels in the optical flow tensor + optical_flow[0, second_frame_y, second_frame_x] = displacements_x + optical_flow[1, second_frame_y, second_frame_x] = displacements_y + + return optical_flow + + +def get_minimal_224_crops_new_batched(video_tensor, N): + B, T, C, H, W = video_tensor.shape + + # Calculate the number of crops needed in both the height and width dimensions + num_crops_h = math.ceil(H / 224) if H > 224 else 1 + num_crops_w = math.ceil(W / 224) if W > 224 else 1 + + # Calculate the step size for the height and width dimensions + step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1)) + step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1)) + + # Create a list to store the cropped tensors and their start positions + cropped_tensors = [] + crop_positions = [] + + # Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list + for i in range(num_crops_h): + for j in range(num_crops_w): + start_h = i * step_size_h + start_w = j * step_size_w + end_h = min(start_h + 224, H) + end_w = min(start_w + 224, W) + crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w] + cropped_tensors.append(crop) + crop_positions.append((start_h, start_w)) + + D = len(cropped_tensors) + + # If N is greater than D, generate additional random crops + if N > D and H > 224 and W > 224: # check if H and W are greater than 224 + for _ in range(N - D): + start_h = random.randint(0, H - 224) + start_w = random.randint(0, W - 224) + crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] + cropped_tensors.append(crop) + crop_positions.append((start_h, start_w)) + + # Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224) + cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors] + + return cropped_tensors, crop_positions + + +def create_weighted_mask_batched(h, w): + y_mask = np.linspace(0, 1, h) + y_mask = np.minimum(y_mask, 1 - y_mask) + x_mask = np.linspace(0, 1, w) + x_mask = np.minimum(x_mask, 1 - x_mask) + weighted_mask = np.outer(y_mask, x_mask) + return torch.from_numpy(weighted_mask).float() + + +def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape): + B, T, C, H, W = original_shape + + # Initialize an empty tensor to store the reconstructed video + reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device) + + # Create a tensor to store the sum of weighted masks + weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device) + + # Create a weighted mask for the crops + weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device) + weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor. + + for idx, crop in enumerate(cropped_tensors): + start_h, start_w = crop_positions[idx] + + # Multiply the crop with the weighted mask + weighted_crop = crop * weighted_mask + + # Add the weighted crop to the corresponding location in the reconstructed_video tensor + reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop + + # Update the weighted_masks_sum tensor + weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask + + # Add a small epsilon value to avoid division by zero + epsilon = 1e-8 + + # Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon + reconstructed_video /= (weighted_masks_sum + epsilon) + + return reconstructed_video + + +def l2_norm(x): + return x.square().sum(-3, True).sqrt() + + +resize = lambda x, a: F.interpolate(x, [int(a * x.shape[-2]), int(a * x.shape[-1])], mode='bilinear', + align_corners=False) + +upsample = lambda x, H, W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False) + + +def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5): + fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd) + flow_diff_fwd = flow_fwd + fwd_bck_cycle + + bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck) + flow_diff_bck = flow_bck + bck_fwd_cycle + + norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2 + norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2 + + occ_thresh_fwd = occ_thresh * norm_fwd + 0.5 + occ_thresh_bck = occ_thresh * norm_bck + 0.5 + + occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float() + occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float() + + return occ_mask_fwd, occ_mask_bck + +def forward_backward_cycle_consistency(flow_fwd, flow_bck, niters=10): + # Make sure to be using axes-swapped, upsampled flows! + bck_flow_clone = flow_bck.clone().detach() + fwd_flow_clone = flow_fwd.clone().detach() + + for i in range(niters): + + fwd_bck_cycle_orig, _ = bblosses.backward_warp(img2=bck_flow_clone, flow=fwd_flow_clone) + flow_diff_fwd_orig = fwd_flow_clone + fwd_bck_cycle_orig + + fwd_flow_clone = fwd_flow_clone - flow_diff_fwd_orig/2 + + bck_fwd_cycle_orig, _ = bblosses.backward_warp(img2=fwd_flow_clone, flow=bck_flow_clone) + flow_diff_bck_orig = bck_flow_clone + bck_fwd_cycle_orig + + + bck_flow_clone = bck_flow_clone - flow_diff_bck_orig/2 + + return fwd_flow_clone, bck_flow_clone + +from PIL import Image +def resize_flow_map(flow_map, target_size): + """ + Resize a flow map to a target size while adjusting the flow vectors. + + Parameters: + flow_map (numpy.ndarray): Input flow map of shape (H, W, 2) where each pixel contains a (dx, dy) flow vector. + target_size (tuple): Target size (height, width) for the resized flow map. + + Returns: + numpy.ndarray: Resized and scaled flow map of shape (target_size[0], target_size[1], 2). + """ + # Get the original size + flow_map = flow_map[0].detach().cpu().numpy() + flow_map = flow_map.transpose(1, 2, 0) + original_size = flow_map.shape[:2] + + # Separate the flow map into two channels: dx and dy + flow_map_x = flow_map[:, :, 0] + flow_map_y = flow_map[:, :, 1] + + # Convert each flow channel to a PIL image for resizing + flow_map_x_img = Image.fromarray(flow_map_x) + flow_map_y_img = Image.fromarray(flow_map_y) + + # Resize both channels to the target size using bilinear interpolation + flow_map_x_resized = flow_map_x_img.resize(target_size, Image.BILINEAR) + flow_map_y_resized = flow_map_y_img.resize(target_size, Image.BILINEAR) + + # Convert resized PIL images back to NumPy arrays + flow_map_x_resized = np.array(flow_map_x_resized) + flow_map_y_resized = np.array(flow_map_y_resized) + + # Compute the scaling factor based on the size change + scale_factor = target_size[0] / original_size[0] # Scaling factor for both dx and dy + + # Scale the flow vectors (dx and dy) accordingly + flow_map_x_resized *= scale_factor + flow_map_y_resized *= scale_factor + + # Recombine the two channels into a resized flow map + flow_map_resized = np.stack([flow_map_x_resized, flow_map_y_resized], axis=-1) + + flow_map_resized = torch.from_numpy(flow_map_resized)[None].permute(0, 3, 1, 2) + + return flow_map_resized + +def get_vmae_optical_flow_crop_batched_smoothed(generator, + mask_generator, + img1, + img2, + neg_back_flow=True, + num_scales=1, + min_scale=400, + N_mask_samples=100, + mask_ratio=0.8, + smoothing_factor=1): + + ##### DEPRECATED + print('Deprecated. Please use scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed') + + return scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator, + mask_generator, + img1, + img2, + neg_back_flow=neg_back_flow, + num_scales=num_scales, + min_scale=min_scale, + N_mask_samples=N_mask_samples, + mask_ratio=mask_ratio, + smoothing_factor=smoothing_factor) + + + +def average_crops(tensor, D): + C, H, W = tensor.shape + + # Create zero-filled tensors for the shifted crops + down_shifted = torch.zeros_like(tensor) + up_shifted = torch.zeros_like(tensor) + right_shifted = torch.zeros_like(tensor) + left_shifted = torch.zeros_like(tensor) + + # Shift the tensor and store the results in the zero-filled tensors + down_shifted[:, :H-D, :] = tensor[:, D:, :] + up_shifted[:, D:, :] = tensor[:, :H-D, :] + right_shifted[:, :, :W-D] = tensor[:, :, D:] + left_shifted[:, :, D:] = tensor[:, :, :W-D] + + # Average the tensor with its four crops + result = (tensor + down_shifted + up_shifted + right_shifted + left_shifted) / 5.0 + + return result + + +def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(predictor, + mask_generator, + img1, + img2, + conditioning_img=None, + num_scales=1, + min_scale=400, + N_mask_samples=100, + smoothing_factor=1): + B = img1.shape[0] + assert len(img1.shape) == 4 + assert num_scales >= 1 + + # For scaling + h1 = img2.shape[-2] + w1 = img2.shape[-1] + + + alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1 + + frame_size = 224 // predictor.patch_size[-1] + + patch_size = predictor.patch_size[-1] + + num_frames = predictor.num_frames + + all_fwd_flows_e2d = [] + + s_hs = [] + s_ws = [] + + for aidx in range(num_scales): + # print(aidx) + + # print('aidx: ', aidx) + + img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], + mode='bicubic', align_corners=True) + img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], + mode='bicubic', align_corners=True) + + if conditioning_img is not None: + conditioning_img_scaled = F.interpolate(conditioning_img.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], + mode='bilinear', align_corners=False) + + # print("img1_scaled", img1_scaled.shape, alpha, min_scale, num_scales) + + h2 = img2_scaled.shape[-2] + w2 = img2_scaled.shape[-1] + + s_h = h1 / h2 + s_w = w1 / w2 + + s_hs.append(s_h) + s_ws.append(s_w) + + if conditioning_img is not None: + video = torch.cat([conditioning_img_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1) + else: + video = torch.cat([img2_scaled.unsqueeze(1)]*(num_frames-1) + [img1_scaled.unsqueeze(1)], 1) + + # Should work, even if the incoming video is already 224x224 + crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1) + + num_crops = len(crops1) + + crop_flows_enc = [] + crop_flows_enc2dec = [] + N_samples = N_mask_samples + + crop = torch.cat(crops1, 0).cuda() + + optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda() + mask_counts = torch.zeros(frame_size, frame_size).cuda() + + i = 0 + while i < N_samples or (mask_counts == 0).any().item(): + if i % 100 == 0: + pass # print(i) + + # This would be that every sample has the same mask. For now that's okay I think + mask = mask_generator().bool().cuda() + mask_2f = ~mask[0, (frame_size * frame_size)*(num_frames-1):] + mask_counts += mask_2f.reshape(frame_size, frame_size) + + with torch.cuda.amp.autocast(enabled=True): + + processed_x = crop.transpose(1, 2) + + encoder_out = predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1)) + encoder_to_decoder = predictor.encoder_to_decoder(encoder_out) + + encoder_to_decoder = encoder_to_decoder[:, (frame_size * frame_size)*(num_frames-2):, :] + flow_mask = mask[:, (frame_size * frame_size)*(num_frames-2):] + + optical_flow_e2d = [] + # one per batch element for now + for b in range(B * num_crops): + batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size) + # optical_flow_e2d.append(batch_flow.unsqueeze(0)) + + optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0)) + + optical_flow_e2d = torch.cat(optical_flow_e2d, 0) + optical_flows_enc2dec += optical_flow_e2d + i += 1 + + optical_flows_enc2dec = optical_flows_enc2dec / mask_counts + + #other fucntion + # scale_factor_y = video.shape[-2] / 224 + # scale_factor_x = video.shape[-1] / 224 + # + # scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec) + # scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w + # scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h + # + # # split the crops back up + # crop_flows_enc2dec = scaled_optical_flow.split(B, 0) + + ### + #Kevin's fn + crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0) + + ### + + #Changed by Kevin + T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in + crop_flows_enc2dec] + optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, ( + B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1) + + #other function + # optical_flows_enc2dec_joined = reconstruct_video_new_2_batched( + # [_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in + # crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1) + # + all_fwd_flows_e2d.append(optical_flows_enc2dec_joined) + + #other function + # all_fwd_flows_e2d_new = [] + # + # for r in all_fwd_flows_e2d: + # new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1]) + # all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1)) + # return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1) + # + # + # return_flow = -return_flow + # all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new] + # + # return return_flow, all_fwd_flows_e2d_new + + #Kevin's method + all_fwd_flows_e2d_new = [] + + for ridx, r in enumerate(all_fwd_flows_e2d): + # print('ridx', ridx) + # print('sh', s_hs[ridx]) + # print('sw', s_ws[ridx]) + # print('scale_fac y', scale_ys[ridx]) + # print('scale_fac x', scale_xs[ridx]) + + _sh = s_hs[ridx] + _sw = s_ws[ridx] + _sfy = predictor.patch_size[-1] + _sfx = predictor.patch_size[-1] + + # plt.figure(figsize=(20, 20)) + + # plt.subplot(1,3,1) + # plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0)) + + # plt.subplot(1,3,2) + new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])], mode='bicubic', align_corners=True) + # plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0)) + + scaled_new_r = torch.zeros_like(new_r) + scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw + scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh + + # plt.subplot(1,3,3) + # plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0)) + + # plt.show() + + all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1)) + return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1) + + return_flow = -return_flow + all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new] + + return return_flow , all_fwd_flows_e2d_new + +def extract_jacobians_and_flows(img1, img2, + flow_generator, + mask, + target_mask=None): + + IMAGE_SIZE = img1.shape[-2:] + + y = torch.cat([img2.unsqueeze(1), img1.unsqueeze(1)], 1) + + jacobians, flows, _ = flow_generator(y, mask, target_mask) + + # swap x,y flow dims + flows = torch.cat([flows[0, 1].unsqueeze(0), flows[0, 0].unsqueeze(0)]) + + # upsample to 224 + flows = flows.unsqueeze(0).repeat_interleave(IMAGE_SIZE[0] // flows.shape[-1], -1).repeat_interleave( + IMAGE_SIZE[0] // flows.shape[-1], -2) + + return jacobians, flows + +import matplotlib.pyplot as plt + +class FlowToRgb(object): + + def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False): + self.max_speed = max_speed + self.from_image_coordinates = from_image_coordinates + self.from_sampling_grid = from_sampling_grid + + def __call__(self, flow): + assert flow.size(-3) == 2, flow.shape + if self.from_sampling_grid: + flow_x, flow_y = torch.split(flow, [1, 1], dim=-3) + flow_y = -flow_y + elif not self.from_image_coordinates: + flow_x, flow_y = torch.split(flow, [1, 1], dim=-3) + else: + flow_h, flow_w = torch.split(flow, [1,1], dim=-3) + flow_x, flow_y = [flow_w, -flow_h] + + + # print("flow_x", flow_x[0, :, 0, 0], flow_y[0, :, 0, 0]) + angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi + speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed + + # print("angle", angle[0, :, 0, 0] * 180 / np.pi) + + hue = torch.fmod(angle, torch.tensor(2 * np.pi)) + sat = torch.ones_like(hue) + val = speed + + hsv = torch.cat([hue, sat, val], -3) + rgb = kornia.color.hsv_to_rgb(hsv) + return rgb + + def make_colorwheel(self): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col += RY + # YG + colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col:col + YG, 1] = 255 + col += YG + # GC + colorwheel[col:col + GC, 1] = 255 + colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col += GC + # CB + colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(0, CB) / CB) + colorwheel[col:col + CB, 2] = 255 + col += CB + # BM + colorwheel[col:col + BM, 2] = 255 + colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col += BM + # MR + colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(0, MR) / MR) + colorwheel[col:col + MR, 0] = 255 + return colorwheel \ No newline at end of file diff --git a/cwm/eval/Flow/flow_utils_legacy.py b/cwm/eval/Flow/flow_utils_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..355590223e769fa77ba46ca22837662e8ea0304a --- /dev/null +++ b/cwm/eval/Flow/flow_utils_legacy.py @@ -0,0 +1,152 @@ +def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator, + mask_generator, + img1, + img2, + neg_back_flow=True, + num_scales=1, + min_scale=400, + N_mask_samples=100, + mask_ratio=0.8, + smoothing_factor=1): + B = img1.shape[0] + assert len(img1.shape) == 4 + assert num_scales >= 1 + + # For scaling + h1 = img2.shape[-2] + w1 = img2.shape[-1] + assert min_scale < h1 and min_scale >= 360 # Below 360p, the flows look terrible + + if neg_back_flow is False: + print('WARNING: Not calculating negative backward flow') + + alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1 + + frame_size = 224 // generator.patch_size[-1] + + all_fwd_flows_e2d = [] + + s_hs = [] + s_ws = [] + + for aidx in range(num_scales): + print(aidx) + + # print('aidx: ', aidx) + + img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], + mode='bicubic', align_corners=True) + img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], + mode='bicubic', align_corners=True) + + h2 = img2_scaled.shape[-2] + w2 = img2_scaled.shape[-1] + + s_h = h1 / h2 + s_w = w1 / w2 + + s_hs.append(s_h) + s_ws.append(s_w) + + # Because technically the compute_optical_flow function returns neg back flow + if neg_back_flow is True: + video = torch.cat([img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1) + else: + video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1)], 1) + + # Should work, even if the incoming video is already 224x224 + crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1) + + num_crops = len(crops1) + + crop_flows_enc = [] + crop_flows_enc2dec = [] + N_samples = N_mask_samples + + crop = torch.cat(crops1, 0).cuda() + + optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda() + mask_counts = torch.zeros(frame_size, frame_size).cuda() + + i = 0 + while i < N_samples or (mask_counts == 0).any().item(): + if i % 100 == 0: + pass # print(i) + mask_generator.mask_ratio = mask_ratio + + # This would be that every sample has the same mask. For now that's okay I think + mask = mask_generator()[None] + mask_2f = ~mask[0, frame_size * frame_size:] + mask_counts += mask_2f.reshape(frame_size, frame_size) + + with torch.cuda.amp.autocast(enabled=True): + + processed_x = generator._preprocess(crop) + + encoder_out = generator.predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1)) + encoder_to_decoder = generator.predictor.encoder_to_decoder(encoder_out) + + optical_flow_e2d = [] + # one per batch element for now + for b in range(B * num_crops): + batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), mask, frame_size) + optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0)) + + optical_flow_e2d = torch.cat(optical_flow_e2d, 0) + optical_flows_enc2dec += optical_flow_e2d + i += 1 + + optical_flows_enc2dec = optical_flows_enc2dec / mask_counts + + # split the crops back up + crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0) + + T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in + crop_flows_enc2dec] + + optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, ( + B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1) + + all_fwd_flows_e2d.append(optical_flows_enc2dec_joined) + + all_fwd_flows_e2d_new = [] + + for ridx, r in enumerate(all_fwd_flows_e2d): + # print('ridx', ridx) + # print('sh', s_hs[ridx]) + # print('sw', s_ws[ridx]) + # print('scale_fac y', scale_ys[ridx]) + # print('scale_fac x', scale_xs[ridx]) + + _sh = s_hs[ridx] + _sw = s_ws[ridx] + _sfy = generator.patch_size[-1] + _sfx = generator.patch_size[-1] + + # plt.figure(figsize=(20, 20)) + + # plt.subplot(1,3,1) + # plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0)) + + # plt.subplot(1,3,2) + new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])], + mode='bicubic', align_corners=True) + # plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0)) + + scaled_new_r = torch.zeros_like(new_r) + scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw + scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh + + # plt.subplot(1,3,3) + # plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0)) + + # plt.show() + + all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1)) + return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1) + + if neg_back_flow is True: + return_flow = -return_flow + all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new] + + return return_flow, all_fwd_flows_e2d_new \ No newline at end of file diff --git a/cwm/eval/Flow/generator.py b/cwm/eval/Flow/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..954ca3ad5ba3269bf4a58e75150155528b79a2f1 --- /dev/null +++ b/cwm/eval/Flow/generator.py @@ -0,0 +1,579 @@ +import kornia +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +import cwm.eval.Flow.masking_flow as masking + + +def boltzmann(x, beta=1, eps=1e-9): + if beta is None: + return x + x = torch.exp(x * beta) + return x / x.amax((-1,-2), keepdim=True).clamp(min=eps) + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +def imagenet_normalize(x, temporal_dim=1): + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(x.device)[None,None,:,None,None].to(x) + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(x.device)[None,None,:,None,None].to(x) + if temporal_dim == 2: + mean = mean.transpose(1,2) + std = std.transpose(1,2) + return (x - mean) / std + +def imagenet_unnormalize(x, temporal_dim=2): + device = x.device + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None].to(x) + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None].to(x) + if temporal_dim == 2: + mean = mean.transpose(1,2) + std = std.transpose(1,2) + x = x*std + mean + return x + + + +def coordinate_ims(batch_size, seq_length, imsize, normalize=True, dtype_out=torch.float32): + static = False + if seq_length == 0: + static = True + seq_length = 1 + B = batch_size + T = seq_length + H,W = imsize + ones = torch.ones([B,H,W,1], dtype=dtype_out) + if normalize: + h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=dtype_out)) + h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5) + w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=dtype_out)) + w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5) + else: + h = torch.arange(H).to(ones).view(1,H,1,1) * ones + w = torch.arange(W).to(ones).view(1,1,W,1) * ones + h = torch.stack([h]*T, 1) + w = torch.stack([w]*T, 1) + hw_ims = torch.cat([h,w], -1) + if static: + hw_ims = hw_ims[:,0] + return hw_ims + + +def get_distribution_centroid(dist, eps=1e-9, normalize=False): + + B,T,C,H,W = dist.shape + assert C == 1 + dist_sum = dist.sum((-2, -1), keepdim=True).clamp(min=eps) + dist = dist / dist_sum + + grid = coordinate_ims(B, T, [H,W], normalize=normalize).to(dist.device) + grid = grid.permute(0,1,4,2,3) + centroid = (grid * dist).sum((-2,-1)) + return centroid + + + +class FlowToRgb(object): + + def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False): + self.max_speed = max_speed + self.from_image_coordinates = from_image_coordinates + self.from_sampling_grid = from_sampling_grid + + def __call__(self, flow): + assert flow.size(-3) == 2, flow.shape + if self.from_sampling_grid: + flow_x, flow_y = torch.split(flow, [1, 1], dim=-3) + flow_y = -flow_y + elif not self.from_image_coordinates: + flow_x, flow_y = torch.split(flow, [1, 1], dim=-3) + else: + flow_h, flow_w = torch.split(flow, [1,1], dim=-3) + flow_x, flow_y = [flow_w, -flow_h] + + angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi + speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed + + hue = torch.fmod(angle, torch.tensor(2 * np.pi)) + sat = torch.ones_like(hue) + val = speed + + hsv = torch.cat([hue, sat, val], -3) + rgb = kornia.color.hsv_to_rgb(hsv) + return rgb + +class Patchify(nn.Module): + """Convert a set of images or a movie into patch vectors""" + + def __init__(self, + patch_size=(16, 16), + temporal_dim=1, + squeeze_channel_dim=True + ): + super().__init__() + self.set_patch_size(patch_size) + self.temporal_dim = temporal_dim + assert self.temporal_dim in [1, 2], self.temporal_dim + self._squeeze_channel_dim = squeeze_channel_dim + + @property + def num_patches(self): + if (self.T is None) or (self.H is None) or (self.W is None): + return None + else: + return (self.T // self.pt) * (self.H // self.ph) * (self.W // self.pw) + + def set_patch_size(self, patch_size): + self.patch_size = patch_size + if len(self.patch_size) == 2: + self.ph, self.pw = self.patch_size + self.pt = 1 + self._patches_are_3d = False + elif len(self.patch_size) == 3: + self.pt, self.ph, self.pw = self.patch_size + self._patches_are_3d = True + else: + raise ValueError("patch_size must be a 2- or 3-tuple, but is %s" % self.patch_size) + + self.shape_inp = self.rank_inp = self.H = self.W = self.T = None + self.D = self.C = self.E = self.embed_dim = None + + def _check_shape(self, x): + self.shape_inp = x.shape + self.rank_inp = len(self.shape_inp) + self.H, self.W = self.shape_inp[-2:] + assert (self.H % self.ph) == 0 and (self.W % self.pw) == 0, (self.shape_inp, self.patch_size) + if (self.rank_inp == 5) and self._patches_are_3d: + self.T = self.shape_inp[self.temporal_dim] + assert (self.T % self.pt) == 0, (self.T, self.pt) + elif self.rank_inp == 5: + self.T = self.shape_inp[self.temporal_dim] + else: + self.T = 1 + + def split_by_time(self, x): + shape = x.shape + assert shape[1] % self.T == 0, (shape, self.T) + return x.view(shape[0], self.T, shape[1] // self.T, *shape[2:]) + + def merge_by_time(self, x): + shape = x.shape + return x.view(shape[0], shape[1] * shape[2], *shape[3:]) + + def video_to_patches(self, x): + if self.rank_inp == 4: + assert self.pt == 1, (self.pt, x.shape) + x = rearrange(x, 'b c (h ph) (w pw) -> b (h w) (ph pw) c', ph=self.ph, pw=self.pw) + else: + assert self.rank_inp == 5, (x.shape, self.rank_inp, self.shape_inp) + dim_order = 'b (t pt) c (h ph) (w pw)' if self.temporal_dim == 1 else 'b c (t pt) (h ph) (w pw)' + x = rearrange(x, dim_order + ' -> b (t h w) (pt ph pw) c', pt=self.pt, ph=self.ph, pw=self.pw) + + self.N, self.D, self.C = x.shape[-3:] + self.embed_dim = self.E = self.D * self.C + return x + + def patches_to_video(self, x): + shape = x.shape + rank = len(shape) + if rank == 4: + B, _N, _D, _C = shape + else: + assert rank == 3, rank + B, _N, _E = shape + assert (_E % self.D == 0), (_E, self.D) + x = x.view(B, _N, self.D, -1) + + if _N < self.num_patches: + masked_patches = self.get_masked_patches( + x, + num_patches=(self.num_patches - _N), + mask_mode=self.mask_mode) + x = torch.cat([x, masked_patches], 1) + + x = rearrange( + x, + 'b (t h w) (pt ph pw) c -> b c (t pt) (h ph) (w pw)', + pt=self.pt, ph=self.ph, pw=self.pw, + t=(self.T // self.pt), h=(self.H // self.ph), w=(self.W // self.pw)) + + if self.rank_inp == 5 and (self.temporal_dim == 1): + x = x.transpose(1, 2) + elif self.rank_inp == 4: + assert x.shape[2] == 1, x.shape + x = x[:, :, 0] + return x + + @staticmethod + def get_masked_patches(x, num_patches, mask_mode='zeros'): + shape = x.shape + patches_shape = (shape[0], num_patches, *shape[2:]) + if mask_mode == 'zeros': + return torch.zeros(patches_shape).to(x.device).to(x.dtype).detach() + elif mask_mode == 'gray': + return 0.5 * torch.ones(patches_shape).to(x.device).to(x.dtype).detach() + else: + raise NotImplementedError("Haven't implemented mask_mode == %s" % mask_mode) + + def average_within_patches(self, z): + if len(z.shape) == 3: + z = rearrange(z, 'b n (d c) -> b n d c', c=self.C) + return z.mean(-2, True).repeat(1, 1, z.shape[-2], 1) + + def forward(self, x, to_video=False, mask_mode='zeros'): + if not to_video: + self._check_shape(x) + x = self.video_to_patches(x) + return x if not self._squeeze_channel_dim else x.view(x.size(0), self.N, -1) + + else: # x are patches + assert (self.shape_inp is not None) and (self.num_patches is not None) + self.mask_mode = mask_mode + x = self.patches_to_video(x) + return x + + +class DerivativeFlowGenerator(nn.Module): + """Estimate flow of a two-frame predictor using torch autograd""" + + def __init__(self, + predictor, + perturbation_patch_size=None, + aggregation_patch_size=None, + agg_power=None, + agg_channel_func=None, + num_samples=1, + leave_one_out_sampling=False, + average_jacobian=True, + confidence_thresh=None, + temporal_dim=2, + imagenet_normalize_inputs=True): + + super(DerivativeFlowGenerator, self).__init__() + + self.predictor = predictor + + self.patchify = Patchify(self.patch_size, temporal_dim=1, squeeze_channel_dim=True) + + self.set_temporal_dim(temporal_dim) + + self.imagenet_normalize_inputs = imagenet_normalize_inputs + + self.perturbation_patch_size = self._get_patch_size(perturbation_patch_size) or self.patch_size + self.aggregation_patch_size = self._get_patch_size(aggregation_patch_size) or self.patch_size + self.agg_patchify = Patchify(self.aggregation_patch_size, + temporal_dim=1, + squeeze_channel_dim=False) + self.agg_channel_func = agg_channel_func or (lambda x: F.relu(x).sum(-3, True)) + self.average_jacobian = average_jacobian + self.confidence_thresh = confidence_thresh + + self.num_samples = num_samples + self.leave_one_out_sampling = leave_one_out_sampling + self.agg_power = agg_power + self.t_dim = temporal_dim + + def _get_patch_size(self, p): + if p is None: + return None + elif isinstance(p, int): + return (1, p, p) + elif len(p) == 2: + return (1, p[0], p[1]) + else: + assert len(p) == 3, p + return (p[0], p[1], p[2]) + + def set_temporal_dim(self, t_dim): + if t_dim == 1: + self.predictor.t_dim = 1 + self.predictor.c_dim = 2 + elif t_dim == 2: + self.predictor.c_dim = 1 + self.predictor.t_dim = 2 + else: + raise ValueError("temporal_dim must be 1 or 2") + + @property + def c_dim(self): + if self.predictor is None: + return None + return self.predictor.c_dim + + @property + def patch_size(self): + if self.predictor is None: + return None + elif hasattr(self.predictor, 'patch_size'): + return self.predictor.patch_size + elif hasattr(self.predictor.encoder.patch_embed, 'proj'): + return self.predictor.encoder.patch_embed.proj.kernel_size + else: + return None + @property + def S(self): + return self.num_samples + + @property + def sequence_length(self): + if self.predictor is None: + return None + elif hasattr(self.predictor, 'sequence_length'): + return self.predictor.sequence_length + elif hasattr(self.predictor, 'num_frames'): + return self.predictor.num_frames + else: + return 2 + @property + def mask_shape(self): + if self.predictor is None: + return None + elif hasattr(self.predictor, 'mask_shape'): + return self.predictor.mask_shape + + assert self.patch_size is not None + pt, ph, pw = self.patch_size + return (self.sequence_length // pt, + self.inp_shape[-2] // ph, + self.inp_shape[-1] // pw) + + @property + def perturbation_mask_shape(self): + return ( + self.mask_shape[0], + self.inp_shape[-2] // self.perturbation_patch_size[-2], + self.inp_shape[-1] // self.perturbation_patch_size[-1] + ) + + + + @property + def p_mask_shape(self): + return self.perturbation_mask_shape + + @property + def aggregation_mask_shape(self): + return ( + 1, + self.inp_shape[-2] // self.aggregation_patch_size[-2], + self.inp_shape[-1] // self.aggregation_patch_size[-1] + ) + + @property + def a_mask_shape(self): + return self.aggregation_mask_shape + + def get_perturbation_input(self, x): + self.set_input(x) + y = torch.zeros((self.B, *self.p_mask_shape), dtype=x.dtype, device=x.device, requires_grad=True) + y = y.unsqueeze(2).repeat(1, 1, x.shape[2], 1, 1) + return y + + def pred_patches_to_video(self, y, x, mask): + """input at visible positions, preds at masked positions""" + B, C = y.shape[0], y.shape[-1] + self.patchify._check_shape(x) + self.patchify.D = np.prod(self.patch_size) + x = self.patchify(x) + y_out = torch.zeros_like(x) + x_vis = x[~mask] + + y_out[~mask] = x_vis.view(-1, C) + try: + y_out[mask] = y.view(-1, C) + except: + y_out[mask] = y.reshape(-1, C) + + return self.patchify(y_out, to_video=True) + + def set_image_size(self, *args, **kwargs): + assert self.predictor is not None, "Can't set the image size without a predictor" + if hasattr(self.predictor, 'set_image_size'): + self.predictor.set_image_size(*args, **kwargs) + else: + self.predictor.image_size = args[0] + + def predict(self, x=None, mask=None, forward_full=False): + if x is None: + x = self.x + if mask is None: + mask = self.generate_mask(x) + + self.set_image_size(x.shape[-2:]) + y = self.predictor( + self._preprocess(x), + mask if (x.size(0) == 1) else self.mask_rectangularizer(mask), forward_full=forward_full) + + y = self.pred_patches_to_video(y, x, mask=mask) + + frame = -1 % y.size(1) + y = y[:, frame:frame + 1] + + return y + + def _get_perturbation_func(self, x=None, mask=None): + + if (x is not None): + self.set_input(x, mask) + + def forward_mini_image(y): + y = y.repeat_interleave(self.perturbation_patch_size[-2], -2) + y = y.repeat_interleave(self.perturbation_patch_size[-1], -1) + x_pred = self.predict(self.x + y, self.mask) + x_pred = self.agg_patchify(x_pred).mean(-2).sum(-1).view(self.B, *self.a_mask_shape) + return x_pred[self.targets] + + return forward_mini_image + + def _postprocess_jacobian(self, jac): + _jac = torch.zeros((self.B, *self.a_mask_shape, *jac.shape[1:])).to(jac.device).to(jac.dtype) + _jac[self.targets] = jac + jac = self.agg_channel_func(_jac) + assert jac.size(-3) == 1, jac.shape + jac = jac.squeeze(-3)[..., 0, :, :] # derivative w.r.t. first frame and agg channels + jac = jac.view(self.B, self.a_mask_shape[-2], self.a_mask_shape[-1], + self.B, self.p_mask_shape[-2], self.p_mask_shape[-1]) + bs = torch.arange(0, self.B).long().to(jac.device) + jac = jac[bs, :, :, bs, :, :] # take diagonal + return jac + + def _confident_jacobian(self, jac): + if self.confidence_thresh is None: + return torch.ones_like(jac[:, None, ..., 0, 0]) + conf = (jac.amax((-2, -1)) > self.confidence_thresh).float()[:, None] + return conf + + def set_input(self, x, mask=None, timestamps=None): + shape = x.shape + if len(shape) == 4: + x = x.unsqueeze(1) + else: + assert len(shape) == 5, \ + "Input must be a movie of shape [B,T,C,H,W]" + \ + "or a single frame of shape [B,C,H,W]" + + self.inp_shape = x.shape + self.x = x + self.B = self.inp_shape[0] + self.T = self.inp_shape[1] + self.C = self.inp_shape[2] + if mask is not None: + self.mask = mask + + if timestamps is not None: + self.timestamps = timestamps + + def _preprocess(self, x): + if self.imagenet_normalize_inputs: + x = imagenet_normalize(x) + if self.t_dim != 1: + x = x.transpose(self.t_dim, self.c_dim) + return x + + def _jacobian_to_flows(self, jac): + if self.agg_power is None: + jac = (jac == jac.amax((-2, -1), True)).float() + else: + jac = torch.pow(jac, self.agg_power) + + jac = jac.view(self.B * np.prod(self.a_mask_shape[-2:]), 1, 1, *self.p_mask_shape[-2:]) + centroids = get_distribution_centroid(jac, normalize=False).view( + self.B, self.a_mask_shape[-2], self.a_mask_shape[-1], 2) + rescale = [self.a_mask_shape[-2] / self.p_mask_shape[-2], + self.a_mask_shape[-1] / self.p_mask_shape[-1]] + centroids = centroids * torch.tensor(rescale, device=centroids.device).view(1, 1, 1, 2) + + flows = centroids - \ + coordinate_ims(1, 0, self.a_mask_shape[-2:], normalize=False).to(jac.device) + flows = flows.permute(0, 3, 1, 2) + px_scale = torch.tensor(self.aggregation_patch_size[-2:]).float().to(flows.device).view(1, 2, 1, 1) + flows *= px_scale + + return flows + + def set_targets(self, targets=None, frame=-1): + frame = frame % self.mask_shape[0] + if targets is None: + targets = self.get_mask_image(self.mask)[:, frame:frame + 1] + else: + assert len(targets.shape) == 4, targets.shape + targets = targets[:, frame:frame + 1] + self.targets = ~masking.upsample_masks(~targets, self.a_mask_shape[-2:]) + + def _get_mask_partition(self, mask): + mask = self.get_mask_image(mask) + mask_list = masking.partition_masks( + mask[:, 1:], num_samples=self.S, leave_one_out=self.leave_one_out_sampling) + return [torch.cat([mask[:, 0:1].view(m.size(0), -1), m], -1) + for m in mask_list] + + def _compute_jacobian(self, y): + perturbation_func = self._get_perturbation_func() + jac = torch.autograd.functional.jacobian( + perturbation_func, + y, + vectorize=False) + jac = self._postprocess_jacobian(jac) + return jac + + def _upsample_mask(self, mask): + return masking.upsample_masks( + mask.view(mask.size(0), -1, *self.mask_shape[-2:]).float(), self.inp_shape[-2:]) + + def get_mask_image(self, mask, upsample=False, invert=False, shape=None): + if shape is None: + shape = self.mask_shape + mask = mask.view(-1, *shape) + if upsample: + mask = self._upsample_mask(mask) + if invert: + mask = 1 - mask + return mask + + def forward(self, x, mask, targets=None): + self.set_input(x, mask) + y = self.get_perturbation_input(x) + mask_list = self._get_mask_partition(mask) + + jacobian, flows, confident = [], [], [] + for s, mask_sample in enumerate(mask_list): + self.set_input(x, mask_sample) + self.set_targets(targets) + + import time + t1 = time.time() + jac = self._compute_jacobian(y) + conf_jac = masking.upsample_masks(self._confident_jacobian(jac), self.a_mask_shape[-2:]) + jacobian.append(jac) + confident.append(conf_jac) + if not self.average_jacobian: + flow = self._jacobian_to_flows(jac) * self.targets * conf_jac * \ + masking.upsample_masks(self.get_mask_image(self.mask)[:, 1:], self.a_mask_shape[-2:]) + flows.append(flow) + t2 = time.time() + print(t2 - t1) + + jacobian = torch.stack(jacobian, -1) + confident = torch.stack(confident, -1) + valid = torch.stack([masking.upsample_masks( + self.get_mask_image(m)[:, 1:], self.a_mask_shape[-2:]) for m in mask_list], -1) + valid = valid * confident + + if self.average_jacobian: + _valid = valid[:, 0].unsqueeze(-2).unsqueeze(-2) + jac = (jacobian * _valid.float()).sum(-1) / _valid.float().sum(-1).clamp(min=1) + flows = self._jacobian_to_flows(jac) * \ + masking.upsample_masks(_valid[:, None, ..., 0, 0, :].amax(-1).bool(), self.a_mask_shape[-2:]) + if targets is not None: + self.set_targets(targets) + flows *= self.targets + else: + flows = torch.stack(flows, -1) + flows = flows.sum(-1) / valid.float().sum(-1).clamp(min=1) + + valid = valid * (targets[:, -1:].unsqueeze(-1) if targets is not None else 1) + + return (jacobian, flows, valid) \ No newline at end of file diff --git a/cwm/eval/Flow/losses.py b/cwm/eval/Flow/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..25ee0ea28deb584aba156b9d91264aab29dac8dd --- /dev/null +++ b/cwm/eval/Flow/losses.py @@ -0,0 +1,60 @@ +import torch +import torch.nn.functional as F +from torchvision import transforms + + +def sampling_grid(height, width): + H, W = height, width + grid = torch.stack([ + torch.arange(W).view(1, -1).repeat(H, 1), + torch.arange(H).view(-1, 1).repeat(1, W) + ], -1) + grid = grid.view(1, H, W, 2) + return grid + + +def normalize_sampling_grid(coords): + assert len(coords.shape) == 4, coords.shape + assert coords.size(-1) == 2, coords.shape + H, W = coords.shape[-3:-1] + xs, ys = coords.split([1, 1], -1) + xs = 2 * xs / (W - 1) - 1 + ys = 2 * ys / (H - 1) - 1 + return torch.cat([xs, ys], -1) + + +def backward_warp(img2, flow, do_mask=False): + """ + Grid sample from img2 using the flow from img1->img2 to get a prediction of img1. + + flow: [B,2,H',W'] in units of pixels at its current resolution. The two channels + should be (x,y) where larger y values correspond to lower parts of the image. + """ + + ## resize the flow to the image size. + ## since flow has units of pixels, its values need to be rescaled accordingly. + if list(img2.shape[-2:]) != list(flow.shape[-2:]): + scale = [img2.size(-1) / flow.size(-1), # x + img2.size(-2) / flow.size(-2)] # y + scale = torch.tensor(scale).view(1, 2, 1, 1).to(flow.device) + flow = scale * transforms.Resize(img2.shape[-2:])(flow) # defaults to bilinear + + B, C, H, W = img2.shape + + ## use flow to warp sampling grid + grid = sampling_grid(H, W).to(flow.device) + flow.permute(0, 2, 3, 1) + + ## put grid in normalized image coordinates + grid = normalize_sampling_grid(grid) + + ## backward warp, i.e. sample pixel (x,y) from (x+flow_x, y+flow_y) + img1_pred = F.grid_sample(img2, grid, align_corners=True) + + if do_mask: + mask = (grid[..., 0] > -1) & (grid[..., 0] < 1) & \ + (grid[..., 1] > -1) & (grid[..., 1] < 1) + mask = mask[:, None].to(img2.dtype) + return (img1_pred, mask) + + else: + return (img1_pred, torch.ones_like(grid[..., 0][:, None]).float()) diff --git a/cwm/eval/Flow/masking_flow.py b/cwm/eval/Flow/masking_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..650250d7b8e741154435ff675a3b703bf6990847 --- /dev/null +++ b/cwm/eval/Flow/masking_flow.py @@ -0,0 +1,375 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms + +def upsample_masks(masks, size, thresh=0.5): + shape = masks.shape + dtype = masks.dtype + h, w = shape[-2:] + H, W = size + if (H == h) and (W == w): + return masks + elif (H < h) and (W < w): + s = (h // H, w // W) + return masks[..., ::s[0], ::s[1]] + + masks = masks.unsqueeze(-2).unsqueeze(-1) + masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w) + if ((H % h) == 0) and ((W % w) == 0): + masks = masks.view(*shape[:-2], H, W) + else: + _H = np.prod(masks.shape[-4:-2]) + _W = np.prod(masks.shape[-2:]) + masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh + masks = masks.view(*shape[:2], H, W).to(masks.dtype) + return masks + + + + +def partition_masks(masks, num_samples=2, leave_one_out=False): + B = masks.shape[0] + S = num_samples + masks = masks.view(B, -1) + partitioned = [torch.ones_like(masks) for _ in range(S)] + for b in range(B): + vis_inds = torch.where(~masks[b])[0] + vis_inds = vis_inds[torch.randperm(vis_inds.size(0))] + if leave_one_out: + for s in range(S): + partitioned[s][b][vis_inds] = 0 + partitioned[s][b][vis_inds[s::S]] = 1 + else: + for s in range(S): + partitioned[s][b][vis_inds[s::S]] = 0 + return partitioned + + +class RectangularizeMasks(nn.Module): + """Make sure all masks in a batch have same number of 1s and 0s""" + + def __init__(self, truncation_mode='min'): + super().__init__() + self._mode = truncation_mode + assert self._mode in ['min', 'max', 'mean', 'full', 'none', None], (self._mode) + + def set_mode(self, mode): + self._mode = mode + + def __call__(self, masks): + + if self._mode in ['none', None]: + return masks + + assert isinstance(masks, torch.Tensor), type(masks) + if self._mode == 'full': + return torch.ones_like(masks) + + shape = masks.shape + masks = masks.flatten(1) + B, N = masks.shape + num_masked = masks.float().sum(-1) + M = { + 'min': torch.amin, 'max': torch.amax, 'mean': torch.mean + }[self._mode](num_masked).long() + + num_changes = num_masked.long() - M + + for b in range(B): + nc = num_changes[b] + if nc > 0: + inds = torch.where(masks[b])[0] + inds = inds[torch.randperm(inds.size(0))[:nc].to(inds.device)] + masks[b, inds] = 0 + elif nc < 0: + inds = torch.where(~masks[b])[0] + inds = inds[torch.randperm(inds.size(0))[:-nc].to(inds.device)] + masks[b, inds] = 1 + if list(masks.shape) != list(shape): + masks = masks.view(*shape) + + return masks + + +class UniformMaskingGenerator(object): + def __init__(self, input_size, mask_ratio, seed=None, clumping_factor=1, randomize_num_visible=False): + self.frames = None + if len(input_size) == 3: + self.frames, self.height, self.width = input_size + elif len(input_size) == 2: + self.height, self.width = input_size + elif len(input_size) == 1 or isinstance(input_size, int): + self.height = self.width = input_size + + self.clumping_factor = clumping_factor + self.pad_h = self.height % self.c[0] + self.pad_w = self.width % self.c[1] + self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1]) + self.mask_ratio = mask_ratio + + self.rng = np.random.RandomState(seed=seed) + self.randomize_num_visible = randomize_num_visible + + @property + def num_masks_per_frame(self): + if not hasattr(self, '_num_masks_per_frame'): + self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame) + return self._num_masks_per_frame + + @num_masks_per_frame.setter + def num_masks_per_frame(self, val): + self._num_masks_per_frame = val + self._mask_ratio = (val / self.num_patches_per_frame) + + @property + def c(self): + if isinstance(self.clumping_factor, int): + return (self.clumping_factor, self.clumping_factor) + else: + return self.clumping_factor[:2] + + @property + def mask_ratio(self): + return self._mask_ratio + + @mask_ratio.setter + def mask_ratio(self, val): + self._mask_ratio = val + self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame) + + @property + def num_visible(self): + return self.num_patches_per_frame - self.num_masks_per_frame + + @num_visible.setter + def num_visible(self, val): + self.num_masks_per_frame = self.num_patches_per_frame - val + + def __repr__(self): + repr_str = "Mask: total patches per frame {}, mask patches per frame {}, mask ratio {}, random num num visible? {}".format( + self.num_patches_per_frame, self.num_masks_per_frame, self.mask_ratio, self.randomize_num_visible + ) + return repr_str + + def sample_mask_per_frame(self): + num_masks = self.num_masks_per_frame + if self.randomize_num_visible: + num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1)) + mask = np.hstack([ + np.zeros(self.num_patches_per_frame - num_masks), + np.ones(num_masks)]) + self.rng.shuffle(mask) + if max(*self.c) > 1: + mask = mask.reshape(self.height // self.c[0], + 1, + self.width // self.c[1], + 1) + mask = np.tile(mask, (1, self.c[0], 1, self.c[1])) + mask = mask.reshape((self.height - self.pad_h, self.width - self.pad_w)) + _pad_h = self.rng.choice(range(self.pad_h + 1)) + pad_h = (self.pad_h - _pad_h, _pad_h) + _pad_w = self.rng.choice(range(self.pad_w + 1)) + pad_w = (self.pad_w - _pad_w, _pad_w) + mask = np.pad(mask, + (pad_h, pad_w), + constant_values=1 + ).reshape((self.height, self.width)) + return mask + + def __call__(self, num_frames=None): + num_frames = (num_frames or self.frames) or 1 + masks = np.stack([self.sample_mask_per_frame() for _ in range(num_frames)]).flatten() + return masks + + +class TubeMaskingGenerator(UniformMaskingGenerator): + + def __call__(self, num_frames=None): + num_frames = (num_frames or self.frames) or 1 + masks = np.tile(self.sample_mask_per_frame(), (num_frames, 1)).flatten() + return masks + + +class RotatedTableMaskingGenerator(TubeMaskingGenerator): + + def __init__(self, tube_length=None, *args, **kwargs): + super(RotatedTableMaskingGenerator, self).__init__(*args, **kwargs) + self.tube_length = tube_length + + def __call__(self, num_frames=None): + num_frames = (num_frames or self.frames) or 2 + tube_length = self.tube_length or (num_frames - 1) + table_thickness = num_frames - tube_length + assert tube_length < num_frames, (tube_length, num_frames) + + tubes = super().__call__(num_frames=tube_length) + top = np.zeros(table_thickness * self.height * self.width).astype(tubes.dtype).flatten() + masks = np.concatenate([top, tubes], 0) + return masks + + +class PytorchMaskGeneratorWrapper(nn.Module): + """Pytorch wrapper for numpy masking generators""" + + def __init__(self, + mask_generator=TubeMaskingGenerator, + *args, **kwargs): + super().__init__() + self.mask_generator = mask_generator(*args, **kwargs) + + @property + def mask_ratio(self): + return self.mask_generator.mask_ratio + + @mask_ratio.setter + def mask_ratio(self, value): + self.mask_generator.mask_ratio = value + + def forward(self, device='cuda', dtype_out=torch.bool, **kwargs): + masks = self.mask_generator(**kwargs) + masks = torch.tensor(masks).to(device).to(dtype_out) + return masks + + +class MaskingGenerator(nn.Module): + """Pytorch base class for masking generators""" + + def __init__(self, + input_size, + mask_ratio, + seed=0, + visible_frames=0, + clumping_factor=1, + randomize_num_visible=False, + create_on_cpu=True, + always_batch=False): + super().__init__() + self.frames = None + + if len(input_size) == 3: + self.frames, self.height, self.width = input_size + elif len(input_size) == 2: + self.height, self.width = input_size + elif len(input_size) == 1 or isinstance(input_size, int): + self.height = self.width = input_size + + self.clumping_factor = clumping_factor + self.pad_h = self.height % self.c[0] + self.pad_w = self.width % self.c[1] + self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1]) + + self.mask_ratio = mask_ratio + self.visible_frames = visible_frames + self.always_batch = always_batch + self.create_on_cpu = create_on_cpu + + self.rng = np.random.RandomState(seed=seed) + self._set_torch_seed(seed) + + self.randomize_num_visible = randomize_num_visible + + @property + def num_masks_per_frame(self): + if not hasattr(self, '_num_masks_per_frame'): + self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame) + return self._num_masks_per_frame + + @num_masks_per_frame.setter + def num_masks_per_frame(self, val): + self._num_masks_per_frame = val + self._mask_ratio = (val / self.num_patches_per_frame) + + @property + def c(self): + if isinstance(self.clumping_factor, int): + return (self.clumping_factor,) * 2 + else: + return self.clumping_factor[:2] + + @property + def mask_ratio(self): + return self._mask_ratio + + @mask_ratio.setter + def mask_ratio(self, val): + self._mask_ratio = val + self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame) + + @property + def num_visible(self): + return self.num_patches_per_frame - self.num_masks_per_frame + + @num_visible.setter + def num_visible(self, val): + self.num_masks_per_frame = self.num_patches_per_frame - val + + def _set_torch_seed(self, seed): + self.seed = seed + torch.manual_seed(self.seed) + + def __repr__(self): + repr_str = ("Class: {}\nMask: total patches per mask {},\n" + \ + "mask patches per mask {}, visible patches per mask {}, mask ratio {:0.3f}\n" + \ + "randomize num visible? {}").format( + type(self).__name__, self.num_patches_per_frame, + self.num_masks_per_frame, self.num_visible, self.mask_ratio, + self.randomize_num_visible + ) + return repr_str + + def sample_mask_per_frame(self, *args, **kwargs): + num_masks = self.num_masks_per_frame + if self.randomize_num_visible: + num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1)) + + mask = torch.cat([ + torch.zeros([self.num_patches_per_frame - num_masks]), + torch.ones([num_masks])], 0).bool() + inds = torch.randperm(mask.size(0)).long() + mask = mask[inds] + + if max(*self.c) > 1: + mask = mask.view(self.height // self.c[0], + 1, + self.width // self.c[1], + 1) + mask = torch.tile(mask, (1, self.c[0], 1, self.c[1])) + mask = mask.reshape(self.height - self.pad_h, self.width - self.pad_w) + _pad_h = self.rng.choice(range(self.pad_h + 1)) + pad_h = (self.pad_h - _pad_h, _pad_h) + _pad_w = self.rng.choice(range(self.pad_w + 1)) + pad_w = (self.pad_w - _pad_w, _pad_w) + mask = F.pad(mask, + pad_w + pad_h, + mode='constant', + value=1) + mask = mask.reshape(self.height, self.width) + + return mask + + def forward(self, x=None, num_frames=None): + + num_frames = (num_frames or self.frames) or 1 + if isinstance(x, torch.Tensor): + batch_size = x.size(0) + masks = torch.stack([ + torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten() + for b in range(batch_size)], 0) + if not self.create_on_cpu: + masks = masks.to(x.device) + if batch_size == 1 and not self.always_batch: + masks = masks.squeeze(0) + else: + batch_size = 1 + masks = torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten() + if self.always_batch: + masks = masks[None] + + if self.visible_frames > 0: + vis = torch.zeros((batch_size, 1, self.height, self.width), dtype=torch.bool) + vis = vis.view(masks.shape).to(masks.device) + masks = torch.cat(([vis] * self.visible_frames) + [masks], -1) + + return masks diff --git a/cwm/eval/Flow/vis_utils.py b/cwm/eval/Flow/vis_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e2767b1c3d84b9f4ed1382463106f9de96003f8c --- /dev/null +++ b/cwm/eval/Flow/vis_utils.py @@ -0,0 +1,150 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch + + +def imshow(ims, ax=None, t=0, vmin=None, vmax=None, title=None, cmap=None, fontsize=20): + if ax is None: + fig, ax = plt.subplots(1,1) + with torch.no_grad(): + im = ims[t].float().cpu().numpy().transpose((1,2,0)) + if (vmin is not None) and (vmax is not None): + im =ax.imshow(im, vmin=vmin, vmax=vmax, cmap=(cmap or 'viridis')) + else: + im =ax.imshow(im) + + if title is not None: + ax.set_title(title, fontsize=fontsize) + + return (im, ax) + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) + +from decord import VideoReader, cpu +from PIL import Image +from torchvision import transforms +def get_video(video_name, num_frames=2, delta_time=4, frame=None): + decord_vr = VideoReader(video_name, num_threads=1, ctx=cpu(0)) + max_end_ind = len(decord_vr) - num_frames*delta_time - 1 + start_frame = frame if frame is not None else rng.randint(1, max_end_ind) + print("fps", decord_vr.get_avg_fps()) + print("start frame = %d" % start_frame) + frame_id_list = list(range(start_frame, start_frame + num_frames*delta_time, delta_time)) + video_data = decord_vr.get_batch(frame_id_list).asnumpy() + video_data = [Image.fromarray(video_data[t]).convert('RGB') for t, _ in enumerate(frame_id_list)] + return (torch.stack([transforms.ToTensor()(im) for im in video_data], 0), start_frame) + + + diff --git a/cwm/eval/IntPhys/__init__.py b/cwm/eval/IntPhys/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/eval/Physion/__init__.py b/cwm/eval/Physion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/eval/Physion/feature_extractor.py b/cwm/eval/Physion/feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..e050ff61cf27102fb34093da708f4ad40a359b5e --- /dev/null +++ b/cwm/eval/Physion/feature_extractor.py @@ -0,0 +1,317 @@ +import numpy as np +from physion_evaluator.feature_extract_interface import PhysionFeatureExtractor +from physion_evaluator.utils import DataAugmentationForVideoMAE + +from torch.functional import F + +from cwm.eval.Flow.flow_utils import get_occ_masks + +from cwm.model.model_factory import model_factory +import torch + +def load_predictor( + model_func_, + load_path_, + **kwargs): + predictor = model_func_(**kwargs).eval().requires_grad_(False) + + did_load = predictor.load_state_dict( + torch.load(load_path_, map_location=torch.device("cpu"))['model']) + predictor._predictor_load_path = load_path_ + print(did_load, load_path_) + return predictor + + +class CWM(PhysionFeatureExtractor): + def __init__(self, model_name, aggregate_embeddings=False): + super().__init__() + + self.model = model_factory.load_model(model_name).cuda().half() + + self.num_frames = self.model.num_frames + + self.timestamps = np.arange(self.num_frames) + + ps = (224 // self.model.patch_size[1]) ** 2 + + self.bool_masked_pos = np.zeros([ps * self.num_frames]) + self.bool_masked_pos[ps * (self.num_frames - 1):] = 1 + + self.ps = ps + + self.aggregate_embeddings = aggregate_embeddings + + def transform(self): + + return DataAugmentationForVideoMAE( + imagenet_normalize=True, + rescale_size=224, + ), 150, 4 + + def fwd(self, videos): + bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool() + bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0]) + x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True, + return_features=True) + return x_encoded + + def extract_features(self, videos, for_flow=False): + ''' + videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm + returns: [B, T, D] extracted features + ''' + + videos = videos.transpose(1, 2) + + all_features = [] + + # repeat the last frame of the video + videos = torch.cat([videos, videos[:, :, -1:]], dim=2) + + for x in range(0, 4, self.num_frames - 1): + vid = videos[:, :, x:x + self.num_frames, :, :] + all_features.append(self.fwd(vid)) + if self.aggregate_embeddings: + feats = all_features[-1].mean(dim=1, keepdim=True) + all_features[-1] = feats + # feats = feats.view(feats.shape[0], -1, self.model.num_patches_per_frame, feats.shape[-1]) + # feats = feats.mean(dim=2) + # all_features[-1] = feats + + x_encoded = torch.cat(all_features, dim=1) + + return x_encoded + + +class CWM_Keypoints(PhysionFeatureExtractor): + def __init__(self, model_name): + super().__init__() + + self.model = model_factory.load_model(model_name).cuda().half() + + self.frames = [[0, 1, 2], [1, 2, 3]] + + self.num_frames = self.model.num_frames + + self.ps = (224 // self.model.patch_size[1]) ** 2 + + self.bool_masked_pos = np.zeros([self.ps * self.num_frames]) + self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1 + + self.frame_gap = 150 + + self.num_frames_dataset = 4 + + self.res = 224 + + + def transform(self): + + return DataAugmentationForVideoMAE( + imagenet_normalize=True, + rescale_size=self.res, + ), self.frame_gap, self.num_frames_dataset + + def fwd(self, videos): + bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool() + bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0]) + _, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True, + return_features=True) + return x_encoded + + def extract_features(self, videos, segments=None): + ''' + videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm + returns: [B, T, D] extracted features + ''' + + videos = videos.transpose(1, 2) + + all_features = [] + + for x, arr in enumerate(self.frames): + + #use the downsampled videos for keypoints + vid = videos[:, :, arr, :, :].half() + frame0 = vid[:, :, 0] + frame1 = vid[:, :, 1] + frame2 = vid[:, :, 2] + + #extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2 + mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1) + + #reshape the features to [batch size, num_features] + k_feat = k_feat.view(k_feat.shape[0], -1) + + all_features.append(k_feat) + + x_encoded = torch.cat(all_features, dim=1) + + return x_encoded + + +class CWM_KeypointsFlow(PhysionFeatureExtractor): + def __init__(self, model_name): + super().__init__() + + self.model = model_factory.load_model(model_name).cuda().half() + + self.frames = [[0, 3, 6], [3, 6, 9], [6, 9, 9]] + + self.num_frames = self.model.num_frames + + self.timestamps = np.arange(self.num_frames) + + self.ps = (224 // self.model.patch_size[1]) ** 2 + + self.bool_masked_pos = np.zeros([self.ps * self.num_frames]) + self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1 + + self.frame_gap = 50 + + self.num_frames_dataset = 9 + + self.res = 512 + + def transform(self): + + return DataAugmentationForVideoMAE( + imagenet_normalize=True, + rescale_size=self.res, + ), self.frame_gap, self.num_frames_dataset + + def fwd(self, videos): + bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool() + bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0]) + _, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True, + return_features=True) + return x_encoded + + def get_forward_flow(self, videos): + + fid = 6 + + forward_flow = self.model.get_flow(videos[:, :, fid], videos[:, :, fid + 1], conditioning_img=videos[:, :, fid + 2], mode='cosine') + + backward_flow = self.model.get_flow(videos[:, :, fid + 1], videos[:, :, fid], conditioning_img=videos[:, :, fid - 1], mode='cosine') + + occlusion_mask = get_occ_masks(forward_flow, backward_flow)[0] + + forward_flow = forward_flow * occlusion_mask + + forward_flow = torch.stack([forward_flow, forward_flow, forward_flow], dim=1) + + forward_flow = forward_flow.to(videos.device) + + forward_flow = F.interpolate(forward_flow, size=(2, 224, 224), mode='nearest') + + return forward_flow + + def extract_features(self, videos, segments=None): + ''' + videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm + returns: [B, T, D] extracted features + Note: + For efficiency, the optical flow is computed and added for a single frame (300ms) as we found this to be sufficient + for capturing temporal dynamics in our experiments. This approach can be extended to multiple frames if needed, + depending on the complexity of the task. + ''' + + + #resize to 224 to get keypoints and features + videos_downsampled = F.interpolate(videos.flatten(0, 1), size=(224, 224), mode='bilinear', align_corners=False) + videos_downsampled = videos_downsampled.view(videos.shape[0], videos.shape[1], videos.shape[2], 224, 224) + + #for computing flow at higher resolution + videos_ = F.interpolate(videos.flatten(0, 1), size=(1024, 1024), mode='bilinear', align_corners=False) + videos = videos_.view(videos.shape[0], videos.shape[1], videos.shape[2], 1024, 1024) + + videos = videos.transpose(1, 2).half() + videos_downsampled = videos_downsampled.transpose(1, 2).half() + + # Get the forward flow for the frame at 300ms + forward_flow = self.get_forward_flow(videos) + + # Verify that there are no nans forward flow + assert not torch.isnan(forward_flow).any(), "Forward flow is nan" + + all_features = [] + + for x, arr in enumerate(self.frames): + + #use the downsampled videos for keypoints + vid = videos_downsampled[:, :, arr, :, :] + frame0 = vid[:, :, 0] + frame1 = vid[:, :, 1] + frame2 = vid[:, :, 2] + + #extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2 + mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1) + + #for the last set of frames only use features at keypoint regions of frame2 + if (x == 2): + k_feat = k_feat[:, -10:, :] + + #reshape the features to [batch size, num_features] + k_feat = k_feat.view(k_feat.shape[0], -1) + + choices_image_resolution = choices * self.model.patch_size[1] + + # At 300ms, add optical flow patches at the detected keypoint locations + # For the first frame (x == 0) + if x == 0: + # Extract the optical flow information from the forward flow matrix for the second channel (index 2) + flow_keyp = forward_flow[:, 2] + + # Initialize a result tensor to store the flow patches + # Tensor shape: [batch_size, 8x8 patch (flattened to 64) * 2 channels, 10 keypoints] + flow = torch.zeros(vid.shape[0], 8 * 8 * 2, 10).to(videos.device) + + # Patch size shift (since 8x8 patches are being extracted) + shift = 8 + + # Loop over each element in the batch to process individual video frames + for b in range(flow_keyp.size(0)): + # Extract the x and y coordinates of the keypoint locations for this batch element + x_indices = choices_image_resolution[b, :, 0] + y_indices = choices_image_resolution[b, :, 1] + + # For each keypoint (10 total keypoints in this case) + for ind in range(10): + # Extract the 8x8 patch of optical flow at each keypoint's (x, y) location + # Flatten the patch and assign it to the corresponding slice in the result tensor + flow[b, :, ind] = flow_keyp[b, :, y_indices[ind]:y_indices[ind] + shift, + x_indices[ind]:x_indices[ind] + shift].flatten() + + # Reshape the flow tensor for easier concatenation (flatten across all patches) + flow = flow.view(flow.shape[0], -1) + + # Concatenate the extracted optical flow features with the existing feature tensor (k_feat) + k_feat = torch.cat([k_feat, flow], dim=1) + + all_features.append(k_feat) + + x_encoded = torch.cat(all_features, dim=1) + + return x_encoded + + +class CWM_base_8x8_3frame(CWM): + def __init__(self,): + super().__init__('vitb_8x8patch_3frames') + +class CWM_base_8x8_3frame_mean_embed(CWM): + def __init__(self,): + super().__init__('vitb_8x8patch_3frames', aggregate_embeddings=True) + +# CWM* (keypoints only) 74.7 +class CWM_base_8x8_3frame_keypoints(CWM_Keypoints): + def __init__(self,): + super().__init__('vitb_8x8patch_3frames') + + +# CWM* (keypoints + Flow) 75.4 +class CWM_base_8x8_3frame_keypoints_flow(CWM_KeypointsFlow): + def __init__(self,): + super().__init__('vitb_8x8patch_3frames') + diff --git a/cwm/eval/Physion/flow_utils.py b/cwm/eval/Physion/flow_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7490805e601091e8c185dcc95fa047ce04ed331 --- /dev/null +++ b/cwm/eval/Physion/flow_utils.py @@ -0,0 +1,279 @@ + +import torch +import numpy as np +import random +import math + +def create_weighted_mask_batched(h, w): + y_mask = np.linspace(0, 1, h) + y_mask = np.minimum(y_mask, 1 - y_mask) + x_mask = np.linspace(0, 1, w) + x_mask = np.minimum(x_mask, 1 - x_mask) + weighted_mask = np.outer(y_mask, x_mask) + return torch.from_numpy(weighted_mask).float() + +def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape): + B, T, C, H, W = original_shape + + # Initialize an empty tensor to store the reconstructed video + reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device) + + # Create a tensor to store the sum of weighted masks + weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device) + + # Create a weighted mask for the crops + weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device) + weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor. + + for idx, crop in enumerate(cropped_tensors): + start_h, start_w = crop_positions[idx] + + # Multiply the crop with the weighted mask + weighted_crop = crop * weighted_mask + + # Add the weighted crop to the corresponding location in the reconstructed_video tensor + reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop + + # Update the weighted_masks_sum tensor + weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask + + # Add a small epsilon value to avoid division by zero + epsilon = 1e-8 + + # Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon + reconstructed_video /= (weighted_masks_sum + epsilon) + + return reconstructed_video + +import torch.nn.functional as F + +resize = lambda x,a: F.interpolate(x, [int(a*x.shape[-2]), int(a*x.shape[-1])], mode='bilinear', align_corners=False) + +upsample = lambda x,H,W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False) + + + +# +def compute_optical_flow(embedding_tensor, mask_tensor, frame_size): + # Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame + mask_unrolled = mask_tensor.view(-1) + + second_frame_unmask_indices = torch.where(mask_unrolled[frame_size**2:] == False)[0] + + # Divide the embedding tensor into two parts: corresponding to the first and the second frame + first_frame_embeddings = embedding_tensor[0, :frame_size**2, :] + second_frame_embeddings = embedding_tensor[0, frame_size**2:, :] + + # Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame + dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T) + norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :] + cos_sim_matrix = dot_product / norms + + # Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame + first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1) + + # Convert the 1D pixel indices into 2D coordinates + second_frame_y = second_frame_unmask_indices // frame_size + second_frame_x = second_frame_unmask_indices % frame_size + first_frame_y = first_frame_most_similar_indices // frame_size + first_frame_x = first_frame_most_similar_indices % frame_size + + # Compute the x and y displacements and convert them to float + displacements_x = (second_frame_x - first_frame_x).float() + displacements_y = (second_frame_y - first_frame_y).float() + + # Initialize optical flow tensor + optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device) + + # Assign the computed displacements to the corresponding pixels in the optical flow tensor + optical_flow[0, second_frame_y, second_frame_x] = displacements_x + optical_flow[1, second_frame_y, second_frame_x] = displacements_y + + return optical_flow + +def get_minimal_224_crops_new_batched(video_tensor, N): + B, T, C, H, W = video_tensor.shape + + # Calculate the number of crops needed in both the height and width dimensions + num_crops_h = math.ceil(H / 224) if H > 224 else 1 + num_crops_w = math.ceil(W / 224) if W > 224 else 1 + + # Calculate the step size for the height and width dimensions + step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1)) + step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1)) + + # Create a list to store the cropped tensors and their start positions + cropped_tensors = [] + crop_positions = [] + + # Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list + for i in range(num_crops_h): + for j in range(num_crops_w): + start_h = i * step_size_h + start_w = j * step_size_w + end_h = min(start_h + 224, H) + end_w = min(start_w + 224, W) + crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w] + cropped_tensors.append(crop) + crop_positions.append((start_h, start_w)) + + D = len(cropped_tensors) + + # If N is greater than D, generate additional random crops + if N > D and H > 224 and W > 224: # check if H and W are greater than 224 + for _ in range(N - D): + start_h = random.randint(0, H - 224) + start_w = random.randint(0, W - 224) + crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] + cropped_tensors.append(crop) + crop_positions.append((start_h, start_w)) + + # Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224) + cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors] + + return cropped_tensors, crop_positions + +def get_honglin_3frame_vmae_optical_flow_crop_batched(generator, + mask_generator, + img1, + img2, + img3, + neg_back_flow=True, + num_scales=1, + min_scale=400, + N_mask_samples=100, + mask_ratio=0.8, + flow_frames='23'): + B = img1.shape[0] + assert len(img1.shape) == 4 + assert num_scales >= 1 + + # For scaling + h1 = img2.shape[-2] + w1 = img2.shape[-1] + assert min_scale < h1 + + if neg_back_flow is False: + print('WARNING: Not calculating negative backward flow') + + alpha = (min_scale / img1.shape[-2]) ** (1 / 4) + + frame_size = 224 // generator.patch_size[-1] + + patch_size = generator.patch_size[-1] + + all_fwd_flows_e2d = [] + + for aidx in range(num_scales): + + # print('aidx: ', aidx) + + img1_scaled = resize(img1.clone(), alpha ** aidx) + img2_scaled = resize(img2.clone(), alpha ** aidx) + img3_scaled = resize(img3.clone(), alpha ** aidx) + + h2 = img2_scaled.shape[-2] + w2 = img2_scaled.shape[-1] + + s_h = h1 / h2 + s_w = w1 / w2 + + # Because technically the compute_optical_flow function returns neg back flow + if neg_back_flow is True: + video = torch.cat([img3_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1) + else: + video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img3_scaled.unsqueeze(1)], 1) + + # Should work, even if the incoming video is already 224x224 + crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1) + + # print(len(crops1), crops1[0].shape) + + num_crops = len(crops1) + + crop_flows_enc = [] + crop_flows_enc2dec = [] + N_samples = N_mask_samples + + crop = torch.cat(crops1, 0).cuda() + # print(crop.shape) + + optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda() + mask_counts = torch.zeros(frame_size, frame_size).cuda() + + i = 0 + while i < N_samples or (mask_counts == 0).any().item(): + if i % 100 == 0: + pass # print(i) + mask_generator.mask_ratio = mask_ratio + + # breakpoint() + # This would be that every sample has the same mask. For now that's okay I think + mask = mask_generator(num_frames=3)[None] + mask_2f = ~mask[0, frame_size * frame_size * 2:] + mask_counts += mask_2f.reshape(frame_size, frame_size) + + with torch.cuda.amp.autocast(enabled=True): + + processed_x = crop.transpose(1, 2) + + # print("crop", processed_x.max()) + + encoder_out = generator.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1)) + encoder_to_decoder = generator.encoder_to_decoder(encoder_out) + # print(encoder_to_decoder.shape) + + if flow_frames == '23': + encoder_to_decoder = encoder_to_decoder[:, frame_size * frame_size:, :] + flow_mask = mask[:, frame_size * frame_size:] + # print(encoder_to_decoder.shape) + elif flow_frames == '12': + encoder_to_decoder = encoder_to_decoder[:, :frame_size * frame_size * 2, :] + # print(encoder_to_decoder.shape) + flow_mask = mask[:, :frame_size * frame_size * 2] + # print(mask.shape) + # print(flow_mask.shape) + # print() + + optical_flow_e2d = [] + # one per batch element for now + for b in range(B * num_crops): + batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size) + optical_flow_e2d.append(batch_flow.unsqueeze(0)) + + optical_flow_e2d = torch.cat(optical_flow_e2d, 0) + optical_flows_enc2dec += optical_flow_e2d + i += 1 + + optical_flows_enc2dec = optical_flows_enc2dec / mask_counts + + scale_factor_y = video.shape[-2] / 224 + scale_factor_x = video.shape[-1] / 224 + + scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec) + scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w + scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h + + # split the crops back up + crop_flows_enc2dec = scaled_optical_flow.split(B, 0) + # print(len(crop_flows_enc2dec)) + + optical_flows_enc2dec_joined = reconstruct_video_new_2_batched( + [_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in + crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1) + + all_fwd_flows_e2d.append(optical_flows_enc2dec_joined) + + all_fwd_flows_e2d_new = [] + + for r in all_fwd_flows_e2d: + new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1]) + all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1)) + return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1) + + if neg_back_flow is True: + return_flow = -return_flow + all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new] + + return return_flow, all_fwd_flows_e2d_new + diff --git a/cwm/eval/Physion/run_eval.sh b/cwm/eval/Physion/run_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..3650eaf5ff7fbe035ab6dec72a266ef58c9dccfe --- /dev/null +++ b/cwm/eval/Physion/run_eval.sh @@ -0,0 +1,17 @@ +#physion_feature_extract \ +#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \ +#--data_root_path /ccn2/u/rmvenkat/data/testing_physion/regenerate_from_old_commit/ \ +#--model_class feature_extractor.CWM_base_8x8_3frame \ +#--gpu 1 \ +#--batch_size 8 \ +#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release/ \ +#--mode ocp + +physion_train_readout \ +--train-path /ccn2/u/rmvenkat/data/physion_release/ocp/train_features.hdf5 \ +--test-path /ccn2/u/rmvenkat/data/physion_release/ocp/test_features.hdf5 \ +--model-name CWM_base_8x8_3frame \ +--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release/ocp/train_json.json \ +--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release/ocp/test_json.json \ +--test-scenario-map /ccn2/u/rmvenkat/data/physion_release/ocp/test_scenario_map.json \ +--save_path /ccn2/u/rmvenkat/data/physion_release/ \ No newline at end of file diff --git a/cwm/eval/Physion/run_eval_kfflow.sh b/cwm/eval/Physion/run_eval_kfflow.sh new file mode 100644 index 0000000000000000000000000000000000000000..9a1a547d962f2f9d8104008fc63c55c6e52866a8 --- /dev/null +++ b/cwm/eval/Physion/run_eval_kfflow.sh @@ -0,0 +1,18 @@ +#physion_feature_extract \ +#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \ +#--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \ +#--model_class feature_extractor.CWM_base_8x8_3frame_Keypoints_KFFlowPatched_noF1_cwm_50_occ_mask \ +#--gpu 1 \ +#--batch_size 8 \ +#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_kfflow/ \ +#--mode ocp + + +physion_train_readout \ +--train-path /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/train_features.hdf5 \ +--test-path /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_features.hdf5 \ +--model-name CWM_base_8x8_3frame_Keypoints_KFFlowPatched_noF1_cwm_50_occ_mask \ +--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/train_json.json \ +--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_json.json \ +--test-scenario-map /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_scenario_map.json \ +--save_path /ccn2/u/rmvenkat/data/physion_release_kfflow/ \ No newline at end of file diff --git a/cwm/eval/Physion/run_eval_mp4s.sh b/cwm/eval/Physion/run_eval_mp4s.sh new file mode 100644 index 0000000000000000000000000000000000000000..156f3c801b2d9eed159585732819c3206a42b2ac --- /dev/null +++ b/cwm/eval/Physion/run_eval_mp4s.sh @@ -0,0 +1,19 @@ +dir_for_saving=/ccn2/u/rmvenkat/data/physion_release/ +model_name=CWM_base_8x8_3frame + +physion_feature_extract \ +--data_root_path /ccn2/u/rmvenkat/data/download_test/physion_mp4s/ \ +--model_class feature_extractor.$model_name \ +--gpu 1 \ +--batch_size 8 \ +--dir_for_saving $dir_for_saving \ +--mode ocd + +physion_train_readout \ +--train-path ${dir_for_saving}ocd/train_features.hdf5 \ +--test-path ${dir_for_saving}ocd/test_features.hdf5 \ +--model-name $model_name \ +--train-scenario-indices ${dir_for_saving}ocd/train_json.json \ +--test-scenario-indices ${dir_for_saving}ocd/test_json.json \ +--test-scenario-map ${dir_for_saving}ocd/test_scenario_map.json \ +--save_path $dir_for_saving diff --git a/cwm/eval/Physion/run_eval_mp4s_keyp.sh b/cwm/eval/Physion/run_eval_mp4s_keyp.sh new file mode 100644 index 0000000000000000000000000000000000000000..58fea4eedbf9983a0f7ef29d077d2f1d03ce76cd --- /dev/null +++ b/cwm/eval/Physion/run_eval_mp4s_keyp.sh @@ -0,0 +1,17 @@ +#physion_feature_extract \ +#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \ +#--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \ +#--model_class feature_extractor_cleaned.CWM_base_8x8_3frame_keypoints \ +#--gpu 3 \ +#--batch_size 8 \ +#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_keyp/ \ +#--mode ocp + +physion_train_readout \ +--train-path /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/train_features.hdf5 \ +--test-path /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_features.hdf5 \ +--model-name CWM_base_8x8_3frame \ +--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/train_json.json \ +--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_json.json \ +--test-scenario-map /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_scenario_map.json \ +--save_path /ccn2/u/rmvenkat/data/physion_release_keyp/ \ No newline at end of file diff --git a/cwm/eval/Physion/run_eval_mp4s_keyp_flow.sh b/cwm/eval/Physion/run_eval_mp4s_keyp_flow.sh new file mode 100644 index 0000000000000000000000000000000000000000..7c0016e7942e5dde6d9ab9b6f5accf50369c07c3 --- /dev/null +++ b/cwm/eval/Physion/run_eval_mp4s_keyp_flow.sh @@ -0,0 +1,17 @@ +#physion_feature_extract \ +#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \ +#--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \ +#--model_class feature_extractor_cleaned.CWM_base_8x8_3frame_keypoints_flow \ +#--gpu 7 \ +#--batch_size 8 \ +#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ \ +#--mode ocp + +physion_train_readout \ +--train-path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/train_features.hdf5 \ +--test-path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_features.hdf5 \ +--model-name CWM_base_8x8_3frame \ +--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/train_json.json \ +--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_json.json \ +--test-scenario-map /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_scenario_map.json \ +--save_path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ \ No newline at end of file diff --git a/cwm/eval/Segmentation/__init__.py b/cwm/eval/Segmentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/eval/Segmentation/archive/__init__.py b/cwm/eval/Segmentation/archive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/eval/Segmentation/archive/common/__init__.py b/cwm/eval/Segmentation/archive/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/eval/Segmentation/archive/common/coco_loader_lsj.py b/cwm/eval/Segmentation/archive/common/coco_loader_lsj.py new file mode 100644 index 0000000000000000000000000000000000000000..7f3cfeec95cf641cd4a3b0671ed8abac063ebbe2 --- /dev/null +++ b/cwm/eval/Segmentation/archive/common/coco_loader_lsj.py @@ -0,0 +1,222 @@ +import detectron2.data.transforms as T +from detectron2 import model_zoo +from detectron2.config import LazyCall as L + +# Data using LSJ +image_size = 512 +dataloader = model_zoo.get_config("common/data/coco.py").dataloader +dataloader.train.mapper.augmentations = [ + L(T.RandomFlip)(horizontal=True), # flip first + L(T.ResizeScale)( + min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size + ), + L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False), +] +dataloader.train.mapper.image_format = "RGB" +dataloader.train.total_batch_size = 64 +dataloader.train.num_workers = 0 +# recompute boxes due to cropping +dataloader.train.mapper.recompute_boxes = True + +dataloader.test.mapper.augmentations = [ + L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size), +] + + + + +import copy +import logging +import numpy as np +from typing import List, Optional, Union +import torch + +from detectron2.config import configurable + +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T + +""" +This file contains the default mapping that's applied to "dataset dicts". +""" + +__all__ = ["DatasetMapper"] + + +class DatasetMapper: + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into a format used by the model. + + This is the default callable to be used to map your dataset dict into training data. + You may need to follow it to implement your own one for customized logic, + such as a different way to read or transform images. + See :doc:`/tutorials/data_loading` for details. + + The callable currently does the following: + + 1. Read the image from "file_name" + 2. Applies cropping/geometric transforms to the image and annotations + 3. Prepare data and annotations to Tensor and :class:`Instances` + """ + + @configurable + def __init__( + self, + is_train: bool, + *, + augmentations: List[Union[T.Augmentation, T.Transform]], + image_format: str, + use_instance_mask: bool = False, + use_keypoint: bool = False, + instance_mask_format: str = "polygon", + keypoint_hflip_indices: Optional[np.ndarray] = None, + precomputed_proposal_topk: Optional[int] = None, + recompute_boxes: bool = False, + ): + """ + NOTE: this interface is experimental. + + Args: + is_train: whether it's used in training or inference + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + use_instance_mask: whether to process instance segmentation annotations, if available + use_keypoint: whether to process keypoint annotations if available + instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation + masks into this format. + keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` + precomputed_proposal_topk: if given, will load pre-computed + proposals from dataset_dict and keep the top k proposals for each image. + recompute_boxes: whether to overwrite bounding box annotations + by computing tight bounding boxes from instance mask annotations. + """ + if recompute_boxes: + assert use_instance_mask, "recompute_boxes requires instance masks" + # fmt: off + self.is_train = is_train + self.augmentations = T.AugmentationList(augmentations) + self.image_format = image_format + self.use_instance_mask = use_instance_mask + self.instance_mask_format = instance_mask_format + self.use_keypoint = use_keypoint + self.keypoint_hflip_indices = keypoint_hflip_indices + self.proposal_topk = precomputed_proposal_topk + self.recompute_boxes = recompute_boxes + # fmt: on + logger = logging.getLogger(__name__) + mode = "training" if is_train else "inference" + logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") + + @classmethod + def from_config(cls, cfg, is_train: bool = True): + augs = utils.build_augmentation(cfg, is_train) + if cfg.INPUT.CROP.ENABLED and is_train: + augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) + recompute_boxes = cfg.MODEL.MASK_ON + else: + recompute_boxes = False + + ret = { + "is_train": is_train, + "augmentations": augs, + "image_format": cfg.INPUT.FORMAT, + "use_instance_mask": cfg.MODEL.MASK_ON, + "instance_mask_format": cfg.INPUT.MASK_FORMAT, + "use_keypoint": cfg.MODEL.KEYPOINT_ON, + "recompute_boxes": recompute_boxes, + } + + if cfg.MODEL.KEYPOINT_ON: + ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) + + if cfg.MODEL.LOAD_PROPOSALS: + ret["precomputed_proposal_topk"] = ( + cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN + if is_train + else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST + ) + return ret + + def _transform_annotations(self, dataset_dict, transforms, image_shape): + # USER: Modify this if you want to keep them for some reason. + for anno in dataset_dict["annotations"]: + if not self.use_instance_mask: + anno.pop("segmentation", None) + if not self.use_keypoint: + anno.pop("keypoints", None) + + # USER: Implement additional transformations if you have other types of data + annos = [ + utils.transform_instance_annotations( + obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices + ) + for obj in dataset_dict.pop("annotations") + if obj.get("iscrowd", 0) == 0 + ] + instances = utils.annotations_to_instances( + annos, image_shape, mask_format=self.instance_mask_format + ) + + # After transforms such as cropping are applied, the bounding box may no longer + # tightly bound the object. As an example, imagine a triangle object + # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight + # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to + # the intersection of original bounding box and the cropping box. + if self.recompute_boxes: + instances.gt_boxes = instances.gt_masks.get_bounding_boxes() + dataset_dict["instances"] = utils.filter_empty_instances(instances) + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + # USER: Write your own image loading if it's not from a file + image = utils.read_image(dataset_dict["file_name"], format=self.image_format) + utils.check_image_size(dataset_dict, image) + + # USER: Remove if you don't do semantic/panoptic segmentation. + if "sem_seg_file_name" in dataset_dict: + sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) + else: + sem_seg_gt = None + + aug_input = T.AugInput(image, sem_seg=sem_seg_gt) + transforms = self.augmentations(aug_input) + image, sem_seg_gt = aug_input.image, aug_input.sem_seg + + image_shape = image.shape[:2] # h, w + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) + + # USER: Remove if you don't use pre-computed proposals. + # Most users would not need this feature. + if self.proposal_topk is not None: + utils.transform_proposals( + dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk + ) + + if not self.is_train: + # USER: Modify this if you want to keep them for some reason. + dataset_dict.pop("annotations", None) + dataset_dict.pop("sem_seg_file_name", None) + return dataset_dict + + if "annotations" in dataset_dict: + self._transform_annotations(dataset_dict, transforms, image_shape) + + # Modified by Honglin Chen: change it to class-agnostic instance labels + dataset_dict['instances'].gt_classes *= 0 + return dataset_dict + + +dataloader.train.mapper._target_ = DatasetMapper diff --git a/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format.py b/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format.py new file mode 100644 index 0000000000000000000000000000000000000000..757d3201e9f241dd752e7cadb92c6bc1ef8d2a8d --- /dev/null +++ b/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format.py @@ -0,0 +1,19 @@ +import torch +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--input', type=str, help='The path to the checkpoint.') +parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint') +args = parser.parse_args() + +state_dict = torch.load(args.input, map_location='cpu')['model'] + +new_state_dict = {} +for k, v in state_dict.items(): + if 'encoder' in k and not 'decoder' in k: + new_k = 'backbone.net.model.' + k + new_state_dict[new_k] = v + +output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output +torch.save(new_state_dict, output_path) +print('Save model to', output_path) \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format_v2.py b/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..6b58caf4024bea05862a34aa3c81331ed40099e2 --- /dev/null +++ b/cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format_v2.py @@ -0,0 +1,54 @@ +import torch +import argparse +import sys +sys.path.append('../../../') +from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table + +parser = argparse.ArgumentParser() +parser.add_argument('--input', type=str, help='The path to the checkpoint.') +parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint') +args = parser.parse_args() + +state_dict = torch.load(args.input, map_location='cpu')['model'] +mae = True +# C = state_dict['encoder.patch_embed.proj.weight'].shape[0] +C = 768 +pos_embed = get_sinusoid_encoding_table(14*14, C) +cls_token = torch.zeros(1, 1, C) +pos_embed = torch.cat([cls_token, pos_embed], dim=1) + + +new_state_dict = {'backbone.net.pos_embed': pos_embed} +for k, v in state_dict.items(): + + if mae or ('encoder' in k and not 'decoder' in k or 'patch_embed' in k): + + if 'patch_embed.proj.weight' in k: + + if len(v.shape) == 5: + if v.shape[2] == 1: + v = v.squeeze(2) # (768, 3, 1, 16, 16) -> (768, 3, 16, 16) + else: + v = v[:, :, 0] + + old_k = k + k = k.replace('encoder.', 'backbone.net.') if not mae else 'backbone.net.'+k + + if 'attn' in k and '_bias' in k: + old_attn = '.'.join(old_k.split('.')[:-1]) + attn = '.'.join(k.split('.')[:-1]) + k = attn + '.qkv.bias' + if k in new_state_dict: + continue + + v = torch.cat([ + state_dict[old_attn + '.q_bias'], + state_dict[old_attn + '.k_bias'] if (old_attn + '.k_bias') in state_dict else torch.zeros_like(state_dict[old_attn + '.q_bias']), + state_dict[old_attn + '.v_bias'], + ], dim=0) + print(k, v.shape) + new_state_dict[k] = v + +output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output +torch.save(new_state_dict, output_path) +print('Save model to', output_path) \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/common/convert_dino_checkpoint_detectron_format.py b/cwm/eval/Segmentation/archive/common/convert_dino_checkpoint_detectron_format.py new file mode 100644 index 0000000000000000000000000000000000000000..477efc0c06657518571cf9ee5775ef2189592299 --- /dev/null +++ b/cwm/eval/Segmentation/archive/common/convert_dino_checkpoint_detectron_format.py @@ -0,0 +1,26 @@ +import torch +import argparse +import sys +sys.path.append('../../../') +from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table + +parser = argparse.ArgumentParser() +parser.add_argument('--input', type=str, help='The path to the checkpoint.') +parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint') +args = parser.parse_args() +breakpoint() +state_dict = torch.load(args.input, map_location='cpu') + +new_state_dict = {} + +for k, v in state_dict.items(): + if 'pos_embed' in k: + breakpoint() + else: + pass + k = 'backbone.net.' + k + new_state_dict[k] = v + +output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output +torch.save(new_state_dict, output_path) +print('Save model to', output_path) \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/competition.py b/cwm/eval/Segmentation/archive/competition.py new file mode 100644 index 0000000000000000000000000000000000000000..2143ee00f0f90f9e23c777dae4e1effaf805ef1c --- /dev/null +++ b/cwm/eval/Segmentation/archive/competition.py @@ -0,0 +1,673 @@ +import numpy as np +import torch +from torch import nn +from torchvision import transforms +import torch.nn.functional as F +from torch.distributions.categorical import Categorical + +from kornia.filters.kernels import (get_spatial_gradient_kernel2d, + normalize_kernel2d) + +def l2_normalize(x): + return F.normalize(x, p=2.0, dim=-1, eps=1e-6) + +def reduce_max(x, dim, keepdim=True): + return torch.max(x, dim=dim, keepdim=keepdim)[0] + +def coordinate_ims(batch_size, seq_length, imsize): + static = False + if seq_length == 0: + static = True + seq_length = 1 + B = batch_size + T = seq_length + H,W = imsize + ones = torch.ones([B,H,W,1], dtype=torch.float32) + h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=torch.float32)) + h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5) + w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=torch.float32)) + w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5) + h = torch.stack([h]*T, 1) + w = torch.stack([w]*T, 1) + hw_ims = torch.cat([h,w], -1) + if static: + hw_ims = hw_ims[:,0] + return hw_ims + +def dot_product_attention(queries, keys, normalize=True, eps=1e-8): + """ + Compute the normalized dot product between two PyTorch tensors + """ + B,N,D_q = queries.size() + _B,N_k,D_k = keys.size() + assert D_q == D_k, (queries.shape, keys.shape) + if normalize: + queries = F.normalize(queries, p=2.0, dim=-1, eps=eps) + keys = F.normalize(keys, p=2.0, dim=-1, eps=eps) + + outputs = torch.matmul(queries, torch.transpose(keys, 1, 2)) # [B, N, N_k] + attention = torch.transpose(outputs, 1, 2) # [B, N_k, N] + + return outputs + +def sample_image_inds_from_probs(probs, num_points, eps=1e-9): + + B,H,W = probs.shape + P = num_points + N = H*W + + probs = probs.reshape(B,N) + probs = torch.maximum(probs + eps, torch.tensor(0., device=probs.device)) / (probs.sum(dim=-1, keepdim=True) + eps) + dist = Categorical(probs=probs, validate_args=False) + + indices = dist.sample([P]).permute(1,0).to(torch.int32) # [B,P] + + indices_h = torch.minimum(torch.maximum(torch.div(indices, W, rounding_mode='floor'), torch.tensor(0)), torch.tensor(H-1)) + indices_w = torch.minimum(torch.maximum(torch.fmod(indices, W), torch.tensor(0)), torch.tensor(W-1)) + indices = torch.stack([indices_h, indices_w], dim=-1) # [B,P,2] + return indices + +def get_gradient_image(image, mode='sobel', order=1, normalize_kernel=True): + + B,C,H,W = list(image.size()) + + # prepare kernel + kernel = get_spatial_gradient_kernel2d(mode, order) + if normalize_kernel: + kernel = normalize_kernel2d(kernel) + tmp_kernel = kernel.to(image).detach() + tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1) + kernel_flip = tmp_kernel.flip(-3) + + # pad spatial dims of image + padding = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] + out_channels = 3 if (order == 2) else 2 + padded_image = F.pad(image.reshape(B*C, 1, H, W), padding, 'replicate')[:, :, None] # [B*C,1,1,H+p,W+p] + gradient_image = F.conv3d(padded_image, kernel_flip, padding=0).view(B, C, out_channels, H, W) + return gradient_image + +def sample_coordinates_at_borders(image, num_points=16, mask=None, sum_edges=True, normalized_coordinates=True): + """ + Sample num_points in normalized (h,w) coordinates from the borders of the input image + """ + B,C,H,W = list(image.size()) + if mask is not None: + assert mask.shape[2:] == image.shape[2:], (mask.size(), image.size()) + else: + mask = torch.ones(size=(B,1,H,W)).to(image) + + gradient_image = get_gradient_image(image * mask, mode='sobel', order=1) # [B,C,2,H,W] + gradient_magnitude = torch.sqrt(torch.square(gradient_image).sum(dim=2)) + if sum_edges: + edges = gradient_magnitude.sum(1) # [B,H,W] + else: + edges = gradient_magnitude.max(1)[0] + + if mask is not None: + edges = edges * mask[:,0] + + coordinates = sample_image_inds_from_probs(edges, num_points=num_points) + if normalized_coordinates: + coordinates = coordinates.to(torch.float32) + coordinates /= torch.tensor([H-1,W-1], dtype=torch.float32)[None,None].to(coordinates.device) + coordinates = 2.0 * coordinates - 1.0 + return coordinates + +def index_into_images(images, indices, channels_last=False): + """ + index into an image at P points to get its values + + images: [B,C,H,W] + indices: [B,P,2] + """ + assert indices.size(-1) == 2, indices.size() + if channels_last: + images = images.permute(0,3,1,2) # [B,C,H,W] + B,C,H,W = images.shape + _,P,_ = indices.shape + inds_h, inds_w = list(indices.to(torch.long).permute(2,0,1)) # [B,P] each + inds_b = torch.arange(B, dtype=torch.long).unsqueeze(-1).expand(-1,P).to(indices) + inds = torch.stack([inds_b, inds_h, inds_w], 0).to(torch.long) + values = images.permute(0,2,3,1)[list(inds)] # [B,P,C] + return values + +def soft_index(images, indices, scale_by_imsize=True): + assert indices.shape[-1] == 2, indices.shape + B,C,H,W = images.shape + _,P,_ = indices.shape + + # h_inds, w_inds = indices.split([1,1], dim=-1) + h_inds, w_inds = list(indices.permute(2,0,1)) + if scale_by_imsize: + h_inds = (h_inds + 1.0) * torch.tensor(H).to(h_inds) * 0.5 + w_inds = (w_inds + 1.0) * torch.tensor(W).to(w_inds) * 0.5 + + h_inds = torch.maximum(torch.minimum(h_inds, torch.tensor(H-1).to(h_inds)), torch.tensor(0.).to(h_inds)) + w_inds = torch.maximum(torch.minimum(w_inds, torch.tensor(W-1).to(w_inds)), torch.tensor(0.).to(w_inds)) + + h_floor = torch.floor(h_inds) + w_floor = torch.floor(w_inds) + h_ceil = torch.ceil(h_inds) + w_ceil = torch.ceil(w_inds) + + bot_right_weight = (h_inds - h_floor) * (w_inds - w_floor) + bot_left_weight = (h_inds - h_floor) * (w_ceil - w_inds) + top_right_weight = (h_ceil - h_inds) * (w_inds - w_floor) + top_left_weight = (h_ceil - h_inds) * (w_ceil - w_inds) + + in_bounds = (bot_right_weight + bot_left_weight + top_right_weight + top_left_weight) > 0.95 + in_bounds = in_bounds.to(torch.float32) + + top_left_vals = index_into_images(images, torch.stack([h_floor, w_floor], -1)) + top_right_vals = index_into_images(images, torch.stack([h_floor, w_ceil], -1)) + bot_left_vals = index_into_images(images, torch.stack([h_ceil, w_floor], -1)) + bot_right_vals = index_into_images(images, torch.stack([h_ceil, w_ceil], -1)) + + im_vals = top_left_vals * top_left_weight[...,None] + im_vals += top_right_vals * top_right_weight[...,None] + im_vals += bot_left_vals * bot_left_weight[...,None] + im_vals += bot_right_vals * bot_right_weight[...,None] + + im_vals = im_vals.view(B,P,C) + + return im_vals + +def compute_compatibility(positions, plateau, phenotypes=None, availability=None, noise=0.1): + """ + Compute how well "fit" each agent is for the position it's at on the plateau, + according to its "phenotype" + + positions: [B,P,2] + plateau: [B,H,W,Q] + phenotypes: [B,P,D] or None + availability: [B,H,W,A] + """ + B,H,W,Q = plateau.shape + P = positions.shape[1] + if phenotypes is None: + phenotypes = soft_index(plateau, positions) + + if availability is not None: + assert list(availability.shape)[:-1] == list(plateau.shape)[:-1], (availability.shape, plateau.shape) + A = availability.size(-1) + assert P % A == 0, (P, A) + S = P // A # population size + print("computing availability -- needlessly?", [B,H,W,A,Q]) + plateau = availability[...,None] * plateau[...,None,:] # [B,H,W,A,Q] + plateau = plateau.view(B,H,W,A*Q) + + plateau_values = soft_index(plateau.permute(0,3,1,2), positions, scale_by_imsize=True) + if noise > 0: + plateau_values += noise * torch.rand(size=plateau_values.size(), dtype=torch.float32).to(plateau_values.device) + + if availability is not None: + plateau_values = l2_normalize(plateau_values.view(B, P, A, Q)) + inds = torch.tile(torch.eye(A)[None].expand(B,-1,-1), (1,S,1))[...,None] # [B,P,A,1] + plateau_values = torch.sum(plateau_values * inds.to(plateau_values), dim=-2) # [B,P,Q] + else: + plateau_values = l2_normalize(plateau_values) + + compatibility = torch.sum( + l2_normalize(phenotypes) * plateau_values, dim=-1, keepdim=True) # [B,P,1] + + return compatibility + +def compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=None, eps=1e-6): + """Find overlaps between masks""" + B,N,P = masks.shape + if masks_target is None: + masks_target = masks + if mask_thresh is not None: + masks = (masks > mask_thresh).to(torch.float32) + masks_target = (masks_target > mask_thresh).to(torch.float32) + + ## union and intersection + overlaps = masks[...,None] * masks_target[...,None,:] # [B,N,P,P] + I = overlaps.sum(dim=1) + U = torch.maximum(masks[...,None], masks_target[...,None,:]).sum(dim=1) + iou = I / torch.maximum(U, torch.tensor(eps, dtype=torch.float32)) # [B,P,P] + + return iou + +def compete_agents(masks, fitnesses, alive, + mask_thresh=0.5, compete_thresh=0.2, + sticky_winners=True): + """ + Kill off agents (which mask dimensions are "alive") based on mask overlap and fitnesses of each + + args: + masks: [B,N,P] + fitnesses: [B,P,1] + alive: [B,P,1] + + returns: + still_alive: [B,P,1] + + """ + B,N,P = masks.shape + assert list(alive.shape) == [B,P,1], alive.shape + assert list(fitnesses.shape) == [B,P,1], fitnesses.shape + + ## find territorial disputes + overlaps = compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=mask_thresh) + disputes = overlaps > compete_thresh # [B,P,P] + + ## agents don't fight themselves + disputes = torch.logical_and( + disputes, torch.logical_not( + torch.eye(P, dtype=torch.bool, device=disputes.device).unsqueeze(0).expand(B,-1,-1))) + + ## kill off the agents with lower fitness in each dispute + killed = torch.logical_and(disputes, fitnesses < torch.transpose(fitnesses, 1, 2)) + + ## once an agent wins, it always wins again + if sticky_winners: + winners = (alive > 0.5) + losers = torch.logical_not(winners) + + ## winners can't lose to last round's losers + winners_vs_losers = torch.logical_and(winners, torch.transpose(losers, 1, 2)) # [B,P,P] + killed = torch.logical_and(killed, torch.logical_not(winners_vs_losers)) + + ## losers can't overtake last round's winners + losers_vs_winners = torch.logical_and(losers, torch.transpose(winners, 1, 2)) + losers_vs_winners_disputes = torch.logical_and(losers_vs_winners, disputes) + killed = torch.logical_or(killed, losers_vs_winners_disputes) + + ## if an agent was killed by *any* competitor, it's dead + killed = torch.any(killed, dim=2, keepdim=True) + alive = torch.logical_not(killed).to(torch.float32) + + return alive + +def compute_distance_weighted_vectors(vector_map, positions, mask=None, beta=1.0, eps=1e-8): + """ + compute vectors whose values are a weighted mean of vector_map, where weights are given by distance. + """ + B,H,W,D = vector_map.shape + assert positions.size(-1) == 2, positions.size() + B,P,_ = positions.shape + N = H*W + + if mask is None: + mask = torch.ones_like(vector_map[...,0:1]).to(vector_map.device) + else: + assert list(mask.shape) == [B,H,W,1] + + hw_grid = coordinate_ims(B, 0, [H,W]).view(B, N, 2).to(vector_map.device) + delta_positions = hw_grid[:,None] - positions[:,:,None] # [B,P,N,2] + distances = torch.sqrt(delta_positions[...,0]**2 + delta_positions[...,1]**2 + eps) # [B,P,N] + + ## max distance is 2*sqrt(2) + inv_distances = (2.0 * np.sqrt(2.0)) / (distances + eps) + inv_distances = F.softmax(beta * inv_distances * mask.view(B, 1, N), dim=-1) # [B,P,N] + distance_weighted_vectors = torch.sum( + vector_map.view(B, 1, N, D) * inv_distances[...,None], dim=2, keepdim=False) # [B,P,D] + return distance_weighted_vectors + +def masks_from_phenotypes(plateau, phenotypes, normalize=True): + + B,H,W,Q = plateau.shape + N = H*W + masks = dot_product_attention( + queries=plateau.view(B,N,Q), + keys=phenotypes, + normalize=normalize) + masks = F.relu(masks) + return masks + +class Competition(nn.Module): + + def __init__( + self, + size=None, + num_masks=16, + num_competition_rounds=5, + mask_beta=10.0, + reduce_func=reduce_max, + stop_gradient=True, + stop_gradient_phenotypes=True, + normalization_func=l2_normalize, + sum_edges=True, + mask_thresh=0.5, + compete_thresh=0.2, + sticky_winners=True, + selection_strength=100.0, + homing_strength=10.0, + mask_dead_segments=True + ): + super().__init__() + self.num_masks = self.M = num_masks + self.num_competition_rounds = num_competition_rounds + self.mask_beta = mask_beta + self.reduce_func = reduce_func + self.normalization_func = normalization_func + + ## stop gradients + self.sg_func = lambda x: (x.detach() if stop_gradient else x) + self.sg_phenotypes_func = lambda x: (x.detach() if stop_gradient_phenotypes else x) + + ## agent sampling kwargs + self.sum_edges = sum_edges + + ## competition kwargs + self.mask_thresh = mask_thresh + self.compete_thresh = compete_thresh + self.sticky_winners = sticky_winners + self.selection_strength = selection_strength + self.homing_strength = homing_strength + self.mask_dead_segments = mask_dead_segments + + ## shapes + self.B = self.T = self.BT = self.N = self.Q = None + self.size = size # [H,W] + if self.size: + assert len(self.size) == 2, self.size + + def reshape_batch_time(self, x, merge=True): + + if merge: + self.is_temporal = True + B, T = x.size()[0:2] + if self.B: + assert (B == self.B), (B, self.B) + else: + self.B = B + + if self.T: + assert (T == self.T), (T, self.T) + else: + self.T = T + + assert B*T == (self.B * self.T), (B*T, self.B*self.T) + if self.BT is None: + self.BT = self.B * self.T + + return torch.reshape(x, [self.BT] + list(x.size())[2:]) + + else: # split + BT = x.size()[0] + assert self.B and self.T, (self.B, self.T) + if self.BT is not None: + assert BT == self.BT, (BT, self.BT) + else: + self.BT = BT + + return torch.reshape(x, [self.B, self.T] + list(x.size())[1:]) + + def process_plateau_input(self, plateau): + + shape = plateau.size() + if len(shape) == 5: + self.is_temporal = True + self.B, self.T, self.H, self.W, self.Q = shape + self.N = self.H * self.W + self.BT = self.B * self.T + plateau = self.reshape_batch_time(plateau) + elif (len(shape) == 4) and (self.size is None): + self.is_temporal = False + self.B, self.H, self.W, self.Q = shape + self.N = self.H * self.W + self.T = 1 + self.BT = self.B*self.T + elif (len(shape) == 4) and (self.size is not None): + self.is_temporal = True + self.B, self.T, self.N, self.Q = shape + self.BT = self.B * self.T + self.H, self.W = self.size + plateau = self.reshape_batch_time(plateau) + plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q]) + elif len(shape) == 3: + assert self.size is not None, \ + "You need to specify an image size to reshape the plateau of shape %s" % shape + self.is_temporal = False + self.B, self.N, self.Q = shape + self.T = 1 + self.BT = self.B + self.H, self.W = self.size + plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q]) + else: + raise ValueError("input plateau map with shape %s cannot be reshaped to [BT, H, W, Q]" % shape) + + return plateau + + def forward(self, + plateau, + agents=None, + alive=None, + phenotypes=None, + compete=True, + update_pointers=True, + yoke_phenotypes_to_agents=True, + noise=0.1 + ): + """ + Find the uniform regions within the plateau map + by competition between visual "indices." + + args: + plateau: [B,[T],H,W,Q] feature map with smooth "plateaus" + + returns: + masks: [B, [T], H, W, M] one mask in each of M channels + agents: [B, [T], M, 2] positions of agents in normalized coordinates + alive: [B, [T], M] binary vector indicating which masks are valid + phenotypes: [B, [T], M, Q] + unharvested: [B, [T], H, W] map of regions that weren't covered + + """ + + ## preprocess + plateau = self.process_plateau_input(plateau) # [BT,H,W,Q] + plateau = self.normalization_func(plateau) + + ## sample initial indices ("agents") from borders of the plateau map + if agents is None: + agents = sample_coordinates_at_borders( + plateau.permute(0,3,1,2), + num_points=self.M, + mask=None, + sum_edges=self.sum_edges) + else: + if self.is_temporal: + agents = agents.view(self.BT, *agents.shape[2:]) + + ## the agents have "phenotypes" depending on where they're situated on the plateau map + if phenotypes is None: + phenotypes = self.sg_phenotypes_func( + self.normalization_func( + soft_index(plateau.permute(0,3,1,2), + agents, scale_by_imsize=True))) + elif self.is_temporal: + phenotypes = phenotypes.view(self.BT, *phenotypes.shape[2:]) + + ## the "fitness" of an agent -- how likely it is to survive competition -- + ## is how well its phenotype matches the plateau vector at its current position + ## initially all of these agents are "alive" + if alive is None: + alive = torch.ones_like(agents[...,-1:]) # [BT,M,1] + fitnesses = compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise) + alive_mask = None + else: + if self.is_temporal: + alive = alive.view(self.BT, *alive.shape[2:]) + alive_mask = (alive > 0.5).float() + fitnesses = alive_mask + compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise) * (1 - alive_mask) + + alive_t = torch.transpose(alive, 1, 2) # [BT, 1, M] + + ## compute the masks at initialization + masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True) + + ## find the "unharvested" regions of the plateau map not covered by agents + unharvested = torch.minimum(self.reduce_func(masks_pred, dim=-1, keepdim=True), torch.tensor(1.0)) + unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1) + + if alive_mask is not None: + new_agents = sample_coordinates_at_borders( + plateau.permute(0,3,1,2), num_points=self.M, + mask=unharvested.permute(0,3,1,2), + sum_edges=self.sum_edges) + agents = agents * alive_mask + new_agents * (1.0 - alive_mask) + + new_phenotypes = self.sg_phenotypes_func( + self.normalization_func( + soft_index(plateau.permute(0,3,1,2), + new_agents, scale_by_imsize=True))) + phenotypes = phenotypes * alive_mask + new_phenotypes * (1.0 - alive_mask) + + for r in range(self.num_competition_rounds): + # print("Evolution round {}".format(r+1)) + + ## compute the "availability" of the plateau map for each agent (i.e. where it can harvest from) + alive_t = torch.transpose(alive, 1, 2) # [BT, 1, M] + # availability = alive_t * masks_pred + (1.0 - alive_t) * unharvested.view(self.BT, self.N, 1) + # availability = availability.view(self.BT, self.H, self.W, self.M) + + ## update the fitnesses + if update_pointers and compete: + fitnesses = compute_compatibility( + positions=agents, + plateau=plateau, + phenotypes=phenotypes, + # availability=availability) + availability=None, + noise=noise + ) + + + ## kill agents that have wandered off the map + in_bounds = torch.all( + torch.logical_and(agents < 1.0, agents > -1.0), + dim=-1, keepdim=True) # [BT,M,1] + fitnesses *= in_bounds.to(fitnesses) + + ## break ties in fitness + fitnesses -= 0.001 * torch.arange(self.M, dtype=torch.float32)[None,:,None].expand(self.BT,-1,-1).to(fitnesses.device) + + ## recompute the masks (why?) + if yoke_phenotypes_to_agents: + occupied_regions = self.sg_phenotypes_func( + soft_index(plateau.permute(0,3,1,2), agents, scale_by_imsize=True)) + masks_pred = masks_from_phenotypes(plateau, occupied_regions, normalize=True) # [BT,N,M] + + ## have each pair of agents compete. + ## If their masks overlap, the winner is the one with higher fitness + if compete: + alive = compete_agents(masks_pred, fitnesses, alive, + mask_thresh=self.mask_thresh, + compete_thresh=self.compete_thresh, + sticky_winners=self.sticky_winners) + + alive *= in_bounds.to(alive) + alive_t = torch.transpose(alive, 1, 2) + + # print("Num alive masks", alive.sum(), "which ones --> ", np.where(alive[0,:,0].detach().cpu().numpy())) + if not yoke_phenotypes_to_agents: + masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True) + + ## update which parts of the plateau are "unharvested" + unharvested = torch.minimum(self.reduce_func(masks_pred * alive_t, dim=-1, keepdim=True), + torch.tensor(1.0, dtype=torch.float32)) + unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1) + + + ## update phenotypes of the winners + if update_pointers: + if self.mask_thresh is not None: + winner_phenotypes = (masks_pred[...,None] > self.mask_thresh).to(plateau) + if self.selection_strength > 0: + winner_phenotypes = winner_phenotypes * plateau.view(self.BT, self.N, 1, self.Q) + winner_phenotypes = self.normalization_func(winner_phenotypes.mean(dim=1)) # [BT,M,Q] + phenotypes += (alive * winner_phenotypes) * self.selection_strength + + ## reinitialize losing agent positions + alive_mask = (alive > 0.5).to(torch.float32) + loser_agents = sample_coordinates_at_borders( + plateau.permute(0,3,1,2), num_points=self.M, + mask=unharvested.permute(0,3,1,2), + sum_edges=self.sum_edges) + agents = agents * alive_mask + loser_agents * (1.0 - alive_mask) + + + ## reinitialize loser agent phenotypes + loser_phenotypes = self.normalization_func( + compute_distance_weighted_vectors(plateau, agents, mask=unharvested, beta=self.homing_strength)) + phenotypes = alive_mask * phenotypes + (1.0 - alive_mask) * loser_phenotypes + phenotypes = self.normalization_func(phenotypes) + + ## that's it for this round! + # print("round %d" % r, alive.shape, torch.where(alive[0,:,0])) + + ## run a final competition between the surviving masks + if self.mask_beta is not None: + masks_pred = F.softmax( + self.mask_beta * masks_pred * alive_t - \ + self.mask_beta * (1.0 - alive_t), dim=-1) + if self.mask_dead_segments: + masks_pred *= alive_t + + masks_pred = masks_pred.view(self.BT,self.H,self.W,self.M) + if self.is_temporal: + masks_pred = self.reshape_batch_time(masks_pred, merge=False) + agents = self.reshape_batch_time(agents, merge=False) + alive = self.reshape_batch_time(alive, merge=False) + phenotypes = self.reshape_batch_time(phenotypes, merge=False) + unharvested = self.reshape_batch_time(unharvested, merge=False) + + return (masks_pred, agents, alive, phenotypes, unharvested) + + @staticmethod + def masks_to_segments(masks): + return masks.argmax(-1) + + @staticmethod + def flatten_plateau_with_masks(plateau, masks, alive, flatten_masks=True): + B,M,_ = alive.shape + Q = plateau.shape[-1] + if flatten_masks: + masks = F.one_hot((alive[...,None,None,:,0] * masks).argmax(-1), num_classes=M).float() + + flat_plateau = torch.zeros_like(plateau) + phenotypes = torch.zeros((B,M,Q), device=plateau.device).float() + for b in range(B): + m_inds = torch.where(alive[b,:,0])[0] + masks_b = masks[b,...,m_inds] + num_px = masks_b.sum((0,1)).clamp(min=1)[:,None] # [K,1] + phenos_b = torch.einsum('hwk,hwq->kq', masks_b, plateau[b]) / num_px # [K,Q] + flat_plateau_b = (masks_b[...,None] * phenos_b[None,None]).sum(-2) # [H,W,Q] + + phenotypes[b,m_inds,:] = phenos_b + flat_plateau[b] = flat_plateau_b + + _norm = lambda x: F.normalize(x, p=2, dim=-1) + return (_norm(flat_plateau), _norm(phenotypes)) + + @staticmethod + def plot_agents(agents, alive, size=[128,128]): + B,M,_ = alive.shape + agent_map = -1 * torch.ones((B,*size), device=alive.device, dtype=torch.long) + for b in range(B): + inds = torch.where(alive[b,:,0]) + for i in inds[0]: + pos = agents[b,i]*0.5 + 0.5 + pos = pos * torch.tensor(size, device=pos.device) + hmin, wmin = list(torch.floor(pos).long()) + hmax, wmax = list(torch.ceil(pos).long()) + agent_map[b,[hmin,hmin,hmax,hmax],[wmin,wmax,wmin,wmax]] = i + + return agent_map + +if __name__ == '__main__': + + Comp = Competition(num_masks=32, num_competition_rounds=5) + + left = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([1.,0.2,0.]) + middle = torch.ones(size=(32,16)).unsqueeze(-1) * torch.tensor([0.,1.,0.2]) + right = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([0.1,0.,1.]) + plateau = torch.cat([left, middle, right], dim=-2).unsqueeze(0) + masks, agents, alive, phenotypes, unharvested = Comp(plateau) + mask_inds = np.where(alive[0,:,0].numpy())[0] + print(np.argmax(masks[0,...], axis=-1)) + for ind in mask_inds: + print("num pixels in mask %d ---> %d" % (ind, (np.argmax(masks[0], -1) == ind).sum())) diff --git a/cwm/eval/Segmentation/archive/configs/__init__.py b/cwm/eval/Segmentation/archive/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/eval/Segmentation/archive/configs/mask_rcnn_cwm_vitdet_b_100ep.py b/cwm/eval/Segmentation/archive/configs/mask_rcnn_cwm_vitdet_b_100ep.py new file mode 100644 index 0000000000000000000000000000000000000000..34c89e7b446d7efd91c1940d35de782d9872c75a --- /dev/null +++ b/cwm/eval/Segmentation/archive/configs/mask_rcnn_cwm_vitdet_b_100ep.py @@ -0,0 +1,56 @@ +from functools import partial +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.config import CfgNode, LazyConfig +from detectron2.solver import WarmupParamScheduler +from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate + +from ..common.coco_loader_lsj import dataloader + + +# model = model_zoo.get_config("./models/mask_rcnn_cwm.py").model + +cfg_file = "./models/mask_rcnn_cwm.py" +model = LazyConfig.load(cfg_file).model + +# url = get_checkpoint_url(config_path) +# if "train" in cfg and "init_checkpoint" in cfg.train: +# cfg.train.init_checkpoint = url +# else: +# raise NotImplementedError + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = ( + #"/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping/checkpoint-799-encoder.pth" + './output/model_0004999.pth' +) +train.eval_period = 1e9 + +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +# train.max_iter = 184375 +# milestones = [163889, 177546] + +# 50 ep = 30730 iters * 96 images/iter / 118000 images/ep +train.max_iter = 61458 +milestones = [54629, 59182] + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=milestones, + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +# Optimizer +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7) +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep.py b/cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep.py new file mode 100644 index 0000000000000000000000000000000000000000..4937e3401476f9f7d2b06d86c0d05091ddd8fdf5 --- /dev/null +++ b/cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep.py @@ -0,0 +1,59 @@ +from functools import partial +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate +import os +from ..common.coco_loader_lsj import dataloader +from detectron2.data.datasets import register_coco_instances +from detectron2.config import CfgNode, LazyConfig +# model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model +# model.backbone.square_pad = 512 # change input size to 512x512 + +cfg_file = "./models/mask_rcnn_vitdet_v2.py" +model = LazyConfig.load(cfg_file).model +model.backbone.square_pad = 512 # change input size to 512x512 + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = ( + #"detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth?matching_heuristics=True" + #"/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_16x16_no_clumping_mr0.98/checkpoint-799-encoder.pth" + "/ccn2/u/honglinc/cwm_checkpoints/mae_vitb/mae_pretrain_vit_base-encoder.pth" +) +train.output_dir = os.path.dirname(train.init_checkpoint) + "/coco_finetune_512_v3" + +root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets")) +register_coco_instances("cls_agnostic_coco", {}, + os.path.join(root, "coco/annotations/coco_cls_agnostic_instances_val2017.json"), + os.path.join(root, "coco/val2017") + ) +dataloader.test.dataset.names = 'cls_agnostic_coco' +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +# train.max_iter = 184375 +# milestones = [163889, 177546] + +# 50 ep = 30730 iters * 96 images/iter / 118000 images/ep +train.max_iter = 61458 +milestones = [54629, 59182] + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=milestones, + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +# Optimizer +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7) +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_dino.py b/cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_dino.py new file mode 100644 index 0000000000000000000000000000000000000000..d94d204e188e3631f31a81933d29fddb360abfaa --- /dev/null +++ b/cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_dino.py @@ -0,0 +1,56 @@ +from functools import partial +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.config import CfgNode, LazyConfig +from detectron2.solver import WarmupParamScheduler + +from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate +import os +from ..common.coco_loader_lsj import dataloader + +# model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model +# model.backbone.square_pad = 512 # change input size to 512x512 + +cfg_file = "./models/mask_rcnn_cwm.py" +model = LazyConfig.load(cfg_file).model + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = ( + '/home/honglinc/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth' +) +train.output_dir = '/ccn2/u/honglinc/cwm_checkpoints/dinov2_coco_finetune_512' + +# model.backbone.net.window_size = 0 +# model.backbone.net.window_block_indexes = [] +# model.backbone.net.use_rel_pos = False +# model.backbone.net.drop_path_rate = 0. + +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +# train.max_iter = 184375 +# milestones = [163889, 177546] + +# 50 ep = 30730 iters * 96 images/iter / 118000 images/ep +train.max_iter = 61458 +milestones = [54629, 59182] + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=milestones, + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +# Optimizer +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7) +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_v2.py b/cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..6686314dd0f7b4fee2412aa4accf04ea37faa015 --- /dev/null +++ b/cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_v2.py @@ -0,0 +1,59 @@ +from functools import partial +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.config import CfgNode, LazyConfig +from detectron2.solver import WarmupParamScheduler + +from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate +import os +from ..common.coco_loader_lsj import dataloader +from detectron2.data.datasets import register_coco_instances +# model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model +# model.backbone.square_pad = 512 # change input size to 512x512 + +cfg_file = "./models/mask_rcnn_cwm.py" +model = LazyConfig.load(cfg_file).model + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = ( + # "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth?matching_heuristics=True" + '/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_16x16_no_clumping_mr0.90/checkpoint-799-encoder.pth' +) +train.output_dir = os.path.dirname(train.init_checkpoint) + "/coco_finetune_512_v3" +train.eval_period = 1e9 + +root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets")) +register_coco_instances("cls_agnostic_coco", {}, + os.path.join(root, "coco/annotations/coco_cls_agnostic_instances_val2017.json"), + os.path.join(root, "coco/val2017") + ) +dataloader.test.dataset.names = 'cls_agnostic_coco' +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +# train.max_iter = 184375 +# milestones = [163889, 177546] + +# 50 ep = 30730 iters * 96 images/iter / 118000 images/ep +train.max_iter = 61458 +milestones = [54629, 59182] + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=milestones, + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +# Optimizer +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7) +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/connected_component.py b/cwm/eval/Segmentation/archive/connected_component.py new file mode 100644 index 0000000000000000000000000000000000000000..fa89f784ae497a5f247138fa311b49207b246a17 --- /dev/null +++ b/cwm/eval/Segmentation/archive/connected_component.py @@ -0,0 +1,105 @@ +from kornia.contrib import connected_components +import torch +import pdb +import matplotlib.pyplot as plt +import time + +# def reorder_int_labels(x): +# _, y = torch.unique(x, return_inverse=True) +# y -= y.min() +# return y + +# def label_connected_component(labels, max_area=500, min_area=20, max_ccs=128, num_iterations=500): + +# assert len(labels.size()) == 2 + +# # per-label binary mask +# unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1] +# binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W] + +# # label connected components +# cc = connected_components(binary_masks.unsqueeze(1), num_iterations=num_iterations) # [?, 1, H, W] +# cc = reorder_int_labels(cc) +# bincount = torch.bincount(cc.long().flatten()) + +# # find all connected components (id, mask, area, valid) +# # cc_id = torch.nonzero(bincount) # [num_cc] +# cc_id = torch.argsort(bincount)[-max_ccs:] +# cc_mask = (cc == cc_id.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W] +# cc_area = bincount[cc_id] # [num_cc] +# valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc] +# valid = valid.reshape(-1, 1, 1) # [num_cc, 1, 1] + +# # final labels for connected component +# out = valid * cc_mask +# out = out.argmax(0) +# return out + +def reorder_int_labels(x): + _, y = torch.unique(x, return_inverse=True) + y -= y.min() + return y + + +def label_connected_component(labels, min_area=20, topk=256): + size = labels.size() + assert len(size) == 2 + max_area = size[0] * size[1] - 1 + + # per-label binary mask + unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1], where ? is the number of unique id + binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W] + + # label connected components + # cc is an integer tensor, each unique id represents a single connected component + cc = connected_components(binary_masks.unsqueeze(1), num_iterations=500) # [?, 1, H, W] + + # reorder indices in cc so that cc_area tensor below is a smaller + cc = reorder_int_labels(cc) + + # area of each connected components + cc_area = torch.bincount(cc.long().flatten().cpu()).cuda() # bincount on GPU is much slower + num_cc = cc_area.shape[0] + valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc] + + if num_cc < topk: + selected_cc = torch.arange(num_cc).cuda() + else: + _, selected_cc = torch.topk(cc_area, k=topk) + valid = valid[selected_cc] + + # collapse the 0th dimension, since there is only matched one connected component (across 0th dimension) + cc_mask = (cc == selected_cc.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W] + cc_mask = cc_mask * valid.reshape(-1, 1, 1) + out = cc_mask.argmax(0) + return out + + +# def reorder_int_labels(x): +# _, y = torch.unique(x, return_inverse=True) +# y -= y.min() +# return y + +# def label_connected_component(labels, max_area=500, min_area=20): + +# assert len(labels.size()) == 2 + +# # per-label binary mask +# unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1] +# binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W] + +# # label connected components +# cc = connected_components(binary_masks.unsqueeze(1)) # [?, 1, H, W] +# cc = reorder_int_labels(cc) +# bincount = torch.bincount(cc.long().flatten()) + +# # find all connected components (id, mask, area, valid) +# cc_id = torch.nonzero(bincount) # [num_cc] +# cc_mask = (cc == cc_id.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W] +# cc_area = bincount[cc_id] # [num_cc] +# valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc] +# valid = valid.reshape(-1, 1, 1) # [num_cc, 1, 1] + +# # final labels for connected component +# out = valid * cc_mask +# out = out.argmax(0) diff --git a/cwm/eval/Segmentation/archive/generate_zero_shot_segments.py b/cwm/eval/Segmentation/archive/generate_zero_shot_segments.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe82fc0db281c732190b976bb90f6b5e70519a4 --- /dev/null +++ b/cwm/eval/Segmentation/archive/generate_zero_shot_segments.py @@ -0,0 +1,251 @@ +import torch +import os +import glob +import time +from torchvision.io import read_image +import matplotlib.pyplot as plt +from scipy import ndimage +from PIL import Image +import bbnet.trainval.validator as validator +import modeling_pretrain_cleaned as vmae_transformers +import modeling_pretrain as vmae_transformers_old +import positional_vmae as pos_transformers +import big_models as big_transformers +import bbnet.models.preprocessor as preprocessor +import bbnet.models.error as error_generator +from functools import partial +import bbnet.models.teachers as teachers +from tqdm import tqdm +from torch.nn import functional as F +import argparse +import sys +import numpy as np +import json +import pycocotools.mask as mask_util +sys.path.append('/ccn2/u/honglinc/CutLER') +sys.path.append('/ccn2/u/honglinc/CutLER/maskcut') +sys.path.append('/ccn2/u/honglinc/CutLER/third_party') +import dino +from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners, get_masked_affinity_matrix +from third_party.TokenCut.unsupervised_saliency_detection import utils, metric +from third_party.TokenCut.unsupervised_saliency_detection.object_discovery import detect_box +from crf import densecrf +#from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners +# DINO hyperparameters +vit_arch = 'base' +vit_feat = 'k' +patch_size = 8 +# DINO pre-trained model +url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" +feat_dim = 768 +dino_backbone = dino.ViTFeat(url, feat_dim, vit_arch, vit_feat, patch_size) +dino_backbone = dino_backbone.eval().requires_grad_(False).cuda() + + +def get_dino_predominance(images, dims=[28, 28], current_mask=None, painting=None, img_size=[224, 224]): + input_dino = images + input_dino = input_dino - torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input_dino.device) + input_dino = input_dino / torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input_dino.device) + # input_dino = images.tensor + input_dino = torch.nn.functional.interpolate(input_dino, size=img_size, mode='bilinear') + features = dino_backbone(input_dino) + + predominence_map = [] + + for i in range(features.shape[0]): + feats = features[i] + if current_mask == None: + painting = torch.from_numpy(np.zeros(dims)) + painting = painting.to(feats) + else: + feats, painting = get_masked_affinity_matrix(painting, feats, current_mask, ps=dims[0]) + + A, D = get_affinity_matrix(feats, tau=0.15) + # get the second-smallest eigenvector + _, second_smallest_vec = second_smallest_eigenvector(A, D) + # get salient area + bipartition = get_salient_areas(second_smallest_vec) + + # check if we should reverse the partition based on: + # 1) peak of the 2nd smallest eigvec 2) object centric bias + seed = np.argmax(np.abs(second_smallest_vec)) + nc = check_num_fg_corners(bipartition, dims) + if nc >= 2: + reverse = True + else: + reverse = bipartition[seed] != 1 + if reverse: + second_smallest_vec = 1 - second_smallest_vec + second_smallest_vec = torch.tensor(second_smallest_vec).to(images.device).contiguous() + map = torch.nn.functional.interpolate(second_smallest_vec.reshape(1, 1, dims[0], dims[1]), size=img_size, + mode='bilinear') + map -= map.min() + map /= map.max() + predominence_map.append(map) + init_dist = torch.cat(predominence_map, dim=0).detach() + return init_dist, A, feats, painting + + + + +def interpolate_pos_encoding(pos_embed, n_frames, h, w): + N = pos_embed.shape[1] + if N == (h * w * n_frames): + return pos_embed + old_h = old_w = int((N / n_frames) ** 0.5) + patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2) + + patch_pos_embed = F.interpolate( + patch_pos_embed, + size=(h, w), + mode='bicubic', + ) + return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) + + + +def vis_results(x, targets_dict, annotation, name): + img = x[0, 0].permute(1, 2, 0).cpu() + fig, axs = plt.subplots(1, 1+len(targets_dict), figsize=(3*len(targets_dict), 3)) + axs[0].imshow(img) + axs[0].set_title('Image') + + for i, v in enumerate(targets_list): + v = v[0, 0] # .cpu() + axs[1+i].imshow((v[..., None] * img) + (~v[..., None] * torch.ones_like(img))) + axs[1+i].set_title(f'Segment {i}', fontsize=10) + + for ax in axs: + ax.set_axis_off() + + plt.show() + plt.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser('Generate zero-shot segments from CWM model', add_help=False) + parser.add_argument('--input_pattern', default='/ccn2/u/honglinc/datasets/coco/images/val2017/*', nargs='+', type=str, help='Pattern for input images') + parser.add_argument('--output', default='./output.pt', type=str, help='output path for saving the results') + parser.add_argument('--num_iter', default=1, type=int, help='number of iterations') + parser.add_argument('--visualize', action='store_true', help='Visualize the results') + args = parser.parse_args() + + ## Prepare for the extraction + image_list = glob.glob(args.input_pattern) if isinstance(args.input_pattern, str) else args.input_pattern + thresh = 0.5 + visualize = args.visualize + save_dict = {} + image_size = [480, 480] + patch_size = 8 + dims = [int(s / patch_size) for s in image_size] + + ## Load pretrained model + default_model_dir = '/ccn2/u/honglinc/cwm_checkpoints/' + model_func = vmae_transformers.vitb_8x8patch_3frames + ckpt_path = 'ablation_3frame_no_clumping_mr0.90_extra_data_ep400' # the original IMU-conditioned 4x4 + label = '3 frame 8x8' + teacher_func = teachers.iteration_segment_teacher_with_filter + + teacher = teacher_func( + model_func=model_func, + model_path=teachers.get_load_path(os.path.join(default_model_dir, ckpt_path), model_checkpoint=-1), + visualization_mode=visualize, + initial_sampling_distribution_kwargs={'num_samples': 20, 'num_active_patches': 1, 'num_passive_patches': 1}, + ).requires_grad_(False).cuda() + + teacher.predictor.encoder.pos_embed = interpolate_pos_encoding( + teacher.predictor.encoder.pos_embed, 3, dims[0], dims[1]) + teacher.predictor.pos_embed = interpolate_pos_encoding( + teacher.predictor.pos_embed, 3, dims[0], dims[1]) + teacher.predictor.image_size = image_size + + ## Start extracting segments + start = time.time() + + + if os.path.exists(args.output): + print('Load partial results from: ', args.output) + save_dict = torch.load(args.output) + print('Length of existing dict: ', len(save_dict)) + + for image_path in image_list: + + # Prepare input + image_name = image_path.split('/')[-1] + + if image_name in save_dict: + continue + + image = read_image(image_path) + if image.shape[0] == 1: + image = image.expand(3, -1, -1) + + x = torch.stack([image] * 3, dim=0) + x = torch.nn.functional.interpolate(x.float(), size=image_size, mode='bicubic')[None] / 255. + _x = x.to(torch.float16).cuda() + + targets_list = [] + # extract segments iteratively + for n in range(args.num_iter): + + # Compute predominance map from dino + if n == 0: + predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), dims=dims, img_size=image_size) + else: + predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), + current_mask=current_mask.cuda(), + painting=painting, dims=dims, + img_size=image_size) + + if visualize: + plt.imshow(predominance[0, 0].cpu()) + plt.title(f'Predominance (max:{predominance[0, 0].max()})') + plt.show() + + # mask out segments that are already extracted + if n > 0: + for mask in targets_list: + predominance[0, 0][mask[0, 0].cuda()] = 0 + + + # extract segments given predominance map + with torch.cuda.amp.autocast(enabled=True): + targets = teacher(_x, sampling_distribution=predominance)[0] + if n == 0: + targets_list = [targets.cpu() >= thresh] + else: + ratio = targets.mean() + mask = targets.cpu() >= thresh + iou = 0 + match_idx = None + + for idx, existing_mask in enumerate(targets_list): + _iou = metric.IoU(mask[0, 0], existing_mask[0, 0]) + if _iou > iou: + iou = _iou + match_idx = idx + + # remove segments if it has large IoU + if iou > 0.2 or ratio <= 0.01: + mask = torch.zeros_like(mask) + # elif iou > 0.1: + # mask[0, 0][targets_list[match_idx][0, 0]] = 0 + + targets_list.append(mask) + + current_mask = F.interpolate(targets, size=dims, mode='bilinear') >= thresh + + vid_name = image_path + save_dict[image_name] = targets_list + if visualize: + vis_results(x, targets_list, None, vid_name.split('/')[-2] + '.png') + + if (len(save_dict) + 1) % 1 == 0: + total = len(image_list) + num_completed = len(save_dict) + avg_time = (time.time() - start) / num_completed + eta = (total - num_completed) * avg_time / 60. + print(f'{num_completed} / {total} completed, avg. time per image: {avg_time:.2f} sec, eta: {eta:.1f} mins') + torch.save(save_dict, args.output) + ## Save the results + torch.save(save_dict, args.output) \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/generate_zero_shot_segments_parallel.sh b/cwm/eval/Segmentation/archive/generate_zero_shot_segments_parallel.sh new file mode 100644 index 0000000000000000000000000000000000000000..d62243711054d76330ca9d4305de3bed126d724f --- /dev/null +++ b/cwm/eval/Segmentation/archive/generate_zero_shot_segments_parallel.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + exit 1 +fi + +directory="$1" +pattern="$2" +N="$3" +num_iters="$4" +output_dir="$5" + +# Check if the directory exists +if [ ! -d "$directory" ]; then + echo "Error: Directory not found." + exit 1 +fi + +# Create the output directory if it doesn't exist +mkdir -p "$output_dir" + +# Get the list of files and subdirectories in the directory +contents=($(find "$directory" -type f -name "$pattern")) + +# Calculate the total number of items +num_items=${#contents[@]} + +# Calculate the approximate number of items in each list +items_per_list=$((num_items / N)) +remainder=$((num_items % N)) + +# Initialize variables +start_index=0 +background_pids=() + +# Split the items into N lists and execute the Python command in parallel +for ((i = 1; i <= N; i++)); do + end_index=$((start_index + items_per_list - 1)) + + # If there's a remainder, distribute the remaining items + if [ $i -le $remainder ]; then + end_index=$((end_index + 1)) + fi + + # Create a sublist from start_index to end_index + sublist=("${contents[@]:start_index:end_index - start_index + 1}") + + # Generate the output path + output_path="$output_dir/$i.pt" + + # Generate the Python command using the sublist, num_iters, and output_path + cuda_visible_devices=$((i - 1)) + cmd="CUDA_VISIBLE_DEVICES=$cuda_visible_devices python generate_zero_shot_segments.py --input_pattern ${sublist[*]} --num_iter $num_iters --output $output_path" + + # Execute the command in the background +# if [ $i -eq 3 ]; then +# eval "$cmd" & +# fi + eval "$cmd" & + + # Record the background process ID + background_pids+=($!) + + # Update the start_index for the next sublist + start_index=$((end_index + 1)) +done + +# Wait for all background processes to finish +for pid in "${background_pids[@]}"; do + wait "$pid" +done diff --git a/cwm/eval/Segmentation/archive/generate_zero_shot_segments_v2.py b/cwm/eval/Segmentation/archive/generate_zero_shot_segments_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f25dafa58e42a86da3d826f24b0be7d936fb7600 --- /dev/null +++ b/cwm/eval/Segmentation/archive/generate_zero_shot_segments_v2.py @@ -0,0 +1,300 @@ +import torch +import os +import glob +import time +from torchvision.io import read_image +import matplotlib.pyplot as plt +from scipy import ndimage +from PIL import Image + +#import bbnet.trainval.validator as validator +import modeling_pretrain_cleaned as vmae_transformers +import modeling_pretrain as vmae_transformers_old +#import positional_vmae as pos_transformers +#import big_models as big_transformers +import bbnet.models.preprocessor as preprocessor +import bbnet.models.error as error_generator +from functools import partial +#import bbnet.models.teachers as teachers +from tqdm import tqdm +from torch.nn import functional as F +import argparse +import sys +import numpy as np +import json +import pycocotools.mask as mask_util +sys.path.append('/ccn2/u/honglinc/CutLER') +sys.path.append('/ccn2/u/honglinc/CutLER/maskcut') +sys.path.append('/ccn2/u/honglinc/CutLER/third_party') +import dino +import maskcut +# from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners, get_masked_affinity_matrix +from third_party.TokenCut.unsupervised_saliency_detection import utils, metric +from third_party.TokenCut.unsupervised_saliency_detection.object_discovery import detect_box +from crf import densecrf +#from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners +# DINO hyperparameters +vit_arch = 'base' +vit_feat = 'k' +patch_size = 8 +# DINO pre-trained model +url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" +feat_dim = 768 +dino_backbone = dino.ViTFeat(url, feat_dim, vit_arch, vit_feat, patch_size) +dino_backbone = dino_backbone.eval().requires_grad_(False).cuda() + +def get_affinity_matrix(feats, tau, eps=1e-5): + # get affinity matrix via measuring patch-wise cosine similarity + feats = F.normalize(feats, p=2, dim=1) + A = (feats.permute(0, 2, 1) @ feats) + # convert the affinity matrix to a binary one. + A = A > tau + eps = torch.ones_like(A) * eps + A = torch.where(A.float() == 0, eps, A) + d_i = A.sum(-1) + D = torch.diag_embed(d_i) + return A, D + +def second_smallest_eigenvector(A, D): + # get the second smallest eigenvector from affinity matrix + _, eigenvectors = torch.lobpcg(D - A, B=D, k=2, largest=False) + second_smallest_vec = eigenvectors[:, :, 1] + return -second_smallest_vec + +def get_salient_areas(second_smallest_vec): + # get the area corresponding to salient objects. + avg = second_smallest_vec.mean(-1, keepdims=True) + bipartition = second_smallest_vec > avg + return bipartition + +def check_num_fg_corners(bipartition, dims): + # check number of corners belonging to the foreground + dims = [bipartition.shape[0]] + dims + bipartition_ = bipartition.reshape(dims) + top_l, top_r, bottom_l, bottom_r = bipartition_[:,0,0], bipartition_[:,0,-1], bipartition_[:,-1,0], bipartition_[:,-1,-1] + nc = top_l.int() + top_r.int() + bottom_l.int() + bottom_r.int() + return nc + +def get_dino_predominance(images, dims=[28, 28], current_mask=None, painting=None, img_size=[224, 224]): + input_dino = images + input_dino = input_dino - torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input_dino.device) + input_dino = input_dino / torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input_dino.device) + # input_dino = images.tensor + input_dino = torch.nn.functional.interpolate(input_dino, size=img_size, mode='bilinear') + feats = dino_backbone(input_dino) # [B, C, N] + B = feats.shape[0] + + predominence_map = [] + if current_mask == None: + painting = torch.from_numpy(np.zeros(dims)) + painting = painting.to(feats) + else: + feats, painting = get_masked_affinity_matrix(painting, feats, current_mask, ps=dims[0]) + + + + A, D = get_affinity_matrix(feats, tau=0.15) + # get the second-smallest eigenvector + + + #_second_smallest_vec = maskcut.second_smallest_eigenvector(A[10].cpu(), D[10].cpu()) + second_smallest_vec = second_smallest_eigenvector(A, D) + + # get salient area + bipartition = get_salient_areas(second_smallest_vec) + + # check if we should reverse the partition based on: + # 1) peak of the 2nd smallest eigvec 2) object centric bias + batch_inds = torch.arange(second_smallest_vec.shape[0]).to(second_smallest_vec).unsqueeze(0) + seed = torch.argmax(second_smallest_vec.abs(), dim=-1).unsqueeze(0) + seed = torch.cat([batch_inds, seed], dim=0).long() + + reverse = bipartition[list(seed)] !=1 + + nc = check_num_fg_corners(bipartition, dims) + reverse[nc >= 2] = True + second_smallest_vec[reverse] = 1 - second_smallest_vec[reverse] + + second_smallest_vec = torch.tensor(second_smallest_vec).to(images.device).contiguous() + map = torch.nn.functional.interpolate(second_smallest_vec.reshape(B, 1, dims[0], dims[1]), size=img_size, + mode='bicubic') + map -= map.min() + map /= map.max() + predominence_map.append(map) + init_dist = torch.cat(predominence_map, dim=0).detach().contiguous() + + return init_dist, A, feats, painting + + + + +def interpolate_pos_encoding(pos_embed, n_frames, h, w): + N = pos_embed.shape[1] + if N == (h * w * n_frames): + return pos_embed + old_h = old_w = int((N / n_frames) ** 0.5) + patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2) + + patch_pos_embed = F.interpolate( + patch_pos_embed, + size=(h, w), + mode='bicubic', + ) + return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) + + + +def vis_results(x, targets_dict, predominance, annotation, name): + B = x.shape[0] + + fig, axs = plt.subplots(B, 2+len(targets_dict), figsize=(3*len(targets_dict), 2*B)) + + for b in range(B): + img = x[b, 0].permute(1, 2, 0).cpu() + axs[b, 0].imshow(img) + axs[b, 0].set_title('Image') + axs[b, 1].imshow(predominance[b, 0].cpu()) + axs[b, 1].set_title('Predominace') + + for i, v in enumerate(targets_list): + v = v[b, 0] # .cpu() + axs[b, 1+i].imshow((v[..., None] * img) + (~v[..., None] * torch.ones_like(img))) + axs[b, 1+i].set_title(f'Segment {i}', fontsize=10) + + for ax in axs: + for a in ax: + a.set_axis_off() + + plt.show() + plt.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser('Generate zero-shot segments from CWM model', add_help=False) + parser.add_argument('--input_pattern', default='/ccn2/u/honglinc/datasets/coco/images/val2017/*', nargs='+', type=str, help='Pattern for input images') + parser.add_argument('--output', default='./output.pt', type=str, help='output path for saving the results') + parser.add_argument('--num_iter', default=1, type=int, help='number of iterations') + parser.add_argument('--visualize', action='store_true', help='Visualize the results') + args = parser.parse_args() + + ## Prepare for the extraction + image_list = glob.glob(args.input_pattern) if isinstance(args.input_pattern, str) else args.input_pattern + thresh = 0.5 + visualize = args.visualize + save_dict = {} + image_size = [480, 480] + patch_size = 8 + dims = [int(s / patch_size) for s in image_size] + batch_size = 10 + + ## Load pretrained model + default_model_dir = '/ccn2/u/honglinc/cwm_checkpoints/' + model_func = vmae_transformers.vitb_8x8patch_3frames + ckpt_path = 'ablation_3frame_no_clumping_mr0.90_extra_data_ep400' # the original IMU-conditioned 4x4 + label = '3 frame 8x8' + teacher_func = teachers.iteration_segment_teacher_with_filter + + teacher = teacher_func( + model_func=model_func, + model_path=teachers.get_load_path(os.path.join(default_model_dir, ckpt_path), model_checkpoint=-1), + visualization_mode=visualize, + initial_sampling_distribution_kwargs={'num_samples': 20, 'num_active_patches': 1, 'num_passive_patches': 1}, + ).requires_grad_(False).cuda() + + teacher.predictor.encoder.pos_embed = interpolate_pos_encoding( + teacher.predictor.encoder.pos_embed, 3, dims[0], dims[1]) + teacher.predictor.pos_embed = interpolate_pos_encoding( + teacher.predictor.pos_embed, 3, dims[0], dims[1]) + teacher.predictor.image_size = image_size + + ## Start extracting segments + start = time.time() + batch = [] + image_names = [] + import pdb;pdb.set_trace() + for image_path in sorted(image_list): + + # Prepare input + image_name = image_path.split('/')[-1] + image = read_image(image_path) + if image.shape[0] == 1: + image = image.expand(3, -1, -1) + + x = torch.stack([image] * 3, dim=0) + x = torch.nn.functional.interpolate(x.float(), size=image_size, mode='bicubic')[None] / 255. + + print('length', len(batch)) + if len(batch) < batch_size: + batch.append(x) + image_names.append(image_name) + continue + else: + x = torch.cat(batch, dim=0) + batch = [] + image_names = [] + + _x = x.to(torch.float16).cuda() + + targets_list = [] + # extract segments iteratively + for n in range(args.num_iter): + + # Compute predominance map from dino + if n == 0: + predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), dims=dims, img_size=image_size) + else: + raise ValueError('Not implemented') + predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), + current_mask=current_mask.cuda(), + painting=painting, dims=dims, + img_size=image_size) + + # mask out segments that are already extracted + if n > 0: + for mask in targets_list: + predominance[0, 0][mask[0, 0].cuda()] = 0 + + + # extract segments given predominance map + with torch.cuda.amp.autocast(enabled=True): + targets = teacher(_x, sampling_distribution=predominance)[0] + print('targets.shape', targets.shape) + if n == 0: + targets_list = [targets.cpu() >= thresh] + else: + ratio = targets.mean() + mask = targets.cpu() >= thresh + iou = 0 + match_idx = None + + for idx, existing_mask in enumerate(targets_list): + _iou = metric.IoU(mask[0, 0], existing_mask[0, 0]) + if _iou > iou: + iou = _iou + match_idx = idx + + # remove segments if it has large IoU + if iou > 0.2 or ratio <= 0.01: + mask = torch.zeros_like(mask) + # elif iou > 0.1: + # mask[0, 0][targets_list[match_idx][0, 0]] = 0 + + targets_list.append(mask) + + current_mask = F.interpolate(targets, size=dims, mode='bilinear') >= thresh + + vid_name = image_path + save_dict[image_name] = targets_list + if visualize: + vis_results(x, targets_list, predominance, None, vid_name.split('/')[-2] + '.png') + + if (len(save_dict) + 1) % 1 == 0: + total = len(image_list) + num_completed = len(save_dict) + avg_time = (time.time() - start) / num_completed + eta = (total - num_completed) * avg_time / 60. + print(f'{num_completed} / {total} completed, avg. time per image: {avg_time:.2f} sec, eta: {eta:.1f} mins') + print('remove save') + #torch.save(save_dict, args.output) + ## Save the results + torch.save(save_dict, args.output) \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/lazyconfig_train_net.py b/cwm/eval/Segmentation/archive/lazyconfig_train_net.py new file mode 100644 index 0000000000000000000000000000000000000000..75ead1f00a632f9e4e5587eaa9e0b6fc76985a61 --- /dev/null +++ b/cwm/eval/Segmentation/archive/lazyconfig_train_net.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Training script using the new "LazyConfig" python config files. + +This scripts reads a given python config file and runs the training or evaluation. +It can be used to train any models or dataset as long as they can be +instantiated by the recursive construction defined in the given config file. + +Besides lazy construction of models, dataloader, etc., this scripts expects a +few common configuration parameters currently defined in "configs/common/train.py". +To add more complicated training logic, you can easily add other configs +in the config file and implement a new train_net.py to handle them. +""" +import logging + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, instantiate +from detectron2.engine import ( + AMPTrainer, + SimpleTrainer, + default_argument_parser, + default_setup, + default_writers, + hooks, + launch, +) +from detectron2.engine.defaults import create_ddp_model +from detectron2.evaluation import inference_on_dataset, print_csv_format +from detectron2.utils import comm +import warnings +warnings.filterwarnings("ignore") +logger = logging.getLogger("detectron2") + + +def do_test(cfg, model): + if "evaluator" in cfg.dataloader: + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + return ret + + +def do_train(args, cfg): + """ + Args: + cfg: an object with the following attributes: + model: instantiate to a module + dataloader.{train,test}: instantiate to dataloaders + dataloader.evaluator: instantiate to evaluator for test set + optimizer: instantaite to an optimizer + lr_multiplier: instantiate to a fvcore scheduler + train: other misc config defined in `configs/common/train.py`, including: + output_dir (str) + init_checkpoint (str) + amp.enabled (bool) + max_iter (int) + eval_period, log_period (int) + device (str) + checkpointer (dict) + ddp (dict) + """ + model = instantiate(cfg.model) + logger = logging.getLogger("detectron2") + logger.info("Model:\n{}".format(model)) + model.to(cfg.train.device) + + cfg.optimizer.params.model = model + optim = instantiate(cfg.optimizer) + + train_loader = instantiate(cfg.dataloader.train) + + model = create_ddp_model(model, **cfg.train.ddp) + + trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim) + checkpointer = DetectionCheckpointer( + model, + cfg.train.output_dir, + trainer=trainer, + ) + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), + hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) + if comm.is_main_process() + else None, + hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), + hooks.PeriodicWriter( + default_writers(cfg.train.output_dir, cfg.train.max_iter), + period=cfg.train.log_period, + ) + if comm.is_main_process() + else None, + ] + ) + + checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) + if args.resume and checkpointer.has_checkpoint(): + # The checkpoint stores the training iteration that just finished, thus we start + # at the next iteration + start_iter = trainer.iter + 1 + else: + start_iter = 0 + trainer.train(start_iter, cfg.train.max_iter) + + +def main(args): + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + + if args.eval_only: + model = instantiate(cfg.model) + model.to(cfg.train.device) + model = create_ddp_model(model) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + print(do_test(cfg, model)) + else: + do_train(args, cfg) + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/linear_probing.py b/cwm/eval/Segmentation/archive/linear_probing.py new file mode 100644 index 0000000000000000000000000000000000000000..02d343319bb5ce5f5547b676dd86f78206013d67 --- /dev/null +++ b/cwm/eval/Segmentation/archive/linear_probing.py @@ -0,0 +1,59 @@ +import os +import argparse +import sys +import torch +import warnings +warnings.filterwarnings("ignore") +torch.multiprocessing.set_sharing_strategy('file_system') +# Set environment variables +# os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4,5,6' +os.environ['OMP_NUM_THREADS'] = '1' +os.environ['DETECTRON2_DATASETS'] = '/ccn2/u/honglinc/datasets' + +# Add necessary path +MASK2FORMER_PATH = '/ccn2/u/honglinc/Mask2Former' +BBNET_PATH = '/home/honglinc/BBNet' +sys.path.append(os.path.join(BBNET_PATH, 'bbnet/models/VideoMAE-main/')) +sys.path.append(BBNET_PATH) +sys.path.append(MASK2FORMER_PATH) + +# BBNet import +import modeling_pretrain as vmae_tranformers +from evaluate_segmentation_readout_helper_v2 import CWMSegmentPredictorV2 + +import detectron2.utils.comm as comm +from detectron2.evaluation import verify_results +from train_net import setup, Trainer, DetectionCheckpointer +from detectron2.engine import default_argument_parser, launch + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + if cfg.TEST.AUG.ENABLED: + res.update(Trainer.test_with_TTA(cfg, model)) + if comm.is_main_process(): + verify_results(cfg, res) + return res + + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/linear_probing_helper.py b/cwm/eval/Segmentation/archive/linear_probing_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..a89ecf7ba73ce15fa3b443feeb86dc645765ebad --- /dev/null +++ b/cwm/eval/Segmentation/archive/linear_probing_helper.py @@ -0,0 +1,535 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.ops import batched_nms, masks_to_boxes +from detectron2.config import configurable +from detectron2.data import MetadataCatalog +from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head +from detectron2.modeling.backbone import Backbone +from detectron2.modeling.postprocessing import sem_seg_postprocess +from detectron2.structures import Boxes, ImageList, Instances, BitMasks +from detectron2.utils.memory import retry_if_cuda_oom +from mask2former.modeling.criterion import SetCriterion +from mask2former.modeling.matcher import HungarianMatcher +import modeling_pretrain as vmae_tranformers +import matplotlib.pyplot as plt +from detectron2.utils.visualizer import Visualizer +import os +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets import register_coco_instances + +root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets")) +register_coco_instances("cls_agnostic_coco", {}, + os.path.join(root, "coco/annotations/coco_cls_agnostic_instances_val2017.json"), + os.path.join(root, "coco/val2017") + ) + + +@META_ARCH_REGISTRY.register() +class CWMSegmentPredictorV2(nn.Module): + """ + Main class for mask classification semantic segmentation architectures. + """ + + @configurable + def __init__( + self, + *, + criterion: nn.Module, + num_queries: int, + object_mask_threshold: float, + overlap_threshold: float, + metadata, + size_divisibility: int, + sem_seg_postprocess_before_inference: bool, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + # inference + semantic_on: bool, + panoptic_on: bool, + instance_on: bool, + test_topk_per_image: int, + output_dir: str, + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + sem_seg_head: a module that predicts semantic segmentation from backbone features + criterion: a module that defines the loss + num_queries: int, number of queries + object_mask_threshold: float, threshold to filter query based on classification score + for panoptic segmentation inference + overlap_threshold: overlap threshold used in general inference for panoptic segmentation + metadata: dataset meta, get `thing` and `stuff` category names for panoptic + segmentation inference + size_divisibility: Some backbones require the input height and width to be divisible by a + specific integer. We can use this to override such requirement. + sem_seg_postprocess_before_inference: whether to resize the prediction back + to original input size before semantic segmentation inference or after. + For high-resolution dataset like Mapillary, resizing predictions before + inference will cause OOM error. + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + semantic_on: bool, whether to output semantic segmentation prediction + instance_on: bool, whether to output instance segmentation prediction + panoptic_on: bool, whether to output panoptic segmentation prediction + test_topk_per_image: int, instance segmentation parameter, keep topk instances per image + """ + super().__init__() + + self.criterion = criterion + self.num_queries = num_queries + self.overlap_threshold = overlap_threshold + self.object_mask_threshold = object_mask_threshold + self.metadata = metadata + if size_divisibility < 0: + # use backbone size_divisibility if not set + size_divisibility = self.backbone.size_divisibility + self.size_divisibility = size_divisibility + self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + # additional args + self.semantic_on = semantic_on + self.instance_on = instance_on + self.panoptic_on = panoptic_on + self.test_topk_per_image = test_topk_per_image + + if not self.semantic_on: + assert self.sem_seg_postprocess_before_inference + + # Load CWM predictor + self.output_dir = output_dir + if 'cwm' in output_dir: + model_func = vmae_tranformers.base_8x8patch_2frames_1tube_flash + predictor = model_func().cuda() + + load_path = '/ccn2/u/feigelis/model_checkpoints/kevin_checkpoints/' + \ + 'fulltrain_kinetics_8x8patch_rotated_table_distributed_with_ddp' + \ + '_copied_from_oldnode/checkpoint-3199.pth' + + did_load = predictor.load_state_dict(torch.load(load_path, map_location=torch.device("cpu"))['model']) + print('Load CWM pretrained predictor', did_load) + self.predictor = predictor.eval().requires_grad_(False) + self.num_patches = self.predictor.encoder.num_patches + self.patch_size = self.predictor.encoder.patch_size[-1] + self.mask_ratio = 0.99 + num_hidden_layers = 4 + hidden_dim = 1024 + input_dim = self.predictor.decoder.embed_dim + + decoder_layers = [torch.nn.Linear(input_dim, hidden_dim), torch.nn.ReLU()] + for i in range(num_hidden_layers): + decoder_layers.append(torch.nn.Linear(hidden_dim, hidden_dim)) + decoder_layers.append(torch.nn.ReLU()) + decoder_layers.append(torch.nn.Linear(hidden_dim, num_queries)) + + self.decoder = torch.nn.Sequential(*decoder_layers).cuda() + + @classmethod + def from_config(cls, cfg): + + # Loss parameters: + no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT + + # loss weights + class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT + dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT + mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT + + # building criterion + matcher = HungarianMatcher( + cost_class=class_weight, + cost_mask=mask_weight, + cost_dice=dice_weight, + num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, + ) + + weight_dict = {"loss_mask": mask_weight, "loss_dice": dice_weight} + + losses = ["masks"] + + criterion = SetCriterion( + num_classes=80, + matcher=matcher, + weight_dict=weight_dict, + eos_coef=no_object_weight, + losses=losses, + num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, + oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO, + importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO, + ) + + return { + "criterion": criterion, + "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES, + "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD, + "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD, + "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), + "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, + "sem_seg_postprocess_before_inference": ( + cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE + or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON + or cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON + ), + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + # inference + "semantic_on": cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON, + "instance_on": cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON, + "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON, + "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE, + "output_dir": cfg.OUTPUT_DIR, + } + + @property + def device(self): + return self.pixel_mean.device + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + * "image": Tensor, image in (C, H, W) format. + * "instances": per-region ground truth + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model (may be different + from input resolution), used in inference. + Returns: + list[dict]: + each dict has the results for one image. The dict contains the following keys: + + * "sem_seg": + A Tensor that represents the + per-pixel segmentation prediced by the head. + The prediction has shape KxHxW that represents the logits of + each class for each pixel. + * "panoptic_seg": + A tuple that represent panoptic output + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. + segments_info (list[dict]): Describe each segment in `panoptic_seg`. + Each dict contains keys "id", "category_id", "isthing". + """ + + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + + ### + # image_size = images.image_sizes[0] + # processed_results = [] + # input_per_image = batched_inputs[0] + # height = input_per_image.get("height", image_size[0]) + # width = input_per_image.get("width", image_size[1]) + # + # gt_instances = [x["instances"] for x in batched_inputs] + # targets = [] + # for targets_per_image in gt_instances: + # # pad gt + # try: + # gt_masks = targets_per_image.gt_masks + # except: + # print('NO GT MASKS') + # gt_masks = torch.zeros(1, height, width) + # + # targets.append( + # { + # "labels": targets_per_image.gt_classes, + # "masks": gt_masks, + # } + # ) + # + # mask_cls_results = torch.ones(1, self.num_queries, 81)#.to(self.device) + # mask_pred_result = targets[0]['masks']#.to(self.device) + # + # processed_results.append({}) + # if self.instance_on: + # instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_results[0], mask_pred_result) + # processed_results[-1]["instances"] = instance_r + # return processed_results + ### + + with torch.cuda.amp.autocast(enabled=True): + with torch.no_grad(): + if not self.training: + # resize to patch size + x = F.interpolate(images.tensor, size=(224, 224), mode="bilinear", align_corners=False) + x = x.to(torch.float16).unsqueeze(2).expand(-1, -1, 2, -1, -1) + else: + x = images.tensor.to(torch.float16).unsqueeze(2).expand(-1, -1, 2, -1, -1) + + # mask out the second frame + mask = torch.zeros([x.shape[0], self.num_patches]).to(x.device).bool() + mask[:, int(self.num_patches // 2):] = 1 + + # num_visibles = int((1 - self.mask_ratio) * int(self.num_patches // 2)) + 1 + # rand_idx = torch.randint(low=int(self.num_patches//2), high=self.num_patches, size=(x.shape[0], int(num_visibles))) + # for i in range(x.shape[0]): + # mask[i, rand_idx[i]] = 0 + + feature = self.predictor.encoder(x, mask=mask) + feature = self.predictor.encoder_to_decoder(feature) + # out = self.predictor(x, mask) + + logits = self.decoder(feature).float() + B, N, _ = logits.shape + pred_masks = logits.view(B, int(N ** 0.5), int(N ** 0.5), self.num_queries).permute(0, 3, 1, + 2) # [B, num_queries, H, W] + outputs = {"pred_masks": pred_masks} + + if self.training: + # mask classification target + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + targets = self.prepare_targets(gt_instances, images) + else: + targets = None + + # bipartite matching-based loss + losses = self.criterion(outputs, targets) + + for k in list(losses.keys()): + if k in self.criterion.weight_dict: + losses[k] *= self.criterion.weight_dict[k] + else: + # remove this loss if not specified in `weight_dict` + losses.pop(k) + return losses + else: + # mask_cls_results = outputs["pred_logits"] + mask_cls_results = torch.ones(x.shape[0], self.num_queries, 81).to(self.device) + mask_pred_results = outputs["pred_masks"] + # upsample masks + mask_pred_results = F.interpolate( + mask_pred_results, + size=(images.tensor.shape[-2], images.tensor.shape[-1]), + mode="bilinear", + align_corners=False, + ) + + # if "instances" in batched_inputs[0]: + # gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + # targets = self.prepare_targets(gt_instances, images) + # else: + # targets = None + + del outputs + + processed_results = [] + for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( + mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + processed_results.append({}) + + if self.sem_seg_postprocess_before_inference: + mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( + mask_pred_result, image_size, height, width + ) + mask_cls_result = mask_cls_result.to(mask_pred_result) + + # semantic segmentation inference + if self.semantic_on: + r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result) + if not self.sem_seg_postprocess_before_inference: + r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width) + processed_results[-1]["sem_seg"] = r + + # panoptic segmentation inference + if self.panoptic_on: + panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result) + processed_results[-1]["panoptic_seg"] = panoptic_r + + # instance segmentation inference + if self.instance_on: + instance_r, nms_idx = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result) + processed_results[-1]["instances"] = instance_r + + # Visualization + ''' + + rgb_image = F.interpolate(images.tensor.float(), size=(height, width), mode='bilinear') + visualizer = Visualizer(rgb_image.cpu().detach()[0].permute(1,2,0)) + visualizer = visualizer.draw_instance_predictions(instance_r) + + recon = torch.zeros(1, self.num_patches, self.patch_size ** 2 * 3) + recon[mask] = out.float().cpu().detach() + recon = self.unpatchify(recon[:, int(self.num_patches // 2):]) + recon = recon[0].permute(1, 2, 0).float().clamp(0, 1) + # fig, axs = plt.subplots(1, 7, figsize=(20, 3)) + # + # axs[0].imshow(images.tensor.float()[0].permute(1, 2, 0).cpu().detach()) + # axs[1].imshow(images.tensor.float()[0].permute(1, 2, 0).cpu().detach()) + # # axs[1].imshow(batched_inputs[0]['instances'].gt_masks.argmax(0)) + # axs[2].imshow(recon) + # axs[3].imshow(feature[0].view(28, 28, -1)[..., 0:3].cpu().detach().float()) + # axs[4].imshow(feature[0].view(28, 28, -1)[..., 100:103].cpu().detach().float()) + # axs[5].imshow(feature[0].view(28, 28, -1)[..., 200:203].cpu().detach().float()) + # axs[6].imshow(visualizer.get_image()) + file_name = batched_inputs[0]['file_name'].split('/')[-1].split('.jpg')[0] + # for a in axs: + # a.set_axis_off() + fig, axs = plt.subplots(1, 2, figsize=(16, 6)) + axs[0].imshow(images.tensor.float()[0].permute(1, 2, 0).cpu().detach()) + axs[1].imshow(visualizer.get_image()) + plt.savefig(f"/ccn2/u/honglinc/temp/{file_name}.png", bbox_inches='tight') + + fig, axs = plt.subplots(10, 10, figsize=(10, 10)) + for a in axs: + for _a in a: + _a.set_axis_off() + + for i in range(mask_pred_result.shape[0]): + # print(mask_pred_result.shape, height, width) + mask_area_ratio = mask_pred_result[i].sigmoid().float().flatten().sum() / (height * width) + axs[i // 10, i % 10].imshow(mask_pred_result[i].cpu().detach() > 0) + nms = 1 if i in nms_idx else -1 + axs[i // 10, i % 10].set_title(f'{mask_area_ratio.item():.2f}, {nms}', fontsize=11) + plt.savefig(f"/ccn2/u/honglinc/temp/{file_name}_mask.png", bbox_inches='tight') + ''' + + return processed_results + + def prepare_targets(self, targets, images): + h_pad, w_pad = images.tensor.shape[-2:] + new_targets = [] + for targets_per_image in targets: + # pad gt + gt_masks = targets_per_image.gt_masks + padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device) + padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks + new_targets.append( + { + "labels": targets_per_image.gt_classes, + "masks": padded_masks, + } + ) + return new_targets + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) + return semseg + + def panoptic_inference(self, mask_cls, mask_pred): + scores, labels = F.softmax(mask_cls, dim=-1).max(-1) + mask_pred = mask_pred.sigmoid() + + keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_masks = mask_pred[keep] + cur_mask_cls = mask_cls[keep] + cur_mask_cls = cur_mask_cls[:, :-1] + + cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks + + h, w = cur_masks.shape[-2:] + panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) + segments_info = [] + + current_segment_id = 0 + + if cur_masks.shape[0] == 0: + # We didn't detect any mask :( + return panoptic_seg, segments_info + else: + # take argmax + cur_mask_ids = cur_prob_masks.argmax(0) + stuff_memory_list = {} + for k in range(cur_classes.shape[0]): + pred_class = cur_classes[k].item() + isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() + mask_area = (cur_mask_ids == k).sum().item() + original_area = (cur_masks[k] >= 0.5).sum().item() + mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) + + if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: + if mask_area / original_area < self.overlap_threshold: + continue + + # merge stuff regions + if not isthing: + if int(pred_class) in stuff_memory_list.keys(): + panoptic_seg[mask] = stuff_memory_list[int(pred_class)] + continue + else: + stuff_memory_list[int(pred_class)] = current_segment_id + 1 + + current_segment_id += 1 + panoptic_seg[mask] = current_segment_id + + segments_info.append( + { + "id": current_segment_id, + "isthing": bool(isthing), + "category_id": int(pred_class), + } + ) + + return panoptic_seg, segments_info + + def instance_inference(self, mask_cls, mask_pred): + # mask_pred is already processed to have the same shape as original input + image_size = mask_pred.shape[-2:] + + mask_area_ratio = (mask_pred > 0).float().flatten(1, 2).sum(1) / (image_size[0] * image_size[1]) + mask_area_filter = (mask_area_ratio > 0.01) & (mask_area_ratio < 0.9) + mask_pred = mask_pred[mask_area_filter] + original_idx = torch.arange(mask_area_filter.shape[0])[mask_area_filter] + try: + box = masks_to_boxes(mask_pred > 0) + scores = (mask_pred.sigmoid().flatten(1) * (mask_pred > 0).flatten(1)).sum(1) / ( + (mask_pred > 0).flatten(1).sum(1) + 1e-6) + nms_idx = batched_nms(box, scores, torch.zeros(box.shape[0]).long(), 0.3) + + mask_pred = mask_pred[nms_idx] + box = box[nms_idx] + + except Exception as e: + import pdb; + pdb.set_trace() + print(e, mask_pred.shape, mask_area_filter.sum()) + box = torch.zeros(mask_pred.shape[0], 4).to(mask_pred) + + nms_idx = original_idx[nms_idx] + mask_pred = mask_pred.cpu() + + result = Instances(image_size) + # mask (before sigmoid) + result.pred_masks = (mask_pred > 0).float() + result.pred_boxes = Boxes(box.cpu()) + # Uncomment the following to get boxes from masks (this is slow) + # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() + + # calculate average mask prob + mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / ( + result.pred_masks.flatten(1).sum(1) + 1e-6) + scores_per_image = torch.ones(mask_pred.size(0)).to(mask_pred.device) + labels_per_image = torch.zeros(mask_pred.size(0)).to(mask_pred.device) + + result.scores = scores_per_image * mask_scores_per_image + result.pred_classes = labels_per_image + return result, nms_idx + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_size + h = w = int(x.shape[1] ** .5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/merge_results_into_json.py b/cwm/eval/Segmentation/archive/merge_results_into_json.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe9fb52cb1e8572e49d7c31bf9908362ea047db --- /dev/null +++ b/cwm/eval/Segmentation/archive/merge_results_into_json.py @@ -0,0 +1,186 @@ +import os +import json +import glob +import torch +import datetime +import argparse +import torch.nn.functional as F +import numpy as np +import pycocotools.mask as mask_util +def create_image_info(image_id, file_name, image_size, + date_captured=datetime.datetime.utcnow().isoformat(' '), + license_id=1, coco_url="", flickr_url=""): + """Return image_info in COCO style + Args: + image_id: the image ID + file_name: the file name of each image + image_size: image size in the format of (width, height) + date_captured: the date this image info is created + license: license of this image + coco_url: url to COCO images if there is any + flickr_url: url to flickr if there is any + """ + image_info = { + "id": image_id, + "file_name": file_name, + "width": image_size[0], + "height": image_size[1], + "date_captured": date_captured, + "license": license_id, + "coco_url": coco_url, + "flickr_url": flickr_url + } + return image_info + + +def create_annotation_info(annotation_id, image_id, category_info, binary_mask, + image_size=None, bounding_box=None): + """Return annotation info in COCO style + Args: + annotation_id: the annotation ID + image_id: the image ID + category_info: the information on categories + binary_mask: a 2D binary numpy array where '1's represent the object + file_name: the file name of each image + image_size: image size in the format of (width, height) + bounding_box: the bounding box for detection task. If bounding_box is not provided, + we will generate one according to the binary mask. + """ + upper = np.max(binary_mask) + lower = np.min(binary_mask) + thresh = upper / 2.0 + binary_mask[binary_mask > thresh] = upper + binary_mask[binary_mask <= thresh] = lower + if image_size is not None: + binary_mask = resize_binary_mask(binary_mask.astype(np.uint8), image_size) + + binary_mask_encoded = mask_util.encode(np.asfortranarray(binary_mask.astype(np.uint8))) + + area = mask_util.area(binary_mask_encoded) + if area < 1: + return None + + if bounding_box is None: + bounding_box = mask_util.toBbox(binary_mask_encoded) + + rle = mask_util.encode(np.array(binary_mask[...,None], order="F", dtype="uint8"))[0] + rle['counts'] = rle['counts'].decode('ascii') + segmentation = rle + + annotation_info = { + "id": annotation_id, + "image_id": image_id, + "category_id": category_info["id"], + "iscrowd": 0, + "area": area.tolist(), + "bbox": bounding_box.tolist(), + "segmentation": segmentation, + "width": binary_mask.shape[1], + "height": binary_mask.shape[0], + } + + return annotation_info + +# necessay info used for coco style annotations +INFO = { + "description": "ImageNet-1K: pseudo-masks with MaskCut", + "url": "https://github.com/facebookresearch/CutLER", + "version": "1.0", + "year": 2023, + "contributor": "Xudong Wang", + "date_created": datetime.datetime.utcnow().isoformat(' ') +} + +LICENSES = [ + { + "id": 1, + "name": "Apache License", + "url": "https://github.com/facebookresearch/CutLER/blob/main/LICENSE" + } +] + +# only one class, i.e. foreground +CATEGORIES = [ + { + 'id': 1, + 'name': 'fg', + 'supercategory': 'fg', + }, +] + +convert = lambda text: int(text) if text.isdigit() else text.lower() +natrual_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] + +output = { + "info": INFO, + "licenses": LICENSES, + "categories": CATEGORIES, + "images": [], + "annotations": []} + +category_info = { + "is_crowd": 0, + "id": 1 +} + +if __name__ == "__main__": + + parser = argparse.ArgumentParser('Merge pytorch results file into json') + + parser.add_argument('--base-dir', type=str, + default='annotations/', + help='Dir to the generated annotation .pt files with CWM') + parser.add_argument('--save-path', type=str, default="coco_train_fixsize480_N3.json", + help='Path to save the merged annotation file') + args = parser.parse_args() + + file_list = glob.glob(os.path.join(args.base_dir, '*', '*')) + + ann_file = '/ccn2/u/honglinc/datasets/coco/annotations/instances_train2017.json' + with open(ann_file, 'r') as file: + gt_json = json.load(file) + + image_id, segmentation_id = 1, 1 + image_names = [] + for file_name in file_list: + print('processing file name', file_name) + + data = torch.load(file_name) + + for img_name, mask_list in data.items(): + + for img in gt_json['images']: + if img['file_name'] == img_name: + height = img['height'] + width = img['width'] + break + + flag = img_name not in image_names + if flag: + image_info = create_image_info( + image_id, img_name, (height, width, 3)) + output["images"].append(image_info) + image_names.append(img_name) + + + for mask in mask_list: + # create coco-style annotation info + + if mask.sum() == 0: + continue + pseudo_mask = F.interpolate(mask.float(), size=(height, width), mode='bicubic') > 0.5 + pseudo_mask = pseudo_mask[0,0].numpy() + annotation_info = create_annotation_info( + segmentation_id, image_id, category_info, pseudo_mask.astype(np.uint8), None) + if annotation_info is not None: + output["annotations"].append(annotation_info) + segmentation_id += 1 + if flag: + image_id += 1 + print(image_id, segmentation_id) + + # save annotations + with open(args.save_path, 'w') as output_json_file: + json.dump(output, output_json_file) + print(f'dumping {args.save_path}') + print("Done: {} images; {} anns.".format(len(output['images']), len(output['annotations']))) diff --git a/cwm/eval/Segmentation/archive/models/__init__.py b/cwm/eval/Segmentation/archive/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/eval/Segmentation/archive/models/mask_rcnn_cwm.py b/cwm/eval/Segmentation/archive/models/mask_rcnn_cwm.py new file mode 100644 index 0000000000000000000000000000000000000000..203a1fe7a38bf816e7616ace6d57d8583e88c7e1 --- /dev/null +++ b/cwm/eval/Segmentation/archive/models/mask_rcnn_cwm.py @@ -0,0 +1,261 @@ +from functools import partial +import torch.nn as nn +from detectron2.config import LazyCall as L +from detectron2.modeling import ViT +from detectron2.modeling import SimpleFeaturePyramid as BaseSimpleFeaturePyramid +from detectron2.modeling.backbone.fpn import LastLevelMaxPool +from detectron2.layers import CNNBlockBase, Conv2d, get_norm +import sys +sys.path.append('../../') +from modeling_pretrain_cleaned import PretrainVisionTransformer +from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous +from models.mask_rcnn_fpn_v2 import model, constants +from detectron2.modeling.backbone import Backbone +import torch +import math +import torch.nn.functional as F +import time + +model.pixel_mean = constants['imagenet_rgb256_mean'] +model.pixel_std = constants['imagenet_rgb256_std'] +model.input_format = "RGB" + +class ViT(Backbone): + def __init__(self, + img_size=224, + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_num_classes=0, + decoder_embed_dim=384, + decoder_num_heads=16, + decoder_depth=8, + mlp_ratio=4, + qkv_bias=True, + k_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + patch_size=(8, 8), + num_frames=3, + tubelet_size=1, + use_flash_attention=True, + return_detectron_format=True, + out_feature='last_feat' + ): + super().__init__() + self.model = PretrainVisionTransformer( # Single-scale ViT backbone + img_size=img_size, + encoder_embed_dim=encoder_embed_dim, + encoder_depth=encoder_depth, + encoder_num_heads=encoder_num_heads, + encoder_num_classes=encoder_num_classes, + decoder_embed_dim=decoder_embed_dim, + decoder_num_heads=decoder_num_heads, + decoder_depth=decoder_depth, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + k_bias=k_bias, + norm_layer=norm_layer, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + use_flash_attention=use_flash_attention, + return_detectron_format=return_detectron_format, + out_feature=out_feature + ) + self._out_features = [out_feature] + self._out_feature_channels = {out_feature: encoder_embed_dim * 2} + self._out_feature_strides = {out_feature: patch_size[0]} + self.patch_hw = 512 // patch_size[0] + self.num_frames = num_frames + pos_embed = self.get_abs_pos(self.model.encoder.pos_embed, num_frames, [self.patch_hw, self.patch_hw]) + self.model.encoder.pos_embed = pos_embed[:, 0:self.patch_hw**2 * (self.num_frames - 1), :] + + def forward(self, x): + B = x.shape[0] + x = x.unsqueeze(2).expand(-1, -1, self.num_frames-1, -1, -1) + mask = torch.zeros(B, self.patch_hw**2 * (self.num_frames - 1), dtype=torch.bool).to(x.device) + return self.model(x, mask) + + def get_abs_pos(self, abs_pos, num_frames, hw): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + + h, w = hw + + xy_num = abs_pos.shape[1] // num_frames + size = int(math.sqrt(xy_num)) + assert size * size * num_frames == abs_pos.shape[1] + abs_pos = abs_pos.view(num_frames, xy_num, -1) + + if size != h or size != w: + new_abs_pos = torch.nn.functional.interpolate( + abs_pos.reshape(num_frames, size, size, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ) + + return new_abs_pos.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) + else: + return abs_pos + + +class SimpleFeaturePyramid(BaseSimpleFeaturePyramid): + """ + This module implements SimpleFeaturePyramid in :paper:`vitdet`. + It creates pyramid features built on top of the input feature map. + """ + + def __init__( + self, + net, + in_feature, + out_channels, + scale_factors, + top_block=None, + norm="LN", + square_pad=0, + ): + """ + Args: + net (Backbone): module representing the subnetwork backbone. + Must be a subclass of :class:`Backbone`. + in_feature (str): names of the input feature maps coming + from the net. + out_channels (int): number of channels in the output feature maps. + scale_factors (list[float]): list of scaling factors to upsample or downsample + the input features for creating pyramid features. + top_block (nn.Module or None): if provided, an extra operation will + be performed on the output of the last (smallest resolution) + pyramid output, and the result will extend the result list. The top_block + further downsamples the feature map. It must have an attribute + "num_levels", meaning the number of extra pyramid levels added by + this block, and "in_feature", which is a string representing + its input feature (e.g., p5). + norm (str): the normalization to use. + square_pad (int): If > 0, require input images to be padded to specific square size. + """ + super(BaseSimpleFeaturePyramid, self).__init__() + assert isinstance(net, Backbone) + + self.scale_factors = scale_factors + + input_shapes = net.output_shape() + strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors] + _assert_strides_are_log2_contiguous(strides) + + dim = input_shapes[in_feature].channels + self.stages = [] + use_bias = norm == "" + for idx, scale in enumerate(scale_factors): + out_dim = dim + if scale == 4.0: + layers = [ + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + get_norm(norm, dim // 2), + nn.GELU(), + nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), + ] + out_dim = dim // 4 + elif scale == 2.0: + layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] + out_dim = dim // 2 + elif scale == 1.0: + layers = [] + elif scale == 0.5: + layers = [nn.MaxPool2d(kernel_size=2, stride=2)] + elif scale == 0.25: + layers = [nn.MaxPool2d(kernel_size=4, stride=4)] + else: + raise NotImplementedError(f"scale_factor={scale} is not supported yet.") + + layers.extend( + [ + Conv2d( + out_dim, + out_channels, + kernel_size=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + ), + Conv2d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + ), + ] + ) + layers = nn.Sequential(*layers) + + stage = int(math.log2(strides[idx])) + self.add_module(f"simfp_{stage}", layers) + self.stages.append(layers) + + self.net = net + self.in_feature = in_feature + self.top_block = top_block + # Return feature names are "p", like ["p2", "p3", ..., "p6"] + self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} + # top block output feature maps. + if self.top_block is not None: + for s in range(stage, stage + self.top_block.num_levels): + self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) + + self._out_features = list(self._out_feature_strides.keys()) + self._out_feature_channels = {k: out_channels for k in self._out_features} + self._size_divisibility = strides[-1] + self._square_pad = square_pad + + +# Base +embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1 +# Creates Simple Feature Pyramid from ViT backbone +model.backbone = L(SimpleFeaturePyramid)( + net=L(ViT)( + img_size=224, + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_num_classes=0, + decoder_embed_dim=384, + decoder_num_heads=16, + decoder_depth=8, + mlp_ratio=4, + qkv_bias=True, + k_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + patch_size=(16, 16), #(8, 8), + num_frames=3, + tubelet_size=1, + return_detectron_format=True, + use_flash_attention=True, + out_feature='last_feat' + ), + in_feature="${.net.out_feature}", + out_channels=256, + scale_factors=(4.0, 2.0, 1.0, 0.5, 0.25), + top_block=L(LastLevelMaxPool)(), + norm="LN", + square_pad=512, +) + +model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN" + +# 2conv in RPN: +model.proposal_generator.head.conv_dims = [-1, -1] + +# 4conv1fc box head +model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] +model.roi_heads.box_head.fc_dims = [1024] \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/models/mask_rcnn_fpn.py b/cwm/eval/Segmentation/archive/models/mask_rcnn_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..0e82d4b10987e102e466e4d4ce1116627e22a698 --- /dev/null +++ b/cwm/eval/Segmentation/archive/models/mask_rcnn_fpn.py @@ -0,0 +1,103 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.meta_arch import GeneralizedRCNN +from detectron2.modeling.anchor_generator import DefaultAnchorGenerator +from detectron2.modeling.backbone.fpn import LastLevelMaxPool +from detectron2.modeling.backbone import BasicStem, FPN, ResNet +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.poolers import ROIPooler +from detectron2.modeling.proposal_generator import RPN, StandardRPNHead +from detectron2.modeling.roi_heads import ( + StandardROIHeads, + FastRCNNOutputLayers, + MaskRCNNConvUpsampleHead, + FastRCNNConvFCHead, +) + +constants = dict( + imagenet_rgb256_mean=[123.675, 116.28, 103.53], + imagenet_rgb256_std=[58.395, 57.12, 57.375], + imagenet_bgr256_mean=[103.530, 116.280, 123.675], + # When using pre-trained models in Detectron1 or any MSRA models, + # std has been absorbed into its conv1 weights, so the std needs to be set 1. + # Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) + imagenet_bgr256_std=[1.0, 1.0, 1.0], +) + +model = L(GeneralizedRCNN)( + backbone=L(FPN)( + bottom_up=L(ResNet)( + stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), + stages=L(ResNet.make_default_stages)( + depth=50, + stride_in_1x1=True, + norm="FrozenBN", + ), + out_features=["res2", "res3", "res4", "res5"], + ), + in_features="${.bottom_up.out_features}", + out_channels=256, + top_block=L(LastLevelMaxPool)(), + ), + proposal_generator=L(RPN)( + in_features=["p2", "p3", "p4", "p5", "p6"], + head=L(StandardRPNHead)(in_channels=256, num_anchors=3), + anchor_generator=L(DefaultAnchorGenerator)( + sizes=[[32], [64], [128], [256], [512]], + aspect_ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + offset=0.0, + ), + anchor_matcher=L(Matcher)( + thresholds=[0.3, 0.7], labels=[0, -1, 1], allow_low_quality_matches=True + ), + box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]), + batch_size_per_image=256, + positive_fraction=0.5, + pre_nms_topk=(2000, 1000), + post_nms_topk=(1000, 1000), + nms_thresh=0.7, + ), + roi_heads=L(StandardROIHeads)( + num_classes=2, + batch_size_per_image=512, + positive_fraction=0.25, + proposal_matcher=L(Matcher)( + thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False + ), + box_in_features=["p2", "p3", "p4", "p5"], + box_pooler=L(ROIPooler)( + output_size=7, + scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), + sampling_ratio=0, + pooler_type="ROIAlignV2", + ), + box_head=L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[], + fc_dims=[1024, 1024], + ), + box_predictor=L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)), + num_classes="${..num_classes}", + ), + mask_in_features=["p2", "p3", "p4", "p5"], + mask_pooler=L(ROIPooler)( + output_size=14, + scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), + sampling_ratio=0, + pooler_type="ROIAlignV2", + ), + mask_head=L(MaskRCNNConvUpsampleHead)( + input_shape=ShapeSpec(channels=256, width=14, height=14), + num_classes="${..num_classes}", + conv_dims=[256, 256, 256, 256, 256], + ), + ), + pixel_mean=constants['imagenet_bgr256_mean'], + pixel_std=constants['imagenet_bgr256_std'], + input_format="BGR", +) \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/models/mask_rcnn_fpn_v2.py b/cwm/eval/Segmentation/archive/models/mask_rcnn_fpn_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..31bd382cf93a1f29a6064d5f55e83743d0c92bb2 --- /dev/null +++ b/cwm/eval/Segmentation/archive/models/mask_rcnn_fpn_v2.py @@ -0,0 +1,106 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.meta_arch import GeneralizedRCNN +from detectron2.modeling.anchor_generator import DefaultAnchorGenerator +from detectron2.modeling.backbone.fpn import LastLevelMaxPool +from detectron2.modeling.backbone import BasicStem, FPN, ResNet +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.poolers import ROIPooler +from detectron2.modeling.proposal_generator import RPN, StandardRPNHead +from detectron2.modeling.roi_heads import ( + StandardROIHeads, + FastRCNNOutputLayers, + MaskRCNNConvUpsampleHead, + FastRCNNConvFCHead, +) +import sys +sys.path.append('/ccn2/u/honglinc/CutLER/cutler') +from modeling.roi_heads import CustomStandardROIHeads + +constants = dict( + imagenet_rgb256_mean=[123.675, 116.28, 103.53], + imagenet_rgb256_std=[58.395, 57.12, 57.375], + imagenet_bgr256_mean=[103.530, 116.280, 123.675], + # When using pre-trained models in Detectron1 or any MSRA models, + # std has been absorbed into its conv1 weights, so the std needs to be set 1. + # Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) + imagenet_bgr256_std=[1.0, 1.0, 1.0], +) + +model = L(GeneralizedRCNN)( + backbone=L(FPN)( + bottom_up=L(ResNet)( + stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), + stages=L(ResNet.make_default_stages)( + depth=50, + stride_in_1x1=True, + norm="FrozenBN", + ), + out_features=["res2", "res3", "res4", "res5"], + ), + in_features="${.bottom_up.out_features}", + out_channels=256, + top_block=L(LastLevelMaxPool)(), + ), + proposal_generator=L(RPN)( + in_features=["p2", "p3", "p4", "p5", "p6"], + head=L(StandardRPNHead)(in_channels=256, num_anchors=3), + anchor_generator=L(DefaultAnchorGenerator)( + sizes=[[32], [64], [128], [256], [512]], + aspect_ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + offset=0.0, + ), + anchor_matcher=L(Matcher)( + thresholds=[0.3, 0.7], labels=[0, -1, 1], allow_low_quality_matches=True + ), + box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]), + batch_size_per_image=256, + positive_fraction=0.5, + pre_nms_topk=(2000, 1000), + post_nms_topk=(1000, 1000), + nms_thresh=0.7, + ), + roi_heads=L(CustomStandardROIHeads)( + num_classes=1, + batch_size_per_image=512, + positive_fraction=0.25, + proposal_matcher=L(Matcher)( + thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False + ), + box_in_features=["p2", "p3", "p4", "p5"], + box_pooler=L(ROIPooler)( + output_size=7, + scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), + sampling_ratio=0, + pooler_type="ROIAlignV2", + ), + box_head=L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[], + fc_dims=[1024, 1024], + ), + box_predictor=L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)), + num_classes="${..num_classes}", + ), + mask_in_features=["p2", "p3", "p4", "p5"], + mask_pooler=L(ROIPooler)( + output_size=14, + scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), + sampling_ratio=0, + pooler_type="ROIAlignV2", + ), + mask_head=L(MaskRCNNConvUpsampleHead)( + input_shape=ShapeSpec(channels=256, width=14, height=14), + num_classes="${..num_classes}", + conv_dims=[256, 256, 256, 256, 256], + ), + ), + pixel_mean=constants['imagenet_bgr256_mean'], + pixel_std=constants['imagenet_bgr256_std'], + input_format="BGR", +) \ No newline at end of file diff --git a/cwm/eval/Segmentation/archive/models/mask_rcnn_vitdet_v2.py b/cwm/eval/Segmentation/archive/models/mask_rcnn_vitdet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ec2fafecb980a03b1f70d5e8d45ceb6064bc72 --- /dev/null +++ b/cwm/eval/Segmentation/archive/models/mask_rcnn_vitdet_v2.py @@ -0,0 +1,57 @@ +from functools import partial +import torch.nn as nn +from detectron2.config import LazyCall as L +from detectron2.modeling import ViT, SimpleFeaturePyramid +from detectron2.modeling.backbone.fpn import LastLevelMaxPool +from models.mask_rcnn_fpn_v2 import model, constants + +model.pixel_mean = constants['imagenet_rgb256_mean'] +model.pixel_std = constants['imagenet_rgb256_std'] +model.input_format = "RGB" + +# Base +embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1 +# Creates Simple Feature Pyramid from ViT backbone +model.backbone = L(SimpleFeaturePyramid)( + net=L(ViT)( # Single-scale ViT backbone + img_size=1024, + patch_size=16, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + drop_path_rate=dp, + window_size=14, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + window_block_indexes=[ + # 2, 5, 8 11 for global attention + 0, + 1, + 3, + 4, + 6, + 7, + 9, + 10, + ], + residual_block_indexes=[], + use_rel_pos=True, + out_feature="last_feat", + ), + in_feature="${.net.out_feature}", + out_channels=256, + scale_factors=(4.0, 2.0, 1.0, 0.5), + top_block=L(LastLevelMaxPool)(), + norm="LN", + square_pad=512, +) + +model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN" + +# 2conv in RPN: +model.proposal_generator.head.conv_dims = [-1, -1] + +# 4conv1fc box head +model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] +model.roi_heads.box_head.fc_dims = [1024] \ No newline at end of file diff --git a/cwm/eval/Segmentation/extract_segment.py b/cwm/eval/Segmentation/extract_segment.py new file mode 100644 index 0000000000000000000000000000000000000000..732703901401b889c7b619128273802d733ba82d --- /dev/null +++ b/cwm/eval/Segmentation/extract_segment.py @@ -0,0 +1,36 @@ +import torch +from torch import nn +import cwm.eval.Segmentation.utils as utils +from external.raft_interface import RAFTInterface + +class SegmentExtractor(nn.Module): + def __init__(self, num_segments=1, iters=4, motion_range=4): + self.num_segments = num_segments + self.iters = iters + self.motion_range = motion_range + self.flow_interface = RAFTInterface() + + def get_sampling_dist(self, x, model): + pass + + def forward(self, x, model, sampling_dist=None): + """ + x: [B, 3, H, W] a batch of imagenet-normalized image tensor + model: a pre-trained CWM model + """ + if not sampling_dist: + sampling_dist = self.get_sampling_dist(x, model) + + ## Step 1: sample initial moving and static locations from the distribution + moving_pos = utils.sample_positions_from_dist(num=1, dist=sampling_dist) # [B, num, 2] + static_pos = utils.sample_positions_from_dist(num=1, dist=(1-sampling_dist)) # [B, num, 2] + movement = torch.randint(-self.motion_range, self.motion_range, (B, 1, 2)) # [B, 1, 2] + + ## Step 2: compute initial flow maps + pred = model.get_counterfactual(x, mask, moving_pos=moving_pos, static_pos=static_pos, movement=movement) + flow = self.flow_interface(x[:, :, 0], pred) + + ## Step 3: iterate to add more moving and static motions + + + diff --git a/cwm/eval/Segmentation/utils.py b/cwm/eval/Segmentation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d5bdef535d8d2563c353fbe0da843bc0ac84f1f6 --- /dev/null +++ b/cwm/eval/Segmentation/utils.py @@ -0,0 +1,3 @@ +import torch + + diff --git a/cwm/eval/__init__.py b/cwm/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/model/keypoint_utils.py b/cwm/model/keypoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..608fd20b50f2e7659b778e6f45eb2f0bd7a20d23 --- /dev/null +++ b/cwm/model/keypoint_utils.py @@ -0,0 +1,186 @@ +from einops import rearrange + +import torch + +import numpy as np + +from torchvision import transforms + +def unpatchify(labels, norm=True): + # Define the input tensor + B = labels.shape[0] # batch size + N_patches = int(np.sqrt(labels.shape[1])) # number of patches along each dimension + patch_size = int(np.sqrt(labels.shape[2] / 3)) # patch size along each dimension + channels = 3 # number of channels + + rec_imgs = rearrange(labels, 'b n (p c) -> b n p c', c=3) + # Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch. + rec_imgs = rearrange(rec_imgs, + 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)', + p0=1, + p1=patch_size, + p2=patch_size, + h=N_patches, + w=N_patches) + if norm: + MEAN = torch.from_numpy(np.array((0.485, 0.456, 0.406))[None, :, None, None, None]).cuda().half() + STD = torch.from_numpy(np.array((0.229, 0.224, 0.225))[None, :, None, None, None]).cuda().half() + + rec_imgs = (rec_imgs - MEAN) / STD + + return rec_imgs + +def upsample_masks(masks, size, thresh=0.5): + shape = masks.shape + dtype = masks.dtype + h, w = shape[-2:] + H, W = size + if (H == h) and (W == w): + return masks + elif (H < h) and (W < w): + s = (h // H, w // W) + return masks[..., ::s[0], ::s[1]] + + masks = masks.unsqueeze(-2).unsqueeze(-1) + masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w) + if ((H % h) == 0) and ((W % w) == 0): + masks = masks.view(*shape[:-2], H, W) + else: + _H = np.prod(masks.shape[-4:-2]) + _W = np.prod(masks.shape[-2:]) + masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh + masks = masks.view(*shape[:2], H, W).to(masks.dtype) + return masks + +def get_keypoints_batch(model, x, + n_samples, + n_rounds, + frac=0.25, + mask=None, + pool='avg', + ): + """x = image pair tensor + n_samples = number of potential candidates to look at on each round + (produces one new unmasked per round) + n_rounds = total number of unmasked patches + frac = how often to do random sampling vs error-based sampling + mask = initial mask + + """ + # .half() + + + B = x.shape[0] + + IMAGE_SIZE = [224, 224] + predictor = model + patch_size = predictor.patch_size[-1] + num_frames = predictor.num_frames + patch_num = IMAGE_SIZE[0] // patch_size + # this is setup for getting per-patch error + if pool == 'avg': + pool_op = torch.nn.AvgPool2d(patch_size, stride=patch_size) + elif pool == 'max': + pool_op = torch.nn.MaxPool2d(patch_size, stride=patch_size) + + # initiazing rng + rng = np.random.RandomState(seed=0) + + n_patches = patch_num * patch_num + + # initializing mask at the fully masked state + mshape = num_frames * patch_num * patch_num + mshape_masked = (num_frames - 1) * patch_num * patch_num + + if mask is None: + mask = torch.ones([B, mshape], dtype=torch.bool) + mask[:, :mshape_masked] = False + + err_array = [] + choices = [] + + # flows = [] + for round_num in range(n_rounds): + # print(round_num) + # get the current prediction with current state of the mask + # .... produces out_flow b/c it's with head-motion condition + out = unpatchify(predictor(x, mask, forward_full=True)) + + # print(out.shape) + keypoint_recon = out.clone() + # flow = teacher.predict_flow(out) + # flows.append(flow) + + # get the error map + err_mat = (out[:, :, 0] - x[:, :, -1]).abs().mean(1) + # pool it to patch-size + pooled_err = pool_op(err_mat[:, None]) + # flatten the rror + flat_pooled_error = pooled_err.flatten(1, 3) + # set error to be zero where the mask is unmasked so it doesn't interfere + flat_pooled_error[mask[:, -n_patches:] == False] = 0 + # sort patches by where the error is highest + err_sort = torch.argsort(flat_pooled_error, -1) + new_mask = mask.clone().detach() + errors = [] + tries = [] + err_choices = 0 + + # look at various candidates to reveal in the next round + for sample_num in range(n_samples): + # if sample_num % 10 == 0: + # print("%d/%d" % (sample_num, n_samples)) + # either randomly sample + + err_choices += 1 + new_try = (num_frames - 1) * n_patches + err_sort[:, -1 * err_choices] + tries.append(new_try) + + for k in range(B): + new_mask[k, new_try[k]] = False + + reshaped_new_mask = upsample_masks( + new_mask.view(B, num_frames, IMAGE_SIZE[1] // patch_size, IMAGE_SIZE[1] // patch_size)[:, (num_frames - 1):], + IMAGE_SIZE)[:, 0] + + # print(reshaped_new_mask.sum()) + out = unpatchify(predictor(x, new_mask, forward_full=True)) + + abs_error = (out[:, :, 0] - x[:, :, -1]).abs().sum(1).cpu() + + masked_abs_error = abs_error * reshaped_new_mask + error = masked_abs_error.flatten(1, 2).sum(-1) + errors.append(error) + + # take the best one + for k in range(B): + new_mask[k, new_try[k]] = True + + errors = torch.stack(errors, 1) + tries = torch.stack(tries, 1) + best_ind = torch.argmin(errors, dim=-1) + best = torch.tensor([tries[k, best_ind[k]] for k in range(B)]) + choices.append(best) + err_array.append(errors) + # print(best) + for k in range(B): + mask[k, best[k]] = False + + feat = predictor(x, mask, forward_full=True, return_features=True) + + feat = feat#[:, :784*2] + + choices = torch.stack(choices, 1) + + #get x y coordinates of the keypoints + + choices = choices % mshape_masked + choices_x = choices % (patch_num) + choices_y = choices // (patch_num) + choices = torch.stack([choices_x, choices_y], 2) + + out = unpatchify(predictor(x, mask, forward_full=True), norm=False) + + keypoint_recon = out[0, :, 0].permute(1, 2, 0).detach().cpu().numpy() * 255 + + return mask, choices, err_array, feat, keypoint_recon.astype('uint8') \ No newline at end of file diff --git a/cwm/model/model_factory.py b/cwm/model/model_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..e27a8ed96c5b22c81a7153baa2f5f7f7bab490ad --- /dev/null +++ b/cwm/model/model_factory.py @@ -0,0 +1,93 @@ +""" +Model Factory for loading a checkpoint from gcloud, then initializing the model and the configuration +""" +import os +import requests +import torch +import tqdm + +from cwm.model import model_pretrain + +GCLOUD_BUCKET_NAME = "stanford_neuroai_models" +GCLOUD_URL_NAME = "https://storage.googleapis.com/stanford_neuroai_models" +CACHE_PATH = f"{os.getenv('CACHE')}/stanford_neuroai_models" if os.getenv('CACHE') is not None else ".cache/stanford_neuroai_models" +_model_catalogue ={ + "vitb_8x8patch_3frames": { + "path": "cwm/3frame_cwm_8x8.pth", + "init_fn": model_pretrain.vitb_8x8patch_3frames, + }, + "vitb_4x4patch_2frames": { + "path": "cwm/2frame_cwm_4x4.pth", + "init_fn": model_pretrain.vitb_4x4patch_2frames, + }, + + "vitb_8x8patch_2frames": { + "path": "cwm/2frame_cwm_8x8.pth", + "init_fn": model_pretrain.vitb_8x8patch_2frames, + }, + +} + + +class ModelFactory: + + def __init__(self, bucket_name: str = GCLOUD_BUCKET_NAME): + self.bucket_name = bucket_name + + def get_catalog(self): + """ + Get the list of available models + """ + # Initialize the storage client + return _model_catalogue.keys() + + def load_model(self, model_name: str, force_download=False): + """ + Load the model given the name + + Args: + model_name: str + Name of the model to load + force_download: bool (optional) + Whether to force the download of the freshest weights from gcloud + + Returns: + model: torch.nn.Module + Model initialized from the checkpoint + """ + # Find cache dir, use a directory inside of it as the checkpoint path + checkpoint_path = os.path.join(CACHE_PATH, _model_catalogue[model_name]["path"]) + # Make checkpoint directory if it does not exist + os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + # Construct gcloud url + gcloud_url = os.path.join(GCLOUD_URL_NAME, _model_catalogue[model_name]['path']) + # Download the model from google cloud using requests (with tqdm timer) + response = requests.get(gcloud_url, stream=True) + total_size_in_bytes= int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 Kibibyte + + # If force_download is true or the model has not yet been downloaded, grab it from gcloud + if force_download or not os.path.exists(checkpoint_path): + progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) + + print(f"Saving model to cache: {CACHE_PATH}") + with open(checkpoint_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + + # Initialize the model and the configuration from the checkpoint + print("checkpoint_path", checkpoint_path) + ckpt = torch.load(checkpoint_path, map_location="cpu") + + # Initialize the model from the specified initialization function + model = _model_catalogue[model_name]["init_fn"]() + + + # Load the model from the checkpoint + model.load_state_dict(ckpt['model'], strict=True) + print('Model loaded successfully') + + return model + +model_factory = ModelFactory() \ No newline at end of file diff --git a/cwm/model/model_pretrain.py b/cwm/model/model_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..199902c49b575b64ee986dc2efa7eaea15c07739 --- /dev/null +++ b/cwm/model/model_pretrain.py @@ -0,0 +1,827 @@ +from functools import partial +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_ as __call_trunc_normal_ + +from cwm.data.masking_generator import RotatedTableMaskingGenerator +from cwm.model.model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table +import cwm.eval.Flow.masking_flow as masking +import cwm.utils as utils +# from external.raft_interface import RAFTInterface +import cwm.eval.Flow.generator as generator +import matplotlib.pyplot as plt +import torch.nn.functional as F +from cwm.eval.Flow import flow_utils +import cwm.model.keypoint_utils as keypoint_utils + +def trunc_normal_(tensor, mean=0., std=1.): + __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) + + +def interpolate_pos_encoding(pos_embed, n_frames, h, w): + N = pos_embed.shape[1] + if N == (h * w * n_frames): + return pos_embed + old_h = old_w = int((N / n_frames) ** 0.5) + patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2) + + patch_pos_embed = F.interpolate( + patch_pos_embed, + size=(h, w), + mode='bilinear', + ) + return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) + + +class PretrainVisionTransformerEncoder(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, img_size=224, patch_size=(16, 16), in_chans=3, num_classes=0, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2, + num_frames=16, block_func=Block, k_bias=False, use_learnable_pos_emb=False, block_kwargs={}): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.patch_size = (tubelet_size,) + patch_size + self.pt, self.ph, self.pw = self.patch_size + self.h = int(img_size / self.ph) + self.w = int(img_size / self.pw) + self.hw = self.h * self.w + self.dims = [self.h, self.w] + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + tubelet_size=tubelet_size, + num_frames=num_frames + ) + num_patches = self.patch_embed.num_patches + + self.num_patches = num_patches + self.num_frames = num_frames + + if use_learnable_pos_emb: + self.use_learnable_pos_emb = True + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.pos_embed, std=.02) + else: + # sine-cosine positional embeddings + self.use_learnable_pos_emb = False + self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + block_func( + dim=embed_dim, in_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, **block_kwargs, k_bias=k_bias, + xla_flash=True) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def _get_pos_embed(self): + return self.pos_embed + + def forward_block(self, x, idx): + return self.blocks[idx](x) + + def forward_features(self, x, mask, move_pos=None, static_pos=None, movement=None, res=1): + + T = x.shape[2] + + x = embed = self.patch_embed(x) + pos_embed = self._get_pos_embed().type_as(x).to(x.device).clone() + + if not self.use_learnable_pos_emb: + pos_embed = pos_embed.detach() + + if res != 1: + print("res") + p0 = self.patch_size[-2] + p1 = self.patch_size[-1] + pos_embed = interpolate_pos_encoding(self.pos_embed, T, int(224 // p0 * res), int(224 // p1 * res)) + + x = x + pos_embed + B, _, C = x.shape + x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible + + if move_pos is not None: + h, w = self.h, self.w + first_frame_emb = embed[:, :self.hw].view(B, h, w, C) # [B, h, w, C] + last_frame_pos_emb = pos_embed[:, -self.hw:].view(1, h, w, C).expand(B, -1, -1, -1) # [B, h, w, C] + denominator = torch.tensor([self.h, self.w]).view(1, 1, 2).to(x.device) + + new_pos = move_pos + movement # [B, P, 2] + move_pos = move_pos / denominator * 2 - 1 + new_pos = (new_pos / denominator).clamp(0, 1) * 2 - 1 # handle special case where new_pos is out of bounds + static_pos = static_pos / denominator * 2 - 1 + + moving_emb = utils.sample_embedding(first_frame_emb, move_pos, mode='nearest') # [B, P, C] + moving_pos_emb = utils.sample_embedding(last_frame_pos_emb, new_pos, mode='nearest') # [B, P, C] + + static_emb = utils.sample_embedding(first_frame_emb, static_pos, mode='nearest') # [B, P, C] + static_pos_emb = utils.sample_embedding(last_frame_pos_emb, static_pos, mode='nearest') # [B, P, C] + + x_vis = torch.cat([x_vis, moving_emb + moving_pos_emb, static_emb + static_pos_emb], dim=1) + + # assert B == 1, "Only support batch size 1 for now" + # for (px, py) in move_patches: + # idx = px * self.w + py + # dx, dy = delta + # nx, ny = px + dx, py + dy + # new_idx = nx * self.w + ny + (self.patch_embed.num_frames - 1) * (self.h * self.w) + # + # emb = embed[:, idx] + # pos_emb = pos_embed[:, new_idx] + # emb = emb + pos_emb + # x_vis = torch.cat([x_vis, emb[None]], 1) + + # if static_patches is not None: + # for (px, py) in static_patches: + # idx = px * self.w + py + # new_idx = px * self.w + py + (self.patch_embed.num_frames - 1) * (self.h * self.w) + # emb = embed[:, idx] + # pos_emb = pos_embed[:, new_idx] + # emb = emb + pos_emb + # x_vis = torch.cat([x_vis, emb[None]], 1) + + for blk in self.blocks: + x_vis = blk(x_vis) + + x_vis = self.norm(x_vis) + return x_vis + + def _set_inputs(self, *args, **kwargs): + pass + + def forward(self, x, mask, move_pos=None, static_pos=None, movement=None, res=1): + + self._set_inputs(x, mask) + x = self.forward_features(x, mask, move_pos, static_pos, movement, res=res) + x = self.head(x) + return x + + +class PretrainVisionTransformerDecoder(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, patch_size=(16, 16), num_classes=768, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, block_func=Block, block_kwargs={}, + k_bias=False + ): + super().__init__() + + self.num_classes = num_classes + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.patch_size = patch_size + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + block_func( + dim=embed_dim, in_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, **block_kwargs, k_bias=k_bias, + xla_flash=True) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_block(self, x, idx): + return self.blocks[idx](x) + + def get_last_tokens(self, x, return_token_num): + if return_token_num > 0: + return self.head(self.norm(x[:, -return_token_num:])) + elif return_token_num == 0: + return self.head(self.norm(x))[:, x.size(1):] + else: + return self.head(self.norm(x)) + + def forward(self, x, return_token_num): + for blk in self.blocks: + x = blk(x) + + if return_token_num > 0: + x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels + else: + x = self.head(self.norm(x)) + + return x + + +class PretrainVisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + default_input_kwargs = {'unnormalize': True} + + def __init__(self, + img_size=224, + patch_size=(16, 16), + encoder_func=PretrainVisionTransformerEncoder, + encoder_in_chans=3, + encoder_num_classes=0, + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_block_func=Block, + encoder_block_kwargs={}, + decoder_num_classes=None, + # For pretraining this parameter isn't relevant but must be set according to tube&patch size + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=8, + decoder_block_func=Block, + decoder_block_kwargs={}, + mlp_ratio=4., + qkv_bias=False, + k_bias=False, + qk_scale=None, + num_frames=16, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + init_values=0., + tubelet_size=2, + use_flash_attention=False, + use_learnable_pos_emb=False, + **kwargs + ): + super().__init__() + + encoder_block_kwargs.update({'flash_attention': use_flash_attention}) + decoder_block_kwargs.update({'flash_attention': use_flash_attention}) + + self.tubelet_size = tubelet_size + num_classes = 3 * tubelet_size * ( + patch_size[0] * patch_size[1]) if decoder_num_classes is None else decoder_num_classes + + self.encoder = encoder_func( + img_size=img_size, + patch_size=patch_size, + in_chans=encoder_in_chans, + num_classes=encoder_num_classes, + embed_dim=encoder_embed_dim, + depth=encoder_depth, + num_heads=encoder_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_layer=norm_layer, + init_values=init_values, + tubelet_size=tubelet_size, + num_frames=num_frames, + block_func=encoder_block_func, + block_kwargs=encoder_block_kwargs, + use_learnable_pos_emb=use_learnable_pos_emb, + k_bias=k_bias, + **kwargs) + + self.decoder = PretrainVisionTransformerDecoder( + patch_size=patch_size, + num_classes=num_classes, + embed_dim=decoder_embed_dim, + depth=decoder_depth, + num_heads=decoder_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_layer=norm_layer, + init_values=init_values, + block_func=decoder_block_func, + k_bias=k_bias, + block_kwargs=decoder_block_kwargs) + + self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=k_bias) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + trunc_normal_(self.mask_token, std=.02) + + if use_learnable_pos_emb: + self.use_learnable_pos_emb = True + self.pos_embed = nn.Parameter(torch.zeros(self.encoder.num_patches, decoder_embed_dim)) + trunc_normal_(self.pos_embed, std=.02) + else: + self.use_learnable_pos_emb = False + self.pos_embed = get_sinusoid_encoding_table(self.encoder.num_patches, decoder_embed_dim) + + self.num_frames = num_frames + self.num_patches = self.encoder.num_patches + + if self.num_frames is not None: + self.num_patches_per_frame = self.num_patches // self.num_frames + else: + self.num_patches_per_frame = self.num_patches + + self.patch_size = self.encoder.patch_size + + if isinstance(img_size, int): + self.image_size = (img_size, img_size) + else: + assert hasattr(img_size, '__len__'), img_size + self.image_size = img_size + + # self.flow_interface = RAFTInterface() + + @property + def mask_size(self): + return (self.num_frames // self.patch_size[0], + self.image_size[-2] // self.patch_size[-2], + self.image_size[-1] // self.patch_size[-1]) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'mask_token'} + + def adjust_input_resolution(self, H, W): + if self.image_size == [H, W]: + return + + patch_size = self.encoder.patch_size[-2:] + self.image_size = [H, W] + self.encoder.h = int(H / self.encoder.ph) + self.encoder.w = int(W / self.encoder.pw) + self.encoder.hw = self.encoder.h * self.encoder.w + self.encoder.dims = [self.encoder.h, self.encoder.w] + dims = [int(s / p) for s, p in zip(self.image_size, patch_size)] + self.encoder.pos_embed = utils.interpolate_pos_encoding(self.encoder.pos_embed, 3, dims[0], dims[1]) + print('pos_embed', self.encoder.pos_embed.shape) + self.pos_embed = utils.interpolate_pos_encoding(self.pos_embed, 3, dims[0], dims[1]) + + def forward(self, x, mask, forward_full=False, return_features=False, res=1, *args, **kwargs): + + _, _, T, _, _ = x.shape + self.device = x.device + + num_patches_per_frame = (x.shape[-1] // self.encoder.patch_size[1]) ** 2 + + x_vis = self.encoder(x, mask, res=res, *args, **kwargs) + + if return_features: + return x_vis + + x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d] + B, N, C = x_vis.shape + + # add pos embedding + # if res != 1: + # p0 = self.patch_size[-2] + # p1 = self.patch_size[-1] + # pos_embed = interpolate_pos_encoding(self.pos_embed.unsqueeze(0), T, int(224 // p0 * res), int(224 // p1 * res)) + # else: + # pos_embed = self.pos_embed.unsqueeze(0) + + expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone() + + if not self.use_learnable_pos_emb: + expand_pos_embed = expand_pos_embed.detach() + + pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C) + pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C) + + nctx = num_patches_per_frame * (self.num_frames - 1) + + x_vis = x_vis + pos_emd_vis + + x_full = torch.cat([x_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d] + + if forward_full: + x_full = torch.cat([x_vis, self.mask_token + expand_pos_embed[:, nctx:]], dim=1) # [B, N, C_d] + x_all = self.decoder(x_full, num_patches_per_frame) + x = x_all + else: + x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16] + + return x + + def get_counterfactual(self, x, move_patches): + ''' + :param x: input tensor [1, C, T, H, W]: support only batch size 1 for now + :param move_patches: torch tensor [N, 4] sized array where each row contains patch motion [x1, y1, x2, y2] in pixel coordinates + :return: + ''' + B, _, T, H, H = x.shape + + mask = torch.ones(B, self.encoder.hw * self.encoder.num_frames).to(x.device).bool() + mask[:, :self.encoder.hw * (self.encoder.num_frames - 1)] = False + + move_patches = (move_patches/H)*self.encoder.h + move_patches = move_patches.to(torch.int64) + + for x1, y1, x2, y2 in move_patches: + idx2 = x2*self.encoder.w + y2 + (self.encoder.num_frames - 1) * (self.encoder.h * self.encoder.w) + mask[:, idx2] = False + im_x1 = x1*self.encoder.ph + im_y1 = y1*self.encoder.pw + im_x2 = x2*self.encoder.ph + im_y2 = y2*self.encoder.pw + x[:, :, -1, im_x2:im_x2+self.encoder.ph, im_y2:im_y2+self.encoder.pw] = x[:, :, -2, im_x1:im_x1+self.encoder.ph, im_y1:im_y1+self.encoder.pw] + + prediction = self.forward(x, mask, forward_full=True) + + prediction = utils.unpatchify_cwm( + prediction, + patch_size=self.encoder.patch_size[-1], + ) # reshape the output to an image + + return prediction + + + def get_directional_counterfactual(self, x, mask=None, move_pos=None, static_pos=None, movement=None, max_movement=None): + B, _, T, _, _ = x.shape + + if mask is None: # default mask: all visible but the last frame + mask = torch.ones(B, self.encoder.hw * self.encoder.num_frames).to(x.device).bool() + mask[:, :self.encoder.hw * (self.encoder.num_frames - 1)] = False + + if movement is None: # generate random motion if movement is not specified + assert max_movement is not None and move_pos is not None + movement = torch.randint(-max_movement, max_movement, move_pos.shape).to(x.device) # [B, num_samples, 2] + + x_vis = self.encoder(x, mask, move_pos=move_pos, static_pos=static_pos, movement=movement) # [B, N_vis, C_e] + x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d] + B, N, C = x_vis.shape + expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach() + pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C) + + if move_pos is not None: + h, w = self.encoder.h, self.encoder.w + last_frame_pos_emb = expand_pos_embed[:, -(h * w):].view(B, h, w, C) # [B, h, w, C] + + # compute new locations of the moved patche, snormalize positions to range [-1, 1] + new_pos = move_pos + movement # [B, P, 2] + denominator = torch.tensor([h, w]).view(1, 1, 2).to(x.device) + new_pos = (new_pos / denominator).clamp(0, 1) * 2 - 1 + static_pos = static_pos / denominator * 2 - 1 + + # sample the position embeddings of the moved and static patches + moving_pos_emb = utils.sample_embedding(last_frame_pos_emb, new_pos, mode='nearest') # [B, P, C] + static_pos_emb = utils.sample_embedding(last_frame_pos_emb, static_pos, mode='nearest') # [B, P, C] + + # concatenate with the position embeddings to the visible patches + pos_emd_vis = torch.cat([pos_emd_vis, moving_pos_emb, static_pos_emb], dim=1) + + # assert B == 1, "Only support batch size 1 for now" + # offset = (self.encoder.patch_embed.num_frames - 1) * (self.encoder.h * self.encoder.w) + # for (px, py) in move_patches: + # dx, dy = delta + # nx, ny = px + dx, py + dy + # new_idx = nx * self.encoder.w + ny + offset + # pos_emb = expand_pos_embed[:, new_idx] + # pos_emd_vis = torch.cat([pos_emd_vis, pos_emb[None]], 1) + # + # if static_patches is not None: + # for (px, py) in static_patches: + # new_idx = px * self.encoder.w + py + offset + # pos_emb = expand_pos_embed[:, new_idx] + # pos_emd_vis = torch.cat([pos_emd_vis, pos_emb[None]], 1) + + pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C) + x_vis = x_vis + pos_emd_vis + x_full = torch.cat([x_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d] + x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16] + + prediction = utils.unpatchify_cwm( + x, + patch_size=self.encoder.patch_size[-1], + mask=mask[:, -self.encoder.hw:] + ) # reshape the output to an image + + return prediction + + + @torch.no_grad() + def get_segment(self, x, mask=None, sampling_dist=None, num_segments=1, num_iters=4, num_samples=4, max_movement=4, + vis=True): + B, C, T, H, W = x.shape + N = num_samples + patch_size = self.encoder.patch_size[-1] + + self.adjust_input_resolution(H, W) + + # ## Step 0: define the sampling distribution for moving and static locations + # if sampling_dist is None: + # sampling_dist = utils.get_dino_predominance(x[:, :, 0], dims=self.encoder.dims, img_size=self.image_size)[0] + # sampling_dist = F.interpolate(sampling_dist, self.encoder.dims, mode='bilinear', align_corners=False) + # sampling_dist = sampling_dist.squeeze(1) ** 4 + # + # if vis: + # print('sampling_dist', sampling_dist.shape) + # plt.imshow(sampling_dist[0].cpu().numpy()) + # plt.title(f'Sampling distribution (max:{sampling_dist.max():.3f})') + # plt.show() + + ## Step 1: sample initial moving and static locations from the distribution + init_move_pos, init_static_pos, init_flow_mag, max_score = None, None, None, 0 + + # Sample multiple positions for each segment and select the one with consistent outputs + for _ in range(N): + # sample one moving position per example in the batch + move_pos = utils.sample_positions_from_dist(size=[1, 1], dist=sampling_dist).repeat(N, 1, 1) # [BN, 1, 2] + + # each move position has N static positions and movement directions + static_pos = utils.sample_positions_from_dist(size=[B * N, 1], dist=-sampling_dist) # [BN, 1, 2] + + ## compute initial flow maps + _x = x.repeat(N, 1, 1, 1, 1) # [BN, C, T, H, W] + pred = self.get_directional_counterfactual(_x, move_pos=move_pos, static_pos=static_pos, max_movement=max_movement) + flow, flow_mag = self.flow_interface(_x[:, :, 0].float(), pred.clamp(0, 1).float(), return_magnitude=True) + flow_mag = flow_mag.view(B, N, H, W) + scores = flow_mag.flatten(2, 3).std(dim=1).mean(-1) # [B, N] + print('scores', scores, flow_mag.shape) + if scores.mean(-1) > max_score: + init_move_pos, init_static_pos, init_flow_mag, max_score = move_pos, static_pos, flow_mag, scores + + # visualize samples + if vis: + fig, axs = plt.subplots(1, num_samples, figsize=(2 * num_samples, 2 * num_samples)) + + for i in range(num_samples): + move = move_pos[i, 0].cpu() + static = static_pos[i, 0].cpu() + flow_rgb = utils.flow_to_rgb(flow[i].cpu().permute(1, 2, 0)) + axs[i].imshow(flow_rgb) + axs[i].scatter(move[1] * patch_size, move[0] * patch_size, color='green', s=20) + axs[i].set_axis_off() + axs[i].scatter(static[1] * patch_size, static[0] * patch_size, color='red', s=20) + fig.subplots_adjust(wspace=0.01, hspace=0.01) # Change these values to adjust space + + plt.show() + plt.close() + + ## Step 2: iteratively add more moving and static locations to refine the segment + prev_flow_mag = init_flow_mag + prev_move_pos = init_move_pos + prev_static_pos = init_static_pos + npos_per_iter = 1 + for it in range(num_iters): + print('Iteration', it) + sampling_dist = F.interpolate(prev_flow_mag, size=self.encoder.dims, mode='bilinear').mean(1) + # sample one moving position per example in the batch + move_pos = utils.sample_positions_from_dist(size=[1, npos_per_iter], dist=sampling_dist).repeat(N, 1, 1) + move_pos = torch.cat([prev_move_pos, move_pos], dim=1) + + # each move position has N static positions and movement directions + static_pos = utils.sample_positions_from_dist(size=[B * N, npos_per_iter], + dist=-sampling_dist) # [BN, 1, 2] + static_pos = torch.cat([prev_static_pos, static_pos], dim=1) + + pred = self.get_directional_counterfactual(_x, move_pos=move_pos, static_pos=static_pos, max_movement=max_movement) + flow, flow_mag = self.flow_interface(_x[:, :, 0].float(), pred.clamp(0, 1).float(), return_magnitude=True) + flow_mag = flow_mag.view(B, N, H, W) + scores = flow_mag.flatten(2, 3).std(dim=1).mean(-1) # [B, N] + print('scores', scores, flow_mag.shape) + if scores.mean(-1) > max_score: + init_move_pos, init_static_pos, init_flow_mag, max_score = move_pos, static_pos, flow_mag, scores + + # visualize samples + if vis: + fig, axs = plt.subplots(1, num_samples, figsize=(2 * num_samples, 2 * num_samples)) + + for i in range(num_samples): + flow_rgb = utils.flow_to_rgb(flow[i].cpu().permute(1, 2, 0)) + + axs[i].imshow(flow_rgb) + axs[i].set_axis_off() + for k in range(move_pos.shape[1]): + move = move_pos[i, k].cpu() + static = static_pos[i, k].cpu() + axs[i].scatter(move[1] * patch_size, move[0] * patch_size, color='green', s=20) + axs[i].scatter(static[1] * patch_size, static[0] * patch_size, color='red', s=20) + fig.subplots_adjust(wspace=0.01, hspace=0.01) # Change these values to adjust space + + plt.show() + plt.close() + + ## Step 3: iterate to add more moving and static motions + return None + + @torch.no_grad() + def get_flow(self, img1, img2, conditioning_img=None, mode='jacobian', perturbation_patch_size=8, aggregation_patch_size=8, mask_ratio=0.0, num_scales=1, num_mask_samples=1): + ''' + :param img1: input image 1 [B, C, H, W] + :param img2: input image 2 [B, C, H, W] + :param mode: which flow extraction method to use: 'jacobian' or 'optical_flow + :param mask_ratio: what frame2 mask ratio to use when extracting flow + :return: forward flow [B, 2, H, W] + ''' + + + + if mode == 'jacobian': + frame_size = 224 // self.patch_size[-1] + DFG = generator.DerivativeFlowGenerator( + predictor=self, + perturbation_patch_size=perturbation_patch_size, + aggregation_patch_size=aggregation_patch_size, + agg_power=None, + agg_channel_func=lambda x: F.relu(x.sum(-3, True)), + num_samples=5, + average_jacobian=False, + leave_one_out_sampling=False, + imagenet_normalize_inputs=False, + temporal_dim=2, + confidence_thresh=None + ).to(img1.device) + + maskgen_uniform = masking.PytorchMaskGeneratorWrapper( + mask_generator=masking.RotatedTableMaskingGenerator, + input_size=(self.num_frames, frame_size, frame_size), + mask_ratio=mask_ratio + ).to(img1.device) + + + jac_fwd, forward_flow = flow_utils.extract_jacobians_and_flows(img1, img2, + DFG, + maskgen_uniform()[None]) + + else: + frame_size = 224 // self.patch_size[-1] + mask_generator = RotatedTableMaskingGenerator( + input_size=(self.num_frames, frame_size, frame_size), + mask_ratio=mask_ratio, + tube_length=1, + batch_size=1, + mask_type='rotated_table' + ) + + forward_flow, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self, mask_generator, img1, img2, conditioning_img=conditioning_img, num_scales=num_scales, N_mask_samples=num_mask_samples) + + return forward_flow + + @torch.no_grad() + def get_keypoints(self, img1, img2, img3=None, num_keypoints=10, samples_per_keypoint=1): + ''' + :param img1: input image 1 [B, C, H, W] imagenet normalized + :param img2: input image 2 [B, C, H, W] imagenet normalized + :param num_keypoints: number of keypoints to extract + :param samples_per_keypoint: number of samples per keypoint + ''' + if self.num_frames == 2: + x = torch.stack([img1, img2], dim=2) + else: + if img3 is None: + x = torch.stack([img1, img1, img2], dim=2) + else: + x = torch.stack([img1, img2, img3], dim=2) + + mask, choices, err_array, feat, keypoint_recon = keypoint_utils.get_keypoints_batch(self, x, samples_per_keypoint, num_keypoints) + + return mask, choices, err_array, feat, keypoint_recon + +def pretrain_vit_base_224_scaffold(img_size=224, **kwargs): + model = PretrainVisionTransformer( + img_size=img_size, + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_num_classes=0, + decoder_embed_dim=512, + decoder_num_heads=16, + decoder_depth=8, + mlp_ratio=4, + qkv_bias=True, + k_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + model.default_cfg = _cfg() + return model + + +def pretrain_videomae_base_224_scaffold(**kwargs): + model = PretrainVisionTransformer( + img_size=224, + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_num_classes=0, + decoder_embed_dim=384, + decoder_num_heads=6, + decoder_depth=4, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + model.default_cfg = _cfg() + return model + + +def vitb_8x8patch_3frames(**kwargs): + model = pretrain_vit_base_224_scaffold( + patch_size=(8, 8), + num_frames=3, + tubelet_size=1, + use_flash_attention=True, + **kwargs) + return model + + +def vitb_8x8patch_2frames(**kwargs): + model = pretrain_vit_base_224_scaffold( + patch_size=(8, 8), + num_frames=2, + tubelet_size=1, + use_flash_attention=True, + **kwargs) + return model + +def vitb_8x8patch_2frames_vmae(**kwargs): + model = pretrain_videomae_base_224_scaffold( + patch_size=(8, 8), + num_frames=2, + tubelet_size=1, + use_flash_attention=True, + **kwargs) + return model + +def vitb_4x4patch_2frames(**kwargs): + model = pretrain_videomae_base_224_scaffold( + patch_size=(4, 4), + num_frames=2, + tubelet_size=1, + **kwargs) + return model + +# def base_8x8patch_2frames_1tube(**kwargs): +# model = pretrain_videomae_base_224_scaffold( +# patch_size=(8, 8), +# num_frames=2, +# tubelet_size=1, +# **kwargs) +# return model diff --git a/cwm/model/model_utils.py b/cwm/model/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5b7af6bcea77d9ca902998a76f69449bd1309896 --- /dev/null +++ b/cwm/model/model_utils.py @@ -0,0 +1,188 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import drop_path, to_2tuple + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + **kwargs + } + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None, flash_attention=False, k_bias=False, legacy=True, xla_flash=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.legacy = legacy + + self.xla_flash = xla_flash + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + if k_bias: + self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.k_bias = None + else: + self.q_bias = None + self.v_bias = None + self.k_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + if self.k_bias is not None: + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) + else: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p) + x = x.transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + attn_head_dim=None, in_dim=None, flash_attention=False, k_bias=False, legacy=False, xla_flash=False): + super().__init__() + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim, flash_attention=flash_attention, k_bias=k_bias, legacy=legacy, xla_flash=xla_flash) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if (init_values or 0) > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=(16, 16), in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2): + super().__init__() + img_size = to_2tuple(img_size) + + self.tubelet_size = int(tubelet_size) + if num_frames is not None: + self.num_frames = int(num_frames) + self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size) + else: + self.num_frames = None + self.num_patches = None + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim, + kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]), + stride=(self.tubelet_size, patch_size[0], patch_size[1])) + + def forward(self, x, **kwargs): + # B, C, T, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + +# sin-cos position encoding +# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 +def get_sinusoid_encoding_table(positions, + d_hid, + apply_sinusoid=True): + ''' Sinusoid position encoding table ''' + # TODO: make it with torch instead of numpy + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + if isinstance(positions, int): + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(positions)]) + else: + assert hasattr(positions, '__len__') + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in positions]) + if apply_sinusoid: + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + diff --git a/cwm/optim_factory.py b/cwm/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..921a6cfad87667d817b0a9c8b1e5cefa3cccd18f --- /dev/null +++ b/cwm/optim_factory.py @@ -0,0 +1,183 @@ +import torch +from torch import optim as optim + +from timm.optim.adafactor import Adafactor +from timm.optim.adahessian import Adahessian +from timm.optim.adamp import AdamP +from timm.optim.lookahead import Lookahead +from timm.optim.nadam import Nadam +# from timm.optim.novograd import NovoGrad +# from timm.optim.nvnovograd import NvNovoGrad +from timm.optim.radam import RAdam +from timm.optim.rmsprop_tf import RMSpropTF +from timm.optim.sgdp import SGDP + +import json + +if torch.cuda.device_count() > 0: + use_tpu = False +else: + import torch_xla.core.xla_model as xm + print = xm.master_print + use_tpu = True + +try: + from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD + has_apex = True +except ImportError: + has_apex = False + + +def get_num_layer_for_vit(var_name, num_max_layer): + if var_name in ("cls_token", "mask_token", "pos_embed"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("rel_pos_bias"): + return num_max_layer - 1 + elif var_name.startswith("blocks"): + layer_id = int(var_name.split('.')[1]) + return layer_id + 1 + else: + return num_max_layer - 1 + + +class LayerDecayValueAssigner(object): + def __init__(self, values): + self.values = values + + def get_scale(self, layer_id): + return self.values[layer_id] + + def get_layer_id(self, var_name): + return get_num_layer_for_vit(var_name, len(self.values)) + + +def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): + parameter_group_names = {} + parameter_group_vars = {} + all_names = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + group_name = "no_decay" + this_weight_decay = 0. + else: + group_name = "decay" + this_weight_decay = weight_decay + if get_num_layer is not None: + layer_id = get_num_layer(name) + group_name = "layer_%d_%s" % (layer_id, group_name) + else: + layer_id = None + + if group_name not in parameter_group_names: + if get_layer_scale is not None: + scale = get_layer_scale(layer_id) + else: + scale = 1. + + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale + } + + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + + all_names.append(name) + # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + + return list(parameter_group_vars.values()) + + +def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): + opt_lower = args.opt.lower() + weight_decay = args.weight_decay + if weight_decay and filter_bias_and_bn: + skip = {} + if skip_list is not None: + skip = skip_list + elif hasattr(model, 'no_weight_decay'): + skip = model.no_weight_decay() + parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) + weight_decay = 0. + else: + parameters = model.parameters() + + if 'fused' in opt_lower: + assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' + + opt_args = dict(lr=args.lr, weight_decay=weight_decay) + if hasattr(args, 'opt_eps') and args.opt_eps is not None: + opt_args['eps'] = args.opt_eps + if hasattr(args, 'opt_betas') and args.opt_betas is not None: + opt_args['betas'] = args.opt_betas + + opt_split = opt_lower.split('_') + opt_lower = opt_split[-1] + if opt_lower == 'sgd' or opt_lower == 'nesterov': + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) + elif opt_lower == 'momentum': + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) + elif opt_lower == 'adam': + optimizer = optim.Adam(parameters, **opt_args) + elif opt_lower == 'adamw': + optimizer = optim.AdamW(parameters, **opt_args) + elif opt_lower == 'nadam': + optimizer = Nadam(parameters, **opt_args) + elif opt_lower == 'radam': + optimizer = RAdam(parameters, **opt_args) + elif opt_lower == 'adamp': + optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) + elif opt_lower == 'sgdp': + optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) + elif opt_lower == 'adadelta': + optimizer = optim.Adadelta(parameters, **opt_args) + elif opt_lower == 'adafactor': + if not args.lr: + opt_args['lr'] = None + optimizer = Adafactor(parameters, **opt_args) + elif opt_lower == 'adahessian': + optimizer = Adahessian(parameters, **opt_args) + elif opt_lower == 'rmsprop': + optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) + elif opt_lower == 'rmsproptf': + optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) + # elif opt_lower == 'novograd': + # optimizer = NovoGrad(parameters, **opt_args) + # elif opt_lower == 'nvnovograd': + # optimizer = NvNovoGrad(parameters, **opt_args) + elif opt_lower == 'fusedsgd': + opt_args.pop('eps', None) + optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) + elif opt_lower == 'fusedmomentum': + opt_args.pop('eps', None) + optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) + elif opt_lower == 'fusedadam': + optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) + elif opt_lower == 'fusedadamw': + optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) + elif opt_lower == 'fusedlamb': + optimizer = FusedLAMB(parameters, **opt_args) + elif opt_lower == 'fusednovograd': + opt_args.setdefault('betas', (0.95, 0.98)) + optimizer = FusedNovoGrad(parameters, **opt_args) + else: + assert False and "Invalid optimizer" + raise ValueError + + if len(opt_split) > 1: + if opt_split[0] == 'lookahead': + optimizer = Lookahead(optimizer) + + return optimizer diff --git a/cwm/raft_core/__init__.py b/cwm/raft_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/raft_core/corr.py b/cwm/raft_core/corr.py new file mode 100644 index 0000000000000000000000000000000000000000..32e847bb1f63c81e7ea88a4f173e0e96f5fad5fc --- /dev/null +++ b/cwm/raft_core/corr.py @@ -0,0 +1,91 @@ +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2*r+1, device=coords.device) + dy = torch.linspace(-r, r, 2*r+1, device=coords.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/cwm/raft_core/datasets.py b/cwm/raft_core/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..3411fdacfb900024005e8997d07c600e963a95ca --- /dev/null +++ b/cwm/raft_core/datasets.py @@ -0,0 +1,235 @@ +# Data loading based on https://github.com/NVIDIA/flownet2-pytorch + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F + +import os +import math +import random +from glob import glob +import os.path as osp + +from utils import frame_utils +from utils.augmentor import FlowAugmentor, SparseFlowAugmentor + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + def __getitem__(self, index): + + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[...,None], (1, 1, 3)) + img2 = np.tile(img2[...,None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + return img1, img2, flow, valid.float() + + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): + super(MpiSintel, self).__init__(aug_params) + flow_root = osp.join(root, split, 'flow') + image_root = osp.join(root, split, dstype) + + if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + for i in range(len(image_list)-1): + self.image_list += [ [image_list[i], image_list[i+1]] ] + self.extra_info += [ (scene, i) ] # scene and frame_id + + if split != 'test': + self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + + +class FlyingChairs(FlowDataset): + def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): + super(FlyingChairs, self).__init__(aug_params) + + images = sorted(glob(osp.join(root, '*.ppm'))) + flows = sorted(glob(osp.join(root, '*.flo'))) + assert (len(images)//2 == len(flows)) + + split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split=='training' and xid==1) or (split=='validation' and xid==2): + self.flow_list += [ flows[i] ] + self.image_list += [ [images[2*i], images[2*i+1]] ] + + +class FlyingThings3D(FlowDataset): + def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): + super(FlyingThings3D, self).__init__(aug_params) + + for cam in ['left']: + for direction in ['into_future', 'into_past']: + image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) + image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) + flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(osp.join(idir, '*.png')) ) + flows = sorted(glob(osp.join(fdir, '*.pfm')) ) + for i in range(len(flows)-1): + if direction == 'into_future': + self.image_list += [ [images[i], images[i+1]] ] + self.flow_list += [ flows[i] ] + elif direction == 'into_past': + self.image_list += [ [images[i+1], images[i]] ] + self.flow_list += [ flows[i+1] ] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): + super(KITTI, self).__init__(aug_params, sparse=True) + if split == 'testing': + self.is_test = True + + root = osp.join(root, split) + images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) + images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + + for img1, img2 in zip(images1, images2): + frame_id = img1.split('/')[-1] + self.extra_info += [ [frame_id] ] + self.image_list += [ [img1, img2] ] + + if split == 'training': + self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root='datasets/HD1k'): + super(HD1K, self).__init__(aug_params, sparse=True) + + seq_ix = 0 + while 1: + flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) + images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + + if len(flows) == 0: + break + + for i in range(len(flows)-1): + self.flow_list += [flows[i]] + self.image_list += [ [images[i], images[i+1]] ] + + seq_ix += 1 + + +def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): + """ Create the data loader for the corresponding trainign set """ + + if args.stage == 'chairs': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} + train_dataset = FlyingChairs(aug_params, split='training') + + elif args.stage == 'things': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} + clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') + final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') + train_dataset = clean_dataset + final_dataset + + elif args.stage == 'sintel': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + things = FlyingThings3D(aug_params, dstype='frames_cleanpass') + sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') + sintel_final = MpiSintel(aug_params, split='training', dstype='final') + + if TRAIN_DS == 'C+T+K+S+H': + kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) + hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) + train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things + + elif TRAIN_DS == 'C+T+K/S': + train_dataset = 100*sintel_clean + 100*sintel_final + things + + elif args.stage == 'kitti': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + train_dataset = KITTI(aug_params, split='training') + + train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, + pin_memory=False, shuffle=True, num_workers=4, drop_last=True) + + print('Training with %d image pairs' % len(train_dataset)) + return train_loader + diff --git a/cwm/raft_core/extractor.py b/cwm/raft_core/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9c759d1243d4694e8656c2f6f8a37e53edd009 --- /dev/null +++ b/cwm/raft_core/extractor.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/cwm/raft_core/raft-kitti.pth b/cwm/raft_core/raft-kitti.pth new file mode 100644 index 0000000000000000000000000000000000000000..c4ded38c16a8db5a3409d2714051124843511b82 --- /dev/null +++ b/cwm/raft_core/raft-kitti.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9d170362415e1a27bd8402ee966a3ddf0d60df9b2df2c0b4949f5ced490a9e6 +size 21108000 diff --git a/cwm/raft_core/raft-sintel.pth b/cwm/raft_core/raft-sintel.pth new file mode 100644 index 0000000000000000000000000000000000000000..054ab495809cac00bb9290514ae141e034f8332c --- /dev/null +++ b/cwm/raft_core/raft-sintel.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90630d2e7d488a0d3ccb5e8194524850c4c05c732ea4ff99799822c7fa5c5cbf +size 21108000 diff --git a/cwm/raft_core/raft.py b/cwm/raft_core/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..72195a2ab3a277fed9545fe9bd6edeae9c5d25b5 --- /dev/null +++ b/cwm/raft_core/raft.py @@ -0,0 +1,268 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock, SmallUpdateBlock +from .extractor import BasicEncoder, SmallEncoder +from .corr import CorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +import argparse +from pathlib import Path + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + +class Dummy: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + +def get_args(cmd=None): + parser = argparse.ArgumentParser() + parser.add_argument('--corr_levels', type=int, default=4) + parser.add_argument('--corr_radius', type=int, default=4) + parser.add_argument('--dropout', type=float, default=0.0) + parser.add_argument('--mixed_precision', action='store_true') + parser.add_argument('--small', action='store_true') + parser.add_argument('--gpus', type=int, nargs='+', default=[0]) + + if cmd is None: + args = parser.parse_args() + else: + args = parser.parse_args(cmd) + return args + +def load_raft_model(load_path, + ignore_prefix=None, + multiframe=False, + scale_inputs=False, + **kwargs): + + path = Path(load_path) if load_path else None + + args = get_args("") + for k,v in kwargs.items(): + args.__setattr__(k,v) + args.multiframe = multiframe + args.scale_inputs = scale_inputs + + model = RAFT(args) + + if load_path is not None: + weight_dict = torch.load(load_path, map_location=torch.device("cpu")) + new_dict = dict() + for k in weight_dict.keys(): + if 'module' in k: + new_dict[k.replace('module.', '')] = weight_dict[k] + else: + new_dict[k] = weight_dict[k] + + if ignore_prefix is not None: + new_dict_1 = dict() + for k, v in new_dict.items(): + new_dict_1[k.replace(ignore_prefix, '')] = v + new_dict = new_dict_1 + + did_load = model.load_state_dict(new_dict, strict=False) + print(did_load, type(model).__name__, load_path) + else: + print("created a new %s with %d parameters" % ( + type(model).__name__, + sum([v.numel() for v in model.parameters()]))) + + return model + +def get_raft_flow(x, raft_model, iters=24, backward=False, t_dim=1): + assert len(x.shape) == 5, x.shape + assert x.shape[t_dim] >= 2, x.shape + x = x * 255.0 + inds = torch.tensor([0,1]).to(x.device) + x1, x2 = torch.index_select(x, t_dim, inds).unbind(t_dim) + if backward: + flow = raft_model(x2, x1, test_mode=True, iters=iters)[-1] + else: + flow = raft_model(x1, x2, test_mode=True, iters=iters)[-1] + + return flow + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + self.multiframe = self.args.multiframe + self.scale_inputs = self.args.scale_inputs + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + + if 'dropout' not in self.args: + self.args.dropout = 0 + + if 'alternate_corr' not in self.args: + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) + self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) + self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H//8, W//8, device=img.device, dtype=img.dtype) + coords1 = coords_grid(N, H//8, W//8, device=img.device, dtype=img.dtype) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + @property + def iters(self): + if getattr(self, '_iters', None) is None: + return None + return self._iters + @iters.setter + def iters(self, value=None): + self._iters = value + + def _forward_two_images( + self, + image1, image2, + iters=24, flow_init=None, + upsample=True, test_mode=True, **kwargs): + """ Estimate optical flow between pair of frames """ + if self.iters is not None: + iters = self.iters + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + decorator = autocast(enabled=True) if \ + (self.args.mixed_precision or (image1.dtype in [torch.float16, torch.bfloat16])) \ + else Dummy(enabled=False) + with decorator: + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + # with autocast(enabled=self.args.mixed_precision): + with decorator: + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + # with autocast(enabled=self.args.mixed_precision): + with decorator: + net, up_mask, delta_flow, motion_features = self.update_block(net, inp, corr, flow) + + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up, motion_features + + return flow_predictions, motion_features + + def forward(self, *args, **kwargs): + + if not self.multiframe: + return self._forward_two_images(*args, **kwargs) + x = (args[0] * 255.0) if self.scale_inputs else args[0] + assert len(x.shape) == 5, x.shape + assert x.shape[1] >= 2, x.shape + num_frames = x.size(1) + flows = [] + motion_features = [] + backward = kwargs.get('backward', False) + for t in range(num_frames-1): + x1, x2 = torch.index_select( + x, 1, torch.tensor([t,t+1]).to(x.device)).unbind(1) + _args = (x2, x1) if backward else (x1, x2) + _, flow, features = self._forward_two_images(*_args, *args[1:], **kwargs) + + flows.insert(0, flow) if backward else flows.append(flow) + motion_features.append(features) + + + return torch.stack(flows, 1), torch.stack(motion_features, 1) + diff --git a/cwm/raft_core/update.py b/cwm/raft_core/update.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb35a80ac3b3505d876919242b14aac7a766c57 --- /dev/null +++ b/cwm/raft_core/update.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow, motion_features + + + diff --git a/cwm/raft_core/utils/__init__.py b/cwm/raft_core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cwm/raft_core/utils/augmentor.py b/cwm/raft_core/utils/augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..e81c4f2b5c16c31c0ae236d744f299d430228a04 --- /dev/null +++ b/cwm/raft_core/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/cwm/raft_core/utils/flow_viz.py b/cwm/raft_core/utils/flow_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..dcee65e89b91b07ee0496aeb4c7e7436abf99641 --- /dev/null +++ b/cwm/raft_core/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/cwm/raft_core/utils/frame_utils.py b/cwm/raft_core/utils/frame_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c491135efaffc25bd61ec3ecde99d236f5deb12 --- /dev/null +++ b/cwm/raft_core/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/cwm/raft_core/utils/utils.py b/cwm/raft_core/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6db311b1857c204e52ce42715967712849e896 --- /dev/null +++ b/cwm/raft_core/utils/utils.py @@ -0,0 +1,90 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + +MAX_DIM = 256 * 28 * 28 + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1).to(img.dtype) + if img.size(0) >= MAX_DIM: + split_size = img.size(0) // 2 + imgs, grids = torch.split(img, split_size, dim=0), torch.split(grid, split_size, dim=0) + img = torch.cat([F.grid_sample(imgs[i], grids[i], align_corners=True) + for i in range(len(imgs))], 0) + else: + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.to(img.dtype) + + return img + + +def coords_grid(batch, ht, wd, device, dtype=torch.float32): + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).to(dtype) + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/cwm/run_pretraining.py b/cwm/run_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..34345e7390d052fd242444987cc762e2a79e759c --- /dev/null +++ b/cwm/run_pretraining.py @@ -0,0 +1,242 @@ + +import argparse +import datetime +import numpy as np +import random +import time +import torch +import json +import os +from pathlib import Path +from optim_factory import create_optimizer +from torch.nn.parallel import DistributedDataParallel as DDP +from utils import NativeScalerWithGradNormCount as NativeScaler +import utils +from cwm.data.dataset_utils import build_pretraining_dataset +from cwm.model import model_pretrain +from engine_for_pretraining import train_one_epoch +import wandb +import torch.backends.cudnn as cudnn +np.random.seed(0) +random.seed(0) + +def get_args(): + parser = argparse.ArgumentParser('CWM pre-training script', add_help=False) + + # training parameters + parser.add_argument('--batch_size', default=64, type=int, help='per-GPU batch-size') + parser.add_argument('--epochs', default=800, type=int, help='number of training epochs') + parser.add_argument('--save_ckpt_freq', default=50, type=int, help='save checkpoint frequency') + parser.add_argument('--print_freq', default=1, type=int, help='frequency of printing training stats') + parser.add_argument('--accum_iter', default=1, type=int, help='number of steps to accumulate gradients') + parser.add_argument('--eval', action='store_true', help='evaluation mode') + parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving') + parser.add_argument('--log_dir', default=None, help='path where to tensorboard log') + parser.add_argument('--device', default='cuda', help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--val_after', default=50, type=int) + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') + parser.add_argument('--use_wandb', action='store_true', help='use wandb for logging') + + # Model parameters + parser.add_argument('--model', default='vitb_8x8patch_3frames', type=str, help='Name of model to train') + parser.add_argument('--context_frames', type=int, default=2, help='number of frames model will see densely') + parser.add_argument('--target_frames', type=int, default=1, help='number of frames model will see sparsely') + parser.add_argument('--temporal_units', type=str, default='ms', help='the units in which time is defined') + parser.add_argument('--sampling_rate', type=int, default=150, help='temporal gap between context/target frames') + parser.add_argument('--context_target_gap', type=int, nargs='+', default=[150, 150], help='gap between context/target') + + # Masking and target parameters + parser.add_argument('--mask_type', default='rotated_table', type=str, help='masked strategy') + parser.add_argument('--mask_ratio', default=0.75, type=float, help='masking ratio') + parser.add_argument('--mask_kwargs', default='', type=json.loads, help='extra arguments for masking generator') + parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT', help='Drop path rate (default: 0.1)') + + # Optimizer parameters + parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default:adamw)') + parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer epsilon') + parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') + parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)') + parser.add_argument('--weight_decay_end', type=float, default=0.05, help='Final value of the weight decay.') + parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR', help='learning rate (default: 1.5e-4)') + parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate') + parser.add_argument('--min_lr', type=float, default=0, metavar='LR', help='lower lr bound for cyclic schedulers)') + parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR') + parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', help='steps to warmup LR') + + # Dataset parameters + parser.add_argument('--data_path', default='/path/to/list_kinetics-400', type=str, help='dataset path') + parser.add_argument('--data_path_list', type=str, nargs='+', default=None, help='[path1, path2, path3, ...]') + parser.add_argument('--num_workers', default=10, type=int) + + # Augmentation parameters + parser.add_argument('--augmentation_type', type=str, default='multiscale', choices=['multiscale', 'center', 'none']) + parser.add_argument('--augmentation_scales', type=float, nargs='+', default=[1.0, 0.875, 0.75, 0.66]) + + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + + return parser.parse_args() + + +# Assuming 'model' is your PyTorch model +def export_model_parameters(model): + with open('model_parameters.txt', 'w') as f: + for name, param in model.named_parameters(): + f.write(f"{name} {param.size()}\n") + + +def main(args): + ## Setup distributed training + utils.init_distributed_mode(args) + cudnn.benchmark = True + device = torch.device(args.device) + num_tasks = utils.get_world_size() + sampler_rank = global_rank = utils.get_rank() + world_size = utils.get_world_size() + + ## Fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + ## Initialize model + model = getattr(model_pretrain, args.model)() + args.input_size = int(model.encoder.patch_embed.img_size[0]) + args.tubelet_size = model.patch_size[0] + + args.mask_input_size = ( + (args.context_frames + args.target_frames) // args.tubelet_size, + args.input_size // model.patch_size[-2], + args.input_size // model.patch_size[-1], + ) + + ## Prepare datasets + dataset_train = build_pretraining_dataset(args) + + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, + num_replicas=num_tasks, + rank=sampler_rank, + shuffle=True, + drop_last=True + ) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, + sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, drop_last=True, + worker_init_fn=utils.seed_worker, + ) + + num_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks + + n_params, n_params_str = utils.get_model_num_parameters(model) + + total_batch_size = args.batch_size * world_size * args.accum_iter + + ## LR and warmup + export_model_parameters(model) + + model = DDP(model.to(device), device_ids=[args.gpu], find_unused_parameters=False) + + ## Optimizer, loss scaler + optimizer = create_optimizer(args, model.module) + loss_scaler = NativeScaler() + + ## LR scheduler, WD scheduler + args.lr = args.lr * total_batch_size / 256 + args.min_lr = args.min_lr * total_batch_size / 256 + args.warmup_lr = args.warmup_lr * total_batch_size / 256 + + lr_schedule_values = utils.cosine_scheduler( + args.lr, args.min_lr, args.epochs, num_steps_per_epoch, + warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, + ) + + wd_schedule_values = utils.cosine_scheduler( + args.weight_decay, args.weight_decay_end, args.epochs, num_steps_per_epoch + ) + + ## Resume from checkpoint, if any + utils.auto_load_model(args=args, model=model, optimizer=optimizer, loss_scaler=loss_scaler) + + ## Print training arguments + print("world size: %d" % args.world_size) + print("model: %s" % args.model) + print("image size: %s" % str(args.input_size)) + print("patch size: %s" % str(model.module.encoder.patch_embed.patch_size[-2:])) + print("context frames: %s" % str(args.context_frames)) + print("target frames: %s" % str(args.target_frames)) + print("per-device batch size: %d" % total_batch_size) + print("total batch size: %d" % total_batch_size) + print("grad accumulation: %d" % args.accum_iter) + print("dataset length: %d" % len(dataset_train)) + print("steps per epoch: %d" % num_steps_per_epoch) + print("num parameters: %s" % n_params_str) + print("lr: %.8f" % args.lr) + + ## Setup logging + if args.use_wandb and utils.is_main_process(): + wandb.init(project="cwm", name=args.output_dir.split('/')[-1], config=args) + + + print(f'start training at epoch {args.start_epoch} for {args.epochs} epochs') + start_time = time.time() + + for epoch in range(args.start_epoch, args.epochs): + + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + + # Run one epoch + train_stats = train_one_epoch( + model, data_loader_train, optimizer, device, epoch, loss_scaler, + start_steps=epoch * num_steps_per_epoch, + lr_schedule_values=lr_schedule_values, + wd_schedule_values=wd_schedule_values, + args=args, + global_rank=global_rank, + ) + + # Save checkpoint + if args.output_dir and ((epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs): + utils.save_model(args=args, model=model, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch) + + # Logging + start_time = time.time() + do_write = (global_rank == 0) if args.use_xla else utils.is_main_process() + if args.output_dir and do_write: + log_stats = { + **{f'train/{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + 'params': n_params, + 'epoch_time': time.time() - start_time + } + + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + if args.use_wandb: + wandb.log(log_stats) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + opts = get_args() + + if opts.output_dir: + Path(opts.output_dir).mkdir(parents=True, exist_ok=True) + + main(opts) diff --git a/cwm/utils.py b/cwm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..be702f264e56f9de37a56f32ff08bd7241e7712b --- /dev/null +++ b/cwm/utils.py @@ -0,0 +1,733 @@ +import datetime +import io +import os +import random +import sys +import time +from collections import defaultdict, deque +from pathlib import Path + +import matplotlib +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from einops import rearrange +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.utils import get_state_dict +from torch import inf + +# sys.path.append(os.path.join(os.environ['HOME'], '.cache/torch/CutLER')) +# sys.path.append(os.path.join(os.environ['HOME'], '.cache/torch/CutLER/maskcut')) +# sys.path.append(os.path.join(os.environ['HOME'], '.cache/torch/CutLER/third_party')) +# import dino +# from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners, get_masked_affinity_matrix +# +# #from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners +# # DINO hyperparameters +# global dino_backbone +# dino_backbone = None + +def patchify(x, tubelet_size, patch_size): + ''' + :param x: [B, C, T, H, W] + :param tubelet_size: 2 + :param patch_size: (8, 8) + :return: + ''' + videos_squeeze = rearrange(x, + 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', + p0=tubelet_size, + p1=patch_size[0], + p2=patch_size[1]) + + videos_patch = rearrange(videos_squeeze, 'b n p c -> b n (p c)') + + return videos_patch + +def imagenet_unnormalize(x, temporal_dim=2): + device = x.device + + if len(x.shape) == 3: + if x.shape[0] == 3: # "channel_first" + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[:, None, None].to(x) + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[:, None, None].to(x) + else: # channel_last + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :].to(x) + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :].to(x) + elif len(x.shape) == 4: + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None].to(x) + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None].to(x) + elif len(x.shape) == 5: + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None].to(x) + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None].to(x) + + if temporal_dim == 2: + mean = mean.transpose(1,2) + std = std.transpose(1,2) + + return x * std + mean + +def imagenet_normalize(x, temporal_dim=2): + device = x.device + + if len(x.shape) == 3: + if x.shape[0] == 3: # "channel_first" + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[:, None, None].to(x) + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[:, None, None].to(x) + else: # channel_last + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :].to(x) + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :].to(x) + elif len(x.shape) == 4: + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None].to(x) + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None].to(x) + elif len(x.shape) == 5: + mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None].to(x) + std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None].to(x) + + if temporal_dim == 2: + mean = mean.transpose(1,2) + std = std.transpose(1,2) + + return (x - mean) / std + +def sinusoidal_embedding(x, n_freq=5, keep_ori=True): + """ + create sin embedding for 3d vectors + input: + x: *x3 + n_freq: number of raised frequency + """ + + shape = list(x.shape) + assert x.shape[-1] == 3, "expect the last dimension to have size 3" + x = x.reshape(-1, 3) + + embedded = [] + if keep_ori: + embedded.append(x) + emb_fns = [torch.sin, torch.cos] + freqs = 2. ** torch.linspace(0., n_freq - 1, steps=n_freq) + for freq in freqs: + for emb_fn in emb_fns: + embedded.append(emb_fn(freq * x)) + embedded = torch.cat(embedded, dim=-1) + C = embedded.shape[-1] + embedded = embedded.reshape(shape[:-1] + [C]) + return embedded + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def update2(self, kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.2f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.6f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + + + + +def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + +def _load_checkpoint_for_ema(model_ema, checkpoint): + """ + Workaround for ModelEma._load_checkpoint to accept an already-loaded object + """ + mem_file = io.BytesIO() + torch.save(checkpoint, mem_file) + mem_file.seek(0) + model_ema._load_checkpoint(mem_file) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + args.distributed = True + args.rank = int(os.environ["RANK"]) + args.gpu = int(os.environ['LOCAL_RANK']) + args.world_size = int(os.environ['WORLD_SIZE']) + args.dist_backend = 'nccl' + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model, prefix=prefix) + + warn_missing_keys = [] + ignore_missing_keys = [] + for key in missing_keys: + keep_flag = True + for ignore_key in ignore_missing.split('|'): + if ignore_key in key: + keep_flag = False + break + if keep_flag: + warn_missing_keys.append(key) + else: + ignore_missing_keys.append(key) + + missing_keys = warn_missing_keys + + if len(missing_keys) > 0: + print("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + print("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + if len(ignore_missing_keys) > 0: + print("Ignored weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, ignore_missing_keys)) + if len(error_msgs) > 0: + print('\n'.join(error_msgs)) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + # breakpoint() + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, + start_warmup_value=0, warmup_steps=-1): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_steps > 0: + warmup_iters = warmup_steps + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + iter_per_len = iters/len(iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iter_per_len)) + # schedule = np.array( + # [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) + + schedule = np.concatenate((warmup_schedule, schedule)) + + assert len(schedule) == epochs * niter_per_ep + return schedule + + +def get_model_num_parameters(model): + + num_parameters = sum([v.numel() for v in model.parameters() if v.requires_grad]) + + human_readable_fn = lambda num: \ + f'{num / 1e9:.3f} B' if num >= 1e9 else f'{num / 1e6:.3f} M' \ + if num >= 1e6 else f'{num / 1e3:.3f} K' if num >= 1e3 else str(num) + num_parameters_str = human_readable_fn(num_parameters) + + return num_parameters, num_parameters_str + +def save_model(args, epoch, model, optimizer, loss_scaler, model_ema=None): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + 'model': model.module.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + } + + if model_ema is not None: + to_save['model_ema'] = get_state_dict(model_ema) + + save_on_master(to_save, checkpoint_path) + else: + client_state = {'epoch': epoch} + if model_ema is not None: + client_state['model_ema'] = get_state_dict(model_ema) + model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) + + +def auto_load_model(args, model, optimizer, loss_scaler, model_ema=None, global_rank=None): + output_dir = Path(args.output_dir) + if loss_scaler is not None: + # torch.amp + if len(args.resume) == 0: + import glob + if global_rank is None: + all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) + else: + all_checkpoints = glob.glob(os.path.join(output_dir, f'checkpoint-*-rank-{global_rank}.pth')) + latest_ckpt = -1 + for ckpt in all_checkpoints: + if global_rank is None: + t = ckpt.split('-')[-1].split('.')[0] + else: + t = ckpt.split('checkpoint-')[1].split('-')[0] + if t.isdigit(): + latest_ckpt = max(int(t), latest_ckpt) + if latest_ckpt >= 0: + if global_rank is None: + args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) + else: + args.resume = os.path.join(output_dir, 'checkpoint-%d-rank-%d.pth' % (latest_ckpt, global_rank)) + if args.resume: + print("Auto resume checkpoint: %s" % args.resume) + + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model.module.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % args.resume) + if 'optimizer' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + args.start_epoch = checkpoint['epoch'] + 1 + if hasattr(args, 'model_ema') and args.model_ema: + _load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + + else: + # deepspeed, only support '--auto_resume'. + import glob + all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) + latest_ckpt = -1 + for ckpt in all_checkpoints: + t = ckpt.split('-')[-1].split('.')[0] + if t.isdigit(): + latest_ckpt = max(int(t), latest_ckpt) + if latest_ckpt >= 0: + args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt) + print("Auto resume checkpoint: %d" % latest_ckpt) + _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt) + args.start_epoch = client_states['epoch'] + 1 + if model_ema is not None: + if args.model_ema: + _load_checkpoint_for_ema(model_ema, client_states['model_ema']) + +def unpatchify(x, patch_size): + """ + x: (N, L, patch_size**2*3) + imgs: (N, 3, H, W) + """ + p = patch_size + h = w = int(x.shape[1] ** .5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + +def unpatchify_cwm(x, patch_size, mask=None): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + if mask is not None: + h = w = int(mask.shape[1] ** .5) + recon = torch.zeros(x.shape[0], h*w, x.shape[-1]).to(x) + recon[mask] = x.flatten(0, 1) + else: + h = w = int(x.shape[1] ** .5) + recon = x + + p = patch_size + assert h * w == recon.shape[1] + + recon = recon.reshape(shape=(recon.shape[0], h, w, p, p, 3)) + recon = torch.einsum('nhwpqc->nchpwq', recon) + imgs = recon.reshape(shape=(recon.shape[0], 3, h * p, h * p)) + return imgs + + +def sample_embedding(embedding, pos, mode='bilinear'): + """ + Sample embedding tensor at specified positions + embedding: [B, H, W, C] + pos: [B, P, 2] (convention: first dim is row, second dim is column) + """ + embedding = embedding.permute(0, 3, 1, 2) # [B, C, H, W] + device = embedding.device + # grid_sampling assues first value to be column-dimension, second value to be row-dimension + pos = pos.flip(dims=(-1,)) + assert pos.min() >= -1 and pos.max() <= 1, "grid sampling expect to be in range [-1, 1]" + + return F.grid_sample(embedding, pos[:, None].to(device), mode=mode).squeeze(-2).permute(0, 2, 1) # [B, P, C] + + +def sample_positions_from_dist(size, dist): + """ + Samples positions from a given unnormalized probability distribution. + + Parameters: + num (int): The number of samples to draw for each distribution in the batch. + dist (torch.Tensor): A float tensor of shape [B, H, W] representing the unnormalized + probability distributions for B batches each of length N. + + Returns: + torch.Tensor: A tensor of shape [B, num] containing the sampled positions. + """ + assert dist.dim() == 3, "dist should be a 3D tensor with shape [B, H, W]." + assert len(size) == 2, "size should be a 2D tuple (batch_size, num_samples)" + B, H, W = dist.shape + + new_B, num_samples = size + + if dist.min() < 0: + dist -= dist.min() + + # Flatten the last two dimensions to make it [B, H*W] + flattened_dist = dist.view(B, -1) + + # Sample indices according to the normalized distribution + sampled_indices = torch.multinomial(flattened_dist, new_B * num_samples, replacement=True) + + # Convert the flattened indices back to 2D indices + sampled_row_indices = sampled_indices // W + sampled_col_indices = sampled_indices % W + + # Stack the row and column indices + samples = torch.stack((sampled_row_indices, sampled_col_indices), dim=-1) + samples = samples.view(new_B, num_samples, 2) + + return samples +# +# def get_dino_predominance(images, dims=[28, 28], current_mask=None, painting=None, img_size=[224, 224]): +# global dino_backbone +# if dino_backbone is None: +# vit_arch = 'base' +# vit_feat = 'k' +# patch_size = 8 +# # DINO pre-trained model +# url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" +# feat_dim = 768 +# dino_backbone = dino.ViTFeat(url, feat_dim, vit_arch, vit_feat, patch_size) +# dino_backbone = dino_backbone.eval().requires_grad_(False).cuda() +# +# input_dino = images +# # input_dino = input_dino - torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input_dino.device) +# # input_dino = input_dino / torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input_dino.device) +# # input_dino = images.tensor +# input_dino = torch.nn.functional.interpolate(input_dino, size=img_size, mode='bilinear') +# features = dino_backbone(input_dino) +# +# predominence_map = [] +# +# for i in range(features.shape[0]): +# feats = features[i] +# if current_mask == None: +# painting = torch.from_numpy(np.zeros(dims)) +# painting = painting.to(feats) +# else: +# feats, painting = get_masked_affinity_matrix(painting, feats, current_mask, ps=dims[0]) +# +# A, D = get_affinity_matrix(feats, tau=0.15) +# # get the second-smallest eigenvector +# _, second_smallest_vec = second_smallest_eigenvector(A, D) +# # get salient area +# bipartition = get_salient_areas(second_smallest_vec) +# +# # check if we should reverse the partition based on: +# # 1) peak of the 2nd smallest eigvec 2) object centric bias +# seed = np.argmax(np.abs(second_smallest_vec)) +# nc = check_num_fg_corners(bipartition, dims) +# if nc >= 2: +# reverse = True +# else: +# reverse = bipartition[seed] != 1 +# if reverse: +# second_smallest_vec = 1 - second_smallest_vec +# second_smallest_vec = torch.tensor(second_smallest_vec).to(images.device).contiguous() +# map = torch.nn.functional.interpolate(second_smallest_vec.reshape(1, 1, dims[0], dims[1]), size=img_size, +# mode='bilinear') +# map -= map.min() +# map /= map.max() +# predominence_map.append(map) +# init_dist = torch.cat(predominence_map, dim=0).detach() +# return init_dist, A, feats, painting + + +def interpolate_pos_encoding(pos_embed, n_frames, h, w): + N = pos_embed.shape[1] + if N == (h * w * n_frames): + return pos_embed + old_h = old_w = int((N / n_frames) ** 0.5) + patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2) + + patch_pos_embed = F.interpolate( + patch_pos_embed, + size=(h, w), + mode='bicubic', + ) + return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) + + +def flow_to_rgb(vec, flow_mag_range=None, white_bg=False): + height, width = vec.shape[:2] + scaling = 50. / (height**2 + width**2)**0.5 + direction = (np.arctan2(vec[..., 0], vec[..., 1]) + np.pi) / (2 * np.pi) + norm = np.linalg.norm(vec, axis=-1) + if flow_mag_range is None: + flow_mag_range = norm.min(), norm.max() + magnitude = np.clip((norm - flow_mag_range[0]) * scaling, 0., 1.) + if white_bg == True: + value = np.ones_like(direction) + hsv = np.stack([direction, magnitude, saturation], axis=-1) + else: + saturation = np.ones_like(direction) + hsv = np.stack([direction, saturation , magnitude], axis=-1) + rgb = matplotlib.colors.hsv_to_rgb(hsv) + return rgb \ No newline at end of file diff --git a/external/__init__.py b/external/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7154a9d18789f292b58f6c7683f1ddefaa871462 --- /dev/null +++ b/external/__init__.py @@ -0,0 +1,38 @@ +import os +import subprocess +import torch + +def setup_raft(): + print('RAFT is not installed. Auto-install RAFT to ~/.cache/torch/RAFT') + # Store the current working directory + initial_directory = os.getcwd() + os.makedirs(os.path.join(os.environ['HOME'], '.cache/torch/'), exist_ok=True) + os.chdir(os.path.join(os.environ['HOME'], '.cache/torch/')) + subprocess.run(['git', 'clone', 'https://github.com/princeton-vl/RAFT.git'], check=True) + os.chdir('RAFT') + subprocess.run(['./download_models.sh'], check=True) + # Change back to the initial directory + os.chdir(initial_directory) + +def setup_cutler(): + print('CUTLER is not installed. Auto-install CUTLER to ~/.cache/torch/CUTLER') + # Store the current working directory + initial_directory = os.getcwd() + os.makedirs(os.path.join(os.environ['HOME'], '.cache/torch/'), exist_ok=True) + os.chdir(os.path.join(os.environ['HOME'], '.cache/torch/')) + subprocess.run(['git', 'clone', '--recursive', 'https://github.com/facebookresearch/CutLER.git'], check=True) + subprocess.run(['git', 'clone', 'https://github.com/facebookresearch/detectron2.git'], check=True) + os.chdir('detectron2') + subprocess.run(['pip', 'install', '-e', '.'], check=True) + subprocess.run(['pip', 'install', 'git+https://github.com/cocodataset/panopticapi.git'], check=True) + subprocess.run(['pip', 'install', 'git+https://github.com/mcordts/cityscapesScripts.git'], check=True) + subprocess.run(['pip', 'install', 'git+https://github.com/lucasb-eyer/pydensecrf.git'], check=True) + + # Change back to the initial directory + os.chdir(initial_directory) + +if not os.path.isdir(os.path.join(os.environ['HOME'], '.cache/torch/', 'RAFT')): + setup_raft() + +if not os.path.isdir(os.path.join(os.environ['HOME'], '.cache/torch/', 'CutLER')): + setup_cutler() diff --git a/external/app.py b/external/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d759404634be3798908bdd3718c25b4f6f6a8d30 --- /dev/null +++ b/external/app.py @@ -0,0 +1,71 @@ +import cv2 +import numpy as np +import gradio as gr + +from pydantic import BaseModel + + +# Points color and marker +color = (0, 255, 0) # Red color for all points +marker_type = 1 # Cross marker + +with gr.Blocks() as demo: + with gr.Row(): + gr.Markdown('''# Annotate Points!🚀 + Upload an image and click to annotate points on it. + ''') + + # Annotating points on an image + with gr.Tab(label='Image'): + with gr.Row():#.style(equal_height=True): + with gr.Column(): + # Input image + original_image = gr.State(value=None) # store original image without points + input_image = gr.Image(type="numpy", label="Upload Image") + + # Annotate points + selected_points = gr.State([]) # store points + with gr.Row(): + gr.Markdown('Click on the image to select points.') + undo_button = gr.Button('Undo point') + + # Show the image with the annotated points + with gr.Tab(label='Image+Points'): + output_image = gr.Image(type='numpy') + + + # Store the original image once uploaded + def store_img(img): + return img, [] # Reset selected points when a new image is uploaded + + + input_image.upload(store_img, [input_image], [original_image, selected_points]) + + + # Get points when clicked on the image + def get_point(img, sel_pix, evt: gr.SelectData): + sel_pix.append(evt.index) # Append the point's location (coordinates) + + # Draw points on the image + for point in sel_pix: + cv2.drawMarker(img, point, color, markerType=marker_type, markerSize=80, thickness=20) + return img if isinstance(img, np.ndarray) else np.array(img) + + + input_image.select(get_point, [input_image, selected_points], [input_image]) + + + # Undo the last selected point + def undo_points(orig_img, sel_pix): + temp = orig_img.copy() + if len(sel_pix) != 0: + sel_pix.pop() # Remove the last point + for point in sel_pix: + cv2.drawMarker(temp, point, color, markerType=marker_type, markerSize=20, thickness=5) + return temp if isinstance(temp, np.ndarray) else np.array(temp) + + + undo_button.click(undo_points, [original_image, selected_points], [input_image]) + + # Launch the app +demo.queue().launch(inbrowser=True) diff --git a/external/gradio_app.py b/external/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..d16e32280f9ffc352b9dc2dff80659bb7e901a3f --- /dev/null +++ b/external/gradio_app.py @@ -0,0 +1,82 @@ +import cv2 +import gradio as gr +import numpy as np + +# Global variable to store the list of arrows (start and end points) +arrows = [] +start_point = None + + +# Function to draw all arrows (including zero-length arrows) on the image +def draw_arrow(image, click_coords): + global start_point, arrows + + # Convert the image to a numpy array if it's not already + img_array = np.array(image, dtype=np.uint8) + + # Get the current point from the user (click coordinates as (x, y)) + current_point = (int(click_coords[0]), int(click_coords[1])) # Convert float coords to int + + # If start point is not set, set it as the current click position + if start_point is None: + start_point = current_point + return img_array # No arrow yet, just return the original image + + # If start point is already set, add an arrow (including zero-length ones) + end_point = current_point + arrows.append((start_point, end_point)) # Save the arrow + + # Reset start_point for the next arrow + start_point = None + + # Draw all arrows from the saved list + for arrow in arrows: + start, end = arrow + color = (0, 255, 0) # Green arrow + thickness = 3 + + # Draw the arrow (even if it's zero-length, i.e., start == end) + cv2.arrowedLine(img_array, start, end, color, thickness) + + return img_array + + +# Function to reset the canvas (clearing all arrows) +def reset_canvas(): + global arrows, start_point + arrows = [] + start_point = None + return load_image() # Return a fresh image + + +# Load an image for the user to interact with +def load_image(): + img = np.ones((400, 400, 3), dtype=np.uint8) * 255 # White background image + return img + + +# Define Gradio interface using Blocks +def interactive_arrow_interface(): + with gr.Blocks() as demo: + image_input = gr.Image(value=load_image(), interactive=True, + label="Click to specify the arrow's start and end points") + output_image = gr.Image(label="Image with Arrows") + reset_button = gr.Button("Reset") + + # Set up interaction: Handle click events with 'handle_click' + def handle_click(image, evt: gr.SelectData): + print(f"Click coordinates: {evt.index}") + # Pass click coordinates to draw_arrow function and update the image + updated_image = draw_arrow(image, (evt.index[0], evt.index[1])) + return updated_image + + image_input.select(handle_click, [image_input], output_image) + + # Set up the reset button to clear all arrows + reset_button.click(fn=reset_canvas, inputs=None, outputs=image_input) + + return demo + + +# Launch the interactive demo +interactive_arrow_interface().launch(inbrowser=True) diff --git a/external/raft_interface.py b/external/raft_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..72f84450cd25e6150ad60dfc8732f36a6ede8148 --- /dev/null +++ b/external/raft_interface.py @@ -0,0 +1,65 @@ +import sys +import os +sys.path.insert(0, os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/core')) +from raft import RAFT +from utils import flow_viz +sys.path = sys.path[1:] # remove the first path to RAFT +import torch +from cwm.utils import imagenet_unnormalize +from torch import nn +import argparse + + +class Args: + model = os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/models/raft-sintel.pth') + small = False + path = None + mixed_precision = False + alternate_corr = False + + def __iter__(self): + for attr, value in self.__dict__.items(): + yield attr, value + +class RAFTInterface(nn.Module): + def __init__(self): + super().__init__() + args = Args() + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model, map_location=torch.device('cpu'))) + self.model = model.module + self.model.eval() + + for p in self.model.parameters(): + p.requires_grad = False + + @staticmethod + def prepare_inputs(x): + # make sure the input is in the correct format for RAFT + if x.max() <= 1.0 and x.min() >= 0.: # range(0, 1) + x = x * 255. + elif x.min() < 0: # imagenet normalized: + x = imagenet_unnormalize(x) + x = x * 255. + + return x + + def forward(self, x0, x1, return_magnitude=False): + # x0: imagenet-normalized image 0 [B, C, H, W] + # x1: imagenet-normalized image 1 [B, C, H, W] + + # ensure inputs in + x0 = self.prepare_inputs(x0) + x1 = self.prepare_inputs(x1) + with torch.no_grad(): + _, flow_up = self.model(x0, x1, iters=20, test_mode=True) + + if return_magnitude: + flow_magnitude = flow_up.norm(p=2, dim=1) # [B, H, W] + return flow_up, flow_magnitude + + return flow_up + + def viz(self, flow): + flow_rgb = flow_viz.flow_to_image(flow[0].permute(1,2,0).cpu().numpy()) + return flow_rgb \ No newline at end of file diff --git a/gradio_app_intervention.py b/gradio_app_intervention.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5c3d8a7367d0b3854db219df1d6b508bd855e1 --- /dev/null +++ b/gradio_app_intervention.py @@ -0,0 +1,234 @@ +import cv2 +import numpy as np +import gradio as gr +import cwm.utils as utils + +# Points color and arrow properties +arrow_color = (0, 255, 0) # Green color for all arrows +dot_color = (0, 255, 0) # Green color for the dots at start and end +dot_color_fixed = (255, 0, 0) # Red color for zero-length vectors +thickness = 4 # Thickness of the arrow +tip_length = 0.3 # The length of the arrow tip relative to the arrow length +dot_radius = 10 # Radius for the dots +dot_thickness = -1 # Thickness for solid circle (-1 fills the circle) +from PIL import Image +import torch +#load model +from cwm.model.model_factory import model_factory + +from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +# Load CWM 3-frame model (automatically download pre-trained checkpoint) +model = model_factory.load_model('vitb_8x8patch_3frames').to(device) + +model.requires_grad_(False) +model.eval() + +model = model.to(torch.float16) + + +import matplotlib.pyplot as plt +from matplotlib.patches import FancyArrowPatch +from PIL import Image +import numpy as np + +def draw_arrows_matplotlib(img, selected_points, zero_length): + """ + Draw arrows on the image using matplotlib for better quality arrows and dots. + """ + fig, ax = plt.subplots() + ax.imshow(img) + + for i in range(0, len(selected_points), 2): + start_point = selected_points[i] + end_point = selected_points[i + 1] + + if start_point == end_point or zero_length: + # Draw a dot for zero-length vectors or if only one point is clicked + ax.scatter(start_point[0], start_point[1], color='red', s=100) # Red dot for zero-length vector + else: + # Draw arrows + arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]), + color='green', linewidth=2, arrowstyle='->', mutation_scale=15) + ax.add_patch(arrow) + + # Optionally, draw a small circle (dot) at the start and end points + ax.scatter(start_point[0], start_point[1], color='green', s=100) # Green dot at start + ax.scatter(end_point[0], end_point[1], color='green', s=100) # Green dot at end + + # Save the image to a numpy array + fig.canvas.draw() + img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close(fig) + return img_array + + +with gr.Blocks() as demo: + with gr.Row(): + gr.Markdown('''# Generate interventions!🚀 + Upload an image and click to select the start and end points for arrows. Dots will be shown at the beginning and end of each arrow. You can also create zero-length vectors (just a dot) by enabling the toggle below. + ''') + + # Annotating arrows on an image + with gr.Tab(label='Image'): + with gr.Row(): + with gr.Column(): + # Input image + original_image = gr.State(value=None) # store original image without arrows + input_image = gr.Image(type="numpy", label="Upload Image") + + # Annotate arrows + selected_points = gr.State([]) # store points + zero_length_toggle = gr.Checkbox(label="Enable zero-length vectors", value=False) # Toggle for zero-length vectors + with gr.Row(): + gr.Markdown('Click on the image to select the start and end points for each arrow. If zero-length vectors are enabled, clicking once will draw a dot.') + undo_button = gr.Button('Undo last action') + clear_button = gr.Button('Clear All') + + # Run model button + run_model_button = gr.Button('Run Model') + + # Show the image with the annotated arrows + with gr.Tab(label='Intervention'): + output_image = gr.Image(type='numpy') + + def resize_to_square(img, size=512): + img = Image.fromarray(img) + img = img.resize((size, size)) + return np.array(img) + + # Store the original image and resize to square size once uploaded + def store_img(img): + resized_img = resize_to_square(img) # Resize the uploaded image to a square + print(f"Image uploaded with shape: {resized_img.shape}") + return resized_img, resized_img, [] + + input_image.upload(store_img, [input_image], [input_image, original_image, selected_points]) + + # Get points and draw arrows or zero-length vectors based on the toggle + def get_point(img, sel_pix, zero_length, evt: gr.SelectData): + sel_pix.append(evt.index) # Append the point's location (coordinates) + + # Zero-length vector case: Draw a single dot at the clicked point + if zero_length: + point = sel_pix[-1] # Last point clicked + cv2.circle(img, point, dot_radius, dot_color_fixed, dot_thickness) # Draw a dot at the point + sel_pix.append(evt.index) + else: + # Regular case: two clicks for an arrow + # Check if this is the first point (start point for the arrow) + if len(sel_pix) % 2 == 1: + # Draw a dot at the start point to give feedback + start_point = sel_pix[-1] # Last point is the start + cv2.circle(img, start_point, dot_radius, dot_color, dot_thickness) + + # Check if two points have been selected (start and end points for an arrow) + if len(sel_pix) % 2 == 0: + # Draw an arrow between the last two points + start_point = sel_pix[-2] # Second last point is the start + end_point = sel_pix[-1] # Last point is the end + + # Draw arrow + cv2.arrowedLine(img, start_point, end_point, arrow_color, thickness, tipLength=tip_length) + + # Draw a dot at the end point + cv2.circle(img, end_point, dot_radius, dot_color, dot_thickness) + + return img if isinstance(img, np.ndarray) else np.array(img) + + input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image]) + + # Undo the last selected action + def undo_arrows(orig_img, sel_pix, zero_length): + temp = orig_img.copy() + # if zero_length: + # # Undo the last zero-length vector (just the last dot) + # if len(sel_pix) >= 1: + # sel_pix.pop() # Remove the last point + # else: + if len(sel_pix) >= 2: + sel_pix.pop() # Remove the last end point + sel_pix.pop() # Remove the last start point + + # Redraw all remaining arrows and dots + for i in range(0, len(sel_pix), 2): + start_point = sel_pix[i] + end_point = sel_pix[i + 1] + if start_point == end_point: + # Zero-length vector: Draw a dot + color = dot_color_fixed + else: + cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length) + color = arrow_color + # Draw arrow + + # Draw dots at start and end points + cv2.circle(temp, start_point, dot_radius, color, dot_thickness) + cv2.circle(temp, end_point, dot_radius, color, dot_thickness) + + # If there is an odd number of points (e.g., only a start point), draw a dot for it + if len(sel_pix) == 1: + start_point = sel_pix[0] + cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness) + + return temp if isinstance(temp, np.ndarray) else np.array(temp) + + undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image]) + + # Clear all points and reset the image + def clear_all_points(orig_img, sel_pix): + sel_pix.clear() # Clear all points + return orig_img # Reset image to original + + clear_button.click(clear_all_points, [original_image, selected_points], [input_image]) + + # Dummy model function to simulate running a model + def run_model_on_points(points, input_image, original_image): + H = input_image.shape[0] + W = input_image.shape[1] + factor = 224/H + # Example: pretend the model processes points and returns a simple transformation on the image + points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor + + points = points[:, [1, 0, 3, 2]] + + print(points) + + img = Image.fromarray(original_image) + + img = img.resize((224, 224)) + + img = np.array(img) + + np.save("img.npy", original_image) + + img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 + + img = img[None] + + # reshape image to [B, C, T, H, W], C = 3, T = 3 (3-frame model), H = W = 224 + x = img[:, :, None].expand(-1, -1, 3, -1, -1).to(torch.float16) + + # Imagenet-normalize the inputs (standardization) + x = utils.imagenet_normalize(x).to(device) + with torch.no_grad(): + counterfactual = model.get_counterfactual(x, points) + + counterfactual = counterfactual.squeeze() + + counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy() + + # for i in range(0, len(points), 2): + # # Draw rectangles on the points as model output example + # cv2.rectangle(processed_image, points[i], points[i + 1], (255, 0, 0), 3) + return counterfactual + + # Run model when the button is clicked + run_model_button.click(run_model_on_points, [selected_points, input_image, original_image], [output_image]) + + # Launch the app +demo.queue().launch(inbrowser=True, share=True) diff --git a/gradio_app_intervention_better_circles.py b/gradio_app_intervention_better_circles.py new file mode 100644 index 0000000000000000000000000000000000000000..56dda3b5c19a29ce2825affeac7893caf50745de --- /dev/null +++ b/gradio_app_intervention_better_circles.py @@ -0,0 +1,289 @@ +import cv2 +import numpy as np +import gradio as gr +import cwm.utils as utils + +# Points color and arrow properties +arrow_color = (0, 255, 0) # Green color for all arrows +dot_color = (0, 255, 0) # Green color for the dots at start and end +dot_color_fixed = (255, 0, 0) # Red color for zero-length vectors +thickness = 3 # Thickness of the arrow +tip_length = 0.3 # The length of the arrow tip relative to the arrow length +dot_radius = 7 # Radius for the dots +dot_thickness = -1 # Thickness for solid circle (-1 fills the circle) +from PIL import Image +import torch +#load model +from cwm.model.model_factory import model_factory + +from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +# Load CWM 3-frame model (automatically download pre-trained checkpoint) +model = model_factory.load_model('vitb_8x8patch_3frames').to(device) + +model.requires_grad_(False) +model.eval() + +model = model.to(torch.float16) + + +import matplotlib.pyplot as plt +from matplotlib.patches import FancyArrowPatch +from PIL import Image +import numpy as np + +from torchvision import transforms + +def draw_arrows_matplotlib(img, selected_points, zero_length): + """ + Draw arrows on the image using matplotlib for better quality arrows and dots. + """ + fig, ax = plt.subplots() + ax.imshow(img) + + for i in range(0, len(selected_points), 2): + start_point = selected_points[i] + end_point = selected_points[i + 1] + + if start_point == end_point or zero_length: + # Draw a dot for zero-length vectors or if only one point is clicked + ax.scatter(start_point[0], start_point[1], color='red', s=100) # Red dot for zero-length vector + else: + # Draw arrows + arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]), + color='green', linewidth=2, arrowstyle='->', mutation_scale=15) + ax.add_patch(arrow) + + # Optionally, draw a small circle (dot) at the start and end points + ax.scatter(start_point[0], start_point[1], color='green', s=100) # Green dot at start + ax.scatter(end_point[0], end_point[1], color='green', s=100) # Green dot at end + + # Save the image to a numpy array + fig.canvas.draw() + img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close(fig) + return img_array + +import os +# def load_preuploaded_images(): +# image_folder = "assets" +# images = [] +# for img_file in os.listdir(image_folder): +# img_path = os.path.join(image_folder, img_file) +# if img_file.endswith(('png', 'jpg', 'jpeg')): +# images.append(Image.open(img_path)) +# return images +# +# # Function to transfer image from gallery to the input image section + + +# +# preloaded_images = load_preuploaded_images() +# +# print("Preloaded images:", preloaded_images) +with gr.Blocks() as demo: + with gr.Row(): + gr.Markdown('''# Generate interventions!🚀 + Upload an image and click to select the start and end points for arrows. Dots will be shown at the beginning and end of each arrow. You can also create zero-length vectors (just a dot) by enabling the toggle below. + ''') + + # Annotating arrows on an image + with gr.Tab(label='Image'): + with gr.Row(): + with gr.Column(): + # Input image + original_image = gr.State(value=None) # store original image without arrows + input_image = gr.Image(type="numpy", label="Upload Image") + + # Annotate arrows + selected_points = gr.State([]) # store points + zero_length_toggle = gr.Checkbox(label="Select patches to be kept fixed", value=False) # Toggle for zero-length vectors + with gr.Row(): + gr.Markdown('Click on the image to select the start and end points for each arrow. If zero-length vectors are enabled, clicking once will draw a dot.') + undo_button = gr.Button('Undo last action') + clear_button = gr.Button('Clear All') + + # Run model button + run_model_button = gr.Button('Run Model') + + # Show the image with the annotated arrows + with gr.Tab(label='Intervention'): + output_image = gr.Image(type='numpy') + + # Store the original image and resize to square size once uploaded + def resize_to_square(img, size=448): + print("Resizing image to square") + img = Image.fromarray(img) + transform = transforms.Compose([ + transforms.Resize(size), + transforms.CenterCrop(size) + ]) + img = transform(img) # .transpose(1, 2, 0) + + return np.array(img) + + + def load_img(evt: gr.SelectData): + img_path = evt.value['image']['path'] + img = np.array(Image.open(img_path)) + # print(f"Image uploaded with shape: {input.shape}") + resized_img = resize_to_square(img) + return resized_img, resized_img, [] + + + def store_img(img): + resized_img = resize_to_square(img) # Resize the uploaded image to a square + print(f"Image uploaded with shape: {resized_img.shape}") + return resized_img, resized_img, [] + + + with gr.Row(): + with gr.Column(): + gallery = gr.Gallery( ["./assets/desk_1.jpg", "./assets/color_wheel.png", "./assets/desk_1.jpg", "./assets/desk_1.jpg", "./assets/desk_1.jpg"], columns=5, allow_preview=False, label="Select an example image to test") + # examples = gr.Examples( + # examples=[ + # ["./assets/desk_1.jpg", "./assets/desk_1.jpg"], + # ], + # inputs=[input_image, original_image], + # # fn=load_img, + # # outputs=[input_image, original_image], + # # cache_examples=True, + # # run_on_click=True, + # # label="Select an example image to test" + # ) + gallery.select(load_img, outputs=[input_image, original_image, selected_points]) + + input_image.upload(store_img, [input_image], [input_image, original_image, selected_points]) + + # Get points and draw arrows or zero-length vectors based on the toggle + def get_point(img, sel_pix, zero_length, evt: gr.SelectData): + sel_pix.append(evt.index) # Append the point's location (coordinates) + + # Zero-length vector case: Draw a single dot at the clicked point + if zero_length: + point = sel_pix[-1] # Last point clicked + cv2.circle(img, point, dot_radius, dot_color_fixed, dot_thickness, lineType=cv2.LINE_AA) # Draw a dot at the point + sel_pix.append(evt.index) + else: + # Regular case: two clicks for an arrow + # Check if this is the first point (start point for the arrow) + if len(sel_pix) % 2 == 1: + # Draw a dot at the start point to give feedback + start_point = sel_pix[-1] # Last point is the start + cv2.circle(img, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) + + # Check if two points have been selected (start and end points for an arrow) + if len(sel_pix) % 2 == 0: + # Draw an arrow between the last two points + start_point = sel_pix[-2] # Second last point is the start + end_point = sel_pix[-1] # Last point is the end + + # Draw arrow + cv2.arrowedLine(img, start_point, end_point, arrow_color, thickness, tipLength=tip_length, line_type=cv2.LINE_AA) + + # Draw a dot at the end point + cv2.circle(img, end_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) + + return img if isinstance(img, np.ndarray) else np.array(img) + + input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image]) + + # Undo the last selected action + def undo_arrows(orig_img, sel_pix, zero_length): + temp = orig_img.copy() + # if zero_length: + # # Undo the last zero-length vector (just the last dot) + # if len(sel_pix) >= 1: + # sel_pix.pop() # Remove the last point + # else: + if len(sel_pix) >= 2: + sel_pix.pop() # Remove the last end point + sel_pix.pop() # Remove the last start point + + # Redraw all remaining arrows and dots + for i in range(0, len(sel_pix), 2): + start_point = sel_pix[i] + end_point = sel_pix[i + 1] + if start_point == end_point: + # Zero-length vector: Draw a dot + color = dot_color_fixed + else: + cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length) + color = arrow_color + # Draw arrow + + # Draw dots at start and end points + cv2.circle(temp, start_point, dot_radius, color, dot_thickness) + cv2.circle(temp, end_point, dot_radius, color, dot_thickness) + + # If there is an odd number of points (e.g., only a start point), draw a dot for it + if len(sel_pix) == 1: + start_point = sel_pix[0] + cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness) + + return temp if isinstance(temp, np.ndarray) else np.array(temp) + + undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image]) + + + # Clear all points and reset the image + def clear_all_points(orig_img, sel_pix): + sel_pix.clear() # Clear all points + return orig_img # Reset image to original + + clear_button.click(clear_all_points, [original_image, selected_points], [input_image]) + + # Dummy model function to simulate running a model + def run_model_on_points(points, input_image, original_image): + H = input_image.shape[0] + W = input_image.shape[1] + factor = 224/H + # Example: pretend the model processes points and returns a simple transformation on the image + points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor + + points = points[:, [1, 0, 3, 2]] + + print(points) + + img = Image.fromarray(original_image) + + transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224) + ]) + img = np.array(transform(img)) + + # np.save("img.npy", original_image) + + img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 + + img = img[None] + + # reshape image to [B, C, T, H, W], C = 3, T = 3 (3-frame model), H = W = 224 + x = img[:, :, None].expand(-1, -1, 3, -1, -1).to(torch.float16) + + # Imagenet-normalize the inputs (standardization) + x = utils.imagenet_normalize(x).to(device) + with torch.no_grad(): + counterfactual = model.get_counterfactual(x, points) + + counterfactual = counterfactual.squeeze() + + counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy() + + # for i in range(0, len(points), 2): + # # Draw rectangles on the points as model output example + # cv2.rectangle(processed_image, points[i], points[i + 1], (255, 0, 0), 3) + return counterfactual + + # Run model when the button is clicked + run_model_button.click(run_model_on_points, [selected_points, input_image, original_image], [output_image]) + + + + # Launch the app +demo.queue().launch(inbrowser=True, share=True) diff --git a/gradio_app_intervention_better_circles_pil.py b/gradio_app_intervention_better_circles_pil.py new file mode 100644 index 0000000000000000000000000000000000000000..c039ac03269c65285b9618431474b138daeb0af0 --- /dev/null +++ b/gradio_app_intervention_better_circles_pil.py @@ -0,0 +1,250 @@ +import cv2 +import numpy as np +import gradio as gr +import cwm.utils as utils + +# Points color and arrow properties +arrow_color = (0, 255, 0) # Green color for all arrows +dot_color = (0, 255, 0) # Green color for the dots at start and end +dot_color_fixed = (255, 0, 0) # Red color for zero-length vectors +thickness = 3 # Thickness of the arrow +tip_length = 0.3 # The length of the arrow tip relative to the arrow length +dot_radius = 7 # Radius for the dots +dot_thickness = -1 # Thickness for solid circle (-1 fills the circle) +from PIL import Image +import torch +#load model +from cwm.model.model_factory import model_factory + +from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +# Load CWM 3-frame model (automatically download pre-trained checkpoint) +model = model_factory.load_model('vitb_8x8patch_3frames').to(device) + +model.requires_grad_(False) +model.eval() + +model = model.to(torch.float16) + + +import matplotlib.pyplot as plt +from matplotlib.patches import FancyArrowPatch +from PIL import Image +import numpy as np + +def draw_arrows_matplotlib(img, selected_points, zero_length): + """ + Draw arrows on the image using matplotlib for better quality arrows and dots. + """ + fig, ax = plt.subplots() + ax.imshow(img) + + for i in range(0, len(selected_points), 2): + start_point = selected_points[i] + end_point = selected_points[i + 1] + + if start_point == end_point or zero_length: + # Draw a dot for zero-length vectors or if only one point is clicked + ax.scatter(start_point[0], start_point[1], color='red', s=100) # Red dot for zero-length vector + else: + # Draw arrows + arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]), + color='green', linewidth=2, arrowstyle='->', mutation_scale=15) + ax.add_patch(arrow) + + # Optionally, draw a small circle (dot) at the start and end points + ax.scatter(start_point[0], start_point[1], color='green', s=100) # Green dot at start + ax.scatter(end_point[0], end_point[1], color='green', s=100) # Green dot at end + + # Save the image to a numpy array + fig.canvas.draw() + img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close(fig) + return img_array + +from PIL import ImageDraw +with gr.Blocks() as demo: + with gr.Row(): + gr.Markdown('''# Generate interventions!🚀 + Upload an image and click to select the start and end points for arrows. Dots will be shown at the beginning and end of each arrow. You can also create zero-length vectors (just a dot) by enabling the toggle below. + ''') + + # Annotating arrows on an image + with gr.Tab(label='Image'): + with gr.Row(): + with gr.Column(): + # Input image + original_image = gr.State(value=None) # store original image without arrows + input_image = gr.Image(type="numpy", label="Upload Image") + + # Annotate arrows + selected_points = gr.State([]) # store points + zero_length_toggle = gr.Checkbox(label="Select patches to be kept fixed", value=False) # Toggle for zero-length vectors + with gr.Row(): + gr.Markdown('Click on the image to select the start and end points for each arrow. If zero-length vectors are enabled, clicking once will draw a dot.') + undo_button = gr.Button('Undo last action') + clear_button = gr.Button('Clear All') + + # Run model button + run_model_button = gr.Button('Run Model') + + # Show the image with the annotated arrows + with gr.Tab(label='Intervention'): + output_image = gr.Image(type='numpy') + + def resize_to_square(img, size=512): + img = Image.fromarray(img) + img = img.resize((size, size)) + return np.array(img) + + # Store the original image and resize to square size once uploaded + def store_img(img): + resized_img = resize_to_square(img) # Resize the uploaded image to a square + print(f"Image uploaded with shape: {resized_img.shape}") + return resized_img, resized_img, [] + + input_image.upload(store_img, [input_image], [input_image, original_image, selected_points]) + + # Get points and draw arrows or zero-length vectors based on the toggle + def get_point(img, sel_pix, zero_length, evt: gr.SelectData): + sel_pix.append(evt.index) # Append the point's location (coordinates) + + pil_img = Image.fromarray(img) + draw = ImageDraw.Draw(pil_img) + # Zero-length vector case: Draw a single dot at the clicked point + if zero_length: + point = sel_pix[-1] # Last point clicked + draw.ellipse([point[0] - dot_radius, point[1] - dot_radius, + point[0] + dot_radius, point[1] + dot_radius], + fill=dot_color_fixed) + sel_pix.append(evt.index) + else: + # Regular case: two clicks for an arrow + # Check if this is the first point (start point for the arrow) + if len(sel_pix) % 2 == 1: + # Draw a dot at the start point to give feedback + start_point = sel_pix[-1] # Last point is the start + draw.ellipse([start_point[0] - dot_radius, start_point[1] - dot_radius, + start_point[0] + dot_radius, start_point[1] + dot_radius], + fill=dot_color) + + # Check if two points have been selected (start and end points for an arrow) + if len(sel_pix) % 2 == 0: + # Draw an arrow between the last two points + start_point = tuple(sel_pix[-2]) # Second last point is the start + end_point = tuple(sel_pix[-1]) # Last point is the end + + # Draw arrow + draw.line([start_point, end_point], fill=arrow_color, width=thickness) + + # Draw a dot at the end point + draw.ellipse([end_point[0] - dot_radius, end_point[1] - dot_radius, + end_point[0] + dot_radius, end_point[1] + dot_radius], + fill=dot_color) + + return np.array(pil_img) + + input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image]) + + # Undo the last selected action + def undo_arrows(orig_img, sel_pix, zero_length): + temp = orig_img.copy() + # if zero_length: + # # Undo the last zero-length vector (just the last dot) + # if len(sel_pix) >= 1: + # sel_pix.pop() # Remove the last point + # else: + pil_img = Image.fromarray(temp) + draw = ImageDraw.Draw(pil_img) + if len(sel_pix) >= 2: + sel_pix.pop() # Remove the last end point + sel_pix.pop() # Remove the last start point + + # Redraw all remaining arrows and dots + for i in range(0, len(sel_pix), 2): + start_point = sel_pix[i] + end_point = sel_pix[i + 1] + if start_point == end_point: + # Zero-length vector: Draw a dot + color = dot_color_fixed + else: + draw.line([tuple(start_point), tuple(end_point)], fill=arrow_color, width=thickness) + color = arrow_color + # Draw arrow + + # Draw dots at start and end points + draw.ellipse([start_point[0] - dot_radius, start_point[1] - dot_radius, + start_point[0] + dot_radius, start_point[1] + dot_radius], + fill=color) + draw.ellipse([end_point[0] - dot_radius, end_point[1] - dot_radius, + end_point[0] + dot_radius, end_point[1] + dot_radius], + fill=color) + + # If there is an odd number of points (e.g., only a start point), draw a dot for it + if len(sel_pix) == 1: + start_point = sel_pix[0] + draw.ellipse([start_point[0] - dot_radius, start_point[1] - dot_radius, + start_point[0] + dot_radius, start_point[1] + dot_radius], + fill=dot_color) + + return np.array(pil_img) + + undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image]) + + # Clear all points and reset the image + def clear_all_points(orig_img, sel_pix): + sel_pix.clear() # Clear all points + return orig_img # Reset image to original + + clear_button.click(clear_all_points, [original_image, selected_points], [input_image]) + + # Dummy model function to simulate running a model + def run_model_on_points(points, input_image, original_image): + H = input_image.shape[0] + W = input_image.shape[1] + factor = 224/H + # Example: pretend the model processes points and returns a simple transformation on the image + points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor + + points = points[:, [1, 0, 3, 2]] + + print(points) + + img = Image.fromarray(original_image) + + img = img.resize((224, 224)) + + img = np.array(img) + + np.save("img.npy", original_image) + + img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 + + img = img[None] + + # reshape image to [B, C, T, H, W], C = 3, T = 3 (3-frame model), H = W = 224 + x = img[:, :, None].expand(-1, -1, 3, -1, -1).to(torch.float16) + + # Imagenet-normalize the inputs (standardization) + x = utils.imagenet_normalize(x).to(device) + with torch.no_grad(): + counterfactual = model.get_counterfactual(x, points) + + counterfactual = counterfactual.squeeze() + + counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy() + + # for i in range(0, len(points), 2): + # # Draw rectangles on the points as model output example + # cv2.rectangle(processed_image, points[i], points[i + 1], (255, 0, 0), 3) + return counterfactual + + # Run model when the button is clicked + run_model_button.click(run_model_on_points, [selected_points, input_image, original_image], [output_image]) + + # Launch the app +demo.queue().launch(inbrowser=True, share=True) diff --git a/notebooks/visualize_counterfactual.ipynb b/notebooks/visualize_counterfactual.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..54a269c7e65ed69d3106bdb4a97c4481cfea5b29 --- /dev/null +++ b/notebooks/visualize_counterfactual.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e9210a69-b1b3-4495-a30e-0bb61fcd06c4", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '0' # specify which GPU to use\n", + "import torch\n", + "import cwm.utils as utils\n", + "from torchvision.io import read_image\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.patches as patches\n", + "from cwm.model.model_factory import model_factory\n", + "\n", + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "6362414f-f6a6-4f3a-910e-55c599ff4943", + "metadata": {}, + "source": [ + "## Load Pre-trained CWM Model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "078f6662-012a-4d46-b597-ea62b6fc9f1d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model loaded successfully\n" + ] + } + ], + "source": [ + "# Load CWM 3-frame model (automatically download pre-trained checkpoint)\n", + "model = model_factory.load_model('vitb_8x8patch_3frames').cuda() " + ] + }, + { + "cell_type": "markdown", + "id": "caa3a025-2eab-4b79-a57e-490b3bf8f282", + "metadata": {}, + "source": [ + "## Prepare model inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "24d89320-905b-4ca2-87c5-dfa902dedd87", + "metadata": {}, + "outputs": [], + "source": [ + "# Load input image from an image path\n", + "img_path = '/ccn2/u/honglinc/bbnet_notebooks/jordon.jpeg' \n", + "\n", + "# normalized to the range [0, 1]\n", + "img = read_image(img_path)[:, :1480].cuda() / 255. \n", + "\n", + "# resize image to input resolution [224, 224]\n", + "img = torch.nn.functional.interpolate(img[None], size=224, mode='bicubic') \n", + "\n", + "# reshape image to [B, C, T, H, W], C = 3, T = 3 (3-frame model), H = W = 224\n", + "x = img[:, :, None].expand(-1, -1, 3, -1, -1).to(torch.float16) \n", + "\n", + "# Imagenet-normalize the inputs (standardization)\n", + "x = utils.imagenet_normalize(x)\n", + "\n", + "# Prepare masks for the 3-frame model: 28x28=784 patches per image, the first two are unmasked, last one is masked\n", + "bool_masked_pos = torch.ones(1, 784*3).to(x.device).bool()\n", + "bool_masked_pos[:, 0:784*2] = False\n", + "\n", + "# patch size\n", + "patch_size = model.encoder.patch_size[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "99773a0d-a568-410a-82e2-d8d7a3de91ca", + "metadata": {}, + "source": [ + "## Generate counterfactual predictions" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9accdceb-b5fc-4d94-a9ca-2589fb9d4e37", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 784, 192]) torch.Size([1, 784]) torch.Size([1, 784, 192]) 28 28\n" + ] + } + ], + "source": [ + "move_patches=[[8, 11], [16, 12], [22, 4]] # Specify which patches to move\n", + "static_patches=[[10, 22], [4, 4]] # Specify which patches to hold still\n", + "delta = [-3, -1] # Specify the direction of motion\n", + "\n", + "# Model inference\n", + "with torch.cuda.amp.autocast(enabled=True):\n", + " output = model.get_counterfactual(x, bool_masked_pos, move_patches, static_patches, delta) # model reconstruction output\n", + " recon = utils.unpatchify_cwm(output, patch_size=patch_size, mask=bool_masked_pos[:, -784:]) # reshape the output to an image " + ] + }, + { + "cell_type": "markdown", + "id": "eed12ad9-a7c8-49c7-add2-d1ea8129aea6", + "metadata": {}, + "source": [ + "## Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8b6983e6-5b76-4e93-a385-c03c9da80ed0", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n", + "\n", + "# Visualize the input image and selected patches\n", + "axs[0].imshow(img.cpu()[0].permute(1,2,0).clamp(0, 1))\n", + "axs[0].set_axis_off()\n", + "axs[0].set_title('Input image')\n", + "\n", + "for (px, py) in move_patches:\n", + " circle = patches.Rectangle((py * patch_size, px * patch_size), 7, 7, facecolor='#00b530', edgecolor='white', linewidth=2, label='moving')\n", + " axs[0].add_patch(circle)\n", + "\n", + "# Visualize the counterfactual prediction\n", + "axs[1].imshow(recon[0].clamp(0, 1).float().permute(1,2,0).cpu().detach())\n", + "axs[1].set_axis_off()\n", + "axs[1].set_title('Counterfactual prediction')\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/cwm/2frame_clumping_tpu.sh b/scripts/cwm/2frame_clumping_tpu.sh new file mode 100755 index 0000000000000000000000000000000000000000..61bb73dfc5b86daeb944a53346563c07d2a1a268 --- /dev/null +++ b/scripts/cwm/2frame_clumping_tpu.sh @@ -0,0 +1,30 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_2frame_clumping/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.99 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_2frames_1tube_cf2 \ + --context_frames 1 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 100 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla diff --git a/scripts/cwm/2frame_no_clumping_tpu.sh b/scripts/cwm/2frame_no_clumping_tpu.sh new file mode 100755 index 0000000000000000000000000000000000000000..30e3e38fb5ef9dbfa507a9961adc41807944f3b2 --- /dev/null +++ b/scripts/cwm/2frame_no_clumping_tpu.sh @@ -0,0 +1,31 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_2frame_no_clumping/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.99 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_2frames_1tube \ + --context_frames 1 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 100 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --min_lr 1e-5 diff --git a/scripts/cwm/2frame_no_clumping_tpu_mr0.90.sh b/scripts/cwm/2frame_no_clumping_tpu_mr0.90.sh new file mode 100755 index 0000000000000000000000000000000000000000..75fbd91c8270d635a00702f2f8b3edfb69d3bd48 --- /dev/null +++ b/scripts/cwm/2frame_no_clumping_tpu_mr0.90.sh @@ -0,0 +1,31 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_2frame_no_clumping_mr0.90/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_2frames_1tube \ + --context_frames 1 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 100 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --min_lr 1e-5 diff --git a/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.85.sh b/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.85.sh new file mode 100755 index 0000000000000000000000000000000000000000..1adfca7f3cadb44aeb389633e16d2db84f004a57 --- /dev/null +++ b/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.85.sh @@ -0,0 +1,30 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_16x16_no_clumping_mr0.85/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.85 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_16x16patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 32 \ + --accum_iter 1 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ diff --git a/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.90.sh b/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.90.sh new file mode 100755 index 0000000000000000000000000000000000000000..bb17a97788393e77b5145031d9c44210360e1d9f --- /dev/null +++ b/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.90.sh @@ -0,0 +1,30 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_16x16_no_clumping_mr0.90/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_16x16patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 32 \ + --accum_iter 1 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ diff --git a/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.90_extra_data.sh b/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.90_extra_data.sh new file mode 100755 index 0000000000000000000000000000000000000000..b63f9b13a467a560a24c65a05f3754afc3fe9397 --- /dev/null +++ b/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.90_extra_data.sh @@ -0,0 +1,34 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_16x16_no_clumping_mr0.90_extra_data/" +DATA_PATH_Ktr="/ccn2/dataset/Kinetics700/kinetics_700_train_list.txt" +DATA_PATH_M="/ccn2/dataset/Moments/multi_moment_train_list.txt" +DATA_PATH_E="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/ego4d_train_list_320p_chunked_imu.txt" +DATA_PATH_H="/ccn2/dataset2/how_to_100m/how_to_100m_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path_list ${DATA_PATH_Ktr} ${DATA_PATH_M} ${DATA_PATH_E} \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_16x16patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 48 \ + --accum_iter 1 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --resume ./checkpoint-799.pth \ No newline at end of file diff --git a/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.95.sh b/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.95.sh new file mode 100755 index 0000000000000000000000000000000000000000..654fa2489e2946591988e72c23ff66aee27ad927 --- /dev/null +++ b/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.95.sh @@ -0,0 +1,30 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_16x16_no_clumping_mr0.95/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.95 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_16x16patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 32 \ + --accum_iter 1 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ diff --git a/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.98.sh b/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.98.sh new file mode 100755 index 0000000000000000000000000000000000000000..2e878e948cef2dc6e5af0b11c7c78b3fcb435780 --- /dev/null +++ b/scripts/cwm/3frame_16x16_no_clumping_tpu_mr0.98.sh @@ -0,0 +1,30 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_16x16_no_clumping_mr0.98/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.98 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_16x16patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 32 \ + --accum_iter 1 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ diff --git a/scripts/cwm/3frame_clumping3_tpu.sh b/scripts/cwm/3frame_clumping3_tpu.sh new file mode 100755 index 0000000000000000000000000000000000000000..b1560d4d8cc2aab3c8579deb8b5602173b8910e7 --- /dev/null +++ b/scripts/cwm/3frame_clumping3_tpu.sh @@ -0,0 +1,30 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_clumping3/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.99 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube_cf3 \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 100 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla diff --git a/scripts/cwm/3frame_clumping_gpu.sh b/scripts/cwm/3frame_clumping_gpu.sh new file mode 100755 index 0000000000000000000000000000000000000000..4d41333645b66c798cfd5017d9f0a1a0dd1bd7cf --- /dev/null +++ b/scripts/cwm/3frame_clumping_gpu.sh @@ -0,0 +1,31 @@ +OUTPUT_DIR='/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_clumping_v2/' +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +CUDA_VISIBLE_DEVICES=0 OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=1 \ + --master_addr=10.102.2.157 --master_port=32240 \ + --nnodes=1 --node_rank=0 \ + run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.99 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube_cf2 \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 3 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 10 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 0 diff --git a/scripts/cwm/3frame_clumping_tpu.sh b/scripts/cwm/3frame_clumping_tpu.sh new file mode 100755 index 0000000000000000000000000000000000000000..53970008ab660c06fd0bf38cfaddb47858cece6e --- /dev/null +++ b/scripts/cwm/3frame_clumping_tpu.sh @@ -0,0 +1,30 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_clumping_v3/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.99 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube_cf2 \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 100 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ No newline at end of file diff --git a/scripts/cwm/3frame_no_clumping_gpu.sh b/scripts/cwm/3frame_no_clumping_gpu.sh new file mode 100755 index 0000000000000000000000000000000000000000..38fa8d71f629f625ded73561a39bd6ae61711d37 --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_gpu.sh @@ -0,0 +1,32 @@ +OUTPUT_DIR='/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_clumping_debug/' +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +CUDA_VISIBLE_DEVICES=0 OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=1 \ + --master_addr=10.102.2.137 --master_port=32240 \ + --nnodes=1 --node_rank=0 \ + run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 diff --git a/scripts/cwm/3frame_no_clumping_gpu_extra_data.sh b/scripts/cwm/3frame_no_clumping_gpu_extra_data.sh new file mode 100755 index 0000000000000000000000000000000000000000..ba062436cbe01e208c84f3b0c130bb90c329dcdb --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_gpu_extra_data.sh @@ -0,0 +1,34 @@ +OUTPUT_DIR='/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_clumping_debug/' +DATA_PATH_Ktr="/ccn2/dataset/Kinetics700/kinetics_700_train_list.txt" +DATA_PATH_M="/ccn2/dataset/Moments/multi_moment_train_list.txt" +DATA_PATH_E="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/ego4d_train_list_320p_chunked_imu.txt" +DATA_PATH_H="/ccn2/dataset/how_to_100m/how_to_100m_train_list.txt" + +CUDA_VISIBLE_DEVICES=0 OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=1 \ + --master_addr=10.102.2.137 --master_port=32240 \ + --nnodes=1 --node_rank=0 \ + run_cwm_pretraining.py \ + --data_path_list ${DATA_PATH_H} ${DATA_PATH_Ktr} ${DATA_PATH_M} ${DATA_PATH_E} \ + --mask_type rotated_table \ + --mask_ratio 0.99 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 3 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 1 \ + --save_ckpt_freq 1 \ + --epochs 19 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 0 diff --git a/scripts/cwm/3frame_no_clumping_maskvit_gpu.sh b/scripts/cwm/3frame_no_clumping_maskvit_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..842df2f83287df411d30f30693131acc6d98d65d --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_maskvit_gpu.sh @@ -0,0 +1,33 @@ +OUTPUT_DIR='/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_maskvit/' +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ + --master_addr=10.102.2.153 --master_port=32240 \ + --nnodes=2 --node_rank=1 \ + run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table_maskvit \ + --mask_ratio 0.99 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 16 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 10 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 20 \ + --num_workers 32 \ + --min_lr 1e-5 diff --git a/scripts/cwm/3frame_no_clumping_tpu.sh b/scripts/cwm/3frame_no_clumping_tpu.sh new file mode 100755 index 0000000000000000000000000000000000000000..61fd7281abaa2125168b14299a7a1f0fe53b7898 --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu.sh @@ -0,0 +1,31 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_ep3200/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.99 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 1 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 100 \ + --epochs 3200 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --resume ./checkpoint-899.pth diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.60.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.60.sh new file mode 100755 index 0000000000000000000000000000000000000000..83af544a2260ab536cede1dcf4b90dad0ece9bfe --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.60.sh @@ -0,0 +1,31 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.60/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.60 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --min_lr 1e-5 diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.70.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.70.sh new file mode 100755 index 0000000000000000000000000000000000000000..c14b98571766b10920d6314956843e68b6f36bdd --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.70.sh @@ -0,0 +1,31 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.70/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.70 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --min_lr 1e-5 diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.80.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.80.sh new file mode 100755 index 0000000000000000000000000000000000000000..3897d0c36b352c5d183fca3378dc1f07ec312f43 --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.80.sh @@ -0,0 +1,32 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.80_rerun/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.80 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --min_lr 1e-5 \ + #--resume checkpoint-249.pth diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.85.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.85.sh new file mode 100755 index 0000000000000000000000000000000000000000..bf990bf7de1f3bff1fe51a7c70da85f980286252 --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.85.sh @@ -0,0 +1,31 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.85/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.85 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --min_lr 1e-5 diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.85_extra_data_ep400_huge.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.85_extra_data_ep400_huge.sh new file mode 100755 index 0000000000000000000000000000000000000000..0f38124e47c0ea550fdde24a8ae305ad6ffe8b4a --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.85_extra_data_ep400_huge.sh @@ -0,0 +1,34 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.85_extra_data_ep400_huge/" +DATA_PATH_Ktr="/ccn2/dataset/Kinetics700/kinetics_700_train_list.txt" +DATA_PATH_M="/ccn2/dataset/Moments/multi_moment_train_list.txt" +DATA_PATH_E="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/ego4d_train_list_320p_chunked_imu.txt" +DATA_PATH_H="/ccn2/dataset2/how_to_100m/how_to_100m_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path_list ${DATA_PATH_Ktr} ${DATA_PATH_M} ${DATA_PATH_E} \ + --mask_type rotated_table \ + --mask_ratio 0.85 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vithuge_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 2 \ + --accum_iter 8 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 20 \ + --save_ckpt_freq 10 \ + --epochs 400 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 8 \ + --use_xla \ + --resume checkpoint-219.pth diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.85_extra_data_ep400_large.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.85_extra_data_ep400_large.sh new file mode 100755 index 0000000000000000000000000000000000000000..a40b857a49309a632c92caebc4b07472601da51b --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.85_extra_data_ep400_large.sh @@ -0,0 +1,33 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.85_extra_data_ep400_large/" +DATA_PATH_Ktr="/ccn2/dataset/Kinetics700/kinetics_700_train_list.txt" +DATA_PATH_M="/ccn2/dataset/Moments/multi_moment_train_list.txt" +DATA_PATH_E="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/ego4d_train_list_320p_chunked_imu.txt" +DATA_PATH_H="/ccn2/dataset2/how_to_100m/how_to_100m_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path_list ${DATA_PATH_Ktr} ${DATA_PATH_M} ${DATA_PATH_E} \ + --mask_type rotated_table \ + --mask_ratio 0.85 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitlarge_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 8 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 20 \ + --save_ckpt_freq 10 \ + --epochs 400 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.90.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.90.sh new file mode 100755 index 0000000000000000000000000000000000000000..32e19d695ab15e5eb30b7bcf0f2161786ddbbccb --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.90.sh @@ -0,0 +1,32 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.90/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --min_lr 1e-5 \ + --resume ./checkpoint-699.pth diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data.sh new file mode 100755 index 0000000000000000000000000000000000000000..614e4978a0f0d7caaaf2eca3f4b05c8993b06b0a --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data.sh @@ -0,0 +1,36 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_wo_howto/" + +DATA_PATH_Ktr="/ccn2/dataset/Kinetics700/kinetics_700_train_list.txt" +DATA_PATH_M="/ccn2/dataset/Moments/multi_moment_train_list.txt" +DATA_PATH_E="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/ego4d_train_list_320p_chunked_imu.txt" +DATA_PATH_H="/ccn2/dataset2/how_to_100m/how_to_100m_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path_list ${DATA_PATH_Ktr} ${DATA_PATH_M} ${DATA_PATH_E} \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 1 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 5 \ + --save_ckpt_freq 5 \ + --epochs 105 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --min_lr 1e-5 \ + --resume checkpoint-29.pth diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data_ep400.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data_ep400.sh new file mode 100755 index 0000000000000000000000000000000000000000..4483c67362e378a2d6978e2fa9e3b18a7d0c8cec --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data_ep400.sh @@ -0,0 +1,34 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/" +DATA_PATH_Ktr="/ccn2/dataset/Kinetics700/kinetics_700_train_list.txt" +DATA_PATH_M="/ccn2/dataset/Moments/multi_moment_train_list.txt" +DATA_PATH_E="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/ego4d_train_list_320p_chunked_imu.txt" +DATA_PATH_H="/ccn2/dataset2/how_to_100m/how_to_100m_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path_list ${DATA_PATH_Ktr} ${DATA_PATH_M} ${DATA_PATH_E} \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 1 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 20 \ + --save_ckpt_freq 10 \ + --epochs 400 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --resume checkpoint-239.pth \ No newline at end of file diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data_ep400_huge.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data_ep400_huge.sh new file mode 100755 index 0000000000000000000000000000000000000000..12b9adaf4624e428c2eaa341d0b4cb136371cf0d --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data_ep400_huge.sh @@ -0,0 +1,34 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400_huge/" +DATA_PATH_Ktr="/ccn2/dataset/Kinetics700/kinetics_700_train_list.txt" +DATA_PATH_M="/ccn2/dataset/Moments/multi_moment_train_list.txt" +DATA_PATH_E="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/ego4d_train_list_320p_chunked_imu.txt" +DATA_PATH_H="/ccn2/dataset2/how_to_100m/how_to_100m_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path_list ${DATA_PATH_Ktr} ${DATA_PATH_M} ${DATA_PATH_E} \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vithuge_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 2 \ + --accum_iter 8 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 20 \ + --save_ckpt_freq 10 \ + --epochs 400 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 8 \ + --use_xla \ + --resume checkpoint-339.pth diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data_ep400_large.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data_ep400_large.sh new file mode 100755 index 0000000000000000000000000000000000000000..da9eac8e6a59dfe51c33f931422cd5ec9027e30d --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.90_extra_data_ep400_large.sh @@ -0,0 +1,34 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400_large/" +DATA_PATH_Ktr="/ccn2/dataset/Kinetics700/kinetics_700_train_list.txt" +DATA_PATH_M="/ccn2/dataset/Moments/multi_moment_train_list.txt" +DATA_PATH_E="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/ego4d_train_list_320p_chunked_imu.txt" +DATA_PATH_H="/ccn2/dataset2/how_to_100m/how_to_100m_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path_list ${DATA_PATH_Ktr} ${DATA_PATH_M} ${DATA_PATH_E} \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitlarge_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 8 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 20 \ + --save_ckpt_freq 10 \ + --epochs 400 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --resume checkpoint-229.pth diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.90_rotated_table.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.90_rotated_table.sh new file mode 100755 index 0000000000000000000000000000000000000000..c350428465a6f7266f035050b4d84442134b1cbb --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.90_rotated_table.sh @@ -0,0 +1,31 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.90_rotated_tbl/" +DATA_PATH="${HOME}/cwm/cwm/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/cwm/cwm/run_cwm_pretraining_legacy.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 6}' \ + --model vitbase_8x8patch_8frames_1tube \ + --context_frames 2 \ + --target_frames 6 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 1 \ + --accum_iter 1 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 256 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 1 \ + --use_xla \ + --min_lr 1e-5 \ diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr0.95.sh b/scripts/cwm/3frame_no_clumping_tpu_mr0.95.sh new file mode 100755 index 0000000000000000000000000000000000000000..646c7f647223f4a868bd0ab1aeeea226665728f3 --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr0.95.sh @@ -0,0 +1,32 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr0.95/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 0.95 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --min_lr 1e-5 \ + --resume ./checkpoint-649.pth diff --git a/scripts/cwm/3frame_no_clumping_tpu_mr1.0.sh b/scripts/cwm/3frame_no_clumping_tpu_mr1.0.sh new file mode 100755 index 0000000000000000000000000000000000000000..efca48088d3c32d5a83b5258fcde7a0c422bdf78 --- /dev/null +++ b/scripts/cwm/3frame_no_clumping_tpu_mr1.0.sh @@ -0,0 +1,32 @@ +OUTPUT_DIR="${HOME}/checkpoints/cwm_cvpr_checkpoints/ablation_3frame_no_clumping_mr1.0/" +DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt" + +python3 ~/BBNet/bbnet/models/VideoMAE-main/run_cwm_pretraining.py \ + --data_path ${DATA_PATH} \ + --mask_type rotated_table \ + --mask_ratio 1.0 \ + --mask_kwargs '{"tube_length": 1}' \ + --model vitbase_8x8patch_3frames_1tube \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 16 \ + --accum_iter 2 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --no_normlize_target \ + --rescale_size 224 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 \ + --use_xla \ + --min_lr 1e-5 \ + #--resume checkpoint-249.pth diff --git a/scripts/pretrain/3frame_patch8x8_mr0.90_gpu.sh b/scripts/pretrain/3frame_patch8x8_mr0.90_gpu.sh new file mode 100755 index 0000000000000000000000000000000000000000..f48072225171857f78c15dcda1afcbb6e875daff --- /dev/null +++ b/scripts/pretrain/3frame_patch8x8_mr0.90_gpu.sh @@ -0,0 +1,39 @@ +OUTPUT_DIR='checkpoints/3frame_patch8x8_mr0.90_gpu/' +DATA_PATH="cwm/data/video_file_lists/kinetics_400_train_list.txt" +MASTER_ADDRESS=10.102.2.146 +NNODES=1 +NODE_RANK=0 +NPROC_PER_NODE=1 + +echo "master addr: $MASTER_ADDRESS" +echo "num of nodes: $NNODES" +echo "node rank: $NODE_RANK" +echo "procs per node: $NPROC_PER_NODE" + +OMP_NUM_THREADS=1 torchrun \ + --nproc_per_node=$NPROC_PER_NODE --nnodes=$NNODES --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDRESS --master_port=19234 \ + cwm/run_pretraining.py \ + --data_path ${DATA_PATH} \ + --model vitb_8x8patch_3frames \ + --mask_type rotated_table \ + --mask_ratio 0.90 \ + --mask_kwargs '{"tube_length": 1}' \ + --context_frames 2 \ + --target_frames 1 \ + --temporal_units 'ms' \ + --sampling_rate 150 \ + --context_target_gap 150 150 \ + --batch_size 1 \ + --accum_iter 1 \ + --opt adamw \ + --opt_betas 0.9 0.95 \ + --warmup_epochs 40 \ + --save_ckpt_freq 50 \ + --epochs 800 \ + --augmentation_type 'multiscale' \ + --augmentation_scales 1.0 0.875 0.75 0.66 \ + --log_dir ${OUTPUT_DIR} \ + --output_dir ${OUTPUT_DIR} \ + --print_freq 1 \ + --num_workers 16 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..df22e58d1eeec4d90c8906e6bbdc1673da29383d --- /dev/null +++ b/setup.py @@ -0,0 +1,33 @@ +from setuptools import setup, find_packages + +setup( + name="cwm", + version="1.0", + packages=find_packages(), + description="Video modeling transformers package for MAE", + author="Dan Bear, Kevin Figelis, Honglin Chen, Rahul Venakatesh, Wanhee Lee, Dan Yamins, Klemen Kotar", + install_requires=[ + # 'scipy', + # 'scikit-learn', + # 'matplotlib', + # 'h5py', + # 'kornia', + # 'future', + # 'einops', + # 'timm', + # 'opencv-python', + # 'decord', + # 'iopath', + # 'pandas', + # 'IPython', + # 'tensorboardx', + # 'tensorboard', + # 'positional_encodings', + # 'scikit-image', + # 'packaging', + # 'wandb', + # 'numpy==1.23.5', + # 'torch==2.0.0', + # 'torchvision==0.15.1' + ], +)