Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| from typing import List | |
| import torch | |
| def create_positive_map_from_span(tokenized, token_span, max_text_len=256): | |
| """construct a map such that positive_map[i,j] = True iff box i is associated to token j | |
| Input: | |
| - tokenized: | |
| - input_ids: Tensor[1, ntokens] | |
| - attention_mask: Tensor[1, ntokens] | |
| - token_span: list with length num_boxes. | |
| - each item: [start_idx, end_idx] | |
| """ | |
| positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float) | |
| for j, tok_list in enumerate(token_span): | |
| for (beg, end) in tok_list: | |
| beg_pos = tokenized.char_to_token(beg) | |
| end_pos = tokenized.char_to_token(end - 1) | |
| if beg_pos is None: | |
| try: | |
| beg_pos = tokenized.char_to_token(beg + 1) | |
| if beg_pos is None: | |
| beg_pos = tokenized.char_to_token(beg + 2) | |
| except: | |
| beg_pos = None | |
| if end_pos is None: | |
| try: | |
| end_pos = tokenized.char_to_token(end - 2) | |
| if end_pos is None: | |
| end_pos = tokenized.char_to_token(end - 3) | |
| except: | |
| end_pos = None | |
| if beg_pos is None or end_pos is None: | |
| continue | |
| assert beg_pos is not None and end_pos is not None | |
| if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE": | |
| positive_map[j, beg_pos] = 1 | |
| break | |
| else: | |
| positive_map[j, beg_pos : end_pos + 1].fill_(1) | |
| return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) | |
| def build_captions_and_token_span(cat_list, force_lowercase): | |
| """ | |
| Return: | |
| captions: str | |
| cat2tokenspan: dict | |
| { | |
| 'dog': [[0, 2]], | |
| ... | |
| } | |
| """ | |
| cat2tokenspan = {} | |
| captions = "" | |
| for catname in cat_list: | |
| class_name = catname | |
| if force_lowercase: | |
| class_name = class_name.lower() | |
| if "/" in class_name: | |
| class_name_list: List = class_name.strip().split("/") | |
| class_name_list.append(class_name) | |
| class_name: str = random.choice(class_name_list) | |
| tokens_positive_i = [] | |
| subnamelist = [i.strip() for i in class_name.strip().split(" ")] | |
| for subname in subnamelist: | |
| if len(subname) == 0: | |
| continue | |
| if len(captions) > 0: | |
| captions = captions + " " | |
| strat_idx = len(captions) | |
| end_idx = strat_idx + len(subname) | |
| tokens_positive_i.append([strat_idx, end_idx]) | |
| captions = captions + subname | |
| if len(tokens_positive_i) > 0: | |
| captions = captions + " ." | |
| cat2tokenspan[class_name] = tokens_positive_i | |
| return captions, cat2tokenspan | |
| def build_id2posspan_and_caption(category_dict: dict): | |
| """Build id2pos_span and caption from category_dict | |
| Args: | |
| category_dict (dict): category_dict | |
| """ | |
| cat_list = [item["name"].lower() for item in category_dict] | |
| id2catname = {item["id"]: item["name"].lower() for item in category_dict} | |
| caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True) | |
| id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()} | |
| return id2posspan, caption | |