Spaces:
Runtime error
Runtime error
Commit
·
bb7e9ea
1
Parent(s):
a7e7927
Update chat.py
Browse files
chat.py
CHANGED
|
@@ -11,9 +11,10 @@ from vcoder_llava.mm_utils import process_images, load_image_from_base64, tokeni
|
|
| 11 |
from vcoder_llava.constants import (
|
| 12 |
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
|
| 13 |
SEG_TOKEN_INDEX, DEFAULT_SEG_TOKEN,
|
| 14 |
-
DEPTH_TOKEN_INDEX, DEFAULT_DEPTH_TOKEN
|
| 15 |
)
|
| 16 |
from transformers import TextIteratorStreamer
|
|
|
|
| 17 |
|
| 18 |
class Chat:
|
| 19 |
def __init__(self, model_path, model_base, model_name,
|
|
@@ -35,7 +36,7 @@ class Chat:
|
|
| 35 |
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
| 36 |
self.is_multimodal = 'llava' in self.model_name.lower()
|
| 37 |
self.is_seg = "vcoder" in self.model_name.lower()
|
| 38 |
-
self.is_depth =
|
| 39 |
|
| 40 |
@torch.inference_mode()
|
| 41 |
def generate_stream(self, params):
|
|
@@ -167,21 +168,21 @@ class Chat:
|
|
| 167 |
"text": server_error_msg,
|
| 168 |
"error_code": 1,
|
| 169 |
}
|
| 170 |
-
yield json.dumps(ret).encode()
|
| 171 |
except torch.cuda.CudaError as e:
|
| 172 |
print("Caught torch.cuda.CudaError:", e)
|
| 173 |
ret = {
|
| 174 |
"text": server_error_msg,
|
| 175 |
"error_code": 1,
|
| 176 |
}
|
| 177 |
-
yield json.dumps(ret).encode()
|
| 178 |
except Exception as e:
|
| 179 |
print("Caught Unknown Error", e)
|
| 180 |
ret = {
|
| 181 |
"text": server_error_msg,
|
| 182 |
"error_code": 1,
|
| 183 |
}
|
| 184 |
-
yield json.dumps(ret).encode()
|
| 185 |
|
| 186 |
|
| 187 |
if __name__ == "__main__":
|
|
|
|
| 11 |
from vcoder_llava.constants import (
|
| 12 |
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
|
| 13 |
SEG_TOKEN_INDEX, DEFAULT_SEG_TOKEN,
|
| 14 |
+
DEPTH_TOKEN_INDEX, DEFAULT_DEPTH_TOKEN,
|
| 15 |
)
|
| 16 |
from transformers import TextIteratorStreamer
|
| 17 |
+
from threading import Thread
|
| 18 |
|
| 19 |
class Chat:
|
| 20 |
def __init__(self, model_path, model_base, model_name,
|
|
|
|
| 36 |
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
| 37 |
self.is_multimodal = 'llava' in self.model_name.lower()
|
| 38 |
self.is_seg = "vcoder" in self.model_name.lower()
|
| 39 |
+
self.is_depth = "ds" in self.model_name.lower()
|
| 40 |
|
| 41 |
@torch.inference_mode()
|
| 42 |
def generate_stream(self, params):
|
|
|
|
| 168 |
"text": server_error_msg,
|
| 169 |
"error_code": 1,
|
| 170 |
}
|
| 171 |
+
yield json.dumps(ret).encode() + b"\0"
|
| 172 |
except torch.cuda.CudaError as e:
|
| 173 |
print("Caught torch.cuda.CudaError:", e)
|
| 174 |
ret = {
|
| 175 |
"text": server_error_msg,
|
| 176 |
"error_code": 1,
|
| 177 |
}
|
| 178 |
+
yield json.dumps(ret).encode() + b"\0"
|
| 179 |
except Exception as e:
|
| 180 |
print("Caught Unknown Error", e)
|
| 181 |
ret = {
|
| 182 |
"text": server_error_msg,
|
| 183 |
"error_code": 1,
|
| 184 |
}
|
| 185 |
+
yield json.dumps(ret).encode() + b"\0"
|
| 186 |
|
| 187 |
|
| 188 |
if __name__ == "__main__":
|