Spaces:
Sleeping
Sleeping
File size: 1,723 Bytes
31cd6a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
import torch.nn as nn
device = "cuda:0" if torch.cuda.is_available() else "cpu"
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
(
torch.randn((self.in_features, self.out_features), device=device) * 0.1
).requires_grad_()
)
self.bias = nn.Parameter(
(torch.randn(self.out_features, device=device) * 0.1).requires_grad_()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.weight + self.bias
class ReLU(nn.Module):
@staticmethod
def forward(x: torch.Tensor) -> torch.Tensor:
return torch.max(x, torch.tensor(0))
class Sequential(nn.Module):
def __init__(self, *layers):
super(Sequential, self).__init__()
self.layers = nn.ModuleList(layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = layer(x)
return x
class Flatten(nn.Module):
@staticmethod
def forward(x: torch.Tensor) -> torch.Tensor:
return x.view(x.size(0), -1)
class DigitClassifier(nn.Module):
def __init__(self):
super(DigitClassifier, self).__init__()
self.main = Sequential(
Flatten(),
Linear(in_features=784, out_features=256),
ReLU(),
Linear(in_features=256, out_features=64),
ReLU(),
Linear(in_features=64, out_features=10),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.main(x)
|