mgbam commited on
Commit
a23c4d4
·
verified ·
1 Parent(s): 1375d63

Update mini_ladder.py

Browse files
Files changed (1) hide show
  1. mini_ladder.py +26 -28
mini_ladder.py CHANGED
@@ -3,12 +3,11 @@ from transformers import pipeline, AutoTokenizer
3
  # ------------------------------
4
  # 1) CRITIQUE MODEL & TOKENIZER
5
  # ------------------------------
6
- # Using GPT-2 for self-critique
7
  CRITIQUE_MODEL_NAME = "gpt2"
8
  critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME)
9
  critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME)
10
 
11
- # GPT-2 typically has a max context length of 1024 tokens
12
  GPT2_MAX_CONTEXT = 1024
13
 
14
  # ------------------------------
@@ -16,7 +15,7 @@ GPT2_MAX_CONTEXT = 1024
16
  # ------------------------------
17
  def generate_sub_questions(main_query: str):
18
  """
19
- Naive approach to generating sub-questions.
20
  """
21
  return [
22
  f"1) What are common causes of {main_query}?",
@@ -29,10 +28,10 @@ def generate_sub_questions(main_query: str):
29
  # ------------------------------
30
  def self_critique_and_refine(query: str, initial_answer: str, docs: list):
31
  """
32
- 1) Critique the initial answer (GPT-2).
33
- 2) If needed, refine using the original BioGPT pipeline.
34
  """
35
- # A) Construct the critique prompt
36
  critique_prompt = (
37
  f"The following is an answer to the question '{query}'. "
38
  "Evaluate its correctness, clarity, and completeness. "
@@ -40,24 +39,24 @@ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
40
  f"ANSWER:\n{initial_answer}\n\n"
41
  "CRITIQUE:"
42
  )
43
-
44
- # B) Truncate the critique prompt to fit GPT-2’s max context
45
- truncated_critique_prompt = _truncate_prompt_for_gpt2(critique_prompt)
46
-
47
- # C) Generate the critique
 
48
  critique_gen = critique_pipeline(
49
- truncated_critique_prompt,
50
- max_new_tokens=80, # how many tokens to generate for the critique
51
- truncation=True # ensure we don't exceed the final length
52
  )
53
  if critique_gen and isinstance(critique_gen, list):
54
  critique_text = critique_gen[0]["generated_text"]
55
  else:
56
  critique_text = "No critique generated."
57
-
58
- # D) If critique suggests issues, refine using BioGPT
59
  if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
60
- # Build a refine prompt that includes docs
61
  refine_prompt = (
62
  f"Question: {query}\n"
63
  f"Current Answer: {initial_answer}\n"
@@ -67,11 +66,8 @@ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
67
  + "\n\n".join(docs)
68
  + "\nREFINED ANSWER:"
69
  )
70
-
71
- # If BioGPT has similar context limits, you can truncate here too
72
- # e.g., refine_prompt = _truncate_prompt_for_biogpt(refine_prompt)
73
-
74
- from backend import qa_pipeline # Import to avoid circular references
75
  refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
76
  if refined_gen and isinstance(refined_gen, list):
77
  refined_answer = refined_gen[0]["generated_text"]
@@ -83,15 +79,17 @@ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
83
  return refined_answer, critique_text
84
 
85
  # ------------------------------
86
- # 4) HELPER: GPT-2 TRUNCATION
87
  # ------------------------------
88
- def _truncate_prompt_for_gpt2(prompt_text: str) -> str:
89
  """
90
- Token-level truncation to ensure the prompt doesn't exceed GPT-2’s 1024-token limit.
 
91
  """
92
  tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False)
93
- if len(tokens) > GPT2_MAX_CONTEXT:
94
- # Keep the first 1024 tokens
95
- tokens = tokens[:GPT2_MAX_CONTEXT]
 
96
  truncated_text = critique_tokenizer.decode(tokens, skip_special_tokens=True)
97
  return truncated_text
 
3
  # ------------------------------
4
  # 1) CRITIQUE MODEL & TOKENIZER
5
  # ------------------------------
 
6
  CRITIQUE_MODEL_NAME = "gpt2"
7
  critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME)
8
  critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME)
9
 
10
+ # GPT-2 has a maximum context length of 1024 tokens.
11
  GPT2_MAX_CONTEXT = 1024
12
 
13
  # ------------------------------
 
15
  # ------------------------------
16
  def generate_sub_questions(main_query: str):
17
  """
18
+ Naively generates sub-questions for the given main query.
19
  """
20
  return [
21
  f"1) What are common causes of {main_query}?",
 
28
  # ------------------------------
29
  def self_critique_and_refine(query: str, initial_answer: str, docs: list):
30
  """
31
+ 1) Uses GPT-2 to critique the initial answer.
32
+ 2) If the critique indicates missing or incomplete details, refines the answer using BioGPT.
33
  """
34
+ # A) Construct the critique prompt.
35
  critique_prompt = (
36
  f"The following is an answer to the question '{query}'. "
37
  "Evaluate its correctness, clarity, and completeness. "
 
39
  f"ANSWER:\n{initial_answer}\n\n"
40
  "CRITIQUE:"
41
  )
42
+
43
+ # B) Truncate the prompt so that prompt tokens + new tokens <= GPT2_MAX_CONTEXT.
44
+ # Reserve a buffer for new tokens (default 80 tokens).
45
+ truncated_prompt = _truncate_prompt_for_gpt2(critique_prompt, buffer=80)
46
+
47
+ # C) Generate the critique using the truncated prompt.
48
  critique_gen = critique_pipeline(
49
+ truncated_prompt,
50
+ max_new_tokens=80, # tokens to generate for critique
51
+ truncation=True
52
  )
53
  if critique_gen and isinstance(critique_gen, list):
54
  critique_text = critique_gen[0]["generated_text"]
55
  else:
56
  critique_text = "No critique generated."
57
+
58
+ # D) If the critique flags issues, refine the answer using BioGPT.
59
  if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
 
60
  refine_prompt = (
61
  f"Question: {query}\n"
62
  f"Current Answer: {initial_answer}\n"
 
66
  + "\n\n".join(docs)
67
  + "\nREFINED ANSWER:"
68
  )
69
+ # Optionally, if BioGPT also has context limits, apply a similar truncation method.
70
+ from backend import qa_pipeline # Import here to avoid circular imports.
 
 
 
71
  refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
72
  if refined_gen and isinstance(refined_gen, list):
73
  refined_answer = refined_gen[0]["generated_text"]
 
79
  return refined_answer, critique_text
80
 
81
  # ------------------------------
82
+ # 4) HELPER: GPT-2 PROMPT TRUNCATION
83
  # ------------------------------
84
+ def _truncate_prompt_for_gpt2(prompt_text: str, buffer: int = 80) -> str:
85
  """
86
+ Truncates the input prompt so that its token count plus a reserved buffer
87
+ (for new tokens) does not exceed GPT-2's maximum context length.
88
  """
89
  tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False)
90
+ # Ensure we leave room for 'buffer' tokens for generation.
91
+ max_allowed = GPT2_MAX_CONTEXT - buffer
92
+ if len(tokens) > max_allowed:
93
+ tokens = tokens[:max_allowed]
94
  truncated_text = critique_tokenizer.decode(tokens, skip_special_tokens=True)
95
  return truncated_text