Yongdong commited on
Commit
a4f228c
·
1 Parent(s): b4ebf95

Implement GGUF model support with DAG visualization

Browse files
Files changed (3) hide show
  1. app.py +116 -102
  2. dag_visualizer.py +334 -0
  3. requirements.txt +5 -4
app.py CHANGED
@@ -3,6 +3,7 @@ import spaces # Import spaces module for ZeroGPU
3
  from huggingface_hub import login
4
  import os
5
  from json_processor import JsonProcessor
 
6
  import json
7
 
8
  # 1) Read Secrets
@@ -12,8 +13,9 @@ if not hf_token:
12
  # 2) Login to ensure all subsequent from_pretrained calls have proper permissions
13
  login(hf_token)
14
 
15
- import torch
16
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
17
  import warnings
18
  import os
19
  warnings.filterwarnings("ignore")
@@ -22,28 +24,37 @@ warnings.filterwarnings("ignore")
22
  MODEL_CONFIGS = {
23
  "1B": {
24
  "name": "Dart-llm-model-1B",
25
- "gguf_model": "YongdongWang/llama-3.2-1b-lora-qlora-dart-llm-gguf"
 
 
26
  },
27
  "3B": {
28
  "name": "Dart-llm-model-3B",
29
- "gguf_model": "YongdongWang/llama-3.2-3b-lora-qlora-dart-llm-gguf"
 
 
30
  },
31
  "8B": {
32
  "name": "Dart-llm-model-8B",
33
- "gguf_model": "YongdongWang/llama-3.1-8b-lora-qlora-dart-llm-gguf"
 
 
34
  }
35
  }
36
 
37
  DEFAULT_MODEL = "1B" # Set 1B as default
38
 
39
  # Global variables to store model and tokenizer
40
- model = None
41
  tokenizer = None
42
  current_model_config = None
43
  model_loaded = False
44
 
 
 
 
45
  def load_model_and_tokenizer(selected_model=DEFAULT_MODEL):
46
- """Load tokenizer from GGUF model - executed on CPU"""
47
  global tokenizer, model_loaded, current_model_config
48
 
49
  if model_loaded and current_model_config == selected_model:
@@ -51,10 +62,10 @@ def load_model_and_tokenizer(selected_model=DEFAULT_MODEL):
51
 
52
  print(f"🔄 Loading tokenizer for {MODEL_CONFIGS[selected_model]['name']}...")
53
 
54
- # Load tokenizer from GGUF model repository
55
- gguf_model = MODEL_CONFIGS[selected_model]["gguf_model"]
56
  tokenizer = AutoTokenizer.from_pretrained(
57
- gguf_model,
58
  use_fast=False,
59
  trust_remote_code=True
60
  )
@@ -66,44 +77,51 @@ def load_model_and_tokenizer(selected_model=DEFAULT_MODEL):
66
  print("✅ Tokenizer loaded successfully!")
67
 
68
  @spaces.GPU(duration=60) # Request GPU for loading model at startup
69
- def load_model_on_gpu(selected_model=DEFAULT_MODEL):
70
- """Load GGUF model on GPU"""
71
- global model
72
 
73
  # If model is already loaded and it's the same model, return it
74
- if model is not None and current_model_config == selected_model:
75
- return model
76
 
77
  # Clear existing model if switching
78
- if model is not None:
79
  print("🗑️ Clearing existing model from GPU...")
80
- del model
81
- torch.cuda.empty_cache()
82
- model = None
83
 
84
  model_config = MODEL_CONFIGS[selected_model]
85
- print(f"🔄 Loading {model_config['name']} GGUF model on GPU...")
86
 
87
  try:
88
- # Load GGUF model directly (already quantized and merged)
89
- model = AutoModelForCausalLM.from_pretrained(
90
- model_config["gguf_model"],
91
- device_map="auto",
92
- torch_dtype=torch.float16,
93
- trust_remote_code=True,
94
- low_cpu_mem_usage=True
 
 
 
 
 
 
 
95
  )
96
- model.eval()
97
 
98
- print(f"✅ {model_config['name']} GGUF model loaded on GPU successfully!")
99
- return model
100
 
101
  except Exception as load_error:
102
  print(f"❌ GGUF Model loading failed: {load_error}")
103
  raise load_error
104
 
105
  def process_json_in_response(response):
106
- """Process and format JSON content in the response"""
 
 
107
  try:
