prithivMLmods commited on
Commit
d7f29b6
·
verified ·
1 Parent(s): d57321f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -22
app.py CHANGED
@@ -8,6 +8,22 @@ import edge_tts
8
  import asyncio
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  DESCRIPTION = """
12
  # QwQ Edge 💬
13
  """
@@ -26,25 +42,21 @@ h1 {
26
  }
27
  '''
28
 
29
- MAX_MAX_NEW_TOKENS = 2048
30
- DEFAULT_MAX_NEW_TOKENS = 1024
31
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
32
-
33
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
-
35
- model_id = "prithivMLmods/FastThink-0.5B-Tiny"
36
- tokenizer = AutoTokenizer.from_pretrained(model_id)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_id,
39
- device_map="auto",
40
- torch_dtype=torch.bfloat16,
41
- )
42
- model.eval()
43
 
44
 
45
- async def text_to_speech(text: str, output_file="output.mp3"):
46
  """Convert text to speech using Edge TTS and save as MP3"""
47
- voice = "en-US-GuyNeural" # Change this to your preferred voice
48
  communicate = edge_tts.Communicate(text, voice)
49
  await communicate.save(output_file)
50
  return output_file
@@ -62,7 +74,24 @@ def generate(
62
  ):
63
  """Generates chatbot response and handles TTS requests"""
64
  is_tts = message.strip().lower().startswith("@tts")
65
- message = message.replace("@tts", "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  conversation = [*chat_history, {"role": "user", "content": message}]
68
 
@@ -95,7 +124,8 @@ def generate(
95
  final_response = "".join(outputs)
96
 
97
  if is_tts:
98
- output_file = asyncio.run(text_to_speech(final_response))
 
99
  yield gr.Audio(output_file, autoplay=True) # Return playable audio
100
  else:
101
  yield final_response # Return text response
@@ -112,12 +142,12 @@ demo = gr.ChatInterface(
112
  ],
113
  stop_btn=None,
114
  examples=[
115
- ["@tts Who is Nikola Tesla, and why did he die?"],
116
- ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
117
  ["Write a Python function to check if a number is prime."],
118
- ["@tts What causes rainbows to form?"],
119
  ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
120
- ["@tts What is the capital of France?"],
121
  ],
122
  cache_examples=False,
123
  type="messages",
 
8
  import asyncio
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
+ MAX_MAX_NEW_TOKENS = 2048
12
+ DEFAULT_MAX_NEW_TOKENS = 1024
13
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
14
+
15
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+
17
+ model_id = "prithivMLmods/FastThink-0.5B-Tiny"
18
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_id,
21
+ device_map="auto",
22
+ torch_dtype=torch.bfloat16,
23
+ )
24
+ model.eval()
25
+
26
+
27
  DESCRIPTION = """
28
  # QwQ Edge 💬
29
  """
 
42
  }
43
  '''
44
 
45
+ # List of voices
46
+ voices = [
47
+ "en-US-JennyNeural", # @tts1
48
+ "en-US-GuyNeural", # @tts2
49
+ "en-US-AriaNeural", # @tts3
50
+ "en-US-DavisNeural", # @tts4
51
+ "en-US-JaneNeural", # @tts5
52
+ "en-US-JasonNeural", # @tts6
53
+ "en-US-NancyNeural", # @tts7
54
+ "en-US-TonyNeural", # @tts8
55
+ ]
 
 
 
56
 
57
 
58
+ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
59
  """Convert text to speech using Edge TTS and save as MP3"""
 
60
  communicate = edge_tts.Communicate(text, voice)
61
  await communicate.save(output_file)
62
  return output_file
 
74
  ):
75
  """Generates chatbot response and handles TTS requests"""
76
  is_tts = message.strip().lower().startswith("@tts")
77
+ tts_index = None
78
+
79
+ if is_tts:
80
+ # Extract the number after @tts
81
+ tts_part = message.strip().lower().split()[0] # Get the @ttsX part
82
+ if len(tts_part) > 4: # Check if it's @ttsX (e.g., @tts1, @tts2, etc.)
83
+ try:
84
+ tts_index = int(tts_part[4:]) - 1 # Convert to 0-based index
85
+ if tts_index < 0 or tts_index >= len(voices):
86
+ gr.Warning(f"Invalid TTS voice index. Using default voice.")
87
+ tts_index = 0
88
+ except ValueError:
89
+ gr.Warning(f"Invalid TTS voice index. Using default voice.")
90
+ tts_index = 0
91
+ else:
92
+ tts_index = 0 # Default to the first voice if no number is provided
93
+
94
+ message = message.replace(tts_part, "").strip() # Remove @ttsX from the message
95
 
96
  conversation = [*chat_history, {"role": "user", "content": message}]
97
 
 
124
  final_response = "".join(outputs)
125
 
126
  if is_tts:
127
+ voice = voices[tts_index] # Select the voice based on the index
128
+ output_file = asyncio.run(text_to_speech(final_response, voice))
129
  yield gr.Audio(output_file, autoplay=True) # Return playable audio
130
  else:
131
  yield final_response # Return text response
 
142
  ],
143
  stop_btn=None,
144
  examples=[
145
+ ["@tts1 Who is Nikola Tesla, and why did he die?"],
146
+ ["@tts2 A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
147
  ["Write a Python function to check if a number is prime."],
148
+ ["@tts3 What causes rainbows to form?"],
149
  ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
150
+ ["@tts4 What is the capital of France?"],
151
  ],
152
  cache_examples=False,
153
  type="messages",