File size: 4,068 Bytes
109e44b
 
 
 
d710f4c
 
 
 
 
 
 
 
 
 
 
 
 
109e44b
d710f4c
 
 
 
f09d32a
 
d710f4c
 
 
109e44b
 
d710f4c
109e44b
d710f4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f09d32a
 
d710f4c
 
 
 
 
 
 
 
 
 
 
109e44b
 
d710f4c
 
109e44b
d710f4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109e44b
 
d710f4c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from peft import LoraConfig, get_peft_model, PeftModel
import torch
import streamlit as st
from PIL import Image
from streamlit_chat import message
from io import BytesIO, StringIO
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
@st.cache_resource
def load_model():
    config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
    )

    model_name = "./blip2_fakenews_all"
    #
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
    # device_map = {"": 0}
    # device_map = "auto"
    model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl")
    model = PeftModel.from_pretrained(model, model_name)
    model = get_peft_model(model, config)
    return processor, model


st.title('Blip2 Fake News Debunker')

if 'generated' not in st.session_state:
    st.session_state['generated'] = []

if 'past' not in st.session_state:
    st.session_state['past'] = []

if 'bot_prompt' not in st.session_state:
    st.session_state.bot_prompt = []


def get_text():
    chat = st.text_input('Start to chat:', placeholder="Hello! Let's start to chat from here! ")
    return chat

def generate_output(image, prompt):
    encoding = processor(images=image, text=prompt, max_length=512, truncation=True,
                         padding="max_length", return_tensors="pt")
    predictions = model.generate(input_ids=encoding['input_ids'],
                                 pixel_values=encoding['pixel_values'],
                                 max_length=20)
    p = processor.batch_decode(predictions, skip_special_tokens=True)
    out = " ".join(p)
    return out

if st.button('Start a new chat'):
    st.cache_resource.clear()
    st.cache_data.clear()
    for key in st.session_state.keys():
        del st.session_state[key]
    st.experimental_rerun()

col1, col2 = st.columns(2)
show_file = st.empty()

with col1:
    st.markdown("Step 1: ")
    uploaded_file = st.file_uploader("Upload a news image here: ", type=["png", "jpg"])

    if not uploaded_file:
        show_file.info("Please upload a file of type: " + ", ".join(["png", "jpg"]))
    if isinstance(uploaded_file, BytesIO):
            image = Image.open(uploaded_file)
            st.image(image)


with col2:
    st.markdown("Step 2: ")
    txt = st.text_area("Paste news content here: ")
    st.markdown("Step 3: ")
    user_input = get_text()
    # if user_input:
    #     st.write("You: ", user_input)

processor, model = load_model()
def main():
    if uploaded_file and user_input:
        prompt = "Qustions: What is this news about? " \
                 "\nAnswer: " + txt  + \
                 "\nQustions: " + user_input

        if len(st.session_state.bot_prompt) == 0:
            pr: list = prompt.split('\n')
            pr = [p for p in pr if len(p)]  # remove empty string
            st.session_state.bot_prompt = pr
            print(f'init: {st.session_state.bot_prompt}')

        if user_input:
            st.session_state.bot_prompt.append(f'You: {user_input}')

            # Convert a list of prompts to a string for the GPT bot.
            input_prompt: str = '\n'.join(st.session_state.bot_prompt)
            print(f'bot prompt input list:\n{st.session_state.bot_prompt}')
            print(f'bot prompt input string:\n{input_prompt}')

            output = generate_output(image, prompt=input_prompt)

            st.session_state.past.append(user_input)
            st.session_state.generated.append(output)

            # Add bot response for next prompt.
            st.session_state.bot_prompt.append(f'Answer: {output}')
        with col2:
            if st.session_state['generated']:
                for i in range(len(st.session_state['generated']) - 1, -1, -1):
                    message(st.session_state["generated"][i], key=str(i))
                    message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')



if __name__ == '__main__':
    main()