fix int/str for conv_dim indexing
#5
by
winglian
- opened
- modeling_hymba.py +3 -3
modeling_hymba.py
CHANGED
|
@@ -396,7 +396,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
| 396 |
|
| 397 |
if has_mamba_state:
|
| 398 |
if hasattr(config, 'conv_dim'):
|
| 399 |
-
conv_dim = config.conv_dim[i]
|
| 400 |
else:
|
| 401 |
conv_dim = intermediate_size
|
| 402 |
self.conv_states += [
|
|
@@ -1523,7 +1523,7 @@ class HymbaBlock(nn.Module):
|
|
| 1523 |
num_ssm_param = 1
|
| 1524 |
|
| 1525 |
if not hasattr(config, 'conv_dim'):
|
| 1526 |
-
config.conv_dim = {i:0 for i in range(config.num_hidden_layers)}
|
| 1527 |
|
| 1528 |
self.conv1d = nn.Conv1d(
|
| 1529 |
in_channels=self.intermediate_size,
|
|
@@ -1534,7 +1534,7 @@ class HymbaBlock(nn.Module):
|
|
| 1534 |
padding=self.conv_kernel_size - 1
|
| 1535 |
)
|
| 1536 |
|
| 1537 |
-
config.conv_dim[self.layer_idx] = self.intermediate_size
|
| 1538 |
|
| 1539 |
self.x_proj = nn.ModuleList([nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) for _ in range(num_ssm_param)])
|
| 1540 |
self.dt_proj = nn.ModuleList([nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) for _ in range(num_ssm_param)])
|
|
|
|
| 396 |
|
| 397 |
if has_mamba_state:
|
| 398 |
if hasattr(config, 'conv_dim'):
|
| 399 |
+
conv_dim = config.conv_dim[str(i)]
|
| 400 |
else:
|
| 401 |
conv_dim = intermediate_size
|
| 402 |
self.conv_states += [
|
|
|
|
| 1523 |
num_ssm_param = 1
|
| 1524 |
|
| 1525 |
if not hasattr(config, 'conv_dim'):
|
| 1526 |
+
config.conv_dim = {str(i):0 for i in range(config.num_hidden_layers)}
|
| 1527 |
|
| 1528 |
self.conv1d = nn.Conv1d(
|
| 1529 |
in_channels=self.intermediate_size,
|
|
|
|
| 1534 |
padding=self.conv_kernel_size - 1
|
| 1535 |
)
|
| 1536 |
|
| 1537 |
+
config.conv_dim[str(self.layer_idx)] = self.intermediate_size
|
| 1538 |
|
| 1539 |
self.x_proj = nn.ModuleList([nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) for _ in range(num_ssm_param)])
|
| 1540 |
self.dt_proj = nn.ModuleList([nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) for _ in range(num_ssm_param)])
|