Spaces:
Runtime error
Runtime error
Initial commit
Browse files- 100-0.txt +0 -0
- Dockerfile +11 -0
- Procfile +1 -0
- app.py +13 -0
- attention_replication.py +156 -0
- config.yaml +61 -0
- env.yaml +406 -0
- sampling.py +239 -0
- shakespeare_demo.py +105 -0
- transformer_replication.py +183 -0
- word_data.py +100 -0
100-0.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Dockerfile
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Create environment
|
| 2 |
+
FROM mambaorg/micromamba:1.3.1
|
| 3 |
+
COPY --chown=$MAMBA_USER:$MAMBA_USER env.yaml /tmp/env.yaml
|
| 4 |
+
RUN micromamba install --yes --file /tmp/env.yaml && \
|
| 5 |
+
micromamba clean --all --yes
|
| 6 |
+
|
| 7 |
+
# Run app
|
| 8 |
+
COPY . /app/
|
| 9 |
+
WORKDIR /app/
|
| 10 |
+
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
| 11 |
+
RUN python app.py
|
Procfile
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
web: gunicorn app:app
|
app.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask
|
| 2 |
+
import os
|
| 3 |
+
from shakespeare_demo import make_demo
|
| 4 |
+
|
| 5 |
+
app = Flask(__name__)
|
| 6 |
+
|
| 7 |
+
@app.route("/")
|
| 8 |
+
def hello_world():
|
| 9 |
+
return make_demo()
|
| 10 |
+
|
| 11 |
+
if __name__ == "__main__":
|
| 12 |
+
port = int(os.environ.get('PORT', 5999))
|
| 13 |
+
app.run(debug=True, host='0.0.0.0', port=port)
|
attention_replication.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
import torch as t
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import Union, List
|
| 5 |
+
from fancy_einsum import einsum
|
| 6 |
+
from einops import repeat, rearrange, reduce
|
| 7 |
+
import numpy as np
|
| 8 |
+
#%%
|
| 9 |
+
def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
|
| 10 |
+
'''
|
| 11 |
+
Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).
|
| 12 |
+
|
| 13 |
+
With this function, you can ignore masking.
|
| 14 |
+
|
| 15 |
+
Q: shape (batches x seq_Q x head_size)
|
| 16 |
+
K: shape (batches x seq_K x head_size)
|
| 17 |
+
V: shape (batches x seq_K x head_size)
|
| 18 |
+
|
| 19 |
+
Return: shape (batches x seq_Q x head_size)
|
| 20 |
+
'''
|
| 21 |
+
|
| 22 |
+
attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
|
| 23 |
+
#Ignore masking
|
| 24 |
+
attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
|
| 25 |
+
attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
|
| 26 |
+
return attention_values
|
| 27 |
+
|
| 28 |
+
def test_single_head_attention_shape(single_head_attention):
|
| 29 |
+
Q = t.randn(1, 3, 2)
|
| 30 |
+
K = t.randn(1, 5, 2)
|
| 31 |
+
V = t.randn(1, 5, 2)
|
| 32 |
+
attention_values = single_head_attention(Q, K, V)
|
| 33 |
+
assert Q.shape == attention_values.shape
|
| 34 |
+
print(f"All tests in `test_single_head_attention_shape` passed.")
|
| 35 |
+
|
| 36 |
+
def test_single_head_attention(single_head_attention):
|
| 37 |
+
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
|
| 38 |
+
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
|
| 39 |
+
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
|
| 40 |
+
attention_values = single_head_attention(Q.float(), K.float(), V.float())
|
| 41 |
+
t.testing.assert_close(attention_values, t.tensor([[[9.7880e-04, 9.9902e-01, 9.7880e-04], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
|
| 42 |
+
print(f"All tests in `test_single_head_attention` passed.")
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
test_single_head_attention_shape(single_head_attention)
|
| 46 |
+
test_single_head_attention(single_head_attention)
|
| 47 |
+
# %%
|
| 48 |
+
def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
|
| 49 |
+
'''
|
| 50 |
+
Should return the results of masked self-attention.
|
| 51 |
+
|
| 52 |
+
See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking.
|
| 53 |
+
|
| 54 |
+
Q: shape (batches x seq_Q x head_size)
|
| 55 |
+
K: shape (batches x seq_K x head_size)
|
| 56 |
+
V: shape (batches x seq_K x head_size)
|
| 57 |
+
|
| 58 |
+
Return: shape (batches x seq_Q x head_size)
|
| 59 |
+
'''
|
| 60 |
+
attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
|
| 61 |
+
batches, seq_Q, head_size = Q.shape
|
| 62 |
+
batches, seq_K, head_size = K.shape
|
| 63 |
+
|
| 64 |
+
q_index = repeat(t.arange(0, seq_Q), 'q -> b q k', b=batches, k=seq_K)
|
| 65 |
+
k_index = repeat(t.arange(0, seq_K), 'k -> b q k', b=batches, q=seq_Q)
|
| 66 |
+
mask = k_index <= q_index
|
| 67 |
+
attention_scores = t.where(mask, attention_scores, -t.inf)
|
| 68 |
+
attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
|
| 69 |
+
attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
|
| 70 |
+
return attention_values
|
| 71 |
+
|
| 72 |
+
def test_single_head_masked_attention(single_head_masked_attention):
|
| 73 |
+
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
|
| 74 |
+
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
|
| 75 |
+
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
|
| 76 |
+
attention_values = single_head_masked_attention(Q.float(), K.float(), V.float())
|
| 77 |
+
t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
|
| 78 |
+
print(f"All tests in `test_single_head_attention` passed.")
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
test_single_head_attention_shape(single_head_masked_attention)
|
| 82 |
+
test_single_head_masked_attention(single_head_masked_attention)
|
| 83 |
+
# %%
|
| 84 |
+
def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
|
| 85 |
+
'''
|
| 86 |
+
Implements multihead masked attention on the matrices Q, K and V.
|
| 87 |
+
|
| 88 |
+
Q: shape (batch, seq, nheads*headsize)
|
| 89 |
+
K: shape (batch, seq, nheads*headsize)
|
| 90 |
+
V: shape (batch, seq, nheads*headsize)
|
| 91 |
+
|
| 92 |
+
returns: shape (batch, seq, nheads*headsize)
|
| 93 |
+
'''
|
| 94 |
+
new_Q = rearrange(Q, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
|
| 95 |
+
new_K = rearrange(K, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
|
| 96 |
+
new_V = rearrange(V, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
|
| 97 |
+
|
| 98 |
+
attention_scores = einsum('batches nheads seq_Q head_size, batches nheads seq_K head_size -> batches nheads seq_Q seq_K', new_Q, new_K)
|
| 99 |
+
batches, _, seq_Q, head_size = new_Q.shape
|
| 100 |
+
batches, _, seq_K, head_size = new_K.shape
|
| 101 |
+
q_index = repeat(t.arange(0, seq_Q), 'seq_Q -> batches nheads seq_Q seq_K', batches=batches, seq_K=seq_K, nheads=num_heads)
|
| 102 |
+
k_index = repeat(t.arange(0, seq_K), 'seq_K -> batches nheads seq_Q seq_K', batches=batches, seq_Q=seq_Q, nheads=num_heads)
|
| 103 |
+
mask = k_index <= q_index
|
| 104 |
+
device_inf = t.tensor(-np.inf).to(Q.device)
|
| 105 |
+
device_mask = mask.to(Q.device)
|
| 106 |
+
masked_attention_scores = t.where(device_mask, attention_scores, device_inf)
|
| 107 |
+
attention_probabilities = nn.functional.softmax(masked_attention_scores / np.sqrt(head_size), dim=-1)
|
| 108 |
+
attention_values = einsum('batches nheads seq_Q seq_K, batches nheads seq_K head_size -> batches seq_Q nheads head_size', attention_probabilities, new_V)
|
| 109 |
+
return rearrange(attention_values, 'batches seq_Q nheads head_size -> batches seq_Q (nheads head_size)')
|
| 110 |
+
|
| 111 |
+
def test_multihead_masked_attention(multihead_masked_attention):
|
| 112 |
+
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
|
| 113 |
+
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
|
| 114 |
+
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
|
| 115 |
+
attention_values = multihead_masked_attention(Q.float(), K.float(), V.float(), num_heads=1)
|
| 116 |
+
t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
|
| 117 |
+
print(f"All tests in `test_multihead_masked_attention` passed.")
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
test_multihead_masked_attention(multihead_masked_attention)
|
| 121 |
+
# %%
|
| 122 |
+
class MultiheadMaskedAttention(nn.Module):
|
| 123 |
+
W_QKV: nn.Linear
|
| 124 |
+
W_O: nn.Linear
|
| 125 |
+
|
| 126 |
+
def __init__(self, hidden_size: int, num_heads: int):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.hidden_size = hidden_size
|
| 129 |
+
self.num_heads = num_heads
|
| 130 |
+
assert self.hidden_size % self.num_heads == 0
|
| 131 |
+
self.W_QKV = nn.Linear(hidden_size, 3 * hidden_size)
|
| 132 |
+
self.W_O = nn.Linear(hidden_size, hidden_size)
|
| 133 |
+
|
| 134 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
| 135 |
+
'''
|
| 136 |
+
x: shape (batch, seq, hidden_size)
|
| 137 |
+
|
| 138 |
+
Return: shape (batch, seq, hidden_size)
|
| 139 |
+
'''
|
| 140 |
+
QKV = self.W_QKV(x)
|
| 141 |
+
Q = QKV[..., :self.hidden_size]
|
| 142 |
+
K = QKV[..., self.hidden_size:-self.hidden_size]
|
| 143 |
+
V = QKV[..., -self.hidden_size:]
|
| 144 |
+
attention_values = multihead_masked_attention(Q, K, V, self.num_heads)
|
| 145 |
+
return self.W_O(attention_values)
|
| 146 |
+
# %%
|
| 147 |
+
def test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention):
|
| 148 |
+
mma = MultiheadMaskedAttention(1, 1)
|
| 149 |
+
x = t.randn(2, 7, 1)
|
| 150 |
+
output = mma.forward(x)
|
| 151 |
+
assert x.shape == output.shape
|
| 152 |
+
print(f"All tests in `test_MultiheadMaskedAttention_shape` passed.")
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention)
|
| 156 |
+
# %%
|
config.yaml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_version: 1
|
| 2 |
+
|
| 3 |
+
_wandb:
|
| 4 |
+
desc: null
|
| 5 |
+
value:
|
| 6 |
+
cli_version: 0.13.5
|
| 7 |
+
framework: huggingface
|
| 8 |
+
huggingface_version: 4.24.0
|
| 9 |
+
is_jupyter_run: true
|
| 10 |
+
is_kaggle_kernel: false
|
| 11 |
+
python_version: 3.10.6
|
| 12 |
+
start_time: 1668083783.928274
|
| 13 |
+
t:
|
| 14 |
+
1:
|
| 15 |
+
- 1
|
| 16 |
+
- 11
|
| 17 |
+
- 41
|
| 18 |
+
- 49
|
| 19 |
+
- 55
|
| 20 |
+
2:
|
| 21 |
+
- 1
|
| 22 |
+
- 11
|
| 23 |
+
- 41
|
| 24 |
+
- 49
|
| 25 |
+
- 55
|
| 26 |
+
3:
|
| 27 |
+
- 1
|
| 28 |
+
- 2
|
| 29 |
+
- 3
|
| 30 |
+
- 23
|
| 31 |
+
- 37
|
| 32 |
+
4: 3.10.6
|
| 33 |
+
5: 0.13.5
|
| 34 |
+
6: 4.24.0
|
| 35 |
+
8:
|
| 36 |
+
- 1
|
| 37 |
+
- 5
|
| 38 |
+
batch_size:
|
| 39 |
+
desc: null
|
| 40 |
+
value: 64
|
| 41 |
+
dropout:
|
| 42 |
+
desc: null
|
| 43 |
+
value: 0.1
|
| 44 |
+
epochs:
|
| 45 |
+
desc: null
|
| 46 |
+
value: 2
|
| 47 |
+
hidden_size:
|
| 48 |
+
desc: null
|
| 49 |
+
value: 512
|
| 50 |
+
lr:
|
| 51 |
+
desc: null
|
| 52 |
+
value: 0.001
|
| 53 |
+
max_seq_len:
|
| 54 |
+
desc: null
|
| 55 |
+
value: 60
|
| 56 |
+
num_heads:
|
| 57 |
+
desc: null
|
| 58 |
+
value: 8
|
| 59 |
+
num_layers:
|
| 60 |
+
desc: null
|
| 61 |
+
value: 6
|
env.yaml
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: base
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
- conda-forge
|
| 6 |
+
dependencies:
|
| 7 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 8 |
+
- _openmp_mutex=4.5=2_kmp_llvm
|
| 9 |
+
- aiofiles=23.1.0=pyhd8ed1ab_0
|
| 10 |
+
- aiohttp=3.8.4=py310h1fa729e_0
|
| 11 |
+
- aiosignal=1.3.1=pyhd8ed1ab_0
|
| 12 |
+
- alsa-lib=1.2.8=h166bdaf_0
|
| 13 |
+
- altair=4.2.2=pyhd8ed1ab_0
|
| 14 |
+
- anyio=3.6.2=pyhd8ed1ab_0
|
| 15 |
+
- argon2-cffi=21.3.0=pyhd8ed1ab_0
|
| 16 |
+
- argon2-cffi-bindings=21.2.0=py310h5764c6d_3
|
| 17 |
+
- arrow-cpp=11.0.0=ha770c72_4_cpu
|
| 18 |
+
- asttokens=2.2.1=pyhd8ed1ab_0
|
| 19 |
+
- async-timeout=4.0.2=pyhd8ed1ab_0
|
| 20 |
+
- attr=2.5.1=h166bdaf_1
|
| 21 |
+
- attrs=22.2.0=pyh71513ae_0
|
| 22 |
+
- aws-c-auth=0.6.24=h565b4ff_2
|
| 23 |
+
- aws-c-cal=0.5.20=h679401e_5
|
| 24 |
+
- aws-c-common=0.8.10=h0b41bf4_0
|
| 25 |
+
- aws-c-compression=0.2.16=hbe6ad0c_2
|
| 26 |
+
- aws-c-event-stream=0.2.18=h489b7ba_4
|
| 27 |
+
- aws-c-http=0.7.4=hb2c4a47_0
|
| 28 |
+
- aws-c-io=0.13.15=head7655_1
|
| 29 |
+
- aws-c-mqtt=0.8.6=haf0be06_3
|
| 30 |
+
- aws-c-s3=0.2.4=h05be983_0
|
| 31 |
+
- aws-c-sdkutils=0.1.7=hbe6ad0c_2
|
| 32 |
+
- aws-checksums=0.1.14=hbe6ad0c_2
|
| 33 |
+
- aws-crt-cpp=0.19.7=h9b63b7c_3
|
| 34 |
+
- aws-sdk-cpp=1.10.57=hd557813_3
|
| 35 |
+
- backcall=0.2.0=pyh9f0ad1d_0
|
| 36 |
+
- backports=1.0=pyhd8ed1ab_3
|
| 37 |
+
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
|
| 38 |
+
- beautifulsoup4=4.11.2=pyha770c72_0
|
| 39 |
+
- blas=2.116=mkl
|
| 40 |
+
- blas-devel=3.9.0=16_linux64_mkl
|
| 41 |
+
- bleach=6.0.0=pyhd8ed1ab_0
|
| 42 |
+
- brotli=1.0.9=h166bdaf_8
|
| 43 |
+
- brotli-bin=1.0.9=h166bdaf_8
|
| 44 |
+
- brotlipy=0.7.0=py310h5764c6d_1005
|
| 45 |
+
- bzip2=1.0.8=h7f98852_4
|
| 46 |
+
- c-ares=1.18.1=h7f98852_0
|
| 47 |
+
- ca-certificates=2022.12.7=ha878542_0
|
| 48 |
+
- cairo=1.16.0=ha61ee94_1014
|
| 49 |
+
- certifi=2022.12.7=pyhd8ed1ab_0
|
| 50 |
+
- cffi=1.15.1=py310h255011f_3
|
| 51 |
+
- charset-normalizer=2.1.1=pyhd8ed1ab_0
|
| 52 |
+
- click=8.1.3=unix_pyhd8ed1ab_2
|
| 53 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
| 54 |
+
- comm=0.1.2=pyhd8ed1ab_0
|
| 55 |
+
- contourpy=1.0.7=py310hdf3cbec_0
|
| 56 |
+
- cryptography=39.0.1=py310h34c0648_0
|
| 57 |
+
- cuda=11.6.1=0
|
| 58 |
+
- cuda-cccl=11.6.55=hf6102b2_0
|
| 59 |
+
- cuda-command-line-tools=11.6.2=0
|
| 60 |
+
- cuda-compiler=11.6.2=0
|
| 61 |
+
- cuda-cudart=11.6.55=he381448_0
|
| 62 |
+
- cuda-cudart-dev=11.6.55=h42ad0f4_0
|
| 63 |
+
- cuda-cuobjdump=11.6.124=h2eeebcb_0
|
| 64 |
+
- cuda-cupti=11.6.124=h86345e5_0
|
| 65 |
+
- cuda-cuxxfilt=11.6.124=hecbf4f6_0
|
| 66 |
+
- cuda-driver-dev=11.6.55=0
|
| 67 |
+
- cuda-gdb=12.0.140=0
|
| 68 |
+
- cuda-libraries=11.6.1=0
|
| 69 |
+
- cuda-libraries-dev=11.6.1=0
|
| 70 |
+
- cuda-memcheck=11.8.86=0
|
| 71 |
+
- cuda-nsight=12.0.140=0
|
| 72 |
+
- cuda-nsight-compute=12.0.1=0
|
| 73 |
+
- cuda-nvcc=11.6.124=hbba6d2d_0
|
| 74 |
+
- cuda-nvdisasm=12.0.140=0
|
| 75 |
+
- cuda-nvml-dev=11.6.55=haa9ef22_0
|
| 76 |
+
- cuda-nvprof=12.0.146=0
|
| 77 |
+
- cuda-nvprune=11.6.124=he22ec0a_0
|
| 78 |
+
- cuda-nvrtc=11.6.124=h020bade_0
|
| 79 |
+
- cuda-nvrtc-dev=11.6.124=h249d397_0
|
| 80 |
+
- cuda-nvtx=11.6.124=h0630a44_0
|
| 81 |
+
- cuda-nvvp=12.0.146=0
|
| 82 |
+
- cuda-runtime=11.6.1=0
|
| 83 |
+
- cuda-samples=11.6.101=h8efea70_0
|
| 84 |
+
- cuda-sanitizer-api=12.0.140=0
|
| 85 |
+
- cuda-toolkit=11.6.1=0
|
| 86 |
+
- cuda-tools=11.6.1=0
|
| 87 |
+
- cuda-visual-tools=11.6.1=0
|
| 88 |
+
- cycler=0.11.0=pyhd8ed1ab_0
|
| 89 |
+
- dataclasses=0.8=pyhc8e2a94_3
|
| 90 |
+
- datasets=2.9.0=pyhd8ed1ab_0
|
| 91 |
+
- dbus=1.13.6=h5008d03_3
|
| 92 |
+
- debugpy=1.6.6=py310heca2aa9_0
|
| 93 |
+
- decorator=5.1.1=pyhd8ed1ab_0
|
| 94 |
+
- defusedxml=0.7.1=pyhd8ed1ab_0
|
| 95 |
+
- dill=0.3.6=pyhd8ed1ab_1
|
| 96 |
+
- einops=0.6.0=pyhd8ed1ab_0
|
| 97 |
+
- entrypoints=0.4=pyhd8ed1ab_0
|
| 98 |
+
- executing=1.2.0=pyhd8ed1ab_0
|
| 99 |
+
- expat=2.5.0=h27087fc_0
|
| 100 |
+
- fastapi=0.92.0=pyhd8ed1ab_0
|
| 101 |
+
- ffmpeg=4.3=hf484d3e_0
|
| 102 |
+
- ffmpy=0.3.0=pyhb6f538c_0
|
| 103 |
+
- fftw=3.3.10=nompi_hf0379b8_106
|
| 104 |
+
- filelock=3.9.0=pyhd8ed1ab_0
|
| 105 |
+
- flask=2.2.3=pyhd8ed1ab_0
|
| 106 |
+
- flit-core=3.8.0=pyhd8ed1ab_0
|
| 107 |
+
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
| 108 |
+
- font-ttf-inconsolata=3.000=h77eed37_0
|
| 109 |
+
- font-ttf-source-code-pro=2.038=h77eed37_0
|
| 110 |
+
- font-ttf-ubuntu=0.83=hab24e00_0
|
| 111 |
+
- fontconfig=2.14.2=h14ed4e7_0
|
| 112 |
+
- fonts-conda-ecosystem=1=0
|
| 113 |
+
- fonts-conda-forge=1=0
|
| 114 |
+
- fonttools=4.38.0=py310h5764c6d_1
|
| 115 |
+
- freetype=2.12.1=hca18f0e_1
|
| 116 |
+
- frozenlist=1.3.3=py310h5764c6d_0
|
| 117 |
+
- fsspec=2023.1.0=pyhd8ed1ab_0
|
| 118 |
+
- gds-tools=1.5.1.14=0
|
| 119 |
+
- gettext=0.21.1=h27087fc_0
|
| 120 |
+
- gflags=2.2.2=he1b5a44_1004
|
| 121 |
+
- glib=2.74.1=h6239696_1
|
| 122 |
+
- glib-tools=2.74.1=h6239696_1
|
| 123 |
+
- glog=0.6.0=h6f12383_0
|
| 124 |
+
- gmp=6.2.1=h58526e2_0
|
| 125 |
+
- gnutls=3.6.13=h85f3911_1
|
| 126 |
+
- gradio=3.19.1=pyhd8ed1ab_0
|
| 127 |
+
- graphite2=1.3.13=h58526e2_1001
|
| 128 |
+
- gst-plugins-base=1.22.0=h4243ec0_0
|
| 129 |
+
- gstreamer=1.22.0=h25f0c4b_0
|
| 130 |
+
- gstreamer-orc=0.4.33=h166bdaf_0
|
| 131 |
+
- h11=0.14.0=pyhd8ed1ab_0
|
| 132 |
+
- h2=4.1.0=pyhd8ed1ab_0
|
| 133 |
+
- harfbuzz=6.0.0=h8e241bc_0
|
| 134 |
+
- hpack=4.0.0=pyh9f0ad1d_0
|
| 135 |
+
- httpcore=0.16.3=pyhd8ed1ab_0
|
| 136 |
+
- httpx=0.23.3=pyhd8ed1ab_0
|
| 137 |
+
- huggingface_hub=0.12.1=pyhd8ed1ab_0
|
| 138 |
+
- hyperframe=6.0.1=pyhd8ed1ab_0
|
| 139 |
+
- icu=70.1=h27087fc_0
|
| 140 |
+
- idna=3.4=pyhd8ed1ab_0
|
| 141 |
+
- importlib-metadata=6.0.0=pyha770c72_0
|
| 142 |
+
- importlib_metadata=6.0.0=hd8ed1ab_0
|
| 143 |
+
- importlib_resources=5.12.0=pyhd8ed1ab_0
|
| 144 |
+
- ipykernel=6.21.2=pyh210e3f2_0
|
| 145 |
+
- ipython=8.10.0=pyh41d4057_0
|
| 146 |
+
- ipython_genutils=0.2.0=py_1
|
| 147 |
+
- ipywidgets=8.0.4=pyhd8ed1ab_0
|
| 148 |
+
- itsdangerous=2.1.2=pyhd8ed1ab_0
|
| 149 |
+
- jack=1.9.22=h11f4161_0
|
| 150 |
+
- jedi=0.18.2=pyhd8ed1ab_0
|
| 151 |
+
- jinja2=3.1.2=pyhd8ed1ab_1
|
| 152 |
+
- joblib=1.2.0=pyhd8ed1ab_0
|
| 153 |
+
- jpeg=9e=h0b41bf4_3
|
| 154 |
+
- jsonschema=4.17.3=pyhd8ed1ab_0
|
| 155 |
+
- jupyter=1.0.0=py310hff52083_8
|
| 156 |
+
- jupyter_client=8.0.3=pyhd8ed1ab_0
|
| 157 |
+
- jupyter_console=6.5.1=pyhd8ed1ab_0
|
| 158 |
+
- jupyter_core=5.2.0=py310hff52083_0
|
| 159 |
+
- jupyter_events=0.6.3=pyhd8ed1ab_0
|
| 160 |
+
- jupyter_server=2.3.0=pyhd8ed1ab_0
|
| 161 |
+
- jupyter_server_terminals=0.4.4=pyhd8ed1ab_1
|
| 162 |
+
- jupyterlab_pygments=0.2.2=pyhd8ed1ab_0
|
| 163 |
+
- jupyterlab_widgets=3.0.5=pyhd8ed1ab_0
|
| 164 |
+
- keyutils=1.6.1=h166bdaf_0
|
| 165 |
+
- kiwisolver=1.4.4=py310hbf28c38_1
|
| 166 |
+
- krb5=1.20.1=h81ceb04_0
|
| 167 |
+
- lame=3.100=h166bdaf_1003
|
| 168 |
+
- lcms2=2.14=hfd0df8a_1
|
| 169 |
+
- ld_impl_linux-64=2.40=h41732ed_0
|
| 170 |
+
- lerc=4.0.0=h27087fc_0
|
| 171 |
+
- libabseil=20220623.0=cxx17_h05df665_6
|
| 172 |
+
- libarrow=11.0.0=hc42cb68_4_cpu
|
| 173 |
+
- libblas=3.9.0=16_linux64_mkl
|
| 174 |
+
- libbrotlicommon=1.0.9=h166bdaf_8
|
| 175 |
+
- libbrotlidec=1.0.9=h166bdaf_8
|
| 176 |
+
- libbrotlienc=1.0.9=h166bdaf_8
|
| 177 |
+
- libcap=2.66=ha37c62d_0
|
| 178 |
+
- libcblas=3.9.0=16_linux64_mkl
|
| 179 |
+
- libclang=15.0.7=default_had23c3d_1
|
| 180 |
+
- libclang13=15.0.7=default_h3e3d535_1
|
| 181 |
+
- libcrc32c=1.1.2=h9c3ff4c_0
|
| 182 |
+
- libcublas=11.9.2.110=h5e84587_0
|
| 183 |
+
- libcublas-dev=11.9.2.110=h5c901ab_0
|
| 184 |
+
- libcufft=10.7.1.112=hf425ae0_0
|
| 185 |
+
- libcufft-dev=10.7.1.112=ha5ce4c0_0
|
| 186 |
+
- libcufile=1.5.1.14=0
|
| 187 |
+
- libcufile-dev=1.5.1.14=0
|
| 188 |
+
- libcups=2.3.3=h36d4200_3
|
| 189 |
+
- libcurand=10.3.1.124=0
|
| 190 |
+
- libcurand-dev=10.3.1.124=0
|
| 191 |
+
- libcurl=7.88.1=hdc1c0ab_0
|
| 192 |
+
- libcusolver=11.3.4.124=h33c3c4e_0
|
| 193 |
+
- libcusparse=11.7.2.124=h7538f96_0
|
| 194 |
+
- libcusparse-dev=11.7.2.124=hbbe9722_0
|
| 195 |
+
- libdb=6.2.32=h9c3ff4c_0
|
| 196 |
+
- libdeflate=1.17=h0b41bf4_0
|
| 197 |
+
- libedit=3.1.20191231=he28a2e2_2
|
| 198 |
+
- libev=4.33=h516909a_1
|
| 199 |
+
- libevent=2.1.10=h28343ad_4
|
| 200 |
+
- libffi=3.4.2=h7f98852_5
|
| 201 |
+
- libflac=1.4.2=h27087fc_0
|
| 202 |
+
- libgcc-ng=12.2.0=h65d4601_19
|
| 203 |
+
- libgcrypt=1.10.1=h166bdaf_0
|
| 204 |
+
- libgfortran-ng=12.2.0=h69a702a_19
|
| 205 |
+
- libgfortran5=12.2.0=h337968e_19
|
| 206 |
+
- libglib=2.74.1=h606061b_1
|
| 207 |
+
- libgoogle-cloud=2.7.0=h21dfe5b_1
|
| 208 |
+
- libgpg-error=1.46=h620e276_0
|
| 209 |
+
- libgrpc=1.51.1=h4fad500_1
|
| 210 |
+
- libhwloc=2.8.0=h32351e8_1
|
| 211 |
+
- libiconv=1.17=h166bdaf_0
|
| 212 |
+
- liblapack=3.9.0=16_linux64_mkl
|
| 213 |
+
- liblapacke=3.9.0=16_linux64_mkl
|
| 214 |
+
- libllvm15=15.0.7=hadd5161_0
|
| 215 |
+
- libnghttp2=1.51.0=hff17c54_0
|
| 216 |
+
- libnpp=11.6.3.124=hd2722f0_0
|
| 217 |
+
- libnpp-dev=11.6.3.124=h3c42840_0
|
| 218 |
+
- libnsl=2.0.0=h7f98852_0
|
| 219 |
+
- libnvjpeg=11.6.2.124=hd473ad6_0
|
| 220 |
+
- libnvjpeg-dev=11.6.2.124=hb5906b9_0
|
| 221 |
+
- libogg=1.3.4=h7f98852_1
|
| 222 |
+
- libopus=1.3.1=h7f98852_1
|
| 223 |
+
- libpng=1.6.39=h753d276_0
|
| 224 |
+
- libpq=15.2=hb675445_0
|
| 225 |
+
- libprotobuf=3.21.12=h3eb15da_0
|
| 226 |
+
- libsndfile=1.2.0=hb75c966_0
|
| 227 |
+
- libsodium=1.0.18=h36c2ea0_1
|
| 228 |
+
- libsqlite=3.40.0=h753d276_0
|
| 229 |
+
- libssh2=1.10.0=hf14f497_3
|
| 230 |
+
- libstdcxx-ng=12.2.0=h46fd767_19
|
| 231 |
+
- libsystemd0=252=h2a991cd_0
|
| 232 |
+
- libthrift=0.16.0=he500d00_2
|
| 233 |
+
- libtiff=4.5.0=h6adf6a1_2
|
| 234 |
+
- libtool=2.4.7=h27087fc_0
|
| 235 |
+
- libudev1=252=h166bdaf_0
|
| 236 |
+
- libutf8proc=2.8.0=h166bdaf_0
|
| 237 |
+
- libuuid=2.32.1=h7f98852_1000
|
| 238 |
+
- libvorbis=1.3.7=h9c3ff4c_0
|
| 239 |
+
- libwebp-base=1.2.4=h166bdaf_0
|
| 240 |
+
- libxcb=1.13=h7f98852_1004
|
| 241 |
+
- libxkbcommon=1.5.0=h79f4944_0
|
| 242 |
+
- libxml2=2.10.3=h7463322_0
|
| 243 |
+
- libzlib=1.2.13=h166bdaf_4
|
| 244 |
+
- linkify-it-py=2.0.0=pyhd8ed1ab_0
|
| 245 |
+
- llvm-openmp=15.0.7=h0cdce71_0
|
| 246 |
+
- lz4-c=1.9.4=hcb278e6_0
|
| 247 |
+
- markdown-it-py=2.1.0=pyhd8ed1ab_0
|
| 248 |
+
- markupsafe=2.1.2=py310h1fa729e_0
|
| 249 |
+
- matplotlib-base=3.7.0=py310he60537e_0
|
| 250 |
+
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
|
| 251 |
+
- mdit-py-plugins=0.3.3=pyhd8ed1ab_0
|
| 252 |
+
- mdurl=0.1.0=pyhd8ed1ab_0
|
| 253 |
+
- mistune=2.0.5=pyhd8ed1ab_0
|
| 254 |
+
- mkl=2022.1.0=h84fe81f_915
|
| 255 |
+
- mkl-devel=2022.1.0=ha770c72_916
|
| 256 |
+
- mkl-include=2022.1.0=h84fe81f_915
|
| 257 |
+
- mpg123=1.31.2=hcb278e6_0
|
| 258 |
+
- multidict=6.0.4=py310h1fa729e_0
|
| 259 |
+
- multiprocess=0.70.14=py310h5764c6d_3
|
| 260 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
| 261 |
+
- mysql-common=8.0.32=ha901b37_0
|
| 262 |
+
- mysql-libs=8.0.32=hd7da12d_0
|
| 263 |
+
- nbclassic=0.5.2=pyhd8ed1ab_0
|
| 264 |
+
- nbclient=0.7.2=pyhd8ed1ab_0
|
| 265 |
+
- nbconvert=7.2.9=pyhd8ed1ab_0
|
| 266 |
+
- nbconvert-core=7.2.9=pyhd8ed1ab_0
|
| 267 |
+
- nbconvert-pandoc=7.2.9=pyhd8ed1ab_0
|
| 268 |
+
- nbformat=5.7.3=pyhd8ed1ab_0
|
| 269 |
+
- ncurses=6.3=h27087fc_1
|
| 270 |
+
- nest-asyncio=1.5.6=pyhd8ed1ab_0
|
| 271 |
+
- nettle=3.6=he412f7d_0
|
| 272 |
+
- notebook=6.5.2=pyha770c72_1
|
| 273 |
+
- notebook-shim=0.2.2=pyhd8ed1ab_0
|
| 274 |
+
- nsight-compute=2022.4.1.6=0
|
| 275 |
+
- nspr=4.35=h27087fc_0
|
| 276 |
+
- nss=3.88=he45b914_0
|
| 277 |
+
- numpy=1.24.2=py310h8deb116_0
|
| 278 |
+
- openh264=2.1.1=h780b84a_0
|
| 279 |
+
- openjpeg=2.5.0=hfec8fc6_2
|
| 280 |
+
- openssl=3.0.8=h0b41bf4_0
|
| 281 |
+
- orc=1.8.2=hfdbbad2_2
|
| 282 |
+
- orjson=3.8.5=py310h38b9cce_1
|
| 283 |
+
- packaging=23.0=pyhd8ed1ab_0
|
| 284 |
+
- pandas=1.5.3=py310h9b08913_0
|
| 285 |
+
- pandoc=2.19.2=h32600fe_1
|
| 286 |
+
- pandocfilters=1.5.0=pyhd8ed1ab_0
|
| 287 |
+
- parquet-cpp=1.5.1=2
|
| 288 |
+
- parso=0.8.3=pyhd8ed1ab_0
|
| 289 |
+
- pcre2=10.40=hc3806b6_0
|
| 290 |
+
- pexpect=4.8.0=pyh1a96a4e_2
|
| 291 |
+
- pickleshare=0.7.5=py_1003
|
| 292 |
+
- pillow=9.4.0=py310h023d228_1
|
| 293 |
+
- pip=23.0.1=pyhd8ed1ab_0
|
| 294 |
+
- pixman=0.40.0=h36c2ea0_0
|
| 295 |
+
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
|
| 296 |
+
- platformdirs=3.0.0=pyhd8ed1ab_0
|
| 297 |
+
- ply=3.11=py_1
|
| 298 |
+
- prometheus_client=0.16.0=pyhd8ed1ab_0
|
| 299 |
+
- prompt-toolkit=3.0.36=pyha770c72_0
|
| 300 |
+
- prompt_toolkit=3.0.36=hd8ed1ab_0
|
| 301 |
+
- psutil=5.9.4=py310h5764c6d_0
|
| 302 |
+
- pthread-stubs=0.4=h36c2ea0_1001
|
| 303 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
| 304 |
+
- pulseaudio=16.1=ha8d29e2_1
|
| 305 |
+
- pure_eval=0.2.2=pyhd8ed1ab_0
|
| 306 |
+
- pyarrow=11.0.0=py310h633f555_4_cpu
|
| 307 |
+
- pycparser=2.21=pyhd8ed1ab_0
|
| 308 |
+
- pycryptodome=3.16.0=py310h1419917_0
|
| 309 |
+
- pydantic=1.10.5=py310h1fa729e_0
|
| 310 |
+
- pydub=0.25.1=pyhd8ed1ab_0
|
| 311 |
+
- pygments=2.14.0=pyhd8ed1ab_0
|
| 312 |
+
- pyopenssl=23.0.0=pyhd8ed1ab_0
|
| 313 |
+
- pyparsing=3.0.9=pyhd8ed1ab_0
|
| 314 |
+
- pyqt=5.15.7=py310hab646b1_3
|
| 315 |
+
- pyqt5-sip=12.11.0=py310heca2aa9_3
|
| 316 |
+
- pyrsistent=0.19.3=py310h1fa729e_0
|
| 317 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
| 318 |
+
- python=3.10.9=he550d4f_0_cpython
|
| 319 |
+
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
| 320 |
+
- python-fastjsonschema=2.16.2=pyhd8ed1ab_0
|
| 321 |
+
- python-json-logger=2.0.6=pyhd8ed1ab_0
|
| 322 |
+
- python-multipart=0.0.5=py_0
|
| 323 |
+
- python-xxhash=3.2.0=py310h1fa729e_0
|
| 324 |
+
- python_abi=3.10=3_cp310
|
| 325 |
+
- pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0
|
| 326 |
+
- pytorch-cuda=11.6=h867d48c_1
|
| 327 |
+
- pytorch-mutex=1.0=cuda
|
| 328 |
+
- pytz=2022.7.1=pyhd8ed1ab_0
|
| 329 |
+
- pyyaml=6.0=py310h5764c6d_5
|
| 330 |
+
- pyzmq=25.0.0=py310h059b190_0
|
| 331 |
+
- qt-main=5.15.8=h5d23da1_6
|
| 332 |
+
- qtconsole=5.4.0=pyhd8ed1ab_0
|
| 333 |
+
- qtconsole-base=5.4.0=pyha770c72_0
|
| 334 |
+
- qtpy=2.3.0=pyhd8ed1ab_0
|
| 335 |
+
- re2=2023.02.01=hcb278e6_0
|
| 336 |
+
- readline=8.1.2=h0f457ee_0
|
| 337 |
+
- regex=2022.10.31=py310h5764c6d_0
|
| 338 |
+
- requests=2.28.2=pyhd8ed1ab_0
|
| 339 |
+
- responses=0.18.0=pyhd8ed1ab_0
|
| 340 |
+
- rfc3339-validator=0.1.4=pyhd8ed1ab_0
|
| 341 |
+
- rfc3986=1.5.0=pyhd8ed1ab_0
|
| 342 |
+
- rfc3986-validator=0.1.1=pyh9f0ad1d_0
|
| 343 |
+
- s2n=1.3.35=h3358134_0
|
| 344 |
+
- sacremoses=0.0.53=pyhd8ed1ab_0
|
| 345 |
+
- send2trash=1.8.0=pyhd8ed1ab_0
|
| 346 |
+
- setuptools=67.3.2=pyhd8ed1ab_0
|
| 347 |
+
- sip=6.7.7=py310heca2aa9_0
|
| 348 |
+
- six=1.16.0=pyh6c4a22f_0
|
| 349 |
+
- snappy=1.1.9=hbd366e4_2
|
| 350 |
+
- sniffio=1.3.0=pyhd8ed1ab_0
|
| 351 |
+
- soupsieve=2.3.2.post1=pyhd8ed1ab_0
|
| 352 |
+
- stack_data=0.6.2=pyhd8ed1ab_0
|
| 353 |
+
- starlette=0.25.0=pyhd8ed1ab_0
|
| 354 |
+
- tbb=2021.7.0=h924138e_1
|
| 355 |
+
- terminado=0.17.1=pyh41d4057_0
|
| 356 |
+
- tinycss2=1.2.1=pyhd8ed1ab_0
|
| 357 |
+
- tk=8.6.12=h27826a3_0
|
| 358 |
+
- tokenizers=0.13.2=py310he1f1126_0
|
| 359 |
+
- toml=0.10.2=pyhd8ed1ab_0
|
| 360 |
+
- toolz=0.12.0=pyhd8ed1ab_0
|
| 361 |
+
- torchaudio=0.13.1=py310_cu116
|
| 362 |
+
- torchvision=0.14.1=py310_cu116
|
| 363 |
+
- tornado=6.2=py310h5764c6d_1
|
| 364 |
+
- tqdm=4.64.1=pyhd8ed1ab_0
|
| 365 |
+
- traitlets=5.9.0=pyhd8ed1ab_0
|
| 366 |
+
- transformers=4.26.1=pyhd8ed1ab_0
|
| 367 |
+
- typing-extensions=4.4.0=hd8ed1ab_0
|
| 368 |
+
- typing_extensions=4.4.0=pyha770c72_0
|
| 369 |
+
- tzdata=2022g=h191b570_0
|
| 370 |
+
- uc-micro-py=1.0.1=pyhd8ed1ab_0
|
| 371 |
+
- unicodedata2=15.0.0=py310h5764c6d_0
|
| 372 |
+
- urllib3=1.26.14=pyhd8ed1ab_0
|
| 373 |
+
- uvicorn=0.20.0=py310hff52083_1
|
| 374 |
+
- wcwidth=0.2.6=pyhd8ed1ab_0
|
| 375 |
+
- webencodings=0.5.1=py_1
|
| 376 |
+
- websocket-client=1.5.1=pyhd8ed1ab_0
|
| 377 |
+
- websockets=10.4=py310h5764c6d_1
|
| 378 |
+
- werkzeug=2.2.3=pyhd8ed1ab_0
|
| 379 |
+
- wheel=0.38.4=pyhd8ed1ab_0
|
| 380 |
+
- widgetsnbextension=4.0.5=pyhd8ed1ab_0
|
| 381 |
+
- xcb-util=0.4.0=h166bdaf_0
|
| 382 |
+
- xcb-util-image=0.4.0=h166bdaf_0
|
| 383 |
+
- xcb-util-keysyms=0.4.0=h166bdaf_0
|
| 384 |
+
- xcb-util-renderutil=0.3.9=h166bdaf_0
|
| 385 |
+
- xcb-util-wm=0.4.1=h166bdaf_0
|
| 386 |
+
- xorg-kbproto=1.0.7=h7f98852_1002
|
| 387 |
+
- xorg-libice=1.0.10=h7f98852_0
|
| 388 |
+
- xorg-libsm=1.2.3=hd9c2040_1000
|
| 389 |
+
- xorg-libx11=1.7.2=h7f98852_0
|
| 390 |
+
- xorg-libxau=1.0.9=h7f98852_0
|
| 391 |
+
- xorg-libxdmcp=1.1.3=h7f98852_0
|
| 392 |
+
- xorg-libxext=1.3.4=h7f98852_1
|
| 393 |
+
- xorg-libxrender=0.9.10=h7f98852_1003
|
| 394 |
+
- xorg-renderproto=0.11.1=h7f98852_1002
|
| 395 |
+
- xorg-xextproto=7.3.0=h7f98852_1002
|
| 396 |
+
- xorg-xproto=7.0.31=h7f98852_1007
|
| 397 |
+
- xxhash=0.8.1=h0b41bf4_0
|
| 398 |
+
- xz=5.2.6=h166bdaf_0
|
| 399 |
+
- yaml=0.2.5=h7f98852_2
|
| 400 |
+
- yarl=1.8.2=py310h5764c6d_0
|
| 401 |
+
- zeromq=4.3.4=h9c3ff4c_1
|
| 402 |
+
- zipp=3.14.0=pyhd8ed1ab_0
|
| 403 |
+
- zlib=1.2.13=h166bdaf_4
|
| 404 |
+
- zstd=1.5.2=h3eb15da_6
|
| 405 |
+
- pip:
|
| 406 |
+
- fancy-einsum==0.0.3
|
sampling.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
import torch as t
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import transformers
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
gpt = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
|
| 8 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
|
| 9 |
+
|
| 10 |
+
def apply_sampling_methods(
|
| 11 |
+
input_ids: t.Tensor, logits: t.Tensor, temperature=1.0, freq_penalty=0.0, top_k=0, top_p=0.0
|
| 12 |
+
) -> int:
|
| 13 |
+
'''
|
| 14 |
+
Return the next token, sampled from the model's probability distribution with modifiers.
|
| 15 |
+
x
|
| 16 |
+
input_ids: shape (seq,)
|
| 17 |
+
'''
|
| 18 |
+
assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
|
| 19 |
+
assert temperature >= 0, "Temperature should be non-negative"
|
| 20 |
+
assert 0 <= top_p <= 1.0, "Top-p must be a probability"
|
| 21 |
+
assert 0 <= top_k, "Top-k must be non-negative"
|
| 22 |
+
assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"
|
| 23 |
+
|
| 24 |
+
if temperature == 0:
|
| 25 |
+
return greedy_search(logits)
|
| 26 |
+
if temperature != 1.0:
|
| 27 |
+
logits = apply_temperature(logits, temperature)
|
| 28 |
+
if freq_penalty != 0.0:
|
| 29 |
+
logits = apply_freq_penalty(input_ids, logits, freq_penalty)
|
| 30 |
+
if top_k > 0:
|
| 31 |
+
return sample_top_k(logits, top_k)
|
| 32 |
+
if top_p > 0:
|
| 33 |
+
return sample_top_p(logits, top_p)
|
| 34 |
+
return sample_basic(logits)
|
| 35 |
+
|
| 36 |
+
def sample_tokens(
|
| 37 |
+
model,
|
| 38 |
+
tokenizer,
|
| 39 |
+
initial_text: str,
|
| 40 |
+
max_tokens_generated: int = 30,
|
| 41 |
+
**kwargs
|
| 42 |
+
) -> str:
|
| 43 |
+
'''
|
| 44 |
+
Sample tokens until the model outputs `tokenizer.eos_token_id` or the specified token limit is reached.
|
| 45 |
+
|
| 46 |
+
Return: the prompt and continuation concatenated
|
| 47 |
+
'''
|
| 48 |
+
model.eval()
|
| 49 |
+
input_ids: list = tokenizer.encode(initial_text)
|
| 50 |
+
generated = []
|
| 51 |
+
device = next(model.parameters()).device
|
| 52 |
+
for _ in range(max_tokens_generated):
|
| 53 |
+
new_input_ids = t.tensor(np.array(input_ids + generated), dtype=t.int64, device=device)
|
| 54 |
+
new_input_ids_truncated = new_input_ids[-min(tokenizer.model_max_length, new_input_ids.shape[0]):].unsqueeze(0)
|
| 55 |
+
output = model(new_input_ids_truncated)
|
| 56 |
+
all_logits = output if isinstance(output, t.Tensor) else output.logits
|
| 57 |
+
logits = all_logits[0, -1] #batch=0, seq_len=-1 -> returns vocab_size
|
| 58 |
+
new_token = apply_sampling_methods(new_input_ids, logits, **kwargs)
|
| 59 |
+
generated.append(new_token)
|
| 60 |
+
if new_token == getattr(tokenizer, "eos_token_id", None):
|
| 61 |
+
break
|
| 62 |
+
return tokenizer.decode(input_ids + generated)
|
| 63 |
+
|
| 64 |
+
# %%
|
| 65 |
+
def greedy_search(logits: t.Tensor) -> int:
|
| 66 |
+
'''
|
| 67 |
+
logits: shape (vocab_size, )
|
| 68 |
+
|
| 69 |
+
Return: the most likely token (as an integer)
|
| 70 |
+
'''
|
| 71 |
+
return logits.argmax().numpy()
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
prompt = "Jingle bells, jingle bells, jingle all the way"
|
| 75 |
+
print("Greedy decoding with prompt: ", prompt)
|
| 76 |
+
output = sample_tokens(gpt, tokenizer, prompt, max_tokens_generated=8, temperature=0.0)
|
| 77 |
+
print(f"Your model said: {output}")
|
| 78 |
+
expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
|
| 79 |
+
assert output == expected
|
| 80 |
+
|
| 81 |
+
print("Greedy decoding a second time (should be deterministic): ")
|
| 82 |
+
output = sample_tokens(gpt, tokenizer, prompt, max_tokens_generated=8, temperature=0.0)
|
| 83 |
+
print(f"Your model said: {output}")
|
| 84 |
+
expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
|
| 85 |
+
assert output == expected
|
| 86 |
+
|
| 87 |
+
print("Tests passed!")
|
| 88 |
+
# %%
|
| 89 |
+
def sample_basic(logits: t.Tensor) -> int:
|
| 90 |
+
'''
|
| 91 |
+
logits: shape (vocab_size, ) - unnormalized log-probabilities
|
| 92 |
+
|
| 93 |
+
Return: a sampled token
|
| 94 |
+
'''
|
| 95 |
+
return t.distributions.categorical.Categorical(logits=logits).sample()
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
N = 20000
|
| 99 |
+
probs = t.linspace(0, 0.4, 5)
|
| 100 |
+
unnormalized_logits = probs.log() + 1.2345
|
| 101 |
+
samples = t.tensor([sample_basic(unnormalized_logits) for _ in range(N)])
|
| 102 |
+
counts = t.bincount(samples, minlength=len(probs)) / N
|
| 103 |
+
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
|
| 104 |
+
t.testing.assert_close(counts, probs, atol=0.01, rtol=0)
|
| 105 |
+
print("Tests passed!")
|
| 106 |
+
# %%
|
| 107 |
+
def apply_temperature(logits: t.Tensor, temperature: float) -> t.Tensor:
|
| 108 |
+
'''
|
| 109 |
+
logits: shape (vocab_size, )
|
| 110 |
+
|
| 111 |
+
Return: shape (vocab_size, )
|
| 112 |
+
'''
|
| 113 |
+
assert temperature > 0
|
| 114 |
+
return logits / temperature
|
| 115 |
+
|
| 116 |
+
if __name__ == '__main__':
|
| 117 |
+
logits = t.tensor([1, 2]).log()
|
| 118 |
+
cold_logits = apply_temperature(logits, 0.001)
|
| 119 |
+
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
|
| 120 |
+
t.testing.assert_close(cold_logits, 1000.0 * logits)
|
| 121 |
+
hot_logits = apply_temperature(logits, 1000.0)
|
| 122 |
+
print("A high temperature flattens the distribution: ", hot_logits)
|
| 123 |
+
t.testing.assert_close(hot_logits, 0.001 * logits)
|
| 124 |
+
print("Tests passed!")
|
| 125 |
+
|
| 126 |
+
# %%
|
| 127 |
+
def apply_freq_penalty(input_ids: t.Tensor, logits: t.Tensor, freq_penalty: float) -> t.Tensor:
|
| 128 |
+
'''
|
| 129 |
+
input_ids: shape (seq, )
|
| 130 |
+
logits: shape (vocab_size, )
|
| 131 |
+
|
| 132 |
+
Return: shape (vocab_size, )
|
| 133 |
+
'''
|
| 134 |
+
count = input_ids.bincount(minlength=len(logits))
|
| 135 |
+
logits -= count * freq_penalty
|
| 136 |
+
return logits
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
|
| 140 |
+
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt").squeeze()
|
| 141 |
+
logits = t.ones(tokenizer.vocab_size)
|
| 142 |
+
penalized_logits = apply_freq_penalty(input_ids, logits, 2.0)
|
| 143 |
+
assert penalized_logits[5156].item() == -11, "Expected 6 occurrences of ' baby' with leading space"
|
| 144 |
+
assert penalized_logits[14801].item() == -5, "Expected 3 occurrences of ' Baby' with leading space"
|
| 145 |
+
print("Tests passed!")
|
| 146 |
+
# %%
|
| 147 |
+
N_RUNS = 0
|
| 148 |
+
your_prompt = "Jingle bells, jingle bells, jingle all the way"
|
| 149 |
+
cases = [
|
| 150 |
+
("High freq penalty", dict(freq_penalty=100.0)),
|
| 151 |
+
("Negative freq penalty", dict(freq_penalty=-1.0)),
|
| 152 |
+
("Too hot!", dict(temperature=2.0)),
|
| 153 |
+
("Pleasantly cool", dict(temperature=0.7)),
|
| 154 |
+
("Pleasantly warm", dict(temperature=0.9)),
|
| 155 |
+
("Too cold!", dict(temperature=0.01)),
|
| 156 |
+
]
|
| 157 |
+
for (name, kwargs) in cases:
|
| 158 |
+
for i in range(N_RUNS):
|
| 159 |
+
output = sample_tokens(gpt, tokenizer, your_prompt, max_tokens_generated=24, **kwargs)
|
| 160 |
+
print(f"Sample {i} with: {name} ({kwargs}):")
|
| 161 |
+
print(f"Your model said: {repr(output)}\n")
|
| 162 |
+
# %%
|
| 163 |
+
def sample_top_k(logits: t.Tensor, top_k: int) -> int:
|
| 164 |
+
'''
|
| 165 |
+
logits: shape (vocab_size, ) - unnormalized log-probabilities
|
| 166 |
+
top_k: only consider this many of the most likely tokens for sampling
|
| 167 |
+
|
| 168 |
+
Return: a sampled token
|
| 169 |
+
'''
|
| 170 |
+
values, indices = t.topk(logits, top_k)
|
| 171 |
+
return indices[sample_basic(values)].item()
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
N = 50000
|
| 175 |
+
k = 3
|
| 176 |
+
probs = t.linspace(0, 0.4, 5)
|
| 177 |
+
unnormalized_logits = probs.log() + 1.2345
|
| 178 |
+
samples = t.tensor([sample_top_k(unnormalized_logits, k) for _ in range(N)])
|
| 179 |
+
counts = t.bincount(samples, minlength=len(probs)) / N
|
| 180 |
+
expected = probs.clone()
|
| 181 |
+
expected[:-k] = 0
|
| 182 |
+
expected /= expected.sum()
|
| 183 |
+
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
|
| 184 |
+
t.testing.assert_close(counts, expected, atol=0.01, rtol=0)
|
| 185 |
+
print("Tests passed!")
|
| 186 |
+
# %%
|
| 187 |
+
if __name__ == "__main__":
|
| 188 |
+
your_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
|
| 189 |
+
output = sample_tokens(gpt, tokenizer, your_prompt, temperature=0.7, top_k=40, max_tokens_generated=64)
|
| 190 |
+
print(f"Your model said: {repr(output)}")
|
| 191 |
+
# %%
|
| 192 |
+
def sample_top_p(logits: t.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> int:
|
| 193 |
+
'''
|
| 194 |
+
logits: shape (vocab_size, ) - unnormalized log-probabilities
|
| 195 |
+
|
| 196 |
+
Return: a sampled token
|
| 197 |
+
'''
|
| 198 |
+
probs = t.exp(logits.double()) / t.exp(logits.double()).sum()
|
| 199 |
+
sorted_probs, sorted_indices = probs.sort(descending=True)
|
| 200 |
+
cum_probs = sorted_probs.cumsum(-1)
|
| 201 |
+
last_index = max(min_tokens_to_keep, t.where(cum_probs >= top_p)[0][0].numpy() + 1)
|
| 202 |
+
masked_probs = sorted_probs[:last_index]
|
| 203 |
+
sample = t.distributions.categorical.Categorical(probs=t.tensor(masked_probs)).sample()
|
| 204 |
+
return sorted_indices[sample]
|
| 205 |
+
|
| 206 |
+
if __name__ == "__main__":
|
| 207 |
+
N = 2000
|
| 208 |
+
unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456
|
| 209 |
+
samples = t.tensor([sample_top_p(unnormalized_logits, 0.5) for _ in range(N)])
|
| 210 |
+
counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N
|
| 211 |
+
print("top_p of 0.5 or lower should only return token 2: ", counts)
|
| 212 |
+
assert counts[0] == 0 and counts[1] == 0
|
| 213 |
+
|
| 214 |
+
N = 2000
|
| 215 |
+
unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456
|
| 216 |
+
samples = t.tensor([sample_top_p(unnormalized_logits, 0.50001) for _ in range(N)])
|
| 217 |
+
counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N
|
| 218 |
+
print("top_p in (0.5, 0.8] should return tokens 1 and 2: ", counts)
|
| 219 |
+
assert counts[0] == 0
|
| 220 |
+
|
| 221 |
+
N = 50000
|
| 222 |
+
top_p = 0.71
|
| 223 |
+
probs = t.linspace(0, 0.4, 5)
|
| 224 |
+
unnormalized_logits = probs.log() + 1.2345
|
| 225 |
+
samples = t.tensor([sample_top_p(unnormalized_logits, top_p) for _ in range(N)])
|
| 226 |
+
counts = t.bincount(samples, minlength=len(probs)) / N
|
| 227 |
+
expected = probs.clone()
|
| 228 |
+
expected[0:2] = 0
|
| 229 |
+
expected /= expected.sum()
|
| 230 |
+
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
|
| 231 |
+
t.testing.assert_close(counts, expected, atol=0.01, rtol=0.0)
|
| 232 |
+
|
| 233 |
+
print("All tests passed!")
|
| 234 |
+
# %%
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
your_prompt = "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for"
|
| 237 |
+
output = sample_tokens(gpt, tokenizer, your_prompt, temperature=0.7, top_p=0.95, max_tokens_generated=64)
|
| 238 |
+
print(f"Your model said: {repr(output)}")
|
| 239 |
+
# %%
|
shakespeare_demo.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import yaml
|
| 3 |
+
import torch as t
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import re
|
| 6 |
+
from word_data import WordData
|
| 7 |
+
import sampling
|
| 8 |
+
import transformer_replication
|
| 9 |
+
#%%
|
| 10 |
+
MAIN = __name__ == '__main__'
|
| 11 |
+
device = 'cuda' if t.cuda.is_available() else 'cpu'
|
| 12 |
+
#%%
|
| 13 |
+
shakespeare = WordData.from_file(
|
| 14 |
+
'100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL'
|
| 15 |
+
)
|
| 16 |
+
if MAIN:
|
| 17 |
+
print('Vocab size: ', len(shakespeare.vocab))
|
| 18 |
+
#%%
|
| 19 |
+
#%%
|
| 20 |
+
with open('config.yaml', 'r') as f:
|
| 21 |
+
yaml_cfg = yaml.safe_load(f)
|
| 22 |
+
#%%
|
| 23 |
+
with open('model_state_dict.pt') as f:
|
| 24 |
+
state_dict = t.load(
|
| 25 |
+
'model_state_dict.pt'
|
| 26 |
+
)
|
| 27 |
+
#%%
|
| 28 |
+
base_config = transformer_replication.TransformerConfig(
|
| 29 |
+
num_layers=yaml_cfg['num_layers']['value'],
|
| 30 |
+
num_heads=yaml_cfg['num_heads']['value'],
|
| 31 |
+
vocab_size=len(shakespeare.vocab),
|
| 32 |
+
hidden_size=yaml_cfg['hidden_size']['value'],
|
| 33 |
+
max_seq_len=yaml_cfg['max_seq_len']['value'],
|
| 34 |
+
dropout=yaml_cfg['dropout']['value'],
|
| 35 |
+
)
|
| 36 |
+
shakespeare.model_max_length = yaml_cfg['max_seq_len']['value']
|
| 37 |
+
model = transformer_replication.DecoderOnlyTransformer(base_config)
|
| 38 |
+
|
| 39 |
+
model.load_state_dict(state_dict)
|
| 40 |
+
|
| 41 |
+
#%%
|
| 42 |
+
def generate(
|
| 43 |
+
text: str, max_tokens: int, temperature: float,
|
| 44 |
+
top_k: int,
|
| 45 |
+
) -> str:
|
| 46 |
+
return sampling.sample_tokens(
|
| 47 |
+
model,
|
| 48 |
+
shakespeare,
|
| 49 |
+
text,
|
| 50 |
+
max_tokens_generated=max_tokens,
|
| 51 |
+
temperature=temperature,
|
| 52 |
+
top_k=top_k,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
#%%
|
| 56 |
+
def safe_generate(
|
| 57 |
+
text: str, max_tokens: int = 300, temperature: float = 1.0,
|
| 58 |
+
top_k: int = 20,
|
| 59 |
+
) -> str:
|
| 60 |
+
try:
|
| 61 |
+
raw = generate(
|
| 62 |
+
text, max_tokens=max_tokens, temperature=temperature, top_k=top_k,
|
| 63 |
+
)
|
| 64 |
+
match = re.match(r"(?P<start>\D*)\d+\n", raw)
|
| 65 |
+
if match is None:
|
| 66 |
+
return raw
|
| 67 |
+
return match.group('start')
|
| 68 |
+
except KeyError as e:
|
| 69 |
+
return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary"
|
| 70 |
+
#%%
|
| 71 |
+
examples = [
|
| 72 |
+
["I sang a beautiful song"],
|
| 73 |
+
["To be free is to"],
|
| 74 |
+
["How I love thee"],
|
| 75 |
+
]
|
| 76 |
+
#%%
|
| 77 |
+
if MAIN:
|
| 78 |
+
print(safe_generate('How I love thee'))
|
| 79 |
+
#%%
|
| 80 |
+
def make_demo():
|
| 81 |
+
demo = gr.Interface(
|
| 82 |
+
fn=safe_generate,
|
| 83 |
+
inputs=[
|
| 84 |
+
gr.components.Textbox(lines=5, label="Input Text"),
|
| 85 |
+
gr.components.Slider(
|
| 86 |
+
label='max tokens generated', minimum=1, maximum=1000,
|
| 87 |
+
value=300, step=1,
|
| 88 |
+
),
|
| 89 |
+
gr.components.Slider(
|
| 90 |
+
label='temperature', minimum=0, maximum=2, value=1, step=0.1,
|
| 91 |
+
),
|
| 92 |
+
gr.components.Slider(
|
| 93 |
+
label='top_k', minimum=1, maximum=100, value=10, step=1,
|
| 94 |
+
),
|
| 95 |
+
],
|
| 96 |
+
outputs=gr.components.Textbox(label="Generated Text"),
|
| 97 |
+
examples=examples
|
| 98 |
+
)
|
| 99 |
+
demo.launch()
|
| 100 |
+
# %%
|
| 101 |
+
'''
|
| 102 |
+
FIXME:
|
| 103 |
+
* deploy to heroku
|
| 104 |
+
* link from github home
|
| 105 |
+
'''
|
transformer_replication.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import transformers
|
| 3 |
+
import torch as t
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import Union, List
|
| 6 |
+
from fancy_einsum import einsum
|
| 7 |
+
import torch as t
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torchvision import datasets, transforms
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from typing import Union, Optional, Callable, Tuple
|
| 12 |
+
import numpy as np
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
import time
|
| 15 |
+
# %%
|
| 16 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
print(tokenizer("hello meg"))
|
| 19 |
+
print(tokenizer.encode("hello meg"))
|
| 20 |
+
print(tokenizer.decode([31373, 17243]))
|
| 21 |
+
print(tokenizer.tokenize("hello meg"))
|
| 22 |
+
print(f"'{tokenizer.decode(17243)}'")
|
| 23 |
+
# %%
|
| 24 |
+
class Embedding(nn.Module):
|
| 25 |
+
|
| 26 |
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.num_embeddings = num_embeddings
|
| 29 |
+
self.embedding_dim = embedding_dim
|
| 30 |
+
|
| 31 |
+
self.weight = nn.Parameter(t.randn((self.num_embeddings, self.embedding_dim)))
|
| 32 |
+
|
| 33 |
+
def forward(self, x: t.LongTensor) -> t.Tensor:
|
| 34 |
+
'''For each integer in the input, return that row of the embedding.
|
| 35 |
+
'''
|
| 36 |
+
#return einsum('num_embeddings embedding_dim, i num_embeddings -> i embedding_dim', self.weight, nn.functional.one_hot(x, num_classes=self.num_embeddings).float())
|
| 37 |
+
return self.weight[x]
|
| 38 |
+
|
| 39 |
+
def extra_repr(self) -> str:
|
| 40 |
+
return f"{self.num_embeddings}, {self.embedding_dim}"
|
| 41 |
+
|
| 42 |
+
# %%
|
| 43 |
+
#TODO positional encoding
|
| 44 |
+
class PositionalEncoding(nn.Module):
|
| 45 |
+
|
| 46 |
+
def __init__(self, max_seq_len: int, embedding_dim: int):
|
| 47 |
+
super().__init__()
|
| 48 |
+
# Defining our positional encoding array, with `max_seq_len` rows
|
| 49 |
+
# This is an advantage of using sinusoidal encoding: we can easily expand to sequences of greater length without adding more learned params
|
| 50 |
+
angles = t.outer(t.arange(max_seq_len), 1 / 10000 ** (2 * t.arange(embedding_dim//2) / embedding_dim))
|
| 51 |
+
pe = t.zeros((max_seq_len, embedding_dim))
|
| 52 |
+
pe[:, ::2] = t.sin(angles)
|
| 53 |
+
pe[:, 1::2] = t.cos(angles)
|
| 54 |
+
# Register array as a buffer, rather than parameter (we don't want it to be updated by gradient descent)
|
| 55 |
+
self.register_buffer('pe', pe)
|
| 56 |
+
|
| 57 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
| 58 |
+
"""
|
| 59 |
+
x: shape (batch, seq_len, embedding_dim)
|
| 60 |
+
"""
|
| 61 |
+
batch, seq_len, embedding_dim = x.shape
|
| 62 |
+
# We slice the positional encoding, so it's the same shape as x
|
| 63 |
+
# This is equivalent to just using an nn.Embedding, but having the input be t.arange(seq_len)
|
| 64 |
+
return x + self.pe[:seq_len, :] # type: ignore
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# %%
|
| 68 |
+
class LayerNorm(nn.Module):
|
| 69 |
+
|
| 70 |
+
def __init__(self, normalized_shape: Union[int, List[int]], eps: float = 1e-05, elementwise_affine: bool = True):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.normalized_shape = normalized_shape
|
| 73 |
+
self.eps = eps
|
| 74 |
+
self.elementwise_affine = elementwise_affine
|
| 75 |
+
|
| 76 |
+
if self.elementwise_affine:
|
| 77 |
+
self.weight = nn.Parameter(t.ones(normalized_shape))
|
| 78 |
+
self.bias = nn.Parameter(t.zeros(normalized_shape))
|
| 79 |
+
|
| 80 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
| 81 |
+
normalized_shape_dims = 1 if isinstance(self.normalized_shape, int) else len(self.normalized_shape)
|
| 82 |
+
x_mean = x.mean(dim=list(range(x.dim()))[-normalized_shape_dims:], keepdim=True) # complement of the normalised shape
|
| 83 |
+
x_var = x.var(dim=list(range(x.dim()))[-normalized_shape_dims:], keepdim=True, unbiased=False) # complement of the normalised shape
|
| 84 |
+
x_scaled = (x - x_mean) / t.sqrt(x_var + self.eps)
|
| 85 |
+
if self.elementwise_affine:
|
| 86 |
+
return x_scaled * self.weight + self.bias
|
| 87 |
+
return x_scaled
|
| 88 |
+
|
| 89 |
+
def extra_repr(self) -> str:
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
# %%
|
| 93 |
+
from dataclasses import dataclass
|
| 94 |
+
|
| 95 |
+
@dataclass(frozen=True)
|
| 96 |
+
class TransformerConfig:
|
| 97 |
+
'''Constants used throughout your decoder-only transformer model.'''
|
| 98 |
+
|
| 99 |
+
num_layers: int
|
| 100 |
+
num_heads: int
|
| 101 |
+
vocab_size: int
|
| 102 |
+
hidden_size: int
|
| 103 |
+
max_seq_len: int
|
| 104 |
+
dropout: float = 0.1
|
| 105 |
+
layer_norm_epsilon: float = 1e-05
|
| 106 |
+
# %%
|
| 107 |
+
import attention_replication
|
| 108 |
+
|
| 109 |
+
class BertMLP(nn.Module):
|
| 110 |
+
def __init__(self, config: TransformerConfig):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.linear1 = nn.Linear(config.hidden_size, 4 * config.hidden_size)
|
| 113 |
+
self.gelu = nn.GELU()
|
| 114 |
+
self.linear2 = nn.Linear(4 * config.hidden_size, config.hidden_size)
|
| 115 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 116 |
+
|
| 117 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
| 118 |
+
x = self.linear1(x)
|
| 119 |
+
x = self.gelu(x)
|
| 120 |
+
x = self.linear2(x)
|
| 121 |
+
x = self.dropout(x)
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
class DecoderBlock(nn.Module):
|
| 125 |
+
|
| 126 |
+
def __init__(self, config: TransformerConfig):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.attention = attention_replication.MultiheadMaskedAttention(config.hidden_size, config.num_heads)
|
| 129 |
+
self.layer_norm1 = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
|
| 130 |
+
self.mlp = BertMLP(config)
|
| 131 |
+
self.layer_norm2 = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
|
| 132 |
+
|
| 133 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
| 134 |
+
y = self.attention(x)
|
| 135 |
+
y = self.layer_norm1(y)
|
| 136 |
+
x = x + y
|
| 137 |
+
z = self.mlp(x)
|
| 138 |
+
z = self.layer_norm2(z)
|
| 139 |
+
x = x + z
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
class DecoderOnlyTransformer(nn.Module):
|
| 143 |
+
|
| 144 |
+
def __init__(self, config: TransformerConfig):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.token_embedding = Embedding(config.vocab_size, config.hidden_size)
|
| 147 |
+
self.positional_embedding = PositionalEncoding(config.max_seq_len, config.hidden_size)
|
| 148 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 149 |
+
self.bert_blocks = nn.Sequential(*[DecoderBlock(config) for _ in range(config.num_layers)])
|
| 150 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
|
| 151 |
+
|
| 152 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
| 153 |
+
x = self.token_embedding(x)
|
| 154 |
+
x = self.positional_embedding(x)
|
| 155 |
+
x = self.dropout(x)
|
| 156 |
+
for block in self.bert_blocks:
|
| 157 |
+
x = block(x)
|
| 158 |
+
x = self.layer_norm(x)
|
| 159 |
+
x = einsum('num_embeddings embedding_dim,batch seq_len embedding_dim ->batch seq_len num_embeddings', self.token_embedding.weight, x)
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
# %%
|
| 163 |
+
from torch.utils.data import Dataset
|
| 164 |
+
|
| 165 |
+
class CustomTextDataset(Dataset):
|
| 166 |
+
def __init__(self, texts, labels):
|
| 167 |
+
self.labels = labels
|
| 168 |
+
self.texts = texts
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def from_config(config, samples):
|
| 172 |
+
texts = [t.randint(high=config.vocab_size, size=(config.max_seq_len,)) for _ in range(samples)]
|
| 173 |
+
labels = [t.flip(text, (0,)) for text in texts]
|
| 174 |
+
return CustomTextDataset(texts, labels)
|
| 175 |
+
|
| 176 |
+
def __len__(self):
|
| 177 |
+
return len(self.labels)
|
| 178 |
+
|
| 179 |
+
def __getitem__(self, idx):
|
| 180 |
+
label = self.labels[idx]
|
| 181 |
+
text = self.texts[idx]
|
| 182 |
+
sample = (text, label)
|
| 183 |
+
return sample
|
word_data.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
import requests
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
import torch as t
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class WordsDataset(Dataset):
|
| 9 |
+
def __init__(self, texts, labels):
|
| 10 |
+
self.texts = texts
|
| 11 |
+
self.labels = labels
|
| 12 |
+
|
| 13 |
+
def __len__(self):
|
| 14 |
+
return len(self.labels)
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, idx):
|
| 17 |
+
label = self.labels[idx]
|
| 18 |
+
text = self.texts[idx]
|
| 19 |
+
sample = (text, label)
|
| 20 |
+
return sample
|
| 21 |
+
|
| 22 |
+
#%%
|
| 23 |
+
def tokenize(text):
|
| 24 |
+
return re.split(r"\b", text)
|
| 25 |
+
|
| 26 |
+
def _remove_duplicates(text, string=" "):
|
| 27 |
+
if string + string in text:
|
| 28 |
+
text = text.replace(string + string, string)
|
| 29 |
+
return _remove_duplicates(text, string)
|
| 30 |
+
return text
|
| 31 |
+
|
| 32 |
+
def remove_duplicates(text):
|
| 33 |
+
text = _remove_duplicates(text, ' ')
|
| 34 |
+
text = _remove_duplicates(text, '\n')
|
| 35 |
+
return text
|
| 36 |
+
|
| 37 |
+
# %%
|
| 38 |
+
class WordData():
|
| 39 |
+
def __init__(self, text, start, end, device):
|
| 40 |
+
self.complete_text = remove_duplicates(text)
|
| 41 |
+
if start is not None and end is not None:
|
| 42 |
+
self.complete_text = self.get_excerpt(start, end)
|
| 43 |
+
self.complete_tokens = tokenize(self.complete_text)
|
| 44 |
+
self.vocab = sorted(set(self.complete_tokens))
|
| 45 |
+
self.token_to_id = dict(zip(self.vocab, list(range(len(self.vocab)))))
|
| 46 |
+
self.id_to_token = dict(zip(list(range(len(self.vocab))), self.vocab))
|
| 47 |
+
self.model_max_length = None
|
| 48 |
+
self.device = device
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def from_link(link, device, start=None, end=None):
|
| 52 |
+
return WordData(
|
| 53 |
+
requests.get(link).content.decode('utf-8'),
|
| 54 |
+
start,
|
| 55 |
+
end,
|
| 56 |
+
device=device
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def from_file(filename, device, start=None, end=None):
|
| 61 |
+
with open(filename, encoding='utf-8') as f:
|
| 62 |
+
text = f.read()
|
| 63 |
+
return WordData(text, start, end, device=device)
|
| 64 |
+
|
| 65 |
+
def get_excerpt(self, start="THE SONNETS", end="THE END", text=None):
|
| 66 |
+
if text is None:
|
| 67 |
+
text = self.complete_text
|
| 68 |
+
assert start in text, f'get_excerpt: cannot find {start} in text'
|
| 69 |
+
l_stripped = text.split(start, maxsplit=1)[1]
|
| 70 |
+
assert end in l_stripped, f'get_excerpt: cannot find {end} in text'
|
| 71 |
+
r_stripped = l_stripped.split(end, maxsplit=1)[0]
|
| 72 |
+
return r_stripped
|
| 73 |
+
|
| 74 |
+
def generate_autoregressive_dataset(self, sequence_length, text=None):
|
| 75 |
+
self.model_max_length = sequence_length
|
| 76 |
+
if text is None:
|
| 77 |
+
text = self.complete_text
|
| 78 |
+
token_ids = self.encode(text, return_tensors="pt")
|
| 79 |
+
inputs = [token_ids[i:i + sequence_length] for i in range(len(token_ids) - sequence_length)]
|
| 80 |
+
labels = [token_ids[i + 1:i + 1 + sequence_length] for i in range(len(token_ids) - sequence_length)]
|
| 81 |
+
return WordsDataset(inputs, labels)
|
| 82 |
+
|
| 83 |
+
def encode(self, initial_text: str, return_tensors: Optional[str] = None) -> Union[list, t.Tensor]:
|
| 84 |
+
'''
|
| 85 |
+
Tokenizes initial_text, then returns the token ids.
|
| 86 |
+
|
| 87 |
+
Return type is list by default, but if return_tensors="pt" then it is returned as a tensor.
|
| 88 |
+
'''
|
| 89 |
+
tokens = tokenize(initial_text)
|
| 90 |
+
token_ids = [self.token_to_id[t] for t in tokens]
|
| 91 |
+
if return_tensors == "pt":
|
| 92 |
+
return t.tensor(token_ids, device=self.device)
|
| 93 |
+
return token_ids
|
| 94 |
+
|
| 95 |
+
def decode(self, list_of_ids: Union[t.Tensor, list]) -> str:
|
| 96 |
+
'''
|
| 97 |
+
Converts ids to a list of tokens, then joins them into a single string.
|
| 98 |
+
'''
|
| 99 |
+
tokens = [self.id_to_token[int(i)] for i in list_of_ids]
|
| 100 |
+
return "".join(tokens)
|