Wendy commited on
Commit
02cb6ef
·
verified ·
1 Parent(s): 2661ed3

Upload main_.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. main_.py +316 -0
main_.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ import re
6
+ import json
7
+ import argparse
8
+ import random
9
+ from transformers import AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
10
+ from model import T5ForMultimodalGeneration
11
+ from utils_data import AITWDatasetImg, load_data
12
+ from rich.table import Column, Table
13
+ from rich import box
14
+ from rich.console import Console
15
+ console = Console(record=True)
16
+ import action_matching, action_type
17
+ import evaluate
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--data_root', type=str, default='dataset/blip/general_blip')
22
+ parser.add_argument('--output_dir', type=str, default='experiments')
23
+ parser.add_argument('--model', type=str, default='declare-lab/flan-alpaca-base')
24
+ parser.add_argument('--data_ratio', type=float, default=None)
25
+ parser.add_argument('--eval_name', type=str, default=None, help='the saved subset name used for evaluation')
26
+ parser.add_argument('--local_rank', type=int, default=-1)
27
+ parser.add_argument('--epoch', type=int, default=2)
28
+ parser.add_argument('--lr', type=float, default=5e-5)
29
+ parser.add_argument('--warmup_ratio', type=float, default=0.1)
30
+ parser.add_argument('--bs', type=int, default=1)
31
+ parser.add_argument('--debug_num', type=int, default=None)
32
+ parser.add_argument('--input_len', type=int, default=512)
33
+ parser.add_argument('--output_len', type=int, default=256)
34
+ parser.add_argument('--img_dim', type=int, default=1408)
35
+ parser.add_argument('--eval_bs', type=int, default=16)
36
+ parser.add_argument('--eval_acc', type=int, default=None, help='evaluate accumulation step')
37
+ parser.add_argument('--all_data', type=float, default=None, help='whether using all the data for training. Set the ratio for google apps to save computation')
38
+ parser.add_argument('--eval_subset', type=str, default=None, help='use which subset for evaluation/test when training with all data')
39
+ parser.add_argument('--use_history', type=int, default=8, help='use textual action history')
40
+ parser.add_argument('--use_img_history', action='store_true', help='use screen history')
41
+ parser.add_argument('--use_future', type=int, default=16, help='planning the future actions before giving the current action')
42
+ parser.add_argument('--use_layout', action='store_true', help='use annotated layout information')
43
+ parser.add_argument('--transform_axis', default=True, action='store_true', help='use coordinate normalization')
44
+ parser.add_argument('--use_generate', default=True, action='store_true', help='only for baseline to improve inference speed')
45
+ parser.add_argument('--final_eval', action='store_true', help='only evaluate the model at the final epoch')
46
+ parser.add_argument('--user_msg', type=str, default="debug", help='experiment type in the save_dir')
47
+ parser.add_argument('--img_type', type=str, default="blip", help='type of image features')
48
+ parser.add_argument('--evaluate_dir', type=str, default=None, help='the directory of model for evaluation')
49
+ parser.add_argument('--seed', type=int, default=42, help='random seed')
50
+
51
+ args = parser.parse_args()
52
+ return args
53
+
54
+ if __name__ == '__main__':
55
+
56
+ # training logger to log training progress
57
+ training_logger = Table(
58
+ Column("Epoch", justify="center"),
59
+ Column("Steps", justify="center"),
60
+ Column("Loss", justify="center"),
61
+ title="Training Status",
62
+ pad_edge=False,
63
+ box=box.ASCII,
64
+ )
65
+
66
+ args = parse_args()
67
+ print("args",args)
68
+ print('====Input Arguments====')
69
+ print(json.dumps(vars(args), indent=2, sort_keys=False))
70
+
71
+ random.seed(args.seed)
72
+ torch.manual_seed(args.seed) # pytorch random seed
73
+ np.random.seed(args.seed) # numpy random seed
74
+ torch.backends.cudnn.deterministic = True
75
+
76
+ if not os.path.exists(args.output_dir):
77
+ os.mkdir(args.output_dir)
78
+ if args.evaluate_dir is not None:
79
+ args.model = args.evaluate_dir
80
+
81
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
82
+
83
+ console.log(f"""[Model]: Loading {args.model}...\n""")
84
+ console.log(f"[Data]: Reading data...\n")
85
+
86
+ if args.evaluate_dir is not None:
87
+ save_dir = args.evaluate_dir
88
+ else:
89
+ model_name = args.model.replace("/","-")
90
+ gpu_count = torch.cuda.device_count()
91
+ save_dir = f"{args.output_dir}/{args.user_msg}_{model_name}_{args.img_type}_lr{args.lr}_bs{args.bs * gpu_count}_ip{args.input_len}_op{args.output_len}_ep{args.epoch}"
92
+ if not os.path.exists(save_dir):
93
+ os.mkdir(save_dir)
94
+ print(save_dir)
95
+
96
+ model = T5ForMultimodalGeneration.from_pretrained(args.model, args.img_dim)
97
+
98
+
99
+ if args.evaluate_dir is not None:
100
+ train_set = None
101
+ else:
102
+ training_data = load_data(args, "train")
103
+ train_set = AITWDatasetImg(
104
+ training_data,
105
+ tokenizer,
106
+ args.input_len,
107
+ args.output_len
108
+ )
109
+ eval_data = load_data(args, "val")
110
+ eval_set = AITWDatasetImg(
111
+ eval_data,
112
+ tokenizer,
113
+ args.input_len,
114
+ args.output_len
115
+ )
116
+ test_data = load_data(args, "test")
117
+ test_set = AITWDatasetImg(
118
+ test_data,
119
+ tokenizer,
120
+ args.input_len,
121
+ args.output_len
122
+ )
123
+ block = 2000
124
+ for i in range(len(test_set)):
125
+ test_set[i] = test_set[i][:block]
126
+ for i in range(len(eval_set)):
127
+ eval_set[i] = eval_set[i][:block]
128
+ datacollator = DataCollatorForSeq2Seq(tokenizer)
129
+ print("model parameters: ", model.num_parameters())
130
+
131
+ # rougel for rationale generation
132
+ metric = evaluate.load("rouge")
133
+ def compute_metrics_rouge(eval_preds):
134
+ preds, targets = eval_preds
135
+ if isinstance(preds, tuple):
136
+ preds = preds[0]
137
+ preds= np.where(preds != -100, preds, tokenizer.pad_token_id)
138
+ preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
139
+ targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
140
+
141
+ result = metric.compute(predictions=preds, references=targets)
142
+ result = {k: round(v * 100, 4) for k, v in result.items()}
143
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
144
+ result["gen_len"] = np.mean(prediction_lens)
145
+ return result
146
+
147
+ # only use the last model for evaluation to save time
148
+ if args.final_eval:
149
+ training_args = Seq2SeqTrainingArguments(
150
+ save_dir,
151
+ do_train=True if args.evaluate_dir is None else False,
152
+ do_eval=False,
153
+ warmup_ratio=args.warmup_ratio,
154
+ evaluation_strategy="no",
155
+ logging_strategy="steps",
156
+ save_strategy="epoch",
157
+ save_total_limit = 2,
158
+ learning_rate= args.lr,
159
+ eval_accumulation_steps=args.eval_acc,
160
+ per_device_train_batch_size=args.bs,
161
+ per_device_eval_batch_size=args.eval_bs,
162
+ weight_decay=0.01,
163
+ num_train_epochs=args.epoch,
164
+ predict_with_generate=args.use_generate,
165
+ generation_max_length=args.output_len,
166
+ report_to="none",
167
+ local_rank=args.local_rank
168
+ )
169
+ # evaluate at each epoch
170
+ else:
171
+ training_args = Seq2SeqTrainingArguments(
172
+ save_dir,
173
+ do_train=True if args.evaluate_dir is None else False,
174
+ do_eval=True,
175
+ warmup_ratio=args.warmup_ratio,
176
+ evaluation_strategy="epoch",
177
+ logging_strategy="steps",
178
+ save_strategy="epoch",
179
+ save_total_limit = 2,
180
+ learning_rate= args.lr,
181
+ eval_accumulation_steps=args.eval_acc,
182
+ per_device_train_batch_size=args.bs,
183
+ per_device_eval_batch_size=args.eval_bs,
184
+ weight_decay=0.01,
185
+ num_train_epochs=args.epoch,
186
+ metric_for_best_model="rougeL",
187
+ predict_with_generate=args.use_generate,
188
+ generation_max_length=args.output_len,
189
+ load_best_model_at_end=True,
190
+ report_to="none",
191
+ local_rank=args.local_rank
192
+ )
193
+
194
+ trainer = Seq2SeqTrainer(
195
+ model=model,
196
+ args=training_args,
197
+ train_dataset=train_set,
198
+ eval_dataset=eval_set,
199
+ data_collator=datacollator,
200
+ tokenizer=tokenizer,
201
+ compute_metrics = compute_metrics_rouge
202
+ )
203
+
204
+ if args.evaluate_dir is None:
205
+ trainer.train()
206
+ trainer.save_model(save_dir)
207
+
208
+ # metrics = trainer.evaluate(eval_dataset = test_set, max_length=args.output_len)
209
+ # trainer.log_metrics("test", metrics)
210
+ # trainer.save_metrics("test", metrics)
211
+ metrics = {}
212
+
213
+ predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len)
214
+ if trainer.is_world_process_zero():
215
+ preds, targets = predict_results.predictions, predict_results.label_ids
216
+ preds= np.where(preds != -100, preds, tokenizer.pad_token_id)
217
+ preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
218
+ targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
219
+
220
+ action_correct = 0
221
+ text_correct = 0
222
+ type_correct = 0
223
+
224
+ reference_test_positions = test_set.anno_positions
225
+
226
+ output_data = []
227
+
228
+ pattern = r'(?<=Action Decision:\s).*'
229
+
230
+ assert len(preds) == len(targets) == len(reference_test_positions)
231
+ for idx, pred in enumerate(preds):
232
+ try:
233
+ result = re.search(pattern, targets[idx])
234
+ target_text = result.group(0)
235
+ target_text = target_text.strip()
236
+
237
+ reference = eval("{" + target_text + "}")
238
+ except:
239
+ print("reference error")
240
+ continue
241
+
242
+ try:
243
+ result = re.search(pattern, preds[idx])
244
+ pred_text = result.group(0)
245
+ pred_text = pred_text.strip()
246
+
247
+ pred = eval("{" + pred_text + "}")
248
+ action_1_touch_yx = eval(pred["touch_point"])
249
+ action_1_lift_yx = eval(pred["lift_point"])
250
+ action_1_action_type = action_type.ActionType[pred["action_type"]].value
251
+ action_1_typed_text = pred["typed_text"].lower()
252
+ action_1_typed_text = action_1_typed_text.strip()
253
+
254
+ action_1_wrap = f'"action_type": "{action_1_action_type}", "touch_point": "{action_1_touch_yx}", "lift_point": "{action_1_lift_yx}", "typed_text": "{action_1_typed_text}"'
255
+ action_1_wrap = action_1_wrap.replace('"', "'")
256
+ except:
257
+ pred = '{ "action_type": "TYPE", "touch_point": "[-1.0, -1.0]", "lift_point": "[-1.0, -1.0]", "typed_text": "Invalid"}'
258
+
259
+ action_2_touch_yx = eval(reference["touch_point"])
260
+ action_2_lift_yx = eval(reference["lift_point"])
261
+ action_2_action_type = action_type.ActionType[reference["action_type"]].value
262
+ action_2_typed_text = reference["typed_text"].lower()
263
+
264
+ action_2_wrap = f'"action_type": "{action_2_action_type}", "touch_point": "{action_2_touch_yx}", "lift_point": "{action_2_lift_yx}", "typed_text": "{action_2_typed_text}"'
265
+ action_2_wrap = action_2_wrap.replace('"', "'")
266
+
267
+ annotation_positions = reference_test_positions[idx]
268
+
269
+ try:
270
+ check_match = action_matching.check_actions_match(
271
+ action_1_touch_yx,
272
+ action_1_lift_yx,
273
+ action_1_action_type,
274
+ action_2_touch_yx,
275
+ action_2_lift_yx,
276
+ action_2_action_type,
277
+ annotation_positions
278
+ )
279
+
280
+ except Exception as exc:
281
+ print(idx, action_1_touch_yx, action_1_lift_yx)
282
+ check_match = False
283
+ match_label = "invalid"
284
+
285
+ if check_match:
286
+ action_correct += 1
287
+ match_label = 1
288
+ else:
289
+ match_label = 0
290
+ if check_match and (action_1_typed_text in action_2_typed_text or action_2_typed_text in action_1_typed_text):
291
+ text_correct += 1
292
+ if action_1_action_type == action_2_action_type:
293
+ type_correct += 1
294
+
295
+ action_data = {"pred": action_1_wrap, "target": action_2_wrap, "match_label": match_label}
296
+ output_data.append(action_data)
297
+
298
+ metrics["accuracy"] = "{:.2f}".format(action_correct/len(targets) * 100)
299
+ metrics["text_acc"] = "{:.2f}".format(text_correct/len(targets) * 100)
300
+ metrics["type_acc"] = "{:.2f}".format(type_correct/len(targets) * 100)
301
+ metrics["action_correct"] = action_correct
302
+ metrics["text_correct"] = text_correct
303
+ metrics["type_correct"] = type_correct
304
+ metrics["total_num"] = len(targets)
305
+ print(metrics)
306
+ output_data = {
307
+ "metrics": metrics,
308
+ "data": output_data
309
+ }
310
+ print(save_dir)
311
+ if args.eval_name:
312
+ output_prediction_file = os.path.join(save_dir,f"predictions_ans_test_{args.eval_name}.json")
313
+ else:
314
+ output_prediction_file = os.path.join(save_dir,"predictions_ans_test.json")
315
+ with open(output_prediction_file, "w") as writer:
316
+ writer.write(json.dumps(output_data, indent=4))