Samarth991 commited on
Commit
0e78cbf
·
1 Parent(s): bab4b66

adding CV agent file

Browse files
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+ Data/
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+ *.json
164
+ image_store/
QA_bot.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import re
4
+ import time
5
+ from PIL import Image
6
+ import ast
7
+ import numpy as np
8
+
9
+ def reset_conversation():
10
+ st.session_state.messages = []
11
+
12
+ def display_mask_image(image_path):
13
+ if os.path.isfile(image_path):
14
+ image = Image.open(image_path)
15
+ st.image(image, caption='Final Mask', use_column_width=True)
16
+
17
+
18
+ def tyre_synap_bot(filter_agent,image_file_path):
19
+ if "messages" not in st.session_state:
20
+ st.session_state.messages = []
21
+
22
+ print("Found image file path: ",image_file_path)
23
+ # Display chat messages from history on app rerun
24
+ for message in st.session_state.messages:
25
+ with st.chat_message(message["role"]):
26
+ st.markdown(message["content"])
27
+
28
+ # React to user input
29
+ if prompt := st.chat_input("What is up?"):
30
+ # Display user message in chat message container
31
+ st.chat_message("user").markdown(prompt)
32
+ # Add user message to chat history
33
+ st.session_state.messages.append({"role": "user", "content": prompt})
34
+
35
+ ai_response = filter_agent.invoke(
36
+ {
37
+ "input": f'{prompt}, provided image path: {image_file_path}'
38
+ }
39
+ )
40
+
41
+ # ai_response = filter_agent.run(f'{prompt} provided image path :{image_file_path}')
42
+
43
+ response = f"Echo: {ai_response['output']}"
44
+ with st.chat_message("assistant"):
45
+ message_placeholder = st.empty()
46
+ full_response = ""
47
+ if 'mask' in ai_response['output']:
48
+ display_mask_image('final_mask.png')
49
+
50
+ for chunk in re.split(r'(\s+)', response):
51
+ full_response += chunk + " "
52
+ time.sleep(0.01)
53
+ # Add a blinking cursor to simulate typing
54
+ message_placeholder.markdown(full_response + "▌")
55
+ # Add assistant response to chat history
56
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
57
+ st.button('Reset Chat', on_click=reset_conversation)
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from PIL import Image
4
+ from pathlib import Path
5
+ from QA_bot import tyre_synap_bot as bot
6
+ from llm_service import get_llm
7
+ from hub_prompts import PREFIX
8
+
9
+ from extract_tools import get_all_tools
10
+ from langchain.agents import AgentExecutor
11
+ from langchain import hub
12
+ from langchain.agents.format_scratchpad import format_log_to_str
13
+ from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
14
+ from langchain.tools.render import render_text_description
15
+
16
+ import logging
17
+ import warnings
18
+ warnings.filterwarnings("ignore")
19
+
20
+ logging.basicConfig(filename="newfile.log",
21
+ format='%(asctime)s %(message)s',
22
+ filemode='w')
23
+ logger = logging.getLogger()
24
+
25
+ llm = None
26
+ tools = None
27
+ cv_agent = None
28
+
29
+ @st.cache_resource
30
+ def call_llmservice_model(option,api_key):
31
+ model = get_llm(option=option,key=api_key)
32
+ return model
33
+
34
+ @st.cache_resource
35
+ def setup_agent_prompt():
36
+ prompt = hub.pull("hwchase17/react-json")
37
+ if len(tools) == 0 :
38
+ logger.error ("No Tools added")
39
+ else :
40
+ prompt = prompt.partial(
41
+ tools= render_text_description(tools),
42
+ tool_names= ", ".join([t.name for t in tools]),
43
+ additional_kwargs={
44
+ 'system_message':PREFIX,
45
+ }
46
+ )
47
+ return prompt
48
+
49
+ @st.cache_resource
50
+ def agent_initalize():
51
+ agent_prompt = setup_agent_prompt()
52
+ lm_with_stop = llm.bind(stop=["\nObservation"])
53
+ #### we can use create_react_agent https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/react/agent.py
54
+ agent = (
55
+ {
56
+ "input": lambda x: x["input"],
57
+ "agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
58
+ }
59
+ | agent_prompt
60
+ | lm_with_stop
61
+ | ReActJsonSingleInputOutputParser()
62
+ )
63
+
64
+ # instantiate AgentExecutor
65
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True,handle_parsing_errors=True)
66
+ return agent_executor
67
+
68
+ # def agent_initalize(tools,max_iterations=5):
69
+ # zero_shot_agent = initialize_agent(
70
+ # agent= AgentType.ZERO_SHOT_REACT_DESCRIPTION,
71
+ # tools = tools,
72
+ # llm = llm,
73
+ # verbose = True,
74
+ # max_iterations = max_iterations,
75
+ # memory = None,
76
+ # handle_parsing_errors=True,
77
+ # agent_kwargs={
78
+ # 'system_message':PREFIX,
79
+ # # 'format_instructions':FORMAT_INSTRUCTIONS,
80
+ # # 'suffix':SUFFIX
81
+ # }
82
+ # )
83
+ # # sys_message = PREFIX
84
+ # # zero_shot_agent.agent.llm_chain.prompt.template = sys_message
85
+ # return zero_shot_agent
86
+
87
+
88
+ def main():
89
+ database_store = 'image_store'
90
+ st.session_state.disabled = False
91
+ st.session_state.visibility = "visible"
92
+
93
+ st.title("Computer Vision Agent :sunglasses:")
94
+ st.markdown("Use the CV agent to do Object Detection/Panoptic Segementation/Image Segmentation/Image Descrption ")
95
+ st.markdown(
96
+ """
97
+ <style>
98
+ section[data-testid="stSidebar"] {
99
+ width: 350px !important; # Set the width to your desired value
100
+ }
101
+ </style>
102
+ """,
103
+ unsafe_allow_html=True,
104
+ )
105
+
106
+ with st.sidebar:
107
+ st.header("About Project")
108
+ st.markdown(
109
+ """
110
+ - Agent to filter images on basis multiple factors like image quality , object proportion in image , weather in the image .
111
+ - This application uses multiple tools like Image caption tool, DuckDuckGo search tool, Maskformer tool , weather predictor.
112
+ """)
113
+ st.sidebar.subheader("Upload Image !")
114
+ option = st.sidebar.selectbox(
115
+ "Select the Large Language Model ",("deepseek-r1-distill-llama-70b",
116
+ "gemma2-9b-it",
117
+ "llama-3.2-3b-preview",
118
+ "llama-3.2-1b-preview",
119
+ "llama3-8b-8192",
120
+ "Openai",
121
+ "Google",
122
+ "Ollama"),
123
+ index=None,
124
+ placeholder="Select LLM Service...",
125
+ )
126
+ api_key = st.sidebar.text_input("API_KEY", type="password", key="password")
127
+
128
+ uploaded_file = st.sidebar.file_uploader("Upload Image for Processing", type=['png','jpg','jpeg'])
129
+
130
+ if uploaded_file is not None :
131
+ file_path = Path(database_store, uploaded_file.name)
132
+ if not os.path.isdir(database_store):
133
+ os.makedirs(database_store)
134
+
135
+ global llm
136
+ llm = call_llmservice_model(option=option,api_key=api_key)
137
+ logger.info("\tLLM Service {} Active ... !".format(llm.get_name()))
138
+ ## extract tools
139
+ global tools
140
+ tools = get_all_tools()
141
+ logger.info("\tFound {} tools ".format(len(tools)))
142
+ ## generate Agent
143
+ global agent
144
+ cv_agent = agent_initalize()
145
+ logger.info('\tAgent inintalized with {} tools '.format(len(tools)))
146
+
147
+ with open(file_path, mode='wb') as w:
148
+ w.write(uploaded_file.getvalue())
149
+
150
+ if os.path.isfile(file_path):
151
+ st.sidebar.success("File uploaded successfully",icon="✅")
152
+
153
+ with st.sidebar.container():
154
+ image = Image.open(file_path)
155
+ st.image(image,use_container_width=True)
156
+ st.sidebar.subheader("""
157
+ Examples Questions:
158
+ - Describe about the image
159
+ - Tell me what are the things you can detect in the image .
160
+ - How is the image quality
161
+ """)
162
+
163
+ bot(cv_agent,file_path)
164
+
165
+ if __name__ == '__main__':
166
+ main()
167
+
168
+
169
+
extract_tools.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import requests
4
+ from PIL import Image
5
+ import logging
6
+ import torch
7
+ from llm_service import get_llm
8
+ from langchain_core.tools import tool,Tool
9
+ from langchain_community.tools import DuckDuckGoSearchResults
10
+ from langchain_groq import ChatGroq
11
+ from utils import draw_panoptic_segmentation
12
+
13
+ from tool_utils.clip_segmentation import CLIPSEG
14
+ from tool_utils.object_extractor import create_object_extraction_chain
15
+ from tool_utils.yolo_world import YoloWorld
16
+ from tool_utils.image_metadata import image_brightness,variance_of_laplacian,get_signal_to_noise_ratio
17
+
18
+ try:
19
+ from transformers import BlipProcessor, BlipForConditionalGeneration
20
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
21
+ except ImportError as err:
22
+ logging.error("Import error :{}".format(err))
23
+
24
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
+
26
+ logging.info("Loading Foundation Models")
27
+ try:
28
+ clipseg_model = CLIPSEG()
29
+ except Exception as err :
30
+ logging.error("Unable to clipseg model {}".format(err))
31
+ try:
32
+ maskformer_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
33
+ maskformer_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
34
+ except:
35
+ logging.error("Unable to Maskformer model {}".format(err))
36
+
37
+
38
+ def get_groq_model(model_name = "gemma2-9b-it"):
39
+ os.environ.get("GROQ_API_KEY")
40
+ llm_groq = ChatGroq(model=model_name)
41
+ return llm_groq
42
+
43
+ @tool
44
+ def panoptic_image_segemntation(image_path:str)->str:
45
+ """
46
+ The tool is used to create a Panoptic segmentation mask . It uses Maskformer network to create a panoptic segmentation of all \
47
+ the objects present in the image . Use the tool in case user ask to create a panoptic segmentation.
48
+ """
49
+ if image_path.startswith('https'):
50
+ image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
51
+ else:
52
+ image = Image.open(image_path).convert('RGB')
53
+ maskformer_model.to(device)
54
+ inputs = maskformer_processor(image, return_tensors="pt").to(device)
55
+ with torch.no_grad():
56
+ outputs = maskformer_model(**inputs)
57
+
58
+
59
+ prediction = maskformer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
60
+ save_mask_path = draw_panoptic_segmentation(maskformer_model,prediction['segmentation'],prediction['segments_info'])
61
+ labels = []
62
+ for segment in prediction['segments_info']:
63
+ label_names = maskformer_model.config.id2label[segment['label_id']]
64
+ print(label_names)
65
+ labels.append(label_names)
66
+ return 'Panoptic Segmentation image {} created with labels {} '.format(save_mask_path,labels)
67
+
68
+ @tool
69
+ def image_description(img_path:str)->str:
70
+ "Use this tool to describe the image " \
71
+ "The tool helps you to identify weather in the image as well "
72
+ hf_model = "Salesforce/blip-image-captioning-base"
73
+ text = ""
74
+ if img_path.startswith('https'):
75
+ image = Image.open(requests.get(img_path, stream=True).raw).convert('RGB')
76
+ else:
77
+ image = Image.open(img_path).convert('RGB')
78
+ try:
79
+ processor = BlipProcessor.from_pretrained(hf_model)
80
+ caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device)
81
+ except:
82
+ logging.error("unable to load the Blip model ")
83
+
84
+ logging.info("Image Caption model loaded ! ")
85
+
86
+ # unconditional image captioning
87
+ inputs = processor(image, return_tensors ='pt').to(device)
88
+ output = caption_model.generate(**inputs, max_new_tokens=50)
89
+ caption = processor.decode(output[0], skip_special_tokens=True)
90
+
91
+ # conditional image captioning
92
+ obj_text = "Total number of objects in image "
93
+ inputs_2 = processor(image, obj_text ,return_tensors ='pt').to(device)
94
+ out_2 = caption_model.generate(**inputs_2,max_new_tokens=50)
95
+ object_caption = processor.decode(out_2[0], skip_special_tokens=True)
96
+
97
+ ## clear the GPU cache
98
+ with torch.no_grad():
99
+ torch.cuda.empty_cache()
100
+ text = caption + " ."+ object_caption+" ."
101
+ return text
102
+
103
+
104
+ @tool
105
+ def clipsegmentation_mask(input_data:str)->str:
106
+ """
107
+ The tool helps to extract the object masks from the image.
108
+ For example : If you want to extract the object masks from the image use this tool.
109
+ """
110
+ data = input_data.split(",")
111
+ image_path = data[0]
112
+ object_prompts = data[1:]
113
+ masks = clipseg_model.get_segmentation_mask(image_path,object_prompts)
114
+ return masks
115
+
116
+ @tool
117
+ def generate_bounding_box_tool(input_data:str)->str:
118
+ "use this tool when its is required to detect object and provide bounding boxes for the given image and list of objects"
119
+ yolo_world_model= YoloWorld()
120
+ data = input_data.split(",")
121
+ image_path = data[0]
122
+ object_prompts = data[1:]
123
+ object_data = yolo_world_model.run_inference(image_path,object_prompts)
124
+ return object_data
125
+
126
+ @tool
127
+ def object_extraction(img_path:str)->str:
128
+ "Use this tool to identify the objects within the image"
129
+
130
+ hf_model = "Salesforce/blip-image-captioning-base"
131
+ if img_path.startswith('https'):
132
+ image = Image.open(requests.get(img_path, stream=True).raw).convert('RGB')
133
+ else:
134
+ image = Image.open(img_path).convert('RGB')
135
+ try:
136
+ processor = BlipProcessor.from_pretrained(hf_model)
137
+ caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device)
138
+ except:
139
+ logging.error("unable to load the Blip model ")
140
+
141
+ logging.info("Image Caption model loaded ! ")
142
+
143
+ # unconditional image captioning
144
+ inputs = processor(image, return_tensors ='pt').to(device)
145
+ output = caption_model.generate(**inputs, max_new_tokens=50)
146
+ llm = get_groq_model()
147
+ getobject_chain = create_object_extraction_chain(llm=llm)
148
+
149
+ extracted_objects = getobject_chain.invoke({
150
+ 'context': processor.decode(output[0], skip_special_tokens=True)
151
+ }).objects
152
+
153
+ print("Extracted objects : ",extracted_objects)
154
+ ## clear the GPU cache
155
+ with torch.no_grad():
156
+ torch.cuda.empty_cache()
157
+
158
+ return extracted_objects.split(',')
159
+
160
+ @tool
161
+ def get_image_quality(image_path:str)->str:
162
+ """
163
+ This tool helps to find out the parameters of the image.The tool will determine if image is blurry or not.
164
+ It will also tell you if image is bright or not.
165
+ This tool also determines the Signal to Noise Ratio of the image as well .
166
+ For example Output of the tool will be :
167
+ example 1 : Image is blurry.Image is not bright.Signal to Noise is less than 1 - More Noise in image
168
+ example 2 : Image is not blurry . Image is bright.Signal to Noise is greater than 1 - More Signal in image
169
+ """
170
+ image = cv2.imread(image_path)
171
+ image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
172
+
173
+ brightness_text = image_brightness(image)
174
+ blurry_text = variance_of_laplacian(image)
175
+ snr_text = get_signal_to_noise_ratio(image)
176
+ final_text = "Image properties are :\n{}\n{}\n{}".format(blurry_text, brightness_text,snr_text)
177
+ return final_text
178
+
179
+
180
+
181
+ def get_all_tools():
182
+ ## bind tools
183
+ image_desc_tool = Tool(
184
+ name = 'Image_Descprtion_Tool',
185
+ func= image_description,
186
+ description = """
187
+ The tool helps to describe about the image or create a caption of the image
188
+ If the user asks to decribe or genrerate a caption for the image use this tool.
189
+ This tool can also be used to identify the weather within the image .
190
+ user example questions :
191
+ 1. Describe the image ?
192
+ 2. What the weather looks like in the image ?
193
+ """
194
+ )
195
+
196
+ clipseg_tool = Tool(
197
+ name = 'ClipSegmentation-tool',
198
+ func = clipsegmentation_mask,
199
+ description="""Use this tool when user ask to generate the segmentation Mask of the objects provided by the user.
200
+ The input to the tool is the path of the image and list of objects for which Segmenation mask is to generated.
201
+ For example :
202
+ Query :Provide a segmentation mask of all road car and dog in the image
203
+
204
+ The tool will generate the segmentation mask of the objects in the image.
205
+ for such query from the user you need to first use the tool to identify the objects and then use this tool to
206
+ generate the segmentation mask for the objects.
207
+
208
+ """
209
+ )
210
+
211
+ bounding_box_generator = Tool(
212
+ name = 'Bounding Box Generator',
213
+ func = generate_bounding_box_tool,
214
+ description= "The tool helps to provide bounding boxes for the given image and list of objects\
215
+ .Use this tool when user ask to provide bounding boxes for the objects.if user has not specified the names of the objects \
216
+ then use the object extraction tool to identify the objects and then use this tool to generate the bounding boxes for the objects.\
217
+ The input to this tool is the path of the image and list of objects for which bounding boxes are to be generated"
218
+ )
219
+
220
+ object_extractor = Tool(
221
+ name = "Object Extraction Tool",
222
+ func = object_extraction,
223
+ description = " The Tool is used to extract objects within the image . Use this tool if user specifically ask to identify \
224
+ what are the objects I can view in the image or identify the objects within the image . "
225
+ )
226
+
227
+ image_parameters_tool = Tool(
228
+ name = 'Image Parameters_Tool',
229
+ func = get_image_quality,
230
+ description= """ This tool will help you to determine
231
+ - If the image is blurry or not
232
+ - If the image is bright/sharp or not
233
+ - SNR ratio of the image
234
+ Based on the tool output take a proper decision regarding the image quality"""
235
+ )
236
+
237
+ panoptic_segmentation = Tool(
238
+ name = 'panoptic_Segmentation_tool',
239
+ func = panoptic_image_segemntation,
240
+ description = "The tool is used to create a Panoptic segmentation mask . It uses Maskformer network to create a panoptic segmentation of all \
241
+ the objects present in the image . Use the tool in case user ask to create a panoptic segmentation or count objects in the image.\
242
+ The tool also provides a list of objects along with the mask image of the all segmented objects found in the image ."
243
+ )
244
+
245
+ tools = [
246
+ DuckDuckGoSearchResults(),
247
+ image_desc_tool,
248
+ clipseg_tool,
249
+ image_parameters_tool,
250
+ object_extractor,
251
+ bounding_box_generator,
252
+ panoptic_segmentation
253
+ ]
254
+ return tools
final_mask.png ADDED
hub_prompts.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PREFIX = """
2
+ You are an agent designed to filter images on the basis of image quality.You are provided with the Dashcam images from the Car.
3
+ You have access to following tools : Image Descprtion_Tool,Object Proportion_Tool,Image Parameters_Tool,DuckDuckGoSearch_Tool
4
+ Some examples provided about are below:
5
+
6
+ Question: Describe about the image
7
+ Thought: To describe the image , I must need to find a tool that describes an image.Image Descprtion_Tool desribes the image.I should use that tool.
8
+ Action : ```json
9
+ {
10
+ "action": "Image Description_Tool",
11
+ "action_input": "image_store/image.jpg"
12
+ }
13
+ ```
14
+ Observation : "car driving on road.The weather in the image is Stormy.
15
+ Final Answer : Car is driving in a stromy weather.Due to stormy weather the visibility will be low.
16
+
17
+
18
+ Question: I need to know the quality of the image.
19
+ Thought: I must need to find a tool that describes the image quality .Image Parameters_Tool can help to find out parameters like Brightness , Blur and Noise in the image.
20
+ Action : ```json
21
+ {
22
+ "action": "Image Description_Tool",
23
+ "action_input": 'image_store/image_path.jpg'
24
+ }
25
+ ```
26
+ Observation : "Image is Bright enough and have a high Signal to Noise ratio >1, means the qaulity of the image is good."
27
+ Final Answer : The Quality of the image seems good.
28
+
29
+ Question: I need to know the quality of the image.
30
+ Thought: I must need to find a tool that describes the image quality .Image Parameters_Tool can help to find out parameters like Brightness , Blur and Noise in the image.
31
+ Action : ```json
32
+ {
33
+ "action": "Image Description_Tool",
34
+ "action_input": 'image_store/image_path.jpg'
35
+ }
36
+ ```
37
+ Observation : "Image is not Bright enough and have more noise, means the qaulity of the image is bad."
38
+ Final Answer : The Quality of the image does not seems to be good .
39
+
40
+ Question: I need to detemine the cracks .
41
+ Thought: I must need to find a tool that describes the image quality .Image Parameters_Tool can help to find out parameters like Brightness , Blur and Noise in the image.
42
+ Action : ```json
43
+ {
44
+ "action": "Image Description_Tool",
45
+ "action_input": 'image_store/image_path.jpg'
46
+ }
47
+ ```
48
+ Observation : "Image is not Bright enough and have more noise, means the qaulity of the image is bad."
49
+ Final Answer : The Quality of the image does not seems to be good .
50
+
51
+ Final method is "get_image_parameters". This tool helps to find out general properties of image like blurliness ,sharpness,
52
+ brightness, Signal to Noise ratio in image.
53
+
54
+ Use the these tools and the information provided by these tools to construct your final answer.
55
+
56
+ If you get an error while executing a query, rewrite the query and try again.
57
+ If the question does not seem related to the database, just return "I don't know" as the answer.
58
+ """
59
+
60
+
61
+ FORMAT_INSTRUCTIONS="""Use the following format:
62
+
63
+ Question: the input question you must answer
64
+
65
+ Thought: you should always think about what to do
66
+
67
+ Action: the action to take, should be one of [{tool_names}]
68
+
69
+ Action Input: the input to the action
70
+
71
+ Observation: the result of the action
72
+ ... (this Thought/Action/Action Input/Observation can repeat N times)
73
+
74
+ Thought: I now know the final answer
75
+ Final Answer: the final answer to the original input question
76
+
77
+ """
78
+
79
+ SUFFIX = """You are an humble agent provide infomration point wise """
llm_service.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from getpass import getpass
3
+ from langchain_groq import ChatGroq
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ from langchain_openai import AzureChatOpenAI
6
+ from langchain_community.llms import Ollama
7
+ from langchain_openai.chat_models.base import BaseChatOpenAI
8
+
9
+ def azure_openai_service(key,max_retries=3):
10
+ os.environ["AZURE_OPENAI_API_KEY"] = key
11
+ os.environ["AZURE_OPENAI_ENDPOINT"] = "https://indus.api.michelin.com/openai-key-weu"
12
+ model = AzureChatOpenAI(
13
+ azure_deployment="gpt-4o", # or your deployment
14
+ api_version="2023-06-01-preview", # or your api version
15
+ temperature=0,
16
+ max_tokens=None,
17
+ timeout=None,
18
+ max_retries=max_retries)
19
+ return model
20
+
21
+ def get_ollama():
22
+ ## terminal --> ollama start
23
+ llm = Ollama(base_url="http://localhost:11434", model="mistral")
24
+ return llm
25
+
26
+ def get_googleGemini(key):
27
+ os.environ["GOOGLE_API_KEY"] = key
28
+ llm = ChatGoogleGenerativeAI(
29
+ model="gemini-1.5-pro",
30
+ temperature=0,
31
+ max_tokens=None,
32
+ timeout=None,
33
+ max_retries=2)
34
+ return llm
35
+
36
+ def get_groq_model(key,model_name = "gemma2-9b-it"):
37
+ os.environ["GROQ_API_KEY"] = key
38
+ llm_groq = ChatGroq(model=model_name)
39
+ return llm_groq
40
+
41
+
42
+ def get_llm(option,key):
43
+ llm = None
44
+ if option =='deepseek-r1-distill-llama-70b':
45
+ llm = get_groq_model(key,model_name = "deepseek-r1-distill-llama-70b")
46
+ elif option =='gemma2-9b-it':
47
+ llm = get_groq_model(key,model_name="gemma2-9b-it")
48
+ elif option == 'llama-3.2-3b-preview':
49
+ llm = get_groq_model(key,model_name="llama-3.2-3b-preview")
50
+ elif option == 'llama-3.2-1b-preview':
51
+ llm = get_groq_model(key,model_name="llama-3.2-1b-preview")
52
+ elif option == 'llama3-8b-8192':
53
+ llm = get_groq_model(key,model_name="llama3-8b-8192")
54
+ elif option == 'Openai':
55
+ llm = azure_openai_service(key)
56
+ elif option == 'Google':
57
+ llm = get_googleGemini(key)
58
+ elif option == "Ollama" :
59
+ llm = get_ollama()
60
+ return llm
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ langchain
3
+ langchain-community
4
+ langchain-core
5
+ langchain-text-splitters
6
+ langchain-experimental
7
+ langchain-google-genai
8
+ langchain-openai
9
+ tiktoken
10
+ duckduckgo-search
11
+ torch
12
+ transformers
13
+ langchain-groq
14
+ jq
15
+ scikit-learn
16
+ PyWavelets
17
+ scikit-image
tool_utils/clip_segmentation.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from matplotlib import pyplot as plt
3
+ import torch
4
+ import numpy as np
5
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
6
+ from segmentation_mask_overlay import overlay_masks
7
+ from typing import List
8
+ import logging
9
+ class CLIPSEG:
10
+ def __init__(self,model_name = "CIDAS/clipseg-rd64-refined",threshould=0.60):
11
+ self.clip_processor = CLIPSegProcessor.from_pretrained(model_name)
12
+ self.clip_model = CLIPSegForImageSegmentation.from_pretrained(model_name)
13
+ self.threshould = threshould
14
+ self.clip_model.to('cpu')
15
+
16
+ @staticmethod
17
+ def create_rgb_mask(mask,color=None):
18
+ color = tuple(np.random.choice(range(0,256), size=3))
19
+ gray_3_channel = cv2.merge((mask, mask, mask))
20
+ gray_3_channel[mask==255] = color
21
+ return gray_3_channel.astype(np.uint8)
22
+
23
+ def get_segmentation_mask(self,image_path:str,object_prompts:List):
24
+ image = cv2.cvtColor(cv2.imread(image_path),cv2.COLOR_BGR2RGB)
25
+ logging.info("objects found out from the image :{}".format(object_prompts))
26
+
27
+ predicted_masks = []
28
+ inputs = self.clip_processor(
29
+ text=object_prompts,
30
+ images=[image] * len(object_prompts),
31
+ padding="max_length",
32
+ return_tensors="pt",
33
+ )
34
+ with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
35
+ outputs = self.clip_model(**inputs)
36
+ preds = outputs.logits.unsqueeze(1)
37
+ # detections = outputs.logits[0] # Assuming class index 0
38
+
39
+ for i in range(preds.shape[0]):
40
+ predicted_mask = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
41
+ predicted_mask = np.where(predicted_mask>self.threshould, 255,0)
42
+ predicted_masks.append(predicted_mask)
43
+
44
+ resize_image = cv2.resize(image,(352,352))
45
+ mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(object_prompts)]
46
+ cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1]
47
+
48
+ bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
49
+ final_mask = overlay_masks(resize_image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=0.5,beta=0.7)
50
+ try:
51
+ cv2.imwrite('final_mask.png',final_mask)
52
+ return 'Segmentation image created : final_mask.png'
53
+ except Exception as e:
54
+ logging.error("Error while saving the final mask :",e)
55
+ return "unable to create a mask image "
tool_utils/image_metadata.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ from skimage.restoration import estimate_sigma
5
+ import logging
6
+ def image_brightness(image,thresh=0.37):
7
+ L,A,B = cv2.split(cv2.cvtColor(image,cv2.COLOR_BGR2LAB))
8
+ norm_L = L/np.max(L)
9
+ L_mean = np.mean(norm_L)
10
+ if L_mean > thresh:
11
+ return "image is Bright enough "
12
+ else:
13
+ return "image is not bright enough "
14
+
15
+ def variance_of_laplacian(img,threshould=250):
16
+ # compute the Laplacian of the image and then return the focus
17
+ # measure, which is simply the variance of the Laplacian
18
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
19
+
20
+ laplacian_value = cv2.Laplacian(gray, cv2.CV_64F).var()
21
+ logging.info(laplacian_value)
22
+ if laplacian_value <= threshould:
23
+ return " Image is very blurry"
24
+ elif laplacian_value <= 3*threshould:
25
+ return " Image is visible but have some regions out of foucs."
26
+ elif laplacian_value >= 3*threshould:
27
+ return "Image is Very Sharp."
28
+
29
+ def get_signal_to_noise_ratio(image):
30
+ snr_text = None
31
+ snr_value = estimate_sigma(cv2.cvtColor(image,cv2.COLOR_RGB2GRAY), average_sigmas=False)
32
+ logging.info(snr_value)
33
+ if snr_value > 1 :
34
+ snr_text = "Signal to Noise is greater than 1 - More Signal in image "
35
+ else:
36
+ snr_text = "Signal to Noise is less than 1 - More Noise in image "
37
+ return snr_text
38
+
39
+
40
+
41
+
42
+
43
+
tool_utils/mask2former.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import argparse
4
+ import warnings
5
+ try:
6
+ import torch as th
7
+ from transformers import AutoImageProcessor ,Mask2FormerModel,Mask2FormerForUniversalSegmentation
8
+ except ImportError as error:
9
+ raise ('Try installing torch and Transfomers module using pip.')
10
+
11
+ warnings.filterwarnings("ignore")
12
+
13
+
14
+ class MASK2FORMER:
15
+ def __init__(self,model_name="facebook/mask2former-swin-small-ade-semantic",class_id =6): ## use large
16
+ self.image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-semantic")
17
+ self.maskformer_processor = Mask2FormerModel.from_pretrained(model_name)
18
+ self.maskformer_model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name)
19
+ self.DEVICE = "cuda" if th.cuda.is_available() else 'cpu'
20
+ self.segment_id = class_id
21
+ self.maskformer_model.to(self.DEVICE)
22
+
23
+ def create_rgb_mask(self,mask,value=255):
24
+ gray_3_channel = cv2.merge((mask, mask, mask))
25
+ gray_3_channel[mask==value] = (255,255,255)
26
+ return gray_3_channel.astype(np.uint8)
27
+
28
+ def get_mask(self,segmentation):
29
+ """
30
+ Mask out the segment of the class from the provided segment_id
31
+ args : segmentation -> torch.obj - segmentation ouput from the maskformer model
32
+ segment_id -> class id of the object to be extracted
33
+ return : ndarray -> 2D Mask of the image
34
+ """
35
+ if self.segment_id == "vehicle":
36
+ mask = (segmentation.cpu().numpy().copy()==2) | (segmentation.cpu().numpy().copy()==5) | (segmentation.cpu().numpy().copy()== 7)
37
+ else:
38
+ mask = (segmentation.cpu().numpy() == 6)
39
+ visual_mask = (mask * 255).astype(np.uint8)
40
+ return visual_mask #np.asarray(visual_mask)
41
+
42
+ def generate_road_mask(self,img):
43
+ """
44
+ Extract semantic road mask from raw image
45
+ args : img -> np.array - input_image
46
+ return : ndarray -> masked out road .
47
+ """
48
+ inputs = self.image_processor(img, return_tensors="pt")
49
+
50
+ inputs = inputs.to(self.DEVICE)
51
+ with th.no_grad():
52
+ outputs = self.maskformer_model(**inputs)
53
+
54
+ segmentation = self.image_processor.post_process_semantic_segmentation(outputs,target_sizes=[(img.shape[0],img.shape[1])])[0]
55
+ segmented_mask = self.get_mask(segmentation=segmentation)
56
+ return segmented_mask
57
+
58
+ def get_rgb_mask(self,img,segmented_mask):
59
+ """
60
+ Extract RGB road image and removing the background .
61
+ args: img -> ndarray - raw image
62
+ segmented_mask - binary mask from the semantic segmentation
63
+ return : ndarray -> RGB road image with background pixels as 0.
64
+ """
65
+ predicted_rgb_mask = self.create_rgb_mask(segmented_mask)
66
+ rgb_mask_img = cv2.bitwise_and(img,predicted_rgb_mask )
67
+ return rgb_mask_img
68
+
69
+ def run_inference(self,image_name):
70
+ """
71
+ Function used to create a segmentation mask for specific segment_id provided. The function uses
72
+ "facebook/maskformer-swin-small-coco" maskformer model to extract segmentation mask for the provided image
73
+ args: image_name -> str/numpy_array- image path read and processed by maskformer .
74
+ out_path -> str - output path save the masked output
75
+ skip_read -> bool- If provided image is nd_array skip_read == True else False
76
+ segment_id -> int- id value to extract maks Default value is 100 for road
77
+ """
78
+ input_image = cv2.cvtColor( cv2.imread(image_name),cv2.COLOR_BGR2RGB)
79
+ road_mask = self.generate_road_mask(input_image)
80
+ road_image = self.get_rgb_mask(input_image,road_mask)
81
+ obj_prop = round((np.count_nonzero(road_image) / np.size(road_image)) * 100, 1)
82
+ ## empty gou cache
83
+ with th.no_grad():
84
+ th.cuda.empty_cache()
85
+ return obj_prop
86
+
87
+
88
+ def main(args):
89
+ mask2former = ROADMASK_WITH_MASK2FORMER()
90
+ input_image = cv2.cvtColor( cv2.imread(args.image_path),cv2.COLOR_BGR2RGB)
91
+
92
+ road_mask = mask2former.generate_road_mask(input_image)
93
+ road_image = mask2former.get_rgb_mask(input_image,road_mask)
94
+ obj_prop = round(np.count_nonzero(road_image) / np.size(road_image) * 100, 1)
95
+
96
+ return road_mask , road_image , obj_prop
97
+
98
+ if __name__=="__main__":
99
+ parser = argparse.ArgumentParser()
100
+ parser.add_argument('-image_path',help='raw_image_path', required=True)
101
+ args = parser.parse_args()
102
+ main(args)
tool_utils/object_extractor.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel , Field
2
+ from langchain_core.prompts import PromptTemplate
3
+
4
+ class objects_identified(BaseModel):
5
+ objects : str = Field(...,description="Generate a list of objects identified from the given description of the image")
6
+
7
+ def objectextractor_prompt():
8
+ template = """
9
+ You are an AI assistant provided with a context below. The context is a description of an image. Your task is to identify the objects within the image.
10
+ The objects must be living beings or physical items or things that one can view, feel, and touch. these objects must be nouns and not verbs or adjectives or words describing an object like
11
+ 'large', 'beautiful' etc .
12
+ Only provide the name of the object and not the description of the object.Refer the example input and output for better understanding.
13
+ If the conntext mention a boy , a girl , women , girls they will come under the category of "People". so the object for such classes will be people .
14
+ Example Input: Context: "A park filled with men and women , a large oak tree standing in the center, a dog running near a bench, and a bicycle leaning against a nearby fence."
15
+ Example Output:"People" , "Dog" , "Bench" , "Bicycle" ,"Fence"
16
+
17
+ Context: {context}
18
+ """
19
+ prompt = PromptTemplate(template=template,input_variables=["context"])
20
+ return prompt
21
+
22
+ def create_object_extraction_chain(llm):
23
+ object_extraction_chain = objectextractor_prompt() | llm.with_structured_output(objects_identified)
24
+ return object_extraction_chain
tool_utils/yolo_world.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import numpy as np
4
+ from typing import List
5
+ from ultralytics import YOLOWorld
6
+
7
+ class YoloWorld:
8
+ def __init__(self,model_name = "yolov8x-worldv2.pt"):
9
+ self.model = YOLOWorld(model_name)
10
+ self.model.to(device='cpu')
11
+
12
+ def run_inference(self,image_path:str,object_prompts:List):
13
+ object_details = []
14
+ self.model.set_classes(object_prompts)
15
+ results = self.model.predict(image_path)
16
+ for result in results:
17
+ for box in result.boxes:
18
+ object_data = {}
19
+ x1, y1, x2, y2 = np.array(box.xyxy.cpu(), dtype=np.int32).squeeze()
20
+ c1,c2 = (x1,y1),(x2,y2)
21
+ confidence = round(float(box.conf.cpu()),2)
22
+ label = f'{results[0].names[int(box.cls)]}' # [{100*round(confidence,2)}%]'
23
+ print("Object Name :{} Bounding Box:{},{} Confidence score {}\n ".format(label ,c1 ,c2,confidence))
24
+ object_data[label] = {
25
+ 'bounding_box':[x1,y1,x2,y2],
26
+ 'confidence':confidence
27
+ }
28
+ object_details.append(object_data)
29
+ return object_details
30
+
utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as mpatches
4
+ from matplotlib import cm
5
+ import torch
6
+
7
+ def draw_panoptic_segmentation(model,segmentation, segments_info):
8
+ # get the used color map
9
+ viridis = cm.get_cmap('viridis', torch.max(segmentation))
10
+ fig, ax = plt.subplots()
11
+ ax.imshow(segmentation.cpu().numpy())
12
+ instances_counter = defaultdict(int)
13
+ handles = []
14
+ # for each segment, draw its legend
15
+ for segment in segments_info:
16
+ segment_id = segment['id']
17
+ segment_label_id = segment['label_id']
18
+ segment_label = model.config.id2label[segment_label_id]
19
+ label = f"{segment_label}-{instances_counter[segment_label_id]}"
20
+ instances_counter[segment_label_id] += 1
21
+ color = viridis(segment_id)
22
+ handles.append(mpatches.Patch(color=color, label=label))
23
+
24
+ # ax.legend(handles=handles)
25
+ fig.savefig('final_mask.png')
26
+ return 'final_mask.png'