atticusg commited on
Commit
03e24fd
·
verified ·
1 Parent(s): e1060e3

Delete token_position.py

Browse files
Files changed (1) hide show
  1. token_position.py +0 -65
token_position.py DELETED
@@ -1,65 +0,0 @@
1
- """
2
- Token position definitions for MCQA task submission.
3
- This file provides token position functions that identify key tokens in MCQA prompts.
4
- """
5
-
6
- import re
7
- from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index
8
-
9
-
10
- def get_token_positions(pipeline, causal_model):
11
- """
12
- Get token positions for the simple MCQA task.
13
-
14
- Args:
15
- pipeline: The language model pipeline with tokenizer
16
- causal_model: The causal model for the task
17
-
18
- Returns:
19
- list[TokenPosition]: List of TokenPosition objects for intervention experiments
20
- """
21
- def get_correct_symbol_index(input, pipeline, causal_model):
22
- """
23
- Find the index of the correct answer symbol in the prompt.
24
-
25
- Args:
26
- input (Dict): The input dictionary to a causal model
27
- pipeline: The tokenizer pipeline
28
- causal_model: The causal model
29
-
30
- Returns:
31
- list[int]: List containing the index of the correct answer symbol token
32
- """
33
- # Run the model to get the answer position
34
- output = causal_model.run_forward(input)
35
- pointer = output["answer_pointer"]
36
- correct_symbol = output[f"symbol{pointer}"]
37
- prompt = input["raw_input"]
38
-
39
- # Find all single uppercase letters in the prompt
40
- matches = list(re.finditer(r"\b[A-Z]\b", prompt))
41
-
42
- # Find the match corresponding to our correct symbol
43
- symbol_match = None
44
- for match in matches:
45
- if prompt[match.start():match.end()] == correct_symbol:
46
- symbol_match = match
47
- break
48
-
49
- if not symbol_match:
50
- raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}")
51
-
52
- # Get the substring up to the symbol match end
53
- substring = prompt[:symbol_match.end()]
54
- tokenized_substring = list(pipeline.load(substring)["input_ids"][0])
55
-
56
- # The symbol token will be at the end of the substring
57
- return [len(tokenized_substring) - 1]
58
-
59
- # Create TokenPosition objects
60
- token_positions = [
61
- TokenPosition(lambda x: get_correct_symbol_index(x, pipeline, causal_model), pipeline, id="correct_symbol"),
62
- TokenPosition(lambda x: [get_correct_symbol_index(x, pipeline, causal_model)[0]+1], pipeline, id="correct_symbol_period"),
63
- TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token")
64
- ]
65
- return token_positions