Spaces:
Sleeping
Sleeping
from typing import List | |
import faiss | |
import numpy as np | |
import gradio as gr | |
import requests | |
import torch | |
from bs4 import BeautifulSoup | |
from datasets import Dataset | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
# Load retriever model | |
torch.set_grad_enabled(False) # Disable gradients | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device) | |
# Load generation model | |
tokenizer = AutoTokenizer.from_pretrained("MahmoudH/t5-v1_1-base-abs_qa") | |
model = AutoModelForSeq2SeqLM.from_pretrained("MahmoudH/t5-v1_1-base-abs_qa", from_tf=True).to(device) | |
def scrape(urls: List[str]) -> Dataset: | |
data = [] | |
chunk_size = 100 | |
# Extract the text inside all the <p> tags for each search result | |
for url in urls: | |
# Send the request and get the response | |
response = requests.get(url) | |
# Parse the response HTML with BeautifulSoup | |
soup = BeautifulSoup(response.text, "html.parser") | |
# Find all the <p> tags in the HTML and extract their text | |
for string in soup.stripped_strings: | |
text = repr(string).split() | |
contexts = [ | |
" ".join(text[i : i + chunk_size]) | |
for i in range(0, len(text), chunk_size) | |
] | |
for context in contexts: | |
if len(context.split()) >= 15: | |
data.append({"context": context, "url": url}) | |
return Dataset.from_list(data) | |
def search_web(query: str) -> List[str]: | |
url = f"https://www.google.com/search?q={query}" | |
# Set the user agent to avoid being blocked by Google | |
headers = { | |
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36" | |
} | |
# Send the search request and get the response | |
response = requests.get(url, headers=headers) | |
# Parse the response HTML with BeautifulSoup | |
soup = BeautifulSoup(response.content, "html.parser") | |
# Find the search results in the HTML | |
search_results = soup.find_all("div", class_="g") | |
# Extract the title and URL of the top search results | |
urls = set() | |
for result in search_results[:5]: | |
url = result.find("a")["href"] | |
if url.startswith("http"): | |
urls.add(url) | |
return urls | |
def generate_answer(question_doc: str) -> str: | |
q_toks = tokenizer.batch_encode_plus( | |
[question_doc], max_length=1024, pad_to_max_length=True | |
) | |
q_ids, q_mask = ( | |
torch.LongTensor(q_toks["input_ids"]).to(device), | |
torch.LongTensor(q_toks["attention_mask"]).to(device), | |
) | |
model_output = model.generate( | |
input_ids=q_ids, | |
attention_mask=q_mask, | |
max_new_tokens=256, | |
length_penalty=1.5, | |
do_sample=True, | |
num_beams=4 | |
) | |
answer = tokenizer.batch_decode(model_output, skip_special_tokens=True)[0] | |
return answer.strip() | |
def predict(question: str) -> str: | |
urls = search_web(question) | |
data = scrape(urls) | |
# Create vector embeddings and add Faiss index | |
data_with_embeds = data.map( | |
lambda batch: {"embeddings": retriever.encode(batch["context"])}, batched=True | |
) | |
data_with_embeds.add_faiss_index( | |
column="embeddings", metric_type=faiss.METRIC_INNER_PRODUCT | |
) | |
# Get the most relevant examples | |
scores, relevant_examples = data_with_embeds.get_nearest_examples( | |
"embeddings", retriever.encode([question]), k=20 | |
) | |
doc = "<P> " + " <P> ".join( | |
relevant_examples["context"] | |
) # The support document for the model | |
# Generate answer | |
question_doc = f"question: {question} context: {doc}" | |
return generate_answer(question_doc) | |
input_box = gr.Textbox(label="Question") | |
output_box = gr.Textbox(label="Answer") | |
description = """ | |
<div style="text-align: center;"> | |
<p style="font-style: italic;"> Disclaimer: This is just a stupid demo and it craches a lot. Don't take it too seriously.</p> | |
✌😎 | |
</div> | |
""" | |
demo = gr.Interface( | |
fn=predict, inputs=input_box, outputs=output_box, description=description | |
).queue() | |
demo.launch() | |