File size: 3,827 Bytes
1375d63
b8986a1
1375d63
4fc60ec
1375d63
76cdef3
1375d63
 
b8986a1
4fc60ec
1375d63
 
 
 
 
b8986a1
 
a23c4d4
b8986a1
 
 
 
 
 
 
1375d63
 
 
4fc60ec
b8986a1
4fc60ec
 
b8986a1
4fc60ec
 
 
 
 
b8986a1
 
 
 
 
 
 
a23c4d4
4fc60ec
76cdef3
a23c4d4
4fc60ec
1375d63
a23c4d4
4fc60ec
a23c4d4
1375d63
b8986a1
 
 
 
a23c4d4
4fc60ec
b8986a1
 
 
 
 
 
 
 
 
 
4fc60ec
b8986a1
 
 
 
 
 
 
 
 
1375d63
 
a23c4d4
1375d63
76cdef3
1375d63
4fc60ec
1375d63
 
a23c4d4
 
 
1375d63
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from transformers import pipeline, AutoTokenizer

# ------------------------------
# 1) CRITIQUE MODEL & TOKENIZER (for fast mode, this step can be skipped)
# ------------------------------
CRITIQUE_MODEL_NAME = "distilgpt2"
critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME)
critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME)

# DistilGPT-2 has a maximum context length of ~1024 tokens.
GPT2_MAX_CONTEXT = 1024

# ------------------------------
# 2) SUB-QUESTION GENERATION
# ------------------------------
def generate_sub_questions(main_query: str):
    """
    Naively generates sub-questions for the given main query.
    """
    return [
        f"1) What are common causes of {main_query}?",
        f"2) Which medications are typically used for {main_query}?",
        f"3) What are non-pharmacological approaches to {main_query}?"
    ]

# ------------------------------
# 3) SELF-CRITIQUE & REFINEMENT
# ------------------------------
def self_critique_and_refine(query: str, initial_answer: str, docs: list, fast_mode: bool = False):
    """
    Uses a smaller model (DistilGPT-2) for self-critique and optionally refines the answer.
    If fast_mode is True, the self-critique step is skipped, and the initial answer is returned.
    """
    if fast_mode:
        # Fast mode: Skip self-critique and refinement
        return initial_answer, "Self-critique skipped for speed."

    # Construct the critique prompt.
    critique_prompt = (
        f"The following is an answer to the question '{query}'. "
        "Evaluate its correctness, clarity, and completeness. "
        "List any missing details or inaccuracies.\n\n"
        f"ANSWER:\n{initial_answer}\n\n"
        "CRITIQUE:"
    )
    
    # Truncate the prompt to ensure there's room for 20 new tokens.
    truncated_prompt = _truncate_prompt_for_gpt2(critique_prompt, buffer=20)
    
    # Generate the critique using DistilGPT-2.
    critique_gen = critique_pipeline(
        truncated_prompt,
        max_new_tokens=20,
        truncation=True
    )
    if critique_gen and isinstance(critique_gen, list):
        critique_text = critique_gen[0]["generated_text"]
    else:
        critique_text = "No critique generated."
    
    # If critique flags issues, refine the answer using BioGPT.
    if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
        refine_prompt = (
            f"Question: {query}\n"
            f"Current Answer: {initial_answer}\n"
            f"Critique: {critique_text}\n"
            "Refine the answer by adding missing or corrected information. "
            "Use the context below if needed:\n\n"
            + "\n\n".join(docs)
            + "\nREFINED ANSWER:"
        )
        from backend import qa_pipeline  # Import here to avoid circular imports.
        refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
        if refined_gen and isinstance(refined_gen, list):
            refined_answer = refined_gen[0]["generated_text"]
        else:
            refined_answer = initial_answer
    else:
        refined_answer = initial_answer

    return refined_answer, critique_text

# ------------------------------
# 4) HELPER: GPT-2 PROMPT TRUNCATION
# ------------------------------
def _truncate_prompt_for_gpt2(prompt_text: str, buffer: int = 20) -> str:
    """
    Truncates the input prompt so that its token count plus a reserved buffer does not exceed GPT-2's maximum context.
    """
    tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False)
    max_allowed = GPT2_MAX_CONTEXT - buffer
    if len(tokens) > max_allowed:
        tokens = tokens[:max_allowed]
    truncated_text = critique_tokenizer.decode(tokens, skip_special_tokens=True)
    return truncated_text