import os, json, asyncio import torch from fastapi import FastAPI, WebSocket, WebSocketDisconnect from dotenv import load_dotenv from snac import SNAC from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import login, snapshot_download load_dotenv() if (tok := os.getenv("HF_TOKEN")): login(token=tok) device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading SNAC…") snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" snapshot_download( repo_id=model_name, allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"], ignore_patterns=[ "optimizer.pt", "pytorch_model.bin", "training_args.bin", "scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt" ] ) print("Loading Orpheus…") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16 ) model = model.to(device) model.config.pad_token_id = model.config.eos_token_id tokenizer = AutoTokenizer.from_pretrained(model_name) # — Helper Functions (wie gehabt) — def process_prompt(text: str, voice: str): prompt = f"{voice}: {text}" inputs = tokenizer(prompt, return_tensors="pt").to(device) start = torch.tensor([[128259]], device=device) end = torch.tensor([[128009, 128260]], device=device) return torch.cat([start, inputs.input_ids, end], dim=1) def parse_output(ids: torch.LongTensor): st, rm = 128257, 128258 idxs = (ids==st).nonzero(as_tuple=True)[1] cropped = ids[:, idxs[-1].item()+1:] if idxs.numel()>0 else ids row = cropped[0][cropped[0]!=rm] return row.tolist() def redistribute_codes(codes: list[int], snac_model: SNAC): # … genau wie vorher … # return numpy array app = FastAPI() @app.get("/") async def root(): return {"status":"ok","msg":"Hello, Orpheus TTS up!"} @app.websocket("/ws/tts") async def ws_tts(ws: WebSocket): await ws.accept() try: msg = json.loads(await ws.receive_text()) text, voice = msg.get("text",""), msg.get("voice","Jakob") ids = process_prompt(text, voice) gen = model.generate( input_ids=ids, max_new_tokens=2000, do_sample=True, temperature=0.7, top_p=0.95, repetition_penalty=1.1, eos_token_id=model.config.eos_token_id, ) codes = parse_output(gen) audio_np = redistribute_codes(codes, snac) pcm16 = (audio_np*32767).astype("int16").tobytes() chunk = 2400*2 for i in range(0,len(pcm16),chunk): await ws.send_bytes(pcm16[i:i+chunk]) await asyncio.sleep(0.1) await ws.close() except WebSocketDisconnect: print("Client left") except Exception as e: print("Error in /ws/tts:",e) await ws.close(code=1011) if __name__=="__main__": import uvicorn uvicorn.run("app:app",host="0.0.0.0",port=7860)