Update true_pre_2level.py
Browse files- true_pre_2level.py +21 -15
true_pre_2level.py
CHANGED
@@ -9,6 +9,7 @@ import csv
|
|
9 |
# Set global variables
|
10 |
NUM_EXAMPLES = 900 # 要处理的 JSONL 文件的行数(即例子数量)The number of lines (i.e., examples) in the JSONL file to be processed.
|
11 |
QUESTIONS_PER_EXAMPLE = 2 # 每个例子的标准问题数量 The standard number of questions per example
|
|
|
12 |
|
13 |
model = '/path/model'
|
14 |
jsonl_file = '/path/val.jsonl'
|
@@ -72,18 +73,10 @@ def process_conversations(data, session=None):
|
|
72 |
for i in range(0, len(conv), 2):
|
73 |
human_question = conv[i]['value']
|
74 |
|
75 |
-
# if '<image>' in human_question:
|
76 |
-
|
77 |
-
# else:
|
78 |
-
# response = pipe.chat(human_question, session=session, gen_config=gen_config)
|
79 |
-
|
80 |
-
# generated_answer = response.response.text.strip()
|
81 |
-
|
82 |
if i == 0: # First question
|
83 |
response = pipe.chat((human_question, image), session=session, gen_config=gen_config)
|
84 |
generated_answer = response.response.text.strip()
|
85 |
first_level = extract_level(conv[i + 1]['value'], r'The FirstLevel is (.+)$')
|
86 |
-
# predicted_first_level = extract_level(generated_answer, r'The FirstLevel is (.+)$')
|
87 |
predicted_first_level = extract_level(generated_answer, r'(?:The FirstLevel is )?(.+)$')
|
88 |
# Update first_level_accuracy
|
89 |
if first_level in first_level_accuracy:
|
@@ -106,16 +99,22 @@ def process_conversations(data, session=None):
|
|
106 |
|
107 |
elif i == 2: # Second question
|
108 |
# 提取POI信息 POI
|
|
|
109 |
poi_info = extract_poi_info(conv[i]['value'])
|
110 |
# 提取行人密度信息 People
|
|
|
111 |
pedestrian_density = extract_pedestrian_density(conv[i]['value'])
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
if
|
114 |
-
second_levels = first_level_to_second_levels[
|
115 |
# 构建下一个问题,包括POI和行人密度信息
|
116 |
-
#
|
117 |
-
|
118 |
-
human_question = (f"The FirstLevel category of this image is {predicted_first_level}. "
|
119 |
f"Please select the most likely SecondLevel among {', '.join(second_levels)}. "
|
120 |
"This image contains some POI (Point of Interest) information, "
|
121 |
"which is now provided to you. You can refer to this POI information "
|
@@ -133,22 +132,25 @@ def process_conversations(data, session=None):
|
|
133 |
|
134 |
second_level = extract_level(conv[i + 1]['value'], r'The SecondLevel is (.+)$')
|
135 |
predicted_second_level = extract_level(generated_answer, r'(?:The SecondLevel is )?(.+)$')
|
136 |
-
# predicted_second_level = extract_level(generated_answer, r'The SecondLevel is (.+)$')
|
137 |
|
138 |
# Update second_level_accuracy
|
139 |
# 确保真实的第一级分类存在于结构中
|
|
|
140 |
real_first_level = extract_level(conv[i - 1]['value'], r'The FirstLevel is (.+)$')
|
141 |
|
142 |
if real_first_level in second_level_accuracy:
|
143 |
# 使用真实的一级分类查找对应的二级分类结构
|
|
|
144 |
second_level_data = second_level_accuracy[real_first_level]
|
145 |
|
146 |
# 更新统计总数,确保真实的二级分类存在于结构中
|
|
|
147 |
if second_level in second_level_data:
|
148 |
second_level_data[second_level]['total'] += 1
|
149 |
second_level_writer.writerow([data['id'], real_first_level, second_level, predicted_second_level])
|
150 |
|
151 |
# 比较和记录正确性
|
|
|
152 |
if second_level == predicted_second_level:
|
153 |
correct_second += 1
|
154 |
second_level_data[second_level]['correct'] += 1
|
@@ -172,9 +174,9 @@ correct_first_total = 0
|
|
172 |
correct_second_total = 0
|
173 |
error_logs = []
|
174 |
|
175 |
-
|
176 |
with open(jsonl_file, 'r') as f:
|
177 |
lines = [next(f) for _ in range(NUM_EXAMPLES)] # 只读取前 NUM_EXAMPLES 行
|
|
|
178 |
|
179 |
with open('first_level_results_true_0_9K.csv', 'w', newline='') as first_level_csv_file, \
|
180 |
open('second_level_results_true_0_9K.csv', 'w', newline='') as second_level_csv_file:
|
@@ -194,6 +196,7 @@ with open('first_level_results_true_0_9K.csv', 'w', newline='') as first_level_c
|
|
194 |
error_logs.extend(errors)
|
195 |
|
196 |
# 计算和打印正确率
|
|
|
197 |
total_questions = NUM_EXAMPLES * QUESTIONS_PER_EXAMPLE
|
198 |
|
199 |
first_accuracy = correct_first_total / NUM_EXAMPLES
|
@@ -204,6 +207,7 @@ print(f'Second question accuracy: {second_accuracy * 100:.2f}%')
|
|
204 |
print(f'Overall accuracy: {((correct_first_total + correct_second_total) / total_questions) * 100:.2f}%')
|
205 |
|
206 |
# 计算一级分类正确率
|
|
|
207 |
for first_level in first_level_accuracy:
|
208 |
correct = first_level_accuracy[first_level]['correct']
|
209 |
total = first_level_accuracy[first_level]['total']
|
@@ -212,6 +216,7 @@ for first_level in first_level_accuracy:
|
|
212 |
print(f'Accuracy for FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total} ')
|
213 |
|
214 |
# 计算二级分类正确率
|
|
|
215 |
for first_level, second_levels in second_level_accuracy.items():
|
216 |
for second_level in second_levels:
|
217 |
correct = second_levels[second_level]['correct']
|
@@ -221,5 +226,6 @@ for first_level, second_levels in second_level_accuracy.items():
|
|
221 |
print(f'Accuracy for SecondLevel "{second_level}" under FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total}')
|
222 |
|
223 |
# 将错误记录写入日志文件
|
|
|
224 |
with open('error_log_0_9K', 'w') as outfile:
|
225 |
json.dump(error_logs, outfile, indent=4)
|
|
|
9 |
# Set global variables
|
10 |
NUM_EXAMPLES = 900 # 要处理的 JSONL 文件的行数(即例子数量)The number of lines (i.e., examples) in the JSONL file to be processed.
|
11 |
QUESTIONS_PER_EXAMPLE = 2 # 每个例子的标准问题数量 The standard number of questions per example
|
12 |
+
Trans_level = 1 # 定义全局变量,决定是否切换到跨级判别框架,0为不启用跨级判别框架,1为启用 Define a global variable to decide whether to switch to the trans-level discrimination framework. 0 means disabling the trans-level discrimination framework, and 1 means enabling it.
|
13 |
|
14 |
model = '/path/model'
|
15 |
jsonl_file = '/path/val.jsonl'
|
|
|
73 |
for i in range(0, len(conv), 2):
|
74 |
human_question = conv[i]['value']
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
if i == 0: # First question
|
77 |
response = pipe.chat((human_question, image), session=session, gen_config=gen_config)
|
78 |
generated_answer = response.response.text.strip()
|
79 |
first_level = extract_level(conv[i + 1]['value'], r'The FirstLevel is (.+)$')
|
|
|
80 |
predicted_first_level = extract_level(generated_answer, r'(?:The FirstLevel is )?(.+)$')
|
81 |
# Update first_level_accuracy
|
82 |
if first_level in first_level_accuracy:
|
|
|
99 |
|
100 |
elif i == 2: # Second question
|
101 |
# 提取POI信息 POI
|
102 |
+
# Extract POI information (Points of Interest)
|
103 |
poi_info = extract_poi_info(conv[i]['value'])
|
104 |
# 提取行人密度信息 People
|
105 |
+
# Extract pedestrian density information
|
106 |
pedestrian_density = extract_pedestrian_density(conv[i]['value'])
|
107 |
+
|
108 |
+
if Trans_level == 0:
|
109 |
+
current_first_level = predicted_first_level
|
110 |
+
else:
|
111 |
+
current_first_level = first_level
|
112 |
|
113 |
+
if current_first_level in first_level_to_second_levels:
|
114 |
+
second_levels = first_level_to_second_levels[current_first_level]
|
115 |
# 构建下一个问题,包括POI和行人密度信息
|
116 |
+
# Construct the next question, including POI and pedestrian density information
|
117 |
+
human_question = (f"The FirstLevel category of this image is {current_first_level}. "
|
|
|
118 |
f"Please select the most likely SecondLevel among {', '.join(second_levels)}. "
|
119 |
"This image contains some POI (Point of Interest) information, "
|
120 |
"which is now provided to you. You can refer to this POI information "
|
|
|
132 |
|
133 |
second_level = extract_level(conv[i + 1]['value'], r'The SecondLevel is (.+)$')
|
134 |
predicted_second_level = extract_level(generated_answer, r'(?:The SecondLevel is )?(.+)$')
|
|
|
135 |
|
136 |
# Update second_level_accuracy
|
137 |
# 确保真实的第一级分类存在于结构中
|
138 |
+
# Ensure the real first-level category exists in the structure
|
139 |
real_first_level = extract_level(conv[i - 1]['value'], r'The FirstLevel is (.+)$')
|
140 |
|
141 |
if real_first_level in second_level_accuracy:
|
142 |
# 使用真实的一级分类查找对应的二级分类结构
|
143 |
+
# Use the real first-level category to find the corresponding second-level structure
|
144 |
second_level_data = second_level_accuracy[real_first_level]
|
145 |
|
146 |
# 更新统计总数,确保真实的二级分类存在于结构中
|
147 |
+
# Update statistics, ensuring the real second-level category exists in the structure
|
148 |
if second_level in second_level_data:
|
149 |
second_level_data[second_level]['total'] += 1
|
150 |
second_level_writer.writerow([data['id'], real_first_level, second_level, predicted_second_level])
|
151 |
|
152 |
# 比较和记录正确性
|
153 |
+
# Compare and record correctness
|
154 |
if second_level == predicted_second_level:
|
155 |
correct_second += 1
|
156 |
second_level_data[second_level]['correct'] += 1
|
|
|
174 |
correct_second_total = 0
|
175 |
error_logs = []
|
176 |
|
|
|
177 |
with open(jsonl_file, 'r') as f:
|
178 |
lines = [next(f) for _ in range(NUM_EXAMPLES)] # 只读取前 NUM_EXAMPLES 行
|
179 |
+
# Only read the first NUM_EXAMPLES lines
|
180 |
|
181 |
with open('first_level_results_true_0_9K.csv', 'w', newline='') as first_level_csv_file, \
|
182 |
open('second_level_results_true_0_9K.csv', 'w', newline='') as second_level_csv_file:
|
|
|
196 |
error_logs.extend(errors)
|
197 |
|
198 |
# 计算和打印正确率
|
199 |
+
# Calculate and print accuracy
|
200 |
total_questions = NUM_EXAMPLES * QUESTIONS_PER_EXAMPLE
|
201 |
|
202 |
first_accuracy = correct_first_total / NUM_EXAMPLES
|
|
|
207 |
print(f'Overall accuracy: {((correct_first_total + correct_second_total) / total_questions) * 100:.2f}%')
|
208 |
|
209 |
# 计算一级分类正确率
|
210 |
+
# Calculate FirstLevel accuracy
|
211 |
for first_level in first_level_accuracy:
|
212 |
correct = first_level_accuracy[first_level]['correct']
|
213 |
total = first_level_accuracy[first_level]['total']
|
|
|
216 |
print(f'Accuracy for FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total} ')
|
217 |
|
218 |
# 计算二级分类正确率
|
219 |
+
# Calculate SecondLevel accuracy
|
220 |
for first_level, second_levels in second_level_accuracy.items():
|
221 |
for second_level in second_levels:
|
222 |
correct = second_levels[second_level]['correct']
|
|
|
226 |
print(f'Accuracy for SecondLevel "{second_level}" under FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total}')
|
227 |
|
228 |
# 将错误记录写入日志文件
|
229 |
+
# Write error logs to a file
|
230 |
with open('error_log_0_9K', 'w') as outfile:
|
231 |
json.dump(error_logs, outfile, indent=4)
|