from transformers import pipeline | |
import torch | |
def model_fn(model_dir): | |
""" | |
Overrides the default model load function in the HuggingFace Deep Learning Container | |
""" | |
instruct_pipeline = pipeline(model="fuwangwang/mpt-7b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto") | |
return instruct_pipeline |