ruslanmv commited on
Commit
381d2e1
·
verified ·
1 Parent(s): c2dfdca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -32
app.py CHANGED
@@ -1,22 +1,34 @@
1
  import gradio as gr
 
2
 
3
- # Placeholder for model loading (adjust as needed for your specific models)
4
- def load_model(model_name):
5
- print(f"Loading {model_name}...")
6
- # Simulate different model behaviors (replace with actual model logic)
7
- if model_name == "DeepSeek-R1-Distill-Qwen-32B":
8
- return lambda input_text, history: f"Distilled Model Response to: {input_text}"
9
- elif model_name == "DeepSeek-R1":
10
- return lambda input_text, history: f"Base Model Response to: {input_text}"
11
- elif model_name == "DeepSeek-R1-Zero":
12
- return lambda input_text, history: f"Zero Model Response to: {input_text}"
13
- else:
14
- return lambda input_text, history: f"Default Response to: {input_text}"
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Load the models (placeholder functions here)
17
- deepseek_r1_distill = load_model("DeepSeek-R1-Distill-Qwen-32B")
18
- deepseek_r1 = load_model("DeepSeek-R1")
19
- deepseek_r1_zero = load_model("DeepSeek-R1-Zero")
20
 
21
  # --- Chatbot function ---
22
  def chatbot(input_text, history, model_choice, system_message, max_new_tokens, temperature, top_p):
@@ -24,22 +36,29 @@ def chatbot(input_text, history, model_choice, system_message, max_new_tokens, t
24
  print(f"Input: {input_text}, History: {history}, Model: {model_choice}")
25
 
26
  # Choose the model based on user selection
27
- if model_choice == "DeepSeek-R1-Distill-Qwen-32B":
28
- model_function = deepseek_r1_distill
29
- elif model_choice == "DeepSeek-R1":
30
- model_function = deepseek_r1
31
- elif model_choice == "DeepSeek-R1-Zero":
32
- model_function = deepseek_r1_zero
33
  else:
34
- model_function = lambda x, h: "Please select a model."
 
 
 
 
 
35
 
36
- # Simulate model response with parameters
37
- response = model_function(input_text, history)
38
- # Format the response for display (without parameter details in the main chat)
39
- display_response = f"{response}"
40
-
41
- history.append((input_text, display_response))
42
- return history, history, "", model_choice, system_message, max_new_tokens, temperature, top_p # Clear input, keep other parameters
 
 
43
 
44
  # --- Gradio Interface ---
45
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -63,7 +82,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
63
 
64
  # Options moved below the chat interface
65
  with gr.Row():
66
- with gr.Accordion("Options", open=True): # Changed label to "Options"
67
  model_choice = gr.Radio(
68
  choices=["DeepSeek-R1-Distill-Qwen-32B", "DeepSeek-R1", "DeepSeek-R1-Zero"],
69
  label="Choose a Model",
@@ -84,7 +103,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
84
  top_p = gr.Slider(
85
  minimum=0.10, maximum=1.00, value=0.90, label="Top-p (nucleus sampling)"
86
  )
87
-
88
 
89
  # Maintain chat history
90
  chat_history = gr.State([])
 
1
  import gradio as gr
2
+ import subprocess
3
 
4
+ # Function to load a model using Hugging Face Spaces and enable GPU
5
+ def load_model_with_gpu(model_name):
6
+ print(f"Attempting to load {model_name} with GPU enabled...")
7
+ try:
8
+ # Use subprocess to run hf.space_info and get GPU setting
9
+ result = subprocess.run(
10
+ ["python", "-c", f"from huggingface_hub import space_info; print(space_info('{model_name}').hardware)"],
11
+ capture_output=True,
12
+ text=True,
13
+ check=True
14
+ )
15
+ hardware = result.stdout.strip()
16
+ print(f"Hardware for {model_name}: {hardware}")
17
+
18
+ demo = gr.load(name=model_name, src="spaces")
19
+
20
+ # Return the loaded model demo
21
+ print(f"Successfully loaded {model_name}")
22
+ return demo
23
+
24
+ except Exception as e:
25
+ print(f"Error loading model {model_name}: {e}")
26
+ return None
27
 
28
+ # Load the models with GPU enabled (if available)
29
+ deepseek_r1_distill = load_model_with_gpu("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B")
30
+ deepseek_r1 = load_model_with_gpu("deepseek-ai/DeepSeek-R1")
31
+ deepseek_r1_zero = load_model_with_gpu("deepseek-ai/DeepSeek-R1-Zero")
32
 
33
  # --- Chatbot function ---
34
  def chatbot(input_text, history, model_choice, system_message, max_new_tokens, temperature, top_p):
 
36
  print(f"Input: {input_text}, History: {history}, Model: {model_choice}")
37
 
38
  # Choose the model based on user selection
39
+ if model_choice == "DeepSeek-R1-Distill-Qwen-32B" and deepseek_r1_distill:
40
+ model_demo = deepseek_r1_distill
41
+ elif model_choice == "DeepSeek-R1" and deepseek_r1:
42
+ model_demo = deepseek_r1
43
+ elif model_choice == "DeepSeek-R1-Zero" and deepseek_r1_zero:
44
+ model_demo = deepseek_r1_zero
45
  else:
46
+ default_response = "Model not selected or could not be loaded."
47
+ history.append((input_text, default_response))
48
+ return history, history, "", model_choice, system_message, max_new_tokens, temperature, top_p
49
+
50
+ # Adjust the call to the model, remove default_value if not applicable
51
+ model_output = model_demo(input_text, history, max_new_tokens, temperature, top_p, system_message)
52
 
53
+ # Check if model_output is iterable and has expected number of elements
54
+ if not isinstance(model_output, (list, tuple)) or len(model_output) < 2:
55
+ error_message = "Model output does not have the expected format."
56
+ history.append((input_text, error_message))
57
+ return history, history, "", model_choice, system_message, max_new_tokens, temperature, top_p
58
+
59
+ response = model_output[-1][1] if model_output[-1][1] else "Model did not return a response."
60
+ history.append((input_text, response))
61
+ return history, history, "", model_choice, system_message, max_new_tokens, temperature, top_p
62
 
63
  # --- Gradio Interface ---
64
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
82
 
83
  # Options moved below the chat interface
84
  with gr.Row():
85
+ with gr.Accordion("Options", open=True):
86
  model_choice = gr.Radio(
87
  choices=["DeepSeek-R1-Distill-Qwen-32B", "DeepSeek-R1", "DeepSeek-R1-Zero"],
88
  label="Choose a Model",
 
103
  top_p = gr.Slider(
104
  minimum=0.10, maximum=1.00, value=0.90, label="Top-p (nucleus sampling)"
105
  )
 
106
 
107
  # Maintain chat history
108
  chat_history = gr.State([])