sagawa commited on
Commit
eb03b3e
·
1 Parent(s): 69c6094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -54,6 +54,7 @@ class CFG():
54
  num_workers=1
55
 
56
  if st.button('predict'):
 
57
 
58
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59
 
@@ -127,12 +128,13 @@ if st.button('predict'):
127
  preds = []
128
  model.eval()
129
  model.to(device)
130
- tk0 = tqdm(test_loader, total=len(test_loader))
131
- for inputs in tk0:
132
  for k, v in inputs.items():
133
  inputs[k] = v.to(device)
134
  with torch.no_grad():
135
  y_preds = model(inputs)
 
136
  preds.append(y_preds.to('cpu').numpy())
137
  predictions = np.concatenate(preds)
138
  return predictions
@@ -165,10 +167,11 @@ if st.button('predict'):
165
  )
166
 
167
  else:
 
168
  test_ds = pd.DataFrame.from_dict({'input': CFG.data}, orient='index').T
169
  test_dataset = TestDataset(CFG, test_ds)
170
  test_loader = DataLoader(test_dataset,
171
- batch_size=1,
172
  shuffle=False,
173
  num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
174
 
 
54
  num_workers=1
55
 
56
  if st.button('predict'):
57
+ st.progress(0)
58
 
59
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
 
 
128
  preds = []
129
  model.eval()
130
  model.to(device)
131
+ tk0 = enumerate(test_loader, total=len(test_loader))
132
+ for i, inputs in tk0:
133
  for k, v in inputs.items():
134
  inputs[k] = v.to(device)
135
  with torch.no_grad():
136
  y_preds = model(inputs)
137
+ st.progress((i+1)*CFG.batch_size)
138
  preds.append(y_preds.to('cpu').numpy())
139
  predictions = np.concatenate(preds)
140
  return predictions
 
167
  )
168
 
169
  else:
170
+ CFG.batch_size=1
171
  test_ds = pd.DataFrame.from_dict({'input': CFG.data}, orient='index').T
172
  test_dataset = TestDataset(CFG, test_ds)
173
  test_loader = DataLoader(test_dataset,
174
+ batch_size=CFG.batch_size,
175
  shuffle=False,
176
  num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
177