lyclyc52's picture
Update: add original llava code
48ca1e2
raw
history blame
966 Bytes
import os
AVAILABLE_MODELS = {
"llava_llama": "LlavaLlamaForCausalLM, LlavaConfig",
"llava_gemma": "LlavaGemmaForCausalLM, LlavaGemmaConfig",
"llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig",
# "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig",
"llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig",
"llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig",
# Add other models as needed
}
for model_name, model_classes in AVAILABLE_MODELS.items():
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
# print(f"import {model_classes} successfully")
try:
exec(f"from .language_model.{model_name} import {model_classes}")
print(f"import {model_classes} successfully")
except ImportError:
# import traceback
# traceback.print_exc()
print(f"Failed to import {model_classes} from llava.language_model.{model_name}")
pass