from transformers import pipeline, AutoTokenizer # ------------------------------ # 1) CRITIQUE MODEL & TOKENIZER (for fast mode, this step can be skipped) # ------------------------------ CRITIQUE_MODEL_NAME = "distilgpt2" critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME) critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME) # DistilGPT-2 has a maximum context length of ~1024 tokens. GPT2_MAX_CONTEXT = 1024 # ------------------------------ # 2) SUB-QUESTION GENERATION # ------------------------------ def generate_sub_questions(main_query: str): """ Naively generates sub-questions for the given main query. """ return [ f"1) What are common causes of {main_query}?", f"2) Which medications are typically used for {main_query}?", f"3) What are non-pharmacological approaches to {main_query}?" ] # ------------------------------ # 3) SELF-CRITIQUE & REFINEMENT # ------------------------------ def self_critique_and_refine(query: str, initial_answer: str, docs: list, fast_mode: bool = False): """ Uses a smaller model (DistilGPT-2) for self-critique and optionally refines the answer. If fast_mode is True, the self-critique step is skipped, and the initial answer is returned. """ if fast_mode: # Fast mode: Skip self-critique and refinement return initial_answer, "Self-critique skipped for speed." # Construct the critique prompt. critique_prompt = ( f"The following is an answer to the question '{query}'. " "Evaluate its correctness, clarity, and completeness. " "List any missing details or inaccuracies.\n\n" f"ANSWER:\n{initial_answer}\n\n" "CRITIQUE:" ) # Truncate the prompt to ensure there's room for 20 new tokens. truncated_prompt = _truncate_prompt_for_gpt2(critique_prompt, buffer=20) # Generate the critique using DistilGPT-2. critique_gen = critique_pipeline( truncated_prompt, max_new_tokens=20, truncation=True ) if critique_gen and isinstance(critique_gen, list): critique_text = critique_gen[0]["generated_text"] else: critique_text = "No critique generated." # If critique flags issues, refine the answer using BioGPT. if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]): refine_prompt = ( f"Question: {query}\n" f"Current Answer: {initial_answer}\n" f"Critique: {critique_text}\n" "Refine the answer by adding missing or corrected information. " "Use the context below if needed:\n\n" + "\n\n".join(docs) + "\nREFINED ANSWER:" ) from backend import qa_pipeline # Import here to avoid circular imports. refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True) if refined_gen and isinstance(refined_gen, list): refined_answer = refined_gen[0]["generated_text"] else: refined_answer = initial_answer else: refined_answer = initial_answer return refined_answer, critique_text # ------------------------------ # 4) HELPER: GPT-2 PROMPT TRUNCATION # ------------------------------ def _truncate_prompt_for_gpt2(prompt_text: str, buffer: int = 20) -> str: """ Truncates the input prompt so that its token count plus a reserved buffer does not exceed GPT-2's maximum context. """ tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False) max_allowed = GPT2_MAX_CONTEXT - buffer if len(tokens) > max_allowed: tokens = tokens[:max_allowed] truncated_text = critique_tokenizer.decode(tokens, skip_special_tokens=True) return truncated_text