mgbam commited on
Commit
76cdef3
·
verified ·
1 Parent(s): 70284a9

Update mini_ladder.py

Browse files
Files changed (1) hide show
  1. mini_ladder.py +16 -16
mini_ladder.py CHANGED
@@ -3,11 +3,13 @@ from transformers import pipeline, AutoTokenizer
3
  # ------------------------------
4
  # 1) CRITIQUE MODEL & TOKENIZER
5
  # ------------------------------
6
- CRITIQUE_MODEL_NAME = "gpt2"
 
7
  critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME)
8
  critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME)
9
 
10
- # GPT-2 has a maximum context length of 1024 tokens.
 
11
  GPT2_MAX_CONTEXT = 1024
12
 
13
  # ------------------------------
@@ -28,8 +30,8 @@ def generate_sub_questions(main_query: str):
28
  # ------------------------------
29
  def self_critique_and_refine(query: str, initial_answer: str, docs: list):
30
  """
31
- 1) Uses GPT-2 to critique the initial answer.
32
- 2) If the critique indicates missing or incomplete details, refines the answer using BioGPT.
33
  """
34
  # A) Construct the critique prompt.
35
  critique_prompt = (
@@ -40,14 +42,13 @@ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
40
  "CRITIQUE:"
41
  )
42
 
43
- # B) Truncate the prompt so that prompt tokens + new tokens <= GPT2_MAX_CONTEXT.
44
- # Reserve a buffer for new tokens (default 80 tokens).
45
- truncated_prompt = _truncate_prompt_for_gpt2(critique_prompt, buffer=80)
46
 
47
- # C) Generate the critique using the truncated prompt.
48
  critique_gen = critique_pipeline(
49
  truncated_prompt,
50
- max_new_tokens=20, # tokens to generate for critique
51
  truncation=True
52
  )
53
  if critique_gen and isinstance(critique_gen, list):
@@ -55,7 +56,7 @@ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
55
  else:
56
  critique_text = "No critique generated."
57
 
58
- # D) If the critique flags issues, refine the answer using BioGPT.
59
  if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
60
  refine_prompt = (
61
  f"Question: {query}\n"
@@ -66,8 +67,8 @@ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
66
  + "\n\n".join(docs)
67
  + "\nREFINED ANSWER:"
68
  )
69
- # Optionally, if BioGPT also has context limits, apply a similar truncation method.
70
- from backend import qa_pipeline # Import here to avoid circular imports.
71
  refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
72
  if refined_gen and isinstance(refined_gen, list):
73
  refined_answer = refined_gen[0]["generated_text"]
@@ -81,13 +82,12 @@ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
81
  # ------------------------------
82
  # 4) HELPER: GPT-2 PROMPT TRUNCATION
83
  # ------------------------------
84
- def _truncate_prompt_for_gpt2(prompt_text: str, buffer: int = 80) -> str:
85
  """
86
- Truncates the input prompt so that its token count plus a reserved buffer
87
- (for new tokens) does not exceed GPT-2's maximum context length.
88
  """
89
  tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False)
90
- # Ensure we leave room for 'buffer' tokens for generation.
91
  max_allowed = GPT2_MAX_CONTEXT - buffer
92
  if len(tokens) > max_allowed:
93
  tokens = tokens[:max_allowed]
 
3
  # ------------------------------
4
  # 1) CRITIQUE MODEL & TOKENIZER
5
  # ------------------------------
6
+ # Use DistilGPT-2 (a smaller, distilled version of GPT-2) for faster inference on CPU.
7
+ CRITIQUE_MODEL_NAME = "distilgpt2"
8
  critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL_NAME)
9
  critique_tokenizer = AutoTokenizer.from_pretrained(CRITIQUE_MODEL_NAME)
10
 
11
+ # DistilGPT-2 has a maximum context length similar to GPT-2 (around 1024 tokens),
12
+ # but we reserve a smaller buffer since we now generate fewer tokens.
13
  GPT2_MAX_CONTEXT = 1024
14
 
15
  # ------------------------------
 
30
  # ------------------------------
31
  def self_critique_and_refine(query: str, initial_answer: str, docs: list):
32
  """
33
+ Uses a smaller model (DistilGPT-2) for self-critique, with a reduced max_new_tokens.
34
+ If the critique indicates issues, refines the answer using BioGPT.
35
  """
36
  # A) Construct the critique prompt.
37
  critique_prompt = (
 
42
  "CRITIQUE:"
43
  )
44
 
45
+ # B) Truncate the prompt so that prompt tokens + new tokens (20) <= GPT2_MAX_CONTEXT.
46
+ truncated_prompt = _truncate_prompt_for_gpt2(critique_prompt, buffer=20)
 
47
 
48
+ # C) Generate the critique using DistilGPT-2.
49
  critique_gen = critique_pipeline(
50
  truncated_prompt,
51
+ max_new_tokens=20, # Reduced new tokens for speed.
52
  truncation=True
53
  )
54
  if critique_gen and isinstance(critique_gen, list):
 
56
  else:
57
  critique_text = "No critique generated."
58
 
59
+ # D) If the critique flags issues, refine using BioGPT.
60
  if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
61
  refine_prompt = (
62
  f"Question: {query}\n"
 
67
  + "\n\n".join(docs)
68
  + "\nREFINED ANSWER:"
69
  )
70
+ # Optionally, you might also truncate the refine_prompt if needed.
71
+ from backend import qa_pipeline # Import here to avoid circular dependencies.
72
  refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
73
  if refined_gen and isinstance(refined_gen, list):
74
  refined_answer = refined_gen[0]["generated_text"]
 
82
  # ------------------------------
83
  # 4) HELPER: GPT-2 PROMPT TRUNCATION
84
  # ------------------------------
85
+ def _truncate_prompt_for_gpt2(prompt_text: str, buffer: int = 20) -> str:
86
  """
87
+ Truncates the input prompt so that its token count plus a reserved buffer for new tokens
88
+ does not exceed GPT-2's (or DistilGPT-2's) maximum context length.
89
  """
90
  tokens = critique_tokenizer.encode(prompt_text, add_special_tokens=False)
 
91
  max_allowed = GPT2_MAX_CONTEXT - buffer
92
  if len(tokens) > max_allowed:
93
  tokens = tokens[:max_allowed]