File size: 2,163 Bytes
e9300a9
7227f0d
e9300a9
7227f0d
e9300a9
5e8fe36
 
 
 
e9300a9
 
 
 
 
 
 
 
 
 
7227f0d
 
 
 
 
 
 
 
 
 
 
 
 
e9300a9
 
 
 
 
5e8fe36
e9300a9
 
 
 
 
5e8fe36
3aaa3b9
7227f0d
 
 
 
 
 
 
 
 
 
 
 
e837f85
5e8fe36
3aaa3b9
5e8fe36
e9300a9
 
5e8fe36
 
e9300a9
3aaa3b9
e9300a9
 
 
 
 
 
 
5e8fe36
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import subprocess
import torch
import gradio as gr
from PIL import Image
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)

UPLOAD_FOLDER = 'uploads'
OUTPUT_FOLDER = 'outputs'

os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# Fix CUDA Out of Memory issue by enabling memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ** Fix: Pull missing files from Git LFS **
def fix_git_lfs():
    try:
        logging.info("Checking and pulling missing Git LFS files...")
        subprocess.run(["git", "lfs", "pull"], check=True)
        logging.info("Git LFS files pulled successfully.")
    except subprocess.CalledProcessError as e:
        logging.error(f"Git LFS pull failed: {e}")

# Run Git LFS Fix
fix_git_lfs()

# Function to process the image
def gradio_interface(image):
    input_path = os.path.join(UPLOAD_FOLDER, "input.png")
    output_path = os.path.join(OUTPUT_FOLDER, "output.png")

    image.save(input_path)
    logging.info(f"Input image saved at: {input_path}")

    try:
        # Ensure CUDA memory is freed before running inference
        torch.cuda.empty_cache()

        logging.info("Running model...")
        
        # Run NAFNet via subprocess (since import is failing)
        command = [
            "python", "NAFNet/demo.py",
            "-opt", "NAFNet/options/test/REDS/NAFNet-width64.yml",
            "--input_path", input_path,
            "--output_path", output_path
        ]
        result = subprocess.run(command, capture_output=True, text=True)

        if result.returncode != 0:
            logging.error(f"Model error: {result.stderr}")
            return f"Error: {result.stderr}"

        logging.info("Model execution completed.")

        return Image.open(output_path)

    except Exception as e:
        logging.error(f"Exception: {str(e)}")
        return f"Error: {str(e)}"

# Launch Gradio
iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type="pil"),
    outputs=gr.Image(type="pil"),
    title="Image Restoration with NAFNet"
)

iface.launch()  # No `share=True` for Hugging Face Spaces