Update app.py
Browse files
app.py
CHANGED
@@ -8,14 +8,16 @@ with open('style.css') as f:
|
|
8 |
# Initialize the HuggingFace Inference Client
|
9 |
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
|
10 |
|
11 |
-
def
|
12 |
prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}. "
|
13 |
prompt += f"Her hobbies include {hobbies}. Her favorite food is {favorite_food}. Here is her background story: {background_story}."
|
14 |
-
if system_prompt:
|
15 |
-
prompt = f"[SYS] {system_prompt} [/SYS] " + prompt
|
16 |
return prompt
|
17 |
|
18 |
-
def
|
|
|
|
|
|
|
|
|
19 |
temperature = max(temperature, 1e-2)
|
20 |
generate_kwargs = dict(
|
21 |
temperature=temperature,
|
@@ -59,20 +61,32 @@ def main():
|
|
59 |
repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)
|
60 |
|
61 |
# Initialize session state for generated text
|
62 |
-
if "
|
63 |
-
st.session_state.
|
|
|
|
|
64 |
|
65 |
# Generate button
|
66 |
if st.button("Generate Waifu"):
|
67 |
with st.spinner("Generating waifu character..."):
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
st.success("Waifu character generated!")
|
71 |
|
72 |
-
# Display the generated character
|
73 |
-
if st.session_state.
|
74 |
st.subheader("Generated Waifu Character")
|
75 |
-
st.write(st.session_state.
|
|
|
|
|
|
|
76 |
|
77 |
if __name__ == "__main__":
|
78 |
main()
|
|
|
8 |
# Initialize the HuggingFace Inference Client
|
9 |
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
|
10 |
|
11 |
+
def format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story):
|
12 |
prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}. "
|
13 |
prompt += f"Her hobbies include {hobbies}. Her favorite food is {favorite_food}. Here is her background story: {background_story}."
|
|
|
|
|
14 |
return prompt
|
15 |
|
16 |
+
def format_prompt_for_image(name, hair_color, personality, outfit_style):
|
17 |
+
prompt = f"Generate an image prompt for a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}."
|
18 |
+
return prompt
|
19 |
+
|
20 |
+
def generate_text(prompt, temperature=0.9, max_new_tokens=2512, top_p=0.95, repetition_penalty=1.0):
|
21 |
temperature = max(temperature, 1e-2)
|
22 |
generate_kwargs = dict(
|
23 |
temperature=temperature,
|
|
|
61 |
repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)
|
62 |
|
63 |
# Initialize session state for generated text
|
64 |
+
if "character_description" not in st.session_state:
|
65 |
+
st.session_state.character_description = ""
|
66 |
+
if "image_prompt" not in st.session_state:
|
67 |
+
st.session_state.image_prompt = ""
|
68 |
|
69 |
# Generate button
|
70 |
if st.button("Generate Waifu"):
|
71 |
with st.spinner("Generating waifu character..."):
|
72 |
+
description_prompt = format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story)
|
73 |
+
image_prompt = format_prompt_for_image(name, hair_color, personality, outfit_style)
|
74 |
+
|
75 |
+
# Generate character description
|
76 |
+
st.session_state.character_description = generate_text(description_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
|
77 |
+
|
78 |
+
# Generate image prompt
|
79 |
+
st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
|
80 |
+
|
81 |
st.success("Waifu character generated!")
|
82 |
|
83 |
+
# Display the generated character and image prompt
|
84 |
+
if st.session_state.character_description:
|
85 |
st.subheader("Generated Waifu Character")
|
86 |
+
st.write(st.session_state.character_description)
|
87 |
+
if st.session_state.image_prompt:
|
88 |
+
st.subheader("Image Prompt")
|
89 |
+
st.write(st.session_state.image_prompt)
|
90 |
|
91 |
if __name__ == "__main__":
|
92 |
main()
|