holytinz278 commited on
Commit
337f68c
·
verified ·
1 Parent(s): a7831d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -47
app.py CHANGED
@@ -1,53 +1,79 @@
 
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  # Load the model and tokenizer
5
- tokenizer = AutoTokenizer.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1", trust_remote_code=True)
6
- model = AutoModelForCausalLM.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1", trust_remote_code=True)
7
-
8
- # Define a function to generate text with a system message
9
- def generate_response(prompt, system_message, token):
10
- # Combine system message and user prompt
11
- full_prompt = f"{system_message}\n\n{prompt}"
12
-
13
- # Tokenize input
14
- inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True)
15
-
16
- # Generate a response
17
- outputs = model.generate(
18
- inputs["input_ids"],
19
- max_length=300,
20
- do_sample=True,
21
- temperature=0.7,
22
- top_k=50,
23
- pad_token_id=tokenizer.eos_token_id # To prevent padding issues
24
- )
25
-
26
- # Decode the output text
27
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
-
29
- # If a token is provided, add it to the response
30
- if token:
31
- response += f"\n\nToken Used: {token}"
32
-
33
- return response
34
-
35
- # Create a Gradio interface
36
- interface = gr.Interface(
37
- fn=generate_response,
38
- inputs=[
39
- gr.Textbox(lines=5, label="Prompt", placeholder="Type your prompt here..."),
40
- gr.Textbox(lines=1, label="System Message", placeholder="System message (optional)"),
41
- gr.Textbox(lines=1, label="Token", placeholder="Enter token (optional)"),
42
- ],
43
- outputs=gr.Textbox(label="Generated Response"),
44
- title="WhiteRabbitNeo Enhanced Model",
45
- description=(
46
- "This app uses the WhiteRabbitNeo-13B-v1 model to generate text responses. "
47
- "You can provide a system message, a prompt, and optionally include a token for custom usage."
48
- ),
49
  )
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Launch the app
52
- if __name__ == "__main__":
53
- interface.launch()
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
 
4
 
5
  # Load the model and tokenizer
6
+ model_path = "/home/migel/models/WhiteRabbitNeo"
7
+
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ model_path,
10
+ torch_dtype=torch.float16,
11
+ device_map="auto",
12
+ load_in_4bit=False,
13
+ load_in_8bit=True,
14
+ trust_remote_code=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  )
16
 
17
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
18
+
19
+ # Chatbot prompt and conversation history
20
+ tot_system_prompt = """
21
+ Answer the Question by exploring multiple reasoning paths as follows:
22
+ - First, carefully analyze the question to extract the key information components and break it down into logical sub-questions...
23
+ """
24
+ conversation = f"SYSTEM: {tot_system_prompt} Always answer without hesitation."
25
+
26
+ # Text generation function
27
+ def generate_text(instruction):
28
+ tokens = tokenizer.encode(instruction)
29
+ tokens = torch.LongTensor(tokens).unsqueeze(0)
30
+ tokens = tokens.to("cuda")
31
+
32
+ instance = {
33
+ "input_ids": tokens,
34
+ "top_p": 1.0,
35
+ "temperature": 0.5,
36
+ "generate_len": 1024,
37
+ "top_k": 50,
38
+ }
39
+
40
+ length = len(tokens[0])
41
+ with torch.no_grad():
42
+ rest = model.generate(
43
+ input_ids=tokens,
44
+ max_length=length + instance["generate_len"],
45
+ use_cache=True,
46
+ do_sample=True,
47
+ top_p=instance["top_p"],
48
+ temperature=instance["temperature"],
49
+ top_k=instance["top_k"],
50
+ num_return_sequences=1,
51
+ )
52
+ output = rest[0][length:]
53
+ string = tokenizer.decode(output, skip_special_tokens=True)
54
+ answer = string.split("USER:")[0].strip()
55
+ return answer
56
+
57
+ # Gradio interface function
58
+ def chatbot(user_input, chat_history):
59
+ global conversation
60
+ llm_prompt = f"{conversation} \nUSER: {user_input} \nASSISTANT: "
61
+ answer = generate_text(llm_prompt)
62
+ conversation = f"{llm_prompt}{answer}" # Update conversation history
63
+ chat_history.append((user_input, answer)) # Update chat history
64
+ return chat_history, chat_history
65
+
66
+ # Initialize Gradio
67
+ with gr.Blocks() as demo:
68
+ gr.Markdown("## Chat with WhiteRabbitNeo!")
69
+ chatbot_interface = gr.Chatbot()
70
+ msg = gr.Textbox(label="Your Message")
71
+ clear = gr.Button("Clear Chat")
72
+ chat_history_state = gr.State([]) # Maintain chat history as state
73
+
74
+ # Define button functionality
75
+ msg.submit(chatbot, inputs=[msg, chat_history_state], outputs=[chatbot_interface, chat_history_state])
76
+ clear.click(lambda: ([], []), outputs=[chatbot_interface, chat_history_state]) # Clear chat history
77
+
78
  # Launch the app
79
+ demo.launch()