File size: 5,001 Bytes
314bc09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from typing import Optional, List, Any, Dict, Iterator
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from pydantic import PrivateAttr
# used for qwen inference
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class ChatQWEN(BaseChatModel):
"""A custom chat model that invoke Qwen2.5-1.5B-Instruct.
Example:
.. code-block:: python
model = ChatQWEN()
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""
model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
"""The name of the model"""
# other params
temperature: float = 0.7
max_new_tokens: int = 512
device_map: str = "auto"
# private attributes
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
"""The model to call"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# load qwen
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name, trust_remote_code=True
)
self._model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device_map,
torch_dtype=torch.bfloat16,
offload_folder=None,
low_cpu_mem_usage=True,
trust_remote_code=True,
).eval()
# Adicione isto após carregar o modelo
print(f"GPU memory used: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved()/1024**3:.2f} GB")
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
"""Messages from LangChain to format expected by QWEN"""
if isinstance(message, HumanMessage):
return {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
return {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
return {"role": "system", "content": message.content}
else:
raise ValueError(f"Message type not supported: {type(message)}")
def qwen(self, messages):
# make the prompt in a way to the model understand
text = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = self._tokenizer([text], return_tensors="pt").to(
self._model.device
)
# generate the qwen text
with torch.no_grad():
generated_ids = self._model.generate(
**model_inputs,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
)
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
# get the response of the LLM
response = self._tokenizer.batch_decode(
generated_ids, skip_special_tokens=True
)[0]
return response
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[Any] = None,
**kwargs: Any,
) -> ChatResult:
"""
Args:
messages: the prompt composed of a list of messages.
"""
# parse the messages to feed qwen
formatted_messages = [self._convert_message_to_dict(msg) for msg in messages]
# call qwen
qwen_response = self.qwen(formatted_messages)
# process the stop tokens
if stop:
for stop_word in stop:
qwen_response = qwen_response.split(stop_word)[0]
# message type update
message = AIMessage(content=qwen_response.strip())
# return
generation = ChatGeneration(message=message, text=qwen_response.strip())
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "qwen-chat-model"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
}
|