File size: 4,642 Bytes
aa91c42
fa2c889
aa91c42
 
fa2c889
8016ac3
34d2b7b
8016ac3
 
 
aa91c42
 
 
 
8016ac3
47f95ed
34d2b7b
 
 
 
aa91c42
 
 
 
47f95ed
aa91c42
 
8016ac3
aa91c42
 
 
 
 
 
 
 
 
 
 
 
 
 
47f95ed
aa91c42
 
 
 
 
 
8016ac3
 
47f95ed
34d2b7b
47f95ed
 
 
 
34d2b7b
8016ac3
47f95ed
 
 
 
 
 
 
 
 
 
 
 
8016ac3
47f95ed
 
 
 
 
 
8016ac3
47f95ed
 
 
 
 
aa91c42
8016ac3
aa91c42
 
8016ac3
aa91c42
 
47f95ed
aa91c42
 
47f95ed
aa91c42
 
 
ceff522
aa91c42
 
 
fa2c889
aa91c42
fa2c889
 
aa91c42
 
 
 
 
 
 
 
 
8016ac3
 
aa91c42
 
 
fa2c889
aa91c42
 
 
fa2c889
 
aa91c42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import cv2
import gradio as gr
import AnimeGANv3_src
import numpy as np
import logging
import onnxruntime as ort  # Added this import

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class AnimeGANv3:
    def __init__(self):
        os.makedirs('output', exist_ok=True)
        os.makedirs('frames', exist_ok=True)
        logging.info(f"Available ONNX Runtime providers: {ort.get_available_providers()}")
        if 'CUDAExecutionProvider' in ort.get_available_providers():
            logging.info("Running on GPU with CUDA")
        else:
            logging.info("Running on CPU")

    def process_frame(self, frame, style_code, det_face):
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        output = AnimeGANv3_src.Convert(frame_rgb, style_code, det_face)
        return output[:, :, ::-1]

    def inference(self, video_path, style, if_face=None):
        logging.info(f"Starting inference: video={video_path}, style={style}, face_detection={if_face}")
        try:
            style_codes = {
                "AnimeGANv3_Arcane": "A",
                "AnimeGANv3_Trump v1.0": "T",
                "AnimeGANv3_Shinkai": "S",
                "AnimeGANv3_PortraitSketch": "P",
                "AnimeGANv3_Hayao": "H",
                "AnimeGANv3_Disney v1.0": "D",
                "AnimeGANv3_JP_face v1.0": "J",
                "AnimeGANv3_Kpop v2.0": "K",
            }
            style_code = style_codes.get(style, "U")
            det_face = if_face == "Yes"

            # Open video
            cap = cv2.VideoCapture(video_path)
            if not cap.isOpened():
                raise Exception("Could not open video file")

            fps = cap.get(cv2.CAP_PROP_FPS)
            frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            logging.info(f"Extracted {frame_count} frames at {fps} FPS to process")

            # Process in batches
            batch_size = 50  # Adjust based on testing
            save_path = "output/out.mp4"
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            out = None

            frame_idx = 0
            while cap.isOpened():
                batch_frames = []
                for _ in range(batch_size):
                    ret, frame = cap.read()
                    if not ret:
                        break
                    batch_frames.append(frame)
                    frame_idx += 1

                if not batch_frames:
                    break

                for idx, frame in enumerate(batch_frames):
                    stylized_frame = self.process_frame(frame, style_code, det_face)
                    if out is None:  # Initialize writer on first frame
                        out = cv2.VideoWriter(save_path, fourcc, fps, (width, height))
                    out.write(stylized_frame)
                    logging.info(f"Processed frame {frame_idx - len(batch_frames) + idx + 1}/{frame_count}")

            cap.release()
            if out:
                out.release()
            else:
                raise Exception("No frames processed")

            logging.info(f"Video created: {save_path}")
            return save_path
        except Exception as error:
            logging.error(f"Error: {str(error)}")
            return None

# Create an instance
anime_gan = AnimeGANv3()

# Gradio interface
title = "AnimeGANv3: Video to Anime Converter"
description = r"""Upload a video to convert it into anime style using AnimeGANv3.<br>
Select a style and choose whether to optimize for faces.<br>
credits to fine tuner Asher_Chan
<a href='https://github.com/TachibanaYoshino/AnimeGANv3' target='_blank'><b>AnimeGANv3 GitHub</b></a> | 
<a href='https://www.patreon.com/Asher_Chan' target='_blank'><b>Patreon</b></a>"""

iface = gr.Interface(
    fn=anime_gan.inference,
    inputs=[
        gr.Video(label="Input Video"),
        gr.Dropdown(choices=[
            'AnimeGANv3_Hayao',
            'AnimeGANv3_Shinkai',
            'AnimeGANv3_Arcane',
            'AnimeGANv3_Trump v1.0',
            'AnimeGANv3_Disney v1.0',
            'AnimeGANv3_PortraitSketch',
            'AnimeGANv3_JP_face v1.0',
            'AnimeGANv3_Kpop v2.0',
        ], label='AnimeGANv3 Style', value='AnimeGANv3_Arcane'),
        gr.Radio(choices=["Yes", "No"], label='Extract face', value="No"),
    ],
    outputs=[
        gr.Video(label="Output Video")
    ],
    title=title,
    description=description,
    allow_flagging="never"
)

iface.launch()