abs-qa-demo / app.py
MahmoudH's picture
Update app.py
332312f
from typing import List
import faiss
import numpy as np
import gradio as gr
import requests
import torch
from bs4 import BeautifulSoup
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load retriever model
torch.set_grad_enabled(False) # Disable gradients
device = "cuda" if torch.cuda.is_available() else "cpu"
retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
# Load generation model
tokenizer = AutoTokenizer.from_pretrained("MahmoudH/t5-v1_1-base-abs_qa")
model = AutoModelForSeq2SeqLM.from_pretrained("MahmoudH/t5-v1_1-base-abs_qa", from_tf=True).to(device)
def scrape(urls: List[str]) -> Dataset:
data = []
chunk_size = 100
# Extract the text inside all the <p> tags for each search result
for url in urls:
# Send the request and get the response
response = requests.get(url)
# Parse the response HTML with BeautifulSoup
soup = BeautifulSoup(response.text, "html.parser")
# Find all the <p> tags in the HTML and extract their text
for string in soup.stripped_strings:
text = repr(string).split()
contexts = [
" ".join(text[i : i + chunk_size])
for i in range(0, len(text), chunk_size)
]
for context in contexts:
if len(context.split()) >= 15:
data.append({"context": context, "url": url})
return Dataset.from_list(data)
def search_web(query: str) -> List[str]:
url = f"https://www.google.com/search?q={query}"
# Set the user agent to avoid being blocked by Google
headers = {
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
}
# Send the search request and get the response
response = requests.get(url, headers=headers)
# Parse the response HTML with BeautifulSoup
soup = BeautifulSoup(response.content, "html.parser")
# Find the search results in the HTML
search_results = soup.find_all("div", class_="g")
# Extract the title and URL of the top search results
urls = set()
for result in search_results[:5]:
url = result.find("a")["href"]
if url.startswith("http"):
urls.add(url)
return urls
def generate_answer(question_doc: str) -> str:
q_toks = tokenizer.batch_encode_plus(
[question_doc], max_length=1024, pad_to_max_length=True
)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
model_output = model.generate(
input_ids=q_ids,
attention_mask=q_mask,
max_new_tokens=256,
length_penalty=1.5,
do_sample=True,
num_beams=4
)
answer = tokenizer.batch_decode(model_output, skip_special_tokens=True)[0]
return answer.strip()
def predict(question: str) -> str:
urls = search_web(question)
data = scrape(urls)
# Create vector embeddings and add Faiss index
data_with_embeds = data.map(
lambda batch: {"embeddings": retriever.encode(batch["context"])}, batched=True
)
data_with_embeds.add_faiss_index(
column="embeddings", metric_type=faiss.METRIC_INNER_PRODUCT
)
# Get the most relevant examples
scores, relevant_examples = data_with_embeds.get_nearest_examples(
"embeddings", retriever.encode([question]), k=20
)
doc = "<P> " + " <P> ".join(
relevant_examples["context"]
) # The support document for the model
# Generate answer
question_doc = f"question: {question} context: {doc}"
return generate_answer(question_doc)
input_box = gr.Textbox(label="Question")
output_box = gr.Textbox(label="Answer")
description = """
<div style="text-align: center;">
<p style="font-style: italic;"> Disclaimer: This is just a stupid demo and it craches a lot. Don't take it too seriously.</p>
✌😎
</div>
"""
demo = gr.Interface(
fn=predict, inputs=input_box, outputs=output_box, description=description
).queue()
demo.launch()