mrcuddle commited on
Commit
6897b6a
·
verified ·
1 Parent(s): 4cecdf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -30
app.py CHANGED
@@ -1,42 +1,35 @@
1
  import gradio as gr
2
  import subprocess
3
- import spaces
4
 
5
- @spaces.GPU
6
- # Function to execute hf_merge.py with provided parameters
7
- def run_hf_merge(repo_names, output_dir, staging='./staging', p=0.5, lambda_val=1.0, dry_run=False):
8
- repo_list = "\n".join(repo_names.split(","))
9
- with open("repo_list.txt", "w") as file:
10
- file.write(repo_list)
11
-
12
  command = [
13
- "python3", "hf_merge.py", "repo_list.txt", output_dir,
14
- "-staging", staging, "-p", str(p), "-lambda", str(lambda_val)
 
 
15
  ]
16
- if dry_run:
17
- command.append("--dry")
18
-
19
  result = subprocess.run(command, capture_output=True, text=True)
20
- if result.returncode == 0:
21
- return "Merge completed successfully.\n" + result.stdout
22
- else:
23
- return "Error during merge.\n" + result.stderr
24
 
25
- # Gradio interface
26
- interface = gr.Interface(
27
- fn=run_hf_merge,
28
  inputs=[
29
- gr.Textbox(label="Repository Names (comma-separated)"),
30
- gr.Textbox(label="Output Directory"),
31
- gr.Textbox(label="Staging Directory", value="./staging"),
32
- gr.Number(label="Dropout Probability", value=0.5),
33
- gr.Number(label="Scaling Factor", value=1.0),
34
- gr.Checkbox(label="Dry Run")
35
  ],
36
  outputs="text",
37
- title="HuggingFace Model Merger",
38
- description="Merge HuggingFace models using the technique described in 'Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch'."
39
  )
40
 
41
- if __name__ == "__main__":
42
- interface.launch()
 
1
  import gradio as gr
2
  import subprocess
 
3
 
4
+ def merge_models(weight_drop_prob, scaling_factor, base_model, model_to_merge, output_path):
5
+ # Construct the command to run hf_merge.py
 
 
 
 
 
6
  command = [
7
+ "python3", "hf_merge.py",
8
+ "-p", str(weight_drop_prob),
9
+ "-lambda", str(scaling_factor),
10
+ base_model, model_to_merge, output_path
11
  ]
12
+
13
+ # Run the command and capture the output
 
14
  result = subprocess.run(command, capture_output=True, text=True)
15
+
16
+ # Return the output of the command
17
+ return result.stdout
 
18
 
19
+ # Define the Gradio interface
20
+ iface = gr.Interface(
21
+ fn=merge_models,
22
  inputs=[
23
+ gr.inputs.Slider(minimum=0, maximum=1, default=0.13, label="Weight Drop Probability"),
24
+ gr.inputs.Number(default=3.0, label="Scaling Factor"),
25
+ gr.inputs.Textbox(label="Base Model File/Folder"),
26
+ gr.inputs.Textbox(label="Model to Merge"),
27
+ gr.inputs.Textbox(label="Output Path")
 
28
  ],
29
  outputs="text",
30
+ title="Model Merger",
31
+ description="Merge two models using the Super Mario merge method."
32
  )
33
 
34
+ # Launch the interface
35
+ iface.launch()