akshay107 commited on
Commit
d93f604
·
verified ·
1 Parent(s): 7ca134e

Upload merge.py

Browse files
Files changed (1) hide show
  1. merge.py +56 -0
merge.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ from datasets import load_dataset
5
+ from peft import LoraConfig, PeftModel
6
+
7
+ device_map = "auto"
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ "/path/to/meta-llama3-8b",
10
+ #low_cpu_mem_usage=True,
11
+ return_dict=True,
12
+ torch_dtype=torch.float16,
13
+ device_map=device_map,
14
+ )
15
+
16
+ model = PeftModel.from_pretrained(model, "/path/to/llama3-8b-adapter", device_map=device_map)
17
+ model = model.merge_and_unload()
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained("/path/to/meta-llama3-8b", trust_remote_code=True)
20
+ tokenizer.pad_token_id = 18610
21
+
22
+ pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=4096, do_sample=False)
23
+ print("Padding side:",tokenizer.padding_side)
24
+ val_dataset = load_dataset("csv", data_files={'val':'/path/to/actseq-val-new.csv'})["val"]
25
+ test_dataset = load_dataset("csv", data_files={'test':'/path/to/actseq-test-new.csv'})["test"]
26
+
27
+
28
+ def formatting_prompts_func(example):
29
+ output_texts = []
30
+ for i in range(len(example['dial_with_actions'])):
31
+ text = f"Predict the action sequence (AS) for the Minecraft excerpt:\n {example['dial_with_actions'][i]}\n ### AS:"
32
+ output_texts.append(text)
33
+ return output_texts
34
+
35
+
36
+ val_texts = formatting_prompts_func(val_dataset)
37
+ test_texts = formatting_prompts_func(test_dataset)
38
+
39
+ print("Val Length:", len(val_texts))
40
+ print("Test Length:", len(test_texts))
41
+
42
+ f = open("/path/to/val-output-file","w")
43
+
44
+ for text in val_texts:
45
+ print(text)
46
+ print(pipe(text)[0]["generated_text"], file=f)
47
+
48
+ f.close()
49
+
50
+ f = open("/path/to/test-output-file","w")
51
+
52
+ for text in test_texts:
53
+ print(text)
54
+ print(pipe(text)[0]["generated_text"], file=f)
55
+
56
+ f.close()