harishma-a commited on
Commit
6288425
·
verified ·
1 Parent(s): feeb259

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -1,29 +1,31 @@
1
  import gradio as gr
2
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
 
4
- # Load the RAG model and tokenizer from Hugging Face
5
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
6
- retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="legacy", use_dummy_dataset=True, trust_remote_code=True)
7
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
8
 
9
- # Function to process the user query
10
- def rag_query(query):
11
- # Tokenize the input query
12
  inputs = tokenizer(query, return_tensors="pt")
13
 
14
- # Retrieve relevant documents using the retriever
15
- retrieved_docs = retriever(input_ids=inputs["input_ids"], return_tensors="pt")
16
 
17
- # Generate an answer using the RAG model
18
- generated_ids = model.generate(input_ids=inputs["input_ids"], context_input_ids=retrieved_docs["context_input_ids"])
19
-
20
- # Decode the generated answer and return it
21
- answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
22
  return answer
23
 
24
- # Create a Gradio interface for the app
25
- iface = gr.Interface(fn=rag_query, inputs="text", outputs="text",
26
- title="RAG Demo", description="Ask a question and get an answer using Retrieval-Augmented Generation (RAG).")
 
 
 
 
 
 
 
27
 
28
- # Launch the Gradio interface
29
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import RagTokenizer, RagTokenForGeneration
3
 
4
+ # Load tokenizer and model
5
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
6
+ model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
 
7
 
8
+ def rag_generate(query):
9
+ # Tokenize the input question
 
10
  inputs = tokenizer(query, return_tensors="pt")
11
 
12
+ # Generate output
13
+ generated_ids = model.generate(**inputs)
14
 
15
+ # Decode the generated response
16
+ answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
17
+
 
 
18
  return answer
19
 
20
+ # Gradio Interface
21
+ with gr.Blocks() as demo:
22
+ gr.Markdown("# 🤖 RAG Token QA with facebook/rag-token-nq")
23
+ with gr.Row():
24
+ question = gr.Textbox(label="Ask your question")
25
+ with gr.Row():
26
+ answer = gr.Textbox(label="Answer")
27
+
28
+ submit_btn = gr.Button("Generate Answer")
29
+ submit_btn.click(fn=rag_generate, inputs=question, outputs=answer)
30
 
31
+ demo.launch()