Spaces:
Paused
Paused
alessandro trinca tornidor
commited on
Commit
·
f182d7a
1
Parent(s):
a5b4be9
[refactor] add and use create_placeholder_variables() function
Browse files- README.md +1 -0
- main.py +10 -17
- resources/placeholders/error_happened.png +3 -0
- resources/placeholders/no_seg_out.png +3 -0
- utils/utils.py +14 -0
README.md
CHANGED
|
@@ -321,3 +321,4 @@ If you find this project useful in your research, please consider citing:
|
|
| 321 |
|
| 322 |
## Acknowledgement
|
| 323 |
- This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA) and [SAM](https://github.com/facebookresearch/segment-anything).
|
|
|
|
|
|
| 321 |
|
| 322 |
## Acknowledgement
|
| 323 |
- This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA) and [SAM](https://github.com/facebookresearch/segment-anything).
|
| 324 |
+
- placeholders images (error, 'no output segmentation') from Muhammad Khaleeq (https://www.vecteezy.com/members/iyikon)
|
main.py
CHANGED
|
@@ -20,9 +20,7 @@ from model.LISA import LISAForCausalLM
|
|
| 20 |
from model.llava import conversation as conversation_lib
|
| 21 |
from model.llava.mm_utils import tokenizer_image_token
|
| 22 |
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 23 |
-
from utils import constants, session_logger
|
| 24 |
-
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 25 |
-
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
| 26 |
|
| 27 |
|
| 28 |
session_logger.change_logging(logging.DEBUG)
|
|
@@ -34,6 +32,7 @@ FASTAPI_STATIC = os.getenv("FASTAPI_STATIC")
|
|
| 34 |
os.makedirs(FASTAPI_STATIC, exist_ok=True)
|
| 35 |
app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
|
| 36 |
templates = Jinja2Templates(directory="templates")
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
@app.get("/health")
|
|
@@ -230,6 +229,7 @@ def get_inference_model_by_args(args_to_parse):
|
|
| 230 |
logging.info(f"args_to_parse:{args_to_parse}, creating model...")
|
| 231 |
model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
|
| 232 |
logging.info("created model, preparing inference function")
|
|
|
|
| 233 |
|
| 234 |
@session_logger.set_uuid_logging
|
| 235 |
def inference(input_str, input_image):
|
|
@@ -242,22 +242,19 @@ def get_inference_model_by_args(args_to_parse):
|
|
| 242 |
## input valid check
|
| 243 |
if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
|
| 244 |
output_str = "[Error] Invalid input: ", input_str
|
| 245 |
-
|
| 246 |
-
## error happened
|
| 247 |
-
output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1]
|
| 248 |
-
return output_image, output_str
|
| 249 |
|
| 250 |
# Model Inference
|
| 251 |
conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
|
| 252 |
conv.messages = []
|
| 253 |
|
| 254 |
prompt = input_str
|
| 255 |
-
prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
|
| 256 |
if args_to_parse.use_mm_start_end:
|
| 257 |
replace_token = (
|
| 258 |
-
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 259 |
)
|
| 260 |
-
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 261 |
|
| 262 |
conv.append_message(conv.roles[0], prompt)
|
| 263 |
conv.append_message(conv.roles[1], "")
|
|
@@ -300,7 +297,7 @@ def get_inference_model_by_args(args_to_parse):
|
|
| 300 |
max_new_tokens=512,
|
| 301 |
tokenizer=tokenizer,
|
| 302 |
)
|
| 303 |
-
output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
|
| 304 |
|
| 305 |
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
|
| 306 |
text_output = text_output.replace("\n", "").replace(" ", " ")
|
|
@@ -321,12 +318,8 @@ def get_inference_model_by_args(args_to_parse):
|
|
| 321 |
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
| 322 |
)[pred_mask]
|
| 323 |
|
| 324 |
-
output_str = f"
|
| 325 |
-
if save_img is
|
| 326 |
-
output_image = save_img # input_image
|
| 327 |
-
else:
|
| 328 |
-
## no seg output
|
| 329 |
-
output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
|
| 330 |
logging.info(f"output_image type: {type(output_image)}.")
|
| 331 |
return output_image, output_str
|
| 332 |
|
|
|
|
| 20 |
from model.llava import conversation as conversation_lib
|
| 21 |
from model.llava.mm_utils import tokenizer_image_token
|
| 22 |
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 23 |
+
from utils import constants, session_logger, utils
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
session_logger.change_logging(logging.DEBUG)
|
|
|
|
| 32 |
os.makedirs(FASTAPI_STATIC, exist_ok=True)
|
| 33 |
app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
|
| 34 |
templates = Jinja2Templates(directory="templates")
|
| 35 |
+
placeholders = utils.create_placeholder_variables()
|
| 36 |
|
| 37 |
|
| 38 |
@app.get("/health")
|
|
|
|
| 229 |
logging.info(f"args_to_parse:{args_to_parse}, creating model...")
|
| 230 |
model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
|
| 231 |
logging.info("created model, preparing inference function")
|
| 232 |
+
no_seg_out, error_happened = placeholders["no_seg_out"], placeholders["error_happened"]
|
| 233 |
|
| 234 |
@session_logger.set_uuid_logging
|
| 235 |
def inference(input_str, input_image):
|
|
|
|
| 242 |
## input valid check
|
| 243 |
if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
|
| 244 |
output_str = "[Error] Invalid input: ", input_str
|
| 245 |
+
return error_happened, output_str
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
# Model Inference
|
| 248 |
conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
|
| 249 |
conv.messages = []
|
| 250 |
|
| 251 |
prompt = input_str
|
| 252 |
+
prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt
|
| 253 |
if args_to_parse.use_mm_start_end:
|
| 254 |
replace_token = (
|
| 255 |
+
utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN
|
| 256 |
)
|
| 257 |
+
prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token)
|
| 258 |
|
| 259 |
conv.append_message(conv.roles[0], prompt)
|
| 260 |
conv.append_message(conv.roles[1], "")
|
|
|
|
| 297 |
max_new_tokens=512,
|
| 298 |
tokenizer=tokenizer,
|
| 299 |
)
|
| 300 |
+
output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX]
|
| 301 |
|
| 302 |
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
|
| 303 |
text_output = text_output.replace("\n", "").replace(" ", " ")
|
|
|
|
| 318 |
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
| 319 |
)[pred_mask]
|
| 320 |
|
| 321 |
+
output_str = f"ASSISTANT: {text_output}"
|
| 322 |
+
output_image = no_seg_out if save_img is None else save_img
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
logging.info(f"output_image type: {type(output_image)}.")
|
| 324 |
return output_image, output_str
|
| 325 |
|
resources/placeholders/error_happened.png
ADDED
|
Git LFS Details
|
resources/placeholders/no_seg_out.png
ADDED
|
Git LFS Details
|
utils/utils.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
from enum import Enum
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
|
|
|
|
| 7 |
IGNORE_INDEX = -100
|
| 8 |
IMAGE_TOKEN_INDEX = -200
|
| 9 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
|
@@ -40,6 +42,7 @@ ANSWER_LIST = [
|
|
| 40 |
"Sure, the segmentation result is [SEG].",
|
| 41 |
"[SEG].",
|
| 42 |
]
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
class Summary(Enum):
|
|
@@ -161,3 +164,14 @@ def dict_to_cuda(input_dict):
|
|
| 161 |
):
|
| 162 |
input_dict[k] = [ele.cuda(non_blocking=True) for ele in v]
|
| 163 |
return input_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from enum import Enum
|
| 2 |
+
from pathlib import Path
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
import torch.distributed as dist
|
| 7 |
|
| 8 |
+
|
| 9 |
IGNORE_INDEX = -100
|
| 10 |
IMAGE_TOKEN_INDEX = -200
|
| 11 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
|
|
|
| 42 |
"Sure, the segmentation result is [SEG].",
|
| 43 |
"[SEG].",
|
| 44 |
]
|
| 45 |
+
ROOT = Path(__file__).parent.parent
|
| 46 |
|
| 47 |
|
| 48 |
class Summary(Enum):
|
|
|
|
| 164 |
):
|
| 165 |
input_dict[k] = [ele.cuda(non_blocking=True) for ele in v]
|
| 166 |
return input_dict
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def create_placeholder_variables():
|
| 170 |
+
import cv2
|
| 171 |
+
|
| 172 |
+
no_seg_out = cv2.imread(str(ROOT / "resources" / "placeholder" / "no_seg_out.png"))[:, :, ::-1]
|
| 173 |
+
error_happened = cv2.imread(str(ROOT / "resources" / "placeholder" / "error_happened.png"))[:, :, ::-1]
|
| 174 |
+
return {
|
| 175 |
+
"no_seg_out": no_seg_out,
|
| 176 |
+
"error_happened": error_happened
|
| 177 |
+
}
|