YuWang0103 commited on
Commit
b31c308
·
verified ·
1 Parent(s): 6b59850

Update dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +1 -57
dataset.py CHANGED
@@ -34,39 +34,7 @@ def load_dataset_cc(dataname, batch_size, hydra_path, condition):
34
  model = SentenceTransformer("all-MiniLM-L6-v2")
35
  cond_embs = model.encode(condition)
36
 
37
- for domain in domains:
38
- if not os.path.exists(f'{hydra_path}/graphs/{domain}/train.pt'):
39
-
40
- data = torch.load(f'{hydra_path}/graphs/{domain}/{domain}.pt')
41
-
42
- #fix seed
43
- torch.manual_seed(0)
44
-
45
- #random permute and split
46
- n = len(data)
47
- indices = torch.randperm(n)
48
-
49
- if domain == 'eco':
50
- train_indices = indices[:4].repeat(50)
51
- val_indices = indices[4:5].repeat(50)
52
- test_indices = indices[5:]
53
- else:
54
- train_indices = indices[:int(0.7 * n)]
55
- val_indices = indices[int(0.7 * n):int(0.8 * n)]
56
- test_indices = indices[int(0.8 * n):]
57
-
58
- train_data = [data[_] for _ in train_indices]
59
- val_data = [data[_] for _ in val_indices]
60
- test_data = [data[_] for _ in test_indices]
61
-
62
- torch.save(train_indices, f'{hydra_path}/graphs/{domain}/train_indices.pt')
63
- torch.save(val_indices, f'{hydra_path}/graphs/{domain}/val_indices.pt')
64
- torch.save(test_indices, f'{hydra_path}/graphs/{domain}/test_indices.pt')
65
-
66
- torch.save(train_data, f'{hydra_path}/graphs/{domain}/train.pt')
67
- torch.save(val_data, f'{hydra_path}/graphs/{domain}/val.pt')
68
- torch.save(test_data, f'{hydra_path}/graphs/{domain}/test.pt')
69
-
70
 
71
  train_data, val_data, test_data = [], [], []
72
 
@@ -99,30 +67,6 @@ def load_dataset_cc(dataname, batch_size, hydra_path, condition):
99
  test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_data, val_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_data, test_indices)]
100
 
101
 
102
- elif dataname == 'all':
103
- for i, domain in enumerate(domains):
104
- train_d = torch.load(f'{hydra_path}/graphs/{domain}/train.pt')
105
- val_d = torch.load(f'{hydra_path}/graphs/{domain}/val.pt')
106
- test_d = torch.load(f'{hydra_path}/graphs/{domain}/test.pt')
107
-
108
- train_indices = torch.load(f'{hydra_path}/graphs/{domain}/train_indices.pt')
109
- val_indices = torch.load(f'{hydra_path}/graphs/{domain}/val_indices.pt')
110
- test_indices = torch.load(f'{hydra_path}/graphs/{domain}/test_indices.pt')
111
-
112
- # text_prompt = torch.load(f'{hydra_path}/graphs/{domain}/text_prompt_order.pt')
113
-
114
- with open(f'{hydra_path}/graphs/{domain}/text_prompt_order.txt', 'r') as f:
115
- text_prompt = f.readlines()
116
- text_prompt = [x.strip() for x in text_prompt]
117
-
118
- print(domain, text_prompt[0])
119
-
120
- text_embs = model.encode(text_prompt)
121
-
122
- train_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)])
123
- val_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_d, val_indices)])
124
- test_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)])
125
- print(i, domain, len(train_data), len(val_data), len(test_data))
126
 
127
  print('Size of dataset', len(train_data), len(val_data), len(test_data))
128
 
 
34
  model = SentenceTransformer("all-MiniLM-L6-v2")
35
  cond_embs = model.encode(condition)
36
 
37
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  train_data, val_data, test_data = [], [], []
40
 
 
67
  test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_data, val_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_data, test_indices)]
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  print('Size of dataset', len(train_data), len(val_data), len(test_data))
72