Spaces:
Sleeping
Sleeping
feat:增加device 参数
Browse filesfix: set_kv_cache使用默认device问题
- inference.py +1 -1
- server.py +2 -2
inference.py
CHANGED
|
@@ -399,7 +399,7 @@ class OmniInference:
|
|
| 399 |
model = self.model
|
| 400 |
|
| 401 |
with self.fabric.init_tensor():
|
| 402 |
-
model.set_kv_cache(batch_size=2)
|
| 403 |
|
| 404 |
mel, leng = load_audio(audio_path)
|
| 405 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
|
|
|
| 399 |
model = self.model
|
| 400 |
|
| 401 |
with self.fabric.init_tensor():
|
| 402 |
+
model.set_kv_cache(batch_size=2,device=self.device)
|
| 403 |
|
| 404 |
mel, leng = load_audio(audio_path)
|
| 405 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
server.py
CHANGED
|
@@ -46,9 +46,9 @@ def create_app():
|
|
| 46 |
return server.server
|
| 47 |
|
| 48 |
|
| 49 |
-
def serve(ip='0.0.0.0', port=60808):
|
| 50 |
|
| 51 |
-
OmniChatServer(ip, port=port,
|
| 52 |
|
| 53 |
|
| 54 |
if __name__ == "__main__":
|
|
|
|
| 46 |
return server.server
|
| 47 |
|
| 48 |
|
| 49 |
+
def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
|
| 50 |
|
| 51 |
+
OmniChatServer(ip, port=port,run_app=True, device=device)
|
| 52 |
|
| 53 |
|
| 54 |
if __name__ == "__main__":
|