Spaces:
Runtime error
Runtime error
Update dataset.py
Browse files- dataset.py +1 -57
dataset.py
CHANGED
@@ -34,39 +34,7 @@ def load_dataset_cc(dataname, batch_size, hydra_path, condition):
|
|
34 |
model = SentenceTransformer("all-MiniLM-L6-v2")
|
35 |
cond_embs = model.encode(condition)
|
36 |
|
37 |
-
|
38 |
-
if not os.path.exists(f'{hydra_path}/graphs/{domain}/train.pt'):
|
39 |
-
|
40 |
-
data = torch.load(f'{hydra_path}/graphs/{domain}/{domain}.pt')
|
41 |
-
|
42 |
-
#fix seed
|
43 |
-
torch.manual_seed(0)
|
44 |
-
|
45 |
-
#random permute and split
|
46 |
-
n = len(data)
|
47 |
-
indices = torch.randperm(n)
|
48 |
-
|
49 |
-
if domain == 'eco':
|
50 |
-
train_indices = indices[:4].repeat(50)
|
51 |
-
val_indices = indices[4:5].repeat(50)
|
52 |
-
test_indices = indices[5:]
|
53 |
-
else:
|
54 |
-
train_indices = indices[:int(0.7 * n)]
|
55 |
-
val_indices = indices[int(0.7 * n):int(0.8 * n)]
|
56 |
-
test_indices = indices[int(0.8 * n):]
|
57 |
-
|
58 |
-
train_data = [data[_] for _ in train_indices]
|
59 |
-
val_data = [data[_] for _ in val_indices]
|
60 |
-
test_data = [data[_] for _ in test_indices]
|
61 |
-
|
62 |
-
torch.save(train_indices, f'{hydra_path}/graphs/{domain}/train_indices.pt')
|
63 |
-
torch.save(val_indices, f'{hydra_path}/graphs/{domain}/val_indices.pt')
|
64 |
-
torch.save(test_indices, f'{hydra_path}/graphs/{domain}/test_indices.pt')
|
65 |
-
|
66 |
-
torch.save(train_data, f'{hydra_path}/graphs/{domain}/train.pt')
|
67 |
-
torch.save(val_data, f'{hydra_path}/graphs/{domain}/val.pt')
|
68 |
-
torch.save(test_data, f'{hydra_path}/graphs/{domain}/test.pt')
|
69 |
-
|
70 |
|
71 |
train_data, val_data, test_data = [], [], []
|
72 |
|
@@ -99,30 +67,6 @@ def load_dataset_cc(dataname, batch_size, hydra_path, condition):
|
|
99 |
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_data, val_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_data, test_indices)]
|
100 |
|
101 |
|
102 |
-
elif dataname == 'all':
|
103 |
-
for i, domain in enumerate(domains):
|
104 |
-
train_d = torch.load(f'{hydra_path}/graphs/{domain}/train.pt')
|
105 |
-
val_d = torch.load(f'{hydra_path}/graphs/{domain}/val.pt')
|
106 |
-
test_d = torch.load(f'{hydra_path}/graphs/{domain}/test.pt')
|
107 |
-
|
108 |
-
train_indices = torch.load(f'{hydra_path}/graphs/{domain}/train_indices.pt')
|
109 |
-
val_indices = torch.load(f'{hydra_path}/graphs/{domain}/val_indices.pt')
|
110 |
-
test_indices = torch.load(f'{hydra_path}/graphs/{domain}/test_indices.pt')
|
111 |
-
|
112 |
-
# text_prompt = torch.load(f'{hydra_path}/graphs/{domain}/text_prompt_order.pt')
|
113 |
-
|
114 |
-
with open(f'{hydra_path}/graphs/{domain}/text_prompt_order.txt', 'r') as f:
|
115 |
-
text_prompt = f.readlines()
|
116 |
-
text_prompt = [x.strip() for x in text_prompt]
|
117 |
-
|
118 |
-
print(domain, text_prompt[0])
|
119 |
-
|
120 |
-
text_embs = model.encode(text_prompt)
|
121 |
-
|
122 |
-
train_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)])
|
123 |
-
val_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_d, val_indices)])
|
124 |
-
test_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)])
|
125 |
-
print(i, domain, len(train_data), len(val_data), len(test_data))
|
126 |
|
127 |
print('Size of dataset', len(train_data), len(val_data), len(test_data))
|
128 |
|
|
|
34 |
model = SentenceTransformer("all-MiniLM-L6-v2")
|
35 |
cond_embs = model.encode(condition)
|
36 |
|
37 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
train_data, val_data, test_data = [], [], []
|
40 |
|
|
|
67 |
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_data, val_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_data, test_indices)]
|
68 |
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
print('Size of dataset', len(train_data), len(val_data), len(test_data))
|
72 |
|