File size: 5,001 Bytes
314bc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Optional, List, Any, Dict, Iterator
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from pydantic import PrivateAttr

# used for qwen inference
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


class ChatQWEN(BaseChatModel):
    """A custom chat model that invoke Qwen2.5-1.5B-Instruct.

    Example:

        .. code-block:: python

            model = ChatQWEN()
            result = model.invoke([HumanMessage(content="hello")])
            result = model.batch([[HumanMessage(content="hello")],
                                 [HumanMessage(content="world")]])
    """

    model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
    """The name of the model"""
    # other params
    temperature: float = 0.7
    max_new_tokens: int = 512
    device_map: str = "auto"

    # private attributes
    _model: Any = PrivateAttr()
    _tokenizer: Any = PrivateAttr()
    """The model to call"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # load qwen
        self._tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, trust_remote_code=True
        )

        self._model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            device_map=self.device_map,
            torch_dtype=torch.bfloat16,
            offload_folder=None,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
        ).eval()

        # Adicione isto após carregar o modelo
        print(f"GPU memory used: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
        print(f"GPU memory reserved: {torch.cuda.memory_reserved()/1024**3:.2f} GB")

    def _convert_message_to_dict(self, message: BaseMessage) -> dict:
        """Messages from LangChain to format expected by QWEN"""
        if isinstance(message, HumanMessage):
            return {"role": "user", "content": message.content}
        elif isinstance(message, AIMessage):
            return {"role": "assistant", "content": message.content}
        elif isinstance(message, SystemMessage):
            return {"role": "system", "content": message.content}
        else:
            raise ValueError(f"Message type not supported: {type(message)}")

    def qwen(self, messages):
        # make the prompt in a way to the model understand
        text = self._tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        model_inputs = self._tokenizer([text], return_tensors="pt").to(
            self._model.device
        )

        # generate the qwen text
        with torch.no_grad():
            generated_ids = self._model.generate(
                **model_inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=self.temperature,
            )
            generated_ids = [
                output_ids[len(input_ids) :]
                for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
            ]

            # get the response of the LLM
            response = self._tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True
            )[0]

        return response

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[Any] = None,
        **kwargs: Any,
    ) -> ChatResult:
        """
        Args:
            messages: the prompt composed of a list of messages.
        """
        # parse the messages to feed qwen
        formatted_messages = [self._convert_message_to_dict(msg) for msg in messages]

        # call qwen
        qwen_response = self.qwen(formatted_messages)

        # process the stop tokens
        if stop:
            for stop_word in stop:
                qwen_response = qwen_response.split(stop_word)[0]

        # message type update
        message = AIMessage(content=qwen_response.strip())

        # return
        generation = ChatGeneration(message=message, text=qwen_response.strip())
        return ChatResult(generations=[generation])

    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model."""
        return "qwen-chat-model"

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Return a dictionary of identifying parameters.

        This information is used by the LangChain callback system, which
        is used for tracing purposes make it possible to monitor LLMs.
        """
        return {
            # The model name allows users to specify custom token counting
            # rules in LLM monitoring applications (e.g., in LangSmith users
            # can provide per token pricing for their model and monitor
            # costs for the given LLM.)
            "model_name": self.model_name,
        }