zgjiangtoby commited on
Commit
d710f4c
·
1 Parent(s): a3d97fd
Files changed (3) hide show
  1. app.py +111 -22
  2. requirements.txt +0 -312
  3. test.py +35 -0
app.py CHANGED
@@ -2,34 +2,123 @@ from transformers import Blip2Processor, Blip2ForConditionalGeneration
2
  from peft import LoraConfig, get_peft_model, PeftModel
3
  import torch
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
7
 
8
- config = LoraConfig(
9
- r=16,
10
- lora_alpha=32,
11
- lora_dropout=0.05,
12
- bias="none",
13
- )
14
 
15
- # model_name = "/home/yejiang/blip2_fakedit/saved_model/blip2_fakenews_all"
16
- #
17
- # processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
18
- # device_map = {"": 0}
19
- # model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl",
20
- # load_in_8bit=True,
21
- # device_map=device_map)
22
- # model = PeftModel.from_pretrained(model, model_name)
23
- # model = get_peft_model(model, config)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  col1, col2 = st.columns(2)
 
 
27
  with col1:
28
- chat = st.text_input("Chat",
29
- label_visibility=st.session_state.visibility,
30
- disabled=st.session_state.disabled,
31
- placeholder=st.session_state.placeholder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- if chat:
34
- st.write("You entered: ", chat)
35
 
 
 
 
2
  from peft import LoraConfig, get_peft_model, PeftModel
3
  import torch
4
  import streamlit as st
5
+ from PIL import Image
6
+ from streamlit_chat import message
7
+ from io import BytesIO, StringIO
8
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ device = "cpu"
10
+ @st.cache_resource
11
+ def load_model():
12
+ config = LoraConfig(
13
+ r=16,
14
+ lora_alpha=32,
15
+ lora_dropout=0.05,
16
+ bias="none",
17
+ )
18
 
19
+ model_name = "./blip2_fakenews_all"
20
+ #
21
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
22
+ # device_map = {"": 0}
23
+ device_map = "auto"
24
+ model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl",
25
+ load_in_8bit=True,
26
+ device_map=device_map)
27
+ model = PeftModel.from_pretrained(model, model_name)
28
+ model = get_peft_model(model, config)
29
+ return processor, model
30
 
 
 
 
 
 
 
31
 
32
+ st.title('Blip2 Fake News Debunker')
 
 
 
 
 
 
 
 
33
 
34
+ if 'generated' not in st.session_state:
35
+ st.session_state['generated'] = []
36
+
37
+ if 'past' not in st.session_state:
38
+ st.session_state['past'] = []
39
+
40
+ if 'bot_prompt' not in st.session_state:
41
+ st.session_state.bot_prompt = []
42
+
43
+
44
+ def get_text():
45
+ chat = st.text_input('Start to chat:', placeholder="Hello! Let's start to chat from here! ")
46
+ return chat
47
+
48
+ def generate_output(image, prompt):
49
+ encoding = processor(images=image, text=prompt, max_length=512, truncation=True,
50
+ padding="max_length", return_tensors="pt")
51
+ predictions = model.generate(input_ids=encoding['input_ids'].to(device),
52
+ pixel_values=encoding['pixel_values'].to(device, torch.float16),
53
+ max_length=20)
54
+ p = processor.batch_decode(predictions, skip_special_tokens=True)
55
+ out = " ".join(p)
56
+ return out
57
+
58
+ if st.button('Start a new chat'):
59
+ st.cache_resource.clear()
60
+ st.cache_data.clear()
61
+ for key in st.session_state.keys():
62
+ del st.session_state[key]
63
+ st.experimental_rerun()
64
 
65
  col1, col2 = st.columns(2)
66
+ show_file = st.empty()
67
+
68
  with col1:
