File size: 953 Bytes
66c3d74
 
315d67e
66c3d74
 
 
 
315d67e
66c3d74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration

# Initialize the model
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)

# Streamlit UI
st.title("AI Health Assistant (RAG-based)")

def get_answer_rag(question):
    inputs = tokenizer(question, return_tensors="pt")
    retrieved_docs = retriever.retrieve(inputs['input_ids'], top_k=3)
    outputs = model.generate(input_ids=inputs['input_ids'], context_input_ids=retrieved_docs['input_ids'])
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer

# Ask the user for a health-related question
question = st.text_input("Ask a health-related question:")
if question:
    answer = get_answer_rag(question)
    st.write(f"Answer: {answer}")