Spaces:
Runtime error
Runtime error
File size: 2,668 Bytes
4670a90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from pathlib import Path
from typing import Any, Callable, Dict, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.llms import CTransformers, HuggingFacePipeline
from langchain.llms.base import LLM
from .utils import merge
def get_gptq_llm(config: Dict[str, Any]) -> LLM:
try:
from auto_gptq import AutoGPTQForCausalLM
except ImportError:
raise ImportError(
"Could not import `auto_gptq` package. "
"Please install it with `pip install chatdocs[gptq]`"
)
from transformers import (
AutoTokenizer,
TextGenerationPipeline,
MODEL_FOR_CAUSAL_LM_MAPPING,
)
local_files_only = not config["download"]
config = {**config["gptq"]}
model_name_or_path = config.pop("model")
model_file = config.pop("model_file", None)
pipeline_kwargs = config.pop("pipeline_kwargs", None) or {}
model_basename = None
use_safetensors = False
if model_file:
model_basename = Path(model_file).stem
use_safetensors = model_file.endswith(".safetensors")
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
local_files_only=local_files_only,
)
model = AutoGPTQForCausalLM.from_quantized(
model_name_or_path,
model_basename=model_basename,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
**config,
)
MODEL_FOR_CAUSAL_LM_MAPPING.register("chatdocs-gptq", model.__class__)
pipeline = TextGenerationPipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
**pipeline_kwargs,
)
return HuggingFacePipeline(pipeline=pipeline)
def get_llm(
config: Dict[str, Any],
*,
callback: Optional[Callable[[str], None]] = None,
) -> LLM:
class CallbackHandler(BaseCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs) -> None:
callback(token)
callbacks = [CallbackHandler()] if callback else None
local_files_only = not config["download"]
if config["llm"] == "ctransformers":
config = {**config["ctransformers"]}
config = merge(config, {"config": {"local_files_only": local_files_only}})
llm = CTransformers(callbacks=callbacks, **config)
elif config["llm"] == "gptq":
llm = get_gptq_llm(config)
else:
config = {**config["huggingface"]}
config["model_id"] = config.pop("model")
config = merge(config, {"model_kwargs": {"local_files_only": local_files_only}})
llm = HuggingFacePipeline.from_model_id(task="text-generation", **config)
return llm
|