File size: 1,703 Bytes
01d43bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fb89b5
01d43bf
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os

import gradio as gr
import torch
from PIL import Image

from model import MangaColorizer
from utils import pil_to_torch, torch_to_pil

def load_html_template():
    html_dir = os.path.join(os.path.dirname(__file__), "templates")
    index_html_path = os.path.join(html_dir, "index.html")

    if os.path.exists(index_html_path):
        with open(index_html_path, "r") as html_file:
            index_html = html_file.read()
        return index_html
    else:
        print(f"Error: {index_html_path} not found.")


def load_model():
    model = MangaColorizer()
    models_dir = os.path.join(os.path.dirname(__file__), '..', 'model')
    model_file = os.path.join(models_dir, 'best_model_checkpoint.pth')
    if os.path.exists(model_file):
        with open(model_file, "rb") as f:
            checkpoint = torch.load(f, map_location="cpu")
            model.load_state_dict(checkpoint)
    else:
        print(f"Error: {model_file} not found.")

    return model
model = load_model()

def colorize_image(image):
    global model
    img = Image.fromarray(image).convert("L")
    output = model(pil_to_torch(img)).detach().cpu()
    output_image = torch_to_pil(output)
    
    return output_image


def main():
    index_html = load_html_template()
    with gr.Blocks() as demo:
        gr.HTML(index_html)
        gr.Interface(colorize_image, inputs=["image"], outputs=["image"], allow_flagging="never")
        gr.HTML("""
                <p style="text-align: center;font-size: large;">
                Checkout the <a href="https://github.com/zaidmehdi/manga-colorizer">Github Repo</a>
                </p>
                """)
    demo.launch()


if __name__ == "__main__":
    main()