|
from transformers import pipeline |
|
|
|
|
|
CRITIQUE_MODEL = "gpt2" |
|
critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL) |
|
|
|
def generate_sub_questions(main_query: str): |
|
""" |
|
Naively generates sub-questions for the given main query. |
|
""" |
|
return [ |
|
f"1) What are common causes of {main_query}?", |
|
f"2) Which medications are typically used for {main_query}?", |
|
f"3) What are non-pharmacological approaches to {main_query}?" |
|
] |
|
|
|
def self_critique_and_refine(query: str, initial_answer: str, docs: list): |
|
""" |
|
Critiques the initial answer and refines it if necessary. |
|
""" |
|
|
|
critique_prompt = ( |
|
f"The following is an answer to the question '{query}'. " |
|
"Evaluate its correctness, clarity, and completeness. " |
|
"List any missing details or inaccuracies.\n\n" |
|
f"ANSWER:\n{initial_answer}\n\n" |
|
"CRITIQUE:" |
|
) |
|
critique_gen = critique_pipeline(critique_prompt, max_new_tokens=80, truncation=True) |
|
if critique_gen and isinstance(critique_gen, list): |
|
critique_text = critique_gen[0]["generated_text"] |
|
else: |
|
critique_text = "No critique generated." |
|
|
|
|
|
if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]): |
|
refine_prompt = ( |
|
f"Question: {query}\n" |
|
f"Current Answer: {initial_answer}\n" |
|
f"Critique: {critique_text}\n" |
|
"Refine the answer by adding missing or corrected information. " |
|
"Use the context below if needed:\n\n" |
|
+ "\n\n".join(docs) |
|
+ "\nREFINED ANSWER:" |
|
) |
|
|
|
from backend import qa_pipeline |
|
refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True) |
|
if refined_gen and isinstance(refined_gen, list): |
|
refined_answer = refined_gen[0]["generated_text"] |
|
else: |
|
refined_answer = initial_answer |
|
else: |
|
refined_answer = initial_answer |
|
|
|
return refined_answer, critique_text |
|
|