LCZZZZ commited on
Commit
e5412f3
·
verified ·
1 Parent(s): 7ae5b5f

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. .github/ISSUE_TEMPLATE/bug_report.md +38 -0
  3. .github/ISSUE_TEMPLATE/feature_request.md +20 -0
  4. .gitignore +184 -0
  5. CODE_OF_CONDUCT.md +128 -0
  6. LICENSE.md +14 -0
  7. LICENSE_Lavis.md +14 -0
  8. MiniGPT4_Train.md +41 -0
  9. MiniGPTv2.pdf +3 -0
  10. MiniGPTv2_Train.md +24 -0
  11. README.md +208 -8
  12. SECURITY.md +21 -0
  13. checkpoint_stage2.pth +3 -0
  14. dataset/README_1_STAGE.md +96 -0
  15. dataset/README_2_STAGE.md +19 -0
  16. dataset/README_MINIGPTv2_FINETUNE.md +285 -0
  17. dataset/convert_cc_sbu.py +20 -0
  18. dataset/convert_laion.py +20 -0
  19. dataset/download_cc_sbu.sh +6 -0
  20. dataset/download_laion.sh +6 -0
  21. demo.py +171 -0
  22. demo_v2.py +651 -0
  23. environment.yml +35 -0
  24. eval_configs/minigpt4_eval.yaml +22 -0
  25. eval_configs/minigpt4_llama2_eval.yaml +22 -0
  26. eval_configs/minigptv2_benchmark_evaluation.yaml +79 -0
  27. eval_configs/minigptv2_eval.yaml +24 -0
  28. eval_scripts/EVAL_README.md +104 -0
  29. eval_scripts/eval_data/refcoco+_testA.json +0 -0
  30. eval_scripts/eval_data/refcoco+_testB.json +0 -0
  31. eval_scripts/eval_data/refcoco+_val.json +0 -0
  32. eval_scripts/eval_data/refcoco_testA.json +0 -0
  33. eval_scripts/eval_data/refcoco_testB.json +0 -0
  34. eval_scripts/eval_data/refcoco_val.json +0 -0
  35. eval_scripts/eval_data/refcocog_test.json +0 -0
  36. eval_scripts/eval_data/refcocog_val.json +0 -0
  37. eval_scripts/eval_ref.py +128 -0
  38. eval_scripts/eval_vqa.py +252 -0
  39. examples/ad_1.png +0 -0
  40. examples/ad_2.png +0 -0
  41. examples/cook_1.png +0 -0
  42. examples/cook_2.png +0 -0
  43. examples/describe_1.png +0 -0
  44. examples/describe_2.png +0 -0
  45. examples/fact_1.png +0 -0
  46. examples/fact_2.png +0 -0
  47. examples/fix_1.png +0 -0
  48. examples/fix_2.png +0 -0
  49. examples/fun_1.png +0 -0
  50. examples/fun_2.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ MiniGPTv2.pdf filter=lfs diff=lfs merge=lfs -text
