mgbam commited on
Commit
1375d63
·
verified ·
1 Parent(s): fc878b0

Update mini_ladder.py

Browse files
Files changed (1) hide show
  1. mini_ladder.py +52 -11
mini_ladder.py CHANGED
@@ -1,12 +1,22 @@
1
- from transformers import pipeline
2
 
3
- # A second pipeline for self-critique (using a lighter model for demonstration)
4
- CRITIQUE_MODEL = "gpt2" # This can be replaced with another model as needed
5
- critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL)
 
 
 
 
6
 
 
 
 
 
 
 
7
  def generate_sub_questions(main_query: str):
8
  """
9
- Naively generates sub-questions for the given main query.
10
  """
11
  return [
12
  f"1) What are common causes of {main_query}?",
@@ -14,11 +24,15 @@ def generate_sub_questions(main_query: str):
14
  f"3) What are non-pharmacological approaches to {main_query}?"
15
  ]
16
 
 
 
 
17
  def self_critique_and_refine(query: str, initial_answer: str, docs: list):
18
  """
19
- Critiques the initial answer and refines it if necessary.
 
20
  """
21
- # Step 1: Generate a critique using a critique prompt
22
  critique_prompt = (
23
  f"The following is an answer to the question '{query}'. "
24
  "Evaluate its correctness, clarity, and completeness. "
@@ -26,14 +40,24 @@ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
26
  f"ANSWER:\n{initial_answer}\n\n"
27
  "CRITIQUE:"
28
  )
29
- critique_gen = critique_pipeline(critique_prompt, max_new_tokens=80, truncation=True)
 
 
 
 
 
 
 
 
 
30
  if critique_gen and isinstance(critique_gen, list):
31
  critique_text = critique_gen[0]["generated_text"]
32
  else:
33
  critique_text = "No critique generated."
34
 
35
- # Step 2: If the critique suggests issues, refine the answer using the original QA pipeline.
36
  if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
 
37
  refine_prompt = (
38
  f"Question: {query}\n"
39
  f"Current Answer: {initial_answer}\n"
@@ -43,8 +67,11 @@ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
43
  + "\n\n".join(docs)
44
  + "\nREFINED ANSWER:"
45
  )
46
- # Import the qa_pipeline from backend to reuse it (local import to avoid circular dependencies)
47
- from backend import qa_pipeline
 
 
 
48
  refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
49
  if refined_gen and isinstance(refined_gen, list):
50
  refined_answer = refined_gen[0]["generated_text"]
@@ -54,3 +81,17 @@ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
54
  refined_answer = initial_answer
55
 
56
  return refined_answer, critique_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, AutoTokenizer
2
 
3
+ # ------------------------------
4
+ # 1) CRITIQUE MODEL & TOKENIZER
5
+ # ------------------------------
6
+ # Using GPT-2 for self-critique
7
+ CRITIQUE_MODEL_NAME = "gpt2"
8
+ critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME)
9
+ critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME)
10
 
11
+ # GPT-2 typically has a max context length of 1024 tokens
12
+ GPT2_MAX_CONTEXT = 1024
13
+
14
+ # ------------------------------
15
+ # 2) SUB-QUESTION GENERATION
16
+ # ------------------------------
17
  def generate_sub_questions(main_query: str):
18
  """
19
+ Naive approach to generating sub-questions.
20
  """
21
  return [
22
  f"1) What are common causes of {main_query}?",
 
24
  f"3) What are non-pharmacological approaches to {main_query}?"
25
  ]
26
 
27
+ # ------------------------------
28
+ # 3) SELF-CRITIQUE & REFINEMENT
29
+ # ------------------------------
30
  def self_critique_and_refine(query: str, initial_answer: str, docs: list):
31
  """
32
+ 1) Critique the initial answer (GPT-2).
33
+ 2) If needed, refine using the original BioGPT pipeline.
34
  """
35
+ # A) Construct the critique prompt
36
  critique_prompt = (
37
  f"The following is an answer to the question '{query}'. "
38
  "Evaluate its correctness, clarity, and completeness. "
 
40
  f"ANSWER:\n{initial_answer}\n\n"
41
  "CRITIQUE:"
42
  )
43
+
44
+ # B) Truncate the critique prompt to fit GPT-2’s max context
45
+ truncated_critique_prompt = _truncate_prompt_for_gpt2(critique_prompt)
46
+
47
+ # C) Generate the critique
48
+ critique_gen = critique_pipeline(
49
+ truncated_critique_prompt,
50
+ max_new_tokens=80, # how many tokens to generate for the critique
51
+ truncation=True # ensure we don't exceed the final length
52
+ )
53
  if critique_gen and isinstance(critique_gen, list):
54
  critique_text = critique_gen[0]["generated_text"]
55
  else:
56
  critique_text = "No critique generated."
57
 
58
+ # D) If critique suggests issues, refine using BioGPT
59
  if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
60
+ # Build a refine prompt that includes docs
61
  refine_prompt = (
62
  f"Question: {query}\n"
63
  f"Current Answer: {initial_answer}\n"
 
67
  + "\n\n".join(docs)
68
  + "\nREFINED ANSWER:"
69
  )
70
+
71
+ # If BioGPT has similar context limits, you can truncate here too
72
+ # e.g., refine_prompt = _truncate_prompt_for_biogpt(refine_prompt)
73
+
74
+ from backend import qa_pipeline # Import to avoid circular references
75
  refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
76
  if refined_gen and isinstance(refined_gen, list):
77
  refined_answer = refined_gen[0]["generated_text"]
 
81
  refined_answer = initial_answer
82
 
83
  return refined_answer, critique_text
84
+
85
+ # ------------------------------
86
+ # 4) HELPER: GPT-2 TRUNCATION
87
+ # ------------------------------
88
+ def _truncate_prompt_for_gpt2(prompt_text: str) -> str:
89
+ """
90
+ Token-level truncation to ensure the prompt doesn't exceed GPT-2’s 1024-token limit.
91
+ """
92
+ tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False)
93
+ if len(tokens) > GPT2_MAX_CONTEXT:
94
+ # Keep the first 1024 tokens
95
+ tokens = tokens[:GPT2_MAX_CONTEXT]
96
+ truncated_text = critique_tokenizer.decode(tokens, skip_special_tokens=True)
97
+ return truncated_text