Update modeling_phi.py
Browse filesfix dtype for inference
- modeling_phi.py +1 -1
modeling_phi.py
CHANGED
@@ -312,7 +312,7 @@ class MoE(nn.Module):
|
|
312 |
x = x.repeat_interleave(self.num_experts_per_tok, dim=0)
|
313 |
y = torch.empty_like(x)
|
314 |
for i, expert in enumerate(self.mlp):
|
315 |
-
y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
|
316 |
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
|
317 |
return y.view(*orig_shape)
|
318 |
|
|
|
312 |
x = x.repeat_interleave(self.num_experts_per_tok, dim=0)
|
313 |
y = torch.empty_like(x)
|
314 |
for i, expert in enumerate(self.mlp):
|
315 |
+
y[flat_expert_indices == i] = expert(x[flat_expert_indices == i]).to(dtype=y.dtype)
|
316 |
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
|
317 |
return y.view(*orig_shape)
|
318 |
|