|
import torch |
|
from session import logger, log_sys_info |
|
from transformers import AutoTokenizer, GenerationConfig, AutoModel |
|
|
|
|
|
chatglm = 'THUDM/chatglm-6b' |
|
chatglm_rev = '4de8efe' |
|
int8_model = 'KumaTea/twitter-int8' |
|
int8_model_rev = '1136001' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_sys_info() |
|
|
|
model = AutoModel.from_pretrained( |
|
int8_model, |
|
trust_remote_code=True, |
|
revision=int8_model_rev |
|
).float() |
|
tokenizer = AutoTokenizer.from_pretrained(chatglm, trust_remote_code=True, revision=chatglm_rev) |
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
torch.set_default_tensor_type(torch.FloatTensor) |
|
|
|
logger.info('[SYS] Model loaded') |
|
log_sys_info() |
|
|