zaidmehdi's picture
remove server_name and port from demo.launch()
5fb89b5 verified
raw
history blame
1.7 kB
import os
import gradio as gr
import torch
from PIL import Image
from model import MangaColorizer
from utils import pil_to_torch, torch_to_pil
def load_html_template():
html_dir = os.path.join(os.path.dirname(__file__), "templates")
index_html_path = os.path.join(html_dir, "index.html")
if os.path.exists(index_html_path):
with open(index_html_path, "r") as html_file:
index_html = html_file.read()
return index_html
else:
print(f"Error: {index_html_path} not found.")
def load_model():
model = MangaColorizer()
models_dir = os.path.join(os.path.dirname(__file__), '..', 'model')
model_file = os.path.join(models_dir, 'best_model_checkpoint.pth')
if os.path.exists(model_file):
with open(model_file, "rb") as f:
checkpoint = torch.load(f, map_location="cpu")
model.load_state_dict(checkpoint)
else:
print(f"Error: {model_file} not found.")
return model
model = load_model()
def colorize_image(image):
global model
img = Image.fromarray(image).convert("L")
output = model(pil_to_torch(img)).detach().cpu()
output_image = torch_to_pil(output)
return output_image
def main():
index_html = load_html_template()
with gr.Blocks() as demo:
gr.HTML(index_html)
gr.Interface(colorize_image, inputs=["image"], outputs=["image"], allow_flagging="never")
gr.HTML("""
<p style="text-align: center;font-size: large;">
Checkout the <a href="https://github.com/zaidmehdi/manga-colorizer">Github Repo</a>
</p>
""")
demo.launch()
if __name__ == "__main__":
main()