| """OpenAI chat wrapper.""" | |
| from __future__ import annotations | |
| from typing import ( | |
| Any, | |
| AsyncIterator, | |
| Iterator, | |
| List, | |
| Optional, | |
| Union, | |
| ) | |
| from langchain_community.chat_models import ChatOpenAI, AzureChatOpenAI | |
| from langchain_community.chat_models.openai import acompletion_with_retry, _convert_delta_to_message_chunk | |
| from langchain_core.callbacks import ( | |
| AsyncCallbackManagerForLLMRun, | |
| CallbackManagerForLLMRun, | |
| ) | |
| from langchain_core.language_models.chat_models import ( | |
| agenerate_from_stream, | |
| generate_from_stream, | |
| ) | |
| from langchain_core.messages import ( | |
| AIMessageChunk, | |
| BaseMessage, | |
| ) | |
| from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
| from langchain_core.pydantic_v1 import BaseModel | |
| from langchain_community.adapters.openai import ( | |
| convert_dict_to_message, | |
| ) | |
| class H2OBaseChatOpenAI: | |
| def _stream( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> Iterator[ChatGenerationChunk]: | |
| message_dicts, params = self._create_message_dicts(messages, stop) | |
| params = {**params, **kwargs, "stream": True} | |
| default_chunk_class = AIMessageChunk | |
| for chunk in self.completion_with_retry( | |
| messages=message_dicts, run_manager=run_manager, **params | |
| ): | |
| if not isinstance(chunk, dict): | |
| chunk = chunk.dict() | |
| if len(chunk["choices"]) == 0: | |
| continue | |
| choice = chunk["choices"][0] | |
| chunk = _convert_delta_to_message_chunk( | |
| choice["delta"], default_chunk_class | |
| ) | |
| finish_reason = choice.get("finish_reason") | |
| generation_info = ( | |
| dict(finish_reason=finish_reason) if finish_reason is not None else None | |
| ) | |
| default_chunk_class = chunk.__class__ | |
| cg_chunk = ChatGenerationChunk( | |
| message=chunk, generation_info=generation_info | |
| ) | |
| cg_chunk = self.mod_cg_chunk(cg_chunk) | |
| if run_manager: | |
| run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk) | |
| yield cg_chunk | |
| def mod_cg_chunk(self, cg_chunk: ChatGenerationChunk) -> ChatGenerationChunk: | |
| if 'tools' in self.model_kwargs and self.model_kwargs['tools']: | |
| if 'tool_calls' in cg_chunk.message.additional_kwargs: | |
| cg_chunk.message.content = cg_chunk.text = cg_chunk.message.additional_kwargs['tool_calls'][0]['function']['arguments'] | |
| else: | |
| cg_chunk.text = '' | |
| return cg_chunk | |
| def _generate( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| stream: Optional[bool] = None, | |
| **kwargs: Any, | |
| ) -> ChatResult: | |
| should_stream = stream if stream is not None else self.streaming | |
| if should_stream: | |
| stream_iter = self._stream( | |
| messages, stop=stop, run_manager=run_manager, **kwargs | |
| ) | |
| return generate_from_stream(stream_iter) | |
| message_dicts, params = self._create_message_dicts(messages, stop) | |
| params = { | |
| **params, | |
| **({"stream": stream} if stream is not None else {}), | |
| **kwargs, | |
| } | |
| response = self.completion_with_retry( | |
| messages=message_dicts, run_manager=run_manager, **params | |
| ) | |
| return self._create_chat_result(response) | |
| def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: | |
| generations = [] | |
| if not isinstance(response, dict): | |
| response = response.dict() | |
| for res in response["choices"]: | |
| message = convert_dict_to_message(res["message"]) | |
| if 'tools' in self.model_kwargs and self.model_kwargs['tools']: | |
| if 'tool_calls' in message.additional_kwargs: | |
| message.content = ''.join([x['function']['arguments'] for x in message.additional_kwargs['tool_calls']]) | |
| generation_info = dict(finish_reason=res.get("finish_reason")) | |
| if "logprobs" in res: | |
| generation_info["logprobs"] = res["logprobs"] | |
| gen = ChatGeneration( | |
| message=message, | |
| generation_info=generation_info, | |
| ) | |
| generations.append(gen) | |
| token_usage = response.get("usage", {}) | |
| llm_output = { | |
| "token_usage": token_usage, | |
| "model_name": self.model_name, | |
| "system_fingerprint": response.get("system_fingerprint", ""), | |
| } | |
| return ChatResult(generations=generations, llm_output=llm_output) | |
| async def _astream( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> AsyncIterator[ChatGenerationChunk]: | |
| message_dicts, params = self._create_message_dicts(messages, stop) | |
| params = {**params, **kwargs, "stream": True} | |
| default_chunk_class = AIMessageChunk | |
| async for chunk in await acompletion_with_retry( | |
| self, messages=message_dicts, run_manager=run_manager, **params | |
| ): | |
| if not isinstance(chunk, dict): | |
| chunk = chunk.dict() | |
| if len(chunk["choices"]) == 0: | |
| continue | |
| choice = chunk["choices"][0] | |
| chunk = _convert_delta_to_message_chunk( | |
| choice["delta"], default_chunk_class | |
| ) | |
| finish_reason = choice.get("finish_reason") | |
| generation_info = ( | |
| dict(finish_reason=finish_reason) if finish_reason is not None else None | |
| ) | |
| default_chunk_class = chunk.__class__ | |
| cg_chunk = ChatGenerationChunk( | |
| message=chunk, generation_info=generation_info | |
| ) | |
| cg_chunk = self.mod_cg_chunk(cg_chunk) | |
| if run_manager: | |
| await run_manager.on_llm_new_token(token=cg_chunk.text, chunk=cg_chunk) | |
| yield cg_chunk | |
| async def _agenerate( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
| stream: Optional[bool] = None, | |
| **kwargs: Any, | |
| ) -> ChatResult: | |
| should_stream = stream if stream is not None else self.streaming | |
| if should_stream: | |
| stream_iter = self._astream( | |
| messages, stop=stop, run_manager=run_manager, **kwargs | |
| ) | |
| return await agenerate_from_stream(stream_iter) | |
| message_dicts, params = self._create_message_dicts(messages, stop) | |
| params = { | |
| **params, | |
| **({"stream": stream} if stream is not None else {}), | |
| **kwargs, | |
| } | |
| response = await acompletion_with_retry( | |
| self, messages=message_dicts, run_manager=run_manager, **params | |
| ) | |
| return self._create_chat_result(response) | |
| class H2OBaseAzureChatOpenAI(H2OBaseChatOpenAI, AzureChatOpenAI): | |
| pass | |