Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						3dfbcaf
	
1
								Parent(s):
							
							763f1c5
								
first commit
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,101 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            import html
         | 
| 5 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
         | 
| 6 | 
            +
            from threading import Thread
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            model_name_or_path = 'TencentARC/LLaMA-Pro-8B-Instruct'
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
         | 
| 11 | 
            +
            model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            model.half().cuda()
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            def convert_message(message):
         | 
| 16 | 
            +
                message_text = ""
         | 
| 17 | 
            +
                if message["content"] is None and message["role"] == "assistant":
         | 
| 18 | 
            +
                    message_text += "<|assistant|>\n"  # final msg
         | 
| 19 | 
            +
                elif message["role"] == "system":
         | 
| 20 | 
            +
                    message_text += "<|system|>\n" + message["content"].strip() + "\n"
         | 
| 21 | 
            +
                elif message["role"] == "user":
         | 
| 22 | 
            +
                    message_text += "<|user|>\n" + message["content"].strip() + "\n"
         | 
| 23 | 
            +
                elif message["role"] == "assistant":
         | 
| 24 | 
            +
                    message_text += "<|assistant|>\n" + message["content"].strip() + "\n"
         | 
| 25 | 
            +
                else:
         | 
| 26 | 
            +
                    raise ValueError("Invalid role: {}".format(message["role"]))
         | 
| 27 | 
            +
                # gradio cleaning - it converts stuff to html entities
         | 
| 28 | 
            +
                # we would need special handling for where we want to keep the html...
         | 
| 29 | 
            +
                message_text = html.unescape(message_text)
         | 
| 30 | 
            +
                # it also converts newlines to <br>, undo this.
         | 
| 31 | 
            +
                message_text = message_text.replace("<br>", "\n")
         | 
| 32 | 
            +
                return message_text
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def convert_history(chat_history, max_input_length=1024):
         | 
| 35 | 
            +
                history_text = ""
         | 
| 36 | 
            +
                idx = len(chat_history) - 1
         | 
| 37 | 
            +
                # add messages in reverse order until we hit max_input_length
         | 
| 38 | 
            +
                while len(tokenizer(history_text).input_ids) < max_input_length and idx >= 0:
         | 
| 39 | 
            +
                    user_message, chatbot_message = chat_history[idx]
         | 
| 40 | 
            +
                    user_message = convert_message({"role": "user", "content": user_message})
         | 
| 41 | 
            +
                    chatbot_message = convert_message({"role": "assistant", "content": chatbot_message})
         | 
| 42 | 
            +
                    history_text = user_message + chatbot_message + history_text
         | 
| 43 | 
            +
                    idx = idx - 1
         | 
| 44 | 
            +
                # if nothing was added, add <|assistant|> to start generation.
         | 
| 45 | 
            +
                if history_text == "":
         | 
| 46 | 
            +
                    history_text = "<|assistant|>\n"
         | 
| 47 | 
            +
                return history_text
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            @torch.inference_mode()
         | 
| 50 | 
            +
            def instruct(instruction, max_token_output=1024):
         | 
| 51 | 
            +
                input_text = instruction
         | 
| 52 | 
            +
                streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
         | 
| 53 | 
            +
                input_ids = tokenizer(input_text, return_tensors='pt', truncation=False)
         | 
| 54 | 
            +
                input_ids["input_ids"] = input_ids["input_ids"].cuda()
         | 
| 55 | 
            +
                input_ids["attention_mask"] = input_ids["attention_mask"].cuda()
         | 
| 56 | 
            +
                generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=max_token_output)
         | 
| 57 | 
            +
                thread = Thread(target=model.generate, kwargs=generation_kwargs)
         | 
| 58 | 
            +
                thread.start()
         | 
| 59 | 
            +
                return streamer
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            with gr.Blocks() as demo:
         | 
| 63 | 
            +
                # recreating the original qa demo in blocks
         | 
| 64 | 
            +
                with gr.Tab("QA Demo"):
         | 
| 65 | 
            +
                    with gr.Row():
         | 
| 66 | 
            +
                        instruction = gr.Textbox(label="Input")
         | 
| 67 | 
            +
                        output = gr.Textbox(label="Output")
         | 
| 68 | 
            +
                    greet_btn = gr.Button("Submit")
         | 
| 69 | 
            +
                    def yield_instruct(instruction):
         | 
| 70 | 
            +
                        # quick prompt hack:
         | 
| 71 | 
            +
                        instruction = "<|user|>\n" + instruction + "\n<|assistant|>\n"
         | 
| 72 | 
            +
                        output = ""
         | 
| 73 | 
            +
                        for token in instruct(instruction):
         | 
| 74 | 
            +
                            output += token
         | 
| 75 | 
            +
                            yield output
         | 
| 76 | 
            +
                    greet_btn.click(fn=yield_instruct, inputs=[instruction], outputs=output, api_name="greet")
         | 
| 77 | 
            +
                # chatbot-style model
         | 
| 78 | 
            +
                with gr.Tab("Chatbot"):
         | 
| 79 | 
            +
                    chatbot = gr.Chatbot([], elem_id="chatbot")
         | 
| 80 | 
            +
                    msg = gr.Textbox()
         | 
| 81 | 
            +
                    clear = gr.Button("Clear")
         | 
| 82 | 
            +
                    # fn to add user message to history
         | 
| 83 | 
            +
                    def user(user_message, history):
         | 
| 84 | 
            +
                        return "", history + [[user_message, None]]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def bot(history):
         | 
| 87 | 
            +
                    prompt = convert_history(history)
         | 
| 88 | 
            +
                    streaming_out = instruct(prompt)
         | 
| 89 | 
            +
                    history[-1][1] = ""
         | 
| 90 | 
            +
                    for new_token in streaming_out:
         | 
| 91 | 
            +
                        history[-1][1] += new_token
         | 
| 92 | 
            +
                        yield history
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
         | 
| 95 | 
            +
                    bot, chatbot, chatbot
         | 
| 96 | 
            +
                )
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                clear.click(lambda: None, None, chatbot, queue=False)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            if __name__ == "__main__":
         | 
| 101 | 
            +
                demo.queue().launch(share=True)
         | 
