Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| from datetime import datetime | |
| from huggingface_hub import hf_hub_download | |
| from pynvml import * | |
| nvmlInit() | |
| gpu_h = nvmlDeviceGetHandleByIndex(0) | |
| ctx_limit = 512 | |
| title = "RWKV-4 14B fp16" | |
| desc = f'''DEMO limited to ctxlen {ctx_limit}, and slow because A10g does not have enough VRAM for 14B fp16 (some layers are computed on CPU instead). Links: | |
| <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a> | |
| <a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a> | |
| <a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a> | |
| ''' | |
| os.environ["RWKV_JIT_ON"] = '1' | |
| os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster) | |
| from rwkv.model import RWKV | |
| model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-14b", filename="RWKV-4-Pile-14B-20230213-8019.pth") | |
| model = RWKV(model=model_path, strategy='cuda fp16 *32 -> cpu fp32') | |
| # model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-169m", filename="RWKV-4-Pile-169M-20220807-8023.pth") | |
| # model = RWKV(model=model_path, strategy='cuda fp16') | |
| from rwkv.utils import PIPELINE, PIPELINE_ARGS | |
| pipeline = PIPELINE(model, "20B_tokenizer.json") | |
| def infer( | |
| ctx, | |
| token_count=10, | |
| temperature=1.0, | |
| top_p=0.8, | |
| presencePenalty = 0.1, | |
| countPenalty = 0.1, | |
| ): | |
| args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), | |
| alpha_frequency = countPenalty, | |
| alpha_presence = presencePenalty, | |
| token_ban = [0], # ban the generation of some tokens | |
| token_stop = []) # stop generation whenever you see any token here | |
| ctx = ctx.strip(' ') | |
| if ctx.endswith('\n'): | |
| ctx = f'\n{ctx.strip()}\n' | |
| else: | |
| ctx = f'\n{ctx.strip()}' | |
| gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) | |
| print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}') | |
| all_tokens = [] | |
| out_last = 0 | |
| out_str = '' | |
| occurrence = {} | |
| state = None | |
| for i in range(int(token_count)): | |
| out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state) | |
| for n in args.token_ban: | |
| out[n] = -float('inf') | |
| for n in occurrence: | |
| out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) | |
| token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p) | |
| if token in args.token_stop: | |
| break | |
| all_tokens += [token] | |
| if token not in occurrence: | |
| occurrence[token] = 1 | |
| else: | |
| occurrence[token] += 1 | |
| tmp = pipeline.decode(all_tokens[out_last:]) | |
| if '\ufffd' not in tmp: | |
| out_str += tmp | |
| yield out_str.strip() | |
| out_last = i + 1 | |
| yield out_str.strip() | |
| examples = [ | |
| ["Expert Questions & Helpful Answers\nAsk Research Experts\nQuestion:\nHow can we eliminate poverty?\n\nFull Answer:\n", 100, 1.0, 0.8, 0.1, 0.1], | |
| ["Ask Expert\n\nQuestion:\nWhat are some good plans for world peace?\n\nExpert Full Answer:\n", 100, 1.0, 0.8, 0.1, 0.1], | |
| ["Q & A\n\nQuestion:\nWhy is the sky blue?\n\nDetailed Expert Answer:\n", 100, 1.0, 0.8, 0.1, 0.1], | |
| ["Here is a shell script to find all .hpp files in /home/workspace and delete the 3th row string of these files:", 100, 1.0, 0.8, 0.1, 0.1], | |
| ["Building a website can be done in 10 simple steps:\n1.", 100, 1.0, 0.8, 0.1, 0.1], | |
| ["A Chinese phrase is provided: 百闻不如一见。\nThe masterful Chinese translator flawlessly translates the phrase into English:", 100, 1.0, 0.8, 0.1, 0.1], | |
| ["I believe the meaning of life is", 100, 1.0, 0.8, 0.1, 0.1], | |
| ["Simply put, the theory of relativity states that", 100, 1.0, 0.8, 0.1, 0.1], | |
| ] | |
| iface = gr.Interface( | |
| fn=infer, | |
| description=f'''{desc}''', | |
| allow_flagging="never", | |
| inputs=[ | |
| gr.Textbox(lines=14, label="Prompt"), # prompt | |
| gr.Slider(10, 200, step=10, value=100), # token_count | |
| gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature | |
| gr.Slider(0.0, 1.0, step=0.05, value=0.8), # top_p | |
| gr.Slider(0.0, 1.0, step=0.1, value=0.1), # presencePenalty | |
| gr.Slider(0.0, 1.0, step=0.1, value=0.1), # countPenalty | |
| ], | |
| outputs=gr.Textbox(label="Generated Output", lines=30), | |
| examples=examples, | |
| cache_examples=False, | |
| ).queue() | |
| demo = gr.TabbedInterface( | |
| [iface], ["Generative"], | |
| title=title, | |
| ) | |
| demo.queue() | |
| demo.launch(share=False) | |