Ikaros commited on
Commit
8c553b3
·
1 Parent(s): 007a3ff

refactor: simplify to pure websocket server

Browse files
Files changed (2) hide show
  1. app.py +33 -61
  2. requirements.txt +0 -2
app.py CHANGED
@@ -1,13 +1,10 @@
1
  import asyncio
2
  import websockets
3
  import json
4
- import threading
5
- from flask import Flask, request, jsonify
6
  import numpy as np
7
  from music_generator import MusicGenerator
8
 
9
- # --- Existing Flask App Setup ---
10
- app = Flask(__name__)
11
 
12
  # Load the consonance matrix
13
  with open('consonance_matrix.json') as f:
@@ -16,53 +13,32 @@ with open('consonance_matrix.json') as f:
16
  notes = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
17
  generator = MusicGenerator(len(notes))
18
 
19
- def note_to_index(note):
20
- return notes.index(note.split(':')[0])
21
-
22
- def index_to_note(index):
23
- return notes[index]
24
-
25
- # (Keep all the existing @app.route endpoints for now)
26
- @app.route('/predict', methods=['POST'])
27
- def predict():
28
- # This route will likely be deprecated in favor of WebSockets
29
- # but we keep it for now.
30
- data = request.get_json()
31
- history = data.get('history', [])
32
- if len(history) < 1:
33
- return jsonify({'prediction': 'N/A'})
34
- try:
35
- last_note_index = note_to_index(history[-1]['chord'])
36
- prediction_index = generator.generate([last_note_index], length=1)[-1]
37
- prediction = index_to_note(prediction_index)
38
- except (ValueError, IndexError):
39
- prediction = 'N/A'
40
- return jsonify({'prediction': prediction})
41
-
42
-
43
- # --- WebSocket Server Setup ---
44
-
45
  # In-memory storage for connected clients
46
- # We'll have two types of clients: 'extension' and 'webapp'
47
  clients = {
48
  "webapp": set()
49
  }
50
- # We only need one audio source, so we don't need a set for the extension.
51
  audio_source = None
52
 
53
  async def broadcast_to_webapps(message):
54
  """Sends a message to all connected webapp clients."""
55
  if clients["webapp"]:
56
- await asyncio.wait([client.send(message) for client in clients["webapp"]])
 
 
 
 
 
 
 
 
 
 
57
 
58
  async def handle_audio_data(data):
59
  """
60
  This is the core audio processing function.
61
  For now, it will just mock the analysis.
62
- In the future, this is where we'll plug in our TensorFlow model.
63
  """
64
- # Mock analysis: Pretend we detected a chord and generated a prediction.
65
- # We can make this more interesting by picking a random chord.
66
  import random
67
  detected_chord = random.choice(notes)
68
  predicted_chord = random.choice(notes)
@@ -90,15 +66,26 @@ async def connection_handler(websocket, path):
90
  client_type = message_data.get("type")
91
 
92
  if client_type == "extension_hello":
 
 
 
 
 
93
  audio_source = websocket
94
- clients["webapp"].add(websocket) # Also treat extension as a webapp to receive messages
95
  print("Audio capture extension connected.")
96
  await websocket.send(json.dumps({"status": "connected", "role": "audio_source"}))
 
 
97
 
98
  elif client_type == "webapp_hello":
99
  clients["webapp"].add(websocket)
100
  print("Web app client connected.")
101
  await websocket.send(json.dumps({"status": "connected", "role": "viewer"}))
 
 
 
 
 
102
 
103
  else:
104
  print(f"Unknown client type: {client_type}. Disconnecting.")
@@ -108,12 +95,10 @@ async def connection_handler(websocket, path):
108
  async for message in websocket:
109
  if websocket == audio_source:
110
  # This is audio data from the extension
111
- # For now, we assume the message is a chunk of audio data.
112
- # We will simply trigger our mock analysis.
113
  await handle_audio_data(message)
114
 
115
  except websockets.exceptions.ConnectionClosed:
116
- print("Client disconnected.")
117
  finally:
118
  # Remove the client from our sets upon disconnection
119
  if websocket in clients["webapp"]:
@@ -121,28 +106,15 @@ async def connection_handler(websocket, path):
121
  if websocket == audio_source:
122
  audio_source = None
123
  print("Audio capture extension disconnected.")
 
124
 
125
 
