|
from argparse import ArgumentParser |
|
|
|
from lagent.llms import HFTransformer |
|
from lagent.llms.meta_template import INTERNLM2_META as META |
|
|
|
|
|
def parse_args(): |
|
parser = ArgumentParser(description='chatbot') |
|
parser.add_argument( |
|
'--path', |
|
type=str, |
|
default='internlm/internlm2-chat-20b', |
|
help='The path to the model') |
|
parser.add_argument( |
|
'--mode', |
|
type=str, |
|
default='chat', |
|
help='Completion through chat or generate') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
model = HFTransformer( |
|
path=args.path, |
|
meta_template=META, |
|
max_new_tokens=1024, |
|
top_p=0.8, |
|
top_k=None, |
|
temperature=0.1, |
|
repetition_penalty=1.0, |
|
stop_words=['<|im_end|>']) |
|
|
|
def input_prompt(): |
|
print('\ndouble enter to end input >>> ', end='', flush=True) |
|
sentinel = '' |
|
return '\n'.join(iter(input, sentinel)) |
|
|
|
history = [] |
|
while True: |
|
try: |
|
prompt = input_prompt() |
|
except UnicodeDecodeError: |
|
print('UnicodeDecodeError') |
|
continue |
|
if prompt == 'exit': |
|
exit(0) |
|
history.append(dict(role='user', content=prompt)) |
|
if args.mode == 'generate': |
|
history = [dict(role='user', content=prompt)] |
|
print('\nInternLm2:', end='') |
|
current_length = 0 |
|
for status, response, _ in model.stream_chat(history): |
|
print(response[current_length:], end='', flush=True) |
|
current_length = len(response) |
|
history.append(dict(role='assistant', content=response)) |
|
print('') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|