Spaces:
Runtime error
Runtime error
Ethan Shen
commited on
Commit
·
dda1539
1
Parent(s):
bc4858e
Initial commit
Browse files- .gitignore +3 -0
- LICENSE +126 -0
- app.py +97 -0
- params/g15_d3_mixed.json +27 -0
- params/g20_d3_mixed.json +27 -0
- params/g5_d3_mixed.json +27 -0
- params/p15_d10_mixed.json +26 -0
- params/p15_d2_mixed.json +26 -0
- params/p15_d3_mixed.json +26 -0
- params/p15_d3_ngram4_mixed.json +22 -0
- params/p15_d4_mixed.json +26 -0
- params/p15_d5_mixed.json +26 -0
- params/p15_d6_mixed.json +26 -0
- params/p25_d3_mixed.json +26 -0
- params/p40_d3_mixed.json +12 -0
- params/p5_d3_mixed.json +26 -0
- requirements.txt +11 -0
- superposed/llama/__init__.py +6 -0
- superposed/llama/__pycache__/__init__.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/generation.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/model.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/superpose.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/superposed_generation.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/superposed_model.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/tokenizer.cpython-312.pyc +0 -0
- superposed/llama/__pycache__/utils.cpython-312.pyc +0 -0
- superposed/llama/generation.py +268 -0
- superposed/llama/metrics.py +109 -0
- superposed/llama/model.py +548 -0
- superposed/llama/superpose.py +328 -0
- superposed/llama/superposed_generation.py +198 -0
- superposed/llama/superposed_model.py +515 -0
- superposed/llama/tokenizer.py +68 -0
- superposed/llama/utils.py +70 -0
- superposed/ngrams/__pycache__/ngram_models.cpython-312.pyc +0 -0
- superposed/ngrams/make_corpus.py +268 -0
- superposed/ngrams/ngram_models.py +115 -0
- superposed/ngrams/test.json +8 -0
- superposed/notebooks/custom.ipynb +289 -0
- superposed/notebooks/nq.ipynb +417 -0
- superposed/notebooks/triviaqa.ipynb +404 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
weights
|
3 |
+
ckpts-200k
|
LICENSE
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LLAMA 2 COMMUNITY LICENSE AGREEMENT
|
2 |
+
Llama 2 Version Release Date: July 18, 2023
|
3 |
+
|
4 |
+
"Agreement" means the terms and conditions for use, reproduction, distribution and
|
5 |
+
modification of the Llama Materials set forth herein.
|
6 |
+
|
7 |
+
"Documentation" means the specifications, manuals and documentation
|
8 |
+
accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
|
9 |
+
libraries/llama-downloads/.
|
10 |
+
|
11 |
+
"Licensee" or "you" means you, or your employer or any other person or entity (if
|
12 |
+
you are entering into this Agreement on such person or entity's behalf), of the age
|
13 |
+
required under applicable laws, rules or regulations to provide legal consent and that
|
14 |
+
has legal authority to bind your employer or such other person or entity if you are
|
15 |
+
entering in this Agreement on their behalf.
|
16 |
+
|
17 |
+
"Llama 2" means the foundational large language models and software and
|
18 |
+
algorithms, including machine-learning model code, trained model weights,
|
19 |
+
inference-enabling code, training-enabling code, fine-tuning enabling code and other
|
20 |
+
elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
|
21 |
+
libraries/llama-downloads/.
|
22 |
+
|
23 |
+
"Llama Materials" means, collectively, Meta's proprietary Llama 2 and
|
24 |
+
Documentation (and any portion thereof) made available under this Agreement.
|
25 |
+
|
26 |
+
"Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you
|
27 |
+
are an entity, your principal place of business is in the EEA or Switzerland) and Meta
|
28 |
+
Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
29 |
+
|
30 |
+
By clicking "I Accept" below or by using or distributing any portion or element of the
|
31 |
+
Llama Materials, you agree to be bound by this Agreement.
|
32 |
+
|
33 |
+
1. License Rights and Redistribution.
|
34 |
+
|
35 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
|
36 |
+
transferable and royalty-free limited license under Meta's intellectual property or
|
37 |
+
other rights owned by Meta embodied in the Llama Materials to use, reproduce,
|
38 |
+
distribute, copy, create derivative works of, and make modifications to the Llama
|
39 |
+
Materials.
|
40 |
+
|
41 |
+
b. Redistribution and Use.
|
42 |
+
|
43 |
+
i. If you distribute or make the Llama Materials, or any derivative works
|
44 |
+
thereof, available to a third party, you shall provide a copy of this Agreement to such
|
45 |
+
third party.
|
46 |
+
ii. If you receive Llama Materials, or any derivative works thereof, from
|
47 |
+
a Licensee as part of an integrated end user product, then Section 2 of this
|
48 |
+
Agreement will not apply to you.
|
49 |
+
|
50 |
+
iii. You must retain in all copies of the Llama Materials that you
|
51 |
+
distribute the following attribution notice within a "Notice" text file distributed as a
|
52 |
+
part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License,
|
53 |
+
Copyright (c) Meta Platforms, Inc. All Rights Reserved."
|
54 |
+
|
55 |
+
iv. Your use of the Llama Materials must comply with applicable laws
|
56 |
+
and regulations (including trade compliance laws and regulations) and adhere to the
|
57 |
+
Acceptable Use Policy for the Llama Materials (available at
|
58 |
+
https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into
|
59 |
+
this Agreement.
|
60 |
+
|
61 |
+
v. You will not use the Llama Materials or any output or results of the
|
62 |
+
Llama Materials to improve any other large language model (excluding Llama 2 or
|
63 |
+
derivative works thereof).
|
64 |
+
|
65 |
+
2. Additional Commercial Terms. If, on the Llama 2 version release date, the
|
66 |
+
monthly active users of the products or services made available by or for Licensee,
|
67 |
+
or Licensee's affiliates, is greater than 700 million monthly active users in the
|
68 |
+
preceding calendar month, you must request a license from Meta, which Meta may
|
69 |
+
grant to you in its sole discretion, and you are not authorized to exercise any of the
|
70 |
+
rights under this Agreement unless or until Meta otherwise expressly grants you
|
71 |
+
such rights.
|
72 |
+
|
73 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE
|
74 |
+
LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE
|
75 |
+
PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
76 |
+
EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY
|
77 |
+
WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR
|
78 |
+
FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE
|
79 |
+
FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING
|
80 |
+
THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR
|
81 |
+
USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
|
82 |
+
|
83 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE
|
84 |
+
LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT,
|
85 |
+
NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS
|
86 |
+
AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL,
|
87 |
+
CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN
|
88 |
+
IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF
|
89 |
+
ANY OF THE FOREGOING.
|
90 |
+
|
91 |
+
5. Intellectual Property.
|
92 |
+
|
93 |
+
a. No trademark licenses are granted under this Agreement, and in
|
94 |
+
connection with the Llama Materials, neither Meta nor Licensee may use any name
|
95 |
+
or mark owned by or associated with the other or any of its affiliates, except as
|
96 |
+
required for reasonable and customary use in describing and redistributing the
|
97 |
+
Llama Materials.
|
98 |
+
|
99 |
+
b. Subject to Meta's ownership of Llama Materials and derivatives made by or
|
100 |
+
for Meta, with respect to any derivative works and modifications of the Llama
|
101 |
+
Materials that are made by you, as between you and Meta, you are and will be the
|
102 |
+
owner of such derivative works and modifications.
|
103 |
+
|
104 |
+
c. If you institute litigation or other proceedings against Meta or any entity
|
105 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that the Llama
|
106 |
+
Materials or Llama 2 outputs or results, or any portion of any of the foregoing,
|
107 |
+
constitutes an infringement of intellectual property or other rights owned or licensable
|
108 |
+
by you, then any licenses granted to you under this Agreement shall terminate as of
|
109 |
+
the date such litigation or claim is filed or instituted. You will indemnify and hold
|
110 |
+
harmless Meta from and against any claim by any third party arising out of or related
|
111 |
+
to your use or distribution of the Llama Materials.
|
112 |
+
|
113 |
+
6. Term and Termination. The term of this Agreement will commence upon your
|
114 |
+
acceptance of this Agreement or access to the Llama Materials and will continue in
|
115 |
+
full force and effect until terminated in accordance with the terms and conditions
|
116 |
+
herein. Meta may terminate this Agreement if you are in breach of any term or
|
117 |
+
condition of this Agreement. Upon termination of this Agreement, you shall delete
|
118 |
+
and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the
|
119 |
+
termination of this Agreement.
|
120 |
+
|
121 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and
|
122 |
+
construed under the laws of the State of California without regard to choice of law
|
123 |
+
principles, and the UN Convention on Contracts for the International Sale of Goods
|
124 |
+
does not apply to this Agreement. The courts of California shall have exclusive
|
125 |
+
jurisdiction of any dispute arising out of this Agreement.
|
126 |
+
|
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import spaces
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from huggingface_hub import login, snapshot_download
|
9 |
+
|
10 |
+
from superposed.llama.superposed_generation import SuperposedLlama
|
11 |
+
from superposed.llama.tokenizer import Tokenizer
|
12 |
+
from superposed.ngrams.ngram_models import make_models
|
13 |
+
|
14 |
+
# load_dotenv()
|
15 |
+
# print(os.getenv("HF_ACCESS_TOKEN"))
|
16 |
+
login(os.getenv("HF_ACCESS_TOKEN"))
|
17 |
+
if not os.path.exists("./weights/"):
|
18 |
+
os.mkdir("./weights/")
|
19 |
+
snapshot_download(repo_id="meta-llama/Llama-2-7b", local_dir="./weights/")
|
20 |
+
weight_path = "./weights/"
|
21 |
+
# Load params
|
22 |
+
param_file = "params/p15_d3_mixed.json"
|
23 |
+
with open(param_file, "r") as f:
|
24 |
+
params = json.load(f)
|
25 |
+
alpha = params["alpha"]
|
26 |
+
temp = params["temp"]
|
27 |
+
n_drafts = params["n_drafts"]
|
28 |
+
prompt_len = params["prompt_len"]
|
29 |
+
n_token_sample = params["n_token_sample"]
|
30 |
+
i_weights = params["i_weights"]
|
31 |
+
i_length = params["i_length"]
|
32 |
+
# Load main model
|
33 |
+
model = SuperposedLlama.build(ckpt_dir=weight_path,
|
34 |
+
tokenizer_path=f'{weight_path}/tokenizer.model',
|
35 |
+
max_seq_len=100,
|
36 |
+
max_batch_size=32,
|
37 |
+
model_parallel_size=1)
|
38 |
+
tokenizer = Tokenizer(f'{weight_path}/tokenizer.model')
|
39 |
+
# Create ngram models
|
40 |
+
ngrams = make_models("ckpts-200k", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)
|
41 |
+
|
42 |
+
def decode(tokenizer, encoding):
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
tokenizer (Any): Tokenizer
|
46 |
+
encoding (torch.Tensor): Encoding
|
47 |
+
Returns:
|
48 |
+
decoding (str)
|
49 |
+
"""
|
50 |
+
eos_locs = (encoding == tokenizer.eos_id).nonzero()
|
51 |
+
if len(eos_locs > 0):
|
52 |
+
encoding = encoding[:eos_locs[0]]
|
53 |
+
return tokenizer.decode(encoding.to(torch.int32).tolist())
|
54 |
+
|
55 |
+
@spaces.GPU
|
56 |
+
def update_options(input, num_tokens):
|
57 |
+
tokenized_prompts = tokenizer.encode([input], True, False)
|
58 |
+
alive_gens, _ = model.sup_generate(prompt_tokens=tokenized_prompts,
|
59 |
+
smoothing="geom",
|
60 |
+
max_gen_len=num_tokens,
|
61 |
+
n_token_sample=n_token_sample,
|
62 |
+
alpha=alpha,
|
63 |
+
temp=temp,
|
64 |
+
n_drafts=n_drafts,
|
65 |
+
i_weights=i_weights,
|
66 |
+
i_length=i_length,
|
67 |
+
ngrams=ngrams,
|
68 |
+
get_time=False,
|
69 |
+
penalty=200)
|
70 |
+
gens = alive_gens[0].reshape(n_drafts, -1)
|
71 |
+
return decode(tokenizer, gens[0]), decode(tokenizer, gens[1]), decode(tokenizer, gens[2])
|
72 |
+
|
73 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
74 |
+
gr.Markdown(
|
75 |
+
"""
|
76 |
+
# Superposed Decoding
|
77 |
+
Start typing below to see suggestions.
|
78 |
+
""")
|
79 |
+
slider = gr.Slider(minimum=1, maximum=10, step=1, label="Generation length", value=10)
|
80 |
+
inp = gr.Textbox(placeholder="Type anything!", lines=3)
|
81 |
+
option1 = gr.Button(value="Option 1")
|
82 |
+
option2 = gr.Button(value="Option 2")
|
83 |
+
option3 = gr.Button(value="Option 3")
|
84 |
+
inp.change(update_options, inputs=[inp, slider], outputs=[option1, option2, option3])
|
85 |
+
# Button updates
|
86 |
+
@option1.click(inputs=[inp, option1], outputs=inp)
|
87 |
+
def option1_click(curr, txt):
|
88 |
+
return curr + txt
|
89 |
+
@option2.click(inputs=[inp, option2], outputs=inp)
|
90 |
+
def option2_click(curr, txt):
|
91 |
+
return curr + txt
|
92 |
+
@option3.click(inputs=[inp, option3], outputs=inp)
|
93 |
+
def option3_click(curr, txt):
|
94 |
+
return curr + txt
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
demo.launch(debug=True)
|
params/g15_d3_mixed.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.48,
|
3 |
+
"temp": 0.06,
|
4 |
+
"n_drafts": 3,
|
5 |
+
"prompt_len": 15,
|
6 |
+
"n_token_sample": 15,
|
7 |
+
"max_gen_len": 15,
|
8 |
+
"n_token_consider": 32000,
|
9 |
+
"mixing_method": "sample_new_weights_with_score",
|
10 |
+
"smoothing": "geom",
|
11 |
+
"sample_tokens": 0,
|
12 |
+
"sample_beams": 0,
|
13 |
+
"i_weights": [
|
14 |
+
0.01,
|
15 |
+
0.04,
|
16 |
+
0.15,
|
17 |
+
0.18,
|
18 |
+
0.12
|
19 |
+
],
|
20 |
+
"i_length": [
|
21 |
+
1,
|
22 |
+
2,
|
23 |
+
3,
|
24 |
+
4,
|
25 |
+
5
|
26 |
+
]
|
27 |
+
}
|
params/g20_d3_mixed.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.5,
|
3 |
+
"temp": 0.04,
|
4 |
+
"n_drafts": 3,
|
5 |
+
"prompt_len": 15,
|
6 |
+
"n_token_sample": 15,
|
7 |
+
"max_gen_len": 20,
|
8 |
+
"n_token_consider": 32000,
|
9 |
+
"mixing_method": "sample_new_weights_with_score",
|
10 |
+
"smoothing": "geom",
|
11 |
+
"sample_tokens": 0,
|
12 |
+
"sample_beams": 0,
|
13 |
+
"i_weights": [
|
14 |
+
0.01,
|
15 |
+
0.04,
|
16 |
+
0.15,
|
17 |
+
0.18,
|
18 |
+
0.12
|
19 |
+
],
|
20 |
+
"i_length": [
|
21 |
+
1,
|
22 |
+
2,
|
23 |
+
3,
|
24 |
+
4,
|
25 |
+
5
|
26 |
+
]
|
27 |
+
}
|
params/g5_d3_mixed.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.52,
|
3 |
+
"temp": 0.06,
|
4 |
+
"n_drafts": 3,
|
5 |
+
"prompt_len": 15,
|
6 |
+
"n_token_sample": 15,
|
7 |
+
"max_gen_len": 5,
|
8 |
+
"n_token_consider": 32000,
|
9 |
+
"mixing_method": "sample_new_weights_with_score",
|
10 |
+
"smoothing": "geom",
|
11 |
+
"sample_tokens": 0,
|
12 |
+
"sample_beams": 0,
|
13 |
+
"i_weights": [
|
14 |
+
0.01,
|
15 |
+
0.04,
|
16 |
+
0.15,
|
17 |
+
0.18,
|
18 |
+
0.12
|
19 |
+
],
|
20 |
+
"i_length": [
|
21 |
+
1,
|
22 |
+
2,
|
23 |
+
3,
|
24 |
+
4,
|
25 |
+
5
|
26 |
+
]
|
27 |
+
}
|
params/p15_d10_mixed.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.54,
|
3 |
+
"temp": 0.12,
|
4 |
+
"n_drafts": 10,
|
5 |
+
"prompt_len": 15,
|
6 |
+
"n_token_sample": 30,
|
7 |
+
"n_token_consider": 32000,
|
8 |
+
"mixing_method": "sample_new_weights_with_score",
|
9 |
+
"smoothing": "geom",
|
10 |
+
"sample_tokens": 0,
|
11 |
+
"sample_beams": 0,
|
12 |
+
"i_weights": [
|
13 |
+
0.01,
|
14 |
+
0.04,
|
15 |
+
0.15,
|
16 |
+
0.18,
|
17 |
+
0.12
|
18 |
+
],
|
19 |
+
"i_length": [
|
20 |
+
1,
|
21 |
+
2,
|
22 |
+
3,
|
23 |
+
4,
|
24 |
+
5
|
25 |
+
]
|
26 |
+
}
|
params/p15_d2_mixed.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.62,
|
3 |
+
"temp": 0.06,
|
4 |
+
"n_drafts": 2,
|
5 |
+
"prompt_len": 15,
|
6 |
+
"n_token_sample": 6,
|
7 |
+
"n_token_consider": 32000,
|
8 |
+
"mixing_method": "sample_new_weights_with_score",
|
9 |
+
"smoothing": "geom",
|
10 |
+
"sample_tokens": 0,
|
11 |
+
"sample_beams": 0,
|
12 |
+
"i_weights": [
|
13 |
+
0.01,
|
14 |
+
0.04,
|
15 |
+
0.15,
|
16 |
+
0.18,
|
17 |
+
0.12
|
18 |
+
],
|
19 |
+
"i_length": [
|
20 |
+
1,
|
21 |
+
2,
|
22 |
+
3,
|
23 |
+
4,
|
24 |
+
5
|
25 |
+
]
|
26 |
+
}
|
params/p15_d3_mixed.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.54,
|
3 |
+
"temp": 0.06,
|
4 |
+
"n_drafts": 3,
|
5 |
+
"prompt_len": 15,
|
6 |
+
"n_token_sample": 9,
|
7 |
+
"n_token_consider": 32000,
|
8 |
+
"mixing_method": "sample_new_weights_with_score",
|
9 |
+
"smoothing": "geom",
|
10 |
+
"sample_tokens": 0,
|
11 |
+
"sample_beams": 0,
|
12 |
+
"i_weights": [
|
13 |
+
0.01,
|
14 |
+
0.04,
|
15 |
+
0.15,
|
16 |
+
0.18,
|
17 |
+
0.12
|
18 |
+
],
|
19 |
+
"i_length": [
|
20 |
+
1,
|
21 |
+
2,
|
22 |
+
3,
|
23 |
+
4,
|
24 |
+
5
|
25 |
+
]
|
26 |
+
}
|
params/p15_d3_ngram4_mixed.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.55,
|
3 |
+
"temp": 0.1,
|
4 |
+
"n_drafts": 3,
|
5 |
+
"prompt_len": 15,
|
6 |
+
"n_token_sample": 9,
|
7 |
+
"n_token_consider": 32000,
|
8 |
+
"mixing_method": "sample_new_weights_with_score",
|
9 |
+
"smoothing": "geom",
|
10 |
+
"sample_tokens": 0,
|
11 |
+
"sample_beams": 0,
|
12 |
+
"i_weights": [
|
13 |
+
0.01,
|
14 |
+
0.04,
|
15 |
+
0.15
|
16 |
+
],
|
17 |
+
"i_length": [
|
18 |
+
1,
|
19 |
+
2,
|
20 |
+
3
|
21 |
+
]
|
22 |
+
}
|
params/p15_d4_mixed.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.52,
|
3 |
+
"temp": 0.06,
|
4 |
+
"n_drafts": 4,
|
5 |
+
"prompt_len": 15,
|
6 |
+
"n_token_sample": 12,
|
7 |
+
"n_token_consider": 32000,
|
8 |
+
"mixing_method": "sample_new_weights_with_score",
|
9 |
+
"smoothing": "geom",
|
10 |
+
"sample_tokens": 0,
|
11 |
+
"sample_beams": 0,
|
12 |
+
"i_weights": [
|
13 |
+
0.01,
|
14 |
+
0.04,
|
15 |
+
0.15,
|
16 |
+
0.18,
|
17 |
+
0.12
|
18 |
+
],
|
19 |
+
"i_length": [
|
20 |
+
1,
|
21 |
+
2,
|
22 |
+
3,
|
23 |
+
4,
|
24 |
+
5
|
25 |
+
]
|
26 |
+
}
|
params/p15_d5_mixed.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.6,
|
3 |
+
"temp": 0.06,
|
4 |
+
"n_drafts": 5,
|
5 |
+
"prompt_len": 15,
|
6 |
+
"n_token_sample": 15,
|
7 |
+
"n_token_consider": 32000,
|
8 |
+
"mixing_method": "sample_new_weights_with_score",
|
9 |
+
"smoothing": "geom",
|
10 |
+
"sample_tokens": 0,
|
11 |
+
"sample_beams": 0,
|
12 |
+
"i_weights": [
|
13 |
+
0.01,
|
14 |
+
0.04,
|
15 |
+
0.15,
|
16 |
+
0.18,
|
17 |
+
0.12
|
18 |
+
],
|
19 |
+
"i_length": [
|
20 |
+
1,
|
21 |
+
2,
|
22 |
+
3,
|
23 |
+
4,
|
24 |
+
5
|
25 |
+
]
|
26 |
+
}
|
params/p15_d6_mixed.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.52,
|
3 |
+
"temp": 0.06,
|
4 |
+
"n_drafts": 6,
|
5 |
+
"prompt_len": 15,
|
6 |
+
"n_token_sample": 18,
|
7 |
+
"n_token_consider": 32000,
|
8 |
+
"mixing_method": "sample_new_weights_with_score",
|
9 |
+
"smoothing": "geom",
|
10 |
+
"sample_tokens": 0,
|
11 |
+
"sample_beams": 0,
|
12 |
+
"i_weights": [
|
13 |
+
0.01,
|
14 |
+
0.04,
|
15 |
+
0.15,
|
16 |
+
0.18,
|
17 |
+
0.12
|
18 |
+
],
|
19 |
+
"i_length": [
|
20 |
+
1,
|
21 |
+
2,
|
22 |
+
3,
|
23 |
+
4,
|
24 |
+
5
|
25 |
+
]
|
26 |
+
}
|
params/p25_d3_mixed.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.5,
|
3 |
+
"temp": 0.12,
|
4 |
+
"n_drafts": 3,
|
5 |
+
"prompt_len": 25,
|
6 |
+
"n_token_sample": 15,
|
7 |
+
"n_token_consider": 32000,
|
8 |
+
"mixing_method": "sample_new_weights_with_score",
|
9 |
+
"smoothing": "geom",
|
10 |
+
"sample_tokens": 0,
|
11 |
+
"sample_beams": 0,
|
12 |
+
"i_weights": [
|
13 |
+
0.01,
|
14 |
+
0.04,
|
15 |
+
0.15,
|
16 |
+
0.18,
|
17 |
+
0.12
|
18 |
+
],
|
19 |
+
"i_length": [
|
20 |
+
1,
|
21 |
+
2,
|
22 |
+
3,
|
23 |
+
4,
|
24 |
+
5
|
25 |
+
]
|
26 |
+
}
|
params/p40_d3_mixed.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.55,
|
3 |
+
"temp": 0.1,
|
4 |
+
"prompt_len": 40,
|
5 |
+
"mixing_method": "sample_new_weights_with_score",
|
6 |
+
"smoothing": "geom",
|
7 |
+
"sample_tokens": 0,
|
8 |
+
"sample_beams": 0,
|
9 |
+
"i_weights": [0.01, 0.04, 0.15, 0.18, 0.12],
|
10 |
+
"i_length": [1, 2, 3, 4, 5],
|
11 |
+
"ckpt_path": "../ckpts-200k"
|
12 |
+
}
|
params/p5_d3_mixed.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha": 0.34,
|
3 |
+
"temp": 0.12,
|
4 |
+
"n_drafts": 3,
|
5 |
+
"prompt_len": 5,
|
6 |
+
"n_token_sample": 15,
|
7 |
+
"n_token_consider": 32000,
|
8 |
+
"mixing_method": "sample_new_weights_with_score",
|
9 |
+
"smoothing": "geom",
|
10 |
+
"sample_tokens": 0,
|
11 |
+
"sample_beams": 0,
|
12 |
+
"i_weights": [
|
13 |
+
0.01,
|
14 |
+
0.04,
|
15 |
+
0.15,
|
16 |
+
0.18,
|
17 |
+
0.12
|
18 |
+
],
|
19 |
+
"i_length": [
|
20 |
+
1,
|
21 |
+
2,
|
22 |
+
3,
|
23 |
+
4,
|
24 |
+
5
|
25 |
+
]
|
26 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets==2.19.0
|
2 |
+
fairscale==0.4.13
|
3 |
+
loguru==0.7.2
|
4 |
+
nltk==3.8.1
|
5 |
+
numpy==1.26.4
|
6 |
+
Requests==2.32.2
|
7 |
+
sentencepiece==0.2.0
|
8 |
+
setuptools==58.2.0
|
9 |
+
torch==2.3.0
|
10 |
+
tqdm==4.66.4
|
11 |
+
transformers==4.37.2
|
superposed/llama/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
from .generation import Llama, Dialog
|
5 |
+
from .model import ModelArgs, Transformer
|
6 |
+
from .tokenizer import Tokenizer
|
superposed/llama/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (335 Bytes). View file
|
|
superposed/llama/__pycache__/generation.cpython-312.pyc
ADDED
Binary file (13.9 kB). View file
|
|
superposed/llama/__pycache__/model.cpython-312.pyc
ADDED
Binary file (26.7 kB). View file
|
|
superposed/llama/__pycache__/superpose.cpython-312.pyc
ADDED
Binary file (19.1 kB). View file
|
|
superposed/llama/__pycache__/superposed_generation.cpython-312.pyc
ADDED
Binary file (10.1 kB). View file
|
|
superposed/llama/__pycache__/superposed_model.cpython-312.pyc
ADDED
Binary file (25.9 kB). View file
|
|
superposed/llama/__pycache__/tokenizer.cpython-312.pyc
ADDED
Binary file (3.26 kB). View file
|
|
superposed/llama/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (3.97 kB). View file
|
|
superposed/llama/generation.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List, Literal, Optional, Tuple, TypedDict
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from fairscale.nn.model_parallel.initialize import (
|
14 |
+
get_model_parallel_rank,
|
15 |
+
initialize_model_parallel,
|
16 |
+
model_parallel_is_initialized,
|
17 |
+
)
|
18 |
+
|
19 |
+
from superposed.llama.model import ModelArgs, Transformer
|
20 |
+
from superposed.llama.tokenizer import Tokenizer
|
21 |
+
from superposed.llama.utils import *
|
22 |
+
|
23 |
+
Role = Literal["system", "user", "assistant"]
|
24 |
+
|
25 |
+
|
26 |
+
class Message(TypedDict):
|
27 |
+
role: Role
|
28 |
+
content: str
|
29 |
+
|
30 |
+
|
31 |
+
class CompletionPrediction(TypedDict, total=False):
|
32 |
+
generation: str
|
33 |
+
tokens: List[str] # not required
|
34 |
+
logprobs: List[float] # not required
|
35 |
+
|
36 |
+
|
37 |
+
class ChatPrediction(TypedDict, total=False):
|
38 |
+
generation: Message
|
39 |
+
tokens: List[str] # not required
|
40 |
+
logprobs: List[float] # not required
|
41 |
+
|
42 |
+
|
43 |
+
Dialog = List[Message]
|
44 |
+
|
45 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
46 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
47 |
+
|
48 |
+
SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
|
49 |
+
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
|
50 |
+
|
51 |
+
|
52 |
+
class Llama:
|
53 |
+
@staticmethod
|
54 |
+
def build(
|
55 |
+
ckpt_dir: str,
|
56 |
+
tokenizer_path: str,
|
57 |
+
max_seq_len: int,
|
58 |
+
max_batch_size: int,
|
59 |
+
device: None,
|
60 |
+
model_parallel_size: Optional[int] = None,
|
61 |
+
seed: int = 1,
|
62 |
+
) -> "Llama":
|
63 |
+
"""
|
64 |
+
Build a Llama instance by initializing and loading a pre-trained model.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
ckpt_dir (str): Path to the directory containing checkpoint files.
|
68 |
+
tokenizer_path (str): Path to the tokenizer file.
|
69 |
+
max_seq_len (int): Maximum sequence length for input text.
|
70 |
+
max_batch_size (int): Maximum batch size for inference.
|
71 |
+
mixed (bool): Whether to mix embeddings or not
|
72 |
+
model_parallel_size (Optional[int], optional): Number of model parallel processes.
|
73 |
+
If not provided, it's determined from the environment. Defaults to None.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Llama: An instance of the Llama class with the loaded model and tokenizer.
|
77 |
+
|
78 |
+
Raises:
|
79 |
+
AssertionError: If there are no checkpoint files in the specified directory,
|
80 |
+
or if the model parallel size does not match the number of checkpoint files.
|
81 |
+
|
82 |
+
Note:
|
83 |
+
This method initializes the distributed process group, sets the device to CUDA,
|
84 |
+
and loads the pre-trained model and tokenizer.
|
85 |
+
|
86 |
+
"""
|
87 |
+
if not torch.distributed.is_initialized():
|
88 |
+
torch.distributed.init_process_group("nccl")
|
89 |
+
if not model_parallel_is_initialized():
|
90 |
+
if model_parallel_size is None:
|
91 |
+
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
92 |
+
initialize_model_parallel(model_parallel_size)
|
93 |
+
|
94 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
95 |
+
print(local_rank)
|
96 |
+
# torch.cuda.set_device(local_rank)
|
97 |
+
if device == None:
|
98 |
+
torch.cuda.set_device(local_rank)
|
99 |
+
device = f"cuda:{local_rank}"
|
100 |
+
# seed must be the same in all processes
|
101 |
+
torch.manual_seed(seed)
|
102 |
+
|
103 |
+
if local_rank > 0:
|
104 |
+
sys.stdout = open(os.devnull, "w")
|
105 |
+
|
106 |
+
start_time = time.time()
|
107 |
+
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
108 |
+
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
109 |
+
assert model_parallel_size == len(
|
110 |
+
checkpoints
|
111 |
+
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
112 |
+
ckpt_path = checkpoints[get_model_parallel_rank()]
|
113 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
114 |
+
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
115 |
+
params = json.loads(f.read())
|
116 |
+
|
117 |
+
model_args: ModelArgs = ModelArgs(
|
118 |
+
max_seq_len=max_seq_len,
|
119 |
+
max_batch_size=max_batch_size,
|
120 |
+
**params,
|
121 |
+
)
|
122 |
+
tokenizer = Tokenizer(model_path=tokenizer_path)
|
123 |
+
model_args.vocab_size = tokenizer.n_words
|
124 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
125 |
+
model = Transformer(model_args)
|
126 |
+
model.load_state_dict(checkpoint, strict=False)
|
127 |
+
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
128 |
+
return Llama(model, tokenizer, device)
|
129 |
+
|
130 |
+
def __init__(self, model: Transformer, tokenizer: Tokenizer, device):
|
131 |
+
self.model = model.to(device).eval()
|
132 |
+
self.tokenizer = tokenizer
|
133 |
+
self.device = device
|
134 |
+
|
135 |
+
@torch.inference_mode()
|
136 |
+
def generate(
|
137 |
+
self,
|
138 |
+
prompt_tokens: List[List[int]],
|
139 |
+
max_gen_len: int,
|
140 |
+
temperature: float = 0.6,
|
141 |
+
top_p: float = 0.9,
|
142 |
+
logprobs: bool = True,
|
143 |
+
grade: bool = False
|
144 |
+
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
|
145 |
+
"""
|
146 |
+
Generate text sequences based on provided prompts using the language generation model.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
|
150 |
+
max_gen_len (int): Maximum length of the generated text sequence.
|
151 |
+
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
|
152 |
+
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
|
153 |
+
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
|
154 |
+
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
|
158 |
+
|
159 |
+
Note:
|
160 |
+
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
|
161 |
+
If logprobs is True, token log probabilities are computed for each generated token.
|
162 |
+
|
163 |
+
"""
|
164 |
+
params = self.model.params
|
165 |
+
bsz = len(prompt_tokens)
|
166 |
+
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
167 |
+
|
168 |
+
min_prompt_len = min(len(t) for t in prompt_tokens)
|
169 |
+
max_prompt_len = max(len(t) for t in prompt_tokens)
|
170 |
+
# assert min_prompt_len == max_prompt_len
|
171 |
+
prompt_len = min_prompt_len
|
172 |
+
assert max_prompt_len <= params.max_seq_len
|
173 |
+
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
|
174 |
+
|
175 |
+
pad_id = self.tokenizer.pad_id
|
176 |
+
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device)
|
177 |
+
for k, t in enumerate(prompt_tokens):
|
178 |
+
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=self.device)
|
179 |
+
if logprobs:
|
180 |
+
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
181 |
+
prev_pos = 0
|
182 |
+
eos_reached = torch.tensor([False] * bsz, device=self.device)
|
183 |
+
input_text_mask = tokens != pad_id
|
184 |
+
if grade:
|
185 |
+
pad_mask = tokens == pad_id
|
186 |
+
tokens = torch.where(tokens == pad_id, 0, tokens)
|
187 |
+
logits = self.model.forward(tokens, prev_pos, False)
|
188 |
+
tokens[pad_mask] = pad_id
|
189 |
+
token_logprobs = -F.cross_entropy(
|
190 |
+
input=logits[:, :-1, :].transpose(1, 2),
|
191 |
+
target=tokens[:, 1:],
|
192 |
+
reduction="none",
|
193 |
+
ignore_index=pad_id,
|
194 |
+
)
|
195 |
+
#if pad_id in tokens:
|
196 |
+
# print(pad_id)
|
197 |
+
# print(tokens)
|
198 |
+
# print(token_logprobs)
|
199 |
+
return token_logprobs
|
200 |
+
|
201 |
+
for cur_pos in range(min_prompt_len, total_len):
|
202 |
+
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos, False)
|
203 |
+
if temperature > 0:
|
204 |
+
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
205 |
+
next_token = sample_top_p(probs, top_p)
|
206 |
+
else:
|
207 |
+
next_token = torch.argmax(logits[:, -1], dim=-1)
|
208 |
+
|
209 |
+
next_token = next_token.reshape(-1)
|
210 |
+
# only replace token if prompt has already been generated
|
211 |
+
next_token = torch.where(
|
212 |
+
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
213 |
+
)
|
214 |
+
tokens[:, cur_pos] = next_token
|
215 |
+
if logprobs:
|
216 |
+
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
217 |
+
input=logits.transpose(1, 2),
|
218 |
+
target=tokens[:, prev_pos + 1 : cur_pos + 1],
|
219 |
+
reduction="none",
|
220 |
+
ignore_index=pad_id,
|
221 |
+
)
|
222 |
+
eos_reached |= (~input_text_mask[:, cur_pos]) & (
|
223 |
+
next_token == self.tokenizer.eos_id
|
224 |
+
)
|
225 |
+
prev_pos = cur_pos
|
226 |
+
if all(eos_reached):
|
227 |
+
break
|
228 |
+
|
229 |
+
# seq_len = torch.sum(tokens != pad_id, dim=1)
|
230 |
+
# return tokens, torch.exp(-1 * torch.sum(logprobs, dim=1) / (seq_len - prompt_len)), torch.exp(-1 * torch.sum(custom_logprobs, dim=1) / )
|
231 |
+
if logprobs:
|
232 |
+
token_logprobs = token_logprobs.tolist()
|
233 |
+
|
234 |
+
out_ppl = []
|
235 |
+
for i, toks in enumerate(tokens.tolist()):
|
236 |
+
if logprobs:
|
237 |
+
probs = token_logprobs[i][prompt_len : len(prompt_tokens[i]) + max_gen_len]
|
238 |
+
# cut to eos tok if any
|
239 |
+
if self.tokenizer.eos_id in toks:
|
240 |
+
eos_idx = toks.index(self.tokenizer.eos_id)
|
241 |
+
probs = probs[:eos_idx] if logprobs else None
|
242 |
+
out_ppl.append(torch.exp(-1 * torch.sum(torch.tensor(probs)) / len(probs)))
|
243 |
+
return tokens, torch.tensor(out_ppl) if logprobs else None
|
244 |
+
|
245 |
+
def sample_top_p(probs, p, s=1):
|
246 |
+
"""
|
247 |
+
Perform top-p (nucleus) sampling on a probability distribution.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
probs (torch.Tensor): Probability distribution tensor.
|
251 |
+
p (float): Probability threshold for top-p sampling.
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
torch.Tensor: Sampled token indices.
|
255 |
+
|
256 |
+
Note:
|
257 |
+
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
258 |
+
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
259 |
+
|
260 |
+
"""
|
261 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
262 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
263 |
+
mask = probs_sum - probs_sort > p
|
264 |
+
probs_sort[mask] = 0.0
|
265 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
266 |
+
next_token = torch.multinomial(probs_sort, num_samples=s)
|
267 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
268 |
+
return next_token
|
superposed/llama/metrics.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import nltk
|
3 |
+
from nltk.translate.bleu_score import SmoothingFunction
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
def calculate_perplexity(model, tokens, prompt_len, bsz=1, marker=False):
|
7 |
+
"""
|
8 |
+
Calculate perplexity of given tokens using provided model, ignoring padding tokens.
|
9 |
+
Args:
|
10 |
+
model: Llama model
|
11 |
+
tokens (List[List[int]] or torch.Tensor): Input tokens (n_prompt * n_draft, seqlen)
|
12 |
+
prompt_len (int): Prefix length
|
13 |
+
bsz (int): Batch size
|
14 |
+
marker (bool): Whether to show progress bar
|
15 |
+
Returns:
|
16 |
+
Perplexity across all generations (n_prompt * n_drafts)
|
17 |
+
"""
|
18 |
+
it = range(0, len(tokens), bsz)
|
19 |
+
if marker:
|
20 |
+
it = tqdm(it)
|
21 |
+
start = 0
|
22 |
+
ppl = torch.zeros(len(tokens))
|
23 |
+
for start in it:
|
24 |
+
end = start + bsz
|
25 |
+
data = tokens[start : end]
|
26 |
+
if not isinstance(data, list):
|
27 |
+
data = data.tolist()
|
28 |
+
# Remove any padding tokens (-1) in generations
|
29 |
+
for d_idx in range(len(data)):
|
30 |
+
cur = data[d_idx]
|
31 |
+
if -1 in cur:
|
32 |
+
data[d_idx] = cur[:cur.index(-1)]
|
33 |
+
# Calculate cross entropy loss on tokens
|
34 |
+
ce_loss = model.generate(data, max_gen_len=0, temperature=-1, top_p=-1, grade=True)
|
35 |
+
# Cut off everything past `prompt_len`
|
36 |
+
ce_loss = ce_loss[:, prompt_len-1:] # Subtract 1 because the first token (start token) is removed
|
37 |
+
# Calculate perplexity
|
38 |
+
lengths = (ce_loss != 0).sum(dim=-1)
|
39 |
+
mean = ce_loss.sum(dim=-1) / lengths
|
40 |
+
ppl[start : end] = torch.exp(-1 * mean)
|
41 |
+
return ppl
|
42 |
+
|
43 |
+
def calculate_diversity(generations, k=4):
|
44 |
+
"""
|
45 |
+
Calculate diversity of generations using SELF-BLEU.
|
46 |
+
Args:
|
47 |
+
generations (List[List[List[int]]]): Tokenized input
|
48 |
+
k (int, Optional): Number of n-grams to use for bleu
|
49 |
+
Returns:
|
50 |
+
Average diversity across all generations (float)
|
51 |
+
"""
|
52 |
+
nltk.download('punkt') # Can be deleted once downloaded
|
53 |
+
smooth = SmoothingFunction()
|
54 |
+
bleus = []
|
55 |
+
|
56 |
+
for drafts in generations:
|
57 |
+
tokenized_drafts = []
|
58 |
+
# Stringify tokens
|
59 |
+
for d in drafts:
|
60 |
+
if -1 in d:
|
61 |
+
d = d[:d.index(-1)]
|
62 |
+
tokenized_drafts.append([str(n) for n in d])
|
63 |
+
# Calculate SELF-BLEU
|
64 |
+
minlength = min([len(g) for g in tokenized_drafts])
|
65 |
+
minlength = min(minlength, k)
|
66 |
+
weights = tuple((1. / minlength for _ in range(minlength)))
|
67 |
+
for i in range(len(drafts)):
|
68 |
+
# Create source and reference (all other drafts)
|
69 |
+
src = tokenized_drafts[i]
|
70 |
+
ref = tokenized_drafts[:i] + tokenized_drafts[i+1:]
|
71 |
+
tmp = nltk.translate.bleu_score.sentence_bleu(references=ref,
|
72 |
+
hypothesis=src,
|
73 |
+
weights=weights,
|
74 |
+
smoothing_function=smooth.method1)
|
75 |
+
bleus.append(tmp)
|
76 |
+
bleus = torch.Tensor(bleus)
|
77 |
+
return torch.mean(bleus)
|
78 |
+
|
79 |
+
|
80 |
+
def calculate_ngram_repetition(sequences):
|
81 |
+
"""
|
82 |
+
Calculate uniqueness scores of `sequences`.
|
83 |
+
Args:
|
84 |
+
sequences (List[List[int]]): Generated sequences
|
85 |
+
Returns:
|
86 |
+
(unigram_uniqueness, bigram_uniqueness, trigram_uniqueness)
|
87 |
+
"""
|
88 |
+
u_total = 0
|
89 |
+
b_total = 0
|
90 |
+
t_total = 0
|
91 |
+
# Iterate through all sequences indiscriminately
|
92 |
+
for gen in sequences:
|
93 |
+
if -1 in gen:
|
94 |
+
gen = gen[:gen.index(-1)]
|
95 |
+
unigrams, bigrams, trigrams = [], [], []
|
96 |
+
o = [str(i) for i in gen]
|
97 |
+
# Create lists of n-grams for the generation
|
98 |
+
for i in range(len(o)):
|
99 |
+
unigrams.append(o[i])
|
100 |
+
for i in range(len(o) - 1):
|
101 |
+
bigrams.append(o[i] + '_' + o[i + 1])
|
102 |
+
for i in range(len(o) - 2):
|
103 |
+
trigrams.append(o[i] + '_' + o[i + 1] + '_' + o[i + 2])
|
104 |
+
# Calculate uniqueness of the generation
|
105 |
+
u, b, t = len(set(unigrams)) / len(unigrams), len(set(bigrams)) / len(bigrams), len(set(trigrams)) / len(trigrams)
|
106 |
+
u_total += u
|
107 |
+
b_total += b
|
108 |
+
t_total += t
|
109 |
+
return u_total / len(sequences), b_total / len(sequences), t_total / len(sequences)
|
superposed/llama/model.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
import math
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
|
8 |
+
import fairscale.nn.model_parallel.initialize as fs_init
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from fairscale.nn.model_parallel.layers import (
|
12 |
+
ColumnParallelLinear,
|
13 |
+
ParallelEmbedding,
|
14 |
+
RowParallelLinear,
|
15 |
+
)
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class ModelArgs:
|
21 |
+
dim: int = 4096
|
22 |
+
n_layers: int = 32
|
23 |
+
n_heads: int = 32
|
24 |
+
n_kv_heads: Optional[int] = None
|
25 |
+
vocab_size: int = -1 # defined later by tokenizer
|
26 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
27 |
+
ffn_dim_multiplier: Optional[float] = None
|
28 |
+
norm_eps: float = 1e-5
|
29 |
+
|
30 |
+
max_batch_size: int = 32
|
31 |
+
max_seq_len: int = 2048
|
32 |
+
|
33 |
+
|
34 |
+
class RMSNorm(torch.nn.Module):
|
35 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
36 |
+
"""
|
37 |
+
Initialize the RMSNorm normalization layer.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
dim (int): The dimension of the input tensor.
|
41 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
42 |
+
|
43 |
+
Attributes:
|
44 |
+
eps (float): A small value added to the denominator for numerical stability.
|
45 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
46 |
+
|
47 |
+
"""
|
48 |
+
super().__init__()
|
49 |
+
self.eps = eps
|
50 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
51 |
+
|
52 |
+
def _norm(self, x):
|
53 |
+
"""
|
54 |
+
Apply the RMSNorm normalization to the input tensor.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
x (torch.Tensor): The input tensor.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
torch.Tensor: The normalized tensor.
|
61 |
+
|
62 |
+
"""
|
63 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
"""
|
67 |
+
Forward pass through the RMSNorm layer.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
x (torch.Tensor): The input tensor.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
74 |
+
|
75 |
+
"""
|
76 |
+
output = self._norm(x.float()).type_as(x)
|
77 |
+
return output * self.weight
|
78 |
+
|
79 |
+
|
80 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
81 |
+
"""
|
82 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
83 |
+
|
84 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
85 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
86 |
+
The returned tensor contains complex values in complex64 data type.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
dim (int): Dimension of the frequency tensor.
|
90 |
+
end (int): End index for precomputing frequencies.
|
91 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
"""
|
100 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
101 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
102 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
103 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
104 |
+
return freqs_cis
|
105 |
+
|
106 |
+
|
107 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
108 |
+
"""
|
109 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
110 |
+
|
111 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
112 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
116 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
torch.Tensor: Reshaped frequency tensor.
|
120 |
+
|
121 |
+
Raises:
|
122 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
123 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
124 |
+
"""
|
125 |
+
ndim = x.ndim
|
126 |
+
assert 0 <= 1 < ndim
|
127 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
128 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
129 |
+
return freqs_cis.view(*shape)
|
130 |
+
|
131 |
+
|
132 |
+
def apply_rotary_emb(
|
133 |
+
xq: torch.Tensor,
|
134 |
+
xk: torch.Tensor,
|
135 |
+
freqs_cis: torch.Tensor,
|
136 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
137 |
+
"""
|
138 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
139 |
+
|
140 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
141 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
142 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
143 |
+
returned as real tensors.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings.
|
147 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings.
|
148 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
"""
|
156 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
157 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
158 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
159 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
160 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
161 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
162 |
+
|
163 |
+
|
164 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
165 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
166 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
167 |
+
if n_rep == 1:
|
168 |
+
return x
|
169 |
+
return (
|
170 |
+
x[:, :, :, None, :]
|
171 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
172 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
class Attention(nn.Module):
|
177 |
+
"""Multi-head attention module."""
|
178 |
+
def __init__(self, args: ModelArgs):
|
179 |
+
"""
|
180 |
+
Initialize the Attention module.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
args (ModelArgs): Model configuration parameters.
|
184 |
+
|
185 |
+
Attributes:
|
186 |
+
n_kv_heads (int): Number of key and value heads.
|
187 |
+
n_local_heads (int): Number of local query heads.
|
188 |
+
n_local_kv_heads (int): Number of local key and value heads.
|
189 |
+
n_rep (int): Number of repetitions for local heads.
|
190 |
+
head_dim (int): Dimension size of each attention head.
|
191 |
+
wq (ColumnParallelLinear): Linear transformation for queries.
|
192 |
+
wk (ColumnParallelLinear): Linear transformation for keys.
|
193 |
+
wv (ColumnParallelLinear): Linear transformation for values.
|
194 |
+
wo (RowParallelLinear): Linear transformation for output.
|
195 |
+
cache_k (torch.Tensor): Cached keys for attention.
|
196 |
+
cache_v (torch.Tensor): Cached values for attention.
|
197 |
+
|
198 |
+
"""
|
199 |
+
super().__init__()
|
200 |
+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
201 |
+
model_parallel_size = fs_init.get_model_parallel_world_size()
|
202 |
+
self.n_local_heads = args.n_heads // model_parallel_size
|
203 |
+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
204 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
205 |
+
self.head_dim = args.dim // args.n_heads
|
206 |
+
|
207 |
+
self.wq = ColumnParallelLinear(
|
208 |
+
args.dim,
|
209 |
+
args.n_heads * self.head_dim,
|
210 |
+
bias=False,
|
211 |
+
gather_output=False,
|
212 |
+
init_method=lambda x: x,
|
213 |
+
)
|
214 |
+
self.wk = ColumnParallelLinear(
|
215 |
+
args.dim,
|
216 |
+
self.n_kv_heads * self.head_dim,
|
217 |
+
bias=False,
|
218 |
+
gather_output=False,
|
219 |
+
init_method=lambda x: x,
|
220 |
+
)
|
221 |
+
self.wv = ColumnParallelLinear(
|
222 |
+
args.dim,
|
223 |
+
self.n_kv_heads * self.head_dim,
|
224 |
+
bias=False,
|
225 |
+
gather_output=False,
|
226 |
+
init_method=lambda x: x,
|
227 |
+
)
|
228 |
+
self.wo = RowParallelLinear(
|
229 |
+
args.n_heads * self.head_dim,
|
230 |
+
args.dim,
|
231 |
+
bias=False,
|
232 |
+
input_is_parallel=True,
|
233 |
+
init_method=lambda x: x,
|
234 |
+
)
|
235 |
+
|
236 |
+
self.cache_k = torch.zeros(
|
237 |
+
(
|
238 |
+
args.max_batch_size,
|
239 |
+
args.max_seq_len,
|
240 |
+
self.n_local_kv_heads,
|
241 |
+
self.head_dim,
|
242 |
+
)
|
243 |
+
).cuda()
|
244 |
+
self.cache_v = torch.zeros(
|
245 |
+
(
|
246 |
+
args.max_batch_size,
|
247 |
+
args.max_seq_len,
|
248 |
+
self.n_local_kv_heads,
|
249 |
+
self.head_dim,
|
250 |
+
)
|
251 |
+
).cuda()
|
252 |
+
|
253 |
+
def forward(
|
254 |
+
self,
|
255 |
+
x: torch.Tensor,
|
256 |
+
start_pos: int,
|
257 |
+
freqs_cis: torch.Tensor,
|
258 |
+
mask: Optional[torch.Tensor],
|
259 |
+
beam: Optional[bool] = None,
|
260 |
+
n_beams: Optional[int] = None,
|
261 |
+
attention_change_ids: Optional[torch.Tensor] = None
|
262 |
+
):
|
263 |
+
"""
|
264 |
+
Forward pass of the attention module.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (torch.Tensor): Input tensor.
|
268 |
+
start_pos (int): Starting position for caching.
|
269 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor.
|
270 |
+
mask (torch.Tensor, optional): Attention mask tensor.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
torch.Tensor: Output tensor after attention.
|
274 |
+
|
275 |
+
"""
|
276 |
+
bsz, seqlen, _ = x.shape
|
277 |
+
_, max_seq_len, n_local_kv_heads, head_dim = self.cache_k.shape
|
278 |
+
# KV Cache updates for beam search
|
279 |
+
if beam:
|
280 |
+
# Extract used cache values
|
281 |
+
used_cache_k = self.cache_k[:bsz]
|
282 |
+
used_cache_v = self.cache_v[:bsz]
|
283 |
+
# Reshape to apply change ids
|
284 |
+
t_cache_k = used_cache_k.reshape(bsz // n_beams, n_beams, max_seq_len, n_local_kv_heads, head_dim)
|
285 |
+
t_cache_v = used_cache_v.reshape(bsz // n_beams, n_beams, max_seq_len, n_local_kv_heads, head_dim)
|
286 |
+
used_cache_k = torch.take_along_dim(t_cache_k, attention_change_ids.reshape(-1, n_beams, 1, 1, 1), 1)
|
287 |
+
used_cache_v = torch.take_along_dim(t_cache_v, attention_change_ids.reshape(-1, n_beams, 1, 1, 1), 1)
|
288 |
+
# Update cache
|
289 |
+
self.cache_k[:bsz] = used_cache_k.reshape(bsz, max_seq_len, n_local_kv_heads, head_dim)
|
290 |
+
self.cache_v[:bsz] = used_cache_v.reshape(bsz, max_seq_len, n_local_kv_heads, head_dim)
|
291 |
+
|
292 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
293 |
+
|
294 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
295 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
296 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
297 |
+
|
298 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
299 |
+
|
300 |
+
self.cache_k = self.cache_k.to(xq)
|
301 |
+
self.cache_v = self.cache_v.to(xq)
|
302 |
+
|
303 |
+
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
304 |
+
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
305 |
+
|
306 |
+
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
307 |
+
values = self.cache_v[:bsz, : start_pos + seqlen]
|
308 |
+
|
309 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
310 |
+
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
311 |
+
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
312 |
+
|
313 |
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
314 |
+
keys = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
315 |
+
values = values.transpose(1, 2)
|
316 |
+
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) # (bs, n_local_heads, seqlen, seqlen)
|
317 |
+
if mask is not None:
|
318 |
+
scores = scores + mask # (bs, n_local_heads, seqlen, seqlen)
|
319 |
+
scores = F.softmax(scores.float(), dim=-1).type_as(xq) # (bs, n_local_heads, seqlen, seqlen)
|
320 |
+
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
321 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
322 |
+
return self.wo(output)
|
323 |
+
|
324 |
+
|
325 |
+
class FeedForward(nn.Module):
|
326 |
+
def __init__(
|
327 |
+
self,
|
328 |
+
dim: int,
|
329 |
+
hidden_dim: int,
|
330 |
+
multiple_of: int,
|
331 |
+
ffn_dim_multiplier: Optional[float],
|
332 |
+
):
|
333 |
+
"""
|
334 |
+
Initialize the FeedForward module.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
dim (int): Input dimension.
|
338 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
339 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
340 |
+
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
341 |
+
|
342 |
+
Attributes:
|
343 |
+
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
344 |
+
w2 (RowParallelLinear): Linear transformation for the second layer.
|
345 |
+
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
346 |
+
|
347 |
+
"""
|
348 |
+
super().__init__()
|
349 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
350 |
+
# custom dim factor multiplier
|
351 |
+
if ffn_dim_multiplier is not None:
|
352 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
353 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
354 |
+
|
355 |
+
self.w1 = ColumnParallelLinear(
|
356 |
+
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
|
357 |
+
)
|
358 |
+
self.w2 = RowParallelLinear(
|
359 |
+
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
|
360 |
+
)
|
361 |
+
self.w3 = ColumnParallelLinear(
|
362 |
+
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
|
363 |
+
)
|
364 |
+
|
365 |
+
def forward(self, x):
|
366 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
367 |
+
|
368 |
+
|
369 |
+
class TransformerBlock(nn.Module):
|
370 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
371 |
+
"""
|
372 |
+
Initialize a TransformerBlock.
|
373 |
+
|
374 |
+
Args:
|
375 |
+
layer_id (int): Identifier for the layer.
|
376 |
+
args (ModelArgs): Model configuration parameters.
|
377 |
+
|
378 |
+
Attributes:
|
379 |
+
n_heads (int): Number of attention heads.
|
380 |
+
dim (int): Dimension size of the model.
|
381 |
+
head_dim (int): Dimension size of each attention head.
|
382 |
+
attention (Attention): Attention module.
|
383 |
+
feed_forward (FeedForward): FeedForward module.
|
384 |
+
layer_id (int): Identifier for the layer.
|
385 |
+
attention_norm (RMSNorm): Layer normalization for attention output.
|
386 |
+
ffn_norm (RMSNorm): Layer normalization for feedforward output.
|
387 |
+
|
388 |
+
"""
|
389 |
+
super().__init__()
|
390 |
+
self.n_heads = args.n_heads
|
391 |
+
self.dim = args.dim
|
392 |
+
self.head_dim = args.dim // args.n_heads
|
393 |
+
self.attention = Attention(args)
|
394 |
+
self.feed_forward = FeedForward(
|
395 |
+
dim=args.dim,
|
396 |
+
hidden_dim=4 * args.dim,
|
397 |
+
multiple_of=args.multiple_of,
|
398 |
+
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
399 |
+
)
|
400 |
+
self.layer_id = layer_id
|
401 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
402 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
403 |
+
|
404 |
+
def forward(
|
405 |
+
self,
|
406 |
+
x: torch.Tensor,
|
407 |
+
start_pos: int,
|
408 |
+
freqs_cis: torch.Tensor,
|
409 |
+
mask: Optional[torch.Tensor],
|
410 |
+
beam: Optional[bool],
|
411 |
+
n_beams: Optional[int] = None,
|
412 |
+
attention_change_ids: Optional[torch.Tensor] = None
|
413 |
+
):
|
414 |
+
"""
|
415 |
+
Perform a forward pass through the TransformerBlock.
|
416 |
+
|
417 |
+
Args:
|
418 |
+
x (torch.Tensor): Input tensor.
|
419 |
+
start_pos (int): Starting position for attention caching.
|
420 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
421 |
+
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
torch.Tensor: Output tensor after applying attention and feedforward layers.
|
425 |
+
|
426 |
+
"""
|
427 |
+
if beam:
|
428 |
+
h = x + self.attention.forward(
|
429 |
+
self.attention_norm(x), start_pos, freqs_cis, mask, beam, n_beams, attention_change_ids
|
430 |
+
)
|
431 |
+
else:
|
432 |
+
h = x + self.attention.forward(
|
433 |
+
self.attention_norm(x), start_pos, freqs_cis, mask
|
434 |
+
)
|
435 |
+
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
436 |
+
return out
|
437 |
+
|
438 |
+
|
439 |
+
class Transformer(nn.Module):
|
440 |
+
def __init__(self, params: ModelArgs):
|
441 |
+
"""
|
442 |
+
Initialize a Transformer model.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
params (ModelArgs): Model configuration parameters.
|
446 |
+
|
447 |
+
Attributes:
|
448 |
+
params (ModelArgs): Model configuration parameters.
|
449 |
+
vocab_size (int): Vocabulary size.
|
450 |
+
n_layers (int): Number of layers in the model.
|
451 |
+
tok_embeddings (ParallelEmbedding): Token embeddings.
|
452 |
+
layers (torch.nn.ModuleList): List of Transformer blocks.
|
453 |
+
norm (RMSNorm): Layer normalization for the model output.
|
454 |
+
output (ColumnParallelLinear): Linear layer for final output.
|
455 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
456 |
+
|
457 |
+
"""
|
458 |
+
super().__init__()
|
459 |
+
self.params = params
|
460 |
+
self.vocab_size = params.vocab_size
|
461 |
+
self.n_layers = params.n_layers
|
462 |
+
|
463 |
+
self.tok_embeddings = ParallelEmbedding(
|
464 |
+
params.vocab_size, params.dim, init_method=lambda x: x
|
465 |
+
)
|
466 |
+
|
467 |
+
self.layers = torch.nn.ModuleList()
|
468 |
+
for layer_id in range(params.n_layers):
|
469 |
+
self.layers.append(TransformerBlock(layer_id, params))
|
470 |
+
|
471 |
+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
472 |
+
self.output = ColumnParallelLinear(
|
473 |
+
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
|
474 |
+
)
|
475 |
+
|
476 |
+
self.freqs_cis = precompute_freqs_cis(
|
477 |
+
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
|
478 |
+
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
|
479 |
+
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
|
480 |
+
)
|
481 |
+
|
482 |
+
|
483 |
+
@torch.inference_mode()
|
484 |
+
def forward(self,
|
485 |
+
tokens: torch.Tensor,
|
486 |
+
start_pos: int,
|
487 |
+
beam: bool,
|
488 |
+
n_beams: Optional[int] = None,
|
489 |
+
attention_change_ids: Optional[torch.Tensor] = None,
|
490 |
+
verbose: Optional[bool] = False):
|
491 |
+
"""
|
492 |
+
Perform a forward pass through the Transformer model.
|
493 |
+
|
494 |
+
Args:
|
495 |
+
tokens (torch.Tensor): Input token indices.
|
496 |
+
start_pos (int): Starting position for attention caching.
|
497 |
+
verbose (bool): Whether to return intermediate hidden layer states
|
498 |
+
|
499 |
+
Returns:
|
500 |
+
torch.Tensor or (torch.Tensor, Dict): output logits after applying the Transformer model.
|
501 |
+
|
502 |
+
"""
|
503 |
+
### ANALYSIS CODE ###
|
504 |
+
if verbose:
|
505 |
+
states = {"layers": [], "tokens": tokens}
|
506 |
+
#
|
507 |
+
|
508 |
+
_bsz, seqlen = tokens.shape
|
509 |
+
h = self.tok_embeddings(tokens)
|
510 |
+
self.freqs_cis = self.freqs_cis.to(h.device)
|
511 |
+
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
512 |
+
|
513 |
+
### ANALYSIS CODE ###
|
514 |
+
if verbose:
|
515 |
+
states["layers"].append(h)
|
516 |
+
#
|
517 |
+
|
518 |
+
mask = None
|
519 |
+
if seqlen > 1:
|
520 |
+
mask = torch.full(
|
521 |
+
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
|
522 |
+
)
|
523 |
+
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
524 |
+
|
525 |
+
for layer in self.layers:
|
526 |
+
if not beam:
|
527 |
+
h = layer(h, start_pos, freqs_cis, mask, beam)
|
528 |
+
else:
|
529 |
+
h = layer(h, start_pos, freqs_cis, mask, beam, n_beams, attention_change_ids)
|
530 |
+
### ANALYSIS CODE ###
|
531 |
+
if verbose:
|
532 |
+
states["layers"].append(h)
|
533 |
+
#
|
534 |
+
h = self.norm(h)
|
535 |
+
# if want differences, at end, subtract differences from [-1] position of embedding vectors each iteration
|
536 |
+
|
537 |
+
### ANALYSIS CODE ###
|
538 |
+
if verbose:
|
539 |
+
states["layers"].append(h)
|
540 |
+
#
|
541 |
+
|
542 |
+
output = self.output(h).float()
|
543 |
+
|
544 |
+
if verbose:
|
545 |
+
return output, states
|
546 |
+
else:
|
547 |
+
return output
|
548 |
+
|
superposed/llama/superpose.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation loosely based on https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L554
|
2 |
+
import requests
|
3 |
+
import time
|
4 |
+
from datetime import datetime, timedelta
|
5 |
+
from typing import Optional, Literal
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from transformers import LlamaTokenizer
|
10 |
+
|
11 |
+
from superposed.llama.utils import *
|
12 |
+
from superposed.ngrams.ngram_models import NGram
|
13 |
+
|
14 |
+
INF = 1. * 1e7
|
15 |
+
|
16 |
+
# Test by scaling # beams & verify work
|
17 |
+
class Superpose(nn.Module):
|
18 |
+
def __init__(self,
|
19 |
+
initial_tokens,
|
20 |
+
tokenizer,
|
21 |
+
vocab_size,
|
22 |
+
smoothing=Optional[Literal["geom", "all"]],
|
23 |
+
alpha = None,
|
24 |
+
verbose = False,
|
25 |
+
i_weights = None,
|
26 |
+
i_length = None,
|
27 |
+
ngrams = None,
|
28 |
+
sample_beams = False,
|
29 |
+
sample_tokens = False,
|
30 |
+
get_time = False,
|
31 |
+
penalty = 200): # default no effect
|
32 |
+
"""
|
33 |
+
Initialize a beam search class.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
initial_tokens (torch.Tensor): Initial tokens
|
37 |
+
n_prompts (int): Number of prompts
|
38 |
+
tokenizer (Tokenizer): Llama tokenizer
|
39 |
+
vocab_size (int): Total vocab size
|
40 |
+
smoothing (str): Smoothing method ("geom" for default, "all" for only ngram, None for no ngram)
|
41 |
+
ngram_length (int): N gram length to consider
|
42 |
+
alpha (float): Alpha parameter
|
43 |
+
debug (bool): Whether to print information
|
44 |
+
"""
|
45 |
+
super().__init__()
|
46 |
+
# primary parameters
|
47 |
+
self.n_prompts, self.n_drafts, _ = initial_tokens.shape
|
48 |
+
self.tokenizer = tokenizer
|
49 |
+
self.vocab_size = vocab_size
|
50 |
+
self.alive_seq = initial_tokens
|
51 |
+
self.fin_seq = initial_tokens
|
52 |
+
self.smoothing = smoothing
|
53 |
+
self.alive_log_probs = torch.zeros(self.n_prompts, self.n_drafts)
|
54 |
+
self.fin_log_probs = torch.full((self.n_prompts, self.n_drafts), float("-inf"))
|
55 |
+
self.alpha = alpha
|
56 |
+
self.verbose = verbose
|
57 |
+
self.penalty = penalty
|
58 |
+
# devices
|
59 |
+
self.cpu = torch.device('cpu')
|
60 |
+
self.gpu = torch.device('cuda')
|
61 |
+
# Interpolation length and weights
|
62 |
+
self.interpolation_weights = i_weights
|
63 |
+
self.i_length = i_length
|
64 |
+
# N-grams
|
65 |
+
self.bigram = ngrams[0] if len(ngrams) >= 1 else None
|
66 |
+
self.trigram = ngrams[1] if len(ngrams) >= 2 else None
|
67 |
+
self.fourgram = ngrams[2] if len(ngrams) >= 3 else None
|
68 |
+
self.fivegram = ngrams[3] if len(ngrams) >= 4 else None
|
69 |
+
self.sixgram = ngrams[4] if len(ngrams) >= 5 else None
|
70 |
+
self.sevengram = ngrams[5] if len(ngrams) >= 6 else None
|
71 |
+
# Timing
|
72 |
+
self.get_time = get_time
|
73 |
+
self.lookup_time = None
|
74 |
+
|
75 |
+
def forward(self, probs, still_prompt, is_first, cur_pos, n_token_sample):
|
76 |
+
"""
|
77 |
+
Apply beam decoding to update generations.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
probs (torch.Tensor): Next token probability distribution
|
81 |
+
still_prompt (torch.Tensor): Flags of prompts that should not generate yet (n_prompts, )
|
82 |
+
is_first (torch.Tensor): Flags of prompts that are on their first generation (n_prompts, )
|
83 |
+
cur_pos (int): Current generation position
|
84 |
+
n_token_sample (int): Number of tokens from model distribution to use
|
85 |
+
|
86 |
+
Return:
|
87 |
+
if standard beam search:
|
88 |
+
attention_change_ids (torch.Tensor): New indices in kv cache (n_prompts, n_drafts)
|
89 |
+
if mixed:
|
90 |
+
token_weights (torch.Tensor): Mixing weights (n_prompts, vocab_size)
|
91 |
+
"""
|
92 |
+
# Adjust input probabilities
|
93 |
+
probs = self.get_top_k(probs, 32000, n_token_sample)
|
94 |
+
reshaped_probs = probs.reshape(self.n_prompts, 1, -1)
|
95 |
+
reshaped_probs = reshaped_probs.repeat(1, self.n_drafts, 1)
|
96 |
+
# Ngram smoothing
|
97 |
+
if self.smoothing is not None:
|
98 |
+
if self.smoothing == "geom":
|
99 |
+
ngram_probs = self.ngram_probs(self.alive_seq, cur_pos, probs=probs)
|
100 |
+
# Make mask and normalize
|
101 |
+
prob_mask = reshaped_probs != 0
|
102 |
+
ngram_probs *= prob_mask
|
103 |
+
# Calculate logprobs and interpolate distributions
|
104 |
+
llm_log_probs = torch.log(reshaped_probs)
|
105 |
+
ngram_log_probs = torch.log(ngram_probs)
|
106 |
+
log_probs = (1 - self.alpha) * llm_log_probs + self.alpha * ngram_log_probs
|
107 |
+
# Apply penalty to drafts where no interpolation occurred
|
108 |
+
is_all_inf = (log_probs != float("-inf")).sum(dim=-1, keepdims=True) == 0
|
109 |
+
log_probs = torch.where(is_all_inf, (1 - self.alpha) * llm_log_probs - self.penalty, log_probs)
|
110 |
+
elif self.smoothing == "all":
|
111 |
+
ngram_probs = self.ngram_probs(self.alive_seq, cur_pos, probs=None)
|
112 |
+
log_probs = torch.log(ngram_probs)
|
113 |
+
else:
|
114 |
+
log_probs = torch.log(reshaped_probs)
|
115 |
+
curr_log_probs = self.alive_log_probs.unsqueeze(dim=2) + log_probs # [n_prompts, n_drafts, vocab_size]
|
116 |
+
# Warning if nan
|
117 |
+
if (torch.any(torch.isnan(curr_log_probs)).item()):
|
118 |
+
raise RuntimeWarning("nan in sequence log probs", file=self.output_file)
|
119 |
+
# Potential Sequences
|
120 |
+
flat_curr_log_probs = curr_log_probs.reshape(-1, self.vocab_size*self.n_drafts)
|
121 |
+
topk_log_probs, topk_idx = torch.topk(flat_curr_log_probs, 2 * self.n_drafts, dim=-1)
|
122 |
+
topk_beam_id = topk_idx // self.vocab_size # [n_prompts, 2 * n_drafts]
|
123 |
+
topk_idx = topk_idx % self.vocab_size # [n_prompts, 2 * n_drafts]
|
124 |
+
# First timestep uses top-k next tokens
|
125 |
+
is_first_idx = is_first.nonzero(as_tuple=True)[0]
|
126 |
+
if len(is_first_idx) != 0:
|
127 |
+
first_time_log_probs = log_probs[is_first_idx][:, 0, :].squeeze(dim=1)
|
128 |
+
first_time_log_probs, first_time_topk_idx = torch.topk(first_time_log_probs, 2 * self.n_drafts, dim=1)
|
129 |
+
topk_idx[is_first_idx] = first_time_topk_idx
|
130 |
+
topk_log_probs[is_first_idx] = self.alive_log_probs[is_first_idx, 0].unsqueeze(dim=1) + first_time_log_probs
|
131 |
+
# New sequences
|
132 |
+
topk_seq = torch.take_along_dim(self.alive_seq, topk_beam_id.unsqueeze(2), dim=1) # [n_prompts, 2 * n_drafts, vocab_size]
|
133 |
+
topk_seq[:, :, cur_pos] = topk_idx
|
134 |
+
topk_finished = topk_idx == self.tokenizer.eos_id
|
135 |
+
# Only update sequences for those that have begun generating
|
136 |
+
new_alive_seq, new_alive_log_probs = self.grow_alive(topk_seq, topk_log_probs, topk_finished)
|
137 |
+
new_fin_seq, new_fin_log_probs = self.grow_fin(topk_seq, topk_log_probs, topk_finished)
|
138 |
+
still_prompt_probs = still_prompt.reshape(-1, 1)
|
139 |
+
still_prompt_seqs = still_prompt.reshape(-1, 1, 1)
|
140 |
+
self.alive_seq = torch.where(still_prompt_seqs, self.alive_seq, new_alive_seq)
|
141 |
+
self.alive_log_probs = torch.where(still_prompt_probs, self.alive_log_probs, new_alive_log_probs)
|
142 |
+
self.fin_seq = torch.where(still_prompt_seqs, self.fin_seq, new_fin_seq)
|
143 |
+
self.fin_log_probs = torch.where(still_prompt_probs, self.fin_log_probs, new_fin_log_probs)
|
144 |
+
# Create superposition matrix and return it
|
145 |
+
topk_idx = self.alive_seq[:, :, cur_pos].reshape(self.n_prompts, -1)
|
146 |
+
token_weights = self.superposition_matrix(topk_idx)
|
147 |
+
return token_weights
|
148 |
+
|
149 |
+
def grow_alive(self, topk_seq, topk_log_probs, topk_finished):
|
150 |
+
"""
|
151 |
+
Extend running generations.
|
152 |
+
Args:
|
153 |
+
topk_seq (torch.Tensor): Top k sequences (n_prompts, 2 * n_drafts, vocab_size)
|
154 |
+
topk_log_probs (torch.Tensor): Log probabilities (n_prompts, 2 * n_drafts)
|
155 |
+
topk_finished (torch.Tensor): Whether a sequence is finished (n_prompts, 2 * n_drafts)
|
156 |
+
Returns:
|
157 |
+
new_alive_seq, new_alive_log_probs
|
158 |
+
"""
|
159 |
+
topk_log_probs = topk_log_probs + topk_finished * -INF
|
160 |
+
new_alive_log_probs, new_alive_idx = torch.topk(topk_log_probs, self.n_drafts, dim=1)
|
161 |
+
new_alive_seq = torch.take_along_dim(topk_seq, new_alive_idx.unsqueeze(2), dim=1)
|
162 |
+
return new_alive_seq, new_alive_log_probs
|
163 |
+
|
164 |
+
def grow_fin(self, topk_seq, topk_log_probs, topk_finished):
|
165 |
+
"""
|
166 |
+
Update stopped generations.
|
167 |
+
Args:
|
168 |
+
topk_seq (torch.Tensor): Top k sequences (n_prompts, 2 * n_drafts, vocab_size)
|
169 |
+
topk_log_probs (torch.Tensor): Log probabilities (n_prompts, 2 * n_drafts)
|
170 |
+
topk_finished (torch.Tensor): Whether a sequence is finished (n_prompts, 2 * n_drafts)
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
new_fin_seq, new_fin_log_probs
|
174 |
+
"""
|
175 |
+
topk_log_probs = topk_log_probs + ~topk_finished * -INF
|
176 |
+
new_fin_seq = torch.cat([self.fin_seq, topk_seq], dim=1)
|
177 |
+
new_fin_log_probs = torch.cat([self.fin_log_probs, topk_log_probs], dim=1)
|
178 |
+
new_fin_log_probs, new_fin_idx = torch.topk(new_fin_log_probs, self.n_drafts, dim=1)
|
179 |
+
new_fin_seq = torch.take_along_dim(new_fin_seq, new_fin_idx.unsqueeze(2), dim=1)
|
180 |
+
return new_fin_seq, new_fin_log_probs
|
181 |
+
|
182 |
+
def get_top_k(self, probs, m, k):
|
183 |
+
"""
|
184 |
+
Zero out all but top-k tokens in a probability distribution.
|
185 |
+
Args:
|
186 |
+
probs (torch.Tensor): Probability distribution tensor.
|
187 |
+
m (float): Number of tokens to consider (only relevant when sampling).
|
188 |
+
k (int): Number of tokens to sample/keep.
|
189 |
+
Returns:
|
190 |
+
torch.Tensor: New probability distribution based on renormalized probabilities.
|
191 |
+
"""
|
192 |
+
n_prompts, _ = probs.shape
|
193 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
194 |
+
top_k_mask = torch.arange(probs.shape[-1])
|
195 |
+
top_k_mask = top_k_mask.expand(probs.shape[0], -1)
|
196 |
+
top_k_mask = top_k_mask >= m # Set to 1 past k elements
|
197 |
+
probs_sort[top_k_mask] = 0.0 # Zero wherever mask = 1
|
198 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
199 |
+
next_token = torch.gather(probs_idx, -1, torch.topk(probs_sort, k, dim=-1)[1])
|
200 |
+
# Set all other probs to 0
|
201 |
+
new_probs_map = torch.zeros(probs.shape).bool()
|
202 |
+
new_probs_map[torch.repeat_interleave(torch.arange(n_prompts), k), torch.flatten(next_token)] = True
|
203 |
+
new_probs = torch.where(new_probs_map, probs, 0)
|
204 |
+
# Renormalize
|
205 |
+
new_probs.div_(new_probs.sum(dim=-1, keepdim=True))
|
206 |
+
return new_probs
|
207 |
+
|
208 |
+
def superposition_matrix(self, tokens):
|
209 |
+
"""
|
210 |
+
Create superposition matrix based on provided tokens.
|
211 |
+
Args:
|
212 |
+
tokens (torch.Tensor): Tokens to mix (n_prompts, n_drafts)
|
213 |
+
Returns:
|
214 |
+
SUperposition matrix
|
215 |
+
"""
|
216 |
+
# Create superposition matrix
|
217 |
+
mixing_matrix = torch.zeros(self.n_prompts, self.vocab_size)
|
218 |
+
# Convert draft log probs to probabilities
|
219 |
+
weightings = log_prob_to_prob(self.alive_log_probs)
|
220 |
+
# Update probabilities in superposition matrix with draft probabilities
|
221 |
+
for p_idx in range(self.n_prompts):
|
222 |
+
for d_idx in range(self.n_drafts):
|
223 |
+
tok_idx = tokens[p_idx][d_idx]
|
224 |
+
mixing_matrix[p_idx][tok_idx] += weightings[p_idx][d_idx]
|
225 |
+
# Renormalize
|
226 |
+
mixing_matrix.div_(mixing_matrix.sum(dim=-1, keepdims=True))
|
227 |
+
return mixing_matrix
|
228 |
+
|
229 |
+
def ngram_probs(self, alive_seq, cur_pos, probs):
|
230 |
+
"""
|
231 |
+
Calculate and return next token distribution using ngram models.
|
232 |
+
Args:
|
233 |
+
alive_seq (torch.Tensor): Current drafts (n_prompts, n_drafts, seqlen)
|
234 |
+
cur_pos (int): Current timestep
|
235 |
+
probs (torch.Tensor): Current next probability distribution from model (n_prompts, vocab_size).
|
236 |
+
As described in the paper, only tokens w/nonzero probability in `prob` are considered for the
|
237 |
+
ngram distribution. However, passing in `None` as `probs` will consider all tokens.
|
238 |
+
Returns:
|
239 |
+
Next token distribution for each draft (n_prompts, n_drafts, vocab_size)
|
240 |
+
"""
|
241 |
+
if self.get_time:
|
242 |
+
# Start timer
|
243 |
+
start_time = datetime.now()
|
244 |
+
# Create distribution matrix
|
245 |
+
next_token_probs = torch.zeros(self.n_prompts, self.n_drafts, 32000)
|
246 |
+
if probs is not None:
|
247 |
+
# Loop over all prefixes
|
248 |
+
for p_idx in range(len(alive_seq)):
|
249 |
+
# List of possible tokens for the prefix
|
250 |
+
nz = torch.nonzero(probs[p_idx, :], as_tuple=True)[0].tolist()
|
251 |
+
# Generate next token distribution
|
252 |
+
for draft_idx in range(self.n_drafts):
|
253 |
+
i_mask = torch.sum(torch.tensor(self.i_length) <= cur_pos)
|
254 |
+
new_i_weights = self.interpolation_weights[:i_mask]
|
255 |
+
new_i_length = self.i_length[:i_mask]
|
256 |
+
# For each next token
|
257 |
+
for nt in nz:
|
258 |
+
# Calculate probability using ngram interpolation
|
259 |
+
for i, weight in zip(new_i_length, new_i_weights):
|
260 |
+
if cur_pos - i >= 0:
|
261 |
+
key = tuple(alive_seq[p_idx, draft_idx, cur_pos-i:cur_pos].tolist())
|
262 |
+
if i == 1:
|
263 |
+
prob = self.bigram.prob(key, nt)
|
264 |
+
elif i == 2:
|
265 |
+
prob = self.trigram.prob(key, nt)
|
266 |
+
elif i == 3:
|
267 |
+
prob = self.fourgram.prob(key, nt)
|
268 |
+
elif i == 4:
|
269 |
+
prob = self.fivegram.prob(key, nt)
|
270 |
+
elif i == 5:
|
271 |
+
prob = self.sixgram.prob(key, nt)
|
272 |
+
elif i == 6:
|
273 |
+
prob = self.sevengram.prob(key, nt)
|
274 |
+
if prob >= 0:
|
275 |
+
next_token_probs[p_idx, draft_idx, nt] += weight * prob
|
276 |
+
else:
|
277 |
+
for p_idx in range(len(alive_seq)):
|
278 |
+
for draft_idx in range(self.n_drafts):
|
279 |
+
i_mask = torch.sum(torch.tensor(self.i_length) <= cur_pos)
|
280 |
+
new_i_weights = self.interpolation_weights[:i_mask]
|
281 |
+
new_i_length = self.i_length[:i_mask]
|
282 |
+
for i, weight in zip(new_i_length, new_i_weights):
|
283 |
+
if cur_pos - i >= 0:
|
284 |
+
key = tuple(alive_seq[p_idx, draft_idx, cur_pos-i:cur_pos].tolist())
|
285 |
+
if i == 1:
|
286 |
+
ntd = self.bigram.ntd(key)
|
287 |
+
elif i == 2:
|
288 |
+
ntd = self.trigram.ntd(key)
|
289 |
+
elif i == 3:
|
290 |
+
ntd = self.fourgram.ntd(key)
|
291 |
+
elif i == 4:
|
292 |
+
ntd = self.fivegram.ntd(key)
|
293 |
+
elif i == 5:
|
294 |
+
ntd = self.sixgram.ntd(key)
|
295 |
+
elif i == 6:
|
296 |
+
ntd = self.sevengram.ntd(key)
|
297 |
+
if ntd is not None:
|
298 |
+
next_token_probs[p_idx, draft_idx, :] += weight * ntd
|
299 |
+
if self.get_time:
|
300 |
+
total_time = datetime.now() - start_time
|
301 |
+
self.lookup_time = total_time if self.lookup_time is None else self.lookup_time + total_time
|
302 |
+
return next_token_probs
|
303 |
+
|
304 |
+
def return_results(self, prompt_len=None):
|
305 |
+
"""
|
306 |
+
Return generations and perplexities
|
307 |
+
|
308 |
+
Args:
|
309 |
+
prompt_len (int): Length of prompt in tokens. If is None, then ppl is not calculated.
|
310 |
+
Returns:
|
311 |
+
(self.alive_seq, alive_ppl), (self.fin_seq, fin_ppl)
|
312 |
+
OR
|
313 |
+
(self.alive_seq, alive_ppl), (self.fin_seq, fin_ppl), self.lookup_time
|
314 |
+
"""
|
315 |
+
# PPL
|
316 |
+
alive_ppl = 0
|
317 |
+
fin_ppl = 0
|
318 |
+
if prompt_len is not None:
|
319 |
+
alive_ppl = torch.exp(self.alive_log_probs / (-1 * (self.alive_seq.size(dim=-1)-prompt_len)))
|
320 |
+
# Fin ppl
|
321 |
+
fin_seq_lengths = (self.fin_seq != self.tokenizer.pad_id).sum(dim=-1)
|
322 |
+
fin_ppl = torch.exp(self.fin_log_probs / (-1 * (fin_seq_lengths - prompt_len)))
|
323 |
+
fin_ppl += ((fin_ppl == 0) * float("inf"))
|
324 |
+
# print time
|
325 |
+
if not self.get_time:
|
326 |
+
return (self.alive_seq.to(torch.long), alive_ppl), (self.fin_seq.to(torch.long), fin_ppl)
|
327 |
+
else:
|
328 |
+
return (self.alive_seq.to(torch.long), alive_ppl), (self.fin_seq.to(torch.long), fin_ppl), self.lookup_time
|
superposed/llama/superposed_generation.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List, Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from fairscale.nn.model_parallel.initialize import (
|
14 |
+
get_model_parallel_rank,
|
15 |
+
initialize_model_parallel,
|
16 |
+
model_parallel_is_initialized,
|
17 |
+
)
|
18 |
+
|
19 |
+
from superposed.llama.model import ModelArgs
|
20 |
+
from superposed.llama.superposed_model import SuperposedTransformer
|
21 |
+
from superposed.llama.tokenizer import Tokenizer
|
22 |
+
from superposed.llama.superpose import Superpose
|
23 |
+
from superposed.llama.utils import *
|
24 |
+
from superposed.ngrams.ngram_models import make_models
|
25 |
+
|
26 |
+
class SuperposedLlama:
|
27 |
+
@staticmethod
|
28 |
+
def build(
|
29 |
+
ckpt_dir: str,
|
30 |
+
tokenizer_path: str,
|
31 |
+
max_seq_len: int,
|
32 |
+
max_batch_size: int,
|
33 |
+
device = None,
|
34 |
+
model_parallel_size: Optional[int] = None,
|
35 |
+
seed: int = 1,
|
36 |
+
):
|
37 |
+
if not torch.distributed.is_initialized():
|
38 |
+
torch.distributed.init_process_group("nccl")
|
39 |
+
if not model_parallel_is_initialized():
|
40 |
+
if model_parallel_size is None:
|
41 |
+
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
42 |
+
initialize_model_parallel(model_parallel_size)
|
43 |
+
|
44 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
45 |
+
if device == None:
|
46 |
+
torch.cuda.set_device(local_rank)
|
47 |
+
device = torch.cuda.current_device()
|
48 |
+
torch.manual_seed(seed)
|
49 |
+
|
50 |
+
if local_rank > 0:
|
51 |
+
sys.stdout = open(os.devnull, "w")
|
52 |
+
|
53 |
+
start_time = time.time()
|
54 |
+
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
55 |
+
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
56 |
+
assert model_parallel_size == len(
|
57 |
+
checkpoints
|
58 |
+
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
59 |
+
ckpt_path = checkpoints[get_model_parallel_rank()]
|
60 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
61 |
+
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
62 |
+
params = json.loads(f.read())
|
63 |
+
|
64 |
+
model_args: ModelArgs = ModelArgs(
|
65 |
+
max_seq_len=max_seq_len,
|
66 |
+
max_batch_size=max_batch_size,
|
67 |
+
**params,
|
68 |
+
)
|
69 |
+
tokenizer = Tokenizer(model_path=tokenizer_path)
|
70 |
+
model_args.vocab_size = tokenizer.n_words
|
71 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
72 |
+
# Set up superposed decoding
|
73 |
+
model = SuperposedTransformer(model_args)
|
74 |
+
model.load_state_dict(checkpoint, strict=False)
|
75 |
+
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
76 |
+
return SuperposedLlama(model, tokenizer, device)
|
77 |
+
|
78 |
+
def __init__(self, model: SuperposedTransformer, tokenizer: Tokenizer, device):
|
79 |
+
print(device)
|
80 |
+
self.model = model.to(device).eval()
|
81 |
+
self.tokenizer = tokenizer
|
82 |
+
self.device = device
|
83 |
+
|
84 |
+
@torch.inference_mode()
|
85 |
+
def sup_generate(
|
86 |
+
self,
|
87 |
+
prompt_tokens: List[List[int]],
|
88 |
+
smoothing,
|
89 |
+
max_gen_len: int,
|
90 |
+
n_token_sample: int,
|
91 |
+
alpha: int, # weight on bigram probs
|
92 |
+
temp: int,
|
93 |
+
n_drafts: int = 1, # number of beams
|
94 |
+
verbose: bool = False,
|
95 |
+
i_weights = None,
|
96 |
+
i_length = None,
|
97 |
+
ngrams = None,
|
98 |
+
get_time: bool = False,
|
99 |
+
penalty = 200
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
Run multi-sequence generation using superposed embeddings.
|
103 |
+
Args:
|
104 |
+
prompt_tokens (List[List[int]]): Initial tokenized prompts
|
105 |
+
max_gen_len (int): Maximum numbers of tokens to generate
|
106 |
+
alpha (float): Alpha value
|
107 |
+
temp (float): Temperature
|
108 |
+
n_drafts (int): Number of drafts
|
109 |
+
verbose (bool): Whether to save intermediate embeddings for analysis
|
110 |
+
bsz (int): Batch size (default = 16)
|
111 |
+
i_weights (List[float]): Ngram interpolation weights
|
112 |
+
i_length (List[int]): Ngram models to interpolate (1 for bigram, 2 for trigram, etc.)
|
113 |
+
ngrams (Tuple): Ngram models
|
114 |
+
get_time (bool): Return information on time spent doing Ngram lookup
|
115 |
+
penalty (float): Penalty on uninterpolated drafts
|
116 |
+
Returns:
|
117 |
+
(alive_seq, alive_ppl), (fin_seq, fin_ppl): Tuple of (n_prompts, n_drafts, seqlen),
|
118 |
+
(n_prompts, n_drafts) for sequences still generating and sequences that have finished.
|
119 |
+
"""
|
120 |
+
# Check batch size and prompt lengths
|
121 |
+
params = self.model.params
|
122 |
+
bsz = len(prompt_tokens)
|
123 |
+
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
124 |
+
|
125 |
+
min_prompt_len = min(len(t) for t in prompt_tokens)
|
126 |
+
max_prompt_len = max(len(t) for t in prompt_tokens)
|
127 |
+
prompt_len = min_prompt_len
|
128 |
+
assert max_prompt_len <= params.max_seq_len
|
129 |
+
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
|
130 |
+
pad_id = self.tokenizer.pad_id
|
131 |
+
|
132 |
+
# Initialize token tensor and pad where necessary
|
133 |
+
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device)
|
134 |
+
for k, t in enumerate(prompt_tokens):
|
135 |
+
tokens[k, :len(t)] = torch.tensor(t, dtype=torch.long, device=self.device)
|
136 |
+
|
137 |
+
# If no generation is possible
|
138 |
+
if min_prompt_len == total_len:
|
139 |
+
raise RuntimeError("no generation possible")
|
140 |
+
|
141 |
+
# Initialize decoding object
|
142 |
+
initial_tokens = tokens.unsqueeze(1).repeat(1, n_drafts, 1)
|
143 |
+
superpose = Superpose(initial_tokens,
|
144 |
+
tokenizer=self.tokenizer,
|
145 |
+
vocab_size=params.vocab_size,
|
146 |
+
smoothing=smoothing,
|
147 |
+
alpha=alpha,
|
148 |
+
i_weights=i_weights,
|
149 |
+
i_length=i_length,
|
150 |
+
ngrams=ngrams,
|
151 |
+
get_time=get_time,
|
152 |
+
penalty=penalty)
|
153 |
+
unseen_first = torch.ones(bsz)
|
154 |
+
# Superposition matrix
|
155 |
+
token_weights = torch.zeros(bsz, self.model.vocab_size)
|
156 |
+
if verbose:
|
157 |
+
state_list = []
|
158 |
+
prev_pos = 0
|
159 |
+
# Begin inference
|
160 |
+
for cur_pos in range(min_prompt_len, total_len):
|
161 |
+
input_text_mask = tokens != pad_id
|
162 |
+
# Take model step
|
163 |
+
if cur_pos == min_prompt_len:
|
164 |
+
token_weights = None
|
165 |
+
logits = self.model.forward(tokens[:, prev_pos:cur_pos],
|
166 |
+
start_pos=prev_pos,
|
167 |
+
token_weights=token_weights,
|
168 |
+
verbose=verbose)
|
169 |
+
if verbose:
|
170 |
+
logits, states = logits
|
171 |
+
# Softmax
|
172 |
+
if temp > 0:
|
173 |
+
probs = torch.softmax(logits[:, -1] / temp, dim=-1)
|
174 |
+
else:
|
175 |
+
raise RuntimeError("Temperature must be greater than 0 while mixing")
|
176 |
+
if verbose:
|
177 |
+
states["end_probs"] = probs
|
178 |
+
state_list.append(states)
|
179 |
+
# Flag prompts on first generation
|
180 |
+
is_first = torch.mul(tokens[:, cur_pos] == pad_id, unseen_first)
|
181 |
+
unseen_first[is_first.nonzero(as_tuple=True)[0]] = 0
|
182 |
+
# Flag prompts not yet generating
|
183 |
+
still_prompt = input_text_mask[:, cur_pos]
|
184 |
+
# Superposition pass
|
185 |
+
token_weights = superpose(probs, still_prompt, is_first, cur_pos, n_token_sample)
|
186 |
+
# Do not superpose for prompts not yet generating
|
187 |
+
keep_idx = input_text_mask[:, cur_pos].ravel().nonzero()
|
188 |
+
keep_token_weights = torch.zeros_like(token_weights)
|
189 |
+
keep_token_weights[keep_idx, tokens[keep_idx, cur_pos]] = 1
|
190 |
+
token_weights = torch.where(input_text_mask[:, cur_pos].unsqueeze(1).expand(-1, self.model.vocab_size),
|
191 |
+
keep_token_weights, token_weights)
|
192 |
+
prev_pos = cur_pos
|
193 |
+
results = superpose.return_results(prompt_len)
|
194 |
+
if verbose:
|
195 |
+
torch.save(state_list, "../embeddings.pt")
|
196 |
+
return results
|
197 |
+
else:
|
198 |
+
return results
|
superposed/llama/superposed_model.py
ADDED
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
import math
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
|
8 |
+
import fairscale.nn.model_parallel.initialize as fs_init
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from fairscale.nn.model_parallel.layers import (
|
12 |
+
ColumnParallelLinear,
|
13 |
+
ParallelEmbedding,
|
14 |
+
RowParallelLinear,
|
15 |
+
)
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class ModelArgs:
|
21 |
+
dim: int = 4096
|
22 |
+
n_layers: int = 32
|
23 |
+
n_heads: int = 32
|
24 |
+
n_kv_heads: Optional[int] = None
|
25 |
+
vocab_size: int = -1 # defined later by tokenizer
|
26 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
27 |
+
ffn_dim_multiplier: Optional[float] = None
|
28 |
+
norm_eps: float = 1e-5
|
29 |
+
|
30 |
+
max_batch_size: int = 32
|
31 |
+
max_seq_len: int = 2048
|
32 |
+
|
33 |
+
|
34 |
+
class RMSNorm(torch.nn.Module):
|
35 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
36 |
+
"""
|
37 |
+
Initialize the RMSNorm normalization layer.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
dim (int): The dimension of the input tensor.
|
41 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
42 |
+
|
43 |
+
Attributes:
|
44 |
+
eps (float): A small value added to the denominator for numerical stability.
|
45 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
46 |
+
|
47 |
+
"""
|
48 |
+
super().__init__()
|
49 |
+
self.eps = eps
|
50 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
51 |
+
|
52 |
+
def _norm(self, x):
|
53 |
+
"""
|
54 |
+
Apply the RMSNorm normalization to the input tensor.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
x (torch.Tensor): The input tensor.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
torch.Tensor: The normalized tensor.
|
61 |
+
|
62 |
+
"""
|
63 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
"""
|
67 |
+
Forward pass through the RMSNorm layer.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
x (torch.Tensor): The input tensor.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
74 |
+
|
75 |
+
"""
|
76 |
+
output = self._norm(x.float()).type_as(x)
|
77 |
+
k = output * self.weight
|
78 |
+
return k
|
79 |
+
|
80 |
+
|
81 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
82 |
+
"""
|
83 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
84 |
+
|
85 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
86 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
87 |
+
The returned tensor contains complex values in complex64 data type.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
dim (int): Dimension of the frequency tensor.
|
91 |
+
end (int): End index for precomputing frequencies.
|
92 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
"""
|
101 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
102 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
103 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
104 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
105 |
+
return freqs_cis
|
106 |
+
|
107 |
+
|
108 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
109 |
+
"""
|
110 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
111 |
+
|
112 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
113 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
117 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
torch.Tensor: Reshaped frequency tensor.
|
121 |
+
|
122 |
+
Raises:
|
123 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
124 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
125 |
+
"""
|
126 |
+
ndim = x.ndim
|
127 |
+
assert 0 <= 1 < ndim
|
128 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
129 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
130 |
+
return freqs_cis.view(*shape)
|
131 |
+
|
132 |
+
|
133 |
+
def apply_rotary_emb(
|
134 |
+
xq: torch.Tensor,
|
135 |
+
xk: torch.Tensor,
|
136 |
+
freqs_cis: torch.Tensor,
|
137 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
138 |
+
"""
|
139 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
140 |
+
|
141 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
142 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
143 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
144 |
+
returned as real tensors.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings.
|
148 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings.
|
149 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
"""
|
157 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
158 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
159 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
160 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
161 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
162 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
163 |
+
|
164 |
+
|
165 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
166 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
167 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
168 |
+
if n_rep == 1:
|
169 |
+
return x
|
170 |
+
return (
|
171 |
+
x[:, :, :, None, :]
|
172 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
173 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
class Attention(nn.Module):
|
178 |
+
"""Multi-head attention module."""
|
179 |
+
def __init__(self, args: ModelArgs):
|
180 |
+
"""
|
181 |
+
Initialize the Attention module.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
args (ModelArgs): Model configuration parameters.
|
185 |
+
|
186 |
+
Attributes:
|
187 |
+
n_kv_heads (int): Number of key and value heads.
|
188 |
+
n_local_heads (int): Number of local query heads.
|
189 |
+
n_local_kv_heads (int): Number of local key and value heads.
|
190 |
+
n_rep (int): Number of repetitions for local heads.
|
191 |
+
head_dim (int): Dimension size of each attention head.
|
192 |
+
wq (ColumnParallelLinear): Linear transformation for queries.
|
193 |
+
wk (ColumnParallelLinear): Linear transformation for keys.
|
194 |
+
wv (ColumnParallelLinear): Linear transformation for values.
|
195 |
+
wo (RowParallelLinear): Linear transformation for output.
|
196 |
+
cache_k (torch.Tensor): Cached keys for attention.
|
197 |
+
cache_v (torch.Tensor): Cached values for attention.
|
198 |
+
|
199 |
+
"""
|
200 |
+
super().__init__()
|
201 |
+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
202 |
+
model_parallel_size = fs_init.get_model_parallel_world_size()
|
203 |
+
self.n_local_heads = args.n_heads // model_parallel_size
|
204 |
+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
205 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
206 |
+
self.head_dim = args.dim // args.n_heads
|
207 |
+
|
208 |
+
self.wq = ColumnParallelLinear(
|
209 |
+
args.dim,
|
210 |
+
args.n_heads * self.head_dim,
|
211 |
+
bias=False,
|
212 |
+
gather_output=False,
|
213 |
+
init_method=lambda x: x,
|
214 |
+
)
|
215 |
+
self.wk = ColumnParallelLinear(
|
216 |
+
args.dim,
|
217 |
+
self.n_kv_heads * self.head_dim,
|
218 |
+
bias=False,
|
219 |
+
gather_output=False,
|
220 |
+
init_method=lambda x: x,
|
221 |
+
)
|
222 |
+
self.wv = ColumnParallelLinear(
|
223 |
+
args.dim,
|
224 |
+
self.n_kv_heads * self.head_dim,
|
225 |
+
bias=False,
|
226 |
+
gather_output=False,
|
227 |
+
init_method=lambda x: x,
|
228 |
+
)
|
229 |
+
self.wo = RowParallelLinear(
|
230 |
+
args.n_heads * self.head_dim,
|
231 |
+
args.dim,
|
232 |
+
bias=False,
|
233 |
+
input_is_parallel=True,
|
234 |
+
init_method=lambda x: x,
|
235 |
+
)
|
236 |
+
|
237 |
+
self.cache_k = torch.zeros(
|
238 |
+
(
|
239 |
+
args.max_batch_size,
|
240 |
+
args.max_seq_len,
|
241 |
+
self.n_local_kv_heads,
|
242 |
+
self.head_dim,
|
243 |
+
)
|
244 |
+
).cuda()
|
245 |
+
self.cache_v = torch.zeros(
|
246 |
+
(
|
247 |
+
args.max_batch_size,
|
248 |
+
args.max_seq_len,
|
249 |
+
self.n_local_kv_heads,
|
250 |
+
self.head_dim,
|
251 |
+
)
|
252 |
+
).cuda()
|
253 |
+
|
254 |
+
def forward(
|
255 |
+
self,
|
256 |
+
x: torch.Tensor,
|
257 |
+
start_pos: int,
|
258 |
+
freqs_cis: torch.Tensor,
|
259 |
+
mask: Optional[torch.Tensor]
|
260 |
+
):
|
261 |
+
"""
|
262 |
+
Forward pass of the attention module.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
x (torch.Tensor): Input tensor.
|
266 |
+
start_pos (int): Starting position for caching.
|
267 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor.
|
268 |
+
mask (torch.Tensor, optional): Attention mask tensor.
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
torch.Tensor: Output tensor after attention.
|
272 |
+
|
273 |
+
"""
|
274 |
+
bsz, seqlen, _ = x.shape
|
275 |
+
|
276 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
277 |
+
|
278 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
279 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
280 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
281 |
+
|
282 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
283 |
+
|
284 |
+
self.cache_k = self.cache_k.to(xq)
|
285 |
+
self.cache_v = self.cache_v.to(xq)
|
286 |
+
|
287 |
+
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
288 |
+
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
289 |
+
|
290 |
+
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
291 |
+
values = self.cache_v[:bsz, : start_pos + seqlen]
|
292 |
+
|
293 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
294 |
+
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
295 |
+
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
296 |
+
|
297 |
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
298 |
+
keys = keys.transpose(1, 2)
|
299 |
+
values = values.transpose(1, 2)
|
300 |
+
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
301 |
+
if mask is not None:
|
302 |
+
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
303 |
+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
304 |
+
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
305 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
306 |
+
return self.wo(output)
|
307 |
+
|
308 |
+
|
309 |
+
class FeedForward(nn.Module):
|
310 |
+
def __init__(
|
311 |
+
self,
|
312 |
+
dim: int,
|
313 |
+
hidden_dim: int,
|
314 |
+
multiple_of: int,
|
315 |
+
ffn_dim_multiplier: Optional[float],
|
316 |
+
):
|
317 |
+
"""
|
318 |
+
Initialize the FeedForward module.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
dim (int): Input dimension.
|
322 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
323 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
324 |
+
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
325 |
+
|
326 |
+
Attributes:
|
327 |
+
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
328 |
+
w2 (RowParallelLinear): Linear transformation for the second layer.
|
329 |
+
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
330 |
+
|
331 |
+
"""
|
332 |
+
super().__init__()
|
333 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
334 |
+
# custom dim factor multiplier
|
335 |
+
if ffn_dim_multiplier is not None:
|
336 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
337 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
338 |
+
|
339 |
+
self.w1 = ColumnParallelLinear(
|
340 |
+
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
|
341 |
+
)
|
342 |
+
self.w2 = RowParallelLinear(
|
343 |
+
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
|
344 |
+
)
|
345 |
+
self.w3 = ColumnParallelLinear(
|
346 |
+
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
|
347 |
+
)
|
348 |
+
|
349 |
+
def forward(self, x):
|
350 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
351 |
+
|
352 |
+
|
353 |
+
class MixedTransformerBlock(nn.Module):
|
354 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
355 |
+
"""
|
356 |
+
Initialize a TransformerBlock.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
layer_id (int): Identifier for the layer.
|
360 |
+
args (ModelArgs): Model configuration parameters.
|
361 |
+
|
362 |
+
Attributes:
|
363 |
+
n_heads (int): Number of attention heads.
|
364 |
+
dim (int): Dimension size of the model.
|
365 |
+
head_dim (int): Dimension size of each attention head.
|
366 |
+
attention (Attention): Attention module.
|
367 |
+
feed_forward (FeedForward): FeedForward module.
|
368 |
+
layer_id (int): Identifier for the layer.
|
369 |
+
attention_norm (RMSNorm): Layer normalization for attention output.
|
370 |
+
ffn_norm (RMSNorm): Layer normalization for feedforward output.
|
371 |
+
|
372 |
+
"""
|
373 |
+
super().__init__()
|
374 |
+
self.n_heads = args.n_heads
|
375 |
+
self.dim = args.dim
|
376 |
+
self.head_dim = args.dim // args.n_heads
|
377 |
+
self.attention = Attention(args)
|
378 |
+
self.feed_forward = FeedForward(
|
379 |
+
dim=args.dim,
|
380 |
+
hidden_dim=4 * args.dim,
|
381 |
+
multiple_of=args.multiple_of,
|
382 |
+
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
383 |
+
)
|
384 |
+
self.layer_id = layer_id
|
385 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
386 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
387 |
+
|
388 |
+
def forward(
|
389 |
+
self,
|
390 |
+
x: torch.Tensor,
|
391 |
+
start_pos: int,
|
392 |
+
freqs_cis: torch.Tensor,
|
393 |
+
mask: Optional[torch.Tensor]
|
394 |
+
):
|
395 |
+
"""
|
396 |
+
Perform a forward pass through the TransformerBlock.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
x (torch.Tensor): Input tensor.
|
400 |
+
start_pos (int): Starting position for attention caching.
|
401 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
402 |
+
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
|
403 |
+
|
404 |
+
Returns:
|
405 |
+
torch.Tensor: Output tensor after applying attention and feedforward layers.
|
406 |
+
|
407 |
+
"""
|
408 |
+
h = x + self.attention.forward(
|
409 |
+
self.attention_norm(x), start_pos, freqs_cis, mask
|
410 |
+
)
|
411 |
+
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
412 |
+
return out
|
413 |
+
|
414 |
+
class SuperposedTransformer(nn.Module):
|
415 |
+
def __init__(self, params: ModelArgs):
|
416 |
+
"""
|
417 |
+
Initialize a Transformer model.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
params (ModelArgs): Model configuration parameters.
|
421 |
+
|
422 |
+
Attributes:
|
423 |
+
params (ModelArgs): Model configuration parameters.
|
424 |
+
vocab_size (int): Vocabulary size.
|
425 |
+
n_layers (int): Number of layers in the model.
|
426 |
+
tok_embeddings (ParallelEmbedding): Token embeddings.
|
427 |
+
layers (torch.nn.ModuleList): List of Transformer blocks.
|
428 |
+
norm (RMSNorm): Layer normalization for the model output.
|
429 |
+
output (ColumnParallelLinear): Linear layer for final output.
|
430 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
431 |
+
|
432 |
+
"""
|
433 |
+
super().__init__()
|
434 |
+
self.params = params
|
435 |
+
self.vocab_size = params.vocab_size
|
436 |
+
self.n_layers = params.n_layers
|
437 |
+
|
438 |
+
self.tok_embeddings = ParallelEmbedding(
|
439 |
+
params.vocab_size, params.dim, init_method=lambda x: x
|
440 |
+
)
|
441 |
+
|
442 |
+
self.tok_mixing_embeddings = ColumnParallelLinear(
|
443 |
+
params.vocab_size, params.dim, bias=False, init_method=lambda x: x
|
444 |
+
) # dims here are formality (what matters is below)
|
445 |
+
self.tok_mixing_embeddings.weight = nn.Parameter(self.tok_embeddings.weight.T)
|
446 |
+
|
447 |
+
self.layers = torch.nn.ModuleList()
|
448 |
+
for layer_id in range(params.n_layers):
|
449 |
+
self.layers.append(MixedTransformerBlock(layer_id, params))
|
450 |
+
|
451 |
+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
452 |
+
self.output = ColumnParallelLinear(
|
453 |
+
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
|
454 |
+
)
|
455 |
+
|
456 |
+
self.freqs_cis = precompute_freqs_cis(
|
457 |
+
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
|
458 |
+
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
|
459 |
+
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
|
460 |
+
)
|
461 |
+
|
462 |
+
@torch.inference_mode()
|
463 |
+
def forward(self,
|
464 |
+
tokens: torch.Tensor,
|
465 |
+
start_pos: int,
|
466 |
+
token_weights: Optional[torch.Tensor],
|
467 |
+
verbose: Optional[bool] = False):
|
468 |
+
"""
|
469 |
+
Perform a forward pass through the Transformer model.
|
470 |
+
|
471 |
+
Args:
|
472 |
+
tokens (torch.Tensor): Input token indices.
|
473 |
+
start_pos (int): Starting position for attention caching.
|
474 |
+
token_weights (torch.Tensor): Superposition matrix.
|
475 |
+
verbose (bool): Whether to return intermediate hidden layer states
|
476 |
+
|
477 |
+
Returns:
|
478 |
+
torch.Tensor or (torch.Tensor, Dict): Output logits after applying the Transformer model.
|
479 |
+
|
480 |
+
"""
|
481 |
+
if verbose:
|
482 |
+
states = {"layers": [], "weights": None}
|
483 |
+
_bsz, seqlen = tokens.shape
|
484 |
+
if token_weights is not None:
|
485 |
+
h = self.tok_mixing_embeddings(token_weights.half()).unsqueeze(1)
|
486 |
+
else:
|
487 |
+
h = self.tok_embeddings(tokens)
|
488 |
+
self.freqs_cis = self.freqs_cis.to(h.device)
|
489 |
+
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
490 |
+
if verbose:
|
491 |
+
states["layers"].append(h)
|
492 |
+
states["weights"] = token_weights
|
493 |
+
|
494 |
+
mask = None
|
495 |
+
if seqlen > 1:
|
496 |
+
mask = torch.full(
|
497 |
+
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
|
498 |
+
)
|
499 |
+
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
500 |
+
|
501 |
+
for layer in self.layers:
|
502 |
+
h = layer(h, start_pos, freqs_cis, mask)
|
503 |
+
if verbose:
|
504 |
+
states["layers"].append(h)
|
505 |
+
|
506 |
+
h = self.norm(h)
|
507 |
+
if verbose:
|
508 |
+
states["layers"].append(h)
|
509 |
+
|
510 |
+
output = self.output(h).float()
|
511 |
+
|
512 |
+
if verbose:
|
513 |
+
return output, states
|
514 |
+
else:
|
515 |
+
return output
|
superposed/llama/tokenizer.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
import os
|
5 |
+
from logging import getLogger
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
from sentencepiece import SentencePieceProcessor
|
9 |
+
|
10 |
+
|
11 |
+
logger = getLogger()
|
12 |
+
|
13 |
+
|
14 |
+
class Tokenizer:
|
15 |
+
"""tokenizing and encoding/decoding text using SentencePiece."""
|
16 |
+
def __init__(self, model_path: str):
|
17 |
+
"""
|
18 |
+
Initializes the Tokenizer with a SentencePiece model.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model_path (str): The path to the SentencePiece model file.
|
22 |
+
"""
|
23 |
+
# reload tokenizer
|
24 |
+
assert os.path.isfile(model_path), model_path
|
25 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
26 |
+
logger.info(f"Reloaded SentencePiece model from {model_path}")
|
27 |
+
|
28 |
+
# BOS / EOS token IDs
|
29 |
+
self.n_words: int = self.sp_model.vocab_size()
|
30 |
+
self.bos_id: int = self.sp_model.bos_id()
|
31 |
+
self.eos_id: int = self.sp_model.eos_id()
|
32 |
+
self.pad_id: int = self.sp_model.pad_id()
|
33 |
+
logger.info(
|
34 |
+
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
35 |
+
)
|
36 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
37 |
+
|
38 |
+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
|
39 |
+
"""
|
40 |
+
Encodes a string into a list of token IDs.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
s (str): The input string to be encoded.
|
44 |
+
bos (bool): Whether to prepend the beginning-of-sequence token.
|
45 |
+
eos (bool): Whether to append the end-of-sequence token.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
List[int]: A list of token IDs.
|
49 |
+
"""
|
50 |
+
assert type(s) is str
|
51 |
+
t = self.sp_model.encode(s)
|
52 |
+
if bos:
|
53 |
+
t = [self.bos_id] + t
|
54 |
+
if eos:
|
55 |
+
t = t + [self.eos_id]
|
56 |
+
return t
|
57 |
+
|
58 |
+
def decode(self, t: List[int]) -> str:
|
59 |
+
"""
|
60 |
+
Decodes a list of token IDs into a string.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
t (List[int]): The list of token IDs to be decoded.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
str: The decoded string.
|
67 |
+
"""
|
68 |
+
return self.sp_model.decode(t)
|
superposed/llama/utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def log_prob_to_prob(log_probs, temp=1):
|
4 |
+
"""
|
5 |
+
Convert log probabilities to probability distribution and normalize.
|
6 |
+
Args:
|
7 |
+
log_probs (torch.Tensor): Log probs (n_prompts, n_drafts, vocab_size)
|
8 |
+
Returns:
|
9 |
+
Probability distribution (n_prompts, n_drafts, vocab_size)
|
10 |
+
"""
|
11 |
+
# stability constant
|
12 |
+
log_probs = log_probs + torch.max(log_probs, dim=-1, keepdim=True)[0]
|
13 |
+
probs = torch.softmax(log_probs / temp, dim=-1)
|
14 |
+
return probs
|
15 |
+
|
16 |
+
def decode(tokenizer, encoding):
|
17 |
+
"""
|
18 |
+
Decode a list of tokens to a string
|
19 |
+
Args:
|
20 |
+
tokenizer (Any): Tokenizer
|
21 |
+
encoding (torch.Tensor): Encoding
|
22 |
+
Returns:
|
23 |
+
decoding (str)
|
24 |
+
"""
|
25 |
+
pad_locs = (encoding == -1).nonzero()
|
26 |
+
if len(pad_locs > 0):
|
27 |
+
encoding = encoding[:pad_locs[0].item()]
|
28 |
+
return tokenizer.decode(encoding.to(torch.int32).tolist())
|
29 |
+
|
30 |
+
def print_gen(gens, logprobs, tokenizer, n_drafts, prompt_len, output_file):
|
31 |
+
"""
|
32 |
+
Print out generations for debugging.
|
33 |
+
Args:
|
34 |
+
gens (n_prompts * n_drafts, seq_len): Generations to print
|
35 |
+
logprobs (n_prompts * n_drafts): Log probs of each generation
|
36 |
+
tokenizer (any): Tokenizer
|
37 |
+
n_drafts (int): Number of drafts per prompt
|
38 |
+
prompt_len (int): Number of tokens in prompt
|
39 |
+
"""
|
40 |
+
n_prompts, n_drafts, seq_len = gens.shape
|
41 |
+
gens = gens.reshape(-1, seq_len)
|
42 |
+
logprobs = logprobs.flatten()
|
43 |
+
count = 0
|
44 |
+
for i in range(len(gens)):
|
45 |
+
d = decode(tokenizer, gens[i])
|
46 |
+
# first draft of this prompt
|
47 |
+
if i % n_drafts == 0:
|
48 |
+
count = 0
|
49 |
+
print("---------------", file=output_file)
|
50 |
+
prompt = decode(tokenizer, gens[i][:prompt_len])
|
51 |
+
print(f"prompt: {prompt}", file=output_file)
|
52 |
+
print(f"logprob: {logprobs[i]} {count}: {d}", file=output_file)
|
53 |
+
count += 1
|
54 |
+
|
55 |
+
def print_probs(next_probs, tokenizer, output_file):
|
56 |
+
"""
|
57 |
+
Print out next token options and probabilities for debugging
|
58 |
+
Args:
|
59 |
+
next_probs (torch.Tensor): Next token probabilities (n_prompts, n_drafts, vocab_size)
|
60 |
+
tokenizer (any): Tokenizer
|
61 |
+
"""
|
62 |
+
print("\tReminder: At most first n_drafts from seq can be selected.", file=output_file)
|
63 |
+
n_prompts, n_drafts, vocab_size = next_probs.shape
|
64 |
+
for p_idx in range(n_prompts):
|
65 |
+
print(f"\tPrompt {p_idx}:", file=output_file)
|
66 |
+
for d_idx in range(n_drafts):
|
67 |
+
next_token_probs, next_token_idx = next_probs[p_idx, d_idx].topk(n_drafts+2, dim=-1)
|
68 |
+
print(f"\t\tTokens: {[tokenizer.decode([i.item()]) for i in next_token_idx]}", file=output_file)
|
69 |
+
print(f"\t\tLog Probs: {torch.log(next_token_probs)}", file=output_file)
|
70 |
+
print(f"\t\tProbs: {next_token_probs}", file=output_file)
|
superposed/ngrams/__pycache__/ngram_models.cpython-312.pyc
ADDED
Binary file (5.53 kB). View file
|
|
superposed/ngrams/make_corpus.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import multiprocessing
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import glob
|
6 |
+
import json
|
7 |
+
from datasets import load_dataset
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import AutoTokenizer, LlamaTokenizer
|
10 |
+
from loguru import logger
|
11 |
+
|
12 |
+
|
13 |
+
def create_corpuses(
|
14 |
+
ckpt_path,
|
15 |
+
start_doc,
|
16 |
+
end_doc,
|
17 |
+
dataset,
|
18 |
+
tokenizer,
|
19 |
+
train_bigram: bool,
|
20 |
+
train_trigram: bool,
|
21 |
+
train_fourgram: bool,
|
22 |
+
train_fivegram: bool,
|
23 |
+
train_sixgram: bool,
|
24 |
+
train_sevengram: bool
|
25 |
+
):
|
26 |
+
bigram_corpus = {}
|
27 |
+
trigram_corpus = {}
|
28 |
+
fourgram_corpus = {}
|
29 |
+
fivegram_corpus = {}
|
30 |
+
sixgram_corpus = {}
|
31 |
+
sevengram_corpus = {}
|
32 |
+
|
33 |
+
bigram_corpus_counts = {}
|
34 |
+
trigram_corpus_counts = {}
|
35 |
+
fourgram_corpus_counts = {}
|
36 |
+
fivegram_corpus_counts = {}
|
37 |
+
sixgram_corpus_counts = {}
|
38 |
+
sevengram_corpus_counts = {}
|
39 |
+
|
40 |
+
iterations = end_doc - start_doc
|
41 |
+
for i in tqdm(range(iterations)):
|
42 |
+
t = dataset[start_doc + i]["text"]
|
43 |
+
encoded_text = tokenizer.encode(t)
|
44 |
+
for start_idx in range(1, len(encoded_text)): # count from first real to eos
|
45 |
+
pOne = encoded_text[start_idx-1] if start_idx >= 1 else None
|
46 |
+
pTwo = encoded_text[start_idx-2] if start_idx >= 2 else None
|
47 |
+
pThree = encoded_text[start_idx-3] if start_idx >= 3 else None
|
48 |
+
pFour = encoded_text[start_idx-4] if start_idx >= 4 else None
|
49 |
+
pFive = encoded_text[start_idx-5] if start_idx >= 5 else None
|
50 |
+
pSix = encoded_text[start_idx-6] if start_idx >= 6 else None
|
51 |
+
|
52 |
+
token = encoded_text[start_idx]
|
53 |
+
# bigram
|
54 |
+
if train_bigram and start_idx >= 1:
|
55 |
+
prior = pOne
|
56 |
+
if prior not in bigram_corpus:
|
57 |
+
bigram_corpus[prior] = {}
|
58 |
+
bigram_corpus_counts[prior] = 0
|
59 |
+
bigram_corpus[prior][token] = bigram_corpus[prior].get(token, 0) + 1
|
60 |
+
bigram_corpus_counts[prior] += 1
|
61 |
+
# trigram
|
62 |
+
if train_trigram and start_idx >= 2:
|
63 |
+
prior = (pTwo, pOne)
|
64 |
+
if prior not in trigram_corpus:
|
65 |
+
trigram_corpus[prior] = {}
|
66 |
+
trigram_corpus_counts[prior] = 0
|
67 |
+
trigram_corpus[prior][token] = trigram_corpus[prior].get(token, 0) + 1
|
68 |
+
trigram_corpus_counts[prior] += 1
|
69 |
+
# fourgram
|
70 |
+
if train_fourgram and start_idx >= 3:
|
71 |
+
prior = (pThree, pTwo, pOne)
|
72 |
+
if prior not in fourgram_corpus:
|
73 |
+
fourgram_corpus[prior] = {}
|
74 |
+
fourgram_corpus_counts[prior] = 0
|
75 |
+
fourgram_corpus[prior][token] = fourgram_corpus[prior].get(token, 0) + 1
|
76 |
+
fourgram_corpus_counts[prior] += 1
|
77 |
+
# fivegram
|
78 |
+
if train_fivegram and start_idx >= 4:
|
79 |
+
prior = (pFour, pThree, pTwo, pOne)
|
80 |
+
if prior not in fivegram_corpus:
|
81 |
+
fivegram_corpus[prior] = {}
|
82 |
+
fivegram_corpus_counts[prior] = 0
|
83 |
+
fivegram_corpus[prior][token] = fivegram_corpus[prior].get(token, 0) + 1
|
84 |
+
fivegram_corpus_counts[prior] += 1
|
85 |
+
# sixgram
|
86 |
+
if train_sixgram and start_idx >= 5:
|
87 |
+
prior = (pFive, pFour, pThree, pTwo, pOne)
|
88 |
+
if prior not in sixgram_corpus:
|
89 |
+
sixgram_corpus[prior] = {}
|
90 |
+
sixgram_corpus_counts[prior] = 0
|
91 |
+
sixgram_corpus[prior][token] = sixgram_corpus[prior].get(token, 0) + 1
|
92 |
+
sixgram_corpus_counts[prior] += 1
|
93 |
+
# sevengram
|
94 |
+
if train_sevengram and start_idx >= 6:
|
95 |
+
prior = (pSix, pFive, pFour, pThree, pTwo, pOne)
|
96 |
+
if prior not in sevengram_corpus:
|
97 |
+
sevengram_corpus[prior] = {}
|
98 |
+
sevengram_corpus_counts[prior] = 0
|
99 |
+
sevengram_corpus[prior][token] = sevengram_corpus[prior].get(token, 0) + 1
|
100 |
+
sevengram_corpus_counts[prior] += 1
|
101 |
+
save_corpus(ckpt_path, bigram_corpus, trigram_corpus, fourgram_corpus, fivegram_corpus, sixgram_corpus, sevengram_corpus, start_doc, end_doc)
|
102 |
+
save_counts(ckpt_path, bigram_corpus_counts, trigram_corpus_counts, fourgram_corpus_counts, fivegram_corpus_counts, sixgram_corpus_counts, sevengram_corpus_counts, start_doc, end_doc)
|
103 |
+
|
104 |
+
def merge_corpus_helper(c1, c2):
|
105 |
+
"""
|
106 |
+
Merge the corpuses c1 and c2, returning the merged result.
|
107 |
+
"""
|
108 |
+
for prior in c2:
|
109 |
+
# if share prior
|
110 |
+
if prior in c1:
|
111 |
+
c1_prior = c1[prior]
|
112 |
+
c2_prior = c2[prior]
|
113 |
+
for token in c2_prior:
|
114 |
+
# if share token
|
115 |
+
if token in c1_prior:
|
116 |
+
c1_prior[token] += c2_prior[token]
|
117 |
+
# else just use c2's
|
118 |
+
else:
|
119 |
+
c1_prior[token] = c2_prior[token]
|
120 |
+
else:
|
121 |
+
# else just use c2's
|
122 |
+
c1[prior] = c2[prior]
|
123 |
+
return c1
|
124 |
+
|
125 |
+
def merge_counts_helper(c1, c2):
|
126 |
+
"""
|
127 |
+
Merge the count corpuses c1 and c2, returning the merged result.
|
128 |
+
"""
|
129 |
+
for prior in c2:
|
130 |
+
if prior in c1:
|
131 |
+
c1[prior] += c2[prior]
|
132 |
+
else:
|
133 |
+
c1[prior] = c2[prior]
|
134 |
+
return c1
|
135 |
+
|
136 |
+
def save_corpus(save_dir, b_d, t_d, fo_d, fi_d, si_d, se_d, start_doc, end_doc):
|
137 |
+
"""
|
138 |
+
Save corpuses b_d (bigram) to se_d (sevengram), where the corpus contains mappings
|
139 |
+
{prefix : {next_token1: ct, next_token2: ct, ...}}.
|
140 |
+
"""
|
141 |
+
prefixes = ["b_d", "t_d", "fo_d", "fi_d", "si_d", "se_d"]
|
142 |
+
for p, corpus in zip(prefixes, [b_d, t_d, fo_d, fi_d, si_d, se_d]):
|
143 |
+
with open(f"{save_dir}/{p}{start_doc}-{end_doc}.pkl", "wb") as f:
|
144 |
+
pickle.dump(corpus, f)
|
145 |
+
|
146 |
+
def save_counts(save_dir, b_ct, t_ct, fo_ct, fi_ct, si_ct, se_ct, start_doc, end_doc):
|
147 |
+
"""
|
148 |
+
Save count corpuses b_ct (bigram) to se_ct (sevengram), where each count
|
149 |
+
corpus contains mappings {prefix : total}.
|
150 |
+
"""
|
151 |
+
prefixes = ["b_ct", "t_ct", "fo_ct", "fi_ct", "si_ct", "se_ct"]
|
152 |
+
for p, corpus in zip(prefixes, [b_ct, t_ct, fo_ct, fi_ct, si_ct, se_ct]):
|
153 |
+
with open(f"{save_dir}/{p}{start_doc}-{end_doc}.pkl", "wb") as f:
|
154 |
+
pickle.dump(corpus, f)
|
155 |
+
|
156 |
+
def merge_corpuses(ckpt_path):
|
157 |
+
"""
|
158 |
+
Helper to merge corpuses in `ckpt_path`, where `ckpt_path` might contain
|
159 |
+
multiple bigram, trigram, etc. corpuses from each process.
|
160 |
+
"""
|
161 |
+
prefixes = ["b_d", "t_d", "fo_d", "fi_d", "si_d", "se_d"]
|
162 |
+
for prefix in prefixes:
|
163 |
+
if os.path.exists(f"{ckpt_path}/{prefix}_final.pkl"):
|
164 |
+
os.remove(f"{ckpt_path}/{prefix}_final.pkl")
|
165 |
+
corpus = None
|
166 |
+
for filepath in glob.glob(f"{ckpt_path}/{prefix}*"):
|
167 |
+
with open(filepath, "rb") as f:
|
168 |
+
current = pickle.load(f)
|
169 |
+
if corpus is None:
|
170 |
+
corpus = current
|
171 |
+
else:
|
172 |
+
corpus = merge_corpus_helper(corpus, current)
|
173 |
+
os.remove(filepath)
|
174 |
+
with open(f"{ckpt_path}/{prefix}_final.pkl", "wb") as f:
|
175 |
+
pickle.dump(corpus, f)
|
176 |
+
|
177 |
+
def merge_counts(ckpt_path):
|
178 |
+
"""
|
179 |
+
Helper to merge count corpuses in `ckpt_path`, where `ckpt_path` might contain
|
180 |
+
multiple bigram, trigram, etc. count corpuses from each process.
|
181 |
+
"""
|
182 |
+
prefixes = ["b_ct", "t_ct", "fo_ct", "fi_ct", "si_ct", "se_ct"]
|
183 |
+
for prefix in prefixes:
|
184 |
+
if os.path.exists(f"{ckpt_path}/{prefix}_final.pkl"):
|
185 |
+
os.remove(f"{ckpt_path}/{prefix}_final.pkl")
|
186 |
+
|
187 |
+
counts = None
|
188 |
+
for filepath in glob.glob(f"{ckpt_path}/{prefix}*"):
|
189 |
+
with open(filepath, "rb") as f:
|
190 |
+
current = pickle.load(f)
|
191 |
+
if counts is None:
|
192 |
+
counts = current
|
193 |
+
else:
|
194 |
+
counts = merge_counts_helper(counts, current)
|
195 |
+
os.remove(filepath)
|
196 |
+
with open(f"{ckpt_path}/{prefix}_final.pkl", "wb") as f:
|
197 |
+
pickle.dump(counts, f)
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == "__main__":
|
201 |
+
# Input arguments
|
202 |
+
parser = argparse.ArgumentParser()
|
203 |
+
parser.add_argument("ckpt_path", type=str, help="Path to store ngram models")
|
204 |
+
parser.add_argument("start_doc", type=str, help="# of first document")
|
205 |
+
parser.add_argument("end_doc", type=str, help="# of last document")
|
206 |
+
parser.add_argument("c", type=int, help="number of processes")
|
207 |
+
parser.add_argument("--tok_name", type=str, help="name of HF tokenizer, or llama", default="llama")
|
208 |
+
for arg_name in ["--bigram", "--trigram", "--fourgram", "--fivegram", "--sixgram", "--sevengram"]:
|
209 |
+
parser.add_argument(arg_name, type=str, help=f"Whether to make a {arg_name} model")
|
210 |
+
parser.add_argument("--dset_name", type=str, help="name of HF dataset")
|
211 |
+
parser.add_argument("--dset_path", type=str, help="path to dataset")
|
212 |
+
# Parse arguments
|
213 |
+
args = parser.parse_args()
|
214 |
+
start_doc_ovr = int(args.start_doc)
|
215 |
+
end_doc_ovr = int(args.end_doc)
|
216 |
+
n_cores = args.c
|
217 |
+
tok_name = args.tok_name
|
218 |
+
ckpt_path = args.ckpt_path
|
219 |
+
dset_name = args.dset_name
|
220 |
+
dset_path = args.dset_path
|
221 |
+
if not dset_name and not dset_path:
|
222 |
+
raise RuntimeError("Please provide a dataset")
|
223 |
+
if not os.path.exists(ckpt_path):
|
224 |
+
os.makedirs(ckpt_path)
|
225 |
+
logger.info(f"{start_doc_ovr} {end_doc_ovr} {n_cores}")
|
226 |
+
|
227 |
+
# Load dataset and tokenizer
|
228 |
+
if dset_name:
|
229 |
+
ds = load_dataset(dset_name, cache_dir="../../../datasets/")["train"].shuffle(seed=42)
|
230 |
+
else:
|
231 |
+
with open(dset_path, "r") as f:
|
232 |
+
ds = json.load(f)["train"]
|
233 |
+
if tok_name == "llama":
|
234 |
+
# REPLACE WITH YOUR OWN PATH
|
235 |
+
tokenizer = LlamaTokenizer.from_pretrained("../../7B_HF", add_bos_token=False)
|
236 |
+
else:
|
237 |
+
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
238 |
+
|
239 |
+
# Start running
|
240 |
+
num_processes = n_cores
|
241 |
+
total_docs = end_doc_ovr - start_doc_ovr
|
242 |
+
docs_per_c = (total_docs) // num_processes
|
243 |
+
processes = []
|
244 |
+
for core in range(n_cores):
|
245 |
+
start_doc = core * docs_per_c # relative start doc
|
246 |
+
end_doc = (core + 1) * docs_per_c if core < n_cores - 1 else total_docs # relative end doc
|
247 |
+
logger.info(f"Starting core {core} from document {start_doc} to {end_doc}")
|
248 |
+
process = multiprocessing.Process(target=create_corpuses,
|
249 |
+
args=(ckpt_path,
|
250 |
+
start_doc_ovr + start_doc,
|
251 |
+
start_doc_ovr + end_doc,
|
252 |
+
ds, tokenizer,
|
253 |
+
args.bigram,
|
254 |
+
args.trigram,
|
255 |
+
args.fourgram,
|
256 |
+
args.fivegram,
|
257 |
+
args.sixgram,
|
258 |
+
args.sevengram))
|
259 |
+
processes.append(process)
|
260 |
+
process.start()
|
261 |
+
for process in processes:
|
262 |
+
process.join()
|
263 |
+
logger.info("Finished Saving")
|
264 |
+
logger.info("Merging...")
|
265 |
+
merge_corpuses(ckpt_path)
|
266 |
+
merge_counts(ckpt_path)
|
267 |
+
logger.info("Merged.")
|
268 |
+
|
superposed/ngrams/ngram_models.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class NGram():
|
7 |
+
def __init__(self, corpus, corpus_counts, type):
|
8 |
+
self.corpus = corpus
|
9 |
+
self.counts = corpus_counts
|
10 |
+
self.type = type
|
11 |
+
|
12 |
+
def prob(self, key, next):
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
key (tuple): tuple of token ID's forming prior
|
16 |
+
next (int): probability of next token
|
17 |
+
"""
|
18 |
+
l = len(key)
|
19 |
+
if self.type == "bigram":
|
20 |
+
assert l == 1
|
21 |
+
key = key[0]
|
22 |
+
elif self.type == "trigram":
|
23 |
+
assert l == 2
|
24 |
+
elif self.type == "fourgram":
|
25 |
+
assert l == 3
|
26 |
+
elif self.type == "fivegram":
|
27 |
+
assert l == 4
|
28 |
+
elif self.type == "sixgram":
|
29 |
+
assert l == 5
|
30 |
+
elif self.type == "sevengram":
|
31 |
+
assert l == 6
|
32 |
+
|
33 |
+
count = 0
|
34 |
+
if key in self.corpus:
|
35 |
+
count = self.corpus[key].get(next, 0)
|
36 |
+
total = sum(self.corpus[key].values())
|
37 |
+
return count / total
|
38 |
+
else:
|
39 |
+
return -1
|
40 |
+
|
41 |
+
def ntd(self, key, vocab_size=32000):
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
key (tuple): tuple of token ID's forming prior
|
45 |
+
Returns:
|
46 |
+
prob_tensor (torch.Tensor): (vocab_size, ) of full next token probabilities
|
47 |
+
"""
|
48 |
+
if key in self.corpus:
|
49 |
+
prob_tensor = torch.zeros(vocab_size)
|
50 |
+
total = sum(self.corpus[key].values())
|
51 |
+
for next_token in self.corpus[key]:
|
52 |
+
prob_tensor[next_token] = self.corpus[key][next_token] / total
|
53 |
+
return prob_tensor
|
54 |
+
else:
|
55 |
+
return None
|
56 |
+
|
57 |
+
def make_models(ckpt_path, bigram, trigram, fourgram, fivegram, sixgram, sevengram):
|
58 |
+
"""
|
59 |
+
Loads and returns a list correspoding to bigram to sevengram models, containing
|
60 |
+
the models that whose parameters are `True`. See below for expected corpus names.
|
61 |
+
Args:
|
62 |
+
ckpt_path (str): Location of ngram models
|
63 |
+
bigram-sevengram: Which models to load
|
64 |
+
Returns:
|
65 |
+
List of n-gram models
|
66 |
+
"""
|
67 |
+
models = []
|
68 |
+
if bigram:
|
69 |
+
print("Making bigram...")
|
70 |
+
with open(f"{ckpt_path}/b_d_final.pkl", "rb") as f:
|
71 |
+
bigram = pickle.load(f)
|
72 |
+
bigram_model = NGram(bigram, None, "bigram")
|
73 |
+
models.append(bigram_model)
|
74 |
+
print(sys.getsizeof(bigram))
|
75 |
+
|
76 |
+
if trigram:
|
77 |
+
print("Making trigram...")
|
78 |
+
with open(f"{ckpt_path}/t_d_final.pkl", "rb") as f:
|
79 |
+
trigram = pickle.load(f)
|
80 |
+
trigram_model = NGram(trigram, None, "trigram")
|
81 |
+
models.append(trigram_model)
|
82 |
+
print(sys.getsizeof(trigram))
|
83 |
+
|
84 |
+
if fourgram:
|
85 |
+
print("Making fourgram...")
|
86 |
+
with open(f"{ckpt_path}/fo_d_final.pkl", "rb") as f:
|
87 |
+
fourgram = pickle.load(f)
|
88 |
+
fourgram_model = NGram(fourgram, None, "fourgram")
|
89 |
+
models.append(fourgram_model)
|
90 |
+
print(sys.getsizeof(fourgram))
|
91 |
+
|
92 |
+
if fivegram:
|
93 |
+
print("Making fivegram...")
|
94 |
+
with open(f"{ckpt_path}/fi_d_final.pkl", "rb") as f:
|
95 |
+
fivegram = pickle.load(f)
|
96 |
+
fivegram_model = NGram(fivegram, None, "fivegram")
|
97 |
+
models.append(fivegram_model)
|
98 |
+
print(sys.getsizeof(fivegram))
|
99 |
+
|
100 |
+
if sixgram:
|
101 |
+
print("Making sixgram...")
|
102 |
+
with open(f"{ckpt_path}/si_d_final.pkl", "rb") as f:
|
103 |
+
sixgram = pickle.load(f)
|
104 |
+
sixgram_model = NGram(sixgram, None, "sixgram")
|
105 |
+
models.append(sixgram_model)
|
106 |
+
print(sys.getsizeof(sixgram))
|
107 |
+
|
108 |
+
if sevengram:
|
109 |
+
print("Making sevengram...")
|
110 |
+
with open(f"{ckpt_path}/se_d_final.pkl", "rb") as f:
|
111 |
+
sevengram = pickle.load(f)
|
112 |
+
sevengram_model = NGram(sevengram, None, "sevengram")
|
113 |
+
models.append(sevengram_model)
|
114 |
+
|
115 |
+
return models
|
superposed/ngrams/test.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": [
|
3 |
+
{"text": "Hi my name is"},
|
4 |
+
{"text": "This is a story of"},
|
5 |
+
{"text": "In many cases, the architecture you want to use can be guessed from the name or the path of the pretrained model you are supplying"},
|
6 |
+
{"text": "There is one class of AutoModel for each task, and for each backend (PyTorch, TensorFlow, or Flax)."}
|
7 |
+
]
|
8 |
+
}
|
superposed/notebooks/custom.ipynb
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"id": "119805f4-8589-4379-ad87-a7bad4c0e658",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stderr",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
14 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
15 |
+
"<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcWriteOptions size changed, may indicate binary incompatibility. Expected 72 from C header, got 88 from PyObject\n",
|
16 |
+
"<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcReadOptions size changed, may indicate binary incompatibility. Expected 96 from C header, got 104 from PyObject\n",
|
17 |
+
"<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileInfo size changed, may indicate binary incompatibility. Expected 64 from C header, got 88 from PyObject\n",
|
18 |
+
"<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileSelector size changed, may indicate binary incompatibility. Expected 48 from C header, got 72 from PyObject\n",
|
19 |
+
"2024-05-30 03:09:58.230601: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
|
20 |
+
"2024-05-30 03:09:58.280835: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
21 |
+
"To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
22 |
+
"2024-05-30 03:10:03.250651: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
|
23 |
+
]
|
24 |
+
}
|
25 |
+
],
|
26 |
+
"source": [
|
27 |
+
"%load_ext autoreload\n",
|
28 |
+
"%autoreload 2\n",
|
29 |
+
"\n",
|
30 |
+
"import json\n",
|
31 |
+
"import os\n",
|
32 |
+
"import pickle\n",
|
33 |
+
"from datetime import datetime\n",
|
34 |
+
"\n",
|
35 |
+
"import evaluate\n",
|
36 |
+
"import torch\n",
|
37 |
+
"from tqdm import tqdm\n",
|
38 |
+
"\n",
|
39 |
+
"from eval import *\n",
|
40 |
+
"from superposed.llama.metrics import *\n",
|
41 |
+
"from superposed.llama.generation import Llama\n",
|
42 |
+
"from superposed.llama.superposed_generation import SuperposedLlama\n",
|
43 |
+
"from superposed.llama.tokenizer import Tokenizer\n",
|
44 |
+
"from superposed.ngrams.ngram_models import make_models"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": 4,
|
50 |
+
"id": "51c15900-c8b8-46d9-a884-6842a391ef48",
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [],
|
53 |
+
"source": [
|
54 |
+
"sup_device = torch.device(\"cuda:0\")\n",
|
55 |
+
"tokenizer = Tokenizer('../../7B/tokenizer.model')"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 5,
|
61 |
+
"id": "9817d9a4-ad64-41c6-b87b-b1e422b836a9",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [
|
64 |
+
{
|
65 |
+
"name": "stdout",
|
66 |
+
"output_type": "stream",
|
67 |
+
"text": [
|
68 |
+
"Parameters: {'alpha': 0.54, 'temp': 0.06, 'n_drafts': 3, 'prompt_len': 15, 'n_token_sample': 9, 'n_token_consider': 32000, 'mixing_method': 'sample_new_weights_with_score', 'smoothing': 'geom', 'sample_tokens': 0, 'sample_beams': 0, 'i_weights': [0.01, 0.04, 0.15, 0.18, 0.12], 'i_length': [1, 2, 3, 4, 5]}\n"
|
69 |
+
]
|
70 |
+
}
|
71 |
+
],
|
72 |
+
"source": [
|
73 |
+
"# Params\n",
|
74 |
+
"param_file = \"../../params/p15_d3_mixed.json\"\n",
|
75 |
+
"with open(param_file, \"r\") as f:\n",
|
76 |
+
" params = json.load(f)\n",
|
77 |
+
" print(f\"Parameters: {params}\")\n",
|
78 |
+
"alpha = params[\"alpha\"]\n",
|
79 |
+
"temp = params[\"temp\"]\n",
|
80 |
+
"n_drafts = params[\"n_drafts\"]\n",
|
81 |
+
"prompt_len = params[\"prompt_len\"]\n",
|
82 |
+
"n_token_sample = params[\"n_token_sample\"]\n",
|
83 |
+
"i_weights = params[\"i_weights\"]\n",
|
84 |
+
"i_length = params[\"i_length\"]"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"execution_count": 6,
|
90 |
+
"id": "9c99098e-a38b-4c78-a0e9-8c80309830bb",
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [
|
93 |
+
{
|
94 |
+
"name": "stdout",
|
95 |
+
"output_type": "stream",
|
96 |
+
"text": [
|
97 |
+
"Making bigram...\n",
|
98 |
+
"1310800\n",
|
99 |
+
"Making trigram...\n",
|
100 |
+
"671088728\n",
|
101 |
+
"Making fourgram...\n",
|
102 |
+
"2684354648\n",
|
103 |
+
"Making fivegram...\n",
|
104 |
+
"5368709200\n",
|
105 |
+
"Making sixgram...\n",
|
106 |
+
"5368709200\n"
|
107 |
+
]
|
108 |
+
}
|
109 |
+
],
|
110 |
+
"source": [
|
111 |
+
"# Create ngram models\n",
|
112 |
+
"ngrams = make_models(\"../../ckpts-200k\", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": 7,
|
118 |
+
"id": "c3331332-242c-4e98-9f11-58c6dc0ef581",
|
119 |
+
"metadata": {},
|
120 |
+
"outputs": [
|
121 |
+
{
|
122 |
+
"name": "stdout",
|
123 |
+
"output_type": "stream",
|
124 |
+
"text": [
|
125 |
+
"> initializing model parallel with size 1\n",
|
126 |
+
"> initializing ddp with size 1\n",
|
127 |
+
"> initializing pipeline with size 1\n"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"name": "stderr",
|
132 |
+
"output_type": "stream",
|
133 |
+
"text": [
|
134 |
+
"/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/torch/__init__.py:614: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
|
135 |
+
" _C._set_default_tensor_type(t)\n"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"name": "stdout",
|
140 |
+
"output_type": "stream",
|
141 |
+
"text": [
|
142 |
+
"Loaded in 25.15 seconds\n",
|
143 |
+
"cuda:0\n"
|
144 |
+
]
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"source": [
|
148 |
+
"weight_path = \"../../7B/\"\n",
|
149 |
+
"model = SuperposedLlama.build(ckpt_dir=weight_path, \n",
|
150 |
+
" tokenizer_path=f'{weight_path}/tokenizer.model', \n",
|
151 |
+
" max_seq_len=100, \n",
|
152 |
+
" max_batch_size=32,\n",
|
153 |
+
" device=sup_device,\n",
|
154 |
+
" model_parallel_size=1)"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"cell_type": "markdown",
|
159 |
+
"id": "e2b48c23-d6a3-43b1-ad4c-54524aacfda6",
|
160 |
+
"metadata": {},
|
161 |
+
"source": [
|
162 |
+
"# Inference"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": 11,
|
168 |
+
"id": "5093373b-bf76-47e3-8f99-1045b60f29c3",
|
169 |
+
"metadata": {},
|
170 |
+
"outputs": [],
|
171 |
+
"source": [
|
172 |
+
"def decode(tokenizer, encoding):\n",
|
173 |
+
" \"\"\"\n",
|
174 |
+
" Args:\n",
|
175 |
+
" tokenizer (Any): Tokenizer\n",
|
176 |
+
" encoding (torch.Tensor): Encoding\n",
|
177 |
+
" Returns:\n",
|
178 |
+
" decoding (str)\n",
|
179 |
+
" \"\"\"\n",
|
180 |
+
" eos_locs = (encoding == tokenizer.eos_id).nonzero()\n",
|
181 |
+
" if len(eos_locs > 0):\n",
|
182 |
+
" encoding = encoding[:eos_locs[0]]\n",
|
183 |
+
" return tokenizer.decode(encoding.to(torch.int32).tolist())"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": 22,
|
189 |
+
"id": "18703b19-f3e9-46e4-ab1c-c6d3b403c6d2",
|
190 |
+
"metadata": {},
|
191 |
+
"outputs": [],
|
192 |
+
"source": [
|
193 |
+
"prompts = [\n",
|
194 |
+
" \"Hi my name is\",\n",
|
195 |
+
" \"The Seattle Seahawks were Super Bowl\",\n",
|
196 |
+
" \"Penguins are birds native to\"\n",
|
197 |
+
"]\n",
|
198 |
+
"tokenized_prompts = tokenizer.encode(prompts, True, False)"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": 23,
|
204 |
+
"id": "d39cd735-9480-4979-ac92-bbd470f75570",
|
205 |
+
"metadata": {},
|
206 |
+
"outputs": [],
|
207 |
+
"source": [
|
208 |
+
"alive_gens, _ = model.sup_generate(prompt_tokens=tokenized_prompts, \n",
|
209 |
+
" smoothing=\"geom\",\n",
|
210 |
+
" max_gen_len=10, \n",
|
211 |
+
" n_token_sample=n_token_sample,\n",
|
212 |
+
" alpha=alpha, \n",
|
213 |
+
" temp=temp,\n",
|
214 |
+
" n_drafts=n_drafts,\n",
|
215 |
+
" i_weights=i_weights,\n",
|
216 |
+
" i_length=i_length,\n",
|
217 |
+
" ngrams=ngrams,\n",
|
218 |
+
" get_time=False,\n",
|
219 |
+
" penalty=200)"
|
220 |
+
]
|
221 |
+
},
|
222 |
+
{
|
223 |
+
"cell_type": "code",
|
224 |
+
"execution_count": 24,
|
225 |
+
"id": "cfefa793-e49e-483a-a504-5cc9e23f619d",
|
226 |
+
"metadata": {},
|
227 |
+
"outputs": [],
|
228 |
+
"source": [
|
229 |
+
"gens = alive_gens[0].reshape(len(prompts) * n_drafts, -1)"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": 25,
|
235 |
+
"id": "5abf87ab-2ee0-4204-868b-1215abf0c8aa",
|
236 |
+
"metadata": {},
|
237 |
+
"outputs": [
|
238 |
+
{
|
239 |
+
"name": "stdout",
|
240 |
+
"output_type": "stream",
|
241 |
+
"text": [
|
242 |
+
"Hi\n",
|
243 |
+
"my name\n",
|
244 |
+
"is L\n",
|
245 |
+
"inda,\n",
|
246 |
+
"I am\n",
|
247 |
+
"a \n",
|
248 |
+
"40\n",
|
249 |
+
"year old\n",
|
250 |
+
"woman who\n"
|
251 |
+
]
|
252 |
+
}
|
253 |
+
],
|
254 |
+
"source": [
|
255 |
+
"for i in gens:\n",
|
256 |
+
" print(decode(tokenizer, i))"
|
257 |
+
]
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "code",
|
261 |
+
"execution_count": null,
|
262 |
+
"id": "e73dc3cc-baa5-468d-bdd1-827465bdeb62",
|
263 |
+
"metadata": {},
|
264 |
+
"outputs": [],
|
265 |
+
"source": []
|
266 |
+
}
|
267 |
+
],
|
268 |
+
"metadata": {
|
269 |
+
"kernelspec": {
|
270 |
+
"display_name": "Python 3 (ipykernel)",
|
271 |
+
"language": "python",
|
272 |
+
"name": "python3"
|
273 |
+
},
|
274 |
+
"language_info": {
|
275 |
+
"codemirror_mode": {
|
276 |
+
"name": "ipython",
|
277 |
+
"version": 3
|
278 |
+
},
|
279 |
+
"file_extension": ".py",
|
280 |
+
"mimetype": "text/x-python",
|
281 |
+
"name": "python",
|
282 |
+
"nbconvert_exporter": "python",
|
283 |
+
"pygments_lexer": "ipython3",
|
284 |
+
"version": "3.11.5"
|
285 |
+
}
|
286 |
+
},
|
287 |
+
"nbformat": 4,
|
288 |
+
"nbformat_minor": 5
|
289 |
+
}
|
superposed/notebooks/nq.ipynb
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"The autoreload extension is already loaded. To reload it, use:\n",
|
13 |
+
" %reload_ext autoreload\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"%load_ext autoreload\n",
|
19 |
+
"%autoreload 2\n",
|
20 |
+
"\n",
|
21 |
+
"import json\n",
|
22 |
+
"import os\n",
|
23 |
+
"import re\n",
|
24 |
+
"from datetime import datetime\n",
|
25 |
+
"\n",
|
26 |
+
"import torch\n",
|
27 |
+
"from datasets import load_dataset\n",
|
28 |
+
"from tqdm import tqdm\n",
|
29 |
+
"\n",
|
30 |
+
"from eval import *\n",
|
31 |
+
"from superposed.llama.metrics import *\n",
|
32 |
+
"from superposed.llama.generation import Llama\n",
|
33 |
+
"from superposed.llama.superposed_generation import SuperposedLlama\n",
|
34 |
+
"from superposed.llama.tokenizer import Tokenizer\n",
|
35 |
+
"from superposed.ngrams.ngram_models import make_models"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "markdown",
|
40 |
+
"metadata": {},
|
41 |
+
"source": [
|
42 |
+
"# Setup"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 3,
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"nq = load_dataset(\"nq_open\")[\"validation\"]"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": 6,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [
|
59 |
+
{
|
60 |
+
"name": "stdout",
|
61 |
+
"output_type": "stream",
|
62 |
+
"text": [
|
63 |
+
"Parameters: {'alpha': 0.54, 'temp': 0.06, 'n_drafts': 3, 'prompt_len': 15, 'n_token_sample': 9, 'n_token_consider': 32000, 'mixing_method': 'sample_new_weights_with_score', 'smoothing': 'geom', 'sample_tokens': 0, 'sample_beams': 0, 'i_weights': [0.01, 0.04, 0.15, 0.18, 0.12], 'i_length': [1, 2, 3, 4, 5]}\n"
|
64 |
+
]
|
65 |
+
}
|
66 |
+
],
|
67 |
+
"source": [
|
68 |
+
"# Params\n",
|
69 |
+
"param_file = \"../../params/p15_d3_mixed.json\"\n",
|
70 |
+
"with open(param_file, \"r\") as f:\n",
|
71 |
+
" params = json.load(f)\n",
|
72 |
+
" print(f\"Parameters: {params}\")\n",
|
73 |
+
"alpha = params[\"alpha\"]\n",
|
74 |
+
"temp = params[\"temp\"]\n",
|
75 |
+
"n_drafts = params[\"n_drafts\"]\n",
|
76 |
+
"prompt_len = params[\"prompt_len\"]\n",
|
77 |
+
"n_token_sample = params[\"n_token_sample\"]\n",
|
78 |
+
"i_weights = params[\"i_weights\"]\n",
|
79 |
+
"i_length = params[\"i_length\"]"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "markdown",
|
84 |
+
"metadata": {},
|
85 |
+
"source": [
|
86 |
+
"# Create Models"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": 7,
|
92 |
+
"metadata": {},
|
93 |
+
"outputs": [
|
94 |
+
{
|
95 |
+
"name": "stdout",
|
96 |
+
"output_type": "stream",
|
97 |
+
"text": [
|
98 |
+
"Making bigram...\n",
|
99 |
+
"1310800\n",
|
100 |
+
"Making trigram...\n",
|
101 |
+
"671088728\n",
|
102 |
+
"Making fourgram...\n",
|
103 |
+
"2684354648\n",
|
104 |
+
"Making fivegram...\n",
|
105 |
+
"5368709200\n",
|
106 |
+
"Making sixgram...\n",
|
107 |
+
"5368709200\n"
|
108 |
+
]
|
109 |
+
}
|
110 |
+
],
|
111 |
+
"source": [
|
112 |
+
"ngrams = make_models(\"../../ckpts-200k\", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": 9,
|
118 |
+
"metadata": {},
|
119 |
+
"outputs": [],
|
120 |
+
"source": [
|
121 |
+
"sup_device = torch.device(\"cuda:0\")\n",
|
122 |
+
"reg_device = torch.device(\"cuda:1\")"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": 11,
|
128 |
+
"metadata": {},
|
129 |
+
"outputs": [
|
130 |
+
{
|
131 |
+
"name": "stdout",
|
132 |
+
"output_type": "stream",
|
133 |
+
"text": [
|
134 |
+
"> initializing model parallel with size 1\n",
|
135 |
+
"> initializing ddp with size 1\n",
|
136 |
+
"> initializing pipeline with size 1\n"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"name": "stderr",
|
141 |
+
"output_type": "stream",
|
142 |
+
"text": [
|
143 |
+
"/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/torch/__init__.py:614: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
|
144 |
+
" _C._set_default_tensor_type(t)\n"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"name": "stdout",
|
149 |
+
"output_type": "stream",
|
150 |
+
"text": [
|
151 |
+
"Loaded in 33.68 seconds\n",
|
152 |
+
"cuda:0\n"
|
153 |
+
]
|
154 |
+
}
|
155 |
+
],
|
156 |
+
"source": [
|
157 |
+
"# load superposed\n",
|
158 |
+
"weight_path = \"../../7B/\"\n",
|
159 |
+
"sup_model = SuperposedLlama.build(ckpt_dir=weight_path, \n",
|
160 |
+
" tokenizer_path=f'{weight_path}/tokenizer.model', \n",
|
161 |
+
" max_seq_len=1000, \n",
|
162 |
+
" max_batch_size=16,\n",
|
163 |
+
" device=sup_device,\n",
|
164 |
+
" model_parallel_size=1)"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": 12,
|
170 |
+
"metadata": {},
|
171 |
+
"outputs": [
|
172 |
+
{
|
173 |
+
"name": "stdout",
|
174 |
+
"output_type": "stream",
|
175 |
+
"text": [
|
176 |
+
"0\n",
|
177 |
+
"Loaded in 22.47 seconds\n"
|
178 |
+
]
|
179 |
+
}
|
180 |
+
],
|
181 |
+
"source": [
|
182 |
+
"# load regular\n",
|
183 |
+
"reg_model = Llama.build(ckpt_dir=weight_path, \n",
|
184 |
+
" tokenizer_path=f'{weight_path}/tokenizer.model', \n",
|
185 |
+
" max_seq_len=1000, \n",
|
186 |
+
" max_batch_size=16,\n",
|
187 |
+
" device=reg_device, # reg_device,\n",
|
188 |
+
" model_parallel_size=1)"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "code",
|
193 |
+
"execution_count": 13,
|
194 |
+
"metadata": {},
|
195 |
+
"outputs": [],
|
196 |
+
"source": [
|
197 |
+
"tokenizer = Tokenizer(f\"{weight_path}/tokenizer.model\")"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "markdown",
|
202 |
+
"metadata": {},
|
203 |
+
"source": [
|
204 |
+
"# Evaluation"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": 14,
|
210 |
+
"metadata": {},
|
211 |
+
"outputs": [],
|
212 |
+
"source": [
|
213 |
+
"model_types = [\"greedy\", \"superposed\", \"regular\"]\n",
|
214 |
+
"model_type = model_types[1]"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"cell_type": "code",
|
219 |
+
"execution_count": 17,
|
220 |
+
"metadata": {},
|
221 |
+
"outputs": [],
|
222 |
+
"source": [
|
223 |
+
"def evaluate_nq(model_type, question, max_gen_len):\n",
|
224 |
+
" question = \"Answer these questions:\\n\\nQ: \" + question + \"?\\nA:\"\n",
|
225 |
+
" text_len = len(question) # for truncating\n",
|
226 |
+
" prompt_len = len(tokenizer.encode([question], True, False)[0]) # for model\n",
|
227 |
+
" if model_type == \"regular\" or model_type == \"greedy\":\n",
|
228 |
+
" if model_type == \"regular\":\n",
|
229 |
+
" input = [question for _ in range(n_drafts)]\n",
|
230 |
+
" print(input)\n",
|
231 |
+
" sequences, _ = evaluate_nucleus_losses(data=input,\n",
|
232 |
+
" model=reg_model,\n",
|
233 |
+
" tokenizer=tokenizer,\n",
|
234 |
+
" prompt_len=prompt_len,\n",
|
235 |
+
" max_gen_len=max_gen_len,\n",
|
236 |
+
" temp=0.6,\n",
|
237 |
+
" bsz=8,\n",
|
238 |
+
" marker=False)\n",
|
239 |
+
" else:\n",
|
240 |
+
" sequences, _ = evaluate_nucleus_losses(data=[question],\n",
|
241 |
+
" model=reg_model,\n",
|
242 |
+
" tokenizer=tokenizer,\n",
|
243 |
+
" prompt_len=prompt_len,\n",
|
244 |
+
" max_gen_len=max_gen_len,\n",
|
245 |
+
" temp=0,\n",
|
246 |
+
" bsz=8,\n",
|
247 |
+
" marker=False)\n",
|
248 |
+
" n_pd, seq_len = sequences.shape\n",
|
249 |
+
" elif model_type == \"superposed\":\n",
|
250 |
+
" sequences, _ = evaluate_mixed_losses(data=[question],\n",
|
251 |
+
" model=sup_model,\n",
|
252 |
+
" tokenizer=tokenizer,\n",
|
253 |
+
" prompt_len=prompt_len,\n",
|
254 |
+
" max_gen_len=max_gen_len,\n",
|
255 |
+
" alpha=alpha,\n",
|
256 |
+
" temp=temp,\n",
|
257 |
+
" n_drafts=n_drafts,\n",
|
258 |
+
" n_token_sample=n_token_sample,\n",
|
259 |
+
" smoothing=None, # Use greedy\n",
|
260 |
+
" bsz=8,\n",
|
261 |
+
" i_weights=i_weights,\n",
|
262 |
+
" i_length=i_length,\n",
|
263 |
+
" ngrams=ngrams,\n",
|
264 |
+
" marker=False)\n",
|
265 |
+
" n_p, n_d, seq_len = sequences.shape\n",
|
266 |
+
" # Process results\n",
|
267 |
+
" sequences = sequences.reshape(-1, seq_len).tolist()\n",
|
268 |
+
" for d_idx in range(len(sequences)):\n",
|
269 |
+
" draft = sequences[d_idx]\n",
|
270 |
+
" if -1 in draft:\n",
|
271 |
+
" draft = draft[:draft.index(-1)]\n",
|
272 |
+
" sequences[d_idx] = draft\n",
|
273 |
+
" decoded_seq = tokenizer.decode(sequences)\n",
|
274 |
+
" answers = []\n",
|
275 |
+
" for s in decoded_seq:\n",
|
276 |
+
" answers.append(re.split(\"[,.\\n]\", s[text_len:].strip())[0])\n",
|
277 |
+
" return answers\n",
|
278 |
+
" "
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": null,
|
284 |
+
"metadata": {},
|
285 |
+
"outputs": [],
|
286 |
+
"source": [
|
287 |
+
"# Run evaluation\n",
|
288 |
+
"predictions = []\n",
|
289 |
+
"print(f\"Precision from 1 to {n_drafts}\")\n",
|
290 |
+
"for sample in tqdm(nq):\n",
|
291 |
+
" # Adaptively determine max generation length\n",
|
292 |
+
" longest = 0\n",
|
293 |
+
" shortest = 1000\n",
|
294 |
+
" for answer in sample[\"answer\"]:\n",
|
295 |
+
" tmp = tokenizer.encode([answer], False, False)[0]\n",
|
296 |
+
" if len(tmp) > longest:\n",
|
297 |
+
" longest = len(tmp)\n",
|
298 |
+
" if len(tmp) < shortest:\n",
|
299 |
+
" shortest = len(tmp)\n",
|
300 |
+
" question = sample[\"question\"]\n",
|
301 |
+
" answer = evaluate_nq(model_type, question, max_gen_len=shortest+3)\n",
|
302 |
+
" predictions.append({\"question\": question, \"answer\": answer})"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"cell_type": "code",
|
307 |
+
"execution_count": 52,
|
308 |
+
"metadata": {},
|
309 |
+
"outputs": [],
|
310 |
+
"source": [
|
311 |
+
"# Separate results into precisions\n",
|
312 |
+
"precisions = {}\n",
|
313 |
+
"for i in range(1, n_drafts+1):\n",
|
314 |
+
" prec = str(i)\n",
|
315 |
+
" responses = []\n",
|
316 |
+
" for result in predictions:\n",
|
317 |
+
" responses.append({\"question\": result[\"question\"], \"answer\": result[\"answer\"][:i]})\n",
|
318 |
+
" precisions[prec] = responses"
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"cell_type": "code",
|
323 |
+
"execution_count": 53,
|
324 |
+
"metadata": {},
|
325 |
+
"outputs": [
|
326 |
+
{
|
327 |
+
"name": "stdout",
|
328 |
+
"output_type": "stream",
|
329 |
+
"text": [
|
330 |
+
"{'question': 'when was the last time anyone was on the moon', 'answer': ['2019', '2019', '2019-', '2019-', '1019']}\n",
|
331 |
+
"================\n",
|
332 |
+
"{'question': \"who wrote he ain't heavy he's my brother lyrics\", 'answer': ['The song was written by', 'The lyr was written by', 'The Hol was written by', 'Neil song was written by', 'Neil lyr was written by']}\n",
|
333 |
+
"================\n",
|
334 |
+
"{'question': 'how many seasons of the bastard executioner are there', 'answer': ['1', 'There1', 'there1', '1', 'There1']}\n",
|
335 |
+
"================\n",
|
336 |
+
"{'question': 'when did the eagles win last super bowl', 'answer': ['2018', 'The2018', '1018', '2017', 'the2018']}\n",
|
337 |
+
"================\n",
|
338 |
+
"{'question': \"who won last year's ncaa women's basketball\", 'answer': ['the university of connecticut', 'The university of connecticut', 'university of connecticut', 'the University of connecticut', 'The University of connecticut']}\n",
|
339 |
+
"================\n",
|
340 |
+
"{'question': 'when did the isle of wight become an island', 'answer': ['1207', 'when1207', '1287', '1277', 'when1287']}\n",
|
341 |
+
"================\n",
|
342 |
+
"{'question': 'love yourself by justin bieber is about who', 'answer': ['love yourself by justin b', 'love yourself is justin b', 'Justin yourself by justin b', 'Justin yourself is justin b', 'It yourself by justin b']}\n",
|
343 |
+
"================\n",
|
344 |
+
"{'question': 'who was the ruler of england in 1616', 'answer': ['James I', 'James I of', 'King I', 'j I', 'James I']}\n",
|
345 |
+
"================\n",
|
346 |
+
"{'question': 'what is the hot coffee mod in san andreas', 'answer': ['The Hot Coffee mod is a modification for Grand', 'The Hot Coffee mod is a mod for Grand', 'The hot Coffee mod is a modification for Grand', 'The Hot Coffee mod is a modification that Grand', 'It Hot Coffee mod is a modification for Grand']}\n",
|
347 |
+
"================\n",
|
348 |
+
"{'question': 'what is the maximum data rate for the 802.11a standard select one', 'answer': ['54 Mbps', '54Mbps', '54 mbps', '54 Mbps', '54 Mbps']}\n",
|
349 |
+
"================\n"
|
350 |
+
]
|
351 |
+
}
|
352 |
+
],
|
353 |
+
"source": [
|
354 |
+
"# Print some results\n",
|
355 |
+
"counter = 0\n",
|
356 |
+
"for k in predictions:\n",
|
357 |
+
" if counter >= 10:\n",
|
358 |
+
" break\n",
|
359 |
+
" print(k)\n",
|
360 |
+
" counter += 1\n",
|
361 |
+
" print(\"================\")"
|
362 |
+
]
|
363 |
+
},
|
364 |
+
{
|
365 |
+
"cell_type": "markdown",
|
366 |
+
"metadata": {},
|
367 |
+
"source": [
|
368 |
+
"# Saving"
|
369 |
+
]
|
370 |
+
},
|
371 |
+
{
|
372 |
+
"cell_type": "code",
|
373 |
+
"execution_count": 54,
|
374 |
+
"metadata": {},
|
375 |
+
"outputs": [
|
376 |
+
{
|
377 |
+
"name": "stdout",
|
378 |
+
"output_type": "stream",
|
379 |
+
"text": [
|
380 |
+
"dict_keys(['1', '2', '3', '4', '5'])\n"
|
381 |
+
]
|
382 |
+
}
|
383 |
+
],
|
384 |
+
"source": [
|
385 |
+
"# Save results\n",
|
386 |
+
"os.makedirs(\"../../nq/\", exist_ok=True)\n",
|
387 |
+
"print(precisions.keys())\n",
|
388 |
+
"for prec in range(1, n_drafts+1):\n",
|
389 |
+
" out_path = f\"../nq/eval_{model_type}_{prec}_test.jsonl\"\n",
|
390 |
+
" with open(out_path, \"w\") as f:\n",
|
391 |
+
" for obj in precisions[str(prec)]: \n",
|
392 |
+
" f.write(json.dumps(obj) + \"\\n\")"
|
393 |
+
]
|
394 |
+
}
|
395 |
+
],
|
396 |
+
"metadata": {
|
397 |
+
"kernelspec": {
|
398 |
+
"display_name": "Python 3 (ipykernel)",
|
399 |
+
"language": "python",
|
400 |
+
"name": "python3"
|
401 |
+
},
|
402 |
+
"language_info": {
|
403 |
+
"codemirror_mode": {
|
404 |
+
"name": "ipython",
|
405 |
+
"version": 3
|
406 |
+
},
|
407 |
+
"file_extension": ".py",
|
408 |
+
"mimetype": "text/x-python",
|
409 |
+
"name": "python",
|
410 |
+
"nbconvert_exporter": "python",
|
411 |
+
"pygments_lexer": "ipython3",
|
412 |
+
"version": "3.11.5"
|
413 |
+
}
|
414 |
+
},
|
415 |
+
"nbformat": 4,
|
416 |
+
"nbformat_minor": 4
|
417 |
+
}
|
superposed/notebooks/triviaqa.ipynb
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
14 |
+
"<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcWriteOptions size changed, may indicate binary incompatibility. Expected 72 from C header, got 88 from PyObject\n",
|
15 |
+
"<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow.lib.IpcReadOptions size changed, may indicate binary incompatibility. Expected 96 from C header, got 104 from PyObject\n",
|
16 |
+
"<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileInfo size changed, may indicate binary incompatibility. Expected 64 from C header, got 88 from PyObject\n",
|
17 |
+
"<frozen importlib._bootstrap>:241: RuntimeWarning: pyarrow._fs.FileSelector size changed, may indicate binary incompatibility. Expected 48 from C header, got 72 from PyObject\n",
|
18 |
+
"2024-05-30 01:35:17.813978: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
|
19 |
+
"2024-05-30 01:35:20.452213: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
20 |
+
"To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
21 |
+
"2024-05-30 01:35:41.833487: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
|
22 |
+
]
|
23 |
+
}
|
24 |
+
],
|
25 |
+
"source": [
|
26 |
+
"%load_ext autoreload\n",
|
27 |
+
"%autoreload 2\n",
|
28 |
+
"\n",
|
29 |
+
"import copy\n",
|
30 |
+
"import json\n",
|
31 |
+
"import pickle\n",
|
32 |
+
"import os\n",
|
33 |
+
"import random\n",
|
34 |
+
"import re\n",
|
35 |
+
"import string\n",
|
36 |
+
"import math\n",
|
37 |
+
"from datetime import datetime\n",
|
38 |
+
"\n",
|
39 |
+
"import evaluate\n",
|
40 |
+
"import torch\n",
|
41 |
+
"import numpy as np\n",
|
42 |
+
"from datasets import load_dataset\n",
|
43 |
+
"from transformers import LlamaTokenizer\n",
|
44 |
+
"from tqdm import tqdm\n",
|
45 |
+
"\n",
|
46 |
+
"from eval import *\n",
|
47 |
+
"from superposed.llama.metrics import *\n",
|
48 |
+
"from superposed.llama.generation import Llama\n",
|
49 |
+
"from superposed.llama.superposed_generation import SuperposedLlama\n",
|
50 |
+
"from superposed.llama.tokenizer import Tokenizer\n",
|
51 |
+
"from superposed.ngrams.ngram_models import make_models"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "markdown",
|
56 |
+
"metadata": {},
|
57 |
+
"source": [
|
58 |
+
"# Setup"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "code",
|
63 |
+
"execution_count": 3,
|
64 |
+
"metadata": {},
|
65 |
+
"outputs": [
|
66 |
+
{
|
67 |
+
"name": "stdout",
|
68 |
+
"output_type": "stream",
|
69 |
+
"text": [
|
70 |
+
"Parameters: {'alpha': 0.54, 'temp': 0.06, 'n_drafts': 3, 'prompt_len': 15, 'n_token_sample': 9, 'n_token_consider': 32000, 'mixing_method': 'sample_new_weights_with_score', 'smoothing': 'geom', 'sample_tokens': 0, 'sample_beams': 0, 'i_weights': [0.01, 0.04, 0.15, 0.18, 0.12], 'i_length': [1, 2, 3, 4, 5]}\n"
|
71 |
+
]
|
72 |
+
}
|
73 |
+
],
|
74 |
+
"source": [
|
75 |
+
"# Params\n",
|
76 |
+
"param_file = \"../../params/p15_d3_mixed.json\"\n",
|
77 |
+
"with open(param_file, \"r\") as f:\n",
|
78 |
+
" params = json.load(f)\n",
|
79 |
+
" print(f\"Parameters: {params}\")\n",
|
80 |
+
"alpha = params[\"alpha\"]\n",
|
81 |
+
"temp = params[\"temp\"]\n",
|
82 |
+
"n_drafts = params[\"n_drafts\"]\n",
|
83 |
+
"prompt_len = params[\"prompt_len\"]\n",
|
84 |
+
"n_token_sample = params[\"n_token_sample\"]\n",
|
85 |
+
"i_weights = params[\"i_weights\"]\n",
|
86 |
+
"i_length = params[\"i_length\"]"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": 5,
|
92 |
+
"metadata": {
|
93 |
+
"scrolled": true
|
94 |
+
},
|
95 |
+
"outputs": [
|
96 |
+
{
|
97 |
+
"name": "stdout",
|
98 |
+
"output_type": "stream",
|
99 |
+
"text": [
|
100 |
+
"Making bigram...\n",
|
101 |
+
"1310800\n",
|
102 |
+
"Making trigram...\n",
|
103 |
+
"671088728\n",
|
104 |
+
"Making fourgram...\n",
|
105 |
+
"2684354648\n",
|
106 |
+
"Making fivegram...\n",
|
107 |
+
"5368709200\n",
|
108 |
+
"Making sixgram...\n",
|
109 |
+
"5368709200\n"
|
110 |
+
]
|
111 |
+
}
|
112 |
+
],
|
113 |
+
"source": [
|
114 |
+
"ngrams = make_models(\"../../ckpts-200k\", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "code",
|
119 |
+
"execution_count": 10,
|
120 |
+
"metadata": {},
|
121 |
+
"outputs": [],
|
122 |
+
"source": [
|
123 |
+
"sup_device = torch.device(\"cuda:0\")\n",
|
124 |
+
"reg_device = torch.device(\"cuda:1\")"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"cell_type": "code",
|
129 |
+
"execution_count": 11,
|
130 |
+
"metadata": {},
|
131 |
+
"outputs": [
|
132 |
+
{
|
133 |
+
"name": "stdout",
|
134 |
+
"output_type": "stream",
|
135 |
+
"text": [
|
136 |
+
"> initializing model parallel with size 1\n",
|
137 |
+
"> initializing ddp with size 1\n",
|
138 |
+
"> initializing pipeline with size 1\n"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"name": "stderr",
|
143 |
+
"output_type": "stream",
|
144 |
+
"text": [
|
145 |
+
"/gscratch/raivn/ethans/miniconda3/envs/llms_12.1/lib/python3.11/site-packages/torch/__init__.py:614: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:451.)\n",
|
146 |
+
" _C._set_default_tensor_type(t)\n"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"name": "stdout",
|
151 |
+
"output_type": "stream",
|
152 |
+
"text": [
|
153 |
+
"Loaded in 22.07 seconds\n",
|
154 |
+
"cuda:0\n"
|
155 |
+
]
|
156 |
+
}
|
157 |
+
],
|
158 |
+
"source": [
|
159 |
+
"weight_path = \"../../7B/\"\n",
|
160 |
+
"sup_model = SuperposedLlama.build(ckpt_dir=weight_path, \n",
|
161 |
+
" tokenizer_path=f'{weight_path}/tokenizer.model', \n",
|
162 |
+
" max_seq_len=1000, \n",
|
163 |
+
" max_batch_size=16,\n",
|
164 |
+
" device=sup_device,\n",
|
165 |
+
" model_parallel_size=1)"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"cell_type": "code",
|
170 |
+
"execution_count": 12,
|
171 |
+
"metadata": {},
|
172 |
+
"outputs": [
|
173 |
+
{
|
174 |
+
"name": "stdout",
|
175 |
+
"output_type": "stream",
|
176 |
+
"text": [
|
177 |
+
"0\n",
|
178 |
+
"Loaded in 22.76 seconds\n"
|
179 |
+
]
|
180 |
+
}
|
181 |
+
],
|
182 |
+
"source": [
|
183 |
+
"reg_model = Llama.build(ckpt_dir=weight_path, \n",
|
184 |
+
" tokenizer_path=f'{weight_path}/tokenizer.model', \n",
|
185 |
+
" max_seq_len=1000, \n",
|
186 |
+
" max_batch_size=16,\n",
|
187 |
+
" device=reg_device,\n",
|
188 |
+
" model_parallel_size=1)"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "code",
|
193 |
+
"execution_count": 18,
|
194 |
+
"metadata": {},
|
195 |
+
"outputs": [],
|
196 |
+
"source": [
|
197 |
+
"tokenizer = Tokenizer(f\"{weight_path}/tokenizer.model\")"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "markdown",
|
202 |
+
"metadata": {},
|
203 |
+
"source": [
|
204 |
+
"# Evaluation"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": 13,
|
210 |
+
"metadata": {},
|
211 |
+
"outputs": [
|
212 |
+
{
|
213 |
+
"name": "stdout",
|
214 |
+
"output_type": "stream",
|
215 |
+
"text": [
|
216 |
+
"Length: 7993\n"
|
217 |
+
]
|
218 |
+
}
|
219 |
+
],
|
220 |
+
"source": [
|
221 |
+
"trivia_path = \"../../../datasets/qa/wikipedia-dev.json\"\n",
|
222 |
+
"with open(trivia_path, \"r\") as f:\n",
|
223 |
+
" triviaqa = json.load(f)[\"Data\"]\n",
|
224 |
+
"print(f\"Length: {len(triviaqa)}\")"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": 14,
|
230 |
+
"metadata": {},
|
231 |
+
"outputs": [],
|
232 |
+
"source": [
|
233 |
+
"torch.set_default_dtype(torch.float32)"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": 15,
|
239 |
+
"metadata": {},
|
240 |
+
"outputs": [],
|
241 |
+
"source": [
|
242 |
+
"model_types = [\"superposed\", \"regular\"]\n",
|
243 |
+
"model_type = model_types[0]"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "code",
|
248 |
+
"execution_count": 16,
|
249 |
+
"metadata": {},
|
250 |
+
"outputs": [],
|
251 |
+
"source": [
|
252 |
+
"# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/triviaqa/default.yaml\n",
|
253 |
+
"def evaluate_trivia(model_type, question, max_gen_len):\n",
|
254 |
+
" question = \"Question: \" + question + \"\\nAnswer:\"\n",
|
255 |
+
" text_len = len(question) # for truncating\n",
|
256 |
+
" prompt_len = len(tokenizer.encode([question], True, False)[0]) # for model\n",
|
257 |
+
" if model_type == \"regular\":\n",
|
258 |
+
" input = [question for _ in range(n_drafts)]\n",
|
259 |
+
" sequences, _ = evaluate_nucleus_losses(data=input,\n",
|
260 |
+
" model=reg_model,\n",
|
261 |
+
" tokenizer=tokenizer,\n",
|
262 |
+
" prompt_len=prompt_len,\n",
|
263 |
+
" max_gen_len=max_gen_len,\n",
|
264 |
+
" temp=0.6, # Set to 0 for greedy\n",
|
265 |
+
" bsz=8,\n",
|
266 |
+
" marker=False)\n",
|
267 |
+
" n_pd, seq_len = sequences.shape\n",
|
268 |
+
" elif model_type == \"superposed\":\n",
|
269 |
+
" sequences, _ = evaluate_mixed_losses(data=[question],\n",
|
270 |
+
" model=sup_model,\n",
|
271 |
+
" tokenizer=tokenizer,\n",
|
272 |
+
" prompt_len=prompt_len,\n",
|
273 |
+
" max_gen_len=max_gen_len,\n",
|
274 |
+
" alpha=alpha,\n",
|
275 |
+
" temp=temp,\n",
|
276 |
+
" n_drafts=n_drafts,\n",
|
277 |
+
" n_token_sample=n_token_sample,\n",
|
278 |
+
" smoothing=None, # greedy\n",
|
279 |
+
" bsz=8,\n",
|
280 |
+
" i_weights=i_weights,\n",
|
281 |
+
" i_length=i_length,\n",
|
282 |
+
" ngrams=ngrams,\n",
|
283 |
+
" marker=False)\n",
|
284 |
+
" n_p, n_d, seq_len = sequences.shape\n",
|
285 |
+
" # Process results\n",
|
286 |
+
" sequences = sequences.reshape(-1, seq_len).tolist()\n",
|
287 |
+
" for d_idx in range(len(sequences)):\n",
|
288 |
+
" draft = sequences[d_idx]\n",
|
289 |
+
" if -1 in draft:\n",
|
290 |
+
" draft = draft[:draft.index(-1)]\n",
|
291 |
+
" sequences[d_idx] = draft\n",
|
292 |
+
" decoded_seq = tokenizer.decode(sequences)\n",
|
293 |
+
" answers = []\n",
|
294 |
+
" for s in decoded_seq:\n",
|
295 |
+
" # print(s)\n",
|
296 |
+
" answers.append(re.split(\"[,.\\n]\", s[text_len:].strip())[0])\n",
|
297 |
+
" return answers\n",
|
298 |
+
" "
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": null,
|
304 |
+
"metadata": {},
|
305 |
+
"outputs": [],
|
306 |
+
"source": [
|
307 |
+
"questions = {}\n",
|
308 |
+
"predictions = {}\n",
|
309 |
+
"print(f\"Precision from 1 to {n_drafts}\")\n",
|
310 |
+
"for sample in tqdm(triviaqa):\n",
|
311 |
+
" # Adaptively select generation length\n",
|
312 |
+
" longest = 0\n",
|
313 |
+
" shortest = 1000\n",
|
314 |
+
" total = 0\n",
|
315 |
+
" for answer in sample[\"Answer\"][\"Aliases\"]:\n",
|
316 |
+
" tmp = tokenizer.encode([answer], False, False)[0]\n",
|
317 |
+
" if len(tmp) > longest:\n",
|
318 |
+
" longest = len(tmp)\n",
|
319 |
+
" if len(tmp) < shortest:\n",
|
320 |
+
" shortest = len(tmp)\n",
|
321 |
+
" total += len(tmp)\n",
|
322 |
+
" # Evaluation code\n",
|
323 |
+
" id = sample[\"QuestionId\"]\n",
|
324 |
+
" question = sample[\"Question\"]\n",
|
325 |
+
" answer = evaluate_trivia(model_type, question, max_gen_len=longest + 3)\n",
|
326 |
+
" predictions[id] = answer\n",
|
327 |
+
" questions[id] = question"
|
328 |
+
]
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"cell_type": "code",
|
332 |
+
"execution_count": null,
|
333 |
+
"metadata": {},
|
334 |
+
"outputs": [],
|
335 |
+
"source": [
|
336 |
+
"# Save precisions\n",
|
337 |
+
"precisions = {}\n",
|
338 |
+
"for i in range(1, n_drafts+1):\n",
|
339 |
+
" prec = str(i)\n",
|
340 |
+
" responses = {k: v[:i] for k, v in predictions.items()}\n",
|
341 |
+
" precisions[prec] = responses"
|
342 |
+
]
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"cell_type": "code",
|
346 |
+
"execution_count": null,
|
347 |
+
"metadata": {},
|
348 |
+
"outputs": [],
|
349 |
+
"source": [
|
350 |
+
"# Print some results\n",
|
351 |
+
"counter = 0\n",
|
352 |
+
"for k in predictions:\n",
|
353 |
+
" if counter >= 10:\n",
|
354 |
+
" break\n",
|
355 |
+
" print(questions[k])\n",
|
356 |
+
" print(predictions[k])\n",
|
357 |
+
" counter += 1\n",
|
358 |
+
" print(\"================\")"
|
359 |
+
]
|
360 |
+
},
|
361 |
+
{
|
362 |
+
"cell_type": "code",
|
363 |
+
"execution_count": null,
|
364 |
+
"metadata": {},
|
365 |
+
"outputs": [],
|
366 |
+
"source": [
|
367 |
+
"# Save results\n",
|
368 |
+
"os.makedirs(\"../../trivia/\", exist_ok=True)\n",
|
369 |
+
"for prec in range(1, n_drafts+1):\n",
|
370 |
+
" out_path = f\"../nucleus_extra/trivia_extra/ngram_4trivia_{model_type}_{prec}_4.json\"\n",
|
371 |
+
" with open(out_path, \"w\") as f:\n",
|
372 |
+
" json.dump(precisions[str(prec)], f, indent=4)"
|
373 |
+
]
|
374 |
+
},
|
375 |
+
{
|
376 |
+
"cell_type": "code",
|
377 |
+
"execution_count": null,
|
378 |
+
"metadata": {},
|
379 |
+
"outputs": [],
|
380 |
+
"source": []
|
381 |
+
}
|
382 |
+
],
|
383 |
+
"metadata": {
|
384 |
+
"kernelspec": {
|
385 |
+
"display_name": "Python 3 (ipykernel)",
|
386 |
+
"language": "python",
|
387 |
+
"name": "python3"
|
388 |
+
},
|
389 |
+
"language_info": {
|
390 |
+
"codemirror_mode": {
|
391 |
+
"name": "ipython",
|
392 |
+
"version": 3
|
393 |
+
},
|
394 |
+
"file_extension": ".py",
|
395 |
+
"mimetype": "text/x-python",
|
396 |
+
"name": "python",
|
397 |
+
"nbconvert_exporter": "python",
|
398 |
+
"pygments_lexer": "ipython3",
|
399 |
+
"version": "3.11.5"
|
400 |
+
}
|
401 |
+
},
|
402 |
+
"nbformat": 4,
|
403 |
+
"nbformat_minor": 4
|
404 |
+
}
|