File size: 337 Bytes
ffc95f0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
from transformers import pipeline
import torch

def model_fn(model_dir):
    """
    Overrides the default model load function in the HuggingFace Deep Learning Container
    """
    instruct_pipeline = pipeline(model="fuwangwang/mpt-7b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
    return instruct_pipeline