freddyaboulton HF staff commited on
Commit
19b72df
·
verified ·
1 Parent(s): 0c83ad6

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +9 -6
  2. app.py +136 -0
  3. index.html +335 -0
  4. requirements.txt +6 -0
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
- title: Talk To Claude Gradio
3
- emoji:
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.16.1
8
  app_file: app.py
9
  pinned: false
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Talk to Claude
3
+ emoji: 👨‍🦰
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.16.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ short_description: Talk to Anthropic's Claude
12
+ tags: [webrtc, websocket, gradio, secret|TWILIO_ACCOUNT_SID, secret|TWILIO_AUTH_TOKEN, secret|GROQ_API_KEY, secret|ANTHROPIC_API_KEY, secret|ELEVENLABS_API_KEY]
13
  ---
14
 
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import anthropic
6
+ import gradio as gr
7
+ import numpy as np
8
+ from dotenv import load_dotenv
9
+ from elevenlabs import ElevenLabs
10
+ from fastapi import FastAPI
11
+ from fastapi.responses import HTMLResponse, StreamingResponse
12
+ from fastrtc import (
13
+ AdditionalOutputs,
14
+ ReplyOnPause,
15
+ Stream,
16
+ get_tts_model,
17
+ get_twilio_turn_credentials,
18
+ )
19
+ from fastrtc.utils import audio_to_bytes
20
+ from gradio.utils import get_space
21
+ from groq import Groq
22
+ from pydantic import BaseModel
23
+
24
+ load_dotenv()
25
+
26
+ groq_client = Groq()
27
+ claude_client = anthropic.Anthropic()
28
+ tts_client = ElevenLabs(api_key=os.environ["ELEVENLABS_API_KEY"])
29
+
30
+ curr_dir = Path(__file__).parent
31
+
32
+ tts_model = get_tts_model()
33
+
34
+
35
+ def response(
36
+ audio: tuple[int, np.ndarray],
37
+ chatbot: list[dict] | None = None,
38
+ ):
39
+ chatbot = chatbot or []
40
+ messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
41
+ prompt = groq_client.audio.transcriptions.create(
42
+ file=("audio-file.mp3", audio_to_bytes(audio)),
43
+ model="whisper-large-v3-turbo",
44
+ response_format="verbose_json",
45
+ ).text
46
+ print("prompt", prompt)
47
+ chatbot.append({"role": "user", "content": prompt})
48
+ yield AdditionalOutputs(chatbot)
49
+ messages.append({"role": "user", "content": prompt})
50
+ response = claude_client.messages.create(
51
+ model="claude-3-5-haiku-20241022",
52
+ max_tokens=512,
53
+ messages=messages, # type: ignore
54
+ )
55
+ response_text = " ".join(
56
+ block.text # type: ignore
57
+ for block in response.content
58
+ if getattr(block, "type", None) == "text"
59
+ )
60
+ chatbot.append({"role": "assistant", "content": response_text})
61
+ import time
62
+
63
+ start = time.time()
64
+
65
+ print("starting tts", start)
66
+ for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)):
67
+ print("chunk", i, time.time() - start)
68
+ yield chunk
69
+ print("finished tts", time.time() - start)
70
+ yield AdditionalOutputs(chatbot)
71
+
72
+
73
+ chatbot = gr.Chatbot(type="messages")
74
+ stream = Stream(
75
+ modality="audio",
76
+ mode="send-receive",
77
+ handler=ReplyOnPause(response),
78
+ additional_outputs_handler=lambda a, b: b,
79
+ additional_inputs=[chatbot],
80
+ additional_outputs=[chatbot],
81
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
82
+ concurrency_limit=20 if get_space() else None,
83
+ )
84
+
85
+
86
+ class Message(BaseModel):
87
+ role: str
88
+ content: str
89
+
90
+
91
+ class InputData(BaseModel):
92
+ webrtc_id: str
93
+ chatbot: list[Message]
94
+
95
+
96
+ app = FastAPI()
97
+ stream.mount(app)
98
+
99
+
100
+ @app.get("/")
101
+ async def _():
102
+ rtc_config = get_twilio_turn_credentials() if get_space() else None
103
+ html_content = (curr_dir / "index.html").read_text()
104
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
105
+ return HTMLResponse(content=html_content, status_code=200)
106
+
107
+
108
+ @app.post("/input_hook")
109
+ async def _(body: InputData):
110
+ stream.set_input(body.webrtc_id, body.model_dump()["chatbot"])
111
+ return {"status": "ok"}
112
+
113
+
114
+ @app.get("/outputs")
115
+ def _(webrtc_id: str):
116
+ async def output_stream():
117
+ async for output in stream.output_stream(webrtc_id):
118
+ chatbot = output.args[0]
119
+ if len(chatbot) > 1:
120
+ yield f"event: output\ndata: {json.dumps(chatbot[-2])}\n\n"
121
+ yield f"event: output\ndata: {json.dumps(chatbot[-1])}\n\n"
122
+
123
+ return StreamingResponse(output_stream(), media_type="text/event-stream")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ import os
128
+
129
+ if (mode := os.getenv("MODE")) == "UI":
130
+ stream.ui.launch(server_port=7860)
131
+ elif mode == "PHONE":
132
+ stream.fastphone(host="0.0.0.0", port=7860)
133
+ else:
134
+ import uvicorn
135
+
136
+ uvicorn.run(app, host="0.0.0.0", port=7860)
index.html ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>RetroChat Audio</title>
8
+ <style>
9
+ body {
10
+ font-family: monospace;
11
+ background-color: #1a1a1a;
12
+ color: #00ff00;
13
+ margin: 0;
14
+ padding: 20px;
15
+ height: 100vh;
16
+ box-sizing: border-box;
17
+ }
18
+
19
+ .container {
20
+ display: grid;
21
+ grid-template-columns: 1fr 1fr;
22
+ gap: 20px;
23
+ height: calc(100% - 100px);
24
+ margin-bottom: 20px;
25
+ }
26
+
27
+ .visualization-container {
28
+ border: 2px solid #00ff00;
29
+ padding: 20px;
30
+ display: flex;
31
+ flex-direction: column;
32
+ align-items: center;
33
+ position: relative;
34
+ }
35
+
36
+ #visualizer {
37
+ width: 100%;
38
+ height: 100%;
39
+ background-color: #000;
40
+ }
41
+
42
+ .chat-container {
43
+ border: 2px solid #00ff00;
44
+ padding: 20px;
45
+ display: flex;
46
+ flex-direction: column;
47
+ height: 100%;
48
+ box-sizing: border-box;
49
+ }
50
+
51
+ .chat-messages {
52
+ flex-grow: 1;
53
+ overflow-y: auto;
54
+ margin-bottom: 20px;
55
+ padding: 10px;
56
+ border: 1px solid #00ff00;
57
+ }
58
+
59
+ .message {
60
+ margin-bottom: 10px;
61
+ padding: 8px;
62
+ border-radius: 4px;
63
+ }
64
+
65
+ .message.user {
66
+ background-color: #003300;
67
+ }
68
+
69
+ .message.assistant {
70
+ background-color: #002200;
71
+ }
72
+
73
+ .controls {
74
+ text-align: center;
75
+ }
76
+
77
+ button {
78
+ background-color: #000;
79
+ color: #00ff00;
80
+ border: 2px solid #00ff00;
81
+ padding: 10px 20px;
82
+ font-family: monospace;
83
+ font-size: 16px;
84
+ cursor: pointer;
85
+ transition: all 0.3s;
86
+ }
87
+
88
+ button:hover {
89
+ background-color: #00ff00;
90
+ color: #000;
91
+ }
92
+
93
+ #audio-output {
94
+ display: none;
95
+ }
96
+
97
+ /* Retro CRT effect */
98
+ .crt-overlay {
99
+ position: absolute;
100
+ top: 0;
101
+ left: 0;
102
+ width: 100%;
103
+ height: 100%;
104
+ background: repeating-linear-gradient(0deg,
105
+ rgba(0, 255, 0, 0.03),
106
+ rgba(0, 255, 0, 0.03) 1px,
107
+ transparent 1px,
108
+ transparent 2px);
109
+ pointer-events: none;
110
+ }
111
+ </style>
112
+ </head>
113
+
114
+ <body>
115
+ <div class="container">
116
+ <div class="visualization-container">
117
+ <canvas id="visualizer"></canvas>
118
+ <div class="crt-overlay"></div>
119
+ </div>
120
+ <div class="chat-container">
121
+ <div class="chat-messages" id="chat-messages"></div>
122
+ </div>
123
+ </div>
124
+ <div class="controls">
125
+ <button id="start-button">Start</button>
126
+ </div>
127
+ <audio id="audio-output"></audio>
128
+
129
+ <script>
130
+ let audioContext;
131
+ let analyser;
132
+ let dataArray;
133
+ let animationId;
134
+ let chatHistory = [];
135
+ let peerConnection;
136
+ let webrtc_id;
137
+
138
+ const visualizer = document.getElementById('visualizer');
139
+ const ctx = visualizer.getContext('2d');
140
+ const audioOutput = document.getElementById('audio-output');
141
+ const startButton = document.getElementById('start-button');
142
+ const chatMessages = document.getElementById('chat-messages');
143
+
144
+ // Set canvas size
145
+ function resizeCanvas() {
146
+ visualizer.width = visualizer.offsetWidth;
147
+ visualizer.height = visualizer.offsetHeight;
148
+ }
149
+
150
+ window.addEventListener('resize', resizeCanvas);
151
+ resizeCanvas();
152
+
153
+ // Initialize WebRTC
154
+ async function setupWebRTC() {
155
+ const config = __RTC_CONFIGURATION__;
156
+ peerConnection = new RTCPeerConnection(config);
157
+
158
+ try {
159
+ const stream = await navigator.mediaDevices.getUserMedia({
160
+ audio: true
161
+ });
162
+
163
+ stream.getTracks().forEach(track => {
164
+ peerConnection.addTrack(track, stream);
165
+ });
166
+
167
+ // Audio visualization will be set up when we receive the output stream
168
+
169
+ // Handle incoming audio
170
+ peerConnection.addEventListener('track', (evt) => {
171
+ if (audioOutput && audioOutput.srcObject !== evt.streams[0]) {
172
+ audioOutput.srcObject = evt.streams[0];
173
+ audioOutput.play();
174
+
175
+ // Set up audio visualization on the output stream
176
+ audioContext = new AudioContext();
177
+ analyser = audioContext.createAnalyser();
178
+ const source = audioContext.createMediaStreamSource(evt.streams[0]);
179
+ source.connect(analyser);
180
+ analyser.fftSize = 2048;
181
+ dataArray = new Uint8Array(analyser.frequencyBinCount);
182
+ }
183
+ });
184
+
185
+ // Create data channel for messages
186
+ const dataChannel = peerConnection.createDataChannel('text');
187
+ dataChannel.onmessage = handleMessage;
188
+
189
+ // Create and send offer
190
+ const offer = await peerConnection.createOffer();
191
+ await peerConnection.setLocalDescription(offer);
192
+
193
+ await new Promise((resolve) => {
194
+ if (peerConnection.iceGatheringState === "complete") {
195
+ resolve();
196
+ } else {
197
+ const checkState = () => {
198
+ if (peerConnection.iceGatheringState === "complete") {
199
+ peerConnection.removeEventListener("icegatheringstatechange", checkState);
200
+ resolve();
201
+ }
202
+ };
203
+ peerConnection.addEventListener("icegatheringstatechange", checkState);
204
+ }
205
+ });
206
+
207
+ webrtc_id = Math.random().toString(36).substring(7);
208
+
209
+ const response = await fetch('/webrtc/offer', {
210
+ method: 'POST',
211
+ headers: { 'Content-Type': 'application/json' },
212
+ body: JSON.stringify({
213
+ sdp: peerConnection.localDescription.sdp,
214
+ type: peerConnection.localDescription.type,
215
+ webrtc_id: webrtc_id
216
+ })
217
+ });
218
+
219
+ const serverResponse = await response.json();
220
+ await peerConnection.setRemoteDescription(serverResponse);
221
+
222
+ // Start visualization
223
+ draw();
224
+
225
+ // create event stream to receive messages from /output
226
+ const eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id);
227
+ eventSource.addEventListener("output", (event) => {
228
+ const eventJson = JSON.parse(event.data);
229
+ addMessage(eventJson.role, eventJson.content);
230
+ });
231
+ } catch (err) {
232
+ console.error('Error setting up WebRTC:', err);
233
+ }
234
+ }
235
+
236
+ function handleMessage(event) {
237
+ const eventJson = JSON.parse(event.data);
238
+ if (eventJson.type === "send_input") {
239
+ fetch('/input_hook', {
240
+ method: 'POST',
241
+ headers: {
242
+ 'Content-Type': 'application/json',
243
+ },
244
+ body: JSON.stringify({
245
+ webrtc_id: webrtc_id,
246
+ chatbot: chatHistory
247
+ })
248
+ });
249
+ }
250
+ }
251
+
252
+ function addMessage(role, content) {
253
+ const messageDiv = document.createElement('div');
254
+ messageDiv.classList.add('message', role);
255
+ messageDiv.textContent = content;
256
+ chatMessages.appendChild(messageDiv);
257
+ chatMessages.scrollTop = chatMessages.scrollHeight;
258
+ chatHistory.push({ role, content });
259
+ }
260
+
261
+ function draw() {
262
+ animationId = requestAnimationFrame(draw);
263
+
264
+ analyser.getByteTimeDomainData(dataArray);
265
+
266
+ ctx.fillStyle = 'rgb(0, 0, 0)';
267
+ ctx.fillRect(0, 0, visualizer.width, visualizer.height);
268
+
269
+ ctx.lineWidth = 2;
270
+ ctx.strokeStyle = 'rgb(0, 255, 0)';
271
+ ctx.beginPath();
272
+
273
+ const sliceWidth = visualizer.width / dataArray.length;
274
+ let x = 0;
275
+
276
+ for (let i = 0; i < dataArray.length; i++) {
277
+ const v = dataArray[i] / 128.0;
278
+ const y = v * visualizer.height / 2;
279
+
280
+ if (i === 0) {
281
+ ctx.moveTo(x, y);
282
+ } else {
283
+ ctx.lineTo(x, y);
284
+ }
285
+
286
+ x += sliceWidth;
287
+ }
288
+
289
+ ctx.lineTo(visualizer.width, visualizer.height / 2);
290
+ ctx.stroke();
291
+ }
292
+
293
+ function stop() {
294
+ if (peerConnection) {
295
+ if (peerConnection.getTransceivers) {
296
+ peerConnection.getTransceivers().forEach(transceiver => {
297
+ if (transceiver.stop) {
298
+ transceiver.stop();
299
+ }
300
+ });
301
+ }
302
+
303
+ if (peerConnection.getSenders) {
304
+ peerConnection.getSenders().forEach(sender => {
305
+ if (sender.track && sender.track.stop) sender.track.stop();
306
+ });
307
+ }
308
+
309
+ setTimeout(() => {
310
+ peerConnection.close();
311
+ }, 500);
312
+ }
313
+
314
+ if (animationId) {
315
+ cancelAnimationFrame(animationId);
316
+ }
317
+
318
+ if (audioContext) {
319
+ audioContext.close();
320
+ }
321
+ }
322
+
323
+ startButton.addEventListener('click', () => {
324
+ if (startButton.textContent === 'Start') {
325
+ setupWebRTC();
326
+ startButton.textContent = 'Stop';
327
+ } else {
328
+ stop();
329
+ startButton.textContent = 'Start';
330
+ }
331
+ });
332
+ </script>
333
+ </body>
334
+
335
+ </html>
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastrtc[vad, tts]
2
+ elevenlabs
3
+ groq
4
+ anthropic
5
+ twilio
6
+ python-dotenv