Kiwinicki commited on
Commit
2e786fb
·
verified ·
1 Parent(s): 9041310

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -17
app.py CHANGED
@@ -1,35 +1,35 @@
1
  import gradio as gr
2
- import torch.nn as nn
3
- from torch import tanh, Tensor
4
- from abc import ABC, abstractmethod
5
- from huggingface_hub import hf_hub_download
6
  import torch
 
7
  import json
8
  from omegaconf import OmegaConf
 
 
 
 
9
 
 
10
  repo_id = "Kiwinicki/sat2map-generator"
11
  generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
12
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
13
  model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
14
 
15
- import sys, os
16
  sys.path.append(os.path.dirname(model_path))
17
  from model import Generator
18
 
19
-
20
-
21
  with open(config_path, "r") as f:
22
  config_dict = json.load(f)
23
  cfg = OmegaConf.create(config_dict)
24
 
25
- generator = Generator(cfg)
26
- generator.load_state_dict(torch.load(generator_path, map_location=torch.device('cpu')))
 
 
27
  generator.eval()
28
 
29
- from PIL import Image
30
- import torchvision.transforms as transforms
31
-
32
-
33
  transform = transforms.Compose([
34
  transforms.Resize((256, 256)),
35
  transforms.ToTensor(),
@@ -37,12 +37,25 @@ transform = transforms.Compose([
37
  ])
38
 
39
  def process_image(image):
40
- image_tensor = transform(image).unsqueeze(0)
 
 
 
41
  with torch.no_grad():
42
  output_tensor = generator(image_tensor)
43
- output_image = output_tensor.squeeze(0)
 
 
 
44
  output_image = transforms.ToPILImage()(output_image)
 
45
  return output_image
46
 
47
- iface = gr.Interface(fn=process_image, inputs="image", outputs="image", title="Image Generator")
48
- iface.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
2
  import torch
3
+ from huggingface_hub import hf_hub_download
4
  import json
5
  from omegaconf import OmegaConf
6
+ import sys
7
+ import os
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
 
11
+ # Pobierz model i config
12
  repo_id = "Kiwinicki/sat2map-generator"
13
  generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
14
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
15
  model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
16
 
17
+ # Dodaj ścieżkę do modelu
18
  sys.path.append(os.path.dirname(model_path))
19
  from model import Generator
20
 
21
+ # Załaduj konfigurację
 
22
  with open(config_path, "r") as f:
23
  config_dict = json.load(f)
24
  cfg = OmegaConf.create(config_dict)
25
 
26
+ # Inicjalizacja modelu
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ generator = Generator(cfg).to(device)
29
+ generator.load_state_dict(torch.load(generator_path, map_location=device))
30
  generator.eval()
31
 
32
+ # Transformacje
 
 
 
33
  transform = transforms.Compose([
34
  transforms.Resize((256, 256)),
35
  transforms.ToTensor(),
 
37
  ])
38
 
39
  def process_image(image):
40
+ # Konwersja do tensora
41
+ image_tensor = transform(image).unsqueeze(0).to(device)
42
+
43
+ # Inferencja
44
  with torch.no_grad():
45
  output_tensor = generator(image_tensor)
46
+
47
+ # Przygotowanie wyjścia
48
+ output_image = output_tensor.squeeze(0).cpu()
49
+ output_image = output_image * 0.5 + 0.5 # Denormalizacja
50
  output_image = transforms.ToPILImage()(output_image)
51
+
52
  return output_image
53
 
54
+ iface = gr.Interface(
55
+ fn=process_image,
56
+ inputs=gr.Image(type="pil"),
57
+ outputs="image",
58
+ title="Satellite to Map Generator"
59
+ )
60
+
61
+ iface.launch()