Jackoabaad commited on
Commit
0142759
·
1 Parent(s): afe2482

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -27
app.py CHANGED
@@ -1,35 +1,77 @@
1
- import subprocess
2
- import os
3
- import nerf
4
  import torch
 
 
 
5
 
6
- # Check if the directory called Model exists
7
- if not os.path.exists("Model"):
8
- # Create the directory
9
- os.mkdir("Model")
10
 
11
- # Check if the model is downloaded
12
- if not os.path.exists("Model/nerf_model.ckpt"):
13
- # Download the model using the subprocess module
14
- subprocess.run(["wget", "https://github.com/bmild/nerf/releases/download/v0.5/nerf_model.ckpt"])
15
 
16
- # Load the NeRF model
17
- nerf_model = nerf.NeRF.load("Model/nerf_model.ckpt")
 
 
 
 
 
18
 
19
- # Define the camera parameters
20
- camera = nerf.Camera(
21
- fov=60,
22
- focal_length=50,
23
- znear=1,
24
- zfar=100,
25
- principal_point=(0.5, 0.5),
26
- )
27
 
28
- # Define the viewing direction
29
- viewing_direction = torch.tensor([0, 0, 1])
30
 
31
- # Render the novel view
32
- novel_view = nerf_model.render(viewing_direction, camera)
 
 
33
 
34
- # Save the novel view
35
- torch.save(novel_view, "novel_view.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.utils.data as data
5
 
6
+ from nerf import NeRF
7
+ from dataset import Dataset
 
 
8
 
 
 
 
 
9
 
10
+ def train(nerf, dataloader, optimizer, device):
11
+ nerf.train()
12
+ for i, data in enumerate(dataloader):
13
+ # Get the input and target images.
14
+ viewdirs, radiances = data
15
+ viewdirs = viewdirs.to(device)
16
+ radiances = radiances.to(device)
17
 
18
+ # Forward pass.
19
+ outputs = nerf(viewdirs)
 
 
 
 
 
 
20
 
21
+ # Compute the loss.
22
+ loss = nn.functional.mse_loss(outputs, radiances)
23
 
24
+ # Backpropagate the loss.
25
+ optimizer.zero_grad()
26
+ loss.backward()
27
+ optimizer.step()
28
 
29
+
30
+ def test(nerf, dataloader, device):
31
+ nerf.eval()
32
+ psnrs = []
33
+ for i, data in enumerate(dataloader):
34
+ # Get the input and target images.
35
+ viewdirs, radiances = data
36
+ viewdirs = viewdirs.to(device)
37
+ radiances = radiances.to(device)
38
+
39
+ # Forward pass.
40
+ outputs = nerf(viewdirs)
41
+
42
+ # Compute the PSNR.
43
+ psnrs.append(
44
+ torch.mean(
45
+ torch.nn.functional.psnr(outputs, radiances, data["intrinsics"])
46
+ )
47
+ )
48
+
49
+ return np.mean(psnrs)
50
+
51
+
52
+ def main():
53
+ # Create the dataset.
54
+ dataset = Dataset.from_json("data/nerf_synthetic_data.json")
55
+ dataloader = data.DataLoader(dataset, batch_size=1, shuffle=True)
56
+
57
+ # Create the NeRF model.
58
+ nerf = NeRF(32, 64, 8).to(device)
59
+
60
+ # Create the optimizer.
61
+ optimizer = optim.Adam(nerf.parameters(), lr=1e-3)
62
+
63
+ # Train the NeRF model.
64
+ for i in range(1000):
65
+ train(nerf, dataloader, optimizer, device)
66
+
67
+ # Print the loss and PSNR every 100 iterations.
68
+ if i % 100 == 0:
69
+ loss = test(nerf, dataloader, device)
70
+ print(f"Loss: {loss:.4f}")
71
+
72
+ # Save the NeRF model.
73
+ nerf.save("nerf.pth")
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()