Spaces:
Runtime error
Runtime error
File size: 4,393 Bytes
14f034e d85229e 14f034e d85229e 14f034e a0ac53d 14f034e d6fd1f0 14f034e d85229e 14f034e 6c420e0 14f034e 7564980 6c420e0 d6fd1f0 14f034e d6fd1f0 14f034e a0ac53d 14f034e 1c7aced 1c1b839 1c7aced 14f034e 6c420e0 d6fd1f0 6c420e0 14f034e 6c420e0 d6fd1f0 6c420e0 d6fd1f0 14f034e 8fe0728 14f034e 6c420e0 14f034e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import sys
import os
import argparse
import time
import subprocess
import llava.serve.gradio_web_server as gws
# Execute the pip install command with additional options
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
def start_controller():
print("Starting the controller")
controller_command = [
sys.executable,
"-m",
"llava.serve.controller",
"--host",
"0.0.0.0",
"--port",
"10000",
]
print(controller_command)
return subprocess.Popen(controller_command)
def start_worker(model_path: str, bits=16, revision='main', model_base = None, port=21002):
print(f"Starting the model worker for the model {model_path}")
model_name = model_path.strip("/").split("/")[-1]
assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
if bits != 16:
model_name += f"-{bits}bit"
worker_command = [
sys.executable,
"-m",
"llava.serve.model_worker",
"--host",
"0.0.0.0",
"--port",
port,
"--worker-address",
f"http://127.0.0.1:{port}",
"--controller",
"http://localhost:10000",
"--model-path",
model_path,
"--model-name",
model_name,
"--use-flash-attn",
"--revision",
revision,
]
if model_base:
worker_command += [
"--model-base",
model_base
]
if bits != 16:
worker_command += [f"--load-{bits}bit"]
print(worker_command)
return subprocess.Popen(worker_command)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
parser.add_argument("--concurrency-count", type=int, default=5)
parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
gws.args = parser.parse_args()
gws.models = []
gws.title_markdown += """
ONLY WORKS WITH GPU! By default, we load the model with 4-bit quantization to make it fit in smaller hardwares. Set the environment variable `bits` to control the quantization.
Set the environment variable `model` to change the model:
[`liuhaotian/llava-v1.6-mistral-7b`](https://huggingface.co/liuhaotian/llava-v1.6-mistral-7b),
[`liuhaotian/llava-v1.6-vicuna-7b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b),
[`liuhaotian/llava-v1.6-vicuna-13b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-13b),
[`liuhaotian/llava-v1.6-34b`](https://huggingface.co/liuhaotian/llava-v1.6-34b).
"""
print(f"args: {gws.args}")
model_paths = os.getenv("model", "nvandal/LLaVA-Med-v1.5-7b")
model_base = os.getenv("model_base", '')
revisions = os.getenv("revision", "main")
bits = int(os.getenv("bits", 4))
concurrency_count = int(os.getenv("concurrency_count", 5))
controller_proc = start_controller()
start_worker_port = 21002
model_paths = model_paths.split(';')
revisions = revisions.split(';')
model_base = model_base.split(';')
assert(len(model_paths)==len(revisions))
worker_proc = [None]*len(model_paths)
for i, (model_path, revision, model_base) in enumerate(zip(model_paths,revisions,model_base)):
print(model_path, revision, model_base)
worker_proc[i] = start_worker(model_path, bits=bits, revision=revision, model_base=model_base, port=str(start_worker_port+i))
# Wait for worker and controller to start
time.sleep(10)
exit_status = 0
try:
demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
demo.queue(
status_update_rate=10,
api_open=False
).launch(
server_name=gws.args.host,
server_port=gws.args.port,
share=gws.args.share
)
except Exception as e:
print(e)
exit_status = 1
finally:
for w in worker_proc:
w.kill()
controller_proc.kill()
sys.exit(exit_status)
|