Spaces:
Running
on
A10G
Running
on
A10G
Commit
·
8b7a0c4
1
Parent(s):
6c8a46d
datasets
Browse files- minigpt4/datasets/__init__.py +0 -0
- minigpt4/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- minigpt4/datasets/__pycache__/data_utils.cpython-39.pyc +0 -0
- minigpt4/datasets/builders/__init__.py +68 -0
- minigpt4/datasets/builders/__pycache__/__init__.cpython-39.pyc +0 -0
- minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc +0 -0
- minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc +0 -0
- minigpt4/datasets/builders/base_dataset_builder.py +237 -0
- minigpt4/datasets/builders/image_text_pair_builder.py +41 -0
- minigpt4/datasets/data_utils.py +199 -0
- minigpt4/datasets/datasets/__init__.py +0 -0
- minigpt4/datasets/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/aok_vqa_datasets.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/coco_caption.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/coco_dataset.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/coco_vqa_datasets.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/face_emotion.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/first_face.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/flickr.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/gqa_datasets.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/llava_dataset.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/multitask_conversation.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/ocrvqa_dataset.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/text_caps.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/unnatural_instruction.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/vg_dataset.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/__pycache__/vqa_datasets.cpython-39.pyc +0 -0
- minigpt4/datasets/datasets/base_dataset.py +78 -0
- minigpt4/datasets/datasets/dataloader_utils.py +162 -0
- minigpt4/datasets/datasets/first_face.py +174 -0
minigpt4/datasets/__init__.py
ADDED
File without changes
|
minigpt4/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (152 Bytes). View file
|
|
minigpt4/datasets/__pycache__/data_utils.cpython-39.pyc
ADDED
Binary file (6.05 kB). View file
|
|
minigpt4/datasets/builders/__init__.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
|
9 |
+
from minigpt4.datasets.builders.image_text_pair_builder import (
|
10 |
+
FirstfaceCaptionBuilder,
|
11 |
+
)
|
12 |
+
from minigpt4.common.registry import registry
|
13 |
+
|
14 |
+
__all__ = [
|
15 |
+
"FirstfaceCaptionBuilder",
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
|
20 |
+
"""
|
21 |
+
Example
|
22 |
+
|
23 |
+
>>> dataset = load_dataset("coco_caption", cfg=None)
|
24 |
+
>>> splits = dataset.keys()
|
25 |
+
>>> print([len(dataset[split]) for split in splits])
|
26 |
+
|
27 |
+
"""
|
28 |
+
if cfg_path is None:
|
29 |
+
cfg = None
|
30 |
+
else:
|
31 |
+
cfg = load_dataset_config(cfg_path)
|
32 |
+
|
33 |
+
try:
|
34 |
+
builder = registry.get_builder_class(name)(cfg)
|
35 |
+
except TypeError:
|
36 |
+
print(
|
37 |
+
f"Dataset {name} not found. Available datasets:\n"
|
38 |
+
+ ", ".join([str(k) for k in dataset_zoo.get_names()])
|
39 |
+
)
|
40 |
+
exit(1)
|
41 |
+
|
42 |
+
if vis_path is not None:
|
43 |
+
if data_type is None:
|
44 |
+
# use default data type in the config
|
45 |
+
data_type = builder.config.data_type
|
46 |
+
|
47 |
+
assert (
|
48 |
+
data_type in builder.config.build_info
|
49 |
+
), f"Invalid data_type {data_type} for {name}."
|
50 |
+
|
51 |
+
builder.config.build_info.get(data_type).storage = vis_path
|
52 |
+
|
53 |
+
dataset = builder.build_datasets()
|
54 |
+
return dataset
|
55 |
+
|
56 |
+
|
57 |
+
class DatasetZoo:
|
58 |
+
def __init__(self) -> None:
|
59 |
+
self.dataset_zoo = {
|
60 |
+
k: list(v.DATASET_CONFIG_DICT.keys())
|
61 |
+
for k, v in sorted(registry.mapping["builder_name_mapping"].items())
|
62 |
+
}
|
63 |
+
|
64 |
+
def get_names(self):
|
65 |
+
return list(self.dataset_zoo.keys())
|
66 |
+
|
67 |
+
|
68 |
+
dataset_zoo = DatasetZoo()
|
minigpt4/datasets/builders/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (2.33 kB). View file
|
|
minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc
ADDED
Binary file (6.12 kB). View file
|
|
minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc
ADDED
Binary file (1.49 kB). View file
|
|
minigpt4/datasets/builders/base_dataset_builder.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is from
|
3 |
+
Copyright (c) 2022, salesforce.com, inc.
|
4 |
+
All rights reserved.
|
5 |
+
SPDX-License-Identifier: BSD-3-Clause
|
6 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
7 |
+
"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import shutil
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
import torch.distributed as dist
|
16 |
+
from torchvision.datasets.utils import download_url
|
17 |
+
|
18 |
+
import minigpt4.common.utils as utils
|
19 |
+
from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
|
20 |
+
from minigpt4.common.registry import registry
|
21 |
+
from minigpt4.processors.base_processor import BaseProcessor
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
class BaseDatasetBuilder:
|
26 |
+
train_dataset_cls, eval_dataset_cls = None, None
|
27 |
+
|
28 |
+
def __init__(self, cfg=None):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
if cfg is None:
|
32 |
+
# help to create datasets from default config.
|
33 |
+
self.config = load_dataset_config(self.default_config_path())
|
34 |
+
elif isinstance(cfg, str):
|
35 |
+
self.config = load_dataset_config(cfg)
|
36 |
+
else:
|
37 |
+
# when called from task.build_dataset()
|
38 |
+
self.config = cfg
|
39 |
+
|
40 |
+
self.data_type = self.config.data_type
|
41 |
+
print("BaseDatasetBuilder data type:", self.data_type)
|
42 |
+
|
43 |
+
self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
44 |
+
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
45 |
+
|
46 |
+
def build_datasets(self):
|
47 |
+
# download, split, etc...
|
48 |
+
# only called on 1 GPU/TPU in distributed
|
49 |
+
|
50 |
+
if is_main_process():
|
51 |
+
self._download_data()
|
52 |
+
|
53 |
+
if is_dist_avail_and_initialized():
|
54 |
+
dist.barrier()
|
55 |
+
|
56 |
+
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
57 |
+
logging.info("Building datasets...")
|
58 |
+
datasets = self.build() # dataset['train'/'val'/'test']
|
59 |
+
|
60 |
+
return datasets
|
61 |
+
|
62 |
+
def build_processors(self):
|
63 |
+
vis_proc_cfg = self.config.get("vis_processor")
|
64 |
+
txt_proc_cfg = self.config.get("text_processor")
|
65 |
+
|
66 |
+
if vis_proc_cfg is not None:
|
67 |
+
vis_train_cfg = vis_proc_cfg.get("train")
|
68 |
+
vis_eval_cfg = vis_proc_cfg.get("eval")
|
69 |
+
|
70 |
+
self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
|
71 |
+
self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
|
72 |
+
|
73 |
+
if txt_proc_cfg is not None:
|
74 |
+
txt_train_cfg = txt_proc_cfg.get("train")
|
75 |
+
txt_eval_cfg = txt_proc_cfg.get("eval")
|
76 |
+
|
77 |
+
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
|
78 |
+
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def _build_proc_from_cfg(cfg):
|
82 |
+
return (
|
83 |
+
registry.get_processor_class(cfg.name).from_config(cfg)
|
84 |
+
if cfg is not None
|
85 |
+
else None
|
86 |
+
)
|
87 |
+
|
88 |
+
@classmethod
|
89 |
+
def default_config_path(cls, type="default"):
|
90 |
+
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
91 |
+
|
92 |
+
def _download_data(self):
|
93 |
+
self._download_ann()
|
94 |
+
self._download_vis()
|
95 |
+
|
96 |
+
def _download_ann(self):
|
97 |
+
"""
|
98 |
+
Download annotation files if necessary.
|
99 |
+
All the vision-language datasets should have annotations of unified format.
|
100 |
+
|
101 |
+
storage_path can be:
|
102 |
+
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
|
103 |
+
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
|
104 |
+
|
105 |
+
Local annotation paths should be relative.
|
106 |
+
"""
|
107 |
+
anns = self.config.build_info.annotations
|
108 |
+
|
109 |
+
splits = anns.keys()
|
110 |
+
|
111 |
+
cache_root = registry.get_path("cache_root")
|
112 |
+
|
113 |
+
for split in splits:
|
114 |
+
info = anns[split]
|
115 |
+
|
116 |
+
urls, storage_paths = info.get("url", None), info.storage
|
117 |
+
|
118 |
+
if isinstance(urls, str):
|
119 |
+
urls = [urls]
|
120 |
+
if isinstance(storage_paths, str):
|
121 |
+
storage_paths = [storage_paths]
|
122 |
+
|
123 |
+
assert len(urls) == len(storage_paths)
|
124 |
+
|
125 |
+
for url_or_filename, storage_path in zip(urls, storage_paths):
|
126 |
+
# if storage_path is relative, make it full by prefixing with cache_root.
|
127 |
+
if not os.path.isabs(storage_path):
|
128 |
+
storage_path = os.path.join(cache_root, storage_path)
|
129 |
+
|
130 |
+
dirname = os.path.dirname(storage_path)
|
131 |
+
if not os.path.exists(dirname):
|
132 |
+
os.makedirs(dirname)
|
133 |
+
|
134 |
+
if os.path.isfile(url_or_filename):
|
135 |
+
src, dst = url_or_filename, storage_path
|
136 |
+
if not os.path.exists(dst):
|
137 |
+
shutil.copyfile(src=src, dst=dst)
|
138 |
+
else:
|
139 |
+
logging.info("Using existing file {}.".format(dst))
|
140 |
+
else:
|
141 |
+
if os.path.isdir(storage_path):
|
142 |
+
# if only dirname is provided, suffix with basename of URL.
|
143 |
+
raise ValueError(
|
144 |
+
"Expecting storage_path to be a file path, got directory {}".format(
|
145 |
+
storage_path
|
146 |
+
)
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
filename = os.path.basename(storage_path)
|
150 |
+
|
151 |
+
download_url(url=url_or_filename, root=dirname, filename=filename)
|
152 |
+
|
153 |
+
def _download_vis(self):
|
154 |
+
|
155 |
+
storage_path = self.config.build_info.get(self.data_type).storage
|
156 |
+
storage_path = utils.get_cache_path(storage_path)
|
157 |
+
|
158 |
+
if not os.path.exists(storage_path):
|
159 |
+
warnings.warn(
|
160 |
+
f"""
|
161 |
+
The specified path {storage_path} for visual inputs does not exist.
|
162 |
+
Please provide a correct path to the visual inputs or
|
163 |
+
refer to datasets/download_scripts/README.md for downloading instructions.
|
164 |
+
"""
|
165 |
+
)
|
166 |
+
|
167 |
+
def build(self):
|
168 |
+
"""
|
169 |
+
Create by split datasets inheriting torch.utils.data.Datasets.
|
170 |
+
|
171 |
+
# build() can be dataset-specific. Overwrite to customize.
|
172 |
+
"""
|
173 |
+
self.build_processors()
|
174 |
+
|
175 |
+
build_info = self.config.build_info
|
176 |
+
|
177 |
+
ann_info = build_info.annotations
|
178 |
+
vis_info = build_info.get(self.data_type)
|
179 |
+
|
180 |
+
datasets = dict()
|
181 |
+
for split in ann_info.keys():
|
182 |
+
if split not in ["train", "val", "test"]:
|
183 |
+
continue
|
184 |
+
|
185 |
+
is_train = split == "train"
|
186 |
+
|
187 |
+
# processors
|
188 |
+
vis_processor = (
|
189 |
+
self.vis_processors["train"]
|
190 |
+
if is_train
|
191 |
+
else self.vis_processors["eval"]
|
192 |
+
)
|
193 |
+
text_processor = (
|
194 |
+
self.text_processors["train"]
|
195 |
+
if is_train
|
196 |
+
else self.text_processors["eval"]
|
197 |
+
)
|
198 |
+
|
199 |
+
# annotation path
|
200 |
+
ann_paths = ann_info.get(split).storage
|
201 |
+
if isinstance(ann_paths, str):
|
202 |
+
ann_paths = [ann_paths]
|
203 |
+
|
204 |
+
abs_ann_paths = []
|
205 |
+
for ann_path in ann_paths:
|
206 |
+
if not os.path.isabs(ann_path):
|
207 |
+
ann_path = utils.get_cache_path(ann_path)
|
208 |
+
abs_ann_paths.append(ann_path)
|
209 |
+
ann_paths = abs_ann_paths
|
210 |
+
|
211 |
+
# visual data storage path
|
212 |
+
vis_path = os.path.join(vis_info.storage, split)
|
213 |
+
|
214 |
+
if not os.path.isabs(vis_path):
|
215 |
+
# vis_path = os.path.join(utils.get_cache_path(), vis_path)
|
216 |
+
vis_path = utils.get_cache_path(vis_path)
|
217 |
+
|
218 |
+
if not os.path.exists(vis_path):
|
219 |
+
warnings.warn("storage path {} does not exist.".format(vis_path))
|
220 |
+
|
221 |
+
# create datasets
|
222 |
+
dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
|
223 |
+
datasets[split] = dataset_cls(
|
224 |
+
vis_processor=vis_processor,
|
225 |
+
text_processor=text_processor,
|
226 |
+
ann_paths=ann_paths,
|
227 |
+
vis_root=vis_path,
|
228 |
+
)
|
229 |
+
|
230 |
+
return datasets
|
231 |
+
|
232 |
+
|
233 |
+
def load_dataset_config(cfg_path):
|
234 |
+
cfg = OmegaConf.load(cfg_path).datasets
|
235 |
+
cfg = cfg[list(cfg.keys())[0]]
|
236 |
+
|
237 |
+
return cfg
|
minigpt4/datasets/builders/image_text_pair_builder.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
from minigpt4.common.registry import registry
|
6 |
+
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
7 |
+
|
8 |
+
from minigpt4.datasets.datasets.first_face import FeatureFaceDataset
|
9 |
+
|
10 |
+
# FeatureFaceDataset
|
11 |
+
@registry.register_builder("feature_face_caption")
|
12 |
+
class FirstfaceCaptionBuilder(BaseDatasetBuilder):
|
13 |
+
train_dataset_cls = FeatureFaceDataset
|
14 |
+
|
15 |
+
DATASET_CONFIG_DICT = {"default": "configs/datasets/firstface/featureface.yaml"}
|
16 |
+
|
17 |
+
def _download_ann(self):
|
18 |
+
pass
|
19 |
+
|
20 |
+
def _download_vis(self):
|
21 |
+
pass
|
22 |
+
|
23 |
+
def build(self):
|
24 |
+
self.build_processors()
|
25 |
+
|
26 |
+
build_info = self.config.build_info
|
27 |
+
|
28 |
+
datasets = dict()
|
29 |
+
split = "train"
|
30 |
+
|
31 |
+
# create datasets
|
32 |
+
# [NOTE] return inner_datasets (wds.DataPipeline)
|
33 |
+
dataset_cls = self.train_dataset_cls
|
34 |
+
datasets[split] = dataset_cls(
|
35 |
+
vis_processor=self.vis_processors[split],
|
36 |
+
text_processor=self.text_processors[split],
|
37 |
+
ann_path=build_info.ann_path,
|
38 |
+
vis_root=build_info.image_path,
|
39 |
+
)
|
40 |
+
|
41 |
+
return datasets
|
minigpt4/datasets/data_utils.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import gzip
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import random as rnd
|
12 |
+
import tarfile
|
13 |
+
import zipfile
|
14 |
+
import random
|
15 |
+
from typing import List
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
import decord
|
19 |
+
from decord import VideoReader
|
20 |
+
import webdataset as wds
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from torch.utils.data.dataset import IterableDataset
|
24 |
+
|
25 |
+
from minigpt4.common.registry import registry
|
26 |
+
from minigpt4.datasets.datasets.base_dataset import ConcatDataset
|
27 |
+
|
28 |
+
|
29 |
+
decord.bridge.set_bridge("torch")
|
30 |
+
MAX_INT = registry.get("MAX_INT")
|
31 |
+
|
32 |
+
|
33 |
+
class ChainDataset(wds.DataPipeline):
|
34 |
+
r"""Dataset for chaining multiple :class:`DataPipeline` s.
|
35 |
+
|
36 |
+
This class is useful to assemble different existing dataset streams. The
|
37 |
+
chaining operation is done on-the-fly, so concatenating large-scale
|
38 |
+
datasets with this class will be efficient.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
datasets (iterable of IterableDataset): datasets to be chained together
|
42 |
+
"""
|
43 |
+
def __init__(self, datasets: List[wds.DataPipeline]) -> None:
|
44 |
+
super().__init__()
|
45 |
+
self.datasets = datasets
|
46 |
+
self.prob = []
|
47 |
+
self.names = []
|
48 |
+
for dataset in self.datasets:
|
49 |
+
if hasattr(dataset, 'name'):
|
50 |
+
self.names.append(dataset.name)
|
51 |
+
else:
|
52 |
+
self.names.append('Unknown')
|
53 |
+
if hasattr(dataset, 'sample_ratio'):
|
54 |
+
self.prob.append(dataset.sample_ratio)
|
55 |
+
else:
|
56 |
+
self.prob.append(1)
|
57 |
+
logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
|
58 |
+
|
59 |
+
def __iter__(self):
|
60 |
+
datastreams = [iter(dataset) for dataset in self.datasets]
|
61 |
+
while True:
|
62 |
+
select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
|
63 |
+
yield next(select_datastream)
|
64 |
+
|
65 |
+
|
66 |
+
def apply_to_sample(f, sample):
|
67 |
+
if len(sample) == 0:
|
68 |
+
return {}
|
69 |
+
|
70 |
+
def _apply(x):
|
71 |
+
if torch.is_tensor(x):
|
72 |
+
return f(x)
|
73 |
+
elif isinstance(x, dict):
|
74 |
+
return {key: _apply(value) for key, value in x.items()}
|
75 |
+
elif isinstance(x, list):
|
76 |
+
return [_apply(x) for x in x]
|
77 |
+
else:
|
78 |
+
return x
|
79 |
+
|
80 |
+
return _apply(sample)
|
81 |
+
|
82 |
+
|
83 |
+
def move_to_cuda(sample):
|
84 |
+
def _move_to_cuda(tensor):
|
85 |
+
return tensor.cuda()
|
86 |
+
|
87 |
+
return apply_to_sample(_move_to_cuda, sample)
|
88 |
+
|
89 |
+
|
90 |
+
def prepare_sample(samples, cuda_enabled=True):
|
91 |
+
if cuda_enabled:
|
92 |
+
samples = move_to_cuda(samples)
|
93 |
+
|
94 |
+
# TODO fp16 support
|
95 |
+
|
96 |
+
return samples
|
97 |
+
|
98 |
+
|
99 |
+
def reorg_datasets_by_split(datasets, batch_sizes):
|
100 |
+
"""
|
101 |
+
Organizes datasets by split.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
datasets: dict of torch.utils.data.Dataset objects by name.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Dict of datasets by split {split_name: List[Datasets]}.
|
108 |
+
"""
|
109 |
+
# if len(datasets) == 1:
|
110 |
+
# return datasets[list(datasets.keys())[0]]
|
111 |
+
# else:
|
112 |
+
reorg_datasets = dict()
|
113 |
+
reorg_batch_sizes = dict()
|
114 |
+
|
115 |
+
# reorganize by split
|
116 |
+
for dataset_name, dataset in datasets.items():
|
117 |
+
for split_name, dataset_split in dataset.items():
|
118 |
+
if split_name not in reorg_datasets:
|
119 |
+
reorg_datasets[split_name] = [dataset_split]
|
120 |
+
reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]]
|
121 |
+
else:
|
122 |
+
reorg_datasets[split_name].append(dataset_split)
|
123 |
+
reorg_batch_sizes[split_name].append(batch_sizes[dataset_name])
|
124 |
+
|
125 |
+
return reorg_datasets, reorg_batch_sizes
|
126 |
+
|
127 |
+
|
128 |
+
def concat_datasets(datasets):
|
129 |
+
"""
|
130 |
+
Concatenates multiple datasets into a single dataset.
|
131 |
+
|
132 |
+
It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
|
133 |
+
generic IterableDataset because it requires creating separate samplers.
|
134 |
+
|
135 |
+
Now only supports conctenating training datasets and assuming validation and testing
|
136 |
+
have only a single dataset. This is because metrics should not be computed on the concatenated
|
137 |
+
datasets.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
datasets: dict of torch.utils.data.Dataset objects by split.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
|
144 |
+
"val" and "test" remain the same.
|
145 |
+
|
146 |
+
If the input training datasets contain both map-style and DataPipeline datasets, returns
|
147 |
+
a tuple, where the first element is a concatenated map-style dataset and the second
|
148 |
+
element is a chained DataPipeline dataset.
|
149 |
+
|
150 |
+
"""
|
151 |
+
# concatenate datasets in the same split
|
152 |
+
for split_name in datasets:
|
153 |
+
if split_name != "train":
|
154 |
+
assert (
|
155 |
+
len(datasets[split_name]) == 1
|
156 |
+
), "Do not support multiple {} datasets.".format(split_name)
|
157 |
+
datasets[split_name] = datasets[split_name][0]
|
158 |
+
else:
|
159 |
+
iterable_datasets, map_datasets = [], []
|
160 |
+
for dataset in datasets[split_name]:
|
161 |
+
if isinstance(dataset, wds.DataPipeline):
|
162 |
+
logging.info(
|
163 |
+
"Dataset {} is IterableDataset, can't be concatenated.".format(
|
164 |
+
dataset
|
165 |
+
)
|
166 |
+
)
|
167 |
+
iterable_datasets.append(dataset)
|
168 |
+
elif isinstance(dataset, IterableDataset):
|
169 |
+
raise NotImplementedError(
|
170 |
+
"Do not support concatenation of generic IterableDataset."
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
map_datasets.append(dataset)
|
174 |
+
|
175 |
+
# if len(iterable_datasets) > 0:
|
176 |
+
# concatenate map-style datasets and iterable-style datasets separately
|
177 |
+
if len(iterable_datasets) > 1:
|
178 |
+
chained_datasets = (
|
179 |
+
ChainDataset(iterable_datasets)
|
180 |
+
)
|
181 |
+
elif len(iterable_datasets) == 1:
|
182 |
+
chained_datasets = iterable_datasets[0]
|
183 |
+
else:
|
184 |
+
chained_datasets = None
|
185 |
+
|
186 |
+
concat_datasets = (
|
187 |
+
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
|
188 |
+
)
|
189 |
+
|
190 |
+
train_datasets = concat_datasets, chained_datasets
|
191 |
+
train_datasets = tuple([x for x in train_datasets if x is not None])
|
192 |
+
train_datasets = (
|
193 |
+
train_datasets[0] if len(train_datasets) == 1 else train_datasets
|
194 |
+
)
|
195 |
+
|
196 |
+
datasets[split_name] = train_datasets
|
197 |
+
|
198 |
+
return datasets
|
199 |
+
|
minigpt4/datasets/datasets/__init__.py
ADDED
File without changes
|
minigpt4/datasets/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (161 Bytes). View file
|
|
minigpt4/datasets/datasets/__pycache__/aok_vqa_datasets.cpython-39.pyc
ADDED
Binary file (3.83 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-39.pyc
ADDED
Binary file (2.84 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-39.pyc
ADDED
Binary file (4.59 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-39.pyc
ADDED
Binary file (1.84 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/coco_caption.cpython-39.pyc
ADDED
Binary file (4.21 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/coco_dataset.cpython-39.pyc
ADDED
Binary file (11.9 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/coco_vqa_datasets.cpython-39.pyc
ADDED
Binary file (4.04 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-39.pyc
ADDED
Binary file (5.07 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/face_emotion.cpython-39.pyc
ADDED
Binary file (4.23 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/first_face.cpython-39.pyc
ADDED
Binary file (5.41 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/flickr.cpython-39.pyc
ADDED
Binary file (4.19 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/gqa_datasets.cpython-39.pyc
ADDED
Binary file (2.05 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-39.pyc
ADDED
Binary file (1.39 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/llava_dataset.cpython-39.pyc
ADDED
Binary file (4.1 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/multitask_conversation.cpython-39.pyc
ADDED
Binary file (2.4 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/ocrvqa_dataset.cpython-39.pyc
ADDED
Binary file (2.67 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/text_caps.cpython-39.pyc
ADDED
Binary file (2.79 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/unnatural_instruction.cpython-39.pyc
ADDED
Binary file (1.83 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/vg_dataset.cpython-39.pyc
ADDED
Binary file (3.06 kB). View file
|
|
minigpt4/datasets/datasets/__pycache__/vqa_datasets.cpython-39.pyc
ADDED
Binary file (6.2 kB). View file
|
|
minigpt4/datasets/datasets/base_dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
from typing import Iterable
|
10 |
+
|
11 |
+
from torch.utils.data import Dataset, ConcatDataset
|
12 |
+
from torch.utils.data.dataloader import default_collate
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class BaseDataset(Dataset):
|
18 |
+
def __init__(
|
19 |
+
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
|
20 |
+
):
|
21 |
+
"""
|
22 |
+
vis_root (string): Root directory of images (e.g. coco/images/)
|
23 |
+
ann_root (string): directory to store the annotation file
|
24 |
+
"""
|
25 |
+
self.vis_root = vis_root
|
26 |
+
|
27 |
+
self.annotation = []
|
28 |
+
# print("ann paths", ann_paths)
|
29 |
+
for ann_path in ann_paths:
|
30 |
+
# print("ann_path", ann_path)
|
31 |
+
ann = json.load(open(ann_path, "r"))
|
32 |
+
if isinstance(ann, dict):
|
33 |
+
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
|
34 |
+
# self.annotation.extend(json.load(open(ann_path, "r")))
|
35 |
+
else:
|
36 |
+
self.annotation.extend(json.load(open(ann_path, "r")))
|
37 |
+
|
38 |
+
self.vis_processor = vis_processor
|
39 |
+
self.text_processor = text_processor
|
40 |
+
|
41 |
+
self._add_instance_ids()
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.annotation)
|
45 |
+
|
46 |
+
def collater(self, samples):
|
47 |
+
return default_collate(samples)
|
48 |
+
|
49 |
+
def set_processors(self, vis_processor, text_processor):
|
50 |
+
self.vis_processor = vis_processor
|
51 |
+
self.text_processor = text_processor
|
52 |
+
|
53 |
+
def _add_instance_ids(self, key="instance_id"):
|
54 |
+
for idx, ann in enumerate(self.annotation):
|
55 |
+
ann[key] = str(idx)
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
class ConcatDataset(ConcatDataset):
|
60 |
+
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
61 |
+
super().__init__(datasets)
|
62 |
+
|
63 |
+
def collater(self, samples):
|
64 |
+
# TODO For now only supports datasets with same underlying collater implementations
|
65 |
+
|
66 |
+
all_keys = set()
|
67 |
+
for s in samples:
|
68 |
+
all_keys.update(s)
|
69 |
+
|
70 |
+
shared_keys = all_keys
|
71 |
+
for s in samples:
|
72 |
+
shared_keys = shared_keys & set(s.keys())
|
73 |
+
|
74 |
+
samples_shared_keys = []
|
75 |
+
for s in samples:
|
76 |
+
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
77 |
+
|
78 |
+
return self.datasets[0].collater(samples_shared_keys)
|
minigpt4/datasets/datasets/dataloader_utils.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import time
|
9 |
+
import random
|
10 |
+
import torch
|
11 |
+
from minigpt4.datasets.data_utils import move_to_cuda
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
|
15 |
+
class MultiIterLoader:
|
16 |
+
"""
|
17 |
+
A simple wrapper for iterating over multiple iterators.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
loaders (List[Loader]): List of Iterator loaders.
|
21 |
+
ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, loaders, ratios=None):
|
25 |
+
# assert all loaders has __next__ method
|
26 |
+
for loader in loaders:
|
27 |
+
assert hasattr(
|
28 |
+
loader, "__next__"
|
29 |
+
), "Loader {} has no __next__ method.".format(loader)
|
30 |
+
|
31 |
+
if ratios is None:
|
32 |
+
ratios = [1.0] * len(loaders)
|
33 |
+
else:
|
34 |
+
assert len(ratios) == len(loaders)
|
35 |
+
ratios = [float(ratio) / sum(ratios) for ratio in ratios]
|
36 |
+
|
37 |
+
self.loaders = loaders
|
38 |
+
self.ratios = ratios
|
39 |
+
|
40 |
+
def __next__(self):
|
41 |
+
# random sample from each loader by ratio
|
42 |
+
loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
|
43 |
+
return next(self.loaders[loader_idx])
|
44 |
+
|
45 |
+
|
46 |
+
class PrefetchLoader(object):
|
47 |
+
"""
|
48 |
+
Modified from https://github.com/ChenRocks/UNITER.
|
49 |
+
|
50 |
+
overlap compute and cuda data transfer
|
51 |
+
(copied and then modified from nvidia apex)
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self, loader):
|
55 |
+
self.loader = loader
|
56 |
+
self.stream = torch.cuda.Stream()
|
57 |
+
|
58 |
+
def __iter__(self):
|
59 |
+
loader_it = iter(self.loader)
|
60 |
+
self.preload(loader_it)
|
61 |
+
batch = self.next(loader_it)
|
62 |
+
while batch is not None:
|
63 |
+
is_tuple = isinstance(batch, tuple)
|
64 |
+
if is_tuple:
|
65 |
+
task, batch = batch
|
66 |
+
|
67 |
+
if is_tuple:
|
68 |
+
yield task, batch
|
69 |
+
else:
|
70 |
+
yield batch
|
71 |
+
batch = self.next(loader_it)
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
return len(self.loader)
|
75 |
+
|
76 |
+
def preload(self, it):
|
77 |
+
try:
|
78 |
+
self.batch = next(it)
|
79 |
+
except StopIteration:
|
80 |
+
self.batch = None
|
81 |
+
return
|
82 |
+
# if record_stream() doesn't work, another option is to make sure
|
83 |
+
# device inputs are created on the main stream.
|
84 |
+
# self.next_input_gpu = torch.empty_like(self.next_input,
|
85 |
+
# device='cuda')
|
86 |
+
# self.next_target_gpu = torch.empty_like(self.next_target,
|
87 |
+
# device='cuda')
|
88 |
+
# Need to make sure the memory allocated for next_* is not still in use
|
89 |
+
# by the main stream at the time we start copying to next_*:
|
90 |
+
# self.stream.wait_stream(torch.cuda.current_stream())
|
91 |
+
with torch.cuda.stream(self.stream):
|
92 |
+
self.batch = move_to_cuda(self.batch)
|
93 |
+
# more code for the alternative if record_stream() doesn't work:
|
94 |
+
# copy_ will record the use of the pinned source tensor in this
|
95 |
+
# side stream.
|
96 |
+
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
|
97 |
+
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
|
98 |
+
# self.next_input = self.next_input_gpu
|
99 |
+
# self.next_target = self.next_target_gpu
|
100 |
+
|
101 |
+
def next(self, it):
|
102 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
103 |
+
batch = self.batch
|
104 |
+
if batch is not None:
|
105 |
+
record_cuda_stream(batch)
|
106 |
+
self.preload(it)
|
107 |
+
return batch
|
108 |
+
|
109 |
+
def __getattr__(self, name):
|
110 |
+
method = self.loader.__getattribute__(name)
|
111 |
+
return method
|
112 |
+
|
113 |
+
|
114 |
+
def record_cuda_stream(batch):
|
115 |
+
if isinstance(batch, torch.Tensor):
|
116 |
+
batch.record_stream(torch.cuda.current_stream())
|
117 |
+
elif isinstance(batch, list) or isinstance(batch, tuple):
|
118 |
+
for t in batch:
|
119 |
+
record_cuda_stream(t)
|
120 |
+
elif isinstance(batch, dict):
|
121 |
+
for t in batch.values():
|
122 |
+
record_cuda_stream(t)
|
123 |
+
else:
|
124 |
+
pass
|
125 |
+
|
126 |
+
|
127 |
+
class IterLoader:
|
128 |
+
"""
|
129 |
+
A wrapper to convert DataLoader as an infinite iterator.
|
130 |
+
|
131 |
+
Modified from:
|
132 |
+
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
|
136 |
+
self._dataloader = dataloader
|
137 |
+
self.iter_loader = iter(self._dataloader)
|
138 |
+
self._use_distributed = use_distributed
|
139 |
+
self._epoch = 0
|
140 |
+
|
141 |
+
@property
|
142 |
+
def epoch(self) -> int:
|
143 |
+
return self._epoch
|
144 |
+
|
145 |
+
def __next__(self):
|
146 |
+
try:
|
147 |
+
data = next(self.iter_loader)
|
148 |
+
except StopIteration:
|
149 |
+
self._epoch += 1
|
150 |
+
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
|
151 |
+
self._dataloader.sampler.set_epoch(self._epoch)
|
152 |
+
time.sleep(2) # Prevent possible deadlock during epoch transition
|
153 |
+
self.iter_loader = iter(self._dataloader)
|
154 |
+
data = next(self.iter_loader)
|
155 |
+
|
156 |
+
return data
|
157 |
+
|
158 |
+
def __iter__(self):
|
159 |
+
return self
|
160 |
+
|
161 |
+
def __len__(self):
|
162 |
+
return len(self._dataloader)
|
minigpt4/datasets/datasets/first_face.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import pickle
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
import itertools
|
8 |
+
import pandas as pd
|
9 |
+
import json
|
10 |
+
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from PIL import Image
|
15 |
+
import skimage.io as io
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from matplotlib.collections import PatchCollection
|
18 |
+
from matplotlib.patches import Polygon, Rectangle
|
19 |
+
import torch
|
20 |
+
from torch.utils.data import Dataset
|
21 |
+
import webdataset as wds
|
22 |
+
import cv2
|
23 |
+
|
24 |
+
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
25 |
+
|
26 |
+
class FeatureFaceDataset(Dataset):
|
27 |
+
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
28 |
+
|
29 |
+
self.vis_root = vis_root
|
30 |
+
|
31 |
+
self.vis_processor = vis_processor
|
32 |
+
self.text_processor = text_processor
|
33 |
+
|
34 |
+
self.caption_instruction_pool = [
|
35 |
+
"Please describe the details of the expression and tone the video.",
|
36 |
+
"Can you provide a description of the facial expression and tone shown by the person in the video?",
|
37 |
+
"Could you outline the facial expressions and vocal tones displayed in the video?",
|
38 |
+
"Detail the expressions and tone used in the video.",
|
39 |
+
"Explain the visual and auditory expressions captured in the video.",
|
40 |
+
"Provide an analysis of the expressions and tone featured in the video.",
|
41 |
+
]
|
42 |
+
|
43 |
+
self.emotion_instruction_pool = [
|
44 |
+
"Please determine which emotion label in the video represents: happy, sad, neutral, angry, worried, surprise, fear, contempt, doubt.",
|
45 |
+
|
46 |
+
# "Please determine which emotion label in the video represents: happy, sad, neutral, angry, worried, surprise.",
|
47 |
+
# "Identify the displayed emotion in the video: is it happy, sad, neutral, angry, worried, or surprise?",
|
48 |
+
# "Determine the emotional state shown in the video, choosing from happy, sad, neutral, angry, worried, or surprise.",
|
49 |
+
# "Please ascertain the specific emotion portrayed in the video, whether it be happy, sad, neutral, angry, worried, or surprise.",
|
50 |
+
# "Assess and label the emotion evident in the video: could it be happy, sad, neutral, angry, worried, surprise?",
|
51 |
+
]
|
52 |
+
|
53 |
+
self.reason_instruction_pool = [
|
54 |
+
"Please analyze all the clues in the video and reason out the emotional label of the person in the video.",
|
55 |
+
"What is the emotional state of the person in the video? Please tell me the reason.",
|
56 |
+
"What are the facial expressions and vocal tone used in the video? What is the intended meaning behind his words? Which emotion does this reflect?",
|
57 |
+
"Please integrate information from various modalities to infer the emotional category of the person in the video.",
|
58 |
+
"Could you describe the emotion-related features of the individual in the video? What emotional category do they fall into?",
|
59 |
+
]
|
60 |
+
|
61 |
+
# self.task_pool = [
|
62 |
+
# "emotion",
|
63 |
+
# "reason",
|
64 |
+
# "infer",
|
65 |
+
# ]
|
66 |
+
|
67 |
+
self.task_pool = [
|
68 |
+
"emotion",
|
69 |
+
]
|
70 |
+
|
71 |
+
print("ann_path: ", ann_path)
|
72 |
+
self.ann_path = ann_path
|
73 |
+
self.file_path = os.path.dirname(ann_path)
|
74 |
+
self.tmp = [x.strip().split(' ') for x in open(ann_path)]
|
75 |
+
print(('video number:%d' % (len(self.tmp))))
|
76 |
+
|
77 |
+
# emos = ['neutral', 'angry', 'happy', 'sad', 'worried', 'surprise']
|
78 |
+
emos = ['neutral', 'angry', 'happy', 'sad', 'worried', 'surprise', 'fear', 'contempt', 'doubt']
|
79 |
+
|
80 |
+
self.emo2idx, self.idx2emo = {}, {}
|
81 |
+
for ii, emo in enumerate(emos): self.emo2idx[emo] = ii
|
82 |
+
for ii, emo in enumerate(emos): self.emo2idx[ii] = emo
|
83 |
+
|
84 |
+
json_file_path = "/home/user/selected_face/face_emotion/AU_filter_merge.json"
|
85 |
+
with open(json_file_path, 'r') as json_file:
|
86 |
+
self.AU_filter_json = json.load(json_file)
|
87 |
+
|
88 |
+
reason_json_file_path = "/home/user/selected_face/face_emotion/0512_target_smp_end.json"
|
89 |
+
with open(reason_json_file_path, 'r') as json_file:
|
90 |
+
self.reason_dict = json.load(json_file)
|
91 |
+
|
92 |
+
self.character_lines = pd.read_csv('/home/user/selected_face/face_emotion/transcription_en_all.csv')
|
93 |
+
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
return len(self.tmp)
|
97 |
+
|
98 |
+
def __getitem__(self, index):
|
99 |
+
t = self.tmp[index]
|
100 |
+
video_name = t[0]
|
101 |
+
|
102 |
+
image_file = '{}.jpg'.format(video_name)
|
103 |
+
image_path = os.path.join(self.vis_root, image_file)
|
104 |
+
image = Image.open(image_path).convert("RGB")
|
105 |
+
image = self.vis_processor(image)
|
106 |
+
|
107 |
+
FaceMAE_feats, VideoMAE_feats, Audio_feats = self.get(video_name)
|
108 |
+
if len(VideoMAE_feats.shape) == 1:
|
109 |
+
VideoMAE_feats = VideoMAE_feats.unsqueeze(0)
|
110 |
+
if len(Audio_feats.shape) == 1:
|
111 |
+
Audio_feats = Audio_feats.unsqueeze(0)
|
112 |
+
if len(FaceMAE_feats.shape) == 1:
|
113 |
+
FaceMAE_feats = FaceMAE_feats.unsqueeze(0)
|
114 |
+
video_features = torch.cat((FaceMAE_feats, VideoMAE_feats, Audio_feats), dim=0)
|
115 |
+
|
116 |
+
|
117 |
+
# random task
|
118 |
+
task = random.choice(self.task_pool)
|
119 |
+
if task == "emotion":
|
120 |
+
caption = t[2] # llama2 putput only emotion class
|
121 |
+
caption = self.text_processor(caption)
|
122 |
+
instruction_pool = self.emotion_instruction_pool
|
123 |
+
elif task == "reason":
|
124 |
+
caption = self.reason_dict[video_name]['smp_reason_caption']
|
125 |
+
infer_str = " Therefore, it is inferred that his emotional state is: "
|
126 |
+
caption = caption + infer_str + t[2]
|
127 |
+
|
128 |
+
# caption = "" # for test reasoning
|
129 |
+
|
130 |
+
caption = self.text_processor(caption)
|
131 |
+
instruction_pool = self.reason_instruction_pool
|
132 |
+
|
133 |
+
elif task == "infer":
|
134 |
+
infer_str = " Therefore, it is inferred that his emotional state is: "
|
135 |
+
caption = t[2]
|
136 |
+
instruction_pool = [
|
137 |
+
self.reason_dict[video_name]['reason_caption'] + infer_str,
|
138 |
+
]
|
139 |
+
elif task == "caption":
|
140 |
+
caption = self.AU_filter_json[video_name]['caption']
|
141 |
+
caption = self.text_processor(caption)
|
142 |
+
instruction_pool = self.caption_instruction_pool
|
143 |
+
|
144 |
+
|
145 |
+
emotion = self.emo2idx[t[2]]
|
146 |
+
sentence = self.character_lines.loc[self.character_lines['name'] == video_name, 'sentence'].values[0]
|
147 |
+
character_line = "The person in video says: {}. ".format(sentence)
|
148 |
+
|
149 |
+
instruction = "<video><VideoHere></video> <feature><FeatureHere></feature> {} [{}] {} ".format(character_line, task, random.choice(instruction_pool))
|
150 |
+
|
151 |
+
return {
|
152 |
+
"image": image,
|
153 |
+
"video_features": video_features,
|
154 |
+
"instruction_input": instruction,
|
155 |
+
"answer": caption,
|
156 |
+
"emotion": emotion,
|
157 |
+
"image_id": video_name
|
158 |
+
}
|
159 |
+
|
160 |
+
|
161 |
+
def get(self, video_name):
|
162 |
+
# FaceMAE feature
|
163 |
+
FaceMAE_feats_path = os.path.join(self.file_path, 'mae_340_UTT', video_name + '.npy')
|
164 |
+
FaceMAE_feats = torch.tensor(np.load(FaceMAE_feats_path))
|
165 |
+
|
166 |
+
# VideoMAE feature
|
167 |
+
VideoMAE_feats_path = os.path.join(self.file_path, 'maeV_399_UTT', video_name + '.npy')
|
168 |
+
VideoMAE_feats = torch.tensor(np.load(VideoMAE_feats_path))
|
169 |
+
|
170 |
+
# Audio feature
|
171 |
+
Audio_feats_path = os.path.join(self.file_path, 'HL-UTT', video_name + '.npy')
|
172 |
+
Audio_feats = torch.tensor(np.load(Audio_feats_path))
|
173 |
+
|
174 |
+
return FaceMAE_feats, VideoMAE_feats, Audio_feats
|