File size: 3,827 Bytes
1375d63 b8986a1 1375d63 4fc60ec 1375d63 76cdef3 1375d63 b8986a1 4fc60ec 1375d63 b8986a1 a23c4d4 b8986a1 1375d63 4fc60ec b8986a1 4fc60ec b8986a1 4fc60ec b8986a1 a23c4d4 4fc60ec 76cdef3 a23c4d4 4fc60ec 1375d63 a23c4d4 4fc60ec a23c4d4 1375d63 b8986a1 a23c4d4 4fc60ec b8986a1 4fc60ec b8986a1 1375d63 a23c4d4 1375d63 76cdef3 1375d63 4fc60ec 1375d63 a23c4d4 1375d63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from transformers import pipeline, AutoTokenizer
# ------------------------------
# 1) CRITIQUE MODEL & TOKENIZER (for fast mode, this step can be skipped)
# ------------------------------
CRITIQUE_MODEL_NAME = "distilgpt2"
critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME)
critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME)
# DistilGPT-2 has a maximum context length of ~1024 tokens.
GPT2_MAX_CONTEXT = 1024
# ------------------------------
# 2) SUB-QUESTION GENERATION
# ------------------------------
def generate_sub_questions(main_query: str):
"""
Naively generates sub-questions for the given main query.
"""
return [
f"1) What are common causes of {main_query}?",
f"2) Which medications are typically used for {main_query}?",
f"3) What are non-pharmacological approaches to {main_query}?"
]
# ------------------------------
# 3) SELF-CRITIQUE & REFINEMENT
# ------------------------------
def self_critique_and_refine(query: str, initial_answer: str, docs: list, fast_mode: bool = False):
"""
Uses a smaller model (DistilGPT-2) for self-critique and optionally refines the answer.
If fast_mode is True, the self-critique step is skipped, and the initial answer is returned.
"""
if fast_mode:
# Fast mode: Skip self-critique and refinement
return initial_answer, "Self-critique skipped for speed."
# Construct the critique prompt.
critique_prompt = (
f"The following is an answer to the question '{query}'. "
"Evaluate its correctness, clarity, and completeness. "
"List any missing details or inaccuracies.\n\n"
f"ANSWER:\n{initial_answer}\n\n"
"CRITIQUE:"
)
# Truncate the prompt to ensure there's room for 20 new tokens.
truncated_prompt = _truncate_prompt_for_gpt2(critique_prompt, buffer=20)
# Generate the critique using DistilGPT-2.
critique_gen = critique_pipeline(
truncated_prompt,
max_new_tokens=20,
truncation=True
)
if critique_gen and isinstance(critique_gen, list):
critique_text = critique_gen[0]["generated_text"]
else:
critique_text = "No critique generated."
# If critique flags issues, refine the answer using BioGPT.
if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
refine_prompt = (
f"Question: {query}\n"
f"Current Answer: {initial_answer}\n"
f"Critique: {critique_text}\n"
"Refine the answer by adding missing or corrected information. "
"Use the context below if needed:\n\n"
+ "\n\n".join(docs)
+ "\nREFINED ANSWER:"
)
from backend import qa_pipeline # Import here to avoid circular imports.
refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
if refined_gen and isinstance(refined_gen, list):
refined_answer = refined_gen[0]["generated_text"]
else:
refined_answer = initial_answer
else:
refined_answer = initial_answer
return refined_answer, critique_text
# ------------------------------
# 4) HELPER: GPT-2 PROMPT TRUNCATION
# ------------------------------
def _truncate_prompt_for_gpt2(prompt_text: str, buffer: int = 20) -> str:
"""
Truncates the input prompt so that its token count plus a reserved buffer does not exceed GPT-2's maximum context.
"""
tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False)
max_allowed = GPT2_MAX_CONTEXT - buffer
if len(tokens) > max_allowed:
tokens = tokens[:max_allowed]
truncated_text = critique_tokenizer.decode(tokens, skip_special_tokens=True)
return truncated_text
|