Spaces:
Running
Running
Commit
·
d709cdc
1
Parent(s):
814fa89
更新模型初始化,添加设备参数支持,并将device_map默认值修改为None,以提高灵活性和兼容性。
Browse files
examples/simple_llm.py
CHANGED
|
@@ -16,15 +16,16 @@ if __name__ == "__main__":
|
|
| 16 |
try:
|
| 17 |
model_name = "google/gemma-3-4b-it"
|
| 18 |
use_4bit_quantization = False
|
|
|
|
| 19 |
|
| 20 |
# gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
|
| 21 |
# 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
|
| 22 |
if model_name.startswith("mlx-community"):
|
| 23 |
gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
|
| 24 |
elif model_name.startswith("microsoft"):
|
| 25 |
-
gemma_chat = Phi4TransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization)
|
| 26 |
else:
|
| 27 |
-
gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization)
|
| 28 |
|
| 29 |
print("\n--- 示例 1: 简单用户查询 ---")
|
| 30 |
messages_example1 = [
|
|
|
|
| 16 |
try:
|
| 17 |
model_name = "google/gemma-3-4b-it"
|
| 18 |
use_4bit_quantization = False
|
| 19 |
+
device = "mps"
|
| 20 |
|
| 21 |
# gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
|
| 22 |
# 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
|
| 23 |
if model_name.startswith("mlx-community"):
|
| 24 |
gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
|
| 25 |
elif model_name.startswith("microsoft"):
|
| 26 |
+
gemma_chat = Phi4TransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device)
|
| 27 |
else:
|
| 28 |
+
gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device)
|
| 29 |
|
| 30 |
print("\n--- 示例 1: 简单用户查询 ---")
|
| 31 |
messages_example1 = [
|
src/podcast_transcribe/llm/llm_base.py
CHANGED
|
@@ -182,7 +182,7 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
|
|
| 182 |
self,
|
| 183 |
model_name: str,
|
| 184 |
use_4bit_quantization: bool = False,
|
| 185 |
-
device_map: Optional[str] =
|
| 186 |
device: Optional[str] = None,
|
| 187 |
trust_remote_code: bool = True,
|
| 188 |
torch_dtype: Optional[torch.dtype] = None
|
|
|
|
| 182 |
self,
|
| 183 |
model_name: str,
|
| 184 |
use_4bit_quantization: bool = False,
|
| 185 |
+
device_map: Optional[str] = None,
|
| 186 |
device: Optional[str] = None,
|
| 187 |
trust_remote_code: bool = True,
|
| 188 |
torch_dtype: Optional[torch.dtype] = None
|
src/podcast_transcribe/llm/llm_gemma_transfomers.py
CHANGED
|
@@ -11,7 +11,7 @@ class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
|
|
| 11 |
self,
|
| 12 |
model_name: str = "google/gemma-3-4b-it",
|
| 13 |
use_4bit_quantization: bool = False,
|
| 14 |
-
device_map: Optional[str] =
|
| 15 |
device: Optional[str] = None,
|
| 16 |
trust_remote_code: bool = True
|
| 17 |
):
|
|
|
|
| 11 |
self,
|
| 12 |
model_name: str = "google/gemma-3-4b-it",
|
| 13 |
use_4bit_quantization: bool = False,
|
| 14 |
+
device_map: Optional[str] = None,
|
| 15 |
device: Optional[str] = None,
|
| 16 |
trust_remote_code: bool = True
|
| 17 |
):
|
src/podcast_transcribe/llm/llm_phi4_transfomers.py
CHANGED
|
@@ -11,7 +11,7 @@ class Phi4TransformersChatCompletion(TransformersBaseChatCompletion):
|
|
| 11 |
self,
|
| 12 |
model_name: str = "microsoft/Phi-4-mini-reasoning",
|
| 13 |
use_4bit_quantization: bool = False,
|
| 14 |
-
device_map: Optional[str] =
|
| 15 |
device: Optional[str] = None,
|
| 16 |
trust_remote_code: bool = True
|
| 17 |
):
|
|
|
|
| 11 |
self,
|
| 12 |
model_name: str = "microsoft/Phi-4-mini-reasoning",
|
| 13 |
use_4bit_quantization: bool = False,
|
| 14 |
+
device_map: Optional[str] = None,
|
| 15 |
device: Optional[str] = None,
|
| 16 |
trust_remote_code: bool = True
|
| 17 |
):
|
src/podcast_transcribe/llm/llm_router.py
CHANGED
|
@@ -379,7 +379,7 @@ def chat_completion(
|
|
| 379 |
model: Optional[str] = None,
|
| 380 |
device: Optional[str] = None,
|
| 381 |
use_4bit_quantization: bool = False,
|
| 382 |
-
device_map: Optional[str] =
|
| 383 |
trust_remote_code: bool = True,
|
| 384 |
**kwargs
|
| 385 |
) -> Dict[str, Any]:
|
|
@@ -448,7 +448,7 @@ def chat_completion(
|
|
| 448 |
params["device"] = device
|
| 449 |
if use_4bit_quantization:
|
| 450 |
params["use_4bit_quantization"] = use_4bit_quantization
|
| 451 |
-
if device_map
|
| 452 |
params["device_map"] = device_map
|
| 453 |
if not trust_remote_code:
|
| 454 |
params["trust_remote_code"] = trust_remote_code
|
|
@@ -473,7 +473,7 @@ def reasoning_completion(
|
|
| 473 |
model: Optional[str] = None,
|
| 474 |
device: Optional[str] = None,
|
| 475 |
use_4bit_quantization: bool = False,
|
| 476 |
-
device_map: Optional[str] =
|
| 477 |
trust_remote_code: bool = True,
|
| 478 |
extract_reasoning_steps: bool = True,
|
| 479 |
**kwargs
|
|
@@ -521,7 +521,7 @@ def reasoning_completion(
|
|
| 521 |
params["device"] = device
|
| 522 |
if use_4bit_quantization:
|
| 523 |
params["use_4bit_quantization"] = use_4bit_quantization
|
| 524 |
-
if device_map
|
| 525 |
params["device_map"] = device_map
|
| 526 |
if not trust_remote_code:
|
| 527 |
params["trust_remote_code"] = trust_remote_code
|
|
|
|
| 379 |
model: Optional[str] = None,
|
| 380 |
device: Optional[str] = None,
|
| 381 |
use_4bit_quantization: bool = False,
|
| 382 |
+
device_map: Optional[str] = None,
|
| 383 |
trust_remote_code: bool = True,
|
| 384 |
**kwargs
|
| 385 |
) -> Dict[str, Any]:
|
|
|
|
| 448 |
params["device"] = device
|
| 449 |
if use_4bit_quantization:
|
| 450 |
params["use_4bit_quantization"] = use_4bit_quantization
|
| 451 |
+
if device_map:
|
| 452 |
params["device_map"] = device_map
|
| 453 |
if not trust_remote_code:
|
| 454 |
params["trust_remote_code"] = trust_remote_code
|
|
|
|
| 473 |
model: Optional[str] = None,
|
| 474 |
device: Optional[str] = None,
|
| 475 |
use_4bit_quantization: bool = False,
|
| 476 |
+
device_map: Optional[str] = None,
|
| 477 |
trust_remote_code: bool = True,
|
| 478 |
extract_reasoning_steps: bool = True,
|
| 479 |
**kwargs
|
|
|
|
| 521 |
params["device"] = device
|
| 522 |
if use_4bit_quantization:
|
| 523 |
params["use_4bit_quantization"] = use_4bit_quantization
|
| 524 |
+
if device_map:
|
| 525 |
params["device_map"] = device_map
|
| 526 |
if not trust_remote_code:
|
| 527 |
params["trust_remote_code"] = trust_remote_code
|