File size: 3,957 Bytes
4de8fd3
57998d7
a10ed5c
57998d7
 
 
5b7126d
 
 
5467249
4de8fd3
4ef8a52
c0a1eea
 
 
 
 
abe071a
c0a1eea
 
 
 
 
 
 
 
 
 
087827a
5b7126d
 
 
 
c0a1eea
5b7126d
 
 
 
 
 
86668bc
 
82fa495
5467249
 
 
 
 
8598810
5467249
8598810
5467249
 
5b7126d
4ef8a52
087827a
 
5b7126d
 
c0a1eea
2ad73ca
5b7126d
 
 
2ad73ca
5b7126d
 
 
 
 
 
 
 
 
 
 
 
 
b6ac152
087827a
2163596
c0a1eea
 
 
 
 
 
5b7126d
 
 
 
 
 
 
 
 
 
 
b802d0f
5b7126d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
from haystack.utils import fetch_archive_from_http, clean_wiki_text, convert_files_to_docs
from haystack.schema import Answer
from haystack.document_stores import InMemoryDocumentStore
from haystack.pipelines import ExtractiveQAPipeline
from haystack.nodes import FARMReader, TfidfRetriever
import logging
from markdown import markdown
from annotated_text import annotation
from PIL import Image

#Haystack Components
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True)
def start_haystack():
    document_store = InMemoryDocumentStore()
    load_and_write_data(document_store)
    retriever = TfidfRetriever(document_store=document_store)
    reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2-distilled", use_gpu=True)
    pipeline = ExtractiveQAPipeline(reader, retriever)
    return pipeline

def load_and_write_data(document_store):
    doc_dir = './article_txt_got'
    docs = convert_files_to_docs(dir_path=doc_dir, clean_func=clean_wiki_text, split_paragraphs=True)

    document_store.write_documents(docs)

pipeline = start_haystack()

def set_state_if_absent(key, value):
    if key not in st.session_state:
        st.session_state[key] = value

set_state_if_absent("question", "Who is Arya's father?")
set_state_if_absent("results", None)


def reset_results(*args):
    st.session_state.results = None

#Streamlit App

st.title('Haystack Game of Thrones QA ')

image = Image.open('got-haystack.png')
st.image(image)

st.markdown( """
This QA demo uses a [Haystack Extractive QA Pipeline](https://haystack.deepset.ai/components/ready-made-pipelines#extractiveqapipeline) with 
an [InMemoryDocumentStore](https://haystack.deepset.ai/components/document-store) which contains documents about Game of Thrones πŸ‘‘
Go ahead and ask questions about the marvellous kingdom!
""", unsafe_allow_html=True)

question = st.text_input("", value=st.session_state.question, max_chars=100, on_change=reset_results)

def ask_question(question):
    prediction = pipeline.run(query=question, params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}})
    results = []
    for answer in prediction["answers"]:
        answer = answer.to_dict()
        if answer["answer"]:
            results.append(
                {
                    "context": "..." + answer["context"] + "...",
                    "answer": answer["answer"],
                    "relevance": round(answer["score"] * 100, 2),
                    "offset_start_in_doc": answer["offsets_in_document"][0]["start"],
                }
            )
        else:
            results.append(
                {
                    "context": None,
                    "answer": None,
                    "relevance": round(answer["score"] * 100, 2),
                }
            )
    return results

if question:
    with st.spinner("πŸ‘‘    Performing semantic search on royal scripts..."):
        try:
            msg = 'Asked ' + question
            logging.info(msg)
            st.session_state.results = ask_question(question)    
        except Exception as e:
            logging.exception(e)
    


if st.session_state.results:
    st.write('## Top Results')
    for count, result in enumerate(st.session_state.results):
        if result["answer"]:
            answer, context = result["answer"], result["context"]
            start_idx = context.find(answer)
            end_idx = start_idx + len(answer)
            st.write(
                markdown(context[:start_idx] + str(annotation(body=answer, label="ANSWER", background="#964448", color='#ffffff')) + context[end_idx:]),
                unsafe_allow_html=True,
            )
            st.markdown(f"**Relevance:** {result['relevance']}")
        else:
            st.info(
                "πŸ€”    Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
            )