mpt-7b / code /inference.py
fuwangwang's picture
Create code/inference.py
ffc95f0
raw
history blame contribute delete
337 Bytes
from transformers import pipeline
import torch
def model_fn(model_dir):
"""
Overrides the default model load function in the HuggingFace Deep Learning Container
"""
instruct_pipeline = pipeline(model="fuwangwang/mpt-7b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
return instruct_pipeline