khoantap commited on
Commit
927cd21
·
verified ·
1 Parent(s): 0fdfda8

Update modeling_phi.py

Browse files

fix dtype for inference

Files changed (1) hide show
  1. modeling_phi.py +1 -1
modeling_phi.py CHANGED
@@ -312,7 +312,7 @@ class MoE(nn.Module):
312
  x = x.repeat_interleave(self.num_experts_per_tok, dim=0)
313
  y = torch.empty_like(x)
314
  for i, expert in enumerate(self.mlp):
315
- y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
316
  y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
317
  return y.view(*orig_shape)
318
 
 
312
  x = x.repeat_interleave(self.num_experts_per_tok, dim=0)
313
  y = torch.empty_like(x)
314
  for i, expert in enumerate(self.mlp):
315
+ y[flat_expert_indices == i] = expert(x[flat_expert_indices == i]).to(dtype=y.dtype)
316
  y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
317
  return y.view(*orig_shape)
318