Spaces:
Sleeping
Sleeping
File size: 4,219 Bytes
633e625 332312f 633e625 332312f 633e625 332312f 633e625 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from typing import List
import faiss
import numpy as np
import gradio as gr
import requests
import torch
from bs4 import BeautifulSoup
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load retriever model
torch.set_grad_enabled(False) # Disable gradients
device = "cuda" if torch.cuda.is_available() else "cpu"
retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
# Load generation model
tokenizer = AutoTokenizer.from_pretrained("MahmoudH/t5-v1_1-base-abs_qa")
model = AutoModelForSeq2SeqLM.from_pretrained("MahmoudH/t5-v1_1-base-abs_qa", from_tf=True).to(device)
def scrape(urls: List[str]) -> Dataset:
data = []
chunk_size = 100
# Extract the text inside all the <p> tags for each search result
for url in urls:
# Send the request and get the response
response = requests.get(url)
# Parse the response HTML with BeautifulSoup
soup = BeautifulSoup(response.text, "html.parser")
# Find all the <p> tags in the HTML and extract their text
for string in soup.stripped_strings:
text = repr(string).split()
contexts = [
" ".join(text[i : i + chunk_size])
for i in range(0, len(text), chunk_size)
]
for context in contexts:
if len(context.split()) >= 15:
data.append({"context": context, "url": url})
return Dataset.from_list(data)
def search_web(query: str) -> List[str]:
url = f"https://www.google.com/search?q={query}"
# Set the user agent to avoid being blocked by Google
headers = {
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
}
# Send the search request and get the response
response = requests.get(url, headers=headers)
# Parse the response HTML with BeautifulSoup
soup = BeautifulSoup(response.content, "html.parser")
# Find the search results in the HTML
search_results = soup.find_all("div", class_="g")
# Extract the title and URL of the top search results
urls = set()
for result in search_results[:5]:
url = result.find("a")["href"]
if url.startswith("http"):
urls.add(url)
return urls
def generate_answer(question_doc: str) -> str:
q_toks = tokenizer.batch_encode_plus(
[question_doc], max_length=1024, pad_to_max_length=True
)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
model_output = model.generate(
input_ids=q_ids,
attention_mask=q_mask,
max_new_tokens=256,
length_penalty=1.5,
do_sample=True,
num_beams=4
)
answer = tokenizer.batch_decode(model_output, skip_special_tokens=True)[0]
return answer.strip()
def predict(question: str) -> str:
urls = search_web(question)
data = scrape(urls)
# Create vector embeddings and add Faiss index
data_with_embeds = data.map(
lambda batch: {"embeddings": retriever.encode(batch["context"])}, batched=True
)
data_with_embeds.add_faiss_index(
column="embeddings", metric_type=faiss.METRIC_INNER_PRODUCT
)
# Get the most relevant examples
scores, relevant_examples = data_with_embeds.get_nearest_examples(
"embeddings", retriever.encode([question]), k=20
)
doc = "<P> " + " <P> ".join(
relevant_examples["context"]
) # The support document for the model
# Generate answer
question_doc = f"question: {question} context: {doc}"
return generate_answer(question_doc)
input_box = gr.Textbox(label="Question")
output_box = gr.Textbox(label="Answer")
description = """
<div style="text-align: center;">
<p style="font-style: italic;"> Disclaimer: This is just a stupid demo and it craches a lot. Don't take it too seriously.</p>
✌😎
</div>
"""
demo = gr.Interface(
fn=predict, inputs=input_box, outputs=output_box, description=description
).queue()
demo.launch()
|