108
  # Check if response contains JSON-like content
109
  if '{' in response and '}' in response:
@@ -115,6 +133,17 @@ def process_json_in_response(response):
115
  if processed_json:
116
  # Format the JSON nicely
117
  formatted_json = json.dumps(processed_json, indent=2, ensure_ascii=False)
 
 
 
 
 
 
 
 
 
 
 
118
  # Replace the JSON part in the response
119
  import re
120
  json_pattern = r'\{.*\}'
@@ -123,26 +152,22 @@ def process_json_in_response(response):
123
  # Replace the matched JSON with the formatted version
124
  response = response.replace(match.group(), formatted_json)
125
 
126
- return response
127
  except Exception:
128
  # If processing fails, return original response
129
- return response
130
 
131
  @spaces.GPU(duration=60) # GPU inference
132
  def generate_response_gpu(prompt, max_tokens=512, selected_model=DEFAULT_MODEL):
133
- """Generate response - executed on GPU"""
134
- global model
135
-
136
- # Ensure tokenizer is loaded
137
- if tokenizer is None or current_model_config != selected_model:
138
- load_model_and_tokenizer(selected_model)
139
 
140
  # Ensure model is loaded on GPU
141
- if model is None or current_model_config != selected_model:
142
- model = load_model_on_gpu(selected_model)
143
 
144
- if model is None:
145
- return "❌ Model failed to load. Please check the Space logs."
146
 
147
  try:
148
  formatted_prompt = (
@@ -151,67 +176,44 @@ def generate_response_gpu(prompt, max_tokens=512, selected_model=DEFAULT_MODEL):
151
  "### Response:\n"
152
  )
153
 
154
- # Encode input
155
- inputs = tokenizer(
156
- formatted_prompt,
157
- return_tensors="pt",
158
- truncation=True,
159
- max_length=2048
160
- ).to(model.device)
161
-
162
- # Generate response
163
- with torch.no_grad():
164
- outputs = model.generate(
165
- **inputs,
166
- max_new_tokens=max_tokens,
167
- do_sample=False,
168
- temperature=None,
169
- top_p=None,
170
- pad_token_id=tokenizer.pad_token_id,
171
- eos_token_id=tokenizer.eos_token_id,
172
- repetition_penalty=1.1,
173
- early_stopping=True,
174
- no_repeat_ngram_size=3
175
- )
176
-
177
- # Decode output
178
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
179
 
180
- # Extract generated part
181
- if "### Response:" in response:
182
- response = response.split("### Response:")[-1].strip()
183
- elif len(response) > len(formatted_prompt):
184
- response = response[len(formatted_prompt):].strip()
185
 
186
- # Process JSON if present in response
187
- response = process_json_in_response(response)
188
 
189
- return response if response else "❌ No response generated. Please try again with a different prompt."
190
 
191
  except Exception as generation_error:
192
- return f"❌ Generation Error: {str(generation_error)}"
193
 
194
  def chat_interface(message, history, max_tokens, selected_model):
195
  """Chat interface - runs on CPU, calls GPU functions"""
196
  if not message.strip():
197
- return history, ""
198
-
199
- # Initialize tokenizer (if needed)
200
- if tokenizer is None or current_model_config != selected_model:
201
- load_model_and_tokenizer(selected_model)
202
 
203
  try:
204
  # Call GPU function to generate response
205
- response = generate_response_gpu(message, max_tokens, selected_model)
206
  history.append((message, response))
207
- return history, ""
208
  except Exception as chat_error:
209
  error_msg = f"❌ Chat Error: {str(chat_error)}"
210
  history.append((message, error_msg))
211
- return history, ""
212
 
213
- # Load tokenizer at startup with default model
214
- load_model_and_tokenizer(DEFAULT_MODEL)
215
 
216
  # Create Gradio application
