File size: 6,410 Bytes
19cef8f b23b2e1 74e0f7a 707e3ee 19cef8f 5a59b16 f593040 db63f1a 74e0f7a 42f44df fe28822 19cef8f 7249a81 efdf463 b23b2e1 7249a81 efdf463 46a4fa5 707e3ee 0833433 db63f1a b23b2e1 db63f1a 74e0f7a db63f1a 707e3ee db63f1a b23b2e1 74e0f7a a73ddf4 e68091a a73ddf4 e68091a b1a6c07 74e0f7a a73ddf4 cfe386b 74e0f7a f49bf68 74e0f7a db63f1a 7249a81 19cef8f 7249a81 f593040 7249a81 77a9b4d 7249a81 77a9b4d 7249a81 77a9b4d 7249a81 77a9b4d 7249a81 77a9b4d db63f1a ab9dfa6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import streamlit as st
from huggingface_hub import InferenceClient
from gradio_client import Client
import re
# Set the page config
st.set_page_config(layout="wide")
# Load custom CSS
with open('style.css') as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
# Initialize the HuggingFace Inference Client
text_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
#image_client = Client("Boboiazumi/animagine-xl-3.1")
image_client = Client("phenixrhyder/nsfw-waifu-gradio")
def format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story):
prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}. "
prompt += f"Her hobbies include {hobbies}. Her favorite food is {favorite_food}. Here is her background story: {background_story}."
return prompt
def format_prompt_for_image(name, hair_color, personality, outfit_style):
prompt = f"Generate an image prompt for a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}."
return prompt
def clean_generated_text(text):
# Remove any unwanted trailing tags or characters like </s>
clean_text = re.sub(r'</s>$', '', text).strip()
return clean_text
def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
temperature = max(temperature, 1e-2)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
try:
stream = text_client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
return clean_generated_text(output)
except Exception as e:
st.error(f"Error generating text: {e}")
return ""
def generate_image(prompt):
try:
result = image_client.predict(
prompt=prompt,
negative_prompt="",
seed=0,
custom_width=1024,
custom_height=1024,
guidance_scale=7.0,
num_inference_steps=28,
sampler="Euler a",
aspect_ratio_selector="896 x 1152",
style_selector="(None)",
quality_selector="Standard v3.1",
use_upscaler=False,
upscaler_strength=0.55,
upscale_by=1.5,
add_quality_tags=True,
isImg2Img=False,
img_path=None,
img2img_strength=0.65,
api_name="/predict"
)
# Process and display the result
if result:
images = []
for image_data in result[0]:
images.append(image_data['image'])
return images
else:
st.error("Unexpected result format from the Gradio API.")
return None
except Exception as e:
st.error(f"Error generating image: {e}")
st.write("Full error details:", e)
return None
def main():
st.title("Enhanced Waifu Character Generator")
# User inputs
col1, col2 = st.columns(2)
with col1:
name = st.text_input("Name of the Waifu")
hair_color = st.selectbox("Hair Color", ["Blonde", "Brunette", "Red", "Black", "Blue", "Pink"])
personality = st.selectbox("Personality", ["Tsundere", "Yandere", "Kuudere", "Dandere", "Genki", "Normal"])
outfit_style = st.selectbox("Outfit Style", ["School Uniform", "Maid Outfit", "Casual", "Kimono", "Gothic Lolita"])
hobbies = st.text_input("Hobbies")
favorite_food = st.text_input("Favorite Food")
background_story = st.text_area("Background Story")
system_prompt = st.text_input("Optional System Prompt", "")
# Advanced settings
with st.expander("Advanced Settings"):
temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05)
max_new_tokens = st.slider("Max new tokens", 0, 8192, 512, step=64)
top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05)
repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)
# Initialize session state for generated text and image prompt
if "character_description" not in st.session_state:
st.session_state.character_description = ""
if "image_prompt" not in st.session_state:
st.session_state.image_prompt = ""
if "image_paths" not in st.session_state:
st.session_state.image_paths = []
# Generate button
if st.button("Generate Waifu"):
with st.spinner("Generating waifu character..."):
description_prompt = format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story)
image_prompt = format_prompt_for_image(name, hair_color, personality, outfit_style)
# Generate character description
st.session_state.character_description = generate_text(description_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
# Generate image prompt
st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
# Generate image from image prompt
st.session_state.image_paths = generate_image(st.session_state.image_prompt)
st.success("Waifu character generated!")
with col2:
# Display the generated character and image prompt
if st.session_state.character_description:
st.subheader("Generated Waifu Character")
st.write(st.session_state.character_description)
if st.session_state.image_prompt:
st.subheader("Image Prompt")
st.write(st.session_state.image_prompt)
if st.session_state.image_paths:
st.subheader("Generated Image")
for image_path in st.session_state.image_paths:
st.image(image_path, caption="Generated Waifu Image")
if __name__ == "__main__":
main()
|