Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -54,6 +54,7 @@ class CFG():
|
|
54 |
num_workers=1
|
55 |
|
56 |
if st.button('predict'):
|
|
|
57 |
|
58 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
59 |
|
@@ -127,12 +128,13 @@ if st.button('predict'):
|
|
127 |
preds = []
|
128 |
model.eval()
|
129 |
model.to(device)
|
130 |
-
tk0 =
|
131 |
-
for inputs in tk0:
|
132 |
for k, v in inputs.items():
|
133 |
inputs[k] = v.to(device)
|
134 |
with torch.no_grad():
|
135 |
y_preds = model(inputs)
|
|
|
136 |
preds.append(y_preds.to('cpu').numpy())
|
137 |
predictions = np.concatenate(preds)
|
138 |
return predictions
|
@@ -165,10 +167,11 @@ if st.button('predict'):
|
|
165 |
)
|
166 |
|
167 |
else:
|
|
|
168 |
test_ds = pd.DataFrame.from_dict({'input': CFG.data}, orient='index').T
|
169 |
test_dataset = TestDataset(CFG, test_ds)
|
170 |
test_loader = DataLoader(test_dataset,
|
171 |
-
batch_size=
|
172 |
shuffle=False,
|
173 |
num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
|
174 |
|
|
|
54 |
num_workers=1
|
55 |
|
56 |
if st.button('predict'):
|
57 |
+
st.progress(0)
|
58 |
|
59 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
60 |
|
|
|
128 |
preds = []
|
129 |
model.eval()
|
130 |
model.to(device)
|
131 |
+
tk0 = enumerate(test_loader, total=len(test_loader))
|
132 |
+
for i, inputs in tk0:
|
133 |
for k, v in inputs.items():
|
134 |
inputs[k] = v.to(device)
|
135 |
with torch.no_grad():
|
136 |
y_preds = model(inputs)
|
137 |
+
st.progress((i+1)*CFG.batch_size)
|
138 |
preds.append(y_preds.to('cpu').numpy())
|
139 |
predictions = np.concatenate(preds)
|
140 |
return predictions
|
|
|
167 |
)
|
168 |
|
169 |
else:
|
170 |
+
CFG.batch_size=1
|
171 |
test_ds = pd.DataFrame.from_dict({'input': CFG.data}, orient='index').T
|
172 |
test_dataset = TestDataset(CFG, test_ds)
|
173 |
test_loader = DataLoader(test_dataset,
|
174 |
+
batch_size=CFG.batch_size,
|
175 |
shuffle=False,
|
176 |
num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
|
177 |
|