File size: 6,038 Bytes
19cef8f
b23b2e1
74e0f7a
707e3ee
19cef8f
f593040
 
 
 
db63f1a
74e0f7a
 
19cef8f
46a4fa5
f593040
27bd2c4
b23b2e1
 
46a4fa5
 
 
 
707e3ee
 
 
 
 
740a7c8
db63f1a
b23b2e1
 
 
 
 
 
 
 
db63f1a
74e0f7a
db63f1a
 
 
707e3ee
db63f1a
 
 
b23b2e1
74e0f7a
 
 
cfe386b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74e0f7a
 
cfe386b
 
 
 
 
 
74e0f7a
 
 
 
 
db63f1a
f593040
19cef8f
db63f1a
f593040
 
 
 
 
 
 
 
 
 
 
19cef8f
db63f1a
 
 
 
 
 
b23b2e1
74e0f7a
46a4fa5
 
 
 
74e0f7a
 
4a822b0
db63f1a
 
 
46a4fa5
 
 
 
740a7c8
46a4fa5
 
740a7c8
46a4fa5
74e0f7a
 
 
f593040
1d33274
46a4fa5
 
db63f1a
46a4fa5
 
 
 
74e0f7a
 
 
db63f1a
 
ab9dfa6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import streamlit as st
from huggingface_hub import InferenceClient
from gradio_client import Client
import re

# Load custom CSS
with open('style.css') as f:
    st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)

# Initialize the HuggingFace Inference Client
text_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
image_client = Client("Boboiazumi/animagine-xl-3.1")

def format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story):
    prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}. "
    prompt += f"Her hobbies include {hobbies}. Her favorite food is {favorite_food}. Here is her background story: {background_story}."
    return prompt

def format_prompt_for_image(name, hair_color, personality, outfit_style):
    prompt = f"Generate an image prompt for a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}."
    return prompt

def clean_generated_text(text):
    # Remove any unwanted trailing tags or characters like </s>
    clean_text = re.sub(r'</s>$', '', text).strip()
    return clean_text

def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    temperature = max(temperature, 1e-2)
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )
    try:
        stream = text_client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
        output = ""
        for response in stream:
            output += response.token.text
        return clean_generated_text(output)
    except Exception as e:
        st.error(f"Error generating text: {e}")
        return ""

def generate_image(prompt):
    try:
        result = image_client.predict(
            prompt,  # Image prompt
            "",  # Negative prompt
            0,  # Seed
            1024,  # Width
            1024,  # Height
            7.0,  # Guidance scale
            28,  # Number of inference steps
            'Euler a',  # Sampler
            '896 x 1152',  # Aspect Ratio
            '(None)',  # Style Preset
            'Standard v3.1',  # Quality Tags Presets
            False,  # Use Upscaler
            0.55,  # Upscaler strength
            1.5,  # Upscale by
            True,  # Add Quality Tags
            api_name="/run"
        )
        # Check if result is not empty and contains the expected structure
        if result and isinstance(result, list) and len(result) > 0 and 'image' in result[0]:
            return result[0]['image']
        else:
            st.error("Unexpected result format from the Gradio API.")
            return None
    except Exception as e:
        st.error(f"Error generating image: {e}")
        st.write("Full error details:", e)
        return None

def main():
    st.title("Enhanced Waifu Character Generator")

    # User inputs
    col1, col2 = st.columns(2)
    with col1:
        name = st.text_input("Name of the Waifu")
        hair_color = st.selectbox("Hair Color", ["Blonde", "Brunette", "Red", "Black", "Blue", "Pink"])
        personality = st.selectbox("Personality", ["Tsundere", "Yandere", "Kuudere", "Dandere", "Genki", "Normal"])
        outfit_style = st.selectbox("Outfit Style", ["School Uniform", "Maid Outfit", "Casual", "Kimono", "Gothic Lolita"])
    with col2:
        hobbies = st.text_input("Hobbies")
        favorite_food = st.text_input("Favorite Food")
        background_story = st.text_area("Background Story")
        system_prompt = st.text_input("Optional System Prompt", "")

    # Advanced settings
    with st.expander("Advanced Settings"):
        temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05)
        max_new_tokens = st.slider("Max new tokens", 0, 8192, 512, step=64)
        top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05)
        repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)

    # Initialize session state for generated text and image prompt
    if "character_description" not in st.session_state:
        st.session_state.character_description = ""
    if "image_prompt" not in st.session_state:
        st.session_state.image_prompt = ""
    if "image_path" not in st.session_state:
        st.session_state.image_path = ""

    # Generate button
    if st.button("Generate Waifu"):
        with st.spinner("Generating waifu character..."):
            description_prompt = format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story)
            image_prompt = format_prompt_for_image(name, hair_color, personality, outfit_style)

            # Generate character description
            st.session_state.character_description = generate_text(description_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
            
            # Generate image prompt
            st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
            
            # Generate image from image prompt
            st.session_state.image_path = generate_image(st.session_state.image_prompt)
            
            st.success("Waifu character generated!")

    # Display the generated character and image prompt
    if st.session_state.character_description:
        st.subheader("Generated Waifu Character")
        st.write(st.session_state.character_description)
    if st.session_state.image_prompt:
        st.subheader("Image Prompt")
        st.write(st.session_state.image_prompt)
    if st.session_state.image_path:
        st.subheader("Generated Image")
        st.image(st.session_state.image_path, caption="Generated Waifu Image")

if __name__ == "__main__":
    main()