randeom commited on
Commit
f593040
·
verified ·
1 Parent(s): 57893c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -26
app.py CHANGED
@@ -2,20 +2,21 @@ import streamlit as st
2
  from huggingface_hub import InferenceClient
3
  import time
4
 
 
 
 
 
5
  # Initialize the HuggingFace Inference Client
6
  client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
7
 
8
- def format_prompt(message, history, system_prompt=""):
9
- prompt = "<s>"
 
10
  if system_prompt:
11
- prompt += f"[SYS] {system_prompt} [/SYS] "
12
- for user_prompt, bot_response in history:
13
- prompt += f"[INST] {user_prompt} [/INST]"
14
- prompt += f" {bot_response}</s> "
15
- prompt += f"[INST] {message} [/INST]"
16
  return prompt
17
 
18
- def generate(prompt, history, system_prompt="", temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
19
  temperature = max(temperature, 1e-2)
20
  generate_kwargs = dict(
21
  temperature=temperature,
@@ -25,10 +26,8 @@ def generate(prompt, history, system_prompt="", temperature=0.9, max_new_tokens=
25
  do_sample=True,
26
  seed=42,
27
  )
28
-
29
- formatted_prompt = format_prompt(prompt, history, system_prompt)
30
  try:
31
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
32
  output = ""
33
  for response in stream:
34
  output += response.token.text
@@ -38,14 +37,20 @@ def generate(prompt, history, system_prompt="", temperature=0.9, max_new_tokens=
38
  return ""
39
 
40
  def main():
41
- st.title("Waifu Character Generator")
42
 
43
  # User inputs
44
- name = st.text_input("Name of the Waifu")
45
- hair_color = st.selectbox("Hair Color", ["Blonde", "Brunette", "Red", "Black", "Blue", "Pink"])
46
- personality = st.selectbox("Personality", ["Tsundere", "Yandere", "Kuudere", "Dandere", "Genki", "Normal"])
47
- outfit_style = st.selectbox("Outfit Style", ["School Uniform", "Maid Outfit", "Casual", "Kimono", "Gothic Lolita"])
48
- system_prompt = st.text_input("Optional System Prompt", "")
 
 
 
 
 
 
49
 
50
  # Advanced settings
51
  with st.expander("Advanced Settings"):
@@ -54,21 +59,17 @@ def main():
54
  top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05)
55
  repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)
56
 
57
- # Initialize session state for generated text
58
- if "generated_text" not in st.session_state:
59
- st.session_state.generated_text = ""
60
-
61
  # Generate button
62
  if st.button("Generate Waifu"):
63
  with st.spinner("Generating waifu character..."):
64
- history = []
65
- prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}."
66
- st.session_state.generated_text = generate(prompt, history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
67
 
68
  # Display the generated character
69
- if st.session_state.generated_text:
70
  st.subheader("Generated Waifu Character")
71
- st.write(st.session_state.generated_text)
72
 
73
  if __name__ == "__main__":
74
  main()
 
2
  from huggingface_hub import InferenceClient
3
  import time
4
 
5
+ # Load custom CSS
6
+ with open('style.css') as f:
7
+ st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
8
+
9
  # Initialize the HuggingFace Inference Client
10
  client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
11
 
12
+ def format_prompt(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story, system_prompt=""):
13
+ prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}. "
14
+ prompt += f"Her hobbies include {hobbies}. Her favorite food is {favorite_food}. Here is her background story: {background_story}."
15
  if system_prompt:
16
+ prompt = f"[SYS] {system_prompt} [/SYS] " + prompt
 
 
 
 
17
  return prompt
18
 
19
+ def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
20
  temperature = max(temperature, 1e-2)
21
  generate_kwargs = dict(
22
  temperature=temperature,
 
26
  do_sample=True,
27
  seed=42,
28
  )
 
 
29
  try:
30
+ stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
31
  output = ""
32
  for response in stream:
33
  output += response.token.text
 
37
  return ""
38
 
39
  def main():
40
+ st.title("Enhanced Waifu Character Generator")
41
 
42
  # User inputs
43
+ col1, col2 = st.columns(2)
44
+ with col1:
45
+ name = st.text_input("Name of the Waifu")
46
+ hair_color = st.selectbox("Hair Color", ["Blonde", "Brunette", "Red", "Black", "Blue", "Pink"])
47
+ personality = st.selectbox("Personality", ["Tsundere", "Yandere", "Kuudere", "Dandere", "Genki", "Normal"])
48
+ outfit_style = st.selectbox("Outfit Style", ["School Uniform", "Maid Outfit", "Casual", "Kimono", "Gothic Lolita"])
49
+ with col2:
50
+ hobbies = st.text_input("Hobbies")
51
+ favorite_food = st.text_input("Favorite Food")
52
+ background_story = st.text_area("Background Story")
53
+ system_prompt = st.text_input("Optional System Prompt", "")
54
 
55
  # Advanced settings
56
  with st.expander("Advanced Settings"):
 
59
  top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05)
60
  repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)
61
 
 
 
 
 
62
  # Generate button
63
  if st.button("Generate Waifu"):
64
  with st.spinner("Generating waifu character..."):
65
+ prompt = format_prompt(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story, system_prompt)
66
+ generated_text = generate_text(prompt, temperature, max_new_tokens, top_p, repetition_penalty)
67
+ st.success("Waifu character generated!")
68
 
69
  # Display the generated character
70
+ if generated_text:
71
  st.subheader("Generated Waifu Character")
72
+ st.write(generated_text)
73
 
74
  if __name__ == "__main__":
75
  main()