zhou777 commited on
Commit
2de6616
·
verified ·
1 Parent(s): 126f031

Update true_pre_2level.py

Browse files
Files changed (1) hide show
  1. true_pre_2level.py +21 -15
true_pre_2level.py CHANGED
@@ -9,6 +9,7 @@ import csv
9
  # Set global variables
10
  NUM_EXAMPLES = 900 # 要处理的 JSONL 文件的行数(即例子数量)The number of lines (i.e., examples) in the JSONL file to be processed.
11
  QUESTIONS_PER_EXAMPLE = 2 # 每个例子的标准问题数量 The standard number of questions per example
 
12
 
13
  model = '/path/model'
14
  jsonl_file = '/path/val.jsonl'
@@ -72,18 +73,10 @@ def process_conversations(data, session=None):
72
  for i in range(0, len(conv), 2):
73
  human_question = conv[i]['value']
74
 
75
- # if '<image>' in human_question:
76
-
77
- # else:
78
- # response = pipe.chat(human_question, session=session, gen_config=gen_config)
79
-
80
- # generated_answer = response.response.text.strip()
81
-
82
  if i == 0: # First question
83
  response = pipe.chat((human_question, image), session=session, gen_config=gen_config)
84
  generated_answer = response.response.text.strip()
85
  first_level = extract_level(conv[i + 1]['value'], r'The FirstLevel is (.+)$')
86
- # predicted_first_level = extract_level(generated_answer, r'The FirstLevel is (.+)$')
87
  predicted_first_level = extract_level(generated_answer, r'(?:The FirstLevel is )?(.+)$')
88
  # Update first_level_accuracy
89
  if first_level in first_level_accuracy:
@@ -106,16 +99,22 @@ def process_conversations(data, session=None):
106
 
107
  elif i == 2: # Second question
108
  # 提取POI信息 POI
 
109
  poi_info = extract_poi_info(conv[i]['value'])
110
  # 提取行人密度信息 People
 
111
  pedestrian_density = extract_pedestrian_density(conv[i]['value'])
 
 
 
 
 
112
 
113
- if predicted_first_level in first_level_to_second_levels:
114
- second_levels = first_level_to_second_levels[predicted_first_level]
115
  # 构建下一个问题,包括POI和行人密度信息
