UnSinnlos commited on
Commit
b5b5087
·
verified ·
1 Parent(s): 3b0237b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import torch
5
+ import numpy as np
6
+ import gradio as gr
7
+ from chatterbox.tts import ChatterboxTTS
8
+ from huggingface_hub import hf_hub_download
9
+ from safetensors.torch import load_file
10
+ from torch import nn
11
+ import re
12
+
13
+ # === Einstellungen ===
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ MODEL_REPO = "SebastianBodza/Kartoffelbox-v0.1"
16
+ T3_CHECKPOINT_FILE = "t3_kartoffelbox.safetensors"
17
+ MAX_CHARS = 5000
18
+ CHUNK_CHAR_LIMIT = 300
19
+ SETTINGS_DIR = "settings"
20
+
21
+ # === Init ===
22
+ if not os.path.exists(SETTINGS_DIR):
23
+ os.makedirs(SETTINGS_DIR)
24
+
25
+ MODEL = None
26
+ print(f"🚀 Running on device: {DEVICE}")
27
+
28
+ def get_or_load_model():
29
+ global MODEL
30
+ if MODEL is None:
31
+ print("Model not loaded, initializing...")
32
+ MODEL = ChatterboxTTS.from_pretrained(DEVICE)
33
+ checkpoint_path = hf_hub_download(
34
+ repo_id=MODEL_REPO,
35
+ filename=T3_CHECKPOINT_FILE,
36
+ token=os.environ.get("HUGGING_FACE_HUB_TOKEN", "")
37
+ )
38
+ t3_state = load_file(checkpoint_path, device="cpu")
39
+ MODEL.t3.load_state_dict(t3_state)
40
+
41
+ # Position Embeddings erweitern
42
+ pos_emb_module = MODEL.t3.text_pos_emb
43
+ old_pos = pos_emb_module.emb.num_embeddings
44
+ if MAX_CHARS > old_pos:
45
+ emb_dim = pos_emb_module.emb.embedding_dim
46
+ new_emb = nn.Embedding(MAX_CHARS, emb_dim)
47
+ with torch.no_grad():
48
+ new_emb.weight[:old_pos] = pos_emb_module.emb.weight
49
+ pos_emb_module.emb = new_emb
50
+ print(f"Expanded position embeddings: {old_pos} → {MAX_CHARS}")
51
+
52
+ MODEL.t3.to(DEVICE)
53
+ MODEL.s3gen.to(DEVICE)
54
+ print(f"Model loaded. Device: {MODEL.device}")
55
+ return MODEL
56
+
57
+ try:
58
+ get_or_load_model()
59
+ except Exception as e:
60
+ print(f"CRITICAL: Failed to load model: {e}")
61
+
62
+ def set_seed(seed: int):
63
+ torch.manual_seed(seed)
64
+ if DEVICE == "cuda":
65
+ torch.cuda.manual_seed_all(seed)
66
+ random.seed(seed)
67
+ np.random.seed(seed)
68
+
69
+ def split_text_into_chunks(text, max_length=CHUNK_CHAR_LIMIT):
70
+ sentences = re.split(r'(?<=[.!?]) +', text)
71
+ chunks = []
72
+ chunk = ""
73
+ for sentence in sentences:
74
+ if len(chunk) + len(sentence) < max_length:
75
+ chunk += " " + sentence
76
+ else:
77
+ if chunk:
78
+ chunks.append(chunk.strip())
79
+ chunk = sentence
80
+ if chunk:
81
+ chunks.append(chunk.strip())
82
+ return chunks
83
+
84
+ # === Einstellungen speichern/laden ===
85
+ def list_presets():
86
+ return [f[:-5] for f in os.listdir(SETTINGS_DIR) if f.endswith(".json") and f != "last.json"]
87
+
88
+ def load_preset(name):
89
+ path = os.path.join(SETTINGS_DIR, name + ".json")
90
+ if os.path.exists(path):
91
+ with open(path, "r", encoding="utf-8") as f:
92
+ return json.load(f)
93
+ return None
94
+
95
+ def save_preset(name, data):
96
+ path = os.path.join(SETTINGS_DIR, name + ".json")
97
+ with open(path, "w", encoding="utf-8") as f:
98
+ json.dump(data, f, indent=2)
99
+ save_preset("last", data) # Als "zuletzt genutzt" speichern
100
+
101
+ def generate_tts_audio(text_input, audio_prompt_path_input, exaggeration_input, temperature_input, seed_num_input, cfgw_input):
102
+ model = get_or_load_model()
103
+ if seed_num_input != 0:
104
+ set_seed(int(seed_num_input))
105
+
106
+ full_audio = []
107
+ chunks = split_text_into_chunks(text_input[:MAX_CHARS])
108
+ print(f"Text wird in {len(chunks)} Teile aufgeteilt…")
109
+
110
+ for i, chunk in enumerate(chunks):
111
+ print(f"▶️ Teil {i+1}/{len(chunks)}: {chunk[:60]}...")
112
+ wav = model.generate(
113
+ chunk,
114
+ audio_prompt_path=audio_prompt_path_input,
115
+ exaggeration=exaggeration_input,
116
+ temperature=temperature_input,
117
+ cfg_weight=cfgw_input,
118
+ )
119
+ full_audio.append(wav.squeeze(0).cpu().numpy())
120
+
121
+ audio_concat = np.concatenate(full_audio)
122
+ return (model.sr, audio_concat)
123
+
124
+ with gr.Blocks() as demo:
125
+ with gr.Row():
126
+ gr.Markdown("# 🥔 Kartoffel-TTS (Chatterbox)\nLangtext → Sprachstil mit Profilen")
127
+
128
+ with gr.Row():
129
+ with gr.Column():
130
+ preset_dropdown = gr.Dropdown(label="🔄 Preset wählen", choices=list_presets(), value=None)
131
+ preset_name = gr.Textbox(label="📝 Name zum Speichern", value="mein-profil")
132
+
133
+ text = gr.Textbox(
134
+ value="Hier kannst du einen längeren deutschen Text eingeben…",
135
+ label=f"Text (max {MAX_CHARS} Zeichen)",
136
+ max_lines=12
137
+ )
138
+ ref_wav = gr.Audio(
139
+ sources=["upload", "microphone"],
140
+ type="filepath",
141
+ label="Referenz-Audiodatei (optional)",
142
+ value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
143
+ )
144
+ exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5)
145
+ cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.3)
146
+
147
+ with gr.Accordion("Weitere Optionen", open=False):
148
+ seed_num = gr.Number(value=0, label="Zufalls-Seed (0 = zufällig)")
149
+ temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.6)
150
+
151
+ save_btn = gr.Button("💾 Einstellungen speichern")
152
+ run_btn = gr.Button("🎤 Audio generieren")
153
+
154
+ with gr.Column():
155
+ audio_output = gr.Audio(label="🔊 Ergebnis")
156
+
157
+ # Funktionen zuweisen
158
+ def on_preset_selected(name):
159
+ if name:
160
+ p = load_preset(name)
161
+ if p:
162
+ return p["exaggeration"], p["temperature"], p["seed"], p["cfg"]
163
+ return gr.update(), gr.update(), gr.update(), gr.update()
164
+
165
+ preset_dropdown.change(
166
+ on_preset_selected,
167
+ inputs=[preset_dropdown],
168
+ outputs=[exaggeration, temp, seed_num, cfg_weight]
169
+ )
170
+
171
+ def save_current_settings(name, exaggeration, temperature, seed, cfg):
172
+ save_preset(name, {
173
+ "exaggeration": exaggeration,
174
+ "temperature": temperature,
175
+ "seed": seed,
176
+ "cfg": cfg
177
+ })
178
+ return gr.update(choices=list_presets())
179
+
180
+ save_btn.click(
181
+ fn=save_current_settings,
182
+ inputs=[preset_name, exaggeration, temp, seed_num, cfg_weight],
183
+ outputs=[preset_dropdown]
184
+ )
185
+
186
+ run_btn.click(
187
+ fn=generate_tts_audio,
188
+ inputs=[text, ref_wav, exaggeration, temp, seed_num, cfg_weight],
189
+ outputs=[audio_output],
190
+ )
191
+
192
+ # Letztes Profil beim Start laden
193
+ if os.path.exists(os.path.join(SETTINGS_DIR, "last.json")):
194
+ last = load_preset("last")
195
+ if last:
196
+ exaggeration.value = last["exaggeration"]
197
+ temp.value = last["temperature"]
198
+ seed_num.value = last["seed"]
199
+ cfg_weight.value = last["cfg"]
200
+
201
+ # 👇 ROBUSTER START – wichtig für exe ohne Konsole!
202
+ demo.launch(
203
+ quiet=True,
204
+ show_error=True,
205
+ prevent_thread_lock=False
206
+ )