|
import wx |
|
import wx.adv |
|
import wx.html |
|
import wx.html2 |
|
import subprocess |
|
import os |
|
import threading |
|
import queue |
|
import json |
|
import webbrowser |
|
import requests |
|
import sys |
|
|
|
def run_subprocess(cmd, output_queue): |
|
try: |
|
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) |
|
for line in process.stdout: |
|
output_queue.put(line) |
|
process.wait() |
|
if process.returncode == 0: |
|
output_queue.put("Process completed successfully!") |
|
else: |
|
output_queue.put(f"Process failed with return code {process.returncode}") |
|
except Exception as e: |
|
output_queue.put(f"An error occurred: {str(e)}") |
|
|
|
def update_output(output_text, output_queue): |
|
try: |
|
while True: |
|
line = output_queue.get_nowait() |
|
if wx.Window.FindWindowById(output_text.GetId()): |
|
wx.CallAfter(output_text.AppendText, line) |
|
else: |
|
return |
|
except queue.Empty: |
|
pass |
|
except RuntimeError: |
|
return |
|
wx.CallLater(100, update_output, output_text, output_queue) |
|
|
|
def open_store_folder(folder_path): |
|
if os.path.exists(folder_path): |
|
os.startfile(folder_path) |
|
else: |
|
wx.MessageBox(f"The folder {folder_path} does not exist.", "Error", wx.OK | wx.ICON_ERROR) |
|
|
|
class DarkThemedTextCtrl(wx.TextCtrl): |
|
def __init__(self, parent, id=wx.ID_ANY, value="", style=0): |
|
super().__init__(parent, id, value, style=style | wx.NO_BORDER) |
|
self.SetBackgroundColour(wx.Colour(0, 0, 0)) |
|
self.SetForegroundColour(wx.WHITE) |
|
|
|
class CollapsiblePanel(wx.Panel): |
|
def __init__(self, parent, title, *args, **kwargs): |
|
wx.Panel.__init__(self, parent, *args, **kwargs) |
|
self.SetBackgroundColour(parent.GetBackgroundColour()) |
|
|
|
self.toggle_button = wx.Button(self, label=title, style=wx.NO_BORDER) |
|
self.toggle_button.SetBackgroundColour(self.GetBackgroundColour()) |
|
self.toggle_button.Bind(wx.EVT_BUTTON, self.on_toggle) |
|
|
|
self.content_panel = wx.Panel(self) |
|
self.content_panel.SetBackgroundColour(self.GetBackgroundColour()) |
|
|
|
self.main_sizer = wx.BoxSizer(wx.VERTICAL) |
|
self.main_sizer.Add(self.toggle_button, 0, wx.EXPAND | wx.ALL, 5) |
|
self.main_sizer.Add(self.content_panel, 0, wx.EXPAND | wx.ALL, 5) |
|
|
|
self.SetSizer(self.main_sizer) |
|
self.collapsed = True |
|
self.toggle_button.SetLabel(f"▶ {title}") |
|
self.content_panel.Hide() |
|
|
|
def on_toggle(self, event): |
|
self.collapsed = not self.collapsed |
|
self.toggle_button.SetLabel(f"{'▶' if self.collapsed else '▼'} {self.toggle_button.GetLabel()[2:]}") |
|
self.content_panel.Show(not self.collapsed) |
|
self.Layout() |
|
self.GetParent().Layout() |
|
|
|
def get_content_panel(self): |
|
return self.content_panel |
|
|
|
class CustomToolTip(wx.PopupWindow): |
|
def __init__(self, parent, text): |
|
wx.PopupWindow.__init__(self, parent) |
|
|
|
|
|
panel = wx.Panel(self) |
|
self.st = wx.StaticText(panel, 1, text, pos=(10, 10)) |
|
|
|
font = wx.Font(8, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL, False, "Poppins") |
|
self.st.SetFont(font) |
|
|
|
size = self.st.GetBestSize() |
|
self.SetSize((size.width + 20, size.height + 20)) |
|
|
|
|
|
panel.SetSize(self.GetSize()) |
|
panel.SetBackgroundColour(wx.Colour(255, 255, 255)) |
|
|
|
|
|
panel.Bind(wx.EVT_PAINT, self.on_paint) |
|
|
|
def on_paint(self, event): |
|
|
|
panel = event.GetEventObject() |
|
dc = wx.PaintDC(panel) |
|
dc.SetPen(wx.Pen(wx.Colour(210, 210, 210), 1)) |
|
dc.SetBrush(wx.Brush(wx.Colour(255, 255, 255))) |
|
|
|
size = panel.GetSize() |
|
dc.DrawRectangle(0, 0, size.width, size.height) |
|
|
|
class MainFrame(wx.Frame): |
|
def __init__(self): |
|
super().__init__(parent=None, title="Music Source Separation Training & Inference GUI") |
|
self.SetSize(994, 670) |
|
self.SetBackgroundColour(wx.Colour(247, 248, 250)) |
|
|
|
icon = wx.Icon("gui/favicon.ico", wx.BITMAP_TYPE_ICO) |
|
self.SetIcon(icon) |
|
|
|
self.saved_combinations = {} |
|
|
|
|
|
self.Center() |
|
|
|
|
|
font_path = "gui/Poppins Regular 400.ttf" |
|
bold_font_path = "gui/Poppins Bold 700.ttf" |
|
wx.Font.AddPrivateFont(font_path) |
|
wx.Font.AddPrivateFont(bold_font_path) |
|
self.font = wx.Font(9, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL, False, "Poppins") |
|
self.bold_font = wx.Font(10, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_BOLD, False, "Poppins") |
|
self.SetFont(self.font) |
|
|
|
panel = wx.Panel(self) |
|
main_sizer = wx.BoxSizer(wx.VERTICAL) |
|
|
|
|
|
try: |
|
img = wx.Image("gui/mvsep.png", wx.BITMAP_TYPE_PNG) |
|
img_bitmap = wx.Bitmap(img) |
|
img_ctrl = wx.StaticBitmap(panel, -1, img_bitmap) |
|
main_sizer.Add(img_ctrl, 0, wx.ALIGN_CENTER | wx.TOP, 20) |
|
except: |
|
print("Failed to load image: gui/mvsep.png") |
|
|
|
|
|
title_text = wx.StaticText(panel, label="Music Source Separation Training && Inference GUI") |
|
title_text.SetFont(self.bold_font) |
|
title_text.SetForegroundColour(wx.BLACK) |
|
main_sizer.Add(title_text, 0, wx.ALIGN_CENTER | wx.TOP, 10) |
|
|
|
|
|
subtitle_text = wx.StaticText(panel, label="Code by ZFTurbo / GUI by Bas Curtiz") |
|
subtitle_text.SetForegroundColour(wx.BLACK) |
|
main_sizer.Add(subtitle_text, 0, wx.ALIGN_CENTER | wx.TOP, 5) |
|
|
|
|
|
github_link = wx.adv.HyperlinkCtrl(panel, -1, "GitHub Repository", "https://github.com/ZFTurbo/Music-Source-Separation-Training") |
|
github_link.SetNormalColour(wx.Colour(1, 118, 179)) |
|
github_link.SetHoverColour(wx.Colour(86, 91, 123)) |
|
main_sizer.Add(github_link, 0, wx.ALIGN_CENTER | wx.TOP, 10) |
|
|
|
|
|
download_models_btn = self.create_styled_button(panel, "Download Models", self.on_download_models) |
|
main_sizer.Add(download_models_btn, 0, wx.ALIGN_CENTER | wx.TOP | wx.BOTTOM, 10) |
|
|
|
|
|
self.training_panel = CollapsiblePanel(panel, "Training Configuration") |
|
self.training_panel.toggle_button.SetFont(self.bold_font) |
|
self.create_training_controls(self.training_panel.get_content_panel()) |
|
main_sizer.Add(self.training_panel, 0, wx.EXPAND | wx.ALL, 10) |
|
|
|
|
|
self.inference_panel = CollapsiblePanel(panel, "Inference Configuration") |
|
self.inference_panel.toggle_button.SetFont(self.bold_font) |
|
self.create_inference_controls(self.inference_panel.get_content_panel()) |
|
main_sizer.Add(self.inference_panel, 0, wx.EXPAND | wx.ALL, 10) |
|
|
|
panel.SetSizer(main_sizer) |
|
self.load_settings() |
|
|
|
def create_styled_button(self, parent, label, handler): |
|
btn = wx.Button(parent, label=label, style=wx.BORDER_NONE) |
|
btn.SetBackgroundColour(wx.Colour(1, 118, 179)) |
|
btn.SetForegroundColour(wx.WHITE) |
|
btn.SetFont(self.bold_font) |
|
|
|
def on_enter(event): |
|
btn.SetBackgroundColour(wx.Colour(86, 91, 123)) |
|
event.Skip() |
|
|
|
def on_leave(event): |
|
btn.SetBackgroundColour(wx.Colour(1, 118, 179)) |
|
event.Skip() |
|
|
|
def on_click(event): |
|
btn.SetBackgroundColour(wx.Colour(86, 91, 123)) |
|
handler(event) |
|
wx.CallLater(100, lambda: btn.SetBackgroundColour(wx.Colour(1, 118, 179))) |
|
|
|
btn.Bind(wx.EVT_ENTER_WINDOW, on_enter) |
|
btn.Bind(wx.EVT_LEAVE_WINDOW, on_leave) |
|
btn.Bind(wx.EVT_BUTTON, on_click) |
|
|
|
return btn |
|
|
|
def create_training_controls(self, panel): |
|
sizer = wx.BoxSizer(wx.VERTICAL) |
|
|
|
|
|
model_type_sizer = wx.BoxSizer(wx.HORIZONTAL) |
|
model_type_sizer.Add(wx.StaticText(panel, label="Model Type:"), 0, wx.ALIGN_CENTER_VERTICAL) |
|
self.model_type = wx.Choice(panel, choices=["apollo", "bandit", "bandit_v2", "bs_roformer", "htdemucs", "mdx23c", "mel_band_roformer", "scnet", "scnet_unofficial", "segm_models", "swin_upernet", "torchseg"]) |
|
self.model_type.SetFont(self.font) |
|
model_type_sizer.Add(self.model_type, 0, wx.LEFT, 5) |
|
sizer.Add(model_type_sizer, 0, wx.EXPAND | wx.ALL, 5) |
|
|
|
|
|
self.config_entry = self.add_browse_control(panel, sizer, "Config File:", is_folder=False, is_config=True) |
|
|
|
|
|
self.checkpoint_entry = self.add_browse_control(panel, sizer, "Checkpoint:", is_folder=False, is_checkpoint=True) |
|
|
|
|
|
self.result_path_entry = self.add_browse_control(panel, sizer, "Results Path:", is_folder=True) |
|
|
|
|
|
self.data_entry = self.add_browse_control(panel, sizer, "Data Paths (separated by ';'):", is_folder=True) |
|
|
|
|
|
self.valid_entry = self.add_browse_control(panel, sizer, "Validation Paths (separated by ';'):", is_folder=True) |
|
|
|
|
|
workers_device_sizer = wx.BoxSizer(wx.HORIZONTAL) |
|
|
|
workers_sizer = wx.BoxSizer(wx.HORIZONTAL) |
|
workers_sizer.Add(wx.StaticText(panel, label="Number of Workers:"), 0, wx.ALIGN_CENTER_VERTICAL) |
|
self.workers_entry = wx.TextCtrl(panel, value="4") |
|
self.workers_entry.SetFont(self.font) |
|
workers_sizer.Add(self.workers_entry, 0, wx.LEFT, 5) |
|
workers_device_sizer.Add(workers_sizer, 0, wx.EXPAND) |
|
|
|
device_sizer = wx.BoxSizer(wx.HORIZONTAL) |
|
device_sizer.Add(wx.StaticText(panel, label="Device IDs (comma-separated):"), 0, wx.ALIGN_CENTER_VERTICAL | wx.LEFT, 20) |
|
self.device_entry = wx.TextCtrl(panel, value="0") |
|
self.device_entry.SetFont(self.font) |
|
device_sizer.Add(self.device_entry, 0, wx.LEFT, 5) |
|
workers_device_sizer.Add(device_sizer, 0, wx.EXPAND) |
|
|
|
sizer.Add(workers_device_sizer, 0, wx.EXPAND | wx.ALL, 5) |
|
|
|
|
|
self.run_button = self.create_styled_button(panel, "Run Training", self.run_training) |
|
sizer.Add(self.run_button, 0, wx.ALIGN_CENTER | wx.ALL, 10) |
|
|
|
panel.SetSizer(sizer) |
|
|
|
def create_inference_controls(self, panel): |
|
sizer = wx.BoxSizer(wx.VERTICAL) |
|
|
|
|
|
infer_model_type_sizer = wx.BoxSizer(wx.HORIZONTAL) |
|
infer_model_type_sizer.Add(wx.StaticText(panel, label="Model Type:"), 0, wx.ALIGN_CENTER_VERTICAL) |
|
self.infer_model_type = wx.Choice(panel, choices=["apollo", "bandit", "bandit_v2", "bs_roformer", "htdemucs", "mdx23c", "mel_band_roformer", "scnet", "scnet_unofficial", "segm_models", "swin_upernet", "torchseg"]) |
|
self.infer_model_type.SetFont(self.font) |
|
infer_model_type_sizer.Add(self.infer_model_type, 0, wx.LEFT, 5) |
|
|
|
|
|
infer_model_type_sizer.Add(wx.StaticText(panel, label="Preset:"), 0, wx.ALIGN_CENTER_VERTICAL | wx.LEFT, 20) |
|
|
|
|
|
self.saved_combinations_dropdown = wx.Choice(panel, choices=[]) |
|
self.saved_combinations_dropdown.SetFont(self.font) |
|
self.saved_combinations_dropdown.Bind(wx.EVT_CHOICE, self.on_combination_selected) |
|
|
|
|
|
self.saved_combinations_dropdown.SetMinSize((358, -1)) |
|
|
|
|
|
infer_model_type_sizer.Add(self.saved_combinations_dropdown, 0, wx.LEFT, 5) |
|
|
|
|
|
plus_button = self.create_styled_button(panel, "+", self.on_save_combination) |
|
plus_button.SetMinSize((30, 30)) |
|
infer_model_type_sizer.Add(plus_button, 0, wx.LEFT, 5) |
|
|
|
|
|
help_button = wx.StaticText(panel, label="?") |
|
help_button.SetFont(self.bold_font) |
|
help_button.SetForegroundColour(wx.Colour(1, 118, 179)) |
|
tooltip_text = ("How to add a preset?\n\n" |
|
"1. Click Download Models\n" |
|
"2. Right-click a model's Config && Checkpoint\n" |
|
"3. Save link as && select a proper destination\n" |
|
"4. Copy the Model name\n" |
|
"5. Close Download Models\n\n" |
|
"6. Browse for the Config file\n" |
|
"7. Browse for the Checkpoint\n" |
|
"8. Select the Model Type\n" |
|
"9. Click the + button\n" |
|
"10. Paste the Model name && click OK\n\n" |
|
"On next use, just select it from the Preset dropdown.") |
|
|
|
self.tooltip = CustomToolTip(self, tooltip_text) |
|
self.tooltip.Hide() |
|
|
|
def on_help_enter(event): |
|
self.tooltip.Position(help_button.ClientToScreen((0, help_button.GetSize().height)), (0, 0)) |
|
self.tooltip.Show() |
|
|
|
def on_help_leave(event): |
|
self.tooltip.Hide() |
|
|
|
help_button.Bind(wx.EVT_ENTER_WINDOW, on_help_enter) |
|
help_button.Bind(wx.EVT_LEAVE_WINDOW, on_help_leave) |
|
|
|
infer_model_type_sizer.Add(help_button, 0, wx.LEFT | wx.ALIGN_CENTER_VERTICAL, 5) |
|
|
|
sizer.Add(infer_model_type_sizer, 0, wx.EXPAND | wx.ALL, 5) |
|
|
|
|
|
self.infer_config_entry = self.add_browse_control(panel, sizer, "Config File:", is_folder=False, is_config=True) |
|
|
|
|
|
self.infer_checkpoint_entry = self.add_browse_control(panel, sizer, "Checkpoint:", is_folder=False, is_checkpoint=True) |
|
|
|
|
|
self.infer_input_entry = self.add_browse_control(panel, sizer, "Input Folder:", is_folder=True) |
|
|
|
|
|
self.infer_store_entry = self.add_browse_control(panel, sizer, "Output Folder:", is_folder=True) |
|
|
|
|
|
self.extract_instrumental_checkbox = wx.CheckBox(panel, label="Extract Instrumental") |
|
self.extract_instrumental_checkbox.SetFont(self.font) |
|
sizer.Add(self.extract_instrumental_checkbox, 0, wx.EXPAND | wx.ALL, 5) |
|
|
|
|
|
self.run_infer_button = self.create_styled_button(panel, "Run Inference", self.run_inference) |
|
sizer.Add(self.run_infer_button, 0, wx.ALIGN_CENTER | wx.ALL, 10) |
|
|
|
panel.SetSizer(sizer) |
|
|
|
def add_browse_control(self, panel, sizer, label, is_folder=False, is_config=False, is_checkpoint=False): |
|
browse_sizer = wx.BoxSizer(wx.HORIZONTAL) |
|
browse_sizer.Add(wx.StaticText(panel, label=label), 0, wx.ALIGN_CENTER_VERTICAL) |
|
entry = wx.TextCtrl(panel) |
|
entry.SetFont(self.font) |
|
browse_sizer.Add(entry, 1, wx.EXPAND | wx.LEFT, 5) |
|
browse_button = self.create_styled_button(panel, "Browse", lambda event, entry=entry, is_folder=is_folder, is_config=is_config, is_checkpoint=is_checkpoint: self.browse(event, entry, is_folder, is_config, is_checkpoint)) |
|
browse_sizer.Add(browse_button, 0, wx.LEFT, 5) |
|
sizer.Add(browse_sizer, 0, wx.EXPAND | wx.ALL, 5) |
|
return entry |
|
|
|
def browse(self, event, entry, is_folder=False, is_config=False, is_checkpoint=False): |
|
if is_folder: |
|
dialog = wx.DirDialog(self, "Choose a directory", style=wx.DD_DEFAULT_STYLE | wx.DD_DIR_MUST_EXIST) |
|
else: |
|
wildcard = "All files (*.*)|*.*" |
|
if is_config: |
|
wildcard = "YAML files (*.yaml)|*.yaml" |
|
elif is_checkpoint: |
|
wildcard = "Checkpoint files (*.bin;*.chpt;*.ckpt;*.th)|*.bin;*.chpt;*.ckpt;*.th" |
|
|
|
dialog = wx.FileDialog(self, "Choose a file", style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST, wildcard=wildcard) |
|
|
|
dialog.SetFont(self.font) |
|
if dialog.ShowModal() == wx.ID_OK: |
|
entry.SetValue(dialog.GetPath()) |
|
dialog.Destroy() |
|
|
|
def create_output_window(self, title, folder_path): |
|
output_frame = wx.Frame(self, title=title, style=wx.DEFAULT_FRAME_STYLE | wx.STAY_ON_TOP) |
|
output_frame.SetIcon(self.GetIcon()) |
|
output_frame.SetSize(994, 670) |
|
output_frame.SetBackgroundColour(wx.Colour(0, 0, 0)) |
|
output_frame.SetFont(self.font) |
|
|
|
|
|
output_frame.SetPosition(self.GetPosition()) |
|
|
|
output_title = wx.StaticText(output_frame, label=title) |
|
output_title.SetFont(self.bold_font) |
|
output_title.SetForegroundColour(wx.WHITE) |
|
|
|
output_text = DarkThemedTextCtrl(output_frame, style=wx.TE_MULTILINE | wx.TE_READONLY) |
|
output_text.SetFont(self.font) |
|
|
|
open_folder_button = self.create_styled_button(output_frame, f"Open Output Folder", lambda event: open_store_folder(folder_path)) |
|
|
|
sizer = wx.BoxSizer(wx.VERTICAL) |
|
sizer.Add(output_title, 0, wx.ALIGN_CENTER | wx.TOP, 10) |
|
sizer.Add(output_text, 1, wx.EXPAND | wx.ALL, 10) |
|
sizer.Add(open_folder_button, 0, wx.ALIGN_CENTER | wx.BOTTOM, 10) |
|
output_frame.SetSizer(sizer) |
|
|
|
return output_frame, output_text |
|
|
|
def run_training(self, event): |
|
model_type = self.model_type.GetStringSelection() |
|
config_path = self.config_entry.GetValue() |
|
start_checkpoint = self.checkpoint_entry.GetValue() |
|
results_path = self.result_path_entry.GetValue() |
|
data_paths = self.data_entry.GetValue() |
|
valid_paths = self.valid_entry.GetValue() |
|
num_workers = self.workers_entry.GetValue() |
|
device_ids = self.device_entry.GetValue() |
|
|
|
if not model_type: |
|
wx.MessageBox("Please select a model type.", "Input Error", wx.OK | wx.ICON_ERROR) |
|
return |
|
if not config_path: |
|
wx.MessageBox("Please select a config file.", "Input Error", wx.OK | wx.ICON_ERROR) |
|
return |
|
if not results_path: |
|
wx.MessageBox("Please specify a results path.", "Input Error", wx.OK | wx.ICON_ERROR) |
|
return |
|
if not data_paths: |
|
wx.MessageBox("Please specify data paths.", "Input Error", wx.OK | wx.ICON_ERROR) |
|
return |
|
if not valid_paths: |
|
wx.MessageBox("Please specify validation paths.", "Input Error", wx.OK | wx.ICON_ERROR) |
|
return |
|
|
|
cmd = [ |
|
sys.executable, "train.py", |
|
"--model_type", model_type, |
|
"--config_path", config_path, |
|
"--results_path", results_path, |
|
"--data_path", *data_paths.split(';'), |
|
"--valid_path", *valid_paths.split(';'), |
|
"--num_workers", num_workers, |
|
"--device_ids", device_ids |
|
] |
|
|
|
if start_checkpoint: |
|
cmd += ["--start_check_point", start_checkpoint] |
|
|
|
output_queue = queue.Queue() |
|
threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start() |
|
|
|
output_frame, output_text = self.create_output_window("Training Output", results_path) |
|
output_frame.Show() |
|
update_output(output_text, output_queue) |
|
|
|
self.save_settings() |
|
|
|
def run_inference(self, event): |
|
model_type = self.infer_model_type.GetStringSelection() |
|
config_path = self.infer_config_entry.GetValue() |
|
start_checkpoint = self.infer_checkpoint_entry.GetValue() |
|
input_folder = self.infer_input_entry.GetValue() |
|
store_dir = self.infer_store_entry.GetValue() |
|
extract_instrumental = self.extract_instrumental_checkbox.GetValue() |
|
|
|
if not model_type: |
|
wx.MessageBox("Please select a model type.", "Input Error", wx.OK | wx.ICON_ERROR) |
|
return |
|
if not config_path: |
|
wx.MessageBox("Please select a config file.", "Input Error", wx.OK | wx.ICON_ERROR) |
|
return |
|
if not input_folder: |
|
wx.MessageBox("Please specify an input folder.", "Input Error", wx.OK | wx.ICON_ERROR) |
|
return |
|
if not store_dir: |
|
wx.MessageBox("Please specify an output folder.", "Input Error", wx.OK | wx.ICON_ERROR) |
|
return |
|
|
|
cmd = [ |
|
sys.executable, "inference.py", |
|
"--model_type", model_type, |
|
"--config_path", config_path, |
|
"--input_folder", input_folder, |
|
"--store_dir", store_dir |
|
] |
|
|
|
if start_checkpoint: |
|
cmd += ["--start_check_point", start_checkpoint] |
|
|
|
if extract_instrumental: |
|
cmd += ["--extract_instrumental"] |
|
|
|
output_queue = queue.Queue() |
|
threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start() |
|
|
|
output_frame, output_text = self.create_output_window("Inference Output", store_dir) |
|
output_frame.Show() |
|
update_output(output_text, output_queue) |
|
|
|
self.save_settings() |
|
|
|
def save_settings(self): |
|
settings = { |
|
"model_type": self.model_type.GetStringSelection(), |
|
"config_path": self.config_entry.GetValue(), |
|
"start_checkpoint": self.checkpoint_entry.GetValue(), |
|
"results_path": self.result_path_entry.GetValue(), |
|
"data_paths": self.data_entry.GetValue(), |
|
"valid_paths": self.valid_entry.GetValue(), |
|
"num_workers": self.workers_entry.GetValue(), |
|
"device_ids": self.device_entry.GetValue(), |
|
"infer_model_type": self.infer_model_type.GetStringSelection(), |
|
"infer_config_path": self.infer_config_entry.GetValue(), |
|
"infer_start_checkpoint": self.infer_checkpoint_entry.GetValue(), |
|
"infer_input_folder": self.infer_input_entry.GetValue(), |
|
"infer_store_dir": self.infer_store_entry.GetValue(), |
|
"extract_instrumental": self.extract_instrumental_checkbox.GetValue(), |
|
"saved_combinations": self.saved_combinations |
|
} |
|
with open("settings.json", "w") as f: |
|
json.dump(settings, f, indent=2, ensure_ascii=False) |
|
|
|
def load_settings(self): |
|
try: |
|
with open("settings.json", "r") as f: |
|
settings = json.load(f) |
|
|
|
self.model_type.SetStringSelection(settings.get("model_type", "")) |
|
self.config_entry.SetValue(settings.get("config_path", "")) |
|
self.checkpoint_entry.SetValue(settings.get("start_checkpoint", "")) |
|
self.result_path_entry.SetValue(settings.get("results_path", "")) |
|
self.data_entry.SetValue(settings.get("data_paths", "")) |
|
self.valid_entry.SetValue(settings.get("valid_paths", "")) |
|
self.workers_entry.SetValue(settings.get("num_workers", "4")) |
|
self.device_entry.SetValue(settings.get("device_ids", "0")) |
|
|
|
self.infer_model_type.SetStringSelection(settings.get("infer_model_type", "")) |
|
self.infer_config_entry.SetValue(settings.get("infer_config_path", "")) |
|
self.infer_checkpoint_entry.SetValue(settings.get("infer_start_checkpoint", "")) |
|
self.infer_input_entry.SetValue(settings.get("infer_input_folder", "")) |
|
self.infer_store_entry.SetValue(settings.get("infer_store_dir", "")) |
|
self.extract_instrumental_checkbox.SetValue(settings.get("extract_instrumental", False)) |
|
self.saved_combinations = settings.get("saved_combinations", {}) |
|
|
|
self.update_saved_combinations() |
|
except FileNotFoundError: |
|
pass |
|
|
|
def on_download_models(self, event): |
|
DownloadModelsFrame(self).Show() |
|
|
|
def on_save_combination(self, event): |
|
dialog = wx.TextEntryDialog(self, "Enter a name for this preset:", "Save Preset") |
|
if dialog.ShowModal() == wx.ID_OK: |
|
name = dialog.GetValue() |
|
if name: |
|
combination = { |
|
"model_type": self.infer_model_type.GetStringSelection(), |
|
"config_path": self.infer_config_entry.GetValue(), |
|
"checkpoint": self.infer_checkpoint_entry.GetValue() |
|
} |
|
self.saved_combinations[name] = combination |
|
self.update_saved_combinations() |
|
self.save_settings() |
|
dialog.Destroy() |
|
|
|
def on_combination_selected(self, event): |
|
name = self.saved_combinations_dropdown.GetStringSelection() |
|
if name: |
|
combination = self.saved_combinations.get(name) |
|
if combination: |
|
self.infer_model_type.SetStringSelection(combination["model_type"]) |
|
self.infer_config_entry.SetValue(combination["config_path"]) |
|
self.infer_checkpoint_entry.SetValue(combination["checkpoint"]) |
|
|
|
def update_saved_combinations(self): |
|
self.saved_combinations_dropdown.Clear() |
|
for name in self.saved_combinations.keys(): |
|
self.saved_combinations_dropdown.Append(name) |
|
|
|
class DownloadModelsFrame(wx.Frame): |
|
def __init__(self, parent): |
|
super().__init__(parent, title="Download Models", size=(994, 670), style=wx.DEFAULT_FRAME_STYLE & ~(wx.RESIZE_BORDER | wx.MAXIMIZE_BOX)) |
|
self.SetBackgroundColour(wx.Colour(247, 248, 250)) |
|
self.SetFont(wx.Font(9, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL, False, "Poppins")) |
|
|
|
|
|
self.SetPosition(parent.GetPosition()) |
|
|
|
|
|
icon = wx.Icon("gui/favicon.ico", wx.BITMAP_TYPE_ICO) |
|
self.SetIcon(icon) |
|
|
|
panel = wx.Panel(self) |
|
sizer = wx.BoxSizer(wx.VERTICAL) |
|
|
|
|
|
self.web_view = wx.html2.WebView.New(panel) |
|
self.web_view.LoadURL("https://bascurtiz.x10.mx/models-checkpoint-config-urls.html") |
|
self.web_view.Bind(wx.html2.EVT_WEBVIEW_NAVIGATING, self.on_link_click) |
|
self.web_view.Bind(wx.html2.EVT_WEBVIEW_NAVIGATED, self.on_page_load) |
|
sizer.Add(self.web_view, 1, wx.EXPAND) |
|
|
|
panel.SetSizer(sizer) |
|
|
|
def on_link_click(self, event): |
|
url = event.GetURL() |
|
if not url.startswith("https://bascurtiz.x10.mx"): |
|
event.Veto() |
|
webbrowser.open(url) |
|
|
|
def on_page_load(self, event): |
|
self.inject_custom_css() |
|
|
|
def inject_custom_css(self): |
|
css = """ |
|
body { |
|
margin: 0; |
|
padding: 0; |
|
} |
|
::-webkit-scrollbar { |
|
width: 12px; |
|
} |
|
::-webkit-scrollbar-track { |
|
background: #f1f1f1; |
|
} |
|
::-webkit-scrollbar-thumb { |
|
background: #888; |
|
} |
|
::-webkit-scrollbar-thumb:hover { |
|
background: #555; |
|
} |
|
""" |
|
js = f"var style = document.createElement('style'); style.textContent = `{css}`; document.head.appendChild(style);" |
|
self.web_view.RunScript(js) |
|
|
|
if __name__ == "__main__": |
|
app = wx.App() |
|
frame = MainFrame() |
|
frame.Show() |
|
app.MainLoop() |
|
|