116
- # 如果使用跨级判别框架,请修改{predicted_first_level}变量为{first_level}变量
117
- # If using a cross-level discriminative framework, please modify the {predicted_first_level} variable to the {first_level} variable.
118
- human_question = (f"The FirstLevel category of this image is {predicted_first_level}. "
119
  f"Please select the most likely SecondLevel among {', '.join(second_levels)}. "
120
  "This image contains some POI (Point of Interest) information, "
121
  "which is now provided to you. You can refer to this POI information "
@@ -133,22 +132,25 @@ def process_conversations(data, session=None):
133
 
134
  second_level = extract_level(conv[i + 1]['value'], r'The SecondLevel is (.+)$')
135
  predicted_second_level = extract_level(generated_answer, r'(?:The SecondLevel is )?(.+)$')
136
- # predicted_second_level = extract_level(generated_answer, r'The SecondLevel is (.+)$')
137
 
138
  # Update second_level_accuracy
139
  # 确保真实的第一级分类存在于结构中
 
140
  real_first_level = extract_level(conv[i - 1]['value'], r'The FirstLevel is (.+)$')
141
 
142
  if real_first_level in second_level_accuracy:
143
  # 使用真实的一级分类查找对应的二级分类结构
 
144
  second_level_data = second_level_accuracy[real_first_level]
145
 
146
  # 更新统计总数,确保真实的二级分类存在于结构中
 
147
  if second_level in second_level_data:
148
  second_level_data[second_level]['total'] += 1
149
  second_level_writer.writerow([data['id'], real_first_level, second_level, predicted_second_level])
150
 
151
  # 比较和记录正确性
 
152
  if second_level == predicted_second_level:
153
  correct_second += 1
154
  second_level_data[second_level]['correct'] += 1
@@ -172,9 +174,9 @@ correct_first_total = 0
172
  correct_second_total = 0
173
  error_logs = []
174
 
175
-
176
  with open(jsonl_file, 'r') as f:
177
  lines = [next(f) for _ in range(NUM_EXAMPLES)] # 只读取前 NUM_EXAMPLES 行
 
178
 
179
  with open('first_level_results_true_0_9K.csv', 'w', newline='') as first_level_csv_file, \
180
  open('second_level_results_true_0_9K.csv', 'w', newline='') as second_level_csv_file:
@@ -194,6 +196,7 @@ with open('first_level_results_true_0_9K.csv', 'w', newline='') as first_level_c
194
  error_logs.extend(errors)
195
 
196
  # 计算和打印正确率
 
197
  total_questions = NUM_EXAMPLES * QUESTIONS_PER_EXAMPLE
198
 
199
  first_accuracy = correct_first_total / NUM_EXAMPLES
@@ -204,6 +207,7 @@ print(f'Second question accuracy: {second_accuracy * 100:.2f}%')
204
  print(f'Overall accuracy: {((correct_first_total + correct_second_total) / total_questions) * 100:.2f}%')
205
 
206
  # 计算一级分类正确率
 
207
  for first_level in first_level_accuracy:
208
  correct = first_level_accuracy[first_level]['correct']
209
  total = first_level_accuracy[first_level]['total']
@@ -212,6 +216,7 @@ for first_level in first_level_accuracy:
212
  print(f'Accuracy for FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total} ')
213
 
214
  # 计算二级分类正确率
 
215
  for first_level, second_levels in second_level_accuracy.items():
216
  for second_level in second_levels:
217
  correct = second_levels[second_level]['correct']
@@ -221,5 +226,6 @@ for first_level, second_levels in second_level_accuracy.items():
221
  print(f'Accuracy for SecondLevel "{second_level}" under FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total}')
222
 
223
  # 将错误记录写入日志文件
 
224
  with open('error_log_0_9K', 'w') as outfile:
225
  json.dump(error_logs, outfile, indent=4)
 
9
  # Set global variables
10
  NUM_EXAMPLES = 900 # 要处理的 JSONL 文件的行数(即例子数量)The number of lines (i.e., examples) in the JSONL file to be processed.
11
  QUESTIONS_PER_EXAMPLE = 2 # 每个例子的标准问题数量 The standard number of questions per example
12
+ Trans_level = 1 # 定义全局变量,决定是否切换到跨级判别框架,0为不启用跨级判别框架,1为启用 Define a global variable to decide whether to switch to the trans-level discrimination framework. 0 means disabling the trans-level discrimination framework, and 1 means enabling it.
13
 
14
  model = '/path/model'
15
  jsonl_file = '/path/val.jsonl'
 
73
  for i in range(0, len(conv), 2):
74
  human_question = conv[i]['value']
75
 
 
 
 
 
 
 
 
76
  if i == 0: # First question
77
  response = pipe.chat((human_question, image), session=session, gen_config=gen_config)
78
  generated_answer = response.response.text.strip()
79
  first_level = extract_level(conv[i + 1]['value'], r'The FirstLevel is (.+)$')
 
80
  predicted_first_level = extract_level(generated_answer, r'(?:The FirstLevel is )?(.+)$')
81
  # Update first_level_accuracy
82
  if first_level in first_level_accuracy:
 
99
 
100
  elif i == 2: # Second question
101
  # 提取POI信息 POI
102
+ # Extract POI information (Points of Interest)
103
  poi_info = extract_poi_info(conv[i]['value'])
104
  # 提取行人密度信息 People
105
+ # Extract pedestrian density information
106
  pedestrian_density = extract_pedestrian_density(conv[i]['value'])
107
+
108
+ if Trans_level == 0:
109
+ current_first_level = predicted_first_level
110
+ else:
111
+ current_first_level = first_level
112
 
113
+ if current_first_level in first_level_to_second_levels:
114
+ second_levels = first_level_to_second_levels[current_first_level]
115
  # 构建下一个问题,包括POI和行人密度信息
116
+ # Construct the next question, including POI and pedestrian density information
117
+ human_question = (f"The FirstLevel category of this image is {current_first_level}. "
 
118
  f"Please select the most likely SecondLevel among {', '.join(second_levels)}. "
119
  "This image contains some POI (Point of Interest) information, "
120
  "which is now provided to you. You can refer to this POI information "
 
132
 
133
  second_level = extract_level(conv[i + 1]['value'], r'The SecondLevel is (.+)$')
134
  predicted_second_level = extract_level(generated_answer, r'(?:The SecondLevel is )?(.+)$')
 
135
 
136
  # Update second_level_accuracy
137
  # 确保真实的第一级分类存在于结构中
138
+ # Ensure the real first-level category exists in the structure
139
  real_first_level = extract_level(conv[i - 1]['value'], r'The FirstLevel is (.+)$')
140
 
141
  if real_first_level in second_level_accuracy:
142
  # 使用真实的一级分类查找对应的二级分类结构
143
+ # Use the real first-level category to find the corresponding second-level structure
144
  second_level_data = second_level_accuracy[real_first_level]
145
 
146
  # 更新统计总数,确保真实的二级分类存在于结构中
147
+ # Update statistics, ensuring the real second-level category exists in the structure
148
  if second_level in second_level_data:
149
  second_level_data[second_level]['total'] += 1
150
  second_level_writer.writerow([data['id'], real_first_level, second_level, predicted_second_level])
151
 
152
  # 比较和记录正确性
153
+ # Compare and record correctness
154
  if second_level == predicted_second_level:
155
  correct_second += 1
156
  second_level_data[second_level]['correct'] += 1
 
174
  correct_second_total = 0
175
  error_logs = []
176
 
 
177
  with open(jsonl_file, 'r') as f:
178
  lines = [next(f) for _ in range(NUM_EXAMPLES)] # 只读取前 NUM_EXAMPLES 行
179
+ # Only read the first NUM_EXAMPLES lines
180
 
181
  with open('first_level_results_true_0_9K.csv', 'w', newline='') as first_level_csv_file, \
182
  open('second_level_results_true_0_9K.csv', 'w', newline='') as second_level_csv_file:
 
196
  error_logs.extend(errors)
197
 
198
  # 计算和打印正确率
199
+ # Calculate and print accuracy
200
  total_questions = NUM_EXAMPLES * QUESTIONS_PER_EXAMPLE
201
 
202
  first_accuracy = correct_first_total / NUM_EXAMPLES
 
207
  print(f'Overall accuracy: {((correct_first_total + correct_second_total) / total_questions) * 100:.2f}%')
208
 
209
  # 计算一级分类正确率
210
+ # Calculate FirstLevel accuracy
211
  for first_level in first_level_accuracy:
212
  correct = first_level_accuracy[first_level]['correct']
213
  total = first_level_accuracy[first_level]['total']
 
216
  print(f'Accuracy for FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total} ')
217
 
218
  # 计算二级分类正确率
219
+ # Calculate SecondLevel accuracy
220
  for first_level, second_levels in second_level_accuracy.items():
221
  for second_level in second_levels:
222
  correct = second_levels[second_level]['correct']
 
226
  print(f'Accuracy for SecondLevel "{second_level}" under FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total}')
227
 
228
  # 将错误记录写入日志文件
229
+ # Write error logs to a file
230
  with open('error_log_0_9K', 'w') as outfile:
231
  json.dump(error_logs, outfile, indent=4)