Medic / mini_ladder.py
mgbam's picture
Update mini_ladder.py
4fc60ec verified
raw
history blame
3.83 kB
from transformers import pipeline, AutoTokenizer
# ------------------------------
# 1) CRITIQUE MODEL & TOKENIZER (for fast mode, this step can be skipped)
# ------------------------------
CRITIQUE_MODEL_NAME = "distilgpt2"
critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME)
critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME)
# DistilGPT-2 has a maximum context length of ~1024 tokens.
GPT2_MAX_CONTEXT = 1024
# ------------------------------
# 2) SUB-QUESTION GENERATION
# ------------------------------
def generate_sub_questions(main_query: str):
"""
Naively generates sub-questions for the given main query.
"""
return [
f"1) What are common causes of {main_query}?",
f"2) Which medications are typically used for {main_query}?",
f"3) What are non-pharmacological approaches to {main_query}?"
]
# ------------------------------
# 3) SELF-CRITIQUE & REFINEMENT
# ------------------------------
def self_critique_and_refine(query: str, initial_answer: str, docs: list, fast_mode: bool = False):
"""
Uses a smaller model (DistilGPT-2) for self-critique and optionally refines the answer.
If fast_mode is True, the self-critique step is skipped, and the initial answer is returned.
"""
if fast_mode:
# Fast mode: Skip self-critique and refinement
return initial_answer, "Self-critique skipped for speed."
# Construct the critique prompt.
critique_prompt = (
f"The following is an answer to the question '{query}'. "
"Evaluate its correctness, clarity, and completeness. "
"List any missing details or inaccuracies.\n\n"
f"ANSWER:\n{initial_answer}\n\n"
"CRITIQUE:"
)
# Truncate the prompt to ensure there's room for 20 new tokens.
truncated_prompt = _truncate_prompt_for_gpt2(critique_prompt, buffer=20)
# Generate the critique using DistilGPT-2.
critique_gen = critique_pipeline(
truncated_prompt,
max_new_tokens=20,
truncation=True
)
if critique_gen and isinstance(critique_gen, list):
critique_text = critique_gen[0]["generated_text"]
else:
critique_text = "No critique generated."
# If critique flags issues, refine the answer using BioGPT.
if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
refine_prompt = (
f"Question: {query}\n"
f"Current Answer: {initial_answer}\n"
f"Critique: {critique_text}\n"
"Refine the answer by adding missing or corrected information. "
"Use the context below if needed:\n\n"
+ "\n\n".join(docs)
+ "\nREFINED ANSWER:"
)
from backend import qa_pipeline # Import here to avoid circular imports.
refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
if refined_gen and isinstance(refined_gen, list):
refined_answer = refined_gen[0]["generated_text"]
else:
refined_answer = initial_answer
else:
refined_answer = initial_answer
return refined_answer, critique_text
# ------------------------------
# 4) HELPER: GPT-2 PROMPT TRUNCATION
# ------------------------------
def _truncate_prompt_for_gpt2(prompt_text: str, buffer: int = 20) -> str:
"""
Truncates the input prompt so that its token count plus a reserved buffer does not exceed GPT-2's maximum context.
"""
tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False)
max_allowed = GPT2_MAX_CONTEXT - buffer
if len(tokens) > max_allowed:
tokens = tokens[:max_allowed]
truncated_text = critique_tokenizer.decode(tokens, skip_special_tokens=True)
return truncated_text