NguyNhu commited on
Commit
f330d7c
·
verified ·
1 Parent(s): c730e0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -106
app.py CHANGED
@@ -1,107 +1,91 @@
1
- #!/usr/bin/env python
2
-
3
- from setuptools import find_packages, setup
4
-
 
 
 
 
 
 
 
 
 
5
  import os
6
- import subprocess
7
- import time
8
-
9
- version_file = 'realesrgan/version.py'
10
-
11
-
12
- def readme():
13
- with open('README.md', encoding='utf-8') as f:
14
- content = f.read()
15
- return content
16
-
17
-
18
- def get_git_hash():
19
-
20
- def _minimal_ext_cmd(cmd):
21
- # construct minimal environment
22
- env = {}
23
- for k in ['SYSTEMROOT', 'PATH', 'HOME']:
24
- v = os.environ.get(k)
25
- if v is not None:
26
- env[k] = v
27
- # LANGUAGE is used on win32
28
- env['LANGUAGE'] = 'C'
29
- env['LANG'] = 'C'
30
- env['LC_ALL'] = 'C'
31
- out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
32
- return out
33
-
34
- try:
35
- out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
36
- sha = out.strip().decode('ascii')
37
- except OSError:
38
- sha = 'unknown'
39
-
40
- return sha
41
-
42
-
43
- def get_hash():
44
- if os.path.exists('.git'):
45
- sha = get_git_hash()[:7]
46
- else:
47
- sha = 'unknown'
48
-
49
- return sha
50
-
51
-
52
- def write_version_py():
53
- content = """# GENERATED VERSION FILE
54
- # TIME: {}
55
- __version__ = '{}'
56
- __gitsha__ = '{}'
57
- version_info = ({})
58
- """
59
- sha = get_hash()
60
- with open('VERSION', 'r') as f:
61
- SHORT_VERSION = f.read().strip()
62
- VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
63
-
64
- version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
65
- with open(version_file, 'w') as f:
66
- f.write(version_file_str)
67
-
68
-
69
- def get_version():
70
- with open(version_file, 'r') as f:
71
- exec(compile(f.read(), version_file, 'exec'))
72
- return locals()['__version__']
73
-
74
-
75
- def get_requirements(filename='requirements.txt'):
76
- here = os.path.dirname(os.path.realpath(__file__))
77
- with open(os.path.join(here, filename), 'r') as f:
78
- requires = [line.replace('\n', '') for line in f.readlines()]
79
- return requires
80
-
81
-
82
- if __name__ == '__main__':
83
- write_version_py()
84
- setup(
85
- name='realesrgan',
86
- version=get_version(),
87
- description='Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration',
88
- long_description=readme(),
89
- long_description_content_type='text/markdown',
90
- author='Xintao Wang',
91
- author_email='[email protected]',
92
- keywords='computer vision, pytorch, image restoration, super-resolution, esrgan, real-esrgan',
93
- url='https://github.com/xinntao/Real-ESRGAN',
94
- include_package_data=True,
95
- packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
96
- classifiers=[
97
- 'Development Status :: 4 - Beta',
98
- 'License :: OSI Approved :: Apache Software License',
99
- 'Operating System :: OS Independent',
100
- 'Programming Language :: Python :: 3',
101
- 'Programming Language :: Python :: 3.7',
102
- 'Programming Language :: Python :: 3.8',
103
- ],
104
- license='BSD-3-Clause License',
105
- setup_requires=['cython', 'numpy'],
106
- install_requires=get_requirements(),
107
- zip_safe=False)
 
1
+ # -*- coding:UTF-8 -*-
2
+ # !/usr/bin/env python
3
+ import spaces
4
+ import numpy as np
5
+ import gradio as gr
6
+ import gradio.exceptions
7
+ import roop.globals
8
+ from roop.core import (
9
+ start,
10
+ decode_execution_providers,
11
+ )
12
+ from roop.processors.frame.core import get_frame_processors_modules
13
+ from roop.utilities import normalize_output_path
14
  import os
15
+ import random
16
+ from PIL import Image
17
+ import onnxruntime as ort
18
+ import cv2
19
+ from roop.face_analyser import get_one_face
20
+
21
+ @spaces.GPU
22
+ def swap_face(source_file, target_file, doFaceEnhancer):
23
+ session_dir = "temp" # Sử dụng thư mục cố định
24
+ os.makedirs(session_dir, exist_ok=True)
25
+
26
+ # Tạo tên file ngẫu nhiên
27
+ source_filename = f"source_{random.randint(1000, 9999)}.jpg"
28
+ target_filename = f"target_{random.randint(1000, 9999)}.jpg"
29
+ output_filename = f"output_{random.randint(1000, 9999)}.jpg"
30
+
31
+ source_path = os.path.join(session_dir, source_filename)
32
+ target_path = os.path.join(session_dir, target_filename)
33
+
34
+ source_image = Image.fromarray(source_file)
35
+ source_image.save(source_path)
36
+ target_image = Image.fromarray(target_file)
37
+ target_image.save(target_path)
38
+
39
+ print("source_path: ", source_path)
40
+ print("target_path: ", target_path)
41
+
42
+ # Check if a face is detected in the source image
43
+ source_face = get_one_face(cv2.imread(source_path))
44
+ if source_face is None:
45
+ raise gradio.exceptions.Error("No face in source path detected.")
46
+
47
+ # Check if a face is detected in the target image
48
+ target_face = get_one_face(cv2.imread(target_path))
49
+ if target_face is None:
50
+ raise gradio.exceptions.Error("No face in target path detected.")
51
+
52
+ output_path = os.path.join(session_dir, output_filename)
53
+ normalized_output_path = normalize_output_path(source_path, target_path, output_path)
54
+
55
+ frame_processors = ["face_swapper", "face_enhancer"] if doFaceEnhancer else ["face_swapper"]
56
+
57
+ for frame_processor in get_frame_processors_modules(frame_processors):
58
+ if not frame_processor.pre_check():
59
+ print(f"Pre-check failed for {frame_processor}")
60
+ raise gradio.exceptions.Error(f"Pre-check failed for {frame_processor}")
61
+
62
+ roop.globals.source_path = source_path
63
+ roop.globals.target_path = target_path
64
+ roop.globals.output_path = normalized_output_path
65
+ roop.globals.frame_processors = frame_processors
66
+ roop.globals.headless = True
67
+ roop.globals.keep_fps = True
68
+ roop.globals.keep_audio = True
69
+ roop.globals.keep_frames = False
70
+ roop.globals.many_faces = False
71
+ roop.globals.video_encoder = "libx264"
72
+ roop.globals.video_quality = 18
73
+ roop.globals.execution_providers = decode_execution_providers(['cpu'])
74
+ roop.globals.reference_face_position = 0
75
+ roop.globals.similar_face_distance = 0.6
76
+ roop.globals.max_memory = 60
77
+ roop.globals.execution_threads = 8
78
+
79
+ start()
80
+ return normalized_output_path
81
+
82
+ app = gr.Interface(
83
+ fn=swap_face,
84
+ inputs=[
85
+ gr.Image(),
86
+ gr.Image(),
87
+ gr.Checkbox(label="Face Enhancer?", info="Do face enhancement?")
88
+ ],
89
+ outputs="image"
90
+ )
91
+ app.launch()