Spaces:
Runtime error
Runtime error
Initial commit
Browse files- 100-0.txt +0 -0
- Dockerfile +11 -0
- Procfile +1 -0
- app.py +13 -0
- attention_replication.py +156 -0
- config.yaml +61 -0
- env.yaml +406 -0
- sampling.py +239 -0
- shakespeare_demo.py +105 -0
- transformer_replication.py +183 -0
- word_data.py +100 -0
100-0.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Create environment
|
2 |
+
FROM mambaorg/micromamba:1.3.1
|
3 |
+
COPY --chown=$MAMBA_USER:$MAMBA_USER env.yaml /tmp/env.yaml
|
4 |
+
RUN micromamba install --yes --file /tmp/env.yaml && \
|
5 |
+
micromamba clean --all --yes
|
6 |
+
|
7 |
+
# Run app
|
8 |
+
COPY . /app/
|
9 |
+
WORKDIR /app/
|
10 |
+
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
11 |
+
RUN python app.py
|
Procfile
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
web: gunicorn app:app
|
app.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask
|
2 |
+
import os
|
3 |
+
from shakespeare_demo import make_demo
|
4 |
+
|
5 |
+
app = Flask(__name__)
|
6 |
+
|
7 |
+
@app.route("/")
|
8 |
+
def hello_world():
|
9 |
+
return make_demo()
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
port = int(os.environ.get('PORT', 5999))
|
13 |
+
app.run(debug=True, host='0.0.0.0', port=port)
|
attention_replication.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import torch as t
|
3 |
+
import torch.nn as nn
|
4 |
+
from typing import Union, List
|
5 |
+
from fancy_einsum import einsum
|
6 |
+
from einops import repeat, rearrange, reduce
|
7 |
+
import numpy as np
|
8 |
+
#%%
|
9 |
+
def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
|
10 |
+
'''
|
11 |
+
Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).
|
12 |
+
|
13 |
+
With this function, you can ignore masking.
|
14 |
+
|
15 |
+
Q: shape (batches x seq_Q x head_size)
|
16 |
+
K: shape (batches x seq_K x head_size)
|
17 |
+
V: shape (batches x seq_K x head_size)
|
18 |
+
|
19 |
+
Return: shape (batches x seq_Q x head_size)
|
20 |
+
'''
|
21 |
+
|
22 |
+
attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
|
23 |
+
#Ignore masking
|
24 |
+
attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
|
25 |
+
attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
|
26 |
+
return attention_values
|
27 |
+
|
28 |
+
def test_single_head_attention_shape(single_head_attention):
|
29 |
+
Q = t.randn(1, 3, 2)
|
30 |
+
K = t.randn(1, 5, 2)
|
31 |
+
V = t.randn(1, 5, 2)
|
32 |
+
attention_values = single_head_attention(Q, K, V)
|
33 |
+
assert Q.shape == attention_values.shape
|
34 |
+
print(f"All tests in `test_single_head_attention_shape` passed.")
|
35 |
+
|
36 |
+
def test_single_head_attention(single_head_attention):
|
37 |
+
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
|
38 |
+
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
|
39 |
+
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
|
40 |
+
attention_values = single_head_attention(Q.float(), K.float(), V.float())
|
41 |
+
t.testing.assert_close(attention_values, t.tensor([[[9.7880e-04, 9.9902e-01, 9.7880e-04], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
|
42 |
+
print(f"All tests in `test_single_head_attention` passed.")
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
test_single_head_attention_shape(single_head_attention)
|
46 |
+
test_single_head_attention(single_head_attention)
|
47 |
+
# %%
|
48 |
+
def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
|
49 |
+
'''
|
50 |
+
Should return the results of masked self-attention.
|
51 |
+
|
52 |
+
See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking.
|
53 |
+
|
54 |
+
Q: shape (batches x seq_Q x head_size)
|
55 |
+
K: shape (batches x seq_K x head_size)
|
56 |
+
V: shape (batches x seq_K x head_size)
|
57 |
+
|
58 |
+
Return: shape (batches x seq_Q x head_size)
|
59 |
+
'''
|
60 |
+
attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
|
61 |
+
batches, seq_Q, head_size = Q.shape
|
62 |
+
batches, seq_K, head_size = K.shape
|
63 |
+
|
64 |
+
q_index = repeat(t.arange(0, seq_Q), 'q -> b q k', b=batches, k=seq_K)
|
65 |
+
k_index = repeat(t.arange(0, seq_K), 'k -> b q k', b=batches, q=seq_Q)
|
66 |
+
mask = k_index <= q_index
|
67 |
+
attention_scores = t.where(mask, attention_scores, -t.inf)
|
68 |
+
attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
|
69 |
+
attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
|
70 |
+
return attention_values
|
71 |
+
|
72 |
+
def test_single_head_masked_attention(single_head_masked_attention):
|
73 |
+
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
|
74 |
+
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
|
75 |
+
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
|
76 |
+
attention_values = single_head_masked_attention(Q.float(), K.float(), V.float())
|
77 |
+
t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
|
78 |
+
print(f"All tests in `test_single_head_attention` passed.")
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
test_single_head_attention_shape(single_head_masked_attention)
|
82 |
+
test_single_head_masked_attention(single_head_masked_attention)
|
83 |
+
# %%
|
84 |
+
def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
|
85 |
+
'''
|
86 |
+
Implements multihead masked attention on the matrices Q, K and V.
|
87 |
+
|
88 |
+
Q: shape (batch, seq, nheads*headsize)
|
89 |
+
K: shape (batch, seq, nheads*headsize)
|
90 |
+
V: shape (batch, seq, nheads*headsize)
|
91 |
+
|
92 |
+
returns: shape (batch, seq, nheads*headsize)
|
93 |
+
'''
|
94 |
+
new_Q = rearrange(Q, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
|
95 |
+
new_K = rearrange(K, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
|
96 |
+
new_V = rearrange(V, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
|
97 |
+
|
98 |
+
attention_scores = einsum('batches nheads seq_Q head_size, batches nheads seq_K head_size -> batches nheads seq_Q seq_K', new_Q, new_K)
|
99 |
+
batches, _, seq_Q, head_size = new_Q.shape
|
100 |
+
batches, _, seq_K, head_size = new_K.shape
|
101 |
+
q_index = repeat(t.arange(0, seq_Q), 'seq_Q -> batches nheads seq_Q seq_K', batches=batches, seq_K=seq_K, nheads=num_heads)
|
102 |
+
k_index = repeat(t.arange(0, seq_K), 'seq_K -> batches nheads seq_Q seq_K', batches=batches, seq_Q=seq_Q, nheads=num_heads)
|
103 |
+
mask = k_index <= q_index
|
104 |
+
device_inf = t.tensor(-np.inf).to(Q.device)
|
105 |
+
device_mask = mask.to(Q.device)
|
106 |
+
masked_attention_scores = t.where(device_mask, attention_scores, device_inf)
|
107 |
+
attention_probabilities = nn.functional.softmax(masked_attention_scores / np.sqrt(head_size), dim=-1)
|
108 |
+
attention_values = einsum('batches nheads seq_Q seq_K, batches nheads seq_K head_size -> batches seq_Q nheads head_size', attention_probabilities, new_V)
|
109 |
+
return rearrange(attention_values, 'batches seq_Q nheads head_size -> batches seq_Q (nheads head_size)')
|
110 |
+
|
111 |
+
def test_multihead_masked_attention(multihead_masked_attention):
|
112 |
+
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
|
113 |
+
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
|
114 |
+
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
|
115 |
+
attention_values = multihead_masked_attention(Q.float(), K.float(), V.float(), num_heads=1)
|
116 |
+
t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
|
117 |
+
print(f"All tests in `test_multihead_masked_attention` passed.")
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
test_multihead_masked_attention(multihead_masked_attention)
|
121 |
+
# %%
|
122 |
+
class MultiheadMaskedAttention(nn.Module):
|
123 |
+
W_QKV: nn.Linear
|
124 |
+
W_O: nn.Linear
|
125 |
+
|
126 |
+
def __init__(self, hidden_size: int, num_heads: int):
|
127 |
+
super().__init__()
|
128 |
+
self.hidden_size = hidden_size
|
129 |
+
self.num_heads = num_heads
|
130 |
+
assert self.hidden_size % self.num_heads == 0
|
131 |
+
self.W_QKV = nn.Linear(hidden_size, 3 * hidden_size)
|
132 |
+
self.W_O = nn.Linear(hidden_size, hidden_size)
|
133 |
+
|
134 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
135 |
+
'''
|
136 |
+
x: shape (batch, seq, hidden_size)
|
137 |
+
|
138 |
+
Return: shape (batch, seq, hidden_size)
|
139 |
+
'''
|
140 |
+
QKV = self.W_QKV(x)
|
141 |
+
Q = QKV[..., :self.hidden_size]
|
142 |
+
K = QKV[..., self.hidden_size:-self.hidden_size]
|
143 |
+
V = QKV[..., -self.hidden_size:]
|
144 |
+
attention_values = multihead_masked_attention(Q, K, V, self.num_heads)
|
145 |
+
return self.W_O(attention_values)
|
146 |
+
# %%
|
147 |
+
def test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention):
|
148 |
+
mma = MultiheadMaskedAttention(1, 1)
|
149 |
+
x = t.randn(2, 7, 1)
|
150 |
+
output = mma.forward(x)
|
151 |
+
assert x.shape == output.shape
|
152 |
+
print(f"All tests in `test_MultiheadMaskedAttention_shape` passed.")
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention)
|
156 |
+
# %%
|
config.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb_version: 1
|
2 |
+
|
3 |
+
_wandb:
|
4 |
+
desc: null
|
5 |
+
value:
|
6 |
+
cli_version: 0.13.5
|
7 |
+
framework: huggingface
|
8 |
+
huggingface_version: 4.24.0
|
9 |
+
is_jupyter_run: true
|
10 |
+
is_kaggle_kernel: false
|
11 |
+
python_version: 3.10.6
|
12 |
+
start_time: 1668083783.928274
|
13 |
+
t:
|
14 |
+
1:
|
15 |
+
- 1
|
16 |
+
- 11
|
17 |
+
- 41
|
18 |
+
- 49
|
19 |
+
- 55
|
20 |
+
2:
|
21 |
+
- 1
|
22 |
+
- 11
|
23 |
+
- 41
|
24 |
+
- 49
|
25 |
+
- 55
|
26 |
+
3:
|
27 |
+
- 1
|
28 |
+
- 2
|
29 |
+
- 3
|
30 |
+
- 23
|
31 |
+
- 37
|
32 |
+
4: 3.10.6
|
33 |
+
5: 0.13.5
|
34 |
+
6: 4.24.0
|
35 |
+
8:
|
36 |
+
- 1
|
37 |
+
- 5
|
38 |
+
batch_size:
|
39 |
+
desc: null
|
40 |
+
value: 64
|
41 |
+
dropout:
|
42 |
+
desc: null
|
43 |
+
value: 0.1
|
44 |
+
epochs:
|
45 |
+
desc: null
|
46 |
+
value: 2
|
47 |
+
hidden_size:
|
48 |
+
desc: null
|
49 |
+
value: 512
|
50 |
+
lr:
|
51 |
+
desc: null
|
52 |
+
value: 0.001
|
53 |
+
max_seq_len:
|
54 |
+
desc: null
|
55 |
+
value: 60
|
56 |
+
num_heads:
|
57 |
+
desc: null
|
58 |
+
value: 8
|
59 |
+
num_layers:
|
60 |
+
desc: null
|
61 |
+
value: 6
|
env.yaml
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: base
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
dependencies:
|
7 |
+
- _libgcc_mutex=0.1=conda_forge
|
8 |
+
- _openmp_mutex=4.5=2_kmp_llvm
|
9 |
+
- aiofiles=23.1.0=pyhd8ed1ab_0
|
10 |
+
- aiohttp=3.8.4=py310h1fa729e_0
|
11 |
+
- aiosignal=1.3.1=pyhd8ed1ab_0
|
12 |
+
- alsa-lib=1.2.8=h166bdaf_0
|
13 |
+
- altair=4.2.2=pyhd8ed1ab_0
|
14 |
+
- anyio=3.6.2=pyhd8ed1ab_0
|
15 |
+
- argon2-cffi=21.3.0=pyhd8ed1ab_0
|
16 |
+
- argon2-cffi-bindings=21.2.0=py310h5764c6d_3
|
17 |
+
- arrow-cpp=11.0.0=ha770c72_4_cpu
|
18 |
+
- asttokens=2.2.1=pyhd8ed1ab_0
|
19 |
+
- async-timeout=4.0.2=pyhd8ed1ab_0
|
20 |
+
- attr=2.5.1=h166bdaf_1
|
21 |
+
- attrs=22.2.0=pyh71513ae_0
|
22 |
+
- aws-c-auth=0.6.24=h565b4ff_2
|
23 |
+
- aws-c-cal=0.5.20=h679401e_5
|
24 |
+
- aws-c-common=0.8.10=h0b41bf4_0
|
25 |
+
- aws-c-compression=0.2.16=hbe6ad0c_2
|
26 |
+
- aws-c-event-stream=0.2.18=h489b7ba_4
|
27 |
+
- aws-c-http=0.7.4=hb2c4a47_0
|
28 |
+
- aws-c-io=0.13.15=head7655_1
|
29 |
+
- aws-c-mqtt=0.8.6=haf0be06_3
|
30 |
+
- aws-c-s3=0.2.4=h05be983_0
|
31 |
+
- aws-c-sdkutils=0.1.7=hbe6ad0c_2
|
32 |
+
- aws-checksums=0.1.14=hbe6ad0c_2
|
33 |
+
- aws-crt-cpp=0.19.7=h9b63b7c_3
|
34 |
+
- aws-sdk-cpp=1.10.57=hd557813_3
|
35 |
+
- backcall=0.2.0=pyh9f0ad1d_0
|
36 |
+
- backports=1.0=pyhd8ed1ab_3
|
37 |
+
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
|
38 |
+
- beautifulsoup4=4.11.2=pyha770c72_0
|
39 |
+
- blas=2.116=mkl
|
40 |
+
- blas-devel=3.9.0=16_linux64_mkl
|
41 |
+
- bleach=6.0.0=pyhd8ed1ab_0
|
42 |
+
- brotli=1.0.9=h166bdaf_8
|
43 |
+
- brotli-bin=1.0.9=h166bdaf_8
|
44 |
+
- brotlipy=0.7.0=py310h5764c6d_1005
|
45 |
+
- bzip2=1.0.8=h7f98852_4
|
46 |
+
- c-ares=1.18.1=h7f98852_0
|
47 |
+
- ca-certificates=2022.12.7=ha878542_0
|
48 |
+
- cairo=1.16.0=ha61ee94_1014
|
49 |
+
- certifi=2022.12.7=pyhd8ed1ab_0
|
50 |
+
- cffi=1.15.1=py310h255011f_3
|
51 |
+
- charset-normalizer=2.1.1=pyhd8ed1ab_0
|
52 |
+
- click=8.1.3=unix_pyhd8ed1ab_2
|
53 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
54 |
+
- comm=0.1.2=pyhd8ed1ab_0
|
55 |
+
- contourpy=1.0.7=py310hdf3cbec_0
|
56 |
+
- cryptography=39.0.1=py310h34c0648_0
|
57 |
+
- cuda=11.6.1=0
|
58 |
+
- cuda-cccl=11.6.55=hf6102b2_0
|
59 |
+
- cuda-command-line-tools=11.6.2=0
|
60 |
+
- cuda-compiler=11.6.2=0
|
61 |
+
- cuda-cudart=11.6.55=he381448_0
|
62 |
+
- cuda-cudart-dev=11.6.55=h42ad0f4_0
|
63 |
+
- cuda-cuobjdump=11.6.124=h2eeebcb_0
|
64 |
+
- cuda-cupti=11.6.124=h86345e5_0
|
65 |
+
- cuda-cuxxfilt=11.6.124=hecbf4f6_0
|
66 |
+
- cuda-driver-dev=11.6.55=0
|
67 |
+
- cuda-gdb=12.0.140=0
|
68 |
+
- cuda-libraries=11.6.1=0
|
69 |
+
- cuda-libraries-dev=11.6.1=0
|
70 |
+
- cuda-memcheck=11.8.86=0
|
71 |
+
- cuda-nsight=12.0.140=0
|
72 |
+
- cuda-nsight-compute=12.0.1=0
|
73 |
+
- cuda-nvcc=11.6.124=hbba6d2d_0
|
74 |
+
- cuda-nvdisasm=12.0.140=0
|
75 |
+
- cuda-nvml-dev=11.6.55=haa9ef22_0
|
76 |
+
- cuda-nvprof=12.0.146=0
|
77 |
+
- cuda-nvprune=11.6.124=he22ec0a_0
|
78 |
+
- cuda-nvrtc=11.6.124=h020bade_0
|
79 |
+
- cuda-nvrtc-dev=11.6.124=h249d397_0
|
80 |
+
- cuda-nvtx=11.6.124=h0630a44_0
|
81 |
+
- cuda-nvvp=12.0.146=0
|
82 |
+
- cuda-runtime=11.6.1=0
|
83 |
+
- cuda-samples=11.6.101=h8efea70_0
|
84 |
+
- cuda-sanitizer-api=12.0.140=0
|
85 |
+
- cuda-toolkit=11.6.1=0
|
86 |
+
- cuda-tools=11.6.1=0
|
87 |
+
- cuda-visual-tools=11.6.1=0
|
88 |
+
- cycler=0.11.0=pyhd8ed1ab_0
|
89 |
+
- dataclasses=0.8=pyhc8e2a94_3
|
90 |
+
- datasets=2.9.0=pyhd8ed1ab_0
|
91 |
+
- dbus=1.13.6=h5008d03_3
|
92 |
+
- debugpy=1.6.6=py310heca2aa9_0
|
93 |
+
- decorator=5.1.1=pyhd8ed1ab_0
|
94 |
+
- defusedxml=0.7.1=pyhd8ed1ab_0
|
95 |
+
- dill=0.3.6=pyhd8ed1ab_1
|
96 |
+
- einops=0.6.0=pyhd8ed1ab_0
|
97 |
+
- entrypoints=0.4=pyhd8ed1ab_0
|
98 |
+
- executing=1.2.0=pyhd8ed1ab_0
|
99 |
+
- expat=2.5.0=h27087fc_0
|
100 |
+
- fastapi=0.92.0=pyhd8ed1ab_0
|
101 |
+
- ffmpeg=4.3=hf484d3e_0
|
102 |
+
- ffmpy=0.3.0=pyhb6f538c_0
|
103 |
+
- fftw=3.3.10=nompi_hf0379b8_106
|
104 |
+
- filelock=3.9.0=pyhd8ed1ab_0
|
105 |
+
- flask=2.2.3=pyhd8ed1ab_0
|
106 |
+
- flit-core=3.8.0=pyhd8ed1ab_0
|
107 |
+
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
108 |
+
- font-ttf-inconsolata=3.000=h77eed37_0
|
109 |
+
- font-ttf-source-code-pro=2.038=h77eed37_0
|
110 |
+
- font-ttf-ubuntu=0.83=hab24e00_0
|
111 |
+
- fontconfig=2.14.2=h14ed4e7_0
|
112 |
+
- fonts-conda-ecosystem=1=0
|
113 |
+
- fonts-conda-forge=1=0
|
114 |
+
- fonttools=4.38.0=py310h5764c6d_1
|
115 |
+
- freetype=2.12.1=hca18f0e_1
|
116 |
+
- frozenlist=1.3.3=py310h5764c6d_0
|
117 |
+
- fsspec=2023.1.0=pyhd8ed1ab_0
|
118 |
+
- gds-tools=1.5.1.14=0
|
119 |
+
- gettext=0.21.1=h27087fc_0
|
120 |
+
- gflags=2.2.2=he1b5a44_1004
|
121 |
+
- glib=2.74.1=h6239696_1
|
122 |
+
- glib-tools=2.74.1=h6239696_1
|
123 |
+
- glog=0.6.0=h6f12383_0
|
124 |
+
- gmp=6.2.1=h58526e2_0
|
125 |
+
- gnutls=3.6.13=h85f3911_1
|
126 |
+
- gradio=3.19.1=pyhd8ed1ab_0
|
127 |
+
- graphite2=1.3.13=h58526e2_1001
|
128 |
+
- gst-plugins-base=1.22.0=h4243ec0_0
|
129 |
+
- gstreamer=1.22.0=h25f0c4b_0
|
130 |
+
- gstreamer-orc=0.4.33=h166bdaf_0
|
131 |
+
- h11=0.14.0=pyhd8ed1ab_0
|
132 |
+
- h2=4.1.0=pyhd8ed1ab_0
|
133 |
+
- harfbuzz=6.0.0=h8e241bc_0
|
134 |
+
- hpack=4.0.0=pyh9f0ad1d_0
|
135 |
+
- httpcore=0.16.3=pyhd8ed1ab_0
|
136 |
+
- httpx=0.23.3=pyhd8ed1ab_0
|
137 |
+
- huggingface_hub=0.12.1=pyhd8ed1ab_0
|
138 |
+
- hyperframe=6.0.1=pyhd8ed1ab_0
|
139 |
+
- icu=70.1=h27087fc_0
|
140 |
+
- idna=3.4=pyhd8ed1ab_0
|
141 |
+
- importlib-metadata=6.0.0=pyha770c72_0
|
142 |
+
- importlib_metadata=6.0.0=hd8ed1ab_0
|
143 |
+
- importlib_resources=5.12.0=pyhd8ed1ab_0
|
144 |
+
- ipykernel=6.21.2=pyh210e3f2_0
|
145 |
+
- ipython=8.10.0=pyh41d4057_0
|
146 |
+
- ipython_genutils=0.2.0=py_1
|
147 |
+
- ipywidgets=8.0.4=pyhd8ed1ab_0
|
148 |
+
- itsdangerous=2.1.2=pyhd8ed1ab_0
|
149 |
+
- jack=1.9.22=h11f4161_0
|
150 |
+
- jedi=0.18.2=pyhd8ed1ab_0
|
151 |
+
- jinja2=3.1.2=pyhd8ed1ab_1
|
152 |
+
- joblib=1.2.0=pyhd8ed1ab_0
|
153 |
+
- jpeg=9e=h0b41bf4_3
|
154 |
+
- jsonschema=4.17.3=pyhd8ed1ab_0
|
155 |
+
- jupyter=1.0.0=py310hff52083_8
|
156 |
+
- jupyter_client=8.0.3=pyhd8ed1ab_0
|
157 |
+
- jupyter_console=6.5.1=pyhd8ed1ab_0
|
158 |
+
- jupyter_core=5.2.0=py310hff52083_0
|
159 |
+
- jupyter_events=0.6.3=pyhd8ed1ab_0
|
160 |
+
- jupyter_server=2.3.0=pyhd8ed1ab_0
|
161 |
+
- jupyter_server_terminals=0.4.4=pyhd8ed1ab_1
|
162 |
+
- jupyterlab_pygments=0.2.2=pyhd8ed1ab_0
|
163 |
+
- jupyterlab_widgets=3.0.5=pyhd8ed1ab_0
|
164 |
+
- keyutils=1.6.1=h166bdaf_0
|
165 |
+
- kiwisolver=1.4.4=py310hbf28c38_1
|
166 |
+
- krb5=1.20.1=h81ceb04_0
|
167 |
+
- lame=3.100=h166bdaf_1003
|
168 |
+
- lcms2=2.14=hfd0df8a_1
|
169 |
+
- ld_impl_linux-64=2.40=h41732ed_0
|
170 |
+
- lerc=4.0.0=h27087fc_0
|
171 |
+
- libabseil=20220623.0=cxx17_h05df665_6
|
172 |
+
- libarrow=11.0.0=hc42cb68_4_cpu
|
173 |
+
- libblas=3.9.0=16_linux64_mkl
|
174 |
+
- libbrotlicommon=1.0.9=h166bdaf_8
|
175 |
+
- libbrotlidec=1.0.9=h166bdaf_8
|
176 |
+
- libbrotlienc=1.0.9=h166bdaf_8
|
177 |
+
- libcap=2.66=ha37c62d_0
|
178 |
+
- libcblas=3.9.0=16_linux64_mkl
|
179 |
+
- libclang=15.0.7=default_had23c3d_1
|
180 |
+
- libclang13=15.0.7=default_h3e3d535_1
|
181 |
+
- libcrc32c=1.1.2=h9c3ff4c_0
|
182 |
+
- libcublas=11.9.2.110=h5e84587_0
|
183 |
+
- libcublas-dev=11.9.2.110=h5c901ab_0
|
184 |
+
- libcufft=10.7.1.112=hf425ae0_0
|
185 |
+
- libcufft-dev=10.7.1.112=ha5ce4c0_0
|
186 |
+
- libcufile=1.5.1.14=0
|
187 |
+
- libcufile-dev=1.5.1.14=0
|
188 |
+
- libcups=2.3.3=h36d4200_3
|
189 |
+
- libcurand=10.3.1.124=0
|
190 |
+
- libcurand-dev=10.3.1.124=0
|
191 |
+
- libcurl=7.88.1=hdc1c0ab_0
|
192 |
+
- libcusolver=11.3.4.124=h33c3c4e_0
|
193 |
+
- libcusparse=11.7.2.124=h7538f96_0
|
194 |
+
- libcusparse-dev=11.7.2.124=hbbe9722_0
|
195 |
+
- libdb=6.2.32=h9c3ff4c_0
|
196 |
+
- libdeflate=1.17=h0b41bf4_0
|
197 |
+
- libedit=3.1.20191231=he28a2e2_2
|
198 |
+
- libev=4.33=h516909a_1
|
199 |
+
- libevent=2.1.10=h28343ad_4
|
200 |
+
- libffi=3.4.2=h7f98852_5
|
201 |
+
- libflac=1.4.2=h27087fc_0
|
202 |
+
- libgcc-ng=12.2.0=h65d4601_19
|
203 |
+
- libgcrypt=1.10.1=h166bdaf_0
|
204 |
+
- libgfortran-ng=12.2.0=h69a702a_19
|
205 |
+
- libgfortran5=12.2.0=h337968e_19
|
206 |
+
- libglib=2.74.1=h606061b_1
|
207 |
+
- libgoogle-cloud=2.7.0=h21dfe5b_1
|
208 |
+
- libgpg-error=1.46=h620e276_0
|
209 |
+
- libgrpc=1.51.1=h4fad500_1
|
210 |
+
- libhwloc=2.8.0=h32351e8_1
|
211 |
+
- libiconv=1.17=h166bdaf_0
|
212 |
+
- liblapack=3.9.0=16_linux64_mkl
|
213 |
+
- liblapacke=3.9.0=16_linux64_mkl
|
214 |
+
- libllvm15=15.0.7=hadd5161_0
|
215 |
+
- libnghttp2=1.51.0=hff17c54_0
|
216 |
+
- libnpp=11.6.3.124=hd2722f0_0
|
217 |
+
- libnpp-dev=11.6.3.124=h3c42840_0
|
218 |
+
- libnsl=2.0.0=h7f98852_0
|
219 |
+
- libnvjpeg=11.6.2.124=hd473ad6_0
|
220 |
+
- libnvjpeg-dev=11.6.2.124=hb5906b9_0
|
221 |
+
- libogg=1.3.4=h7f98852_1
|
222 |
+
- libopus=1.3.1=h7f98852_1
|
223 |
+
- libpng=1.6.39=h753d276_0
|
224 |
+
- libpq=15.2=hb675445_0
|
225 |
+
- libprotobuf=3.21.12=h3eb15da_0
|
226 |
+
- libsndfile=1.2.0=hb75c966_0
|
227 |
+
- libsodium=1.0.18=h36c2ea0_1
|
228 |
+
- libsqlite=3.40.0=h753d276_0
|
229 |
+
- libssh2=1.10.0=hf14f497_3
|
230 |
+
- libstdcxx-ng=12.2.0=h46fd767_19
|
231 |
+
- libsystemd0=252=h2a991cd_0
|
232 |
+
- libthrift=0.16.0=he500d00_2
|
233 |
+
- libtiff=4.5.0=h6adf6a1_2
|
234 |
+
- libtool=2.4.7=h27087fc_0
|
235 |
+
- libudev1=252=h166bdaf_0
|
236 |
+
- libutf8proc=2.8.0=h166bdaf_0
|
237 |
+
- libuuid=2.32.1=h7f98852_1000
|
238 |
+
- libvorbis=1.3.7=h9c3ff4c_0
|
239 |
+
- libwebp-base=1.2.4=h166bdaf_0
|
240 |
+
- libxcb=1.13=h7f98852_1004
|
241 |
+
- libxkbcommon=1.5.0=h79f4944_0
|
242 |
+
- libxml2=2.10.3=h7463322_0
|
243 |
+
- libzlib=1.2.13=h166bdaf_4
|
244 |
+
- linkify-it-py=2.0.0=pyhd8ed1ab_0
|
245 |
+
- llvm-openmp=15.0.7=h0cdce71_0
|
246 |
+
- lz4-c=1.9.4=hcb278e6_0
|
247 |
+
- markdown-it-py=2.1.0=pyhd8ed1ab_0
|
248 |
+
- markupsafe=2.1.2=py310h1fa729e_0
|
249 |
+
- matplotlib-base=3.7.0=py310he60537e_0
|
250 |
+
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
|
251 |
+
- mdit-py-plugins=0.3.3=pyhd8ed1ab_0
|
252 |
+
- mdurl=0.1.0=pyhd8ed1ab_0
|
253 |
+
- mistune=2.0.5=pyhd8ed1ab_0
|
254 |
+
- mkl=2022.1.0=h84fe81f_915
|
255 |
+
- mkl-devel=2022.1.0=ha770c72_916
|
256 |
+
- mkl-include=2022.1.0=h84fe81f_915
|
257 |
+
- mpg123=1.31.2=hcb278e6_0
|
258 |
+
- multidict=6.0.4=py310h1fa729e_0
|
259 |
+
- multiprocess=0.70.14=py310h5764c6d_3
|
260 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
261 |
+
- mysql-common=8.0.32=ha901b37_0
|
262 |
+
- mysql-libs=8.0.32=hd7da12d_0
|
263 |
+
- nbclassic=0.5.2=pyhd8ed1ab_0
|
264 |
+
- nbclient=0.7.2=pyhd8ed1ab_0
|
265 |
+
- nbconvert=7.2.9=pyhd8ed1ab_0
|
266 |
+
- nbconvert-core=7.2.9=pyhd8ed1ab_0
|
267 |
+
- nbconvert-pandoc=7.2.9=pyhd8ed1ab_0
|
268 |
+
- nbformat=5.7.3=pyhd8ed1ab_0
|
269 |
+
- ncurses=6.3=h27087fc_1
|
270 |
+
- nest-asyncio=1.5.6=pyhd8ed1ab_0
|
271 |
+
- nettle=3.6=he412f7d_0
|
272 |
+
- notebook=6.5.2=pyha770c72_1
|
273 |
+
- notebook-shim=0.2.2=pyhd8ed1ab_0
|
274 |
+
- nsight-compute=2022.4.1.6=0
|
275 |
+
- nspr=4.35=h27087fc_0
|
276 |
+
- nss=3.88=he45b914_0
|
277 |
+
- numpy=1.24.2=py310h8deb116_0
|
278 |
+
- openh264=2.1.1=h780b84a_0
|
279 |
+
- openjpeg=2.5.0=hfec8fc6_2
|
280 |
+
- openssl=3.0.8=h0b41bf4_0
|
281 |
+
- orc=1.8.2=hfdbbad2_2
|
282 |
+
- orjson=3.8.5=py310h38b9cce_1
|
283 |
+
- packaging=23.0=pyhd8ed1ab_0
|
284 |
+
- pandas=1.5.3=py310h9b08913_0
|
285 |
+
- pandoc=2.19.2=h32600fe_1
|
286 |
+
- pandocfilters=1.5.0=pyhd8ed1ab_0
|
287 |
+
- parquet-cpp=1.5.1=2
|
288 |
+
- parso=0.8.3=pyhd8ed1ab_0
|
289 |
+
- pcre2=10.40=hc3806b6_0
|
290 |
+
- pexpect=4.8.0=pyh1a96a4e_2
|
291 |
+
- pickleshare=0.7.5=py_1003
|
292 |
+
- pillow=9.4.0=py310h023d228_1
|
293 |
+
- pip=23.0.1=pyhd8ed1ab_0
|
294 |
+
- pixman=0.40.0=h36c2ea0_0
|
295 |
+
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
|
296 |
+
- platformdirs=3.0.0=pyhd8ed1ab_0
|
297 |
+
- ply=3.11=py_1
|
298 |
+
- prometheus_client=0.16.0=pyhd8ed1ab_0
|
299 |
+
- prompt-toolkit=3.0.36=pyha770c72_0
|
300 |
+
- prompt_toolkit=3.0.36=hd8ed1ab_0
|
301 |
+
- psutil=5.9.4=py310h5764c6d_0
|
302 |
+
- pthread-stubs=0.4=h36c2ea0_1001
|
303 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
304 |
+
- pulseaudio=16.1=ha8d29e2_1
|
305 |
+
- pure_eval=0.2.2=pyhd8ed1ab_0
|
306 |
+
- pyarrow=11.0.0=py310h633f555_4_cpu
|
307 |
+
- pycparser=2.21=pyhd8ed1ab_0
|
308 |
+
- pycryptodome=3.16.0=py310h1419917_0
|
309 |
+
- pydantic=1.10.5=py310h1fa729e_0
|
310 |
+
- pydub=0.25.1=pyhd8ed1ab_0
|
311 |
+
- pygments=2.14.0=pyhd8ed1ab_0
|
312 |
+
- pyopenssl=23.0.0=pyhd8ed1ab_0
|
313 |
+
- pyparsing=3.0.9=pyhd8ed1ab_0
|
314 |
+
- pyqt=5.15.7=py310hab646b1_3
|
315 |
+
- pyqt5-sip=12.11.0=py310heca2aa9_3
|
316 |
+
- pyrsistent=0.19.3=py310h1fa729e_0
|
317 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
318 |
+
- python=3.10.9=he550d4f_0_cpython
|
319 |
+
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
320 |
+
- python-fastjsonschema=2.16.2=pyhd8ed1ab_0
|
321 |
+
- python-json-logger=2.0.6=pyhd8ed1ab_0
|
322 |
+
- python-multipart=0.0.5=py_0
|
323 |
+
- python-xxhash=3.2.0=py310h1fa729e_0
|
324 |
+
- python_abi=3.10=3_cp310
|
325 |
+
- pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0
|
326 |
+
- pytorch-cuda=11.6=h867d48c_1
|
327 |
+
- pytorch-mutex=1.0=cuda
|
328 |
+
- pytz=2022.7.1=pyhd8ed1ab_0
|
329 |
+
- pyyaml=6.0=py310h5764c6d_5
|
330 |
+
- pyzmq=25.0.0=py310h059b190_0
|
331 |
+
- qt-main=5.15.8=h5d23da1_6
|
332 |
+
- qtconsole=5.4.0=pyhd8ed1ab_0
|
333 |
+
- qtconsole-base=5.4.0=pyha770c72_0
|
334 |
+
- qtpy=2.3.0=pyhd8ed1ab_0
|
335 |
+
- re2=2023.02.01=hcb278e6_0
|
336 |
+
- readline=8.1.2=h0f457ee_0
|
337 |
+
- regex=2022.10.31=py310h5764c6d_0
|
338 |
+
- requests=2.28.2=pyhd8ed1ab_0
|
339 |
+
- responses=0.18.0=pyhd8ed1ab_0
|
340 |
+
- rfc3339-validator=0.1.4=pyhd8ed1ab_0
|
341 |
+
- rfc3986=1.5.0=pyhd8ed1ab_0
|
342 |
+
- rfc3986-validator=0.1.1=pyh9f0ad1d_0
|
343 |
+
- s2n=1.3.35=h3358134_0
|
344 |
+
- sacremoses=0.0.53=pyhd8ed1ab_0
|
345 |
+
- send2trash=1.8.0=pyhd8ed1ab_0
|
346 |
+
- setuptools=67.3.2=pyhd8ed1ab_0
|
347 |
+
- sip=6.7.7=py310heca2aa9_0
|
348 |
+
- six=1.16.0=pyh6c4a22f_0
|
349 |
+
- snappy=1.1.9=hbd366e4_2
|
350 |
+
- sniffio=1.3.0=pyhd8ed1ab_0
|
351 |
+
- soupsieve=2.3.2.post1=pyhd8ed1ab_0
|
352 |
+
- stack_data=0.6.2=pyhd8ed1ab_0
|
353 |
+
- starlette=0.25.0=pyhd8ed1ab_0
|
354 |
+
- tbb=2021.7.0=h924138e_1
|
355 |
+
- terminado=0.17.1=pyh41d4057_0
|
356 |
+
- tinycss2=1.2.1=pyhd8ed1ab_0
|
357 |
+
- tk=8.6.12=h27826a3_0
|
358 |
+
- tokenizers=0.13.2=py310he1f1126_0
|
359 |
+
- toml=0.10.2=pyhd8ed1ab_0
|
360 |
+
- toolz=0.12.0=pyhd8ed1ab_0
|
361 |
+
- torchaudio=0.13.1=py310_cu116
|
362 |
+
- torchvision=0.14.1=py310_cu116
|
363 |
+
- tornado=6.2=py310h5764c6d_1
|
364 |
+
- tqdm=4.64.1=pyhd8ed1ab_0
|
365 |
+
- traitlets=5.9.0=pyhd8ed1ab_0
|
366 |
+
- transformers=4.26.1=pyhd8ed1ab_0
|
367 |
+
- typing-extensions=4.4.0=hd8ed1ab_0
|
368 |
+
- typing_extensions=4.4.0=pyha770c72_0
|
369 |
+
- tzdata=2022g=h191b570_0
|
370 |
+
- uc-micro-py=1.0.1=pyhd8ed1ab_0
|
371 |
+
- unicodedata2=15.0.0=py310h5764c6d_0
|
372 |
+
- urllib3=1.26.14=pyhd8ed1ab_0
|
373 |
+
- uvicorn=0.20.0=py310hff52083_1
|
374 |
+
- wcwidth=0.2.6=pyhd8ed1ab_0
|
375 |
+
- webencodings=0.5.1=py_1
|
376 |
+
- websocket-client=1.5.1=pyhd8ed1ab_0
|
377 |
+
- websockets=10.4=py310h5764c6d_1
|
378 |
+
- werkzeug=2.2.3=pyhd8ed1ab_0
|
379 |
+
- wheel=0.38.4=pyhd8ed1ab_0
|
380 |
+
- widgetsnbextension=4.0.5=pyhd8ed1ab_0
|
381 |
+
- xcb-util=0.4.0=h166bdaf_0
|
382 |
+
- xcb-util-image=0.4.0=h166bdaf_0
|
383 |
+
- xcb-util-keysyms=0.4.0=h166bdaf_0
|
384 |
+
- xcb-util-renderutil=0.3.9=h166bdaf_0
|
385 |
+
- xcb-util-wm=0.4.1=h166bdaf_0
|
386 |
+
- xorg-kbproto=1.0.7=h7f98852_1002
|
387 |
+
- xorg-libice=1.0.10=h7f98852_0
|
388 |
+
- xorg-libsm=1.2.3=hd9c2040_1000
|
389 |
+
- xorg-libx11=1.7.2=h7f98852_0
|
390 |
+
- xorg-libxau=1.0.9=h7f98852_0
|
391 |
+
- xorg-libxdmcp=1.1.3=h7f98852_0
|
392 |
+
- xorg-libxext=1.3.4=h7f98852_1
|
393 |
+
- xorg-libxrender=0.9.10=h7f98852_1003
|
394 |
+
- xorg-renderproto=0.11.1=h7f98852_1002
|
395 |
+
- xorg-xextproto=7.3.0=h7f98852_1002
|
396 |
+
- xorg-xproto=7.0.31=h7f98852_1007
|
397 |
+
- xxhash=0.8.1=h0b41bf4_0
|
398 |
+
- xz=5.2.6=h166bdaf_0
|
399 |
+
- yaml=0.2.5=h7f98852_2
|
400 |
+
- yarl=1.8.2=py310h5764c6d_0
|
401 |
+
- zeromq=4.3.4=h9c3ff4c_1
|
402 |
+
- zipp=3.14.0=pyhd8ed1ab_0
|
403 |
+
- zlib=1.2.13=h166bdaf_4
|
404 |
+
- zstd=1.5.2=h3eb15da_6
|
405 |
+
- pip:
|
406 |
+
- fancy-einsum==0.0.3
|
sampling.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import torch as t
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import transformers
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
gpt = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
|
8 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
|
9 |
+
|
10 |
+
def apply_sampling_methods(
|
11 |
+
input_ids: t.Tensor, logits: t.Tensor, temperature=1.0, freq_penalty=0.0, top_k=0, top_p=0.0
|
12 |
+
) -> int:
|
13 |
+
'''
|
14 |
+
Return the next token, sampled from the model's probability distribution with modifiers.
|
15 |
+
x
|
16 |
+
input_ids: shape (seq,)
|
17 |
+
'''
|
18 |
+
assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
|
19 |
+
assert temperature >= 0, "Temperature should be non-negative"
|
20 |
+
assert 0 <= top_p <= 1.0, "Top-p must be a probability"
|
21 |
+
assert 0 <= top_k, "Top-k must be non-negative"
|
22 |
+
assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"
|
23 |
+
|
24 |
+
if temperature == 0:
|
25 |
+
return greedy_search(logits)
|
26 |
+
if temperature != 1.0:
|
27 |
+
logits = apply_temperature(logits, temperature)
|
28 |
+
if freq_penalty != 0.0:
|
29 |
+
logits = apply_freq_penalty(input_ids, logits, freq_penalty)
|
30 |
+
if top_k > 0:
|
31 |
+
return sample_top_k(logits, top_k)
|
32 |
+
if top_p > 0:
|
33 |
+
return sample_top_p(logits, top_p)
|
34 |
+
return sample_basic(logits)
|
35 |
+
|
36 |
+
def sample_tokens(
|
37 |
+
model,
|
38 |
+
tokenizer,
|
39 |
+
initial_text: str,
|
40 |
+
max_tokens_generated: int = 30,
|
41 |
+
**kwargs
|
42 |
+
) -> str:
|
43 |
+
'''
|
44 |
+
Sample tokens until the model outputs `tokenizer.eos_token_id` or the specified token limit is reached.
|
45 |
+
|
46 |
+
Return: the prompt and continuation concatenated
|
47 |
+
'''
|
48 |
+
model.eval()
|
49 |
+
input_ids: list = tokenizer.encode(initial_text)
|
50 |
+
generated = []
|
51 |
+
device = next(model.parameters()).device
|
52 |
+
for _ in range(max_tokens_generated):
|
53 |
+
new_input_ids = t.tensor(np.array(input_ids + generated), dtype=t.int64, device=device)
|
54 |
+
new_input_ids_truncated = new_input_ids[-min(tokenizer.model_max_length, new_input_ids.shape[0]):].unsqueeze(0)
|
55 |
+
output = model(new_input_ids_truncated)
|
56 |
+
all_logits = output if isinstance(output, t.Tensor) else output.logits
|
57 |
+
logits = all_logits[0, -1] #batch=0, seq_len=-1 -> returns vocab_size
|
58 |
+
new_token = apply_sampling_methods(new_input_ids, logits, **kwargs)
|
59 |
+
generated.append(new_token)
|
60 |
+
if new_token == getattr(tokenizer, "eos_token_id", None):
|
61 |
+
break
|
62 |
+
return tokenizer.decode(input_ids + generated)
|
63 |
+
|
64 |
+
# %%
|
65 |
+
def greedy_search(logits: t.Tensor) -> int:
|
66 |
+
'''
|
67 |
+
logits: shape (vocab_size, )
|
68 |
+
|
69 |
+
Return: the most likely token (as an integer)
|
70 |
+
'''
|
71 |
+
return logits.argmax().numpy()
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
prompt = "Jingle bells, jingle bells, jingle all the way"
|
75 |
+
print("Greedy decoding with prompt: ", prompt)
|
76 |
+
output = sample_tokens(gpt, tokenizer, prompt, max_tokens_generated=8, temperature=0.0)
|
77 |
+
print(f"Your model said: {output}")
|
78 |
+
expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
|
79 |
+
assert output == expected
|
80 |
+
|
81 |
+
print("Greedy decoding a second time (should be deterministic): ")
|
82 |
+
output = sample_tokens(gpt, tokenizer, prompt, max_tokens_generated=8, temperature=0.0)
|
83 |
+
print(f"Your model said: {output}")
|
84 |
+
expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
|
85 |
+
assert output == expected
|
86 |
+
|
87 |
+
print("Tests passed!")
|
88 |
+
# %%
|
89 |
+
def sample_basic(logits: t.Tensor) -> int:
|
90 |
+
'''
|
91 |
+
logits: shape (vocab_size, ) - unnormalized log-probabilities
|
92 |
+
|
93 |
+
Return: a sampled token
|
94 |
+
'''
|
95 |
+
return t.distributions.categorical.Categorical(logits=logits).sample()
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
N = 20000
|
99 |
+
probs = t.linspace(0, 0.4, 5)
|
100 |
+
unnormalized_logits = probs.log() + 1.2345
|
101 |
+
samples = t.tensor([sample_basic(unnormalized_logits) for _ in range(N)])
|
102 |
+
counts = t.bincount(samples, minlength=len(probs)) / N
|
103 |
+
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
|
104 |
+
t.testing.assert_close(counts, probs, atol=0.01, rtol=0)
|
105 |
+
print("Tests passed!")
|
106 |
+
# %%
|
107 |
+
def apply_temperature(logits: t.Tensor, temperature: float) -> t.Tensor:
|
108 |
+
'''
|
109 |
+
logits: shape (vocab_size, )
|
110 |
+
|
111 |
+
Return: shape (vocab_size, )
|
112 |
+
'''
|
113 |
+
assert temperature > 0
|
114 |
+
return logits / temperature
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
logits = t.tensor([1, 2]).log()
|
118 |
+
cold_logits = apply_temperature(logits, 0.001)
|
119 |
+
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
|
120 |
+
t.testing.assert_close(cold_logits, 1000.0 * logits)
|
121 |
+
hot_logits = apply_temperature(logits, 1000.0)
|
122 |
+
print("A high temperature flattens the distribution: ", hot_logits)
|
123 |
+
t.testing.assert_close(hot_logits, 0.001 * logits)
|
124 |
+
print("Tests passed!")
|
125 |
+
|
126 |
+
# %%
|
127 |
+
def apply_freq_penalty(input_ids: t.Tensor, logits: t.Tensor, freq_penalty: float) -> t.Tensor:
|
128 |
+
'''
|
129 |
+
input_ids: shape (seq, )
|
130 |
+
logits: shape (vocab_size, )
|
131 |
+
|
132 |
+
Return: shape (vocab_size, )
|
133 |
+
'''
|
134 |
+
count = input_ids.bincount(minlength=len(logits))
|
135 |
+
logits -= count * freq_penalty
|
136 |
+
return logits
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
|
140 |
+
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt").squeeze()
|
141 |
+
logits = t.ones(tokenizer.vocab_size)
|
142 |
+
penalized_logits = apply_freq_penalty(input_ids, logits, 2.0)
|
143 |
+
assert penalized_logits[5156].item() == -11, "Expected 6 occurrences of ' baby' with leading space"
|
144 |
+
assert penalized_logits[14801].item() == -5, "Expected 3 occurrences of ' Baby' with leading space"
|
145 |
+
print("Tests passed!")
|
146 |
+
# %%
|
147 |
+
N_RUNS = 0
|
148 |
+
your_prompt = "Jingle bells, jingle bells, jingle all the way"
|
149 |
+
cases = [
|
150 |
+
("High freq penalty", dict(freq_penalty=100.0)),
|
151 |
+
("Negative freq penalty", dict(freq_penalty=-1.0)),
|
152 |
+
("Too hot!", dict(temperature=2.0)),
|
153 |
+
("Pleasantly cool", dict(temperature=0.7)),
|
154 |
+
("Pleasantly warm", dict(temperature=0.9)),
|
155 |
+
("Too cold!", dict(temperature=0.01)),
|
156 |
+
]
|
157 |
+
for (name, kwargs) in cases:
|
158 |
+
for i in range(N_RUNS):
|
159 |
+
output = sample_tokens(gpt, tokenizer, your_prompt, max_tokens_generated=24, **kwargs)
|
160 |
+
print(f"Sample {i} with: {name} ({kwargs}):")
|
161 |
+
print(f"Your model said: {repr(output)}\n")
|
162 |
+
# %%
|
163 |
+
def sample_top_k(logits: t.Tensor, top_k: int) -> int:
|
164 |
+
'''
|
165 |
+
logits: shape (vocab_size, ) - unnormalized log-probabilities
|
166 |
+
top_k: only consider this many of the most likely tokens for sampling
|
167 |
+
|
168 |
+
Return: a sampled token
|
169 |
+
'''
|
170 |
+
values, indices = t.topk(logits, top_k)
|
171 |
+
return indices[sample_basic(values)].item()
|
172 |
+
|
173 |
+
if __name__ == "__main__":
|
174 |
+
N = 50000
|
175 |
+
k = 3
|
176 |
+
probs = t.linspace(0, 0.4, 5)
|
177 |
+
unnormalized_logits = probs.log() + 1.2345
|
178 |
+
samples = t.tensor([sample_top_k(unnormalized_logits, k) for _ in range(N)])
|
179 |
+
counts = t.bincount(samples, minlength=len(probs)) / N
|
180 |
+
expected = probs.clone()
|
181 |
+
expected[:-k] = 0
|
182 |
+
expected /= expected.sum()
|
183 |
+
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
|
184 |
+
t.testing.assert_close(counts, expected, atol=0.01, rtol=0)
|
185 |
+
print("Tests passed!")
|
186 |
+
# %%
|
187 |
+
if __name__ == "__main__":
|
188 |
+
your_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
|
189 |
+
output = sample_tokens(gpt, tokenizer, your_prompt, temperature=0.7, top_k=40, max_tokens_generated=64)
|
190 |
+
print(f"Your model said: {repr(output)}")
|
191 |
+
# %%
|
192 |
+
def sample_top_p(logits: t.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> int:
|
193 |
+
'''
|
194 |
+
logits: shape (vocab_size, ) - unnormalized log-probabilities
|
195 |
+
|
196 |
+
Return: a sampled token
|
197 |
+
'''
|
198 |
+
probs = t.exp(logits.double()) / t.exp(logits.double()).sum()
|
199 |
+
sorted_probs, sorted_indices = probs.sort(descending=True)
|
200 |
+
cum_probs = sorted_probs.cumsum(-1)
|
201 |
+
last_index = max(min_tokens_to_keep, t.where(cum_probs >= top_p)[0][0].numpy() + 1)
|
202 |
+
masked_probs = sorted_probs[:last_index]
|
203 |
+
sample = t.distributions.categorical.Categorical(probs=t.tensor(masked_probs)).sample()
|
204 |
+
return sorted_indices[sample]
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
N = 2000
|
208 |
+
unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456
|
209 |
+
samples = t.tensor([sample_top_p(unnormalized_logits, 0.5) for _ in range(N)])
|
210 |
+
counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N
|
211 |
+
print("top_p of 0.5 or lower should only return token 2: ", counts)
|
212 |
+
assert counts[0] == 0 and counts[1] == 0
|
213 |
+
|
214 |
+
N = 2000
|
215 |
+
unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456
|
216 |
+
samples = t.tensor([sample_top_p(unnormalized_logits, 0.50001) for _ in range(N)])
|
217 |
+
counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N
|
218 |
+
print("top_p in (0.5, 0.8] should return tokens 1 and 2: ", counts)
|
219 |
+
assert counts[0] == 0
|
220 |
+
|
221 |
+
N = 50000
|
222 |
+
top_p = 0.71
|
223 |
+
probs = t.linspace(0, 0.4, 5)
|
224 |
+
unnormalized_logits = probs.log() + 1.2345
|
225 |
+
samples = t.tensor([sample_top_p(unnormalized_logits, top_p) for _ in range(N)])
|
226 |
+
counts = t.bincount(samples, minlength=len(probs)) / N
|
227 |
+
expected = probs.clone()
|
228 |
+
expected[0:2] = 0
|
229 |
+
expected /= expected.sum()
|
230 |
+
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
|
231 |
+
t.testing.assert_close(counts, expected, atol=0.01, rtol=0.0)
|
232 |
+
|
233 |
+
print("All tests passed!")
|
234 |
+
# %%
|
235 |
+
if __name__ == "__main__":
|
236 |
+
your_prompt = "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for"
|
237 |
+
output = sample_tokens(gpt, tokenizer, your_prompt, temperature=0.7, top_p=0.95, max_tokens_generated=64)
|
238 |
+
print(f"Your model said: {repr(output)}")
|
239 |
+
# %%
|
shakespeare_demo.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import yaml
|
3 |
+
import torch as t
|
4 |
+
import gradio as gr
|
5 |
+
import re
|
6 |
+
from word_data import WordData
|
7 |
+
import sampling
|
8 |
+
import transformer_replication
|
9 |
+
#%%
|
10 |
+
MAIN = __name__ == '__main__'
|
11 |
+
device = 'cuda' if t.cuda.is_available() else 'cpu'
|
12 |
+
#%%
|
13 |
+
shakespeare = WordData.from_file(
|
14 |
+
'100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL'
|
15 |
+
)
|
16 |
+
if MAIN:
|
17 |
+
print('Vocab size: ', len(shakespeare.vocab))
|
18 |
+
#%%
|
19 |
+
#%%
|
20 |
+
with open('config.yaml', 'r') as f:
|
21 |
+
yaml_cfg = yaml.safe_load(f)
|
22 |
+
#%%
|
23 |
+
with open('model_state_dict.pt') as f:
|
24 |
+
state_dict = t.load(
|
25 |
+
'model_state_dict.pt'
|
26 |
+
)
|
27 |
+
#%%
|
28 |
+
base_config = transformer_replication.TransformerConfig(
|
29 |
+
num_layers=yaml_cfg['num_layers']['value'],
|
30 |
+
num_heads=yaml_cfg['num_heads']['value'],
|
31 |
+
vocab_size=len(shakespeare.vocab),
|
32 |
+
hidden_size=yaml_cfg['hidden_size']['value'],
|
33 |
+
max_seq_len=yaml_cfg['max_seq_len']['value'],
|
34 |
+
dropout=yaml_cfg['dropout']['value'],
|
35 |
+
)
|
36 |
+
shakespeare.model_max_length = yaml_cfg['max_seq_len']['value']
|
37 |
+
model = transformer_replication.DecoderOnlyTransformer(base_config)
|
38 |
+
|
39 |
+
model.load_state_dict(state_dict)
|
40 |
+
|
41 |
+
#%%
|
42 |
+
def generate(
|
43 |
+
text: str, max_tokens: int, temperature: float,
|
44 |
+
top_k: int,
|
45 |
+
) -> str:
|
46 |
+
return sampling.sample_tokens(
|
47 |
+
model,
|
48 |
+
shakespeare,
|
49 |
+
text,
|
50 |
+
max_tokens_generated=max_tokens,
|
51 |
+
temperature=temperature,
|
52 |
+
top_k=top_k,
|
53 |
+
)
|
54 |
+
|
55 |
+
#%%
|
56 |
+
def safe_generate(
|
57 |
+
text: str, max_tokens: int = 300, temperature: float = 1.0,
|
58 |
+
top_k: int = 20,
|
59 |
+
) -> str:
|
60 |
+
try:
|
61 |
+
raw = generate(
|
62 |
+
text, max_tokens=max_tokens, temperature=temperature, top_k=top_k,
|
63 |
+
)
|
64 |
+
match = re.match(r"(?P<start>\D*)\d+\n", raw)
|
65 |
+
if match is None:
|
66 |
+
return raw
|
67 |
+
return match.group('start')
|
68 |
+
except KeyError as e:
|
69 |
+
return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary"
|
70 |
+
#%%
|
71 |
+
examples = [
|
72 |
+
["I sang a beautiful song"],
|
73 |
+
["To be free is to"],
|
74 |
+
["How I love thee"],
|
75 |
+
]
|
76 |
+
#%%
|
77 |
+
if MAIN:
|
78 |
+
print(safe_generate('How I love thee'))
|
79 |
+
#%%
|
80 |
+
def make_demo():
|
81 |
+
demo = gr.Interface(
|
82 |
+
fn=safe_generate,
|
83 |
+
inputs=[
|
84 |
+
gr.components.Textbox(lines=5, label="Input Text"),
|
85 |
+
gr.components.Slider(
|
86 |
+
label='max tokens generated', minimum=1, maximum=1000,
|
87 |
+
value=300, step=1,
|
88 |
+
),
|
89 |
+
gr.components.Slider(
|
90 |
+
label='temperature', minimum=0, maximum=2, value=1, step=0.1,
|
91 |
+
),
|
92 |
+
gr.components.Slider(
|
93 |
+
label='top_k', minimum=1, maximum=100, value=10, step=1,
|
94 |
+
),
|
95 |
+
],
|
96 |
+
outputs=gr.components.Textbox(label="Generated Text"),
|
97 |
+
examples=examples
|
98 |
+
)
|
99 |
+
demo.launch()
|
100 |
+
# %%
|
101 |
+
'''
|
102 |
+
FIXME:
|
103 |
+
* deploy to heroku
|
104 |
+
* link from github home
|
105 |
+
'''
|
transformer_replication.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import transformers
|
3 |
+
import torch as t
|
4 |
+
import torch.nn as nn
|
5 |
+
from typing import Union, List
|
6 |
+
from fancy_einsum import einsum
|
7 |
+
import torch as t
|
8 |
+
from torch import nn
|
9 |
+
from torchvision import datasets, transforms
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from typing import Union, Optional, Callable, Tuple
|
12 |
+
import numpy as np
|
13 |
+
from einops import rearrange
|
14 |
+
import time
|
15 |
+
# %%
|
16 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
|
17 |
+
if __name__ == "__main__":
|
18 |
+
print(tokenizer("hello meg"))
|
19 |
+
print(tokenizer.encode("hello meg"))
|
20 |
+
print(tokenizer.decode([31373, 17243]))
|
21 |
+
print(tokenizer.tokenize("hello meg"))
|
22 |
+
print(f"'{tokenizer.decode(17243)}'")
|
23 |
+
# %%
|
24 |
+
class Embedding(nn.Module):
|
25 |
+
|
26 |
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
27 |
+
super().__init__()
|
28 |
+
self.num_embeddings = num_embeddings
|
29 |
+
self.embedding_dim = embedding_dim
|
30 |
+
|
31 |
+
self.weight = nn.Parameter(t.randn((self.num_embeddings, self.embedding_dim)))
|
32 |
+
|
33 |
+
def forward(self, x: t.LongTensor) -> t.Tensor:
|
34 |
+
'''For each integer in the input, return that row of the embedding.
|
35 |
+
'''
|
36 |
+
#return einsum('num_embeddings embedding_dim, i num_embeddings -> i embedding_dim', self.weight, nn.functional.one_hot(x, num_classes=self.num_embeddings).float())
|
37 |
+
return self.weight[x]
|
38 |
+
|
39 |
+
def extra_repr(self) -> str:
|
40 |
+
return f"{self.num_embeddings}, {self.embedding_dim}"
|
41 |
+
|
42 |
+
# %%
|
43 |
+
#TODO positional encoding
|
44 |
+
class PositionalEncoding(nn.Module):
|
45 |
+
|
46 |
+
def __init__(self, max_seq_len: int, embedding_dim: int):
|
47 |
+
super().__init__()
|
48 |
+
# Defining our positional encoding array, with `max_seq_len` rows
|
49 |
+
# This is an advantage of using sinusoidal encoding: we can easily expand to sequences of greater length without adding more learned params
|
50 |
+
angles = t.outer(t.arange(max_seq_len), 1 / 10000 ** (2 * t.arange(embedding_dim//2) / embedding_dim))
|
51 |
+
pe = t.zeros((max_seq_len, embedding_dim))
|
52 |
+
pe[:, ::2] = t.sin(angles)
|
53 |
+
pe[:, 1::2] = t.cos(angles)
|
54 |
+
# Register array as a buffer, rather than parameter (we don't want it to be updated by gradient descent)
|
55 |
+
self.register_buffer('pe', pe)
|
56 |
+
|
57 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
58 |
+
"""
|
59 |
+
x: shape (batch, seq_len, embedding_dim)
|
60 |
+
"""
|
61 |
+
batch, seq_len, embedding_dim = x.shape
|
62 |
+
# We slice the positional encoding, so it's the same shape as x
|
63 |
+
# This is equivalent to just using an nn.Embedding, but having the input be t.arange(seq_len)
|
64 |
+
return x + self.pe[:seq_len, :] # type: ignore
|
65 |
+
|
66 |
+
|
67 |
+
# %%
|
68 |
+
class LayerNorm(nn.Module):
|
69 |
+
|
70 |
+
def __init__(self, normalized_shape: Union[int, List[int]], eps: float = 1e-05, elementwise_affine: bool = True):
|
71 |
+
super().__init__()
|
72 |
+
self.normalized_shape = normalized_shape
|
73 |
+
self.eps = eps
|
74 |
+
self.elementwise_affine = elementwise_affine
|
75 |
+
|
76 |
+
if self.elementwise_affine:
|
77 |
+
self.weight = nn.Parameter(t.ones(normalized_shape))
|
78 |
+
self.bias = nn.Parameter(t.zeros(normalized_shape))
|
79 |
+
|
80 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
81 |
+
normalized_shape_dims = 1 if isinstance(self.normalized_shape, int) else len(self.normalized_shape)
|
82 |
+
x_mean = x.mean(dim=list(range(x.dim()))[-normalized_shape_dims:], keepdim=True) # complement of the normalised shape
|
83 |
+
x_var = x.var(dim=list(range(x.dim()))[-normalized_shape_dims:], keepdim=True, unbiased=False) # complement of the normalised shape
|
84 |
+
x_scaled = (x - x_mean) / t.sqrt(x_var + self.eps)
|
85 |
+
if self.elementwise_affine:
|
86 |
+
return x_scaled * self.weight + self.bias
|
87 |
+
return x_scaled
|
88 |
+
|
89 |
+
def extra_repr(self) -> str:
|
90 |
+
pass
|
91 |
+
|
92 |
+
# %%
|
93 |
+
from dataclasses import dataclass
|
94 |
+
|
95 |
+
@dataclass(frozen=True)
|
96 |
+
class TransformerConfig:
|
97 |
+
'''Constants used throughout your decoder-only transformer model.'''
|
98 |
+
|
99 |
+
num_layers: int
|
100 |
+
num_heads: int
|
101 |
+
vocab_size: int
|
102 |
+
hidden_size: int
|
103 |
+
max_seq_len: int
|
104 |
+
dropout: float = 0.1
|
105 |
+
layer_norm_epsilon: float = 1e-05
|
106 |
+
# %%
|
107 |
+
import attention_replication
|
108 |
+
|
109 |
+
class BertMLP(nn.Module):
|
110 |
+
def __init__(self, config: TransformerConfig):
|
111 |
+
super().__init__()
|
112 |
+
self.linear1 = nn.Linear(config.hidden_size, 4 * config.hidden_size)
|
113 |
+
self.gelu = nn.GELU()
|
114 |
+
self.linear2 = nn.Linear(4 * config.hidden_size, config.hidden_size)
|
115 |
+
self.dropout = nn.Dropout(config.dropout)
|
116 |
+
|
117 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
118 |
+
x = self.linear1(x)
|
119 |
+
x = self.gelu(x)
|
120 |
+
x = self.linear2(x)
|
121 |
+
x = self.dropout(x)
|
122 |
+
return x
|
123 |
+
|
124 |
+
class DecoderBlock(nn.Module):
|
125 |
+
|
126 |
+
def __init__(self, config: TransformerConfig):
|
127 |
+
super().__init__()
|
128 |
+
self.attention = attention_replication.MultiheadMaskedAttention(config.hidden_size, config.num_heads)
|
129 |
+
self.layer_norm1 = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
|
130 |
+
self.mlp = BertMLP(config)
|
131 |
+
self.layer_norm2 = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
|
132 |
+
|
133 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
134 |
+
y = self.attention(x)
|
135 |
+
y = self.layer_norm1(y)
|
136 |
+
x = x + y
|
137 |
+
z = self.mlp(x)
|
138 |
+
z = self.layer_norm2(z)
|
139 |
+
x = x + z
|
140 |
+
return x
|
141 |
+
|
142 |
+
class DecoderOnlyTransformer(nn.Module):
|
143 |
+
|
144 |
+
def __init__(self, config: TransformerConfig):
|
145 |
+
super().__init__()
|
146 |
+
self.token_embedding = Embedding(config.vocab_size, config.hidden_size)
|
147 |
+
self.positional_embedding = PositionalEncoding(config.max_seq_len, config.hidden_size)
|
148 |
+
self.dropout = nn.Dropout(config.dropout)
|
149 |
+
self.bert_blocks = nn.Sequential(*[DecoderBlock(config) for _ in range(config.num_layers)])
|
150 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
|
151 |
+
|
152 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
153 |
+
x = self.token_embedding(x)
|
154 |
+
x = self.positional_embedding(x)
|
155 |
+
x = self.dropout(x)
|
156 |
+
for block in self.bert_blocks:
|
157 |
+
x = block(x)
|
158 |
+
x = self.layer_norm(x)
|
159 |
+
x = einsum('num_embeddings embedding_dim,batch seq_len embedding_dim ->batch seq_len num_embeddings', self.token_embedding.weight, x)
|
160 |
+
return x
|
161 |
+
|
162 |
+
# %%
|
163 |
+
from torch.utils.data import Dataset
|
164 |
+
|
165 |
+
class CustomTextDataset(Dataset):
|
166 |
+
def __init__(self, texts, labels):
|
167 |
+
self.labels = labels
|
168 |
+
self.texts = texts
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def from_config(config, samples):
|
172 |
+
texts = [t.randint(high=config.vocab_size, size=(config.max_seq_len,)) for _ in range(samples)]
|
173 |
+
labels = [t.flip(text, (0,)) for text in texts]
|
174 |
+
return CustomTextDataset(texts, labels)
|
175 |
+
|
176 |
+
def __len__(self):
|
177 |
+
return len(self.labels)
|
178 |
+
|
179 |
+
def __getitem__(self, idx):
|
180 |
+
label = self.labels[idx]
|
181 |
+
text = self.texts[idx]
|
182 |
+
sample = (text, label)
|
183 |
+
return sample
|
word_data.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Optional, Union
|
3 |
+
import requests
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torch as t
|
6 |
+
|
7 |
+
|
8 |
+
class WordsDataset(Dataset):
|
9 |
+
def __init__(self, texts, labels):
|
10 |
+
self.texts = texts
|
11 |
+
self.labels = labels
|
12 |
+
|
13 |
+
def __len__(self):
|
14 |
+
return len(self.labels)
|
15 |
+
|
16 |
+
def __getitem__(self, idx):
|
17 |
+
label = self.labels[idx]
|
18 |
+
text = self.texts[idx]
|
19 |
+
sample = (text, label)
|
20 |
+
return sample
|
21 |
+
|
22 |
+
#%%
|
23 |
+
def tokenize(text):
|
24 |
+
return re.split(r"\b", text)
|
25 |
+
|
26 |
+
def _remove_duplicates(text, string=" "):
|
27 |
+
if string + string in text:
|
28 |
+
text = text.replace(string + string, string)
|
29 |
+
return _remove_duplicates(text, string)
|
30 |
+
return text
|
31 |
+
|
32 |
+
def remove_duplicates(text):
|
33 |
+
text = _remove_duplicates(text, ' ')
|
34 |
+
text = _remove_duplicates(text, '\n')
|
35 |
+
return text
|
36 |
+
|
37 |
+
# %%
|
38 |
+
class WordData():
|
39 |
+
def __init__(self, text, start, end, device):
|
40 |
+
self.complete_text = remove_duplicates(text)
|
41 |
+
if start is not None and end is not None:
|
42 |
+
self.complete_text = self.get_excerpt(start, end)
|
43 |
+
self.complete_tokens = tokenize(self.complete_text)
|
44 |
+
self.vocab = sorted(set(self.complete_tokens))
|
45 |
+
self.token_to_id = dict(zip(self.vocab, list(range(len(self.vocab)))))
|
46 |
+
self.id_to_token = dict(zip(list(range(len(self.vocab))), self.vocab))
|
47 |
+
self.model_max_length = None
|
48 |
+
self.device = device
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def from_link(link, device, start=None, end=None):
|
52 |
+
return WordData(
|
53 |
+
requests.get(link).content.decode('utf-8'),
|
54 |
+
start,
|
55 |
+
end,
|
56 |
+
device=device
|
57 |
+
)
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def from_file(filename, device, start=None, end=None):
|
61 |
+
with open(filename, encoding='utf-8') as f:
|
62 |
+
text = f.read()
|
63 |
+
return WordData(text, start, end, device=device)
|
64 |
+
|
65 |
+
def get_excerpt(self, start="THE SONNETS", end="THE END", text=None):
|
66 |
+
if text is None:
|
67 |
+
text = self.complete_text
|
68 |
+
assert start in text, f'get_excerpt: cannot find {start} in text'
|
69 |
+
l_stripped = text.split(start, maxsplit=1)[1]
|
70 |
+
assert end in l_stripped, f'get_excerpt: cannot find {end} in text'
|
71 |
+
r_stripped = l_stripped.split(end, maxsplit=1)[0]
|
72 |
+
return r_stripped
|
73 |
+
|
74 |
+
def generate_autoregressive_dataset(self, sequence_length, text=None):
|
75 |
+
self.model_max_length = sequence_length
|
76 |
+
if text is None:
|
77 |
+
text = self.complete_text
|
78 |
+
token_ids = self.encode(text, return_tensors="pt")
|
79 |
+
inputs = [token_ids[i:i + sequence_length] for i in range(len(token_ids) - sequence_length)]
|
80 |
+
labels = [token_ids[i + 1:i + 1 + sequence_length] for i in range(len(token_ids) - sequence_length)]
|
81 |
+
return WordsDataset(inputs, labels)
|
82 |
+
|
83 |
+
def encode(self, initial_text: str, return_tensors: Optional[str] = None) -> Union[list, t.Tensor]:
|
84 |
+
'''
|
85 |
+
Tokenizes initial_text, then returns the token ids.
|
86 |
+
|
87 |
+
Return type is list by default, but if return_tensors="pt" then it is returned as a tensor.
|
88 |
+
'''
|
89 |
+
tokens = tokenize(initial_text)
|
90 |
+
token_ids = [self.token_to_id[t] for t in tokens]
|
91 |
+
if return_tensors == "pt":
|
92 |
+
return t.tensor(token_ids, device=self.device)
|
93 |
+
return token_ids
|
94 |
+
|
95 |
+
def decode(self, list_of_ids: Union[t.Tensor, list]) -> str:
|
96 |
+
'''
|
97 |
+
Converts ids to a list of tokens, then joins them into a single string.
|
98 |
+
'''
|
99 |
+
tokens = [self.id_to_token[int(i)] for i in list_of_ids]
|
100 |
+
return "".join(tokens)
|