[feat] take advantage of re-usable image embeddings in SAM model
Browse files- poetry.lock +0 -0
- pyproject.toml +3 -3
- samgis/io/wrappers_helpers.py +22 -0
- samgis/prediction_api/predictors.py +12 -6
- wrappers/fastapi_wrapper.py +5 -2
poetry.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
CHANGED
|
@@ -11,15 +11,15 @@ bson = "^0.5.10"
|
|
| 11 |
contextily = "^1.5.2"
|
| 12 |
geopandas = "^0.14.3"
|
| 13 |
loguru = "^0.7.2"
|
| 14 |
-
numpy = "
|
| 15 |
onnxruntime = "1.16.3"
|
| 16 |
opencv-python-headless = "^4.8.1.78"
|
| 17 |
pillow = "^10.2.0"
|
| 18 |
-
python = "
|
| 19 |
python-dotenv = "^1.0.1"
|
| 20 |
rasterio = "^1.3.9"
|
| 21 |
requests = "^2.31.0"
|
| 22 |
-
samgis-core = "^1.
|
| 23 |
|
| 24 |
[tool.poetry.group.aws_lambda]
|
| 25 |
optional = true
|
|
|
|
| 11 |
contextily = "^1.5.2"
|
| 12 |
geopandas = "^0.14.3"
|
| 13 |
loguru = "^0.7.2"
|
| 14 |
+
numpy = "~1.25.2"
|
| 15 |
onnxruntime = "1.16.3"
|
| 16 |
opencv-python-headless = "^4.8.1.78"
|
| 17 |
pillow = "^10.2.0"
|
| 18 |
+
python = "~3.10"
|
| 19 |
python-dotenv = "^1.0.1"
|
| 20 |
rasterio = "^1.3.9"
|
| 21 |
requests = "^2.31.0"
|
| 22 |
+
samgis-core = "^1.1.1"
|
| 23 |
|
| 24 |
[tool.poetry.group.aws_lambda]
|
| 25 |
optional = true
|
samgis/io/wrappers_helpers.py
CHANGED
|
@@ -200,3 +200,25 @@ def get_url_tile(source_type: str):
|
|
| 200 |
|
| 201 |
def check_source_type_is_terrain(source: str | TileProvider):
|
| 202 |
return isinstance(source, TileProvider) and source.name in list(XYZTerrainProvidersNames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
def check_source_type_is_terrain(source: str | TileProvider):
|
| 202 |
return isinstance(source, TileProvider) and source.name in list(XYZTerrainProvidersNames)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def get_source_name(source: str | TileProvider) -> str | bool:
|
| 206 |
+
try:
|
| 207 |
+
match source.lower():
|
| 208 |
+
case XYZDefaultProvidersNames.DEFAULT_TILES_NAME_SHORT:
|
| 209 |
+
source_output = providers.query_name(XYZDefaultProvidersNames.DEFAULT_TILES_NAME)
|
| 210 |
+
case _:
|
| 211 |
+
source_output = providers.query_name(source)
|
| 212 |
+
if isinstance(source_output, str):
|
| 213 |
+
return source_output
|
| 214 |
+
try:
|
| 215 |
+
source_dict = dict(source_output)
|
| 216 |
+
app_logger.info(f"source_dict:{type(source_dict)}, {'name' in source_dict}, source_dict:{source_dict}.")
|
| 217 |
+
return source_dict["name"]
|
| 218 |
+
except KeyError as ke:
|
| 219 |
+
app_logger.error(f"ke:{ke}.")
|
| 220 |
+
except ValueError as ve:
|
| 221 |
+
app_logger.info(f"source name::{source}, ve:{ve}.")
|
| 222 |
+
app_logger.info(f"source name::{source}.")
|
| 223 |
+
|
| 224 |
+
return False
|
samgis/prediction_api/predictors.py
CHANGED
|
@@ -6,12 +6,13 @@ from samgis.io.tms2geotiff import download_extent
|
|
| 6 |
from samgis.io.wrappers_helpers import check_source_type_is_terrain
|
| 7 |
from samgis.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
|
| 8 |
from samgis_core.prediction_api.sam_onnx import SegmentAnythingONNX
|
| 9 |
-
from samgis_core.prediction_api.sam_onnx import get_raster_inference
|
| 10 |
from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
|
| 11 |
from samgis_core.utilities.type_hints import llist_float, dict_str_int, list_dict
|
| 12 |
|
| 13 |
|
| 14 |
models_dict = {"fastsam": {"instance": None}}
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def samexporter_predict(
|
|
@@ -19,7 +20,8 @@ def samexporter_predict(
|
|
| 19 |
prompt: list_dict,
|
| 20 |
zoom: float,
|
| 21 |
model_name: str = "fastsam",
|
| 22 |
-
source: str = DEFAULT_URL_TILES
|
|
|
|
| 23 |
) -> dict_str_int:
|
| 24 |
"""
|
| 25 |
Return predictions as a geojson from a geo-referenced image using the given input prompt.
|
|
@@ -34,7 +36,8 @@ def samexporter_predict(
|
|
| 34 |
prompt: machine learning input prompt
|
| 35 |
zoom: Level of detail
|
| 36 |
model_name: machine learning model name
|
| 37 |
-
source: xyz
|
|
|
|
| 38 |
|
| 39 |
Returns:
|
| 40 |
Affine transform
|
|
@@ -62,9 +65,12 @@ def samexporter_predict(
|
|
| 62 |
|
| 63 |
app_logger.info(
|
| 64 |
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
return {
|
| 69 |
"n_predictions": n_predictions,
|
| 70 |
**get_vectorized_raster_as_geojson(mask, transform)
|
|
|
|
| 6 |
from samgis.io.wrappers_helpers import check_source_type_is_terrain
|
| 7 |
from samgis.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
|
| 8 |
from samgis_core.prediction_api.sam_onnx import SegmentAnythingONNX
|
| 9 |
+
from samgis_core.prediction_api.sam_onnx import get_raster_inference, get_raster_inference_with_embedding_from_dict
|
| 10 |
from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
|
| 11 |
from samgis_core.utilities.type_hints import llist_float, dict_str_int, list_dict
|
| 12 |
|
| 13 |
|
| 14 |
models_dict = {"fastsam": {"instance": None}}
|
| 15 |
+
embedding_dict = {}
|
| 16 |
|
| 17 |
|
| 18 |
def samexporter_predict(
|
|
|
|
| 20 |
prompt: list_dict,
|
| 21 |
zoom: float,
|
| 22 |
model_name: str = "fastsam",
|
| 23 |
+
source: str = DEFAULT_URL_TILES,
|
| 24 |
+
source_name: str = None
|
| 25 |
) -> dict_str_int:
|
| 26 |
"""
|
| 27 |
Return predictions as a geojson from a geo-referenced image using the given input prompt.
|
|
|
|
| 36 |
prompt: machine learning input prompt
|
| 37 |
zoom: Level of detail
|
| 38 |
model_name: machine learning model name
|
| 39 |
+
source: xyz tile provider object
|
| 40 |
+
source_name: name of tile provider
|
| 41 |
|
| 42 |
Returns:
|
| 43 |
Affine transform
|
|
|
|
| 65 |
|
| 66 |
app_logger.info(
|
| 67 |
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
|
| 68 |
+
app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.")
|
| 69 |
+
embedding_key = f"{source_name}_z{zoom}_w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}"
|
| 70 |
+
mask, n_predictions = get_raster_inference_with_embedding_from_dict(
|
| 71 |
+
img, prompt, models_instance, model_name, embedding_key, embedding_dict)
|
| 72 |
+
app_logger.info(f"created {n_predictions} masks, type {type(mask)}, size {mask.size}: preparing geojson conversion")
|
| 73 |
+
app_logger.info(f"mask shape:{mask.shape}.")
|
| 74 |
return {
|
| 75 |
"n_predictions": n_predictions,
|
| 76 |
**get_vectorized_raster_as_geojson(mask, transform)
|
wrappers/fastapi_wrapper.py
CHANGED
|
@@ -8,13 +8,14 @@ from fastapi.staticfiles import StaticFiles
|
|
| 8 |
from pydantic import ValidationError
|
| 9 |
|
| 10 |
from samgis import PROJECT_ROOT_FOLDER
|
| 11 |
-
from samgis.io.wrappers_helpers import get_parsed_bbox_points
|
| 12 |
from samgis.utilities.type_hints import ApiRequestBody
|
| 13 |
from samgis_core.utilities.fastapi_logger import setup_logging
|
| 14 |
from samgis.prediction_api.predictors import samexporter_predict
|
| 15 |
|
| 16 |
|
| 17 |
app_logger = setup_logging(debug=True)
|
|
|
|
| 18 |
app = FastAPI()
|
| 19 |
|
| 20 |
|
|
@@ -68,9 +69,11 @@ def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
|
|
| 68 |
body_request = get_parsed_bbox_points(request_input)
|
| 69 |
app_logger.info(f"body_request:{body_request}.")
|
| 70 |
try:
|
|
|
|
|
|
|
| 71 |
output = samexporter_predict(
|
| 72 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
| 73 |
-
source=body_request["source"]
|
| 74 |
)
|
| 75 |
duration_run = time.time() - time_start_run
|
| 76 |
app_logger.info(f"duration_run:{duration_run}.")
|
|
|
|
| 8 |
from pydantic import ValidationError
|
| 9 |
|
| 10 |
from samgis import PROJECT_ROOT_FOLDER
|
| 11 |
+
from samgis.io.wrappers_helpers import get_parsed_bbox_points, get_source_name
|
| 12 |
from samgis.utilities.type_hints import ApiRequestBody
|
| 13 |
from samgis_core.utilities.fastapi_logger import setup_logging
|
| 14 |
from samgis.prediction_api.predictors import samexporter_predict
|
| 15 |
|
| 16 |
|
| 17 |
app_logger = setup_logging(debug=True)
|
| 18 |
+
app_logger.info(f"PROJECT_ROOT_FOLDER:{PROJECT_ROOT_FOLDER}.")
|
| 19 |
app = FastAPI()
|
| 20 |
|
| 21 |
|
|
|
|
| 69 |
body_request = get_parsed_bbox_points(request_input)
|
| 70 |
app_logger.info(f"body_request:{body_request}.")
|
| 71 |
try:
|
| 72 |
+
source_name = get_source_name(request_input.source_type)
|
| 73 |
+
app_logger.info(f"source_name = {source_name}.")
|
| 74 |
output = samexporter_predict(
|
| 75 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
| 76 |
+
source=body_request["source"], source_name=source_name
|
| 77 |
)
|
| 78 |
duration_run = time.time() - time_start_run
|
| 79 |
app_logger.info(f"duration_run:{duration_run}.")
|