|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from torch.nn.utils import weight_norm | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ConvRNNF0Predictor(nn.Module): | 
					
						
						|  | def __init__(self, | 
					
						
						|  | num_class: int = 1, | 
					
						
						|  | in_channels: int = 80, | 
					
						
						|  | cond_channels: int = 512 | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.num_class = num_class | 
					
						
						|  | self.condnet = nn.Sequential( | 
					
						
						|  | weight_norm( | 
					
						
						|  | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) | 
					
						
						|  | ), | 
					
						
						|  | nn.ELU(), | 
					
						
						|  | weight_norm( | 
					
						
						|  | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) | 
					
						
						|  | ), | 
					
						
						|  | nn.ELU(), | 
					
						
						|  | weight_norm( | 
					
						
						|  | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) | 
					
						
						|  | ), | 
					
						
						|  | nn.ELU(), | 
					
						
						|  | weight_norm( | 
					
						
						|  | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) | 
					
						
						|  | ), | 
					
						
						|  | nn.ELU(), | 
					
						
						|  | weight_norm( | 
					
						
						|  | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) | 
					
						
						|  | ), | 
					
						
						|  | nn.ELU(), | 
					
						
						|  | ) | 
					
						
						|  | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | x = self.condnet(x) | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | return torch.abs(self.classifier(x).squeeze(-1)) | 
					
						
						|  |  |