File size: 2,741 Bytes
c1c16cb
58da73e
 
c1c16cb
 
58da73e
 
c1c16cb
58da73e
 
c1c16cb
 
 
58da73e
 
c1c16cb
 
 
 
58da73e
8407553
c1c16cb
8407553
c1c16cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# 2023年2月23日
"""
实现web界面

>>> streamlit run app.py
"""

from io import BytesIO
from pathlib import Path

import streamlit as st
from detect import detect, opt
from PIL import Image
from util import get_all_weights

"""
# CycleGAN
功能:上传本地文件、选择转换风格
"""


def load_css(css_path="./util/streamlit/css.css"):
    """
    加载CSS文件
    :param css_path: CSS文件路径
    """
    if Path(css_path).exists():
        with open(css_path) as f:
            # 将CSS文件内容插入到HTML中
            st.markdown(
                f"""<style>{f.read()}</style>""",
                unsafe_allow_html=True,
            )


def load_img_file(file):
    """读取图片文件"""
    img = Image.open(BytesIO(file.read()))
    st.image(img, use_column_width=True)  # 显示图片
    return img


def set_style_options(label: str, frame=st):
    """风格选项"""
    style_options = get_all_weights()
    options = [None] + style_options  # 默认空
    style_param = frame.selectbox(label=label, options=options)
    return style_param


# load_css()
tab_mul2mul, tab_mul2one, tab_set = st.tabs(["多图多风格转换", "多图同风格转换", "参数"])

with tab_mul2mul:
    uploaded_files = st.file_uploader(label="选择本地图片", accept_multiple_files=True, key=1)
    if uploaded_files:
        for idx, uploaded_file in enumerate(uploaded_files):
            colL, colR = st.columns(2)
            with colL:
                img = load_img_file(uploaded_file)
                style = set_style_options(label=str(uploaded_file), frame=st)
            with colR:
                if style:
                    fake_img = detect(img=img, style=style)
                    st.image(fake_img, caption="", use_column_width=True)

with tab_set:
    colL, colR = st.columns([1, 3])
    for k, v in sorted(vars(opt).items()):
        st.text_input(label=k, value=v, disabled=True)
    # st.selectbox("ss", options=opt.parse_args())
    confidence_threshold = st.slider("Confidence threshold", 0.0, 1.0, 0.5, 0.01)
    opt.no_dropout = st.radio("no_droput", [True, False])

with tab_mul2one:
    uploaded_files = st.file_uploader(label="选择本地图片", accept_multiple_files=True, key=2)
    if uploaded_files:
        colL, colR = st.columns(2)
        with colL:
            imgs = [load_img_file(ii) for ii in uploaded_files]
        with colR:
            style = set_style_options(label="选择风格", frame=st)
            if style:
                if st.button("♻️风格转换", use_container_width=True):
                    for img in imgs:
                        fake_img = detect(img, style)
                        st.image(fake_img, caption="", use_column_width=True)