Medic / mini_ladder.py
mgbam's picture
Update mini_ladder.py
1375d63 verified
raw
history blame
3.77 kB
from transformers import pipeline, AutoTokenizer
# ------------------------------
# 1) CRITIQUE MODEL & TOKENIZER
# ------------------------------
# Using GPT-2 for self-critique
CRITIQUE_MODEL_NAME = "gpt2"
critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME)
critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME)
# GPT-2 typically has a max context length of 1024 tokens
GPT2_MAX_CONTEXT = 1024
# ------------------------------
# 2) SUB-QUESTION GENERATION
# ------------------------------
def generate_sub_questions(main_query: str):
"""
Naive approach to generating sub-questions.
"""
return [
f"1) What are common causes of {main_query}?",
f"2) Which medications are typically used for {main_query}?",
f"3) What are non-pharmacological approaches to {main_query}?"
]
# ------------------------------
# 3) SELF-CRITIQUE & REFINEMENT
# ------------------------------
def self_critique_and_refine(query: str, initial_answer: str, docs: list):
"""
1) Critique the initial answer (GPT-2).
2) If needed, refine using the original BioGPT pipeline.
"""
# A) Construct the critique prompt
critique_prompt = (
f"The following is an answer to the question '{query}'. "
"Evaluate its correctness, clarity, and completeness. "
"List any missing details or inaccuracies.\n\n"
f"ANSWER:\n{initial_answer}\n\n"
"CRITIQUE:"
)
# B) Truncate the critique prompt to fit GPT-2’s max context
truncated_critique_prompt = _truncate_prompt_for_gpt2(critique_prompt)
# C) Generate the critique
critique_gen = critique_pipeline(
truncated_critique_prompt,
max_new_tokens=80, # how many tokens to generate for the critique
truncation=True # ensure we don't exceed the final length
)
if critique_gen and isinstance(critique_gen, list):
critique_text = critique_gen[0]["generated_text"]
else:
critique_text = "No critique generated."
# D) If critique suggests issues, refine using BioGPT
if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
# Build a refine prompt that includes docs
refine_prompt = (
f"Question: {query}\n"
f"Current Answer: {initial_answer}\n"
f"Critique: {critique_text}\n"
"Refine the answer by adding missing or corrected information. "
"Use the context below if needed:\n\n"
+ "\n\n".join(docs)
+ "\nREFINED ANSWER:"
)
# If BioGPT has similar context limits, you can truncate here too
# e.g., refine_prompt = _truncate_prompt_for_biogpt(refine_prompt)
from backend import qa_pipeline # Import to avoid circular references
refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
if refined_gen and isinstance(refined_gen, list):
refined_answer = refined_gen[0]["generated_text"]
else:
refined_answer = initial_answer
else:
refined_answer = initial_answer
return refined_answer, critique_text
# ------------------------------
# 4) HELPER: GPT-2 TRUNCATION
# ------------------------------
def _truncate_prompt_for_gpt2(prompt_text: str) -> str:
"""
Token-level truncation to ensure the prompt doesn't exceed GPT-2’s 1024-token limit.
"""
tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False)
if len(tokens) > GPT2_MAX_CONTEXT:
# Keep the first 1024 tokens
tokens = tokens[:GPT2_MAX_CONTEXT]
truncated_text = critique_tokenizer.decode(tokens, skip_special_tokens=True)
return truncated_text