Spestly commited on
Commit
ebac108
Β·
verified Β·
1 Parent(s): e1807a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -74
app.py CHANGED
@@ -30,7 +30,6 @@ MODELS = {
30
  "is_vision": False,
31
  "system_prompt_env": "ATLAS_PRO_0403",
32
  },
33
-
34
  }
35
 
36
  # Load default model
@@ -53,34 +52,21 @@ tokenizer, model = load_model(default_model)
53
  # Generate response function
54
  def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens):
55
  global tokenizer, model
56
- # Load the selected model
57
  selected_model = MODELS[model_key]["sizes"][model_size]
58
  if selected_model != default_model:
59
  tokenizer, model = load_model(selected_model)
60
 
61
- # Get the system prompt from the environment
62
  system_prompt_env = MODELS[model_key]["system_prompt_env"]
63
  system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.")
64
 
65
- # Construct instruction
66
  if MODELS[model_key]["is_vision"]:
67
- # If a vision model, include the image information
68
  image_info = "An image has been provided as input."
69
- instruction = (
70
- f"{system_prompt}\n\n"
71
- f"### Instruction:\n{message}\n{image_info}\n\n### Response:"
72
- )
73
  else:
74
- # For non-vision models
75
- instruction = (
76
- f"{system_prompt}\n\n"
77
- f"### Instruction:\n{message}\n\n### Response:"
78
- )
79
 
80
- # Tokenize input
81
  inputs = tokenizer(instruction, return_tensors="pt")
82
 
83
- # Generate response
84
  with torch.no_grad():
85
  outputs = model.generate(
86
  **inputs,
@@ -91,74 +77,61 @@ def generate_response(message, image, history, model_key, model_size, temperatur
91
  do_sample=True
92
  )
93
 
94
- # Decode response
95
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
  response = response.split("### Response:")[-1].strip()
97
  return response
98
 
99
- # User interface
100
  def create_interface():
101
- # Define input components
102
- message_input = gr.Textbox(label="Message", placeholder="Type your message here...")
103
- model_key_selector = gr.Dropdown(
104
- label="Model",
105
- choices=list(MODELS.keys()),
106
- value=default_model_key
107
- )
108
- model_size_selector = gr.Dropdown(
109
- label="Model Size",
110
- choices=list(MODELS[default_model_key]["sizes"].keys()),
111
- value=default_size
112
- )
113
- temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
114
- top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
115
- max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50)
116
- image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False)
117
 
118
- # Function to toggle visibility of image input
119
- def toggle_image_input(model_key):
120
- return MODELS[model_key]["is_vision"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- # Output components
123
- chat_output = gr.Chatbot(label="Chatbot")
 
 
 
 
 
124
 
125
- # Function to process inputs and generate output
126
- def process_inputs(message, image, model_key, model_size, temperature, top_p, max_new_tokens, history=[]):
127
- response = generate_response(
128
- message=message,
129
- image=image,
130
- history=history,
131
- model_key=model_key,
132
- model_size=model_size,
133
- temperature=temperature,
134
- top_p=top_p,
135
- max_new_tokens=max_new_tokens
136
  )
137
- history.append((message, response))
138
- return history
139
 
140
- # Interface layout
141
- iface = gr.Interface(
142
- fn=process_inputs,
143
- inputs=[
144
- message_input,
145
- image_input,
146
- model_key_selector,
147
- model_size_selector,
148
- temperature_slider,
149
- top_p_slider,
150
- max_tokens_slider
151
- ],
152
- outputs=chat_output,
153
- title="🌟 Atlas-Pro/Flash/Vision Interface",
154
- description="Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Pro (Comming Soon!). Upload images for vision models!",
155
- theme="soft",
156
- live=True
157
- )
158
 
159
- # Add event to toggle image input visibility
160
- iface.input_components[1].set_visibility(toggle_image_input(model_key_selector.value))
161
  return iface
162
 
163
- # Launch the app
164
- create_interface().launch()
 
30
  "is_vision": False,
31
  "system_prompt_env": "ATLAS_PRO_0403",
32
  },
 
33
  }
34
 
35
  # Load default model
 
52
  # Generate response function
53
  def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens):
54
  global tokenizer, model
 
55
  selected_model = MODELS[model_key]["sizes"][model_size]
56
  if selected_model != default_model:
57
  tokenizer, model = load_model(selected_model)
58
 
 
59
  system_prompt_env = MODELS[model_key]["system_prompt_env"]
60
  system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.")
61
 
 
62
  if MODELS[model_key]["is_vision"]:
 
63
  image_info = "An image has been provided as input."
64
+ instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n{image_info}\n\n### Response:"
 
 
 
65
  else:
66
+ instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
 
 
 
 
67
 
 
68
  inputs = tokenizer(instruction, return_tensors="pt")
69
 
 
70
  with torch.no_grad():
71
  outputs = model.generate(
72
  **inputs,
 
77
  do_sample=True
78
  )
79
 
 
80
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
81
  response = response.split("### Response:")[-1].strip()
82
  return response
83
 
 
84
  def create_interface():
85
+ with gr.Blocks(title="🌟 Atlas-Pro/Flash/Vision Interface", theme="soft") as iface:
86
+ gr.Markdown("Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Flash (Coming Soon!). Upload images for vision models!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ model_key_selector = gr.Dropdown(
89
+ label="Model",
90
+ choices=list(MODELS.keys()),
91
+ value=default_model_key
92
+ )
93
+ model_size_selector = gr.Dropdown(
94
+ label="Model Size",
95
+ choices=list(MODELS[default_model_key]["sizes"].keys()),
96
+ value=default_size
97
+ )
98
+ image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False)
99
+ message_input = gr.Textbox(label="Message", placeholder="Type your message here...")
100
+ temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
101
+ top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
102
+ max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50)
103
+ chat_output = gr.Chatbot(label="Chatbot")
104
+ submit_button = gr.Button("Submit")
105
 
106
+ def update_components(model_key):
107
+ model_info = MODELS[model_key]
108
+ new_sizes = list(model_info["sizes"].keys())
109
+ return [
110
+ gr.Dropdown(choices=new_sizes, value=new_sizes[0]),
111
+ gr.Image(visible=model_info["is_vision"])
112
+ ]
113
 
114
+ model_key_selector.change(
115
+ fn=update_components,
116
+ inputs=model_key_selector,
117
+ outputs=[model_size_selector, image_input]
 
 
 
 
 
 
 
118
  )
 
 
119
 
120
+ submit_button.click(
121
+ fn=generate_response,
122
+ inputs=[
123
+ message_input,
124
+ image_input,
125
+ chat_output,
126
+ model_key_selector,
127
+ model_size_selector,
128
+ temperature_slider,
129
+ top_p_slider,
130
+ max_tokens_slider
131
+ ],
132
+ outputs=chat_output
133
+ )
 
 
 
 
134
 
 
 
135
  return iface
136
 
137
+ create_interface().launch()