from pathlib import Path from typing import Any, Callable, Dict, Optional from langchain.callbacks.base import BaseCallbackHandler from langchain.llms import CTransformers, HuggingFacePipeline from langchain.llms.base import LLM from .utils import merge def get_gptq_llm(config: Dict[str, Any]) -> LLM: try: from auto_gptq import AutoGPTQForCausalLM except ImportError: raise ImportError( "Could not import `auto_gptq` package. " "Please install it with `pip install chatdocs[gptq]`" ) from transformers import ( AutoTokenizer, TextGenerationPipeline, MODEL_FOR_CAUSAL_LM_MAPPING, ) local_files_only = not config["download"] config = {**config["gptq"]} model_name_or_path = config.pop("model") model_file = config.pop("model_file", None) pipeline_kwargs = config.pop("pipeline_kwargs", None) or {} model_basename = None use_safetensors = False if model_file: model_basename = Path(model_file).stem use_safetensors = model_file.endswith(".safetensors") tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, local_files_only=local_files_only, ) model = AutoGPTQForCausalLM.from_quantized( model_name_or_path, model_basename=model_basename, use_safetensors=use_safetensors, local_files_only=local_files_only, **config, ) MODEL_FOR_CAUSAL_LM_MAPPING.register("chatdocs-gptq", model.__class__) pipeline = TextGenerationPipeline( task="text-generation", model=model, tokenizer=tokenizer, **pipeline_kwargs, ) return HuggingFacePipeline(pipeline=pipeline) def get_llm( config: Dict[str, Any], *, callback: Optional[Callable[[str], None]] = None, ) -> LLM: class CallbackHandler(BaseCallbackHandler): def on_llm_new_token(self, token: str, **kwargs) -> None: callback(token) callbacks = [CallbackHandler()] if callback else None local_files_only = not config["download"] if config["llm"] == "ctransformers": config = {**config["ctransformers"]} config = merge(config, {"config": {"local_files_only": local_files_only}}) llm = CTransformers(callbacks=callbacks, **config) elif config["llm"] == "gptq": llm = get_gptq_llm(config) else: config = {**config["huggingface"]} config["model_id"] = config.pop("model") config = merge(config, {"model_kwargs": {"local_files_only": local_files_only}}) llm = HuggingFacePipeline.from_model_id(task="text-generation", **config) return llm