File size: 1,780 Bytes
76b04ec
 
 
 
 
 
1307d30
76b04ec
1307d30
76b04ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1307d30
 
76b04ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from retriever import retrieve_relevant_docs
from langchain_core.prompts import PromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_google_genai import ChatGoogleGenerativeAI

import os
# LLM used for both doc chain and fallback answer
llm = ChatGoogleGenerativeAI(model="models/gemini-1.5-flash", temperature=0.3, api_key=os.getenv("GOOGLE_API_KEY"))

# Define the structured prompt
prompt = PromptTemplate.from_template("""
You are a helpful medical assistant. Use only the dataset context below to answer.

Context:
{context}

Question: {input}

If you are unsure, say "Sorry, I couldn't find an answer based on the dataset." Do not guess.
""")

# Build document chain and retrieval chain
document_chain = create_stuff_documents_chain(llm, prompt)
retriever_chain = create_retrieval_chain(retrieve_relevant_docs(), document_chain)

# Expose chain for Streamlit app
graph = retriever_chain

# Manual fallback function if needed
def generate_answer(query: str, context: str) -> str:
    # LLM used for both doc chain and fallback answer
    llm = ChatGoogleGenerativeAI(model="models/gemini-1.5-flash", temperature=0.3)
    if not context.strip():
        return "Sorry, I couldn't find an answer based on the dataset."

    fallback_llm = ChatGoogleGenerativeAI(model="models/gemini-1.5-flash", temperature=0.3)
    fallback_prompt = f"""
    You are a helpful medical assistant. Use only the dataset context below to answer.

    Context:
    {context}

    Question: {query}

    If you are unsure, say "Sorry, I couldn't find an answer based on the dataset." Do not guess.
    """
    response = fallback_llm.invoke(fallback_prompt)
    return response.content.strip()