126
- def run_flask_app():
127
- """Runs the Flask app in a separate thread."""
128
- # Note: Using Flask's development server is not ideal for production.
129
- # A proper WSGI server like Gunicorn should be used.
130
- # But for Hugging Face Spaces, this is often sufficient.
131
- app.run(host='0.0.0.0', port=5000)
132
-
133
-
134
- if __name__ == "__main__":
135
- # Start the Flask app in a background thread
136
- flask_thread = threading.Thread(target=run_flask_app)
137
- flask_thread.daemon = True
138
- flask_thread.start()
139
-
140
- # Start the WebSocket server
141
- # Hugging Face Spaces exposes port 7860 by default for web traffic.
142
- # We will use this port for our WebSocket server.
143
  websocket_port = 7860
144
- print(f"Starting WebSocket server on port {websocket_port}...")
145
- start_server = websockets.serve(connection_handler, "0.0.0.0", websocket_port)
 
146
 
147
- asyncio.get_event_loop().run_until_complete(start_server)
148
- asyncio.get_event_loop().run_forever()
 
1
  import asyncio
2
  import websockets
3
  import json
 
 
4
  import numpy as np
5
  from music_generator import MusicGenerator
6
 
7
+ # --- WebSocket Server Setup ---
 
8
 
9
  # Load the consonance matrix
10
  with open('consonance_matrix.json') as f:
 
13
  notes = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
14
  generator = MusicGenerator(len(notes))
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # In-memory storage for connected clients
 
17
  clients = {
18
  "webapp": set()
19
  }
 
20
  audio_source = None
21
 
22
  async def broadcast_to_webapps(message):
23
  """Sends a message to all connected webapp clients."""
24
  if clients["webapp"]:
25
+ # Create a copy of the set to avoid issues if a client disconnects during broadcast
26
+ disconnected_clients = set()
27
+ for client in clients["webapp"]:
28
+ try:
29
+ await client.send(message)
30
+ except websockets.exceptions.ConnectionClosed:
31
+ disconnected_clients.add(client)
32
+ # Remove clients that have disconnected
33
+ for client in disconnected_clients:
34
+ clients["webapp"].remove(client)
35
+
36
 
37
  async def handle_audio_data(data):
38
  """
39
  This is the core audio processing function.
40
  For now, it will just mock the analysis.
 
41
  """
 
 
42
  import random
43
  detected_chord = random.choice(notes)
44
  predicted_chord = random.choice(notes)
 
66
  client_type = message_data.get("type")
67
 
68
  if client_type == "extension_hello":
69
+ if audio_source is not None:
70
+ # If there's already an extension connected, disconnect the old one.
71
+ print("An extension is already connected. Disconnecting the old one.")
72
+ await audio_source.close(reason="New extension connected.")
73
+
74
  audio_source = websocket
 
75
  print("Audio capture extension connected.")
76
  await websocket.send(json.dumps({"status": "connected", "role": "audio_source"}))
77
+ await broadcast_to_webapps(json.dumps({"type": "status_update", "message": "Audio source connected."}))
78
+
79
 
80
  elif client_type == "webapp_hello":
81
  clients["webapp"].add(websocket)
82
  print("Web app client connected.")
83
  await websocket.send(json.dumps({"status": "connected", "role": "viewer"}))
84
+ if audio_source is not None:
85
+ await websocket.send(json.dumps({"type": "status_update", "message": "Audio source connected."}))
86
+ else:
87
+ await websocket.send(json.dumps({"type": "status_update", "message": "Waiting for audio source..."}))
88
+
89
 
90
  else:
91
  print(f"Unknown client type: {client_type}. Disconnecting.")
 
95
  async for message in websocket:
96
  if websocket == audio_source:
97
  # This is audio data from the extension
 
 
98
  await handle_audio_data(message)
99
 
100
  except websockets.exceptions.ConnectionClosed:
101
+ print(f"Client disconnected: {websocket.remote_address}")
102
  finally:
103
  # Remove the client from our sets upon disconnection
104
  if websocket in clients["webapp"]:
 
106
  if websocket == audio_source:
107
  audio_source = None
108
  print("Audio capture extension disconnected.")
109
+ await broadcast_to_webapps(json.dumps({"type": "status_update", "message": "Audio source disconnected."}))
110
 
111
 
112
+ async def main():
113
+ """Starts the WebSocket server."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  websocket_port = 7860
115
+ print(f"Starting pure WebSocket server on port {websocket_port}...")
116
+ async with websockets.serve(connection_handler, "0.0.0.0", websocket_port):
117
+ await asyncio.Future() # run forever
118
 
119
+ if __name__ == "__main__":
120
+ asyncio.run(main())
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
  networkx==3.3
2
  numpy==1.26.4
3
- flask
4
- flask-cors
5
  tensorflow
6
  websockets
 
1
  networkx==3.3
2
  numpy==1.26.4
 
 
3
  tensorflow
4
  websockets