from transformers import pipeline
import torch

def model_fn(model_dir):
    """
    Overrides the default model load function in the HuggingFace Deep Learning Container
    """
    instruct_pipeline = pipeline(model="fuwangwang/mpt-7b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
    return instruct_pipeline