File size: 337 Bytes
ffc95f0 |
1 2 3 4 5 6 7 8 9 |
from transformers import pipeline
import torch
def model_fn(model_dir):
"""
Overrides the default model load function in the HuggingFace Deep Learning Container
"""
instruct_pipeline = pipeline(model="fuwangwang/mpt-7b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
return instruct_pipeline |