Update app.py
Browse files
app.py
CHANGED
@@ -48,29 +48,32 @@ def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repet
|
|
48 |
def generate_image(prompt):
|
49 |
try:
|
50 |
result = image_client.predict(
|
51 |
-
prompt,
|
52 |
-
"",
|
53 |
-
0,
|
54 |
-
1024,
|
55 |
-
1024,
|
56 |
-
7.0,
|
57 |
-
28,
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
False,
|
63 |
-
0.55,
|
64 |
-
1.5,
|
65 |
-
True,
|
66 |
isImg2Img=False,
|
67 |
-
img_path=None,
|
68 |
img2img_strength=0.65,
|
69 |
api_name="/run"
|
70 |
)
|
71 |
-
#
|
72 |
-
if result
|
73 |
-
|
|
|
|
|
|
|
74 |
else:
|
75 |
st.error("Unexpected result format from the Gradio API.")
|
76 |
return None
|
@@ -107,8 +110,8 @@ def main():
|
|
107 |
st.session_state.character_description = ""
|
108 |
if "image_prompt" not in st.session_state:
|
109 |
st.session_state.image_prompt = ""
|
110 |
-
if "
|
111 |
-
st.session_state.
|
112 |
|
113 |
# Generate button
|
114 |
if st.button("Generate Waifu"):
|
@@ -123,7 +126,7 @@ def main():
|
|
123 |
st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
|
124 |
|
125 |
# Generate image from image prompt
|
126 |
-
st.session_state.
|
127 |
|
128 |
st.success("Waifu character generated!")
|
129 |
|
@@ -134,9 +137,10 @@ def main():
|
|
134 |
if st.session_state.image_prompt:
|
135 |
st.subheader("Image Prompt")
|
136 |
st.write(st.session_state.image_prompt)
|
137 |
-
if st.session_state.
|
138 |
st.subheader("Generated Image")
|
139 |
-
st.
|
|
|
140 |
|
141 |
if __name__ == "__main__":
|
142 |
main()
|
|
|
48 |
def generate_image(prompt):
|
49 |
try:
|
50 |
result = image_client.predict(
|
51 |
+
prompt=prompt,
|
52 |
+
negative_prompt="",
|
53 |
+
seed=0,
|
54 |
+
custom_width=1024,
|
55 |
+
custom_height=1024,
|
56 |
+
guidance_scale=7.0,
|
57 |
+
num_inference_steps=28,
|
58 |
+
sampler="Euler a",
|
59 |
+
aspect_ratio_selector="896 x 1152",
|
60 |
+
style_selector="(None)",
|
61 |
+
quality_selector="Standard v3.1",
|
62 |
+
use_upscaler=False,
|
63 |
+
upscaler_strength=0.55,
|
64 |
+
upscale_by=1.5,
|
65 |
+
add_quality_tags=True,
|
66 |
isImg2Img=False,
|
67 |
+
img_path=None,
|
68 |
img2img_strength=0.65,
|
69 |
api_name="/run"
|
70 |
)
|
71 |
+
# Process and display the result
|
72 |
+
if result:
|
73 |
+
images = []
|
74 |
+
for image_data in result[0]:
|
75 |
+
images.append(image_data['image'])
|
76 |
+
return images
|
77 |
else:
|
78 |
st.error("Unexpected result format from the Gradio API.")
|
79 |
return None
|
|
|
110 |
st.session_state.character_description = ""
|
111 |
if "image_prompt" not in st.session_state:
|
112 |
st.session_state.image_prompt = ""
|
113 |
+
if "image_paths" not in st.session_state:
|
114 |
+
st.session_state.image_paths = []
|
115 |
|
116 |
# Generate button
|
117 |
if st.button("Generate Waifu"):
|
|
|
126 |
st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
|
127 |
|
128 |
# Generate image from image prompt
|
129 |
+
st.session_state.image_paths = generate_image(st.session_state.image_prompt)
|
130 |
|
131 |
st.success("Waifu character generated!")
|
132 |
|
|
|
137 |
if st.session_state.image_prompt:
|
138 |
st.subheader("Image Prompt")
|
139 |
st.write(st.session_state.image_prompt)
|
140 |
+
if st.session_state.image_paths:
|
141 |
st.subheader("Generated Image")
|
142 |
+
for image_path in st.session_state.image_paths:
|
143 |
+
st.image(image_path, caption="Generated Waifu Image")
|
144 |
|
145 |
if __name__ == "__main__":
|
146 |
main()
|