File size: 3,411 Bytes
4de8fd3
57998d7
a10ed5c
57998d7
 
 
5b7126d
 
 
4de8fd3
 
 
4ef8a52
57998d7
087827a
3710fa9
087827a
 
5b7126d
 
 
 
 
 
 
 
 
 
 
 
 
 
4ef8a52
 
 
 
 
 
5b7126d
86668bc
 
4ef8a52
5b7126d
4ef8a52
 
86668bc
087827a
 
5b7126d
 
2ad73ca
5b7126d
 
 
2ad73ca
5b7126d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6ac152
087827a
5b7126d
375d479
d5afdae
5b7126d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
from haystack.utils import fetch_archive_from_http, clean_wiki_text, convert_files_to_docs
from haystack.schema import Answer
from haystack.document_stores import InMemoryDocumentStore
from haystack.pipelines import ExtractiveQAPipeline
from haystack.nodes import FARMReader, TfidfRetriever
import logging
from markdown import markdown
from annotated_text import annotation
import validators
import json

#Haystack Components
document_store = InMemoryDocumentStore()
retriever = TfidfRetriever(document_store=document_store)
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", use_gpu=True)
pipeline = ExtractiveQAPipeline(reader, retriever)

def set_state_if_absent(key, value):
    if key not in st.session_state:
        st.session_state[key] = value


st.set_page_config(page_title="Game of Thrones QA with Haystack", page_icon="https://haystack.deepset.ai/img/HaystackIcon.png")

set_state_if_absent("question", "Who is Arya's father")
set_state_if_absent("results", None)


def reset_results(*args):
    st.session_state.results = None

def load_and_write_data():
    doc_dir = './article_txt_got'
    docs = convert_files_to_docs(dir_path=doc_dir, clean_func=clean_wiki_text, split_paragraphs=True)

    document_store.write_documents(docs)


#Streamlit App

st.title('Game of Thrones QA with Haystack')
question = st.text_input("", value=st.session_state.question, max_chars=100, on_change=reset_results)

load_and_write_data()

def ask_question(question):
    prediction = pipeline.run(query=question, params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}})
    results = []
    for answer in prediction["answers"]:
        if answer["answer"]:
            results.append(
                {
                    "context": "..." + answer["context"] + "...",
                    "answer": answer["answer"],
                    "relevance": round(answer["score"] * 100, 2),
                    "offset_start_in_doc": answer["offsets_in_document"][0]["start"],
                }
            )
        else:
            results.append(
                {
                    "context": None,
                    "answer": None,
                    "relevance": round(answer["score"] * 100, 2),
                }
            )
    return results
    # st.write(prediction['answers'][0].to_dict())
    # st.write(prediction['answers'][1].to_dict())
    # st.write(prediction['answers'][2].to_dict())
    

if question:
    try:
        msg = 'Asekd ' + question
        logging.info(msg)
        st.session_state.results = ask_question(question)    
    except Exception as e:
        logging.exception(e)
    


if st.session_state.results:
    st.write('## Top Results')
    for count, result in enumerate(st.session_state.results):
        if result["answer"]:
            answer, context = result["answer"], result["context"]
            start_idx = context.find(answer)
            end_idx = start_idx + len(answer)
            st.write(
                markdown(context[:start_idx] + str(annotation(answer, "ANSWER", "#8ef")) + context[end_idx:]),
                unsafe_allow_html=True,
            )
            st.markdown(f"**Relevance:** {result['relevance']}")
        else:
            st.info(
                "🤔    Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
            )