69
+ st.markdown("Step 1: ")
70
+ uploaded_file = st.file_uploader("Upload a news image here: ", type=["png", "jpg"])
71
+
72
+ if not uploaded_file:
73
+ show_file.info("Please upload a file of type: " + ", ".join(["png", "jpg"]))
74
+ if isinstance(uploaded_file, BytesIO):
75
+ image = Image.open(uploaded_file)
76
+ st.image(image)
77
+
78
+
79
+ with col2:
80
+ st.markdown("Step 2: ")
81
+ txt = st.text_area("Paste news content here: ")
82
+ st.markdown("Step 3: ")
83
+ user_input = get_text()
84
+ # if user_input:
85
+ # st.write("You: ", user_input)
86
+
87
+ processor, model = load_model()
88
+ def main():
89
+ if uploaded_file and user_input:
90
+ prompt = "Qustions: What is this news about? " \
91
+ "\nAnswer: " + txt + \
92
+ "\nQustions: " + user_input
93
+
94
+ if len(st.session_state.bot_prompt) == 0:
95
+ pr: list = prompt.split('\n')
96
+ pr = [p for p in pr if len(p)] # remove empty string
97
+ st.session_state.bot_prompt = pr
98
+ print(f'init: {st.session_state.bot_prompt}')
99
+
100
+ if user_input:
101
+ st.session_state.bot_prompt.append(f'You: {user_input}')
102
+
103
+ # Convert a list of prompts to a string for the GPT bot.
104
+ input_prompt: str = '\n'.join(st.session_state.bot_prompt)
105
+ print(f'bot prompt input list:\n{st.session_state.bot_prompt}')
106
+ print(f'bot prompt input string:\n{input_prompt}')
107
+
108
+ output = generate_output(image, prompt=input_prompt)
109
+
110
+ st.session_state.past.append(user_input)
111
+ st.session_state.generated.append(output)
112
+
113
+ # Add bot response for next prompt.
114
+ st.session_state.bot_prompt.append(f'Answer: {output}')
115
+ with col2:
116
+ if st.session_state['generated']:
117
+ for i in range(len(st.session_state['generated']) - 1, -1, -1):
118
+ message(st.session_state["generated"][i], key=str(i))
119
+ message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
120
+
121
 
 
 
122
 
123
+ if __name__ == '__main__':
124
+ main()
requirements.txt DELETED
@@ -1,312 +0,0 @@
1
- absl-py==1.2.0
2
- accelerate==0.17.0
3
- aiofiles==22.1.0
4
- aiohttp @ file:///home/conda/feedstock_root/build_artifacts/aiohttp_1649013154501/work
5
- aiosignal @ file:///home/conda/feedstock_root/build_artifacts/aiosignal_1667935791922/work
6
- aiostream==0.4.5
7
- altair==4.2.0
8
- antlr4-python3-runtime==4.9.3
9
- anyio==3.6.1
10
- arxiv==1.4.7
11
- asttokens==2.2.1
12
- async-timeout @ file:///home/conda/feedstock_root/build_artifacts/async-timeout_1640026696943/work
13
- attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1671632566681/work
14
- backcall==0.2.0
15
- beautifulsoup4==4.7.1
16
- bitsandbytes==0.37.1
17
- blinker==1.5
18
- blis==0.7.9
19
- blobfile==2.0.2
20
- boto3==1.26.13
21
- botocore==1.29.13
22
- Bottleneck @ file:///opt/conda/conda-bld/bottleneck_1657175564434/work
23
- braceexpand==0.1.7
24
- brotlipy==0.7.0
25
- cached-property==1.5.2
26
- cachetools==5.2.0
27
- catalogue==2.0.8
28
- certifi==2022.12.7
29
- cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work
30
- cfgv==3.3.1
31
- chardet==3.0.4
32
- charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
33
- chex==0.1.6
34
- click==8.1.3
35
- clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
36
- clip-client==0.6.2
37
- clip-server==0.6.2
38
- colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
39
- comet-ml==3.12.2
40
- commonmark==0.9.1
41
- confection==0.0.4
42
- configobj==5.0.8
43
- contexttimer==0.3.3
44
- contourpy @ file:///opt/conda/conda-bld/contourpy_1663827406301/work
45
- cryptography @ file:///tmp/build/80754af9/cryptography_1652101588893/work
46
- cssselect==1.0.3
47
- cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
48
- cymem==2.0.7
49
- dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work
50
- datasets==2.8.0
51
- decorator==5.1.1
52
- decord==0.6.0
53
- deepspeed==0.8.3
54
- dill @ file:///home/conda/feedstock_root/build_artifacts/dill_1666603105584/work
55
- distlib==0.3.6
56
- dm-tree==0.1.8
57
- docarray==0.16.3
58
- docker==6.0.0
59
- docker-pycreds==0.4.0
60
- dulwich==0.21.3
61
- einops==0.6.0
62
- en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl
63
- entrypoints==0.4
64
- etils==1.0.0
65
- evaluate==0.4.0
66
- everett==3.1.0
67
- exceptiongroup==1.1.0
68
- executing==1.2.0
69
- fairscale==0.4.4
70
- fastapi==0.82.0
71
- feedfinder2==0.0.4
72
- feedparser==5.2.1
73
- ffmpy==0.3.0
74
- filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1672354931606/work
75
- Flask==1.0.2
76
- Flask-Cors==3.0.7
77
- flax==0.6.6
78
- fonttools==4.37.1
79
- frozenlist @ file:///croot/frozenlist_1670004507010/work
80
- fsspec @ file:///home/conda/feedstock_root/build_artifacts/fsspec_1668082755814/work
81
- ftfy==6.1.1
82
- gdown==4.5.1
83
- gitdb==4.0.10
84
- GitPython==3.1.30
85
- google-auth==2.9.1
86
- google-auth-oauthlib==0.4.6
87
- gradio==3.25.0
88
- gradio_client==0.1.3
89
- grpcio==1.47.0
90
- grpcio-health-checking==1.47.0
91
- grpcio-reflection==1.47.0
92
- h11==0.13.0
93
- hjson==3.1.0
94
- httpcore==0.17.0
95
- httptools==0.4.0
96
- httpx==0.24.0
97
- huggingface-hub==0.13.4
98
- identify==2.5.18
99
- idna @ file:///croot/idna_1666125576474/work
100
- imageio==2.22.4
101
- imbalanced-learn==0.10.1
102
- importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1672612343532/work
103
- importlib-resources==5.12.0
104
- iniconfig==2.0.0
105
- iopath==0.1.10
106
- ipython==8.11.0
107
- itsdangerous==1.1.0
108
- jax==0.4.5
109
- jaxlib==0.4.4
110
- jcloud==0.0.35
111
- jedi==0.18.2
112
- jieba3k==0.35.1
113
- jina==3.8.3
114
- jina-hubble-sdk==0.15.5
115
- Jinja2==3.1.2
116
- jmespath==1.0.1
117
- joblib==1.2.0
118
- jsonschema==4.17.3
119
- kaggle==1.5.13
120
- kiwisolver==1.4.4
121
- langcodes==3.3.0
122
- langdetect==1.0.7
123
- latex2mathml==3.75.2
124
- linkify-it-py==2.0.0
125
- lxml==4.9.1
126
- lz4==4.0.2
127
- Markdown==3.4.1
128
- markdown-it-py==2.2.0
129
- MarkupSafe==2.1.1
130
- matplotlib @ file:///croot/matplotlib-suite_1670466153205/work
131
- matplotlib-inline==0.1.6
132
- mdit-py-plugins==0.3.3
133
- mdtex2html==1.2.0
134
- mdurl==0.1.2
135
- mkl-fft==1.3.1
136
- mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186066731/work
137
- mkl-service==2.4.0
138
- msgpack==1.0.4
139
- multidict @ file:///croot/multidict_1665674239670/work
140
- multiprocess==0.70.13
141
- munkres==1.1.4
142
- murmurhash==1.0.9
143
- networkx==2.8.8
144
- newspaper3k==0.2.8
145
- ninja==1.11.1
146
- nltk==3.4
147
- nodeenv==1.7.0
148
- numexpr @ file:///croot/numexpr_1668713893690/work
149
- numpy @ file:///opt/conda/conda-bld/numpy_and_numpy_base_1652801679809/work
150
- nvidia-ml-py3==7.352.0
151
- oauthlib==3.0.1
152
- omegaconf==2.3.0
153
- open-clip-torch==1.3.0
154
- openai==0.27.0
155
- opencv-python-headless==4.5.5.64
156
- opendatasets==0.1.22
157
- openprompt==1.0.1
158
- opt-einsum==3.3.0
159
- optax==0.1.4
160
- orbax==0.1.3
161
- orjson==3.8.10
162
- packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work
163
- pandas==1.5.2
164
- parso==0.8.3
165
- pathspec==0.10.1
166
- pathtools==0.1.2
167
- pathy==0.10.1
168
- peft @ git+https://github.com/huggingface/peft.git@4fd374e80d670781c0d82c96ce94d1215ff23306
169
- pexpect==4.8.0
170
- pickleshare==0.7.5
171
- Pillow==9.4.0
172
- platformdirs==3.1.0
173
- plotly==5.13.1
174
- pluggy==1.0.0
175
- portalocker==2.7.0
176
- pre-commit==3.1.1
177
- preshed==3.0.8
178
- prometheus-client==0.14.1
179
- promise==2.3
180
- prompt-toolkit==3.0.38
181
- protobuf==3.19.4
182
- psutil==5.9.4
183
- ptyprocess==0.7.0
184
- pure-eval==0.2.2
185
- py-cpuinfo==9.0.0
186
- pyarrow==8.0.0
187
- pyasn1==0.4.8
188
- pyasn1-modules==0.2.8
189
- pybase64==1.2.3
190
- pycocoevalcap==1.2
191
- pycocotools==2.0.6
192
- pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
193
- pycryptodome==3.17
194
- pycryptodomex==3.17
195
- pydantic==1.10.2
196
- pydeck==0.8.0
197
- pydub==0.25.1
198
- Pygments==2.13.0
199
- Pympler==1.0.1
200
- PyMuPDF==1.21.1
201
- pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work
202
- pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
203
- pyrsistent==0.19.3
204
- PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work
205
- pytest==7.2.0
206
- python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
207
- python-docx==0.8.11
208
- python-dotenv==0.18.0
209
- python-magic==0.4.27
210
- python-markdown-math==0.8
211
- python-multipart==0.0.5
212
- python-slugify==8.0.1
213
- pytorch-pretrained-bert==0.6.2
214
- pytorch-transformers==1.0.0
215
- pytz @ file:///opt/conda/conda-bld/pytz_1654762638606/work
216
- pytz-deprecation-shim==0.1.0.post0
217
- PyWavelets==1.4.1
218
- pywebarchive==0.5.0
219
- PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1648757097602/work
220
- regex==2022.7.25
221
- requests @ file:///opt/conda/conda-bld/requests_1657734628632/work
222
- requests-file==1.4.3
223
- requests-oauthlib==1.2.0
224
- requests-toolbelt==0.10.1
225
- responses==0.10.15
226
- rich==12.5.1
227
- rouge==1.0.0
228
- rouge-score==0.1.2
229
- rsa==4.9
230
- s3transfer==0.6.0
231
- salesforce-lavis==1.0.0
232
- scikit-image==0.19.3
233
- scikit-learn==1.1.2
234
- scipy==1.9.1
235
- seaborn @ file:///croot/seaborn_1669627814970/work
236
- semantic-version==2.10.0
237
- semver==2.13.0
238
- sentencepiece==0.1.96
239
- sentry-sdk==1.12.1
240
- seqeval==1.2.2
241
- setproctitle==1.3.2
242
- shortuuid==1.0.11
243
- simpletransformers==0.63.9
244
- singledispatch==3.4.0.3
245
- six @ file:///tmp/build/80754af9/six_1644875935023/work
246
- skorch==0.12.1
247
- smart-open==6.3.0
248
- smmap==5.0.0
249
- sniffio==1.3.0
250
- soupsieve==1.8
251
- spacy==3.4.4
252
- spacy-langdetect==0.1.2
253
- spacy-legacy==3.0.11
254
- spacy-loggers==1.0.4
255
- srsly==2.4.5
256
- stack-data==0.6.2
257
- starlette==0.19.1
258
- streamlit==1.16.0
259
- tabulate==0.9.0
260
- tenacity==8.2.2
261
- tensorboard==2.9.1
262
- tensorboard-data-server==0.6.1
263
- tensorboard-plugin-wit==1.8.1
264
- tensorboardX==2.5.1
265
- tensorstore==0.1.33
266
- text-unidecode==1.3
267
- thinc==8.1.6
268
- threadpoolctl==3.1.0
269
- tifffile==2022.10.10
270
- tiktoken==0.4.0
271
- timm==0.4.12
272
- tinysegmenter==0.3
273
- tldextract==2.2.1
274
- tokenizers==0.12.1
275
- toml==0.10.2
276
- tomli==2.0.1
277
- toolz==0.12.0
278
- torch==1.13.1
279
- torchaudio==0.13.1
280
- torchvision==0.14.1
281
- tornado==6.2
282
- tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1662214488106/work
283
- traitlets==5.9.0
284
- transformers @ git+https://github.com/huggingface/transformers.git@a9bd5df16a46356463f2712dd8f6c109fa83d6f9
285
- twython==3.7.0
286
- typer==0.7.0
287
- typing_extensions==4.5.0
288
- tzdata==2022.7
289
- tzlocal==4.2
290
- uc-micro-py==1.0.1
291
- urllib3 @ file:///croot/urllib3_1670526988650/work
292
- uvicorn==0.18.3
293
- uvloop==0.16.0
294
- validators==0.20.0
295
- virtualenv==20.20.0
296
- wandb==0.13.7
297
- wasabi==0.10.1
298
- watchdog==2.2.1
299
- watchfiles==0.16.1
300
- waybackpy==3.0.6
301
- wcwidth==0.2.5
302
- webdataset==0.2.35
303
- websocket-client==1.4.1
304
- websockets==10.3
305
- Werkzeug==2.2.2
306
- wrapt==1.15.0
307
- wurlitzer==3.0.3
308
- xmlx==2.0.0
309
- xxhash==3.0.0
310
- yacs==0.1.8
311
- yarl @ file:///home/conda/feedstock_root/build_artifacts/yarl_1648966524636/work
312
- zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1669453021653/work
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import requests
3
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
4
+ import torch
5
+ from peft import LoraConfig, get_peft_model, PeftModel
6
+
7
+ config = LoraConfig(
8
+ r=16,
9
+ lora_alpha=32,
10
+ lora_dropout=0.05,
11
+ bias="none",
12
+ )
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model_name = "./blip2_fakenews_all"
16
+
17
+ # device_map = {"": 0}
18
+ device_map = "auto"
19
+ processor = Blip2Processor.from_pretrained("blip2")
20
+ model = Blip2ForConditionalGeneration.from_pretrained("blip2",
21
+ load_in_8bit=True,
22
+ device_map=device_map)
23
+ model = PeftModel.from_pretrained(model, model_name)
24
+ model = get_peft_model(model, config)
25
+
26
+
27
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
28
+ image = Image.open(requests.get(url, stream=True).raw)
29
+
30
+ prompt = "Question: Is this real or fake? Answer: real. Question: Why? "
31
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
32
+
33
+ generated_ids = model.generate(**inputs)
34
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
35
+ print(generated_text)