Spaces:
Running
Running
Update src/data/dataset.py
Browse files- src/data/dataset.py +11 -8
src/data/dataset.py
CHANGED
|
@@ -89,11 +89,11 @@ class DruggenDataset(InMemoryDataset):
|
|
| 89 |
smiles_list (list): List of SMILES strings.
|
| 90 |
|
| 91 |
Returns:
|
| 92 |
-
|
| 93 |
filtered_smiles (list): List of valid SMILES strings.
|
| 94 |
"""
|
| 95 |
-
max_length = 0
|
| 96 |
filtered_smiles = []
|
|
|
|
| 97 |
for smiles in tqdm(smiles_list, desc="Filtering SMILES"):
|
| 98 |
mol = Chem.MolFromSmiles(smiles)
|
| 99 |
if mol is None:
|
|
@@ -113,8 +113,9 @@ class DruggenDataset(InMemoryDataset):
|
|
| 113 |
continue
|
| 114 |
|
| 115 |
filtered_smiles.append(smiles)
|
| 116 |
-
|
| 117 |
-
|
|
|
|
| 118 |
|
| 119 |
def _genA(self, mol, connected=True, max_length=None):
|
| 120 |
"""
|
|
@@ -290,20 +291,22 @@ class DruggenDataset(InMemoryDataset):
|
|
| 290 |
"""
|
| 291 |
# Read raw SMILES from file (assuming CSV with no header)
|
| 292 |
smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
data_list = []
|
| 295 |
self.m_dim = len(self.atom_decoder_m)
|
| 296 |
for smiles in tqdm(filtered_smiles, desc='Processing dataset', total=len(filtered_smiles)):
|
| 297 |
mol = Chem.MolFromSmiles(smiles)
|
| 298 |
-
A = self._genA(mol, connected=True, max_length=
|
| 299 |
if A is not None:
|
| 300 |
-
x_array = self._genX(mol, max_length=
|
| 301 |
if x_array is None:
|
| 302 |
continue
|
| 303 |
x = torch.from_numpy(x_array).to(torch.long).view(1, -1)
|
| 304 |
x = label2onehot(x, self.m_dim).squeeze()
|
| 305 |
if self.features:
|
| 306 |
-
f = torch.from_numpy(self._genF(mol, max_length=
|
| 307 |
x = torch.concat((x, f), dim=-1)
|
| 308 |
adjacency = torch.from_numpy(A)
|
| 309 |
edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
|
|
|
|
| 89 |
smiles_list (list): List of SMILES strings.
|
| 90 |
|
| 91 |
Returns:
|
| 92 |
+
num_smiles (int): Number of filtered smiles
|
| 93 |
filtered_smiles (list): List of valid SMILES strings.
|
| 94 |
"""
|
|
|
|
| 95 |
filtered_smiles = []
|
| 96 |
+
num_smiles = 0
|
| 97 |
for smiles in tqdm(smiles_list, desc="Filtering SMILES"):
|
| 98 |
mol = Chem.MolFromSmiles(smiles)
|
| 99 |
if mol is None:
|
|
|
|
| 113 |
continue
|
| 114 |
|
| 115 |
filtered_smiles.append(smiles)
|
| 116 |
+
num_smiles += 1
|
| 117 |
+
|
| 118 |
+
return num_smiles, filtered_smiles
|
| 119 |
|
| 120 |
def _genA(self, mol, connected=True, max_length=None):
|
| 121 |
"""
|
|
|
|
| 291 |
"""
|
| 292 |
# Read raw SMILES from file (assuming CSV with no header)
|
| 293 |
smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
|
| 294 |
+
num_smiles, filtered_smiles = self._filter_smiles(smiles_list)
|
| 295 |
+
self.num_smiles = num_smiles
|
| 296 |
+
|
| 297 |
data_list = []
|
| 298 |
self.m_dim = len(self.atom_decoder_m)
|
| 299 |
for smiles in tqdm(filtered_smiles, desc='Processing dataset', total=len(filtered_smiles)):
|
| 300 |
mol = Chem.MolFromSmiles(smiles)
|
| 301 |
+
A = self._genA(mol, connected=True, max_length=self.max_atom)
|
| 302 |
if A is not None:
|
| 303 |
+
x_array = self._genX(mol, max_length=self.max_atom)
|
| 304 |
if x_array is None:
|
| 305 |
continue
|
| 306 |
x = torch.from_numpy(x_array).to(torch.long).view(1, -1)
|
| 307 |
x = label2onehot(x, self.m_dim).squeeze()
|
| 308 |
if self.features:
|
| 309 |
+
f = torch.from_numpy(self._genF(mol, max_length=self.max_atom)).to(torch.long).view(x.shape[0], -1)
|
| 310 |
x = torch.concat((x, f), dim=-1)
|
| 311 |
adjacency = torch.from_numpy(A)
|
| 312 |
edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
|