Update app.py
Browse files
app.py
CHANGED
@@ -43,7 +43,7 @@ This demo allows you to have a conversation with Sesame CSM 1B, leveraging Whisp
|
|
43 |
# --- Constants --- (Constants can stay outside)
|
44 |
SPEAKER_ID = 0
|
45 |
MAX_CONTEXT_SEGMENTS = 5
|
46 |
-
MAX_GEMMA_LENGTH =
|
47 |
|
48 |
# --- Global Conversation History ---
|
49 |
conversation_history = []
|
@@ -62,7 +62,7 @@ def transcribe_audio(audio_path: str, whisper_model) -> str: # Pass whisper_mod
|
|
62 |
|
63 |
def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: # Pass model and tokenizer
|
64 |
try:
|
65 |
-
input_text = "
|
66 |
input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
|
67 |
generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True)
|
68 |
return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True)
|
@@ -110,7 +110,7 @@ def infer(user_audio) -> tuple[int, np.ndarray]:
|
|
110 |
whisper_model = whisper.load_model("small.en", device=device)
|
111 |
logging.info("Whisper model loaded successfully.")
|
112 |
|
113 |
-
tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-
|
114 |
model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it").to(device)
|
115 |
logging.info("Gemma 3 1B pt model loaded successfully.")
|
116 |
|
|
|
43 |
# --- Constants --- (Constants can stay outside)
|
44 |
SPEAKER_ID = 0
|
45 |
MAX_CONTEXT_SEGMENTS = 5
|
46 |
+
MAX_GEMMA_LENGTH = 150
|
47 |
|
48 |
# --- Global Conversation History ---
|
49 |
conversation_history = []
|
|
|
62 |
|
63 |
def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: # Pass model and tokenizer
|
64 |
try:
|
65 |
+
input_text = "Reapond to the user: " + text
|
66 |
input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
|
67 |
generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True)
|
68 |
return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True)
|
|
|
110 |
whisper_model = whisper.load_model("small.en", device=device)
|
111 |
logging.info("Whisper model loaded successfully.")
|
112 |
|
113 |
+
tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
|
114 |
model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it").to(device)
|
115 |
logging.info("Gemma 3 1B pt model loaded successfully.")
|
116 |
|