37
+ examples_v2/cockdial.png filter=lfs diff=lfs merge=lfs -text
38
+ examples_v2/float.png filter=lfs diff=lfs merge=lfs -text
39
+ figs/demo.png filter=lfs diff=lfs merge=lfs -text
40
+ figs/minigpt2_demo.png filter=lfs diff=lfs merge=lfs -text
41
+ figs/online_demo.png filter=lfs diff=lfs merge=lfs -text
42
+ figs/overview.png filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **To Reproduce**
14
+ Steps to reproduce the behavior:
15
+ 1. Go to '...'
16
+ 2. Click on '....'
17
+ 3. Scroll down to '....'
18
+ 4. See error
19
+
20
+ **Expected behavior**
21
+ A clear and concise description of what you expected to happen.
22
+
23
+ **Screenshots**
24
+ If applicable, add screenshots to help explain your problem.
25
+
26
+ **Desktop (please complete the following information):**
27
+ - OS: [e.g. iOS]
28
+ - Browser [e.g. chrome, safari]
29
+ - Version [e.g. 22]
30
+
31
+ **Smartphone (please complete the following information):**
32
+ - Device: [e.g. iPhone6]
33
+ - OS: [e.g. iOS8.1]
34
+ - Browser [e.g. stock browser, safari]
35
+ - Version [e.g. 22]
36
+
37
+ **Additional context**
38
+ Add any other context about the problem here.
.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Is your feature request related to a problem? Please describe.**
11
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12
+
13
+ **Describe the solution you'd like**
14
+ A clear and concise description of what you want to happen.
15
+
16
+ **Describe alternatives you've considered**
17
+ A clear and concise description of any alternative solutions or features you've considered.
18
+
19
+ **Additional context**
20
+ Add any other context or screenshots about the feature request here.
.gitignore ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/#use-with-ide
111
+ .pdm.toml
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ .idea/
162
+
163
+ wandb/
164
+ jobs/logs/
165
+ *.out
166
+ *ipynb
167
+ .history/
168
+ *.json
169
+ *.sh
170
+ .ipynb_common
171
+ logs/
172
+ results/
173
+ prompts/
174
+ output/
175
+ ckpt/
176
+ divide_vqa.py
177
+ jobs/
178
+
179
+ *.slurm
180
+ slurm*
181
+ sbatch_generate*
182
+ eval_data/
183
+ dataset/Evaluation.md
184
+ jupyter_notebook.slurm
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our
6
+ community a harassment-free experience for everyone, regardless of age, body
7
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
8
+ identity and expression, level of experience, education, socio-economic status,
9
+ nationality, personal appearance, race, religion, or sexual identity
10
+ and orientation.
11
+
12
+ We pledge to act and interact in ways that contribute to an open, welcoming,
13
+ diverse, inclusive, and healthy community.
14
+
15
+ ## Our Standards
16
+
17
+ Examples of behavior that contributes to a positive environment for our
18
+ community include:
19
+
20
+ * Demonstrating empathy and kindness toward other people
21
+ * Being respectful of differing opinions, viewpoints, and experiences
22
+ * Giving and gracefully accepting constructive feedback
23
+ * Accepting responsibility and apologizing to those affected by our mistakes,
24
+ and learning from the experience
25
+ * Focusing on what is best not just for us as individuals, but for the
26
+ overall community
27
+
28
+ Examples of unacceptable behavior include:
29
+
30
+ * The use of sexualized language or imagery, and sexual attention or
31
+ advances of any kind
32
+ * Trolling, insulting or derogatory comments, and personal or political attacks
33
+ * Public or private harassment
34
+ * Publishing others' private information, such as a physical or email
35
+ address, without their explicit permission
36
+ * Other conduct which could reasonably be considered inappropriate in a
37
+ professional setting
38
+
39
+ ## Enforcement Responsibilities
40
+
41
+ Community leaders are responsible for clarifying and enforcing our standards of
42
+ acceptable behavior and will take appropriate and fair corrective action in
43
+ response to any behavior that they deem inappropriate, threatening, offensive,
44
+ or harmful.
45
+
46
+ Community leaders have the right and responsibility to remove, edit, or reject
47
+ comments, commits, code, wiki edits, issues, and other contributions that are
48
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
49
+ decisions when appropriate.
50
+
51
+ ## Scope
52
+
53
+ This Code of Conduct applies within all community spaces, and also applies when
54
+ an individual is officially representing the community in public spaces.
55
+ Examples of representing our community include using an official e-mail address,
56
+ posting via an official social media account, or acting as an appointed
57
+ representative at an online or offline event.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported to the community leaders responsible for enforcement at
63
+ https://discord.gg/2aNvvYVv.
64
+ All complaints will be reviewed and investigated promptly and fairly.
65
+
66
+ All community leaders are obligated to respect the privacy and security of the
67
+ reporter of any incident.
68
+
69
+ ## Enforcement Guidelines
70
+
71
+ Community leaders will follow these Community Impact Guidelines in determining
72
+ the consequences for any action they deem in violation of this Code of Conduct:
73
+
74
+ ### 1. Correction
75
+
76
+ **Community Impact**: Use of inappropriate language or other behavior deemed
77
+ unprofessional or unwelcome in the community.
78
+
79
+ **Consequence**: A private, written warning from community leaders, providing
80
+ clarity around the nature of the violation and an explanation of why the
81
+ behavior was inappropriate. A public apology may be requested.
82
+
83
+ ### 2. Warning
84
+
85
+ **Community Impact**: A violation through a single incident or series
86
+ of actions.
87
+
88
+ **Consequence**: A warning with consequences for continued behavior. No
89
+ interaction with the people involved, including unsolicited interaction with
90
+ those enforcing the Code of Conduct, for a specified period of time. This
91
+ includes avoiding interactions in community spaces as well as external channels
92
+ like social media. Violating these terms may lead to a temporary or
93
+ permanent ban.
94
+
95
+ ### 3. Temporary Ban
96
+
97
+ **Community Impact**: A serious violation of community standards, including
98
+ sustained inappropriate behavior.
99
+
100
+ **Consequence**: A temporary ban from any sort of interaction or public
101
+ communication with the community for a specified period of time. No public or
102
+ private interaction with the people involved, including unsolicited interaction
103
+ with those enforcing the Code of Conduct, is allowed during this period.
104
+ Violating these terms may lead to a permanent ban.
105
+
106
+ ### 4. Permanent Ban
107
+
108
+ **Community Impact**: Demonstrating a pattern of violation of community
109
+ standards, including sustained inappropriate behavior, harassment of an
110
+ individual, or aggression toward or disparagement of classes of individuals.
111
+
112
+ **Consequence**: A permanent ban from any sort of public interaction within
113
+ the community.
114
+
115
+ ## Attribution
116
+
117
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118
+ version 2.0, available at
119
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120
+
121
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
122
+ enforcement ladder](https://github.com/mozilla/diversity).
123
+
124
+ [homepage]: https://www.contributor-covenant.org
125
+
126
+ For answers to common questions about this code of conduct, see the FAQ at
127
+ https://www.contributor-covenant.org/faq. Translations are available at
128
+ https://www.contributor-covenant.org/translations.
LICENSE.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright 2023 Deyao Zhu
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
LICENSE_Lavis.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022 Salesforce, Inc.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
MiniGPT4_Train.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training of MiniGPT-4
2
+
3
+ The training of MiniGPT-4 contains two alignment stages.
4
+
5
+ **1. First pretraining stage**
6
+
7
+ In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
8
+ to align the vision and language model. To download and prepare the datasets, please check
9
+ our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
10
+ After the first stage, the visual features are mapped and can be understood by the language
11
+ model.
12
+ To launch the first stage training, run the following command. In our experiments, we use 4 A100.
13
+ You can change the save path in the config file
14
+ [train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
15
+
16
+ ```bash
17
+ torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
18
+ ```
19
+
20
+ A MiniGPT-4 checkpoint with only stage one training can be downloaded
21
+ [here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
22
+ Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
23
+
24
+
25
+ **2. Second finetuning stage**
26
+
27
+ In the second stage, we use a small high quality image-text pair dataset created by ourselves
28
+ and convert it to a conversation format to further align MiniGPT-4.
29
+ To download and prepare our second stage dataset, please check our
30
+ [second stage dataset preparation instruction](dataset/README_2_STAGE.md).
31
+ To launch the second stage alignment,
32
+ first specify the path to the checkpoint file trained in stage 1 in
33
+ [train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
34
+ You can also specify the output path there.
35
+ Then, run the following command. In our experiments, we use 1 A100.
36
+
37
+ ```bash
38
+ torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
39
+ ```
40
+
41
+ After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.
MiniGPTv2.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:429b0f5e3d70828fd691ef4ffb90c6efa094a8454bf03f8ec00b10fcd443f346
3
+ size 4357853
MiniGPTv2_Train.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Finetune of MiniGPT-4
2
+
3
+
4
+ You firstly need to prepare the dataset. you can follow this step to prepare the dataset.
5
+ our [dataset preparation](dataset/README_MINIGPTv2_FINETUNE.md).
6
+
7
+ In the train_configs/minigptv2_finetune.yaml, you need to set up the following paths:
8
+
9
+ llama_model checkpoint path: "/path/to/llama_checkpoint"
10
+
11
+ ckpt: "/path/to/pretrained_checkpoint"
12
+
13
+ ckpt save path: "/path/to/save_checkpoint"
14
+
15
+ For ckpt, you may load from our pretrained model checkpoints:
16
+ | MiniGPT-v2 (after stage-2) | MiniGPT-v2 (after stage-3) | MiniGPT-v2 (online developing demo) |
17
+ |------------------------------|------------------------------|------------------------------|
18
+ | [Download](https://drive.google.com/file/d/1Vi_E7ZtZXRAQcyz4f8E6LtLh2UXABCmu/view?usp=sharing) |[Download](https://drive.google.com/file/d/1HkoUUrjzFGn33cSiUkI-KcT-zysCynAz/view?usp=sharing) | [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
19
+
20
+
21
+ ```bash
22
+ torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigptv2_finetune.yaml
23
+ ```
24
+
README.md CHANGED
@@ -1,12 +1,212 @@
1
  ---
2
- title: MiniGPT 4
3
- emoji: 🏃
4
- colorFrom: blue
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.38.1
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MiniGPT-4
3
+ app_file: demo_v2.py
 
 
4
  sdk: gradio
5
+ sdk_version: 3.47.1
 
 
6
  ---
7
+ # MiniGPT-V
8
 
9
+ <font size='5'>**MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning**</font>
10
+
11
+ Jun Chen, Deyao Zhu, Xiaoqian Shen, Xiang Li, Zechun Liu, Pengchuan Zhang, Raghuraman Krishnamoorthi, Vikas Chandra, Yunyang Xiong☨, Mohamed Elhoseiny☨
12
+
13
+ ☨equal last author
14
+
15
+ <a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2310.09478.pdf'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/MiniGPT-v2'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'> <a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Gradio-Demo-blue'></a> [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=atFCwV2hSY4)
16
+
17
+
18
+ <font size='5'> **MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models**</font>
19
+
20
+ Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny
21
+
22
+ *equal contribution
23
+
24
+ <a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
25
+
26
+ *King Abdullah University of Science and Technology*
27
+
28
+ ## 💡 Get help - [Q&A](https://github.com/Vision-CAIR/MiniGPT-4/discussions/categories/q-a) or [Discord 💬](https://discord.gg/5WdJkjbAeE)
29
+
30
+ <font size='4'> **Example Community Efforts Built on Top of MiniGPT-4 ** </font>
31
+
32
+ * <a href='https://github.com/waltonfuture/InstructionGPT-4?tab=readme-ov-file'><img src='https://img.shields.io/badge/Project-Page-Green'></a> **InstructionGPT-4**: A 200-Instruction Paradigm for Fine-Tuning MiniGPT-4 Lai Wei, Zihao Jiang, Weiran Huang, Lichao Sun, Arxiv, 2023
33
+
34
+ * <a href='https://openaccess.thecvf.com/content/ICCV2023W/CLVL/papers/Aubakirova_PatFig_Generating_Short_and_Long_Captions_for_Patent_Figures_ICCVW_2023_paper.pdf'><img src='https://img.shields.io/badge/Project-Page-Green'></a> **PatFig**: Generating Short and Long Captions for Patent Figures.", Aubakirova, Dana, Kim Gerdes, and Lufei Liu, ICCVW, 2023
35
+
36
+
37
+ * <a href='https://github.com/JoshuaChou2018/SkinGPT-4'><img src='https://img.shields.io/badge/Project-Page-Green'></a> **SkinGPT-4**: An Interactive Dermatology Diagnostic System with Visual Large Language Model, Juexiao Zhou and Xiaonan He and Liyuan Sun and Jiannan Xu and Xiuying Chen and Yuetan Chu and Longxi Zhou and Xingyu Liao and Bin Zhang and Xin Gao, Arxiv, 2023
38
+
39
+
40
+ * <a href='https://huggingface.co/Tyrannosaurus/ArtGPT-4'><img src='https://img.shields.io/badge/Project-Page-Green'></a> **ArtGPT-4**: Artistic Vision-Language Understanding with Adapter-enhanced MiniGPT-4.", Yuan, Zhengqing, Huiwen Xue, Xinyi Wang, Yongming Liu, Zhuanzhe Zhao, and Kun Wang, Arxiv, 2023
41
+
42
+
43
+ </font>
44
+
45
+ ## News
46
+ [Oct.31 2023] We release the evaluation code of our MiniGPT-v2.
47
+
48
+ [Oct.24 2023] We release the finetuning code of our MiniGPT-v2.
49
+
50
+ [Oct.13 2023] Breaking! We release the first major update with our MiniGPT-v2
51
+
52
+ [Aug.28 2023] We now provide a llama 2 version of MiniGPT-4
53
+
54
+ ## Online Demo
55
+
56
+ Click the image to chat with MiniGPT-v2 around your images
57
+ [![demo](figs/minigpt2_demo.png)](https://minigpt-v2.github.io/)
58
+
59
+ Click the image to chat with MiniGPT-4 around your images
60
+ [![demo](figs/online_demo.png)](https://minigpt-4.github.io)
61
+
62
+
63
+ ## MiniGPT-v2 Examples
64
+
65
+ ![MiniGPT-v2 demos](figs/demo.png)
66
+
67
+
68
+
69
+ ## MiniGPT-4 Examples
70
+ | | |
71
+ :-------------------------:|:-------------------------:
72
+ ![find wild](figs/examples/wop_2.png) | ![write story](figs/examples/ad_2.png)
73
+ ![solve problem](figs/examples/fix_1.png) | ![write Poem](figs/examples/rhyme_1.png)
74
+
75
+ More examples can be found in the [project page](https://minigpt-4.github.io).
76
+
77
+
78
+
79
+ ## Getting Started
80
+ ### Installation
81
+
82
+ **1. Prepare the code and the environment**
83
+
84
+ Git clone our repository, creating a python environment and activate it via the following command
85
+
86
+ ```bash
87
+ git clone https://github.com/Vision-CAIR/MiniGPT-4.git
88
+ cd MiniGPT-4
89
+ conda env create -f environment.yml
90
+ conda activate minigptv
91
+ ```
92
+
93
+
94
+ **2. Prepare the pretrained LLM weights**
95
+
96
+ **MiniGPT-v2** is based on Llama2 Chat 7B. For **MiniGPT-4**, we have both Vicuna V0 and Llama 2 version.
97
+ Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
98
+
99
+ | Llama 2 Chat 7B | Vicuna V0 13B | Vicuna V0 7B |
100
+ :------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
101
+ [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) | [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main)
102
+
103
+
104
+ Then, set the variable *llama_model* in the model config file to the LLM weight path.
105
+
106
+ * For MiniGPT-v2, set the LLM path
107
+ [here](minigpt4/configs/models/minigpt_v2.yaml#L15) at Line 14.
108
+
109
+ * For MiniGPT-4 (Llama2), set the LLM path
110
+ [here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
111
+
112
+ * For MiniGPT-4 (Vicuna), set the LLM path
113
+ [here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
114
+
115
+ **3. Prepare the pretrained model checkpoints**
116
+
117
+ Download the pretrained model checkpoints
118
+
119
+
120
+ | MiniGPT-v2 (after stage-2) | MiniGPT-v2 (after stage-3) | MiniGPT-v2 (online developing demo)|
121
+ |------------------------------|------------------------------|------------------------------|
122
+ | [Download](https://drive.google.com/file/d/1Vi_E7ZtZXRAQcyz4f8E6LtLh2UXABCmu/view?usp=sharing) |[Download](https://drive.google.com/file/d/1HkoUUrjzFGn33cSiUkI-KcT-zysCynAz/view?usp=sharing) | [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
123
+
124
+
125
+ For **MiniGPT-v2**, set the path to the pretrained checkpoint in the evaluation config file
126
+ in [eval_configs/minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#L10) at Line 8.
127
+
128
+
129
+
130
+ | MiniGPT-4 (Vicuna 13B) | MiniGPT-4 (Vicuna 7B) | MiniGPT-4 (LLaMA-2 Chat 7B) |
131
+ |----------------------------|---------------------------|---------------------------------|
132
+ | [Download](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) |
133
+
134
+ For **MiniGPT-4**, set the path to the pretrained checkpoint in the evaluation config file
135
+ in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
136
+
137
+
138
+
139
+ ### Launching Demo Locally
140
+
141
+ For MiniGPT-v2, run
142
+ ```
143
+ python demo_v2.py --cfg-path eval_configs/minigptv2_eval.yaml --gpu-id 0
144
+ ```
145
+
146
+ For MiniGPT-4 (Vicuna version), run
147
+
148
+ ```
149
+ python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
150
+ ```
151
+
152
+ For MiniGPT-4 (Llama2 version), run
153
+
154
+ ```
155
+ python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
156
+ ```
157
+
158
+
159
+ To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
160
+ This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
161
+ For more powerful GPUs, you can run the model
162
+ in 16 bit by setting `low_resource` to `False` in the relevant config file:
163
+
164
+ * MiniGPT-v2: [minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#6)
165
+ * MiniGPT-4 (Llama2): [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6)
166
+ * MiniGPT-4 (Vicuna): [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6)
167
+
168
+ Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run MiniGPT-4 on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
169
+
170
+
171
+ ### Training
172
+ For training details of MiniGPT-4, check [here](MiniGPT4_Train.md).
173
+
174
+ For finetuning details of MiniGPT-v2, check [here](MiniGPTv2_Train.md)
175
+
176
+
177
+ ### Evaluation
178
+ For finetuning details of MiniGPT-v2, check [here](eval_scripts/EVAL_README.md)
179
+
180
+
181
+ ## Acknowledgement
182
+
183
+ + [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before!
184
+ + [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis!
185
+ + [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!
186
+ + [LLaMA](https://github.com/facebookresearch/llama) The strong open-sourced LLaMA 2 language model.
187
+
188
+
189
+ If you're using MiniGPT-4/MiniGPT-v2 in your research or applications, please cite using this BibTeX:
190
+ ```bibtex
191
+
192
+
193
+ @article{chen2023minigptv2,
194
+ title={MiniGPT-v2: large language model as a unified interface for vision-language multi-task learning},
195
+ author={Chen, Jun and Zhu, Deyao and Shen, Xiaoqian and Li, Xiang and Liu, Zechu and Zhang, Pengchuan and Krishnamoorthi, Raghuraman and Chandra, Vikas and Xiong, Yunyang and Elhoseiny, Mohamed},
196
+ year={2023},
197
+ journal={arXiv preprint arXiv:2310.09478},
198
+ }
199
+
200
+ @article{zhu2023minigpt,
201
+ title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models},
202
+ author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed},
203
+ journal={arXiv preprint arXiv:2304.10592},
204
+ year={2023}
205
+ }
206
+ ```
207
+
208
+
209
+ ## License
210
+ This repository is under [BSD 3-Clause License](LICENSE.md).
211
+ Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with
212
+ BSD 3-Clause License [here](LICENSE_Lavis.md).
SECURITY.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Security Policy
2
+
3
+ ## Supported Versions
4
+
5
+ Use this section to tell people about which versions of your project are
6
+ currently being supported with security updates.
7
+
8
+ | Version | Supported |
9
+ | ------- | ------------------ |
10
+ | 5.1.x | :white_check_mark: |
11
+ | 5.0.x | :x: |
12
+ | 4.0.x | :white_check_mark: |
13
+ | < 4.0 | :x: |
14
+
15
+ ## Reporting a Vulnerability
16
+
17
+ Use this section to tell people how to report a vulnerability.
18
+
19
+ Tell them where to go, how often they can expect to get an update on a
20
+ reported vulnerability, what to expect if the vulnerability is accepted or
21
+ declined, etc.
checkpoint_stage2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d23539a5d0f2e02539dde09de21c89bcb054bb989d0a2800b280abd9d7f57a0
3
+ size 679804205
dataset/README_1_STAGE.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download the filtered Conceptual Captions, SBU, LAION datasets
2
+
3
+ ### Pre-training datasets download:
4
+ We use the filtered synthetic captions prepared by BLIP. For more details about the dataset, please refer to [BLIP](https://github.com/salesforce/BLIP).
5
+
6
+ It requires ~2.3T to store LAION and CC3M+CC12M+SBU datasets
7
+
8
+ Image source | Filtered synthetic caption by ViT-L
9
+ --- | :---:
10
+ CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
11
+ LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
12
+
13
+ This will download two json files
14
+ ```
15
+ ccs_synthetic_filtered_large.json
16
+ laion_synthetic_filtered_large.json
17
+ ```
18
+
19
+ ## prepare the data step-by-step
20
+
21
+
22
+ ### setup the dataset folder and move the annotation file to the data storage folder
23
+ ```
24
+ export MINIGPT4_DATASET=/YOUR/PATH/FOR/LARGE/DATASET/
25
+ mkdir ${MINIGPT4_DATASET}/cc_sbu
26
+ mkdir ${MINIGPT4_DATASET}/laion
27
+ mv ccs_synthetic_filtered_large.json ${MINIGPT4_DATASET}/cc_sbu
28
+ mv laion_synthetic_filtered_large.json ${MINIGPT4_DATASET}/laion
29
+ ```
30
+
31
+ ### Convert the scripts to data storate folder
32
+ ```
33
+ cp convert_cc_sbu.py ${MINIGPT4_DATASET}/cc_sbu
34
+ cp download_cc_sbu.sh ${MINIGPT4_DATASET}/cc_sbu
35
+ cp convert_laion.py ${MINIGPT4_DATASET}/laion
36
+ cp download_laion.sh ${MINIGPT4_DATASET}/laion
37
+ ```
38
+
39
+
40
+ ### Convert the laion and cc_sbu annotation file format to be img2dataset format
41
+ ```
42
+ cd ${MINIGPT4_DATASET}/cc_sbu
43
+ python convert_cc_sbu.py
44
+
45
+ cd ${MINIGPT4_DATASET}/laion
46
+ python convert_laion.py
47
+ ```
48
+
49
+ ### Download the datasets with img2dataset
50
+ ```
51
+ cd ${MINIGPT4_DATASET}/cc_sbu
52
+ sh download_cc_sbu.sh
53
+ cd ${MINIGPT4_DATASET}/laion
54
+ sh download_laion.sh
55
+ ```
56
+
57
+
58
+ The final dataset structure
59
+
60
+ ```
61
+ .
62
+ ├── ${MINIGPT4_DATASET}
63
+ │ ├── cc_sbu
64
+ │ ├── convert_cc_sbu.py
65
+ │ ├── download_cc_sbu.sh
66
+ │ ├── ccs_synthetic_filtered_large.json
67
+ │ ├── ccs_synthetic_filtered_large.tsv
68
+ │ └── cc_sbu_dataset
69
+ │ ├── 00000.tar
70
+ │ ├── 00000.parquet
71
+ │ ...
72
+ │ ├── laion
73
+ │ ├── convert_laion.py
74
+ │ ├── download_laion.sh
75
+ │ ├── laion_synthetic_filtered_large.json
76
+ │ ├── laion_synthetic_filtered_large.tsv
77
+ │ └── laion_dataset
78
+ │ ├── 00000.tar
79
+ │ ├── 00000.parquet
80
+ │ ...
81
+ ...
82
+ ```
83
+
84
+
85
+ ## Set up the dataset configuration files
86
+
87
+ Then, set up the LAION dataset loading path in
88
+ [here](../minigpt4/configs/datasets/laion/defaults.yaml#L5) at Line 5 as
89
+ ${MINIGPT4_DATASET}/laion/laion_dataset/{00000..10488}.tar
90
+
91
+ and the Conceptual Captoin and SBU datasets loading path in
92
+ [here](../minigpt4/configs/datasets/cc_sbu/defaults.yaml#L5) at Line 5 as
93
+ ${MINIGPT4_DATASET}/cc_sbu/cc_sbu_dataset/{00000..01255}.tar
94
+
95
+
96
+
dataset/README_2_STAGE.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Second Stage Data Preparation
2
+
3
+ Our second stage dataset can be downloaded from
4
+ [here](https://drive.google.com/file/d/1nJXhoEcy3KTExr17I7BXqY5Y9Lx_-n-9/view?usp=share_link)
5
+ After extraction, you will get a data follder with the following structure:
6
+
7
+ ```
8
+ cc_sbu_align
9
+ ├── filter_cap.json
10
+ └── image
11
+ ├── 2.jpg
12
+ ├── 3.jpg
13
+ ...
14
+ ```
15
+
16
+ Put the folder to any path you want.
17
+ Then, set up the dataset path in the dataset config file
18
+ [here](../minigpt4/configs/datasets/cc_sbu/align.yaml#L5) at Line 5.
19
+
dataset/README_MINIGPTv2_FINETUNE.md ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download the dataset for finetuning the MiniGPT-v2
2
+
3
+
4
+ Download the dataset
5
+
6
+ Image source | Download path
7
+ --- | :---:
8
+ COCO 2014 images | <a href="http://images.cocodataset.org/zips/train2014.zip">images</a> &nbsp;&nbsp; <a href="https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json"> captions</a>
9
+ COCO VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json">vqa train</a> &nbsp;&nbsp; <a href="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json"> vqa val</a>
10
+ Visual Genome | <a href="https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip">images part1</a> &nbsp;&nbsp; <a href="https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip">images part2</a> &nbsp;&nbsp; <a href="https://homes.cs.washington.edu/~ranjay/visualgenome/data/dataset/image_data.json.zip"> image meta data </a>
11
+ TextCaps | <a href="https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip">images</a> &nbsp;&nbsp; <a href="https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_train.json"> annotations</a>
12
+ RefCOCO | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip"> annotations </a>
13
+ RefCOCO+ | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip"> annotations </a>
14
+ RefCOCOg | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip"> annotations </a>
15
+ OKVQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json"> annotations </a>
16
+ AOK-VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json"> annotations </a>
17
+ OCR-VQA | <a href="https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing"> annotations </a>
18
+ GQA | <a href="https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip">images</a> &nbsp;&nbsp; <a href="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json"> annotations </a>
19
+ Filtered flickr-30k | <a href="https://drive.google.com/drive/folders/19c_ggBI77AvdtYlPbuI0ZpnPz73T5teX?usp=sharing"> annotations </a>
20
+ Multi-task conversation | <a href="https://drive.google.com/file/d/11HHqB2c29hbSk-WLxdta-nG8UCUrcCN1/view?usp=sharing"> annotations </a>
21
+ Filtered unnatural instruction | <a href="https://drive.google.com/file/d/1lXNnBcb5WU-sc8Fe2T2N8J0NRw4sBLev/view?usp=sharing"> annotations </a>
22
+ LLaVA | <a href="https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/complex_reasoning_77k.json"> Compelex reasoning </a> &nbsp;&nbsp;<a href="https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/detail_23k.json"> Detailed description </a> &nbsp;&nbsp; <a href="https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/conversation_58k.json"> Conversation </a>
23
+
24
+
25
+
26
+ ### COCO captions
27
+ Download the COCO 2014 images and captions
28
+
29
+ coco 2014 images path
30
+
31
+ ```
32
+ ${MINIGPTv2_DATASET}
33
+ ├── coco
34
+ │ ├── images
35
+ ...
36
+ ```
37
+
38
+
39
+ coco caption annotation path
40
+
41
+ ```
42
+ ${MINIGPTv2_DATASET}
43
+ ├── coco_captions
44
+ │ └── annotations
45
+ │ ├── coco_karpathy_train.json
46
+ ...
47
+ ```
48
+
49
+ Set **image_path** to the COCO 2014 image folder.
50
+ Similarly, set **ann_path** to the coco_karpathy_train.json path
51
+ - [minigpt4/configs/datasets/coco/caption.yaml](../minigpt4/configs/datasets/coco/caption.yaml)
52
+
53
+ ### COCO VQA
54
+ Download the vqa v2 train and validation json files
55
+
56
+ ```
57
+ ├── ${MINIGPTv2_DATASET}
58
+ │ ├── vqav2
59
+ │ ├── vqa_train.json
60
+ | ├── vqa_val.json
61
+ ```
62
+
63
+ Set **image_path** to the COCO 2014 image folder.
64
+ Similarly, set **ann_path** to the vqa_train.json and vqa_val.json path
65
+ - [minigpt4/configs/datasets/coco/defaults_vqa.yaml](../minigpt4/configs/datasets/coco/defaults_vqa.yaml)
66
+
67
+
68
+ ### Visual genome
69
+ Download visiual genome images and annotation files
70
+
71
+ ```
72
+ ${MINIGPTv2_DATASET}
73
+ ├── visual_genome
74
+ │ ├── VG_100K
75
+ │ ├── VG_100K_2
76
+ │ └── region_descriptions.json
77
+ │ └── image_data.json
78
+ ...
79
+ ```
80
+
81
+ Set **image_path** to visual_genome folder.
82
+ Similarly, set **ann_path** to the visual_genome folder.
83
+
84
+ - [minigpt4/configs/datasets/vg/ref.yaml](../minigpt4/configs/datasets/vg/ref.yaml)
85
+
86
+
87
+ ### TextCaps
88
+ Download the TextCaps images and annotation files
89
+
90
+ ```
91
+ ├── ${MINIGPTv2_DATASET}
92
+ │ ├── textcaps
93
+ │ ├── train_images
94
+ │ ├── TextCaps_0.1_train.json
95
+ ```
96
+
97
+ Set **image_path** to TextCaps train_images folder.
98
+ Similarly, set **ann_path** to the TextCaps_0.1_train.json path
99
+
100
+ - [minigpt4/configs/datasets/textcaps/caption.yaml](../minigpt4/configs/datasets/textcaps/caption.yaml)
101
+
102
+ ### RefCOCO, RefCOCO+, RefCOCOg
103
+ Download the RefCOCO, RefCOCO+, RefCOCOg annotation files
104
+
105
+ ```
106
+
107
+ ${MINIGPTv2_DATASET}
108
+ ├── refcoco_annotations
109
+ │ ├── refcoco
110
+ │ │ ├── instances.json
111
+ │ │ ├��─ refs(google).p
112
+ │ │ └── refs(unc).p
113
+ │ ├── refcoco+
114
+ │ │ ├── instances.json
115
+ │ │ └── refs(unc).p
116
+ │ └── refcocog
117
+ │ ├── instances.json
118
+ │ ├── refs(google).p
119
+ │ └─── refs(und).p
120
+ ...
121
+ ```
122
+
123
+
124
+ Set **image_path** to the COCO 2014 image folder.
125
+ Similarly, set **ann_path** in all the following configs to the above folder *refcoco_annotations* that contains refcoco, refcoco+, and refcocog.
126
+
127
+ - [minigpt4/configs/datasets/coco_bbox/refcoco.yaml](../minigpt4/configs/datasets/coco_bbox/refcoco.yaml)
128
+ - [minigpt4/configs/datasets/coco_bbox/refcocog.yaml](../minigpt4/configs/datasets/coco_bbox/refcocog.yaml)
129
+ - [minigpt4/configs/datasets/coco_bbox/refcocop.yaml](../minigpt4/configs/datasets/coco_bbox/refcocop.yaml)
130
+ - [minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml)
131
+ - [minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml)
132
+ - [minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml](../minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml)
133
+
134
+
135
+
136
+
137
+ ### OKVQA
138
+
139
+
140
+ ```
141
+ Location_you_like
142
+ ├── ${MINIGPTv2_DATASET}
143
+ │ ├── okvqa
144
+ │ ├── okvqa_train.json
145
+ ```
146
+
147
+ Set **image_path** to the COCO 2014 image folder.
148
+ Similarly, set **ann_path** to the location of the OKVQA dataset
149
+ - [minigpt4/configs/datasets/okvqa/defaults.yaml](../minigpt4/configs/datasets/okvqa/defaults.yaml)
150
+
151
+
152
+ ### COCO-VQA
153
+
154
+ - [OK-VQA Input Questions](https://okvqa.allenai.org/static/data/OpenEnded_mscoco_train2014_questions.json.zip)
155
+ - [OK-VQA Annotations](https://okvqa.allenai.org/static/data/mscoco_train2014_annotations.json.zip)
156
+
157
+
158
+ ### AOK-VQA
159
+ Download the AOK-VQA annotation dataset
160
+
161
+ ```
162
+ export AOKVQA_DIR=YOUR_DATASET_PATH
163
+ mkdir -p ${AOKVQA_DIR}
164
+ curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR}
165
+ ```
166
+
167
+ ```
168
+ Location_you_like
169
+ ├── ${MINIGPTv2_DATASET}
170
+ │ ├── aokvqa
171
+ │ ├── aokvqa_v1p0_train.json
172
+ ```
173
+
174
+
175
+ Set **image_path** to the COCO 2014 image folder.
176
+ Similarly, set **ann_path** to the location of the AOKVQA dataset
177
+ - [minigpt4/configs/datasets/aokvqa/defaults.yaml](../minigpt4/configs/datasets/aokvqa/defaults.yaml)
178
+
179
+
180
+
181
+ ### OCR-VQA
182
+ Download the OCR-VQA annotation files
183
+ download the images with loadDataset.py script
184
+
185
+ ```
186
+ Location_you_like
187
+ ├── ${MINIGPTv2_DATASET}
188
+ │ ├── ocrvqa
189
+ │ ├── images
190
+ │ ├── dataset.json
191
+ ```
192
+
193
+ Set **image_path** as the ocrvqa/images folder.
194
+ Similarly, set **ann_path** to the dataset.json
195
+ - [minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml](../minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml)
196
+
197
+ ### GQA
198
+ Download the GQA annotation files and images
199
+
200
+ ```
201
+ Location_you_like
202
+ ├── ${MINIGPTv2_DATASET}
203
+ │ ├── gqa
204
+ │ ├── images
205
+ │ ├── train_balanced_questions.json
206
+ ```
207
+
208
+ Set **image_path** as the gqa/images folder.
209
+ Similarly, set **ann_path** to the train_balanced_questions.json
210
+ - [minigpt4/configs/datasets/gqa/balanced_val.yaml](../minigpt4/configs/datasets/gqa/balanced_val.yaml)
211
+
212
+
213
+
214
+ ### filtered Flickr-30k
215
+ Download filtered Flickr-30k images (fill this [form](https://forms.illinois.edu/sec/229675) on official website or from [kaggle](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset/download?datasetVersionNumber=1)) and annotation files
216
+
217
+ ```
218
+ ${MINIGPTv2_DATASET}
219
+ ├── filtered_flickr
220
+ │ ├── images
221
+ │ ├── captiontobbox.json
222
+ │ ├── groundedcaption.json
223
+ │ └── phrasetobbox.json
224
+ ...
225
+ ```
226
+
227
+ Set **image_path** as the flickr-30k images foler.
228
+ Similarly, set **ann_path** to the groundedcaption.json, captiontobbox.json and phrasetobbox.json for the
229
+ grounded image caption, caption to bbox, and phrase to bbox datasets.
230
+
231
+ - [minigpt4/configs/datasets/flickr/default.yaml](../minigpt4/configs/datasets/flickr/default.yaml)
232
+ - [minigpt4/configs/datasets/flickr/caption_to_phrase.yaml](../minigpt4/configs/datasets/flickr/caption_to_phrase.yaml)
233
+ - [minigpt4/configs/datasets/flickr/object_to_phrase.yaml](../minigpt4/configs/datasets/flickr/object_to_phrase.yaml)
234
+
235
+
236
+ ### Multi-task conversation
237
+ Download the multi-task converstation dataset
238
+
239
+ ```
240
+ Location_you_like
241
+ ${MINIGPTv2_DATASET}
242
+ ├── multitask_conversation
243
+ │ └── multitask_conversation.json
244
+ ...
245
+ ```
246
+
247
+ Set **image_path** as the COCO 2014 images folder.
248
+ Similarly, set **ann_path** to the multitask_conversation.json file path
249
+
250
+ - [minigpt4/configs/datasets/multitask_conversation/default.yaml](../minigpt4/configs/datasets/multitask_conversation/default.yaml)
251
+
252
+ ### Unnatural instruction
253
+ Download the filtered unnatural instruction annotation files (we remove the very long sentences from the original unnatural instruction dataset)
254
+
255
+ ```
256
+ Location_you_like
257
+ ├── ${MINIGPTv2_DATASET}
258
+ │ ├── unnatural_instructions
259
+ │ ├── filtered_unnatural_instruction.json
260
+ ```
261
+
262
+ There is no image path.
263
+ Similarly, set **ann_path** to the filtered_unnatural_instruction.json file path
264
+
265
+ - [minigpt4/configs/datasets/nlp/unnatural_instruction.yaml](../minigpt4/configs/datasets/nlp/unnatural_instruction.yaml)
266
+
267
+ ### LLaVA
268
+
269
+ ```
270
+ Location_you_like
271
+ ├── ${MINIGPTv2_DATASET}
272
+ │ ├── llava
273
+ │ ├── conversation_58k.json
274
+ │ ├── detail_23k.json
275
+ │ ├── complex_reasoning_77k.json
276
+ ```
277
+
278
+ Set **image_path** to the COCO 2014 image folder.
279
+ Similarly, set **ann_path** to the location of the previous downloaded conversation_58k.json,
280
+ detail_23k.json, and complex_reasoning_77k.json in conversation.yaml, detail.yaml, and reason.yaml, respectively.
281
+
282
+
283
+ - [minigpt4/configs/datasets/llava/conversation.yaml](../minigpt4/configs/datasets/llava/conversation.yaml)
284
+ - [minigpt4/configs/datasets/llava/detail.yaml](../minigpt4/configs/datasets/llava/detail.yaml)
285
+ - [minigpt4/configs/datasets/llava/reason.yaml](../minigpt4/configs/datasets/llava/reason.yaml)
dataset/convert_cc_sbu.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import csv
3
+
4
+ # specify input and output file paths
5
+ input_file = 'ccs_synthetic_filtered_large.json'
6
+ output_file = 'ccs_synthetic_filtered_large.tsv'
7
+
8
+ # load JSON data from input file
9
+ with open(input_file, 'r') as f:
10
+ data = json.load(f)
11
+
12
+ # extract header and data from JSON
13
+ header = data[0].keys()
14
+ rows = [x.values() for x in data]
15
+
16
+ # write data to TSV file
17
+ with open(output_file, 'w') as f:
18
+ writer = csv.writer(f, delimiter='\t')
19
+ writer.writerow(header)
20
+ writer.writerows(rows)
dataset/convert_laion.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import csv
3
+
4
+ # specify input and output file paths
5
+ input_file = 'laion_synthetic_filtered_large.json'
6
+ output_file = 'laion_synthetic_filtered_large.tsv'
7
+
8
+ # load JSON data from input file
9
+ with open(input_file, 'r') as f:
10
+ data = json.load(f)
11
+
12
+ # extract header and data from JSON
13
+ header = data[0].keys()
14
+ rows = [x.values() for x in data]
15
+
16
+ # write data to TSV file
17
+ with open(output_file, 'w') as f:
18
+ writer = csv.writer(f, delimiter='\t')
19
+ writer.writerow(header)
20
+ writer.writerows(rows)
dataset/download_cc_sbu.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ img2dataset --url_list ccs_synthetic_filtered_large.tsv --input_format "tsv"\
4
+ --url_col "url" --caption_col "caption" --output_format webdataset\
5
+ --output_folder cc_sbu_dataset --processes_count 16 --thread_count 128 --image_size 224 \
6
+ --enable_wandb True
dataset/download_laion.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ img2dataset --url_list laion_synthetic_filtered_large.tsv --input_format "tsv"\
4
+ --url_col "url" --caption_col "caption" --output_format webdataset\
5
+ --output_folder laion_dataset --processes_count 16 --thread_count 128 --image_size 224 \
6
+ --enable_wandb True
demo.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+ import gradio as gr
9
+
10
+ from transformers import StoppingCriteriaList
11
+
12
+ from minigpt4.common.config import Config
13
+ from minigpt4.common.dist_utils import get_rank
14
+ from minigpt4.common.registry import registry
15
+ from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub
16
+
17
+ # imports modules for registration
18
+ from minigpt4.datasets.builders import *
19
+ from minigpt4.models import *
20
+ from minigpt4.processors import *
21
+ from minigpt4.runners import *
22
+ from minigpt4.tasks import *
23
+
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(description="Demo")
27
+ parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
28
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
29
+ parser.add_argument(
30
+ "--options",
31
+ nargs="+",
32
+ help="override some settings in the used config, the key-value pair "
33
+ "in xxx=yyy format will be merged into config file (deprecate), "
34
+ "change to --cfg-options instead.",
35
+ )
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def setup_seeds(config):
41
+ seed = config.run_cfg.seed + get_rank()
42
+
43
+ random.seed(seed)
44
+ np.random.seed(seed)
45
+ torch.manual_seed(seed)
46
+
47
+ cudnn.benchmark = False
48
+ cudnn.deterministic = True
49
+
50
+
51
+ # ========================================
52
+ # Model Initialization
53
+ # ========================================
54
+
55
+ conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
56
+ 'pretrain_llama2': CONV_VISION_LLama2}
57
+
58
+ print('Initializing Chat')
59
+ args = parse_args()
60
+ cfg = Config(args)
61
+
62
+ model_config = cfg.model_cfg
63
+ model_config.device_8bit = args.gpu_id
64
+ model_cls = registry.get_model_class(model_config.arch)
65
+ model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
66
+
67
+ CONV_VISION = conv_dict[model_config.model_type]
68
+
69
+ vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
70
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
71
+
72
+ stop_words_ids = [[835], [2277, 29937]]
73
+ stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
74
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
75
+
76
+ chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
77
+ print('Initialization Finished')
78
+
79
+
80
+ # ========================================
81
+ # Gradio Setting
82
+ # ========================================
83
+
84
+
85
+ def gradio_reset(chat_state, img_list):
86
+ if chat_state is not None:
87
+ chat_state.messages = []
88
+ if img_list is not None:
89
+ img_list = []
90
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
91
+
92
+
93
+ def upload_img(gr_img, text_input, chat_state):
94
+ if gr_img is None:
95
+ return None, None, gr.update(interactive=True), chat_state, None
96
+ chat_state = CONV_VISION.copy()
97
+ img_list = []
98
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
99
+ chat.encode_img(img_list)
100
+ return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
101
+
102
+
103
+ def gradio_ask(user_message, chatbot, chat_state):
104
+ if len(user_message) == 0:
105
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
106
+ chat.ask(user_message, chat_state)
107
+ chatbot = chatbot + [[user_message, None]]
108
+ return '', chatbot, chat_state
109
+
110
+
111
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
112
+ llm_message = chat.answer(conv=chat_state,
113
+ img_list=img_list,
114
+ num_beams=num_beams,
115
+ temperature=temperature,
116
+ max_new_tokens=300,
117
+ max_length=2000)[0]
118
+ chatbot[-1][1] = llm_message
119
+ return chatbot, chat_state, img_list
120
+
121
+
122
+ title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
123
+ description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
124
+ article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
125
+ """
126
+
127
+ #TODO show examples below
128
+
129
+ with gr.Blocks() as demo:
130
+ gr.Markdown(title)
131
+ gr.Markdown(description)
132
+ gr.Markdown(article)
133
+
134
+ with gr.Row():
135
+ with gr.Column(scale=1):
136
+ image = gr.Image(type="pil")
137
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
138
+ clear = gr.Button("Restart")
139
+
140
+ num_beams = gr.Slider(
141
+ minimum=1,
142
+ maximum=10,
143
+ value=1,
144
+ step=1,
145
+ interactive=True,
146
+ label="beam search numbers)",
147
+ )
148
+
149
+ temperature = gr.Slider(
150
+ minimum=0.1,
151
+ maximum=2.0,
152
+ value=1.0,
153
+ step=0.1,
154
+ interactive=True,
155
+ label="Temperature",
156
+ )
157
+
158
+ with gr.Column(scale=2):
159
+ chat_state = gr.State()
160
+ img_list = gr.State()
161
+ chatbot = gr.Chatbot(label='MiniGPT-4')
162
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
163
+
164
+ upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
165
+
166
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
167
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
168
+ )
169
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
170
+
171
+ demo.launch(share=True, enable_queue=True)
demo_v2.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ from collections import defaultdict
5
+
6
+ import cv2
7
+ import re
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+ import html
13
+ import gradio as gr
14
+ from torch.nn import DataParallel
15
+ import torchvision.transforms as T
16
+ import torch.backends.cudnn as cudnn
17
+
18
+ from minigpt4.common.config import Config
19
+
20
+ from minigpt4.common.registry import registry
21
+ from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
22
+
23
+ # imports modules for registration
24
+ from minigpt4.datasets.builders import *
25
+ from minigpt4.models import *
26
+ from minigpt4.processors import *
27
+ from minigpt4.runners import *
28
+ from minigpt4.tasks import *
29
+
30
+
31
+ def parse_args():
32
+ parser = argparse.ArgumentParser(description="Demo")
33
+ parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml',
34
+ help="path to configuration file.")
35
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
36
+ parser.add_argument(
37
+ "--options",
38
+ nargs="+",
39
+ help="override some settings in the used config, the key-value pair "
40
+ "in xxx=yyy format will be merged into config file (deprecate), "
41
+ "change to --cfg-options instead.",
42
+ )
43
+ args = parser.parse_args()
44
+ return args
45
+
46
+
47
+ random.seed(42)
48
+ np.random.seed(42)
49
+ torch.manual_seed(42)
50
+
51
+ cudnn.benchmark = False
52
+ cudnn.deterministic = True
53
+
54
+ print('Initializing Chat')
55
+ args = parse_args()
56
+ cfg = Config(args)
57
+
58
+ device = 'cuda:{}'.format(args.gpu_id)
59
+
60
+ model_config = cfg.model_cfg
61
+ model_config.device_8bit = args.gpu_id
62
+ model_cls = registry.get_model_class(model_config.arch)
63
+ model = model_cls.from_config(model_config).to(device)
64
+ bounding_box_size = 100
65
+
66
+ vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
67
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
68
+
69
+ # model = DataParallel(model)
70
+ model = model.eval()
71
+
72
+ CONV_VISION = Conversation(
73
+ system="",
74
+ roles=(r"<s>[INST] ", r" [/INST]"),
75
+ messages=[],
76
+ offset=2,
77
+ sep_style=SeparatorStyle.SINGLE,
78
+ sep="",
79
+ )
80
+
81
+
82
+ def extract_substrings(string):
83
+ # first check if there is no-finished bracket
84
+ index = string.rfind('}')
85
+ if index != -1:
86
+ string = string[:index + 1]
87
+
88
+ pattern = r'<p>(.*?)\}(?!<)'
89
+ matches = re.findall(pattern, string)
90
+ substrings = [match for match in matches]
91
+
92
+ return substrings
93
+
94
+
95
+ def is_overlapping(rect1, rect2):
96
+ x1, y1, x2, y2 = rect1
97
+ x3, y3, x4, y4 = rect2
98
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
99
+
100
+
101
+ def computeIoU(bbox1, bbox2):
102
+ x1, y1, x2, y2 = bbox1
103
+ x3, y3, x4, y4 = bbox2
104
+ intersection_x1 = max(x1, x3)
105
+ intersection_y1 = max(y1, y3)
106
+ intersection_x2 = min(x2, x4)
107
+ intersection_y2 = min(y2, y4)
108
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
109
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
110
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
111
+ union_area = bbox1_area + bbox2_area - intersection_area
112
+ iou = intersection_area / union_area
113
+ return iou
114
+
115
+
116
+ def save_tmp_img(visual_img):
117
+ file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
118
+ file_path = "/tmp/gradio" + file_name
119
+ visual_img.save(file_path)
120
+ return file_path
121
+
122
+
123
+ def mask2bbox(mask):
124
+ if mask is None:
125
+ return ''
126
+ mask = mask.resize([100, 100], resample=Image.NEAREST)
127
+ mask = np.array(mask)[:, :, 0]
128
+
129
+ rows = np.any(mask, axis=1)
130
+ cols = np.any(mask, axis=0)
131
+
132
+ if rows.sum():
133
+ # Get the top, bottom, left, and right boundaries
134
+ rmin, rmax = np.where(rows)[0][[0, -1]]
135
+ cmin, cmax = np.where(cols)[0][[0, -1]]
136
+ bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
137
+ else:
138
+ bbox = ''
139
+
140
+ return bbox
141
+
142
+
143
+ def escape_markdown(text):
144
+ # List of Markdown special characters that need to be escaped
145
+ md_chars = ['<', '>']
146
+
147
+ # Escape each special character
148
+ for char in md_chars:
149
+ text = text.replace(char, '\\' + char)
150
+
151
+ return text
152
+
153
+
154
+ def reverse_escape(text):
155
+ md_chars = ['\\<', '\\>']
156
+
157
+ for char in md_chars:
158
+ text = text.replace(char, char[1:])
159
+
160
+ return text
161
+
162
+
163
+ colors = [
164
+ (255, 0, 0),
165
+ (0, 255, 0),
166
+ (0, 0, 255),
167
+ (210, 210, 0),
168
+ (255, 0, 255),
169
+ (0, 255, 255),
170
+ (114, 128, 250),
171
+ (0, 165, 255),
172
+ (0, 128, 0),
173
+ (144, 238, 144),
174
+ (238, 238, 175),
175
+ (255, 191, 0),
176
+ (0, 128, 0),
177
+ (226, 43, 138),
178
+ (255, 0, 255),
179
+ (0, 215, 255),
180
+ ]
181
+
182
+ color_map = {
183
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
184
+ color_id, color in enumerate(colors)
185
+ }
186
+
187
+ used_colors = colors
188
+
189
+
190
+ def visualize_all_bbox_together(image, generation):
191
+ if image is None:
192
+ return None, ''
193
+
194
+ generation = html.unescape(generation)
195
+
196
+ image_width, image_height = image.size
197
+ image = image.resize([500, int(500 / image_width * image_height)])
198
+ image_width, image_height = image.size
199
+
200
+ string_list = extract_substrings(generation)
201
+ if string_list: # it is grounding or detection
202
+ mode = 'all'
203
+ entities = defaultdict(list)
204
+ i = 0
205
+ j = 0
206
+ for string in string_list:
207
+ try:
208
+ obj, string = string.split('</p>')
209
+ except ValueError:
210
+ print('wrong string: ', string)
211
+ continue
212
+ bbox_list = string.split('<delim>')
213
+ flag = False
214
+ for bbox_string in bbox_list:
215
+ integers = re.findall(r'-?\d+', bbox_string)
216
+ if len(integers) == 4:
217
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
218
+ left = x0 / bounding_box_size * image_width
219
+ bottom = y0 / bounding_box_size * image_height
220
+ right = x1 / bounding_box_size * image_width
221
+ top = y1 / bounding_box_size * image_height
222
+
223
+ entities[obj].append([left, bottom, right, top])
224
+
225
+ j += 1
226
+ flag = True
227
+ if flag:
228
+ i += 1
229
+ else:
230
+ integers = re.findall(r'-?\d+', generation)
231
+
232
+ if len(integers) == 4: # it is refer
233
+ mode = 'single'
234
+
235
+ entities = list()
236
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
237
+ left = x0 / bounding_box_size * image_width
238
+ bottom = y0 / bounding_box_size * image_height
239
+ right = x1 / bounding_box_size * image_width
240
+ top = y1 / bounding_box_size * image_height
241
+ entities.append([left, bottom, right, top])
242
+ else:
243
+ # don't detect any valid bbox to visualize
244
+ return None, ''
245
+
246
+ if len(entities) == 0:
247
+ return None, ''
248
+
249
+ if isinstance(image, Image.Image):
250
+ image_h = image.height
251
+ image_w = image.width
252
+ image = np.array(image)
253
+
254
+ elif isinstance(image, str):
255
+ if os.path.exists(image):
256
+ pil_img = Image.open(image).convert("RGB")
257
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
258
+ image_h = pil_img.height
259
+ image_w = pil_img.width
260
+ else:
261
+ raise ValueError(f"invaild image path, {image}")
262
+ elif isinstance(image, torch.Tensor):
263
+
264
+ image_tensor = image.cpu()
265
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
266
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
267
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
268
+ pil_img = T.ToPILImage()(image_tensor)
269
+ image_h = pil_img.height
270
+ image_w = pil_img.width
271
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
272
+ else:
273
+ raise ValueError(f"invaild image format, {type(image)} for {image}")
274
+
275
+ indices = list(range(len(entities)))
276
+
277
+ new_image = image.copy()
278
+
279
+ previous_bboxes = []
280
+ # size of text
281
+ text_size = 0.5
282
+ # thickness of text
283
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
284
+ box_line = 2
285
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
286
+ base_height = int(text_height * 0.675)
287
+ text_offset_original = text_height - base_height
288
+ text_spaces = 2
289
+
290
+ # num_bboxes = sum(len(x[-1]) for x in entities)
291
+ used_colors = colors # random.sample(colors, k=num_bboxes)
292
+
293
+ color_id = -1
294
+ for entity_idx, entity_name in enumerate(entities):
295
+ if mode == 'single' or mode == 'identify':
296
+ bboxes = entity_name
297
+ bboxes = [bboxes]
298
+ else:
299
+ bboxes = entities[entity_name]
300
+ color_id += 1
301
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
302
+ skip_flag = False
303
+ orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
304
+
305
+ color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
306
+ new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
307
+
308
+ if mode == 'all':
309
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
310
+
311
+ x1 = orig_x1 - l_o
312
+ y1 = orig_y1 - l_o
313
+
314
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
315
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
316
+ x1 = orig_x1 + r_o
317
+
318
+ # add text background
319
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
320
+ text_line)
321
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
322
+ text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
323
+
324
+ for prev_bbox in previous_bboxes:
325
+ if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
326
+ prev_bbox['phrase'] == entity_name:
327
+ skip_flag = True
328
+ break
329
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
330
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
331
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
332
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
333
+
334
+ if text_bg_y2 >= image_h:
335
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
336
+ text_bg_y2 = image_h
337
+ y1 = image_h
338
+ break
339
+ if not skip_flag:
340
+ alpha = 0.5
341
+ for i in range(text_bg_y1, text_bg_y2):
342
+ for j in range(text_bg_x1, text_bg_x2):
343
+ if i < image_h and j < image_w:
344
+ if j < text_bg_x1 + 1.35 * c_width:
345
+ # original color
346
+ bg_color = color
347
+ else:
348
+ # white
349
+ bg_color = [255, 255, 255]
350
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
351
+ np.uint8)
352
+
353
+ cv2.putText(
354
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
355
+ cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
356
+ )
357
+
358
+ previous_bboxes.append(
359
+ {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
360
+
361
+ if mode == 'all':
362
+ def color_iterator(colors):
363
+ while True:
364
+ for color in colors:
365
+ yield color
366
+
367
+ color_gen = color_iterator(colors)
368
+
369
+ # Add colors to phrases and remove <p></p>
370
+ def colored_phrases(match):
371
+ phrase = match.group(1)
372
+ color = next(color_gen)
373
+ return f'<span style="color:rgb{color}">{phrase}</span>'
374
+
375
+ generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
376
+ generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
377
+ else:
378
+ generation_colored = ''
379
+
380
+ pil_image = Image.fromarray(new_image)
381
+ return pil_image, generation_colored
382
+
383
+
384
+ def gradio_reset(chat_state, img_list):
385
+ if chat_state is not None:
386
+ chat_state.messages = []
387
+ if img_list is not None:
388
+ img_list = []
389
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
390
+ interactive=True), chat_state, img_list
391
+
392
+
393
+ def image_upload_trigger(upload_flag, replace_flag, img_list):
394
+ # set the upload flag to true when receive a new image.
395
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
396
+ upload_flag = 1
397
+ if img_list:
398
+ replace_flag = 1
399
+ return upload_flag, replace_flag
400
+
401
+
402
+ def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
403
+ # set the upload flag to true when receive a new image.
404
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
405
+ upload_flag = 1
406
+ if img_list or replace_flag == 1:
407
+ replace_flag = 1
408
+
409
+ return upload_flag, replace_flag
410
+
411
+
412
+ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
413
+ if len(user_message) == 0:
414
+ text_box_show = 'Input should not be empty!'
415
+ else:
416
+ text_box_show = ''
417
+
418
+ if isinstance(gr_img, dict):
419
+ gr_img, mask = gr_img['image'], gr_img['mask']
420
+ else:
421
+ mask = None
422
+
423
+ if '[identify]' in user_message:
424
+ # check if user provide bbox in the text input
425
+ integers = re.findall(r'-?\d+', user_message)
426
+ if len(integers) != 4: # no bbox in text
427
+ bbox = mask2bbox(mask)
428
+ user_message = user_message + bbox
429
+
430
+ if chat_state is None:
431
+ chat_state = CONV_VISION.copy()
432
+
433
+ if upload_flag:
434
+ if replace_flag:
435
+ chat_state = CONV_VISION.copy() # new image, reset everything
436
+ replace_flag = 0
437
+ chatbot = []
438
+ img_list = []
439
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
440
+ upload_flag = 0
441
+
442
+ chat.ask(user_message, chat_state)
443
+
444
+ chatbot = chatbot + [[user_message, None]]
445
+
446
+ if '[identify]' in user_message:
447
+ visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
448
+ if visual_img is not None:
449
+ file_path = save_tmp_img(visual_img)
450
+ chatbot = chatbot + [[(file_path,), None]]
451
+
452
+ return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
453
+
454
+
455
+ def gradio_answer(chatbot, chat_state, img_list, temperature):
456
+ llm_message = chat.answer(conv=chat_state,
457
+ img_list=img_list,
458
+ temperature=temperature,
459
+ max_new_tokens=500,
460
+ max_length=2000)[0]
461
+ chatbot[-1][1] = llm_message
462
+ return chatbot, chat_state
463
+
464
+
465
+ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
466
+ if len(img_list) > 0:
467
+ if not isinstance(img_list[0], torch.Tensor):
468
+ chat.encode_img(img_list)
469
+ streamer = chat.stream_answer(conv=chat_state,
470
+ img_list=img_list,
471
+ temperature=temperature,
472
+ max_new_tokens=500,
473
+ max_length=2000)
474
+ output = ''
475
+ for new_output in streamer:
476
+ escapped = escape_markdown(new_output)
477
+ output += escapped
478
+ chatbot[-1][1] = output
479
+ yield chatbot, chat_state
480
+ chat_state.messages[-1][1] = '</s>'
481
+ return chatbot, chat_state
482
+
483
+
484
+ def gradio_visualize(chatbot, gr_img):
485
+ if isinstance(gr_img, dict):
486
+ gr_img, mask = gr_img['image'], gr_img['mask']
487
+
488
+ unescaped = reverse_escape(chatbot[-1][1])
489
+ visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
490
+ if visual_img is not None:
491
+ if len(generation_color):
492
+ chatbot[-1][1] = generation_color
493
+ file_path = save_tmp_img(visual_img)
494
+ chatbot = chatbot + [[None, (file_path,)]]
495
+
496
+ return chatbot
497
+
498
+
499
+ def gradio_taskselect(idx):
500
+ prompt_list = [
501
+ '',
502
+ '[grounding] describe this image in detail',
503
+ '[refer] ',
504
+ '[detection] ',
505
+ '[identify] what is this ',
506
+ '[vqa] '
507
+ ]
508
+ instruct_list = [
509
+ '**Hint:** Type in whatever you want',
510
+ '**Hint:** Send the command to generate a grounded image description',
511
+ '**Hint:** Type in a phrase about an object in the image and send the command',
512
+ '**Hint:** Type in a caption or phrase, and see object locations in the image',
513
+ '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
514
+ '**Hint:** Send a question to get a short answer',
515
+ ]
516
+ return prompt_list[idx], instruct_list[idx]
517
+
518
+
519
+
520
+
521
+ chat = Chat(model, vis_processor, device=device)
522
+
523
+ title = """<h1 align="center">MiniGPT-v2 Demo</h1>"""
524
+ description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!'
525
+ # article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4/blob/main/MiniGPTv2.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/GitHub-Repo-blue'></a></p><p><a href='https://www.youtube.com/watch?v=atFCwV2hSY4'><img src='https://img.shields.io/badge/YouTube-Video-red'></a></p>"""
526
+ article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
527
+
528
+ introduction = '''
529
+ For Abilities Involving Visual Grounding:
530
+ 1. Grounding: CLICK **Send** to generate a grounded image description.
531
+ 2. Refer: Input a referring object and CLICK **Send**.
532
+ 3. Detection: Write a caption or phrase, and CLICK **Send**.
533
+ 4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
534
+ 5. VQA: Input a visual question and CLICK **Send**.
535
+ 6. No Tag: Input whatever you want and CLICK **Send** without any tagging
536
+
537
+ You can also simply chat in free form!
538
+ '''
539
+
540
+ text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
541
+ scale=8)
542
+ with gr.Blocks() as demo:
543
+ gr.Markdown(title)
544
+ # gr.Markdown(description)
545
+ gr.Markdown(article)
546
+
547
+ with gr.Row():
548
+ with gr.Column(scale=0.5):
549
+ image = gr.Image(type="pil", tool='sketch', brush_radius=20)
550
+
551
+ temperature = gr.Slider(
552
+ minimum=0.1,
553
+ maximum=1.5,
554
+ value=0.6,
555
+ step=0.1,
556
+ interactive=True,
557
+ label="Temperature",
558
+ )
559
+
560
+ clear = gr.Button("Restart")
561
+
562
+ gr.Markdown(introduction)
563
+
564
+ with gr.Column():
565
+ chat_state = gr.State(value=None)
566
+ img_list = gr.State(value=[])
567
+ chatbot = gr.Chatbot(label='MiniGPT-v2')
568
+
569
+ dataset = gr.Dataset(
570
+ components=[gr.Textbox(visible=False)],
571
+ samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']],
572
+ type="index",
573
+ label='Task Shortcuts',
574
+ )
575
+ task_inst = gr.Markdown('**Hint:** Upload your image and chat')
576
+ with gr.Row():
577
+ text_input.render()
578
+ send = gr.Button("Send", variant='primary', size='sm', scale=1)
579
+
580
+ upload_flag = gr.State(value=0)
581
+ replace_flag = gr.State(value=0)
582
+ image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
583
+
584
+ with gr.Row():
585
+ with gr.Column():
586
+ gr.Examples(examples=[
587
+ ["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag,
588
+ img_list],
589
+ ["examples_v2/sofa.jpg", "[detection] sofas", upload_flag, replace_flag, img_list],
590
+ ["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag,
591
+ img_list],
592
+ ["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag,
593
+ replace_flag, img_list],
594
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
595
+ outputs=[upload_flag, replace_flag])
596
+ with gr.Column():
597
+ gr.Examples(examples=[
598
+ ["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek",
599
+ upload_flag, replace_flag, img_list],
600
+ ["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list],
601
+ ["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list],
602
+ ["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag,
603
+ replace_flag, img_list],
604
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
605
+ outputs=[upload_flag, replace_flag])
606
+
607
+ dataset.click(
608
+ gradio_taskselect,
609
+ inputs=[dataset],
610
+ outputs=[text_input, task_inst],
611
+ show_progress="hidden",
612
+ postprocess=False,
613
+ queue=False,
614
+ )
615
+
616
+ text_input.submit(
617
+ gradio_ask,
618
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
619
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
620
+ ).success(
621
+ gradio_stream_answer,
622
+ [chatbot, chat_state, img_list, temperature],
623
+ [chatbot, chat_state]
624
+ ).success(
625
+ gradio_visualize,
626
+ [chatbot, image],
627
+ [chatbot],
628
+ queue=False,
629
+ )
630
+
631
+ send.click(
632
+ gradio_ask,
633
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
634
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
635
+ ).success(
636
+ gradio_stream_answer,
637
+ [chatbot, chat_state, img_list, temperature],
638
+ [chatbot, chat_state]
639
+ ).success(
640
+ gradio_visualize,
641
+ [chatbot, image],
642
+ [chatbot],
643
+ queue=False,
644
+ )
645
+
646
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
647
+
648
+ demo.launch(share=True, enable_queue=True)
649
+
650
+
651
+
environment.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: minigptv
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ - anaconda
6
+ dependencies:
7
+ - python=3.9
8
+ - cudatoolkit
9
+ - pip
10
+ - pip:
11
+ - torch==2.0.0
12
+ - torchaudio
13
+ - torchvision
14
+ - huggingface-hub==0.18.0
15
+ - matplotlib==3.7.0
16
+ - psutil==5.9.4
17
+ - iopath
18
+ - pyyaml==6.0
19
+ - regex==2022.10.31
20
+ - tokenizers==0.13.2
21
+ - tqdm==4.64.1
22
+ - transformers==4.30.0
23
+ - timm==0.6.13
24
+ - webdataset==0.2.48
25
+ - omegaconf==2.3.0
26
+ - opencv-python==4.7.0.72
27
+ - decord==0.6.0
28
+ - peft==0.2.0
29
+ - sentence-transformers
30
+ - gradio==3.47.1
31
+ - accelerate==0.20.3
32
+ - bitsandbytes==0.37.0
33
+ - scikit-image
34
+ - visual-genome
35
+ - wandb
eval_configs/minigpt4_eval.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: minigpt4
3
+ model_type: pretrain_vicuna0
4
+ max_txt_len: 160
5
+ end_sym: "###"
6
+ low_resource: True
7
+ prompt_template: '###Human: {} ###Assistant: '
8
+ ckpt: 'please set this value to the path of pretrained checkpoint'
9
+
10
+
11
+ datasets:
12
+ cc_sbu_align:
13
+ vis_processor:
14
+ train:
15
+ name: "blip2_image_eval"
16
+ image_size: 224
17
+ text_processor:
18
+ train:
19
+ name: "blip_caption"
20
+
21
+ run:
22
+ task: image_text_pretrain
eval_configs/minigpt4_llama2_eval.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: minigpt4
3
+ model_type: pretrain_llama2
4
+ max_txt_len: 160
5
+ end_sym: "</s>"
6
+ low_resource: True
7
+ prompt_template: '[INST] {} [/INST] '
8
+ ckpt: 'please set this value to the path of pretrained checkpoint'
9
+
10
+
11
+ datasets:
12
+ cc_sbu_align:
13
+ vis_processor:
14
+ train:
15
+ name: "blip2_image_eval"
16
+ image_size: 224
17
+ text_processor:
18
+ train:
19
+ name: "blip_caption"
20
+
21
+ run:
22
+ task: image_text_pretrain
eval_configs/minigptv2_benchmark_evaluation.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: minigpt_v2
3
+ model_type: pretrain
4
+ max_txt_len: 500
5
+ end_sym: "</s>"
6
+ low_resource: False
7
+ prompt_template: '[INST] {} [/INST]'
8
+ llama_model: ""
9
+ ckpt: ""
10
+ lora_r: 64
11
+ lora_alpha: 16
12
+
13
+
14
+ datasets:
15
+ cc_sbu_align:
16
+ vis_processor:
17
+ train:
18
+ name: "blip2_image_eval"
19
+ image_size: 448
20
+ text_processor:
21
+ train:
22
+ name: "blip_caption"
23
+
24
+ evaluation_datasets:
25
+ refcoco:
26
+ eval_file_path: /path/to/eval/annotation/path
27
+ img_path: /path/to/eval/image/path
28
+ max_new_tokens: 20
29
+ batch_size: 10
30
+ refcocog:
31
+ eval_file_path: /path/to/eval/annotation/path
32
+ img_path: /path/to/eval/image/path
33
+ max_new_tokens: 20
34
+ batch_size: 10
35
+ refcoco+:
36
+ eval_file_path: /path/to/eval/annotation/path
37
+ img_path: /path/to/eval/image/path
38
+ max_new_tokens: 20
39
+ batch_size: 10
40
+ gqa:
41
+ eval_file_path: /path/to/eval/annotation/path
42
+ img_path: /path/to/eval/image/path
43
+ max_new_tokens: 20
44
+ batch_size: 10
45
+ okvqa:
46
+ eval_file_path: /path/to/eval/annotation/path
47
+ img_path: /path/to/eval/image/path
48
+ max_new_tokens: 20
49
+ batch_size: 10
50
+ vizwiz:
51
+ eval_file_path: /path/to/eval/annotation/path
52
+ img_path: /path/to/eval/image/path
53
+ max_new_tokens: 20
54
+ batch_size: 10
55
+ iconvqa:
56
+ eval_file_path: /path/to/eval/annotation/path
57
+ img_path: /path/to/eval/image/path
58
+ max_new_tokens: 20
59
+ batch_size: 10
60
+ vsr:
61
+ eval_file_path: cambridgeltl/vsr_zeroshot
62
+ img_path: /path/to/eval/image/path
63
+ max_new_tokens: 20
64
+ batch_size: 10
65
+ hm:
66
+ eval_file_path: /path/to/eval/annotation/path
67
+ img_path: /path/to/eval/image/path
68
+ max_new_tokens: 20
69
+ batch_size: 100
70
+
71
+ run:
72
+ task: image_text_pretrain
73
+ name: minigptv2_evaluation
74
+ save_path: /path/to/save/folder_path
75
+
76
+
77
+
78
+
79
+
eval_configs/minigptv2_eval.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: minigpt_v2
3
+ model_type: pretrain
4
+ max_txt_len: 500
5
+ end_sym: "</s>"
6
+ low_resource: True
7
+ prompt_template: '[INST] {} [/INST]'
8
+ ckpt: "/data3/chengzhi/MiniGPT-4/checkpoint_stage2.pth"
9
+ lora_r: 64
10
+ lora_alpha: 16
11
+
12
+
13
+ datasets:
14
+ cc_sbu_align:
15
+ vis_processor:
16
+ train:
17
+ name: "blip2_image_eval"
18
+ image_size: 448
19
+ text_processor:
20
+ train:
21
+ name: "blip_caption"
22
+
23
+ run:
24
+ task: image_text_pretrain
eval_scripts/EVAL_README.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Evaluation Instruction for MiniGPT-v2
2
+
3
+ ### Data preparation
4
+ Images download
5
+ Image source | Download path
6
+ --- | :---:
7
+ OKVQA| <a href="https://drive.google.com/drive/folders/1jxIgAhtaLu_YqnZEl8Ym11f7LhX3nptN?usp=sharing">annotations</a> &nbsp;&nbsp; <a href="http://images.cocodataset.org/zips/train2017.zip"> images</a>
8
+ gqa | <a href="https://drive.google.com/drive/folders/1-dF-cgFwstutS4qq2D9CFQTDS0UTmIft?usp=drive_link">annotations</a> &nbsp;&nbsp; <a href="https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip">images</a>
9
+ hateful meme | <a href="https://github.com/faizanahemad/facebook-hateful-memes">images and annotations</a>
10
+ iconqa | <a href="https://iconqa.github.io/#download">images and annotation</a>
11
+ vizwiz | <a href="https://vizwiz.org/tasks-and-datasets/vqa/">images and annotation</a>
12
+ RefCOCO | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip"> annotations </a>
13
+ RefCOCO+ | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip"> annotations </a>
14
+ RefCOCOg | <a href="https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip"> annotations </a>
15
+
16
+ ### Evaluation dataset structure
17
+
18
+ ```
19
+ ${MINIGPTv2_EVALUATION_DATASET}
20
+ ├── gqa
21
+ │ └── test_balanced_questions.json
22
+ │ ├── testdev_balanced_questions.json
23
+ │ ├── gqa_images
24
+ ├── hateful_meme
25
+ │ └── hm_images
26
+ │ ├── dev.jsonl
27
+ ├── iconvqa
28
+ │ └── iconvqa_images
29
+ │ ├── choose_text_val.json
30
+ ├── vizwiz
31
+ │ └── vizwiz_images
32
+ │ ├── val.json
33
+ ├── vsr
34
+ │ └── vsr_images
35
+ ├── okvqa
36
+ │ ├── okvqa_test_split.json
37
+ │ ├── mscoco_val2014_annotations_clean.json
38
+ │ ├── OpenEnded_mscoco_val2014_questions_clean.json
39
+ ├── refcoco
40
+ │ └── instances.json
41
+ │ ├── refs(google).p
42
+ │ ├── refs(unc).p
43
+ ├── refcoco+
44
+ │ └── instances.json
45
+ │ ├── refs(unc).p
46
+ ├── refercocog
47
+ │ └── instances.json
48
+ │ ├── refs(google).p
49
+ │ ├── refs(und).p
50
+ ...
51
+ ```
52
+
53
+
54
+ ### environment setup
55
+
56
+ ```
57
+ export PYTHONPATH=$PYTHONPATH:/path/to/directory/of/MiniGPT-4
58
+ ```
59
+
60
+ ### config file setup
61
+
62
+ Set **llama_model** to the path of LLaMA model.
63
+ Set **ckpt** to the path of our pretrained model.
64
+ Set **eval_file_path** to the path of the annotation files for each evaluation data.
65
+ Set **img_path** to the img_path for each evaluation dataset.
66
+ Set **save_path** to the save_path for each evaluation dataset.
67
+
68
+ in [eval_configs/minigptv2_benchmark_evaluation.yaml](../eval_configs/minigptv2_benchmark_evaluation.yaml)
69
+
70
+
71
+
72
+
73
+ ### start evalauting RefCOCO, RefCOCO+, RefCOCOg
74
+ port=port_number
75
+ cfg_path=/path/to/eval_configs/minigptv2_benchmark_evaluation.yaml
76
+
77
+ dataset names:
78
+ | refcoco | refcoco+ | refcocog |
79
+ | ------- | -------- | -------- |
80
+
81
+ ```
82
+ torchrun --master-port ${port} --nproc_per_node 1 eval_ref.py \
83
+ --cfg-path ${cfg_path} --dataset refcoco,refcoco+,refcocog --resample
84
+ ```
85
+
86
+
87
+ ### start evaluating visual question answering
88
+
89
+ port=port_number
90
+ cfg_path=/path/to/eval_configs/minigptv2_benchmark_evaluation.yaml
91
+
92
+ dataset names:
93
+ | okvqa | vizwiz | iconvqa | gqa | vsr | hm |
94
+ | ------- | -------- | -------- |-------- | -------- | -------- |
95
+
96
+
97
+ ```
98
+ torchrun --master-port ${port} --nproc_per_node 1 eval_vqa.py \
99
+ --cfg-path ${cfg_path} --dataset okvqa,vizwiz,iconvqa,gqa,vsr,hm
100
+ ```
101
+
102
+
103
+
104
+
eval_scripts/eval_data/refcoco+_testA.json ADDED
The diff for this file is too large to render. See raw diff
 
eval_scripts/eval_data/refcoco+_testB.json ADDED
The diff for this file is too large to render. See raw diff
 
eval_scripts/eval_data/refcoco+_val.json ADDED
The diff for this file is too large to render. See raw diff
 
eval_scripts/eval_data/refcoco_testA.json ADDED
The diff for this file is too large to render. See raw diff
 
eval_scripts/eval_data/refcoco_testB.json ADDED
The diff for this file is too large to render. See raw diff
 
eval_scripts/eval_data/refcoco_val.json ADDED
The diff for this file is too large to render. See raw diff
 
eval_scripts/eval_data/refcocog_test.json ADDED
The diff for this file is too large to render. See raw diff
 
eval_scripts/eval_data/refcocog_val.json ADDED
The diff for this file is too large to render. See raw diff
 
eval_scripts/eval_ref.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import argparse
5
+ from collections import defaultdict
6
+ import random
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ import torch
11
+ from torch.utils.data import DataLoader
12
+ from minigpt4.common.config import Config
13
+ from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU
14
+ from minigpt4.conversation.conversation import CONV_VISION_minigptv2
15
+
16
+ from minigpt4.datasets.datasets.coco_caption import RefCOCOEvalData
17
+
18
+ def list_of_str(arg):
19
+ return list(map(str, arg.split(',')))
20
+
21
+ parser = eval_parser()
22
+ parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
23
+ parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco")
24
+ parser.add_argument("--resample", action='store_true', help="resolution used in refcoco")
25
+ args = parser.parse_args()
26
+
27
+ cfg = Config(args)
28
+
29
+ eval_dict = {'refcoco': ['val','testA','testB'],
30
+ 'refcoco+': ['val','testA','testB'],
31
+ 'refcocog': ['val','test']}
32
+
33
+
34
+ model, vis_processor = init_model(args)
35
+ model.eval()
36
+ CONV_VISION = CONV_VISION_minigptv2
37
+ conv_temp = CONV_VISION.copy()
38
+ conv_temp.system = ""
39
+
40
+ #
41
+ model.eval()
42
+ save_path = cfg.run_cfg.save_path
43
+
44
+
45
+
46
+ for dataset in args.dataset:
47
+ for split in eval_dict[dataset]:
48
+
49
+ eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"]
50
+ img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"]
51
+ batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"]
52
+ max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"]
53
+
54
+ with open(os.path.join(eval_file_path,f"{dataset}/{dataset}_{split}.json"), 'r') as f:
55
+ refcoco = json.load(f)
56
+
57
+ data = RefCOCOEvalData(refcoco, vis_processor, img_path)
58
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
59
+ minigpt4_predict = defaultdict(list)
60
+ resamples = []
61
+
62
+ for images, questions, img_ids in tqdm(eval_dataloader):
63
+ texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
64
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
65
+ for answer, img_id, question in zip(answers, img_ids, questions):
66
+ answer = answer.replace("<unk>","").replace(" ","").strip()
67
+ pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
68
+ if re.match(pattern, answer):
69
+ minigpt4_predict[img_id].append(answer)
70
+ else:
71
+ resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]})
72
+ if args.resample:
73
+ for i in range(20):
74
+ data = RefCOCOEvalData(resamples, vis_processor, img_path)
75
+ resamples = []
76
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
77
+ for images, questions, img_ids in tqdm(eval_dataloader):
78
+ texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
79
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
80
+ for answer, img_id, question in zip(answers, img_ids, questions):
81
+ answer = answer.replace("<unk>","").replace(" ","").strip()
82
+ pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
83
+ if re.match(pattern, answer) or i == 4:
84
+ minigpt4_predict[img_id].append(answer)
85
+ else:
86
+ resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]})
87
+
88
+ if len(resamples) == 0:
89
+ break
90
+
91
+ file_save_path = os.path.join(save_path,f"{args.dataset}_{split}.json")
92
+ with open(file_save_path,'w') as f:
93
+ json.dump(minigpt4_predict, f)
94
+
95
+ count=0
96
+ total=len(refcoco)
97
+ res=args.res
98
+ refcoco_dict = defaultdict()
99
+ for item in refcoco:
100
+ refcoco_dict[item['img_id']] = item
101
+ for img_id in refcoco_dict:
102
+ item = refcoco_dict[img_id]
103
+ bbox = item['bbox']
104
+ outputs = minigpt4_predict[img_id]
105
+ for output in outputs:
106
+ try:
107
+ integers = re.findall(r'\d+', output)
108
+ pred_bbox = [int(num) for num in integers]
109
+ height = item['height']
110
+ width = item['width']
111
+ pred_bbox[0] = pred_bbox[0] / res * width
112
+ pred_bbox[1] = pred_bbox[1] / res * height
113
+ pred_bbox[2] = pred_bbox[2] / res * width
114
+ pred_bbox[3] = pred_bbox[3] / res * height
115
+
116
+ gt_bbox = [0,0,0,0]
117
+ gt_bbox[0] = bbox[0]
118
+ gt_bbox[1] = bbox[1]
119
+ gt_bbox[2] = bbox[0] + bbox[2]
120
+ gt_bbox[3] = bbox[1] + bbox[3]
121
+
122
+ iou_score = computeIoU(pred_bbox, gt_bbox)
123
+ if iou_score > 0.5:
124
+ count+=1
125
+ except:
126
+ continue
127
+
128
+ print(f'{dataset} {split}:', count / total * 100, flush=True)
eval_scripts/eval_vqa.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import argparse
5
+ from collections import defaultdict
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ import torch
11
+ from torch.utils.data import DataLoader
12
+ from datasets import load_dataset
13
+
14
+
15
+ from minigpt4.datasets.datasets.vqa_datasets import OKVQAEvalData,VizWizEvalData,IconQAEvalData,GQAEvalData,VSREvalData,HMEvalData
16
+ from minigpt4.common.vqa_tools.VQA.PythonHelperTools.vqaTools.vqa import VQA
17
+ from minigpt4.common.vqa_tools.VQA.PythonEvaluationTools.vqaEvaluation.vqaEval import VQAEval
18
+
19
+ from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
20
+ from minigpt4.conversation.conversation import CONV_VISION_minigptv2
21
+ from minigpt4.common.config import Config
22
+
23
+
24
+ def list_of_str(arg):
25
+ return list(map(str, arg.split(',')))
26
+
27
+ parser = eval_parser()
28
+ parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
29
+ args = parser.parse_args()
30
+ cfg = Config(args)
31
+
32
+
33
+
34
+ model, vis_processor = init_model(args)
35
+ conv_temp = CONV_VISION_minigptv2.copy()
36
+ conv_temp.system = ""
37
+ model.eval()
38
+ save_path = cfg.run_cfg.save_path
39
+
40
+
41
+ if 'okvqa' in args.dataset:
42
+
43
+ eval_file_path = cfg.evaluation_datasets_cfg["okvqa"]["eval_file_path"]
44
+ img_path = cfg.evaluation_datasets_cfg["okvqa"]["img_path"]
45
+ batch_size = cfg.evaluation_datasets_cfg["okvqa"]["batch_size"]
46
+ max_new_tokens = cfg.evaluation_datasets_cfg["okvqa"]["max_new_tokens"]
47
+
48
+
49
+ evaluation_annntation_path = os.path.join(eval_file_path, "okvqa_test_split.json")
50
+ with open(evaluation_annntation_path) as f:
51
+ ok_vqa_test_split = json.load(f)
52
+
53
+ data = OKVQAEvalData(ok_vqa_test_split, vis_processor, img_path)
54
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
55
+ minigpt4_predict = []
56
+
57
+ for images, questions, question_ids, img_ids in eval_dataloader:
58
+ texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
59
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
60
+
61
+ for answer, question_id, question, img_id in zip(answers, question_ids, questions, img_ids):
62
+ result = dict()
63
+ answer = answer.lower().replace('<unk>','').strip()
64
+ result['answer'] = answer
65
+ result['question_id'] = int(question_id)
66
+ minigpt4_predict.append(result)
67
+
68
+ file_save_path= os.path.join(save_path,"okvqa.json")
69
+ with open(file_save_path,'w') as f:
70
+ json.dump(minigpt4_predict, f)
71
+
72
+ annFile = os.path.join(eval_file_path,"mscoco_val2014_annotations_clean.json")
73
+ quesFile = os.path.join(eval_file_path,"OpenEnded_mscoco_val2014_questions_clean.json" )
74
+
75
+ vqa = VQA(annFile, quesFile)
76
+ vqaRes = vqa.loadRes(file_save_path, quesFile)
77
+
78
+ vqaEval = VQAEval(vqa, vqaRes, n=2)
79
+ vqaEval.evaluate()
80
+ print ("Overall OKVQA Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']), flush=True)
81
+
82
+ if 'vizwiz' in args.dataset:
83
+
84
+ eval_file_path = cfg.evaluation_datasets_cfg["vizwiz"]["eval_file_path"]
85
+ img_path = cfg.evaluation_datasets_cfg["vizwiz"]["img_path"]
86
+ batch_size = cfg.evaluation_datasets_cfg["vizwiz"]["batch_size"]
87
+ max_new_tokens = cfg.evaluation_datasets_cfg["vizwiz"]["max_new_tokens"]
88
+
89
+ vizwiz = json.load(open(eval_file_path, 'r'))
90
+
91
+ data = VizWizEvalData(vizwiz, vis_processor, img_path)
92
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
93
+ minigpt4_predict = []
94
+ total_acc = []
95
+ for images, texts, gt_answers in tqdm(eval_dataloader):
96
+ texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
97
+ with torch.no_grad():
98
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False,repetition_penalty=1.0)
99
+
100
+ for answer, gt_answer in zip(answers, gt_answers):
101
+ result = dict()
102
+ result['answer'] = answer.replace('<unk>','').strip()
103
+ minigpt4_predict.append(result)
104
+ count=0
105
+ gt_answer = gt_answer.split('_')
106
+ for gt in gt_answer:
107
+ if gt.lower() == answer.lower():
108
+ count += 1
109
+ acc = min(count/3.0, 1.0)
110
+ total_acc.append(acc)
111
+
112
+ file_save_path = os.path.join(save_path, "vizwiz.json")
113
+ with open(file_save_path,'w') as f:
114
+ json.dump(minigpt4_predict, f)
115
+ print('vizwiz Acc: ', np.average(total_acc)* 100.0, flush=True)
116
+
117
+
118
+ if 'iconvqa' in args.dataset:
119
+
120
+ eval_file_path = cfg.evaluation_datasets_cfg["iconvqa"]["eval_file_path"]
121
+ img_path = cfg.evaluation_datasets_cfg["iconvqa"]["img_path"]
122
+ batch_size = cfg.evaluation_datasets_cfg["iconvqa"]["batch_size"]
123
+ max_new_tokens = cfg.evaluation_datasets_cfg["iconvqa"]["max_new_tokens"]
124
+
125
+ iconqa_text_val = json.load(open(eval_file_path,"r"))
126
+
127
+ data = IconQAEvalData(iconqa_text_val, vis_processor, img_path)
128
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
129
+
130
+ count = 0
131
+ for images, texts, candidates, answers in tqdm(eval_dataloader):
132
+ candidates = [candidate.split('_') for candidate in candidates]
133
+ num_cand = [len(candidate) for candidate in candidates]
134
+ for candidate in candidates:
135
+ candidate.extend(['none'] * (max(num_cand) - len(candidate)))
136
+ candidates = [list(x) for x in zip(*candidates)]
137
+ instructions = ["<s>[INST] <Img><ImageHere></Img> {} [/INST]".format(text) for text in texts]
138
+ answer_ranks = model.multi_select(images, instructions, candidates, num_cand=num_cand)
139
+ for idx, answer in enumerate(answers):
140
+ if answer_ranks[idx][0] == answer:
141
+ count += 1
142
+
143
+ print('iconqa Acc: ', count / len(iconqa_text_val) * 100.0, flush=True)
144
+
145
+
146
+ if 'gqa' in args.dataset:
147
+
148
+ eval_file_path = cfg.evaluation_datasets_cfg["gqa"]["eval_file_path"]
149
+ img_path = cfg.evaluation_datasets_cfg["gqa"]["img_path"]
150
+ batch_size = cfg.evaluation_datasets_cfg["gqa"]["batch_size"]
151
+ max_new_tokens = cfg.evaluation_datasets_cfg["gqa"]["max_new_tokens"]
152
+
153
+ gqa = json.load(open(eval_file_path))
154
+ data = GQAEvalData(gqa, vis_processor, img_path)
155
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
156
+ count=0
157
+ total=0
158
+ minigpt4_predict = []
159
+ for images, texts, labels in tqdm(eval_dataloader):
160
+ texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
161
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
162
+
163
+ for answer, label in zip(answers, labels):
164
+ result = dict()
165
+ result['pred'] = answer.lower().replace('<unk>','').strip()
166
+ result['gt'] = label
167
+ minigpt4_predict.append(result)
168
+ if answer.lower() == label:
169
+ count+=1
170
+ total+=1
171
+ print('gqa val:', count / total * 100, flush=True)
172
+
173
+ file_save_path = os.path.join(save_path, "gqa.json")
174
+ with open(file_save_path,'w') as f:
175
+ json.dump(minigpt4_predict, f)
176
+
177
+ if 'vsr' in args.dataset:
178
+
179
+ img_path = cfg.evaluation_datasets_cfg["vsr"]["img_path"]
180
+ batch_size = cfg.evaluation_datasets_cfg["vsr"]["batch_size"]
181
+ max_new_tokens = cfg.evaluation_datasets_cfg["vsr"]["max_new_tokens"]
182
+
183
+ annotation = load_dataset("cambridgeltl/vsr_zeroshot", split='test')
184
+ data = VSREvalData(annotation, vis_processor, img_path)
185
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
186
+ count=0
187
+ total=0
188
+
189
+ minigpt4_predict = []
190
+
191
+ for images, texts, labels in tqdm(eval_dataloader):
192
+ texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
193
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
194
+
195
+ for answer, label in zip(answers, labels):
196
+ result = dict()
197
+ result['pred'] = answer.replace('<unk>','').strip()
198
+ result['gt'] = label
199
+ minigpt4_predict.append(result)
200
+ if answer.lower() == label.lower():
201
+ count+=1
202
+ total+=1
203
+ print('vsr test:', count / total * 100, flush=True)
204
+ file_save_path = os.path.join(save_path,"vsr.json")
205
+ with open(file_save_path,'w') as f:
206
+ json.dump(minigpt4_predict, f)
207
+
208
+ if 'hm' in args.dataset:
209
+
210
+ eval_file_path = cfg.evaluation_datasets_cfg["hm"]["eval_file_path"]
211
+ img_path = cfg.evaluation_datasets_cfg["hm"]["img_path"]
212
+ batch_size = cfg.evaluation_datasets_cfg["hm"]["batch_size"]
213
+ max_new_tokens = cfg.evaluation_datasets_cfg["hm"]["max_new_tokens"]
214
+
215
+ annotation = []
216
+ with open(eval_file_path, 'r') as jsonl_file:
217
+ for line in jsonl_file:
218
+ json_obj = json.loads(line)
219
+ annotation.append(json_obj)
220
+
221
+ data = HMEvalData(annotation, vis_processor, img_path)
222
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
223
+ count=0
224
+ total=0
225
+
226
+ minigpt4_predict = []
227
+
228
+ for images, texts, labels in tqdm(eval_dataloader):
229
+ texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
230
+
231
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
232
+
233
+ for answer, label in zip(answers, labels):
234
+ result = dict()
235
+ if answer.lower().strip() =="yes":
236
+ answer=1
237
+ elif answer.lower().strip()=="no":
238
+ answer=0
239
+ else:
240
+ print("non-matching answer",answer)
241
+
242
+ result['pred'] = answer
243
+ result['gt'] = int(label)
244
+ minigpt4_predict.append(result)
245
+ if answer == label:
246
+ count+=1
247
+ total+=1
248
+
249
+ print('hm val:', count / total * 100, flush=True)
250
+ file_save_path = os.path.join(save_path, "hm.json")
251
+ with open(file_save_path,'w') as f:
252
+ json.dump(minigpt4_predict, f)
examples/ad_1.png ADDED
examples/ad_2.png ADDED
examples/cook_1.png ADDED
examples/cook_2.png ADDED
examples/describe_1.png ADDED
examples/describe_2.png ADDED
examples/fact_1.png ADDED
examples/fact_2.png ADDED
examples/fix_1.png ADDED
examples/fix_2.png ADDED
examples/fun_1.png ADDED
examples/fun_2.png ADDED