217
  with gr.Blocks(
@@ -229,27 +231,30 @@ with gr.Blocks(
229
 
230
  Choose from **three GGUF quantized models** specialized for **robot task planning** using QLoRA fine-tuning:
231
 
232
- - **🚀 Dart-llm-model-1B** (Default): Fastest inference, optimized for speed
233
- - **⚖️ Dart-llm-model-3B**: Balanced performance and quality
234
- - **🎯 Dart-llm-model-8B**: Best quality output, higher latency
235
 
236
- **GGUF Format**: These models are pre-quantized and optimized for efficient deployment, combining the base model and LoRA adaptations.
237
 
238
- **Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots.
 
 
 
239
 
240
- **Models**:
241
- - [YongdongWang/llama-3.2-1b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.2-1b-lora-qlora-dart-llm-gguf) (Default)
242
- - [YongdongWang/llama-3.2-3b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.2-3b-lora-qlora-dart-llm-gguf)
243
- - [YongdongWang/llama-3.1-8b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.1-8b-lora-qlora-dart-llm-gguf)
244
 
245
  ⚡ **Using ZeroGPU**: This Space uses dynamic GPU allocation (Nvidia H200). First generation might take a bit longer.
246
  """)
247
 
248
  with gr.Row():
249
- with gr.Column(scale=3):
250
  chatbot = gr.Chatbot(
251
  label="Task Planning Results",
252
- height=500,
253
  show_label=True,
254
  container=True,
255
  bubble_full_width=False,
@@ -269,6 +274,15 @@ with gr.Blocks(
269
  send_btn = gr.Button("🚀 Generate Tasks", variant="primary", size="sm")
270
  clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="sm")
271
 
 
 
 
 
 
 
 
 
 
272
  with gr.Column(scale=1):
273
  gr.Markdown("### ⚙️ Generation Settings")
274
 
@@ -317,18 +331,18 @@ with gr.Blocks(
317
  msg.submit(
318
  chat_interface,
319
  inputs=[msg, chatbot, max_tokens, model_selector],
320
- outputs=[chatbot, msg]
321
  )
322
 
323
  send_btn.click(
324
  chat_interface,
325
  inputs=[msg, chatbot, max_tokens, model_selector],
326
- outputs=[chatbot, msg]
327
  )
328
 
329
  clear_btn.click(
330
- lambda: ([], ""),
331
- outputs=[chatbot, msg]
332
  )
333
 
334
  if __name__ == "__main__":
 
3
  from huggingface_hub import login
4
  import os
5
  from json_processor import JsonProcessor
6
+ from dag_visualizer import DAGVisualizer
7
  import json
8
 
9
  # 1) Read Secrets
 
13
  # 2) Login to ensure all subsequent from_pretrained calls have proper permissions
14
  login(hf_token)
15
 
16
+ from transformers import AutoTokenizer
17
+ from huggingface_hub import hf_hub_download
18
+ from llama_cpp import Llama
19
  import warnings
20
  import os
21
  warnings.filterwarnings("ignore")
 
24
  MODEL_CONFIGS = {
25
  "1B": {
26
  "name": "Dart-llm-model-1B",
27
+ "base_model": "meta-llama/Llama-3.2-1B", # For tokenizer
28
+ "gguf_model": "YongdongWang/llama-3.2-1b-lora-qlora-dart-llm-gguf",
29
+ "gguf_file": "llama_3.2_1b-lora-qlora-dart-llm_q5_k_m.gguf"
30
  },
31
  "3B": {
32
  "name": "Dart-llm-model-3B",
33
+ "base_model": "meta-llama/Llama-3.2-3B", # For tokenizer
34
+ "gguf_model": "YongdongWang/llama-3.2-3b-lora-qlora-dart-llm-gguf",
35
+ "gguf_file": "llama_3.2_3b-lora-qlora-dart-llm_q4_k_m.gguf"
36
  },
37
  "8B": {
38
  "name": "Dart-llm-model-8B",
39
+ "base_model": "meta-llama/Llama-3.1-8B", # For tokenizer
40
+ "gguf_model": "YongdongWang/llama-3.1-8b-lora-qlora-dart-llm-gguf",
41
+ "gguf_file": "llama3.1-8b-lora-qlora-dart-llm_q4_k_m_fp16.gguf"
42
  }
43
  }
44
 
45
  DEFAULT_MODEL = "1B" # Set 1B as default
46
 
47
  # Global variables to store model and tokenizer
48
+ llm_model = None
49
  tokenizer = None
50
  current_model_config = None
51
  model_loaded = False
52
 
53
+ # Initialize DAG visualizer
54
+ dag_visualizer = DAGVisualizer()
55
+
56
  def load_model_and_tokenizer(selected_model=DEFAULT_MODEL):
57
+ """Load tokenizer - executed on CPU"""
58
  global tokenizer, model_loaded, current_model_config
59
 
60
  if model_loaded and current_model_config == selected_model:
 
62
 
63
  print(f"🔄 Loading tokenizer for {MODEL_CONFIGS[selected_model]['name']}...")
64
 
65
+ # Load tokenizer from base model
66
+ base_model = MODEL_CONFIGS[selected_model]["base_model"]
67
  tokenizer = AutoTokenizer.from_pretrained(
68
+ base_model,
69
  use_fast=False,
70
  trust_remote_code=True
71
  )
 
77
  print("✅ Tokenizer loaded successfully!")
78
 
79
  @spaces.GPU(duration=60) # Request GPU for loading model at startup
80
+ def load_gguf_model_on_gpu(selected_model=DEFAULT_MODEL):
81
+ """Load GGUF model using llama-cpp-python"""
82
+ global llm_model
83
 
84
  # If model is already loaded and it's the same model, return it
85
+ if llm_model is not None and current_model_config == selected_model:
86
+ return llm_model
87
 
88
  # Clear existing model if switching
89
+ if llm_model is not None:
90
  print("🗑️ Clearing existing model from GPU...")
91
+ del llm_model
92
+ llm_model = None
 
93
 
94
  model_config = MODEL_CONFIGS[selected_model]
95
+ print(f"🔄 Loading {model_config['name']} GGUF model...")
96
 
97
  try:
98
+ # Download GGUF model file from HuggingFace Hub
99
+ model_file = hf_hub_download(
100
+ repo_id=model_config["gguf_model"],
101
+ filename=model_config["gguf_file"],
102
+ cache_dir="./gguf_cache"
103
+ )
104
+ print(f"📦 Downloaded GGUF file: {model_file}")
105
+
106
+ # Load GGUF model with llama-cpp-python
107
+ llm_model = Llama(
108
+ model_path=model_file,
109
+ n_ctx=2048, # Context length
110
+ n_gpu_layers=-1, # Use all GPU layers if available
111
+ verbose=False
112
  )
 
113
 
114
+ print(f"✅ {model_config['name']} GGUF model loaded successfully!")
115
+ return llm_model
116
 
117
  except Exception as load_error:
118
  print(f"❌ GGUF Model loading failed: {load_error}")
119
  raise load_error
120
 
121
  def process_json_in_response(response):
122
+ """Process and format JSON content in the response, and generate DAG visualization"""
123
+ dag_image_path = None
124
+
125
  try:
126
  # Check if response contains JSON-like content
127
  if '{' in response and '}' in response:
 
133
  if processed_json:
134
  # Format the JSON nicely
135
  formatted_json = json.dumps(processed_json, indent=2, ensure_ascii=False)
136
+
137
+ # Generate DAG visualization if the JSON contains tasks
138
+ if "tasks" in processed_json and processed_json["tasks"]:
139
+ try:
140
+ dag_image_path = dag_visualizer.create_dag_visualization(
141
+ processed_json,
142
+ title="Robot Task Dependency Graph"
143
+ )
144
+ except Exception as e:
145
+ print(f"DAG visualization failed: {e}")
146
+
147
  # Replace the JSON part in the response
148
  import re
149
  json_pattern = r'\{.*\}'
 
152
  # Replace the matched JSON with the formatted version
153
  response = response.replace(match.group(), formatted_json)
154
 
155
+ return response, dag_image_path
156
  except Exception:
157
  # If processing fails, return original response
158
+ return response, None
159
 
160
  @spaces.GPU(duration=60) # GPU inference
161
  def generate_response_gpu(prompt, max_tokens=512, selected_model=DEFAULT_MODEL):
162
+ """Generate response using GGUF model - executed on GPU"""
163
+ global llm_model
 
 
 
 
164
 
165
  # Ensure model is loaded on GPU
166
+ if llm_model is None or current_model_config != selected_model:
167
+ llm_model = load_gguf_model_on_gpu(selected_model)
168
 
169
+ if llm_model is None:
170
+ return ("❌ GGUF Model failed to load. Please check the Space logs.", None)
171
 
172
  try:
173
  formatted_prompt = (
 
176
  "### Response:\n"
177
  )
178
 
179
+ # Generate response using llama-cpp-python
180
+ output = llm_model(
181
+ formatted_prompt,
182
+ max_tokens=max_tokens,
183
+ stop=["### Instruction:", "###"],
184
+ echo=False,
185
+ temperature=0.1,
186
+ top_p=0.9,
187
+ repeat_penalty=1.1
188
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ # Extract the generated text
191
+ response = output['choices'][0]['text'].strip()
 
 
 
192
 
193
+ # Process JSON if present in response and generate DAG
194
+ response, dag_image_path = process_json_in_response(response)
195
 
196
+ return (response if response else "❌ No response generated. Please try again with a different prompt.", dag_image_path)
197
 
198
  except Exception as generation_error:
199
+ return (f"❌ Generation Error: {str(generation_error)}", None)
200
 
201
  def chat_interface(message, history, max_tokens, selected_model):
202
  """Chat interface - runs on CPU, calls GPU functions"""
203
  if not message.strip():
204
+ return history, "", None
 
 
 
 
205
 
206
  try:
207
  # Call GPU function to generate response
208
+ response, dag_image_path = generate_response_gpu(message, max_tokens, selected_model)
209
  history.append((message, response))
210
+ return history, "", dag_image_path
211
  except Exception as chat_error:
212
  error_msg = f"❌ Chat Error: {str(chat_error)}"
213
  history.append((message, error_msg))
214
+ return history, "", None
215
 
216
+ # GGUF models include tokenizer, no separate loading needed
 
217
 
218
  # Create Gradio application
219
  with gr.Blocks(
 
231
 
232
  Choose from **three GGUF quantized models** specialized for **robot task planning** using QLoRA fine-tuning:
233
 
234
+ - **🚀 Dart-llm-model-1B** (Default): Fastest inference, Q5_K_M quantization
235
+ - **⚖️ Dart-llm-model-3B**: Balanced performance, Q4_K_M quantization
236
+ - **🎯 Dart-llm-model-8B**: Best quality output, Q4_K_M quantization
237
 
238
+ **GGUF Implementation**: Uses native GGUF format with llama-cpp-python for optimal memory efficiency and GPU acceleration.
239
 
240
+ **Capabilities**:
241
+ - Convert natural language robot commands into structured task sequences
242
+ - **NEW: Automatic DAG Visualization** - Generates visual dependency graphs for robot task sequences
243
+ - Support for excavators, dump trucks, and other construction robots
244
 
245
+ **GGUF Models**:
246
+ - [YongdongWang/llama-3.2-1b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.2-1b-lora-qlora-dart-llm-gguf) (Default - Q5_K_M)
247
+ - [YongdongWang/llama-3.2-3b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.2-3b-lora-qlora-dart-llm-gguf) (Q4_K_M)
248
+ - [YongdongWang/llama-3.1-8b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.1-8b-lora-qlora-dart-llm-gguf) (Q4_K_M)
249
 
250
  ⚡ **Using ZeroGPU**: This Space uses dynamic GPU allocation (Nvidia H200). First generation might take a bit longer.
251
  """)
252
 
253
  with gr.Row():
254
+ with gr.Column(scale=2):
255
  chatbot = gr.Chatbot(
256
  label="Task Planning Results",
257
+ height=400,
258
  show_label=True,
259
  container=True,
260
  bubble_full_width=False,
 
274
  send_btn = gr.Button("🚀 Generate Tasks", variant="primary", size="sm")
275
  clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="sm")
276
 
277
+ with gr.Column(scale=2):
278
+ dag_image = gr.Image(
279
+ label="Task Dependency Graph (DAG)",
280
+ show_label=True,
281
+ container=True,
282
+ height=400,
283
+ interactive=False
284
+ )
285
+
286
  with gr.Column(scale=1):
287
  gr.Markdown("### ⚙️ Generation Settings")
288
 
 
331
  msg.submit(
332
  chat_interface,
333
  inputs=[msg, chatbot, max_tokens, model_selector],
334
+ outputs=[chatbot, msg, dag_image]
335
  )
336
 
337
  send_btn.click(
338
  chat_interface,
339
  inputs=[msg, chatbot, max_tokens, model_selector],
340
+ outputs=[chatbot, msg, dag_image]
341
  )
342
 
343
  clear_btn.click(
344
+ lambda: ([], "", None),
345
+ outputs=[chatbot, msg, dag_image]
346
  )
347
 
348
  if __name__ == "__main__":
dag_visualizer.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib
3
+ matplotlib.use('Agg') # Use non-interactive backend for server environments
4
+ import networkx as nx
5
+ import json
6
+ import numpy as np
7
+ from loguru import logger
8
+ import os
9
+ import tempfile
10
+ from datetime import datetime
11
+
12
+ class DAGVisualizer:
13
+ def __init__(self):
14
+ # Configure Matplotlib to use IEEE-style parameters
15
+ plt.rcParams.update({
16
+ 'font.family': 'DejaVu Sans', # Use available font instead of Times New Roman
17
+ 'font.size': 10,
18
+ 'axes.linewidth': 1.2,
19
+ 'axes.labelsize': 12,
20
+ 'xtick.labelsize': 10,
21
+ 'ytick.labelsize': 10,
22
+ 'legend.fontsize': 10,
23
+ 'figure.titlesize': 14
24
+ })
25
+
26
+ def create_dag_from_tasks(self, task_data):
27
+ """
28
+ Create a directed graph from task data.
29
+
30
+ Args:
31
+ task_data: Dictionary containing tasks with structure like:
32
+ {
33
+ "tasks": [
34
+ {
35
+ "task": "task_name",
36
+ "instruction_function": {
37
+ "name": "function_name",
38
+ "robot_ids": ["robot1", "robot2"],
39
+ "dependencies": ["dependency_task"],
40
+ "object_keywords": ["object1", "object2"]
41
+ }
42
+ }
43
+ ]
44
+ }
45
+
46
+ Returns:
47
+ NetworkX DiGraph object
48
+ """
49
+ if not task_data or "tasks" not in task_data:
50
+ logger.warning("Invalid task data structure")
51
+ return None
52
+
53
+ # Create a directed graph
54
+ G = nx.DiGraph()
55
+
56
+ # Add nodes and store mapping from task name to ID
57
+ task_mapping = {}
58
+ for i, task in enumerate(task_data["tasks"]):
59
+ task_id = i + 1
60
+ task_name = task["task"]
61
+ task_mapping[task_name] = task_id
62
+
63
+ # Add node with attributes
64
+ G.add_node(task_id,
65
+ name=task_name,
66
+ function=task["instruction_function"]["name"],
67
+ robots=task["instruction_function"].get("robot_ids", []),
68
+ objects=task["instruction_function"].get("object_keywords", []))
69
+
70
+ # Add dependency edges
71
+ for i, task in enumerate(task_data["tasks"]):
72
+ task_id = i + 1
73
+ dependencies = task["instruction_function"]["dependencies"]
74
+ for dep in dependencies:
75
+ if dep in task_mapping:
76
+ dep_id = task_mapping[dep]
77
+ G.add_edge(dep_id, task_id)
78
+
79
+ return G
80
+
81
+ def calculate_layout(self, G):
82
+ """
83
+ Calculate hierarchical layout for the graph based on dependencies.
84
+ """
85
+ if not G:
86
+ return {}
87
+
88
+ # Calculate layers based on dependencies
89
+ layers = {}
90
+
91
+ def get_layer(node_id, visited=None):
92
+ if visited is None:
93
+ visited = set()
94
+ if node_id in visited:
95
+ return 0
96
+ visited.add(node_id)
97
+
98
+ predecessors = list(G.predecessors(node_id))
99
+ if not predecessors:
100
+ return 0
101
+ return max(get_layer(pred, visited.copy()) for pred in predecessors) + 1
102
+
103
+ for node in G.nodes():
104
+ layer = get_layer(node)
105
+ layers.setdefault(layer, []).append(node)
106
+
107
+ # Calculate positions by layer
108
+ pos = {}
109
+ layer_height = 3.0
110
+ node_width = 4.0
111
+
112
+ for layer_idx, nodes in layers.items():
113
+ y = layer_height * (len(layers) - 1 - layer_idx)
114
+ start_x = -(len(nodes) - 1) * node_width / 2
115
+ for i, node in enumerate(sorted(nodes)):
116
+ pos[node] = (start_x + i * node_width, y)
117
+
118
+ return pos
119
+
120
+ def create_dag_visualization(self, task_data, title="Robot Task Dependency Graph"):
121
+ """
122
+ Create a DAG visualization from task data and return the image path.
123
+
124
+ Args:
125
+ task_data: Task data dictionary
126
+ title: Title for the graph
127
+
128
+ Returns:
129
+ str: Path to the generated image file
130
+ """
131
+ try:
132
+ # Create graph
133
+ G = self.create_dag_from_tasks(task_data)
134
+ if not G or len(G.nodes()) == 0:
135
+ logger.warning("No tasks found or invalid graph structure")
136
+ return None
137
+
138
+ # Calculate layout
139
+ pos = self.calculate_layout(G)
140
+
141
+ # Create figure
142
+ fig, ax = plt.subplots(1, 1, figsize=(max(12, len(G.nodes()) * 2), 8))
143
+
144
+ # Draw edges with arrows
145
+ nx.draw_networkx_edges(G, pos,
146
+ edge_color='#2E86AB',
147
+ arrows=True,
148
+ arrowsize=20,
149
+ arrowstyle='->',
150
+ width=2,
151
+ alpha=0.8,
152
+ connectionstyle="arc3,rad=0.1")
153
+
154
+ # Color nodes based on their position in the graph
155
+ node_colors = []
156
+ for node in G.nodes():
157
+ if G.in_degree(node) == 0: # Start nodes
158
+ node_colors.append('#F24236')
159
+ elif G.out_degree(node) == 0: # End nodes
160
+ node_colors.append('#A23B72')
161
+ else: # Intermediate nodes
162
+ node_colors.append('#F18F01')
163
+
164
+ # Draw nodes
165
+ nx.draw_networkx_nodes(G, pos,
166
+ node_color=node_colors,
167
+ node_size=3500,
168
+ alpha=0.9,
169
+ edgecolors='black',
170
+ linewidths=2)
171
+
172
+ # Label nodes with task IDs
173
+ node_labels = {node: f"T{node}" for node in G.nodes()}
174
+ nx.draw_networkx_labels(G, pos, node_labels,
175
+ font_size=18,
176
+ font_weight='bold',
177
+ font_color='white')
178
+
179
+ # Add detailed info text boxes for each task
180
+ for i, node in enumerate(G.nodes()):
181
+ x, y = pos[node]
182
+ function_name = G.nodes[node]['function']
183
+ robots = G.nodes[node]['robots']
184
+ objects = G.nodes[node]['objects']
185
+
186
+ # Create info text content
187
+ info_text = f"Task {node}: {function_name.replace('_', ' ').title()}\n"
188
+ if robots:
189
+ robot_text = ", ".join([r.replace('robot_', '').replace('_', ' ').title() for r in robots])
190
+ info_text += f"Robots: {robot_text}\n"
191
+ if objects:
192
+ object_text = ", ".join(objects)
193
+ info_text += f"Objects: {object_text}"
194
+
195
+ # Calculate offset based on node position to avoid overlaps
196
+ offset_x = 2.2 if i % 2 == 0 else -2.2
197
+ offset_y = 0.5 if i % 4 < 2 else -0.5
198
+
199
+ # Choose alignment based on offset direction
200
+ h_align = 'left' if offset_x > 0 else 'right'
201
+
202
+ # Draw text box
203
+ bbox_props = dict(boxstyle="round,pad=0.4",
204
+ facecolor='white',
205
+ edgecolor='gray',
206
+ alpha=0.95,
207
+ linewidth=1)
208
+
209
+ ax.text(x + offset_x, y + offset_y, info_text,
210
+ bbox=bbox_props,
211
+ fontsize=12,
212
+ verticalalignment='center',
213
+ horizontalalignment=h_align,
214
+ weight='bold')
215
+
216
+ # Draw dashed connector line from node to text box
217
+ ax.plot([x, x + offset_x], [y, y + offset_y],
218
+ linestyle='--', color='gray', alpha=0.6, linewidth=1)
219
+
220
+ # Expand axis limits to fit everything
221
+ x_vals = [coord[0] for coord in pos.values()]
222
+ y_vals = [coord[1] for coord in pos.values()]
223
+ ax.set_xlim(min(x_vals) - 4.0, max(x_vals) + 4.0)
224
+ ax.set_ylim(min(y_vals) - 2.0, max(y_vals) + 2.0)
225
+
226
+ # Set overall figure properties
227
+ ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
228
+ ax.set_aspect('equal')
229
+ ax.margins(0.2)
230
+ ax.axis('off')
231
+
232
+ # Add legend for node types - Hidden to avoid covering content
233
+ # legend_elements = [
234
+ # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F24236',
235
+ # markersize=10, label='Start Tasks', markeredgecolor='black'),
236
+ # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#A23B72',
237
+ # markersize=10, label='End Tasks', markeredgecolor='black'),
238
+ # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F18F01',
239
+ # markersize=10, label='Intermediate Tasks', markeredgecolor='black'),
240
+ # plt.Line2D([0], [0], color='#2E86AB', linewidth=2, label='Dependencies')
241
+ # ]
242
+ # ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.05, 1.05))
243
+
244
+ # Adjust layout and save
245
+ plt.tight_layout()
246
+
247
+ # Create temporary file for saving the image
248
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
249
+ temp_dir = tempfile.gettempdir()
250
+ image_path = os.path.join(temp_dir, f'dag_visualization_{timestamp}.png')
251
+
252
+ plt.savefig(image_path, dpi=400, bbox_inches='tight',
253
+ pad_inches=0.1, facecolor='white', edgecolor='none')
254
+ plt.close(fig) # Close figure to free memory
255
+
256
+ logger.info(f"DAG visualization saved to: {image_path}")
257
+ return image_path
258
+
259
+ except Exception as e:
260
+ logger.error(f"Error creating DAG visualization: {e}")
261
+ return None
262
+
263
+ def create_simplified_dag_visualization(self, task_data, title="Robot Task Graph"):
264
+ """
265
+ Create a simplified DAG visualization suitable for smaller displays.
266
+
267
+ Args:
268
+ task_data: Task data dictionary
269
+ title: Title for the graph
270
+
271
+ Returns:
272
+ str: Path to the generated image file
273
+ """
274
+ try:
275
+ # Create graph
276
+ G = self.create_dag_from_tasks(task_data)
277
+ if not G or len(G.nodes()) == 0:
278
+ logger.warning("No tasks found or invalid graph structure")
279
+ return None
280
+
281
+ # Calculate layout
282
+ pos = self.calculate_layout(G)
283
+
284
+ # Create figure for simplified graph
285
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
286
+
287
+ # Draw edges
288
+ nx.draw_networkx_edges(G, pos,
289
+ edge_color='black',
290
+ arrows=True,
291
+ arrowsize=15,
292
+ arrowstyle='->',
293
+ width=1.5)
294
+
295
+ # Draw nodes
296
+ nx.draw_networkx_nodes(G, pos,
297
+ node_color='lightblue',
298
+ node_size=3000,
299
+ edgecolors='black',
300
+ linewidths=1.5)
301
+
302
+ # Add node labels with simplified names
303
+ labels = {}
304
+ for node in G.nodes():
305
+ function_name = G.nodes[node]['function']
306
+ simplified_name = function_name.replace('_', ' ').title()
307
+ if len(simplified_name) > 15:
308
+ simplified_name = simplified_name[:12] + "..."
309
+ labels[node] = f"T{node}\n{simplified_name}"
310
+
311
+ nx.draw_networkx_labels(G, pos, labels,
312
+ font_size=11,
313
+ font_weight='bold')
314
+
315
+ ax.set_title(title, fontsize=14, fontweight='bold')
316
+ ax.axis('off')
317
+
318
+ # Adjust layout and save
319
+ plt.tight_layout()
320
+
321
+ # Create temporary file for saving the image
322
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
323
+ temp_dir = tempfile.gettempdir()
324
+ image_path = os.path.join(temp_dir, f'simple_dag_{timestamp}.png')
325
+
326
+ plt.savefig(image_path, dpi=400, bbox_inches='tight')
327
+ plt.close(fig) # Close figure to free memory
328
+
329
+ logger.info(f"Simplified DAG visualization saved to: {image_path}")
330
+ return image_path
331
+
332
+ except Exception as e:
333
+ logger.error(f"Error creating simplified DAG visualization: {e}")
334
+ return None
requirements.txt CHANGED
@@ -1,12 +1,13 @@
1
  pydantic
2
  gradio
3
  transformers
4
- torch
5
- peft
6
- bitsandbytes
7
  accelerate
8
  scipy
9
  sentencepiece
10
  protobuf
11
- spaces
12
  loguru
 
 
 
 
 
 
1
  pydantic
2
  gradio
3
  transformers
 
 
 
4
  accelerate
5
  scipy
6
  sentencepiece
7
  protobuf
 
8
  loguru
9
+ matplotlib
10
+ networkx
11
+ numpy
12
+ llama-cpp-python
13
+ huggingface_hub