from time import time |
from io import BytesIO |
import torch |
import streamlit as st |
import streamlit.components.v1 as components |
import numpy as np |
import torch |
import logging |
from os import environ |
from transformers import OwlViTProcessor, OwlViTForObjectDetection |
from myscaledb import Client |
from classifier import Classifier, prompt2vec, tune, SplitLayer |
from query_model import simple_query, topk_obj_query, rev_query |
from card_model import card, obj_card, style |
from box_utils import postprocess |
environ['TOKENIZERS_PARALLELISM'] = 'true' |
OBJ_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_objects" |
IMG_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_images" |
MODEL_ID = 'google/owlvit-base-patch32' |
DIMS = 512 |
qtime = 0 |
def build_model(name="google/owlvit-base-patch32"): |
"""Model builder function |
Args: |
name (str, optional): Name for HuggingFace OwlViT model. Defaults to "google/owlvit-base-patch32". |
Returns: |
(model, processor): OwlViT model and its processor for both image and text |
""" |
device = 'cpu' |
if torch.cuda.is_available(): |
device = 'cuda' |
model = OwlViTForObjectDetection.from_pretrained(name).to(device) |
processor = OwlViTProcessor.from_pretrained(name) |
return model, processor |
@st.experimental_singleton(show_spinner=False) |
def init_owlvit(): |
""" Initialize OwlViT Model |
Returns: |
model, processor |
""" |
model, processor = build_model(MODEL_ID) |
return model, processor |
@st.experimental_singleton(show_spinner=False) |
def init_db(): |
""" Initialize the Database Connection |
Returns: |
meta_field: Meta field that records if an image is viewed or not |
client: Database connection object |
""" |
meta = [] |
client = Client( |
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) |
assert client.is_alive() |
return meta, client |
def refresh_index(): |
""" Clean the session |
""" |
del st.session_state["meta"] |
st.session_state.meta = [] |
st.session_state.query_num = 0 |
logging.info(f"Refresh for '{st.session_state.meta}'") |
init_db.clear() |
st.session_state.meta, st.session_state.index = init_db() |
if 'clf' in st.session_state: |
del st.session_state.clf |
if 'xq' in st.session_state: |
del st.session_state.xq |
if 'topk_img_id' in st.session_state: |
del st.session_state.topk_img_id |
def query(xq, exclude_list=None): |
""" Query matched w.r.t a given vector |
In this part, we will retrieve A LOT OF data from the server, |
including TopK boxes and their embeddings, the counterpart of non-TopK boxes in TopK images. |
Args: |
xq (numpy.ndarray or list of floats): Query vector |
Returns: |
matches: list of Records object. Keys referrring to selected columns group by images. |
Exclude the user's viewlist. |
img_matches: list of Records object. Containing other non-TopK but hit objects among TopK images. |
side_matches: list of Records object. Containing REAL TopK objects disregard the user's view history |
""" |
attempt = 0 |
xq = xq |
xq = xq / np.linalg.norm(xq, axis=-1, ord=2, keepdims=True) |
status_bar = [st.empty(), st.empty()] |
status_bar[0].write("Retrieving Another TopK Images...") |
pbar = status_bar[1].progress(0) |
while attempt < 3: |
try: |
matches = topk_obj_query( |
st.session_state.index, xq, IMG_DB_NAME, OBJ_DB_NAME, |
exclude_list=exclude_list, topk=5000) |
img_ids = [r['img_id'] for r in matches] |
if 'topk_img_id' not in st.session_state: |
st.session_state.topk_img_id = img_ids |
status_bar[0].write("Retrieving TopK Images...") |
pbar.progress(25) |
o_matches = rev_query( |
st.session_state.index, xq, st.session_state.topk_img_id, |
IMG_DB_NAME, OBJ_DB_NAME, thresh=0.1) |
status_bar[0].write("Retrieving TopKs Objects...") |
pbar.progress(50) |
side_matches = simple_query(st.session_state.index, xq, IMG_DB_NAME, OBJ_DB_NAME, |
thresh=-1, topk=5000) |
status_bar[0].write( |
"Retrieving Non-TopK in Another TopK Images...") |
pbar.progress(75) |
if len(img_ids) > 0: |
img_matches = rev_query( |
st.session_state.index, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, |
thresh=0.1) |
else: |
img_matches = [] |
status_bar[0].write("DONE!") |
pbar.progress(100) |
break |
except Exception as e: |
logging.warning(str(e)) |
st.session_state.meta, st.session_state.index = init_db() |
attempt += 1 |
matches = [] |
_ = [s.empty() for s in status_bar] |
if len(matches) == 0: |
logging.error(f"No matches found for '{OBJ_DB_NAME}'") |
return matches, img_matches, side_matches, o_matches |
@st.experimental_singleton(show_spinner=False) |
def init_random_query(): |
"""Initialize a random query vector |
Returns: |
xq: a random vector |
""" |
xq = np.random.rand(1, DIMS) |
xq /= np.linalg.norm(xq, keepdims=True, axis=-1) |
return xq |
def submit(meta): |
""" Tune the model w.r.t given score from user. |
""" |
st.session_state.meta.extend(meta) |
st.session_state.step += 1 |
matches = st.session_state.matched_boxes |
X, y = list(zip(*((v[-1], |
st.session_state.text_prompts.index( |
st.session_state[f"label-{i}"])) for i, v in matches.items()))) |
st.session_state.xq = tune(st.session_state.clf, |
X, y, iters=int(st.session_state.iters)) |
st.session_state.matches, \ |
st.session_state.img_matches, \ |
st.session_state.side_matches, \ |
st.session_state.o_matches = query( |
st.session_state.xq, st.session_state.meta) |
st.write(style(), unsafe_allow_html=True) |
with st.spinner("Connecting DB..."): |
st.session_state.meta, st.session_state.index = init_db() |
with st.spinner("Loading Models..."): |
model, tokenizer = init_owlvit() |
if 'xq' not in st.session_state: |
with st.container(): |
st.title('Object Detection Safari') |
start = [st.empty() for _ in range(8)] |
start[0].info(""" |
We extracted boxes from **287,104** images in COCO Dataset, including its train / val / test / |
unlabeled images, collecting **165,371,904 boxes** which are then filtered with common prompts. |
You can search with almost any words or phrases you can think of. Please enjoy your journey of |
an adventure to COCO. |
""") |
prompt = start[1].text_input( |
"Prompt:", value="", placeholder="Examples: football, billboard, stop sign, watermark ...",) |
with start[2].container(): |
st.write( |
'You can search with multiple keywords. Plese separate with commas but with no space.') |
st.write('For example: `cat,dog,tree`') |
st.markdown(''' |
<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p> |
''', |
unsafe_allow_html=True) |
upld_model = start[4].file_uploader( |
"Or you can upload your previous run!", type='onnx') |
upld_btn = start[5].button( |
"Use Loaded Weights", disabled=upld_model is None, on_click=refresh_index) |
with start[3]: |
col = st.columns(8) |
has_no_prompt = (len(prompt) == 0 and upld_model is None) |
prompt_xq = col[6].button("Prompt", disabled=len( |
prompt) == 0, on_click=refresh_index) |
random_xq = col[7].button( |
"Random", disabled=not has_no_prompt, on_click=refresh_index) |
matches = [] |
img_matches = [] |
if random_xq: |
xq = init_random_query() |
st.session_state.xq = xq |
prompt = 'unknown' |
st.session_state.text_prompts = prompt.split(',') + ['none'] |
_ = [elem.empty() for elem in start] |
t0 = time() |
matches, img_matches, side_matches, o_matches = query( |
st.session_state.xq, st.session_state.meta) |
t1 = time() |
qtime = (t1-t0) * 1000 |
elif prompt_xq or upld_btn: |
if upld_model is not None: |
import onnx |
from onnx import numpy_helper |
_model = onnx.load(upld_model) |
st.session_state.text_prompts = [ |
node.name for node in _model.graph.output] + ['none'] |
weights = _model.graph.initializer |
xq = numpy_helper.to_array(weights[0]).T |
assert xq.shape[0] == len( |
st.session_state.text_prompts)-1 and xq.shape[1] == DIMS |
st.session_state.xq = xq |
_ = [elem.empty() for elem in start] |
else: |
logging.info(f"Input prompt is {prompt}") |
st.session_state.text_prompts = prompt.split(',') + ['none'] |
input_ids, xq = prompt2vec( |
st.session_state.text_prompts[:-1], model, tokenizer) |
st.session_state.xq = xq |
_ = [elem.empty() for elem in start] |
t0 = time() |
st.session_state.matches, \ |
st.session_state.img_matches, \ |
st.session_state.side_matches, \ |
st.session_state.o_matches = query( |
st.session_state.xq, st.session_state.meta) |
t1 = time() |
qtime = (t1-t0) * 1000 |
if 'xq' in st.session_state: |
o_matches = st.session_state.o_matches |
side_matches = st.session_state.side_matches |
img_matches = st.session_state.img_matches |
matches = st.session_state.matches |
if 'clf' not in st.session_state: |
st.session_state.clf = Classifier(st.session_state.xq) |
st.session_state.step = 0 |
if qtime > 0: |
st.info("Query done in {0:.2f} ms and returned {1:d} images with {2:d} boxes".format( |
qtime, len(matches), sum([len(m["box_id"]) + len(im["box_id"]) for m, im in zip(matches, img_matches)]))) |
st.session_state.dnld_model = BytesIO() |
torch.onnx.export(torch.nn.Sequential(st.session_state.clf.model, SplitLayer()), |
torch.zeros([1, len(st.session_state.xq[0])]), |
st.session_state.dnld_model, |
input_names=['input'], |
output_names=st.session_state.text_prompts[:-1]) |
dnld_nam = st.text_input('Download Name:', |
f'{("_".join([i.replace(" ", "-") for i in st.session_state.text_prompts[:-1]]) if "text_prompts" in st.session_state else "model")}.onnx', |
max_chars=50) |
dnld_btn = st.download_button('Download your classifier!', |
st.session_state.dnld_model, |
dnld_nam) |
side_bar_len = min(240 // len(st.session_state.text_prompts), 120) |
with st.sidebar: |
with st.expander("Top-K Images"): |
with st.container(): |
boxes_w_img, _ = postprocess(o_matches, st.session_state.text_prompts, |
None) |
boxes_w_img = sorted( |
boxes_w_img, key=lambda x: x[4], reverse=True) |
for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img: |
args = img_url, img_w, img_h, boxes |
st.write(card(*args), unsafe_allow_html=True) |
with st.expander("Top-K Objects", expanded=True): |
side_cols = st.columns( |
len(st.session_state.text_prompts[:-1])) |
for _cols, m in zip(side_cols, side_matches): |
with _cols.container(): |
for cx, cy, w, h, logit, img_url, img_w, img_h \ |
in zip(m['cx'], m['cy'], m['w'], m['h'], m['logit'], |
m['img_url'], m['img_w'], m['img_h']): |
st.write("{:s}: {:.4f}".format( |
st.session_state.text_prompts[m['label']], logit)) |
_html = obj_card( |
img_url, img_w, img_h, cx, cy, w, h, dst_len=side_bar_len) |
components.html( |
_html, side_bar_len, side_bar_len) |
with st.container(): |
with st.form("batch", clear_on_submit=False): |
col = st.columns([1, 9]) |
if len(matches) <= 0: |
st.warning( |
'Oops! We didn\'t find anything relevant to your query! Pleas try another one :/') |
else: |
st.session_state.iters = st.slider( |
"Number of Iterations to Update", min_value=0, max_value=10, step=1, value=2) |
col[1].form_submit_button( |
"Choose a new prompt", on_click=refresh_index) |
if len(matches) > 0: |
with st.container(): |
prompt_labels = st.session_state.text_prompts |
boxes_w_img, meta = postprocess(matches, st.session_state.text_prompts, |
img_matches) |
boxes_w_img = sorted( |
boxes_w_img, key=lambda x: x[4], reverse=True) |
st.session_state.matched_boxes = {} |
for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img: |
st.session_state.matched_boxes.update( |
{b[0]: b for b in boxes}) |
args = img_url, img_w, img_h, boxes |
with st.expander("{:s}: {:.4f}".format(img_id, img_score), expanded=True): |
ind_b = 0 |
img_row = st.columns([4, 2, 2, 2]) |
img_row[0].write( |
card(*args), unsafe_allow_html=True) |
for b in boxes: |
_id, cx, cy, w, h, label, logit, is_selected, _ = b |
with img_row[1 + ind_b % 3].container(): |
st.write( |
"{:s}: {:.4f}".format(label, logit)) |
_html = \ |
obj_card(img_url, img_w, img_h, |
*b[1:5], dst_len=120) |
components.html(_html, 120, 120) |
st.selectbox( |
"Class", |
prompt_labels, |
index=prompt_labels.index(label), |
key=f"label-{_id}") |
ind_b += 1 |
col[0].form_submit_button( |
"Train!", on_click=lambda: submit(meta)) |