Spaces:
Paused
Paused
from unsloth import FastLanguageModel | |
from unsloth.chat_templates import get_chat_template | |
import re | |
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.language_models import BaseChatModel, SimpleChatModel | |
from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage | |
from langchain.schema import AIMessage, HumanMessage | |
import gradio as gr | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
from langchain_core.runnables import run_in_executor | |
#loading model | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name = "Ankitnau25/govtbot-llama3.1-v1", | |
max_seq_length = 8192, | |
load_in_4bit = True, | |
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf | |
) | |
# loading tokenizer | |
tokenizer = get_chat_template( | |
tokenizer, | |
chat_template = "alpaca", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth | |
mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt"}, # ShareGPT style | |
map_eos_token = True, # Maps <|im_end|> to </s> instead | |
) | |
FastLanguageModel.for_inference(model) # Enable native 2x faster inference | |
def predict (inp_text): | |
messages = [ | |
{"from": "human", "value": f"{inp_text}"}, | |
] | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
tokenize = True, | |
add_generation_prompt = True, # Must add for generation | |
return_tensors = "pt", | |
).to("cuda") | |
model.generation_config.pad_token_id = tokenizer.pad_token_id | |
outputs = model.generate(input_ids = inputs, use_cache = True ,temperature = 0.01,max_new_tokens = 1024) | |
result = tokenizer.batch_decode(outputs) | |
# print(result) | |
return filter_user_assistant_msgs(result[0]) | |
def filter_user_assistant_msgs(text): | |
msg_pattern = r".*Response:\n(.*?)<\|eot_id\|>" | |
match = re.match(msg_pattern, text, re.DOTALL) | |
if match: | |
message = match.group(1).strip() | |
else: | |
message = text | |
return message | |
#defining custom Langchain chat model | |
class CustomChatModelAdvanced(BaseChatModel): | |
"""A custom chat model that echoes the first `n` characters of the input. | |
When contributing an implementation to LangChain, carefully document | |
the model including the initialization parameters, include | |
an example of how to initialize the model and include any relevant | |
links to the underlying models documentation or API. | |
Example: | |
.. code-block:: python | |
model = CustomChatModel(n=2) | |
result = model.invoke([HumanMessage(content="hello")]) | |
result = model.batch([[HumanMessage(content="hello")], | |
[HumanMessage(content="world")]]) | |
""" | |
model_name: str | |
"""The name of the model""" | |
n: int | |
"""The number of characters from the last message of the prompt to be echoed.""" | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
"""Override the _generate method to implement the chat model logic. | |
This can be a call to an API, a call to a local model, or any other | |
implementation that generates a response to the input prompt. | |
Args: | |
messages: the prompt composed of a list of messages. | |
stop: a list of strings on which the model should stop generating. | |
If generation stops due to a stop token, the stop token itself | |
SHOULD BE INCLUDED as part of the output. This is not enforced | |
across models right now, but it's a good practice to follow since | |
it makes it much easier to parse the output of the model | |
downstream and understand why generation stopped. | |
run_manager: A run manager with callbacks for the LLM. | |
""" | |
# Replace this with actual logic to generate a response from a list | |
# of messages. | |
last_message = messages[-1].content | |
tokens = predict(last_message) | |
message = AIMessage( | |
content=tokens, | |
additional_kwargs={}, # Used to add additional payload (e.g., function calling request) | |
response_metadata={ # Use for response metadata | |
"time_in_seconds": 3, | |
}, | |
) | |
## | |
generation = ChatGeneration(message=message) | |
return ChatResult(generations=[generation]) | |
def _stream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
"""Stream the output of the model. | |
This method should be implemented if the model can generate output | |
in a streaming fashion. If the model does not support streaming, | |
do not implement it. In that case streaming requests will be automatically | |
handled by the _generate method. | |
Args: | |
messages: the prompt composed of a list of messages. | |
stop: a list of strings on which the model should stop generating. | |
If generation stops due to a stop token, the stop token itself | |
SHOULD BE INCLUDED as part of the output. This is not enforced | |
across models right now, but it's a good practice to follow since | |
it makes it much easier to parse the output of the model | |
downstream and understand why generation stopped. | |
run_manager: A run manager with callbacks for the LLM. | |
""" | |
last_message = messages[-1] | |
tokens = last_message.content[: self.n] | |
for token in tokens: | |
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token)) | |
if run_manager: | |
# This is optional in newer versions of LangChain | |
# The on_llm_new_token will be called automatically | |
run_manager.on_llm_new_token(token, chunk=chunk) | |
yield chunk | |
# Let's add some other information (e.g., response metadata) | |
chunk = ChatGenerationChunk( | |
message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3}) | |
) | |
if run_manager: | |
# This is optional in newer versions of LangChain | |
# The on_llm_new_token will be called automatically | |
run_manager.on_llm_new_token(token, chunk=chunk) | |
yield chunk | |
def _llm_type(self) -> str: | |
"""Get the type of language model used by this chat model.""" | |
return "echoing-chat-model-advanced" | |
def _identifying_params(self) -> Dict[str, Any]: | |
"""Return a dictionary of identifying parameters. | |
This information is used by the LangChain callback system, which | |
is used for tracing purposes make it possible to monitor LLMs. | |
""" | |
return { | |
# The model name allows users to specify custom token counting | |
# rules in LLM monitoring applications (e.g., in LangSmith users | |
# can provide per token pricing for their model and monitor | |
# costs for the given LLM.) | |
"model_name": self.model_name, | |
} | |
llm_model = CustomChatModelAdvanced(model_name='unsloth_llama3.1',n=4) | |
def predict_chat(message, history): | |
history_langchain_format = [] | |
for human, ai in history: | |
history_langchain_format.append(HumanMessage(content=human)) | |
history_langchain_format.append(AIMessage(content=ai)) | |
history_langchain_format.append(HumanMessage(content=message)) | |
gpt_response = llm_model(history_langchain_format) | |
return gpt_response.content | |
gr.ChatInterface(predict_chat).launch(debug=True) | |