|
from transformers import pipeline, AutoTokenizer |
|
|
|
|
|
|
|
|
|
CRITIQUE_MODEL_NAME = "distilgpt2" |
|
critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME) |
|
critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME) |
|
|
|
|
|
GPT2_MAX_CONTEXT = 1024 |
|
|
|
|
|
|
|
|
|
def generate_sub_questions(main_query: str): |
|
""" |
|
Naively generates sub-questions for the given main query. |
|
""" |
|
return [ |
|
f"1) What are common causes of {main_query}?", |
|
f"2) Which medications are typically used for {main_query}?", |
|
f"3) What are non-pharmacological approaches to {main_query}?" |
|
] |
|
|
|
|
|
|
|
|
|
def self_critique_and_refine(query: str, initial_answer: str, docs: list, fast_mode: bool = False): |
|
""" |
|
Uses a smaller model (DistilGPT-2) for self-critique and optionally refines the answer. |
|
If fast_mode is True, the self-critique step is skipped, and the initial answer is returned. |
|
""" |
|
if fast_mode: |
|
|
|
return initial_answer, "Self-critique skipped for speed." |
|
|
|
|
|
critique_prompt = ( |
|
f"The following is an answer to the question '{query}'. " |
|
"Evaluate its correctness, clarity, and completeness. " |
|
"List any missing details or inaccuracies.\n\n" |
|
f"ANSWER:\n{initial_answer}\n\n" |
|
"CRITIQUE:" |
|
) |
|
|
|
|
|
truncated_prompt = _truncate_prompt_for_gpt2(critique_prompt, buffer=20) |
|
|
|
|
|
critique_gen = critique_pipeline( |
|
truncated_prompt, |
|
max_new_tokens=20, |
|
truncation=True |
|
) |
|
if critique_gen and isinstance(critique_gen, list): |
|
critique_text = critique_gen[0]["generated_text"] |
|
else: |
|
critique_text = "No critique generated." |
|
|
|
|
|
if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]): |
|
refine_prompt = ( |
|
f"Question: {query}\n" |
|
f"Current Answer: {initial_answer}\n" |
|
f"Critique: {critique_text}\n" |
|
"Refine the answer by adding missing or corrected information. " |
|
"Use the context below if needed:\n\n" |
|
+ "\n\n".join(docs) |
|
+ "\nREFINED ANSWER:" |
|
) |
|
from backend import qa_pipeline |
|
refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True) |
|
if refined_gen and isinstance(refined_gen, list): |
|
refined_answer = refined_gen[0]["generated_text"] |
|
else: |
|
refined_answer = initial_answer |
|
else: |
|
refined_answer = initial_answer |
|
|
|
return refined_answer, critique_text |
|
|
|
|
|
|
|
|
|
def _truncate_prompt_for_gpt2(prompt_text: str, buffer: int = 20) -> str: |
|
""" |
|
Truncates the input prompt so that its token count plus a reserved buffer does not exceed GPT-2's maximum context. |
|
""" |
|
tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False) |
|
max_allowed = GPT2_MAX_CONTEXT - buffer |
|
if len(tokens) > max_allowed: |
|
tokens = tokens[:max_allowed] |
|
truncated_text = critique_tokenizer.decode(tokens, skip_special_tokens=True) |
|
return truncated_text |
|
|