blip2-debunker / app.py
zgjiangtoby's picture
init
f09d32a
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from peft import LoraConfig, get_peft_model, PeftModel
import torch
import streamlit as st
from PIL import Image
from streamlit_chat import message
from io import BytesIO, StringIO
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
@st.cache_resource
def load_model():
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
)
model_name = "./blip2_fakenews_all"
#
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
# device_map = {"": 0}
# device_map = "auto"
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = PeftModel.from_pretrained(model, model_name)
model = get_peft_model(model, config)
return processor, model
st.title('Blip2 Fake News Debunker')
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'bot_prompt' not in st.session_state:
st.session_state.bot_prompt = []
def get_text():
chat = st.text_input('Start to chat:', placeholder="Hello! Let's start to chat from here! ")
return chat
def generate_output(image, prompt):
encoding = processor(images=image, text=prompt, max_length=512, truncation=True,
padding="max_length", return_tensors="pt")
predictions = model.generate(input_ids=encoding['input_ids'],
pixel_values=encoding['pixel_values'],
max_length=20)
p = processor.batch_decode(predictions, skip_special_tokens=True)
out = " ".join(p)
return out
if st.button('Start a new chat'):
st.cache_resource.clear()
st.cache_data.clear()
for key in st.session_state.keys():
del st.session_state[key]
st.experimental_rerun()
col1, col2 = st.columns(2)
show_file = st.empty()
with col1:
st.markdown("Step 1: ")
uploaded_file = st.file_uploader("Upload a news image here: ", type=["png", "jpg"])
if not uploaded_file:
show_file.info("Please upload a file of type: " + ", ".join(["png", "jpg"]))
if isinstance(uploaded_file, BytesIO):
image = Image.open(uploaded_file)
st.image(image)
with col2:
st.markdown("Step 2: ")
txt = st.text_area("Paste news content here: ")
st.markdown("Step 3: ")
user_input = get_text()
# if user_input:
# st.write("You: ", user_input)
processor, model = load_model()
def main():
if uploaded_file and user_input:
prompt = "Qustions: What is this news about? " \
"\nAnswer: " + txt + \
"\nQustions: " + user_input
if len(st.session_state.bot_prompt) == 0:
pr: list = prompt.split('\n')
pr = [p for p in pr if len(p)] # remove empty string
st.session_state.bot_prompt = pr
print(f'init: {st.session_state.bot_prompt}')
if user_input:
st.session_state.bot_prompt.append(f'You: {user_input}')
# Convert a list of prompts to a string for the GPT bot.
input_prompt: str = '\n'.join(st.session_state.bot_prompt)
print(f'bot prompt input list:\n{st.session_state.bot_prompt}')
print(f'bot prompt input string:\n{input_prompt}')
output = generate_output(image, prompt=input_prompt)
st.session_state.past.append(user_input)
st.session_state.generated.append(output)
# Add bot response for next prompt.
st.session_state.bot_prompt.append(f'Answer: {output}')
with col2:
if st.session_state['generated']:
for i in range(len(st.session_state['generated']) - 1, -1, -1):
message(st.session_state["generated"][i], key=str(i))
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
if __name__ == '__main__':
main()