Spaces:
Build error
Build error
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
from peft import LoraConfig, get_peft_model, PeftModel | |
import torch | |
import streamlit as st | |
from PIL import Image | |
from streamlit_chat import message | |
from io import BytesIO, StringIO | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
device = "cpu" | |
def load_model(): | |
config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
bias="none", | |
) | |
model_name = "./blip2_fakenews_all" | |
# | |
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") | |
# device_map = {"": 0} | |
# device_map = "auto" | |
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl") | |
model = PeftModel.from_pretrained(model, model_name) | |
model = get_peft_model(model, config) | |
return processor, model | |
st.title('Blip2 Fake News Debunker') | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = [] | |
if 'past' not in st.session_state: | |
st.session_state['past'] = [] | |
if 'bot_prompt' not in st.session_state: | |
st.session_state.bot_prompt = [] | |
def get_text(): | |
chat = st.text_input('Start to chat:', placeholder="Hello! Let's start to chat from here! ") | |
return chat | |
def generate_output(image, prompt): | |
encoding = processor(images=image, text=prompt, max_length=512, truncation=True, | |
padding="max_length", return_tensors="pt") | |
predictions = model.generate(input_ids=encoding['input_ids'], | |
pixel_values=encoding['pixel_values'], | |
max_length=20) | |
p = processor.batch_decode(predictions, skip_special_tokens=True) | |
out = " ".join(p) | |
return out | |
if st.button('Start a new chat'): | |
st.cache_resource.clear() | |
st.cache_data.clear() | |
for key in st.session_state.keys(): | |
del st.session_state[key] | |
st.experimental_rerun() | |
col1, col2 = st.columns(2) | |
show_file = st.empty() | |
with col1: | |
st.markdown("Step 1: ") | |
uploaded_file = st.file_uploader("Upload a news image here: ", type=["png", "jpg"]) | |
if not uploaded_file: | |
show_file.info("Please upload a file of type: " + ", ".join(["png", "jpg"])) | |
if isinstance(uploaded_file, BytesIO): | |
image = Image.open(uploaded_file) | |
st.image(image) | |
with col2: | |
st.markdown("Step 2: ") | |
txt = st.text_area("Paste news content here: ") | |
st.markdown("Step 3: ") | |
user_input = get_text() | |
# if user_input: | |
# st.write("You: ", user_input) | |
processor, model = load_model() | |
def main(): | |
if uploaded_file and user_input: | |
prompt = "Qustions: What is this news about? " \ | |
"\nAnswer: " + txt + \ | |
"\nQustions: " + user_input | |
if len(st.session_state.bot_prompt) == 0: | |
pr: list = prompt.split('\n') | |
pr = [p for p in pr if len(p)] # remove empty string | |
st.session_state.bot_prompt = pr | |
print(f'init: {st.session_state.bot_prompt}') | |
if user_input: | |
st.session_state.bot_prompt.append(f'You: {user_input}') | |
# Convert a list of prompts to a string for the GPT bot. | |
input_prompt: str = '\n'.join(st.session_state.bot_prompt) | |
print(f'bot prompt input list:\n{st.session_state.bot_prompt}') | |
print(f'bot prompt input string:\n{input_prompt}') | |
output = generate_output(image, prompt=input_prompt) | |
st.session_state.past.append(user_input) | |
st.session_state.generated.append(output) | |
# Add bot response for next prompt. | |
st.session_state.bot_prompt.append(f'Answer: {output}') | |
with col2: | |
if st.session_state['generated']: | |
for i in range(len(st.session_state['generated']) - 1, -1, -1): | |
message(st.session_state["generated"][i], key=str(i)) | |
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') | |
if __name__ == '__main__': | |
main() |