Spaces:
Runtime error
Runtime error
File size: 3,788 Bytes
c2164fe 3d1b36b c2164fe 79df973 8b8b671 79df973 8b8b671 79df973 8b8b671 79df973 c2164fe 8b8b671 79df973 c2164fe 79df973 8b8b671 79df973 8b8b671 c2164fe 79df973 c2164fe 79df973 8b8b671 79df973 8b8b671 79df973 8b8b671 79df973 e60fd27 79df973 8b8b671 79df973 c2164fe e60fd27 79df973 e60fd27 79df973 c2164fe 8b8b671 c2164fe 79df973 8b8b671 c2164fe 8b8b671 3d1b36b 8b8b671 c2164fe 8b8b671 79df973 f8ec29a 79df973 8b8b671 79df973 8b8b671 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# -*- coding: utf-8 -*-
#
# @File: app.py
# @Author: Haozhe Xie
# @Date: 2024-03-02 16:30:00
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2024-03-03 16:08:25
# @Email: [email protected]
import gradio as gr
import logging
import numpy as np
import os
import ssl
import subprocess
import sys
import torch
import urllib.request
from PIL import Image
# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
ssl._create_default_https_context = ssl._create_unverified_context
# Import CityDreamer modules
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))
def setup_runtime_env():
logging.info("CUDA version is %s" % subprocess.check_output(["nvcc", "--version"]))
logging.info("GCC version is %s" % subprocess.check_output(["g++", "--version"]))
# Compile CUDA extensions
ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
for e in os.listdir(ext_dir):
if not os.path.isdir(os.path.join(ext_dir, e)):
continue
subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))
def get_models(file_name):
import citydreamer.model
if not os.path.exists(file_name):
urllib.request.urlretrieve(
"https://huggingface.co/hzxie/city-dreamer/resolve/main/%s" % file_name,
file_name,
)
ckpt = torch.load(file_name)
model = citydreamer.model.GanCraftGenerator(ckpt["cfg"])
if torch.cuda.is_available():
model = torch.nn.DataParallel(model).cuda().eval()
model.load_state_dict(ckpt["gancraft_g"], strict=False)
return model
def get_city_layout():
hf = np.array(Image.open("assets/NYC-HghtFld.png"))
seg = np.array(Image.open("assets/NYC-SegMap.png").convert("P"))
return hf, seg
def get_generated_city(radius, altitude, azimuth, map_center):
# The import must be done after CUDA extension compilation
import citydreamer.inference
return citydreamer.inference.generate_city(
get_generated_city.fgm,
get_generated_city.bgm,
get_generated_city.hf.copy(),
get_generated_city.seg.copy(),
map_center,
map_center,
radius,
altitude,
azimuth,
)
def main(debug):
title = "CityDreamer Demo 🏙️"
with open("README.md", "r") as f:
markdown = f.read()
desc = markdown[markdown.rfind("---") + 3 :]
with open("ARTICLE.md", "r") as f:
arti = f.read()
app = gr.Interface(
get_generated_city,
[
gr.Slider(128, 512, value=343, step=5, label="Camera Radius (m)"),
gr.Slider(256, 512, value=296, step=5, label="Camera Altitude (m)"),
gr.Slider(0, 360, value=60, step=5, label="Camera Azimuth (°)"),
gr.Slider(1440, 6752, value=3970, step=5, label="Map Center (px)"),
],
[gr.Image(type="numpy", label="Generated City")],
title=title,
description=desc,
article=arti,
allow_flagging="never",
)
app.queue(api_open=False)
app.launch(debug=debug)
if __name__ == "__main__":
logging.basicConfig(
format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
)
logging.info("Compiling CUDA extensions...")
setup_runtime_env()
logging.info("Downloading pretrained models...")
fgm = get_models("CityDreamer-Fgnd.pth")
bgm = get_models("CityDreamer-Bgnd.pth")
get_generated_city.fgm = fgm
get_generated_city.bgm = bgm
logging.info("Loading New York city layout to RAM...")
hf, seg = get_city_layout()
get_generated_city.hf = hf
get_generated_city.seg = seg
logging.info("Starting the main application...")
main(os.getenv("DEBUG") == "1")
|