| import os | |
| from .ModelStrategy import ModelStrategy | |
| from langchain_openai import ChatOpenAI | |
| from langchain_mistralai.chat_models import ChatMistralAI | |
| from langchain_anthropic import ChatAnthropic | |
| from llamaapi import LlamaAPI | |
| from langchain_experimental.llms import ChatLlamaAPI | |
| class MistralModel(ModelStrategy): | |
| def get_model(self, model_name): | |
| return ChatMistralAI(model=model_name) | |
| class OpenAIModel(ModelStrategy): | |
| def get_model(self, model_name): | |
| return ChatOpenAI(model=model_name) | |
| class AnthropicModel(ModelStrategy): | |
| def get_model(self, model_name): | |
| return ChatAnthropic(model=model_name) | |
| class LlamaAPIModel(ModelStrategy): | |
| def get_model(self, model_name): | |
| llama = LlamaAPI(os.environ.get("LLAMA_API_KEY")) | |
| return ChatLlamaAPI(client=llama, model=model_name) | |
| class ModelManager(): | |
| def __init__(self): | |
| self.models = { | |
| "mistral": MistralModel(), | |
| "openai": OpenAIModel(), | |
| "anthropic": AnthropicModel(), | |
| "llama": LlamaAPIModel() | |
| } | |
| def get_model(self, provider, model_name): | |
| return self.models[provider].get_model(model_name) |