from transformers import Blip2Processor, Blip2ForConditionalGeneration from peft import LoraConfig, get_peft_model, PeftModel import torch import streamlit as st from PIL import Image from streamlit_chat import message from io import BytesIO, StringIO # device = "cuda" if torch.cuda.is_available() else "cpu" device = "cpu" @st.cache_resource def load_model(): config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", ) model_name = "./blip2_fakenews_all" # processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") # device_map = {"": 0} # device_map = "auto" model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl") model = PeftModel.from_pretrained(model, model_name) model = get_peft_model(model, config) return processor, model st.title('Blip2 Fake News Debunker') if 'generated' not in st.session_state: st.session_state['generated'] = [] if 'past' not in st.session_state: st.session_state['past'] = [] if 'bot_prompt' not in st.session_state: st.session_state.bot_prompt = [] def get_text(): chat = st.text_input('Start to chat:', placeholder="Hello! Let's start to chat from here! ") return chat def generate_output(image, prompt): encoding = processor(images=image, text=prompt, max_length=512, truncation=True, padding="max_length", return_tensors="pt") predictions = model.generate(input_ids=encoding['input_ids'], pixel_values=encoding['pixel_values'], max_length=20) p = processor.batch_decode(predictions, skip_special_tokens=True) out = " ".join(p) return out if st.button('Start a new chat'): st.cache_resource.clear() st.cache_data.clear() for key in st.session_state.keys(): del st.session_state[key] st.experimental_rerun() col1, col2 = st.columns(2) show_file = st.empty() with col1: st.markdown("Step 1: ") uploaded_file = st.file_uploader("Upload a news image here: ", type=["png", "jpg"]) if not uploaded_file: show_file.info("Please upload a file of type: " + ", ".join(["png", "jpg"])) if isinstance(uploaded_file, BytesIO): image = Image.open(uploaded_file) st.image(image) with col2: st.markdown("Step 2: ") txt = st.text_area("Paste news content here: ") st.markdown("Step 3: ") user_input = get_text() # if user_input: # st.write("You: ", user_input) processor, model = load_model() def main(): if uploaded_file and user_input: prompt = "Qustions: What is this news about? " \ "\nAnswer: " + txt + \ "\nQustions: " + user_input if len(st.session_state.bot_prompt) == 0: pr: list = prompt.split('\n') pr = [p for p in pr if len(p)] # remove empty string st.session_state.bot_prompt = pr print(f'init: {st.session_state.bot_prompt}') if user_input: st.session_state.bot_prompt.append(f'You: {user_input}') # Convert a list of prompts to a string for the GPT bot. input_prompt: str = '\n'.join(st.session_state.bot_prompt) print(f'bot prompt input list:\n{st.session_state.bot_prompt}') print(f'bot prompt input string:\n{input_prompt}') output = generate_output(image, prompt=input_prompt) st.session_state.past.append(user_input) st.session_state.generated.append(output) # Add bot response for next prompt. st.session_state.bot_prompt.append(f'Answer: {output}') with col2: if st.session_state['generated']: for i in range(len(st.session_state['generated']) - 1, -1, -1): message(st.session_state["generated"][i], key=str(i)) message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') if __name__ == '__main__': main()