|
|
|
import os
|
|
import sys
|
|
from dotenv import load_dotenv
|
|
from typing import Any
|
|
import torch
|
|
from transformers import AutoModel, AutoTokenizer, AutoProcessor
|
|
|
|
|
|
from src.logger import logging
|
|
from src.exception import CustomExceptionHandling
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
access_token = os.environ.get("ACCESS_TOKEN")
|
|
|
|
|
|
def load_model_tokenizer_and_processor(model_name: str, device: str) -> Any:
|
|
"""
|
|
Load the model, tokenizer and processor.
|
|
|
|
Args:
|
|
- model_name (str): The name of the model to load.
|
|
- device (str): The device to load the model onto.
|
|
|
|
Returns:
|
|
- model: The loaded model.
|
|
- tokenizer: The loaded tokenizer.
|
|
- processor: The loaded processor.
|
|
"""
|
|
try:
|
|
|
|
model = AutoModel.from_pretrained(
|
|
model_name,
|
|
trust_remote_code=True,
|
|
attn_implementation="sdpa",
|
|
torch_dtype=torch.bfloat16,
|
|
init_vision=True,
|
|
init_audio=False,
|
|
init_tts=False,
|
|
token=access_token
|
|
)
|
|
model = model.eval().to(device=device)
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_name, trust_remote_code=True, token=access_token
|
|
)
|
|
processor = AutoProcessor.from_pretrained(
|
|
model_name, trust_remote_code=True, token=access_token
|
|
)
|
|
|
|
|
|
logging.info("Model, tokenizer and processor loaded successfully.")
|
|
|
|
|
|
return model, tokenizer, processor
|
|
|
|
|
|
except Exception as e:
|
|
|
|
raise CustomExceptionHandling(e, sys) from e
|
|
|