I am attempting to refine masks using the batched_inference function. While it works very well using boxes, the points input throws the same error seen in Issue #383. I followed the recommended steps shown there but cannot seem to get beyond the error "RuntimeError: Tensors must have same number of dimensions: got 3 and 2". My points array has a shape of (163, 1, 2) and my labels array has a shape of (163, 1, 1) with 163 being the number of segmented objects in the image.
Here is how I have implemented the code:
def run_automatic_instance_segmentation(
image: np.ndarray,
checkpoint_path: Union[os.PathLike ,str],
model_type: str = "vit_b_em_organelles",
device: Optional[Union[str, torch.device]] = None,
tile_shape: Optional[Tuple[512, 512]] = None, #[1000, 1000],
halo: Optional[Tuple[10, 10]] = None, #[24, 24],
points: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
):
predictor, segmenter = get_predictor_and_segmenter(
model_type=model_type, # choice of the Segment Anything model
checkpoint=checkpoint_path, # overwrite to pass your own finetuned model.
device=device, # the device to run the model inference.
is_tiled=(tile_shape is not None), # whether to run automatic segmentation.
)
prediction = batched_inference(
predictor=predictor,
image=im,
points=points,
point_labels=point_labels,
batch_size=32,
return_instance_segmentation=True,
)
return prediction
# Uses a previous iteration of segmentation to find centroids of each object as positive points
points, point_labels = find_point_prompts(prediction,img_info["file_name"])
points = np.expand_dims(points, 1)
point_labels = np.expand_dims(point_labels, 1)
prediction = run_automatic_instance_segmentation(
image=im,
checkpoint_path=best_checkpoint,
model_type=model_type,
device=device,
points=points,
point_labels=point_labels,
)
Here is the error message:
Compute Image Embeddings 2D: 100%|██████████| 1/1 [00:10<00:00, 10.62s/it]
Traceback (most recent call last):
File "/home/groups/cedar-mmtert-general-scratch/micro_sam/segment_inside_bbox_with_points.py", line 257, in <module>
new_prediction = run_automatic_instance_segmentation(
image=im,
...<7 lines>...
box_ind=0,
)
File "/home/groups/cedar-mmtert-general-scratch/micro_sam/segment_inside_bbox_with_points.py", line 172, in run_automatic_instance_segmentation
prediction = batched_inference(
predictor=predictor,
...<4 lines>...
return_instance_segmentation=True,
)
File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/micro_sam/inference.py", line 139, in batched_inference
batch_masks, batch_ious, batch_logits = predictor.predict_torch(
~~~~~~~~~~~~~~~~~~~~~~~^
point_coords=batch_points,
^^^^^^^^^^^^^^^^^^^^^^^^^^
...<3 lines>...
multimask_output=multimasking
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/segment_anything/predictor.py", line 222, in predict_torch
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
~~~~~~~~~~~~~~~~~~~~~~~~~^
points=points,
^^^^^^^^^^^^^^
boxes=boxes,
^^^^^^^^^^^^
masks=mask_input,
^^^^^^^^^^^^^^^^^
)
^
File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/segment_anything/modeling/prompt_encoder.py", line 155, in forward
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/segment_anything/modeling/prompt_encoder.py", line 85, in _embed_points
labels = torch.cat([labels, padding_label], dim=1)
RuntimeError: Tensors must have same number of dimensions: got 3 and 2
Has this issue been resolved? I also saw the fix that recommended adding an image_shape parameter to the inference.py file which I have also tried without success. Thank you for your help!!
I am attempting to refine masks using the batched_inference function. While it works very well using boxes, the points input throws the same error seen in Issue #383. I followed the recommended steps shown there but cannot seem to get beyond the error "RuntimeError: Tensors must have same number of dimensions: got 3 and 2". My points array has a shape of (163, 1, 2) and my labels array has a shape of (163, 1, 1) with 163 being the number of segmented objects in the image.
Here is how I have implemented the code:
Here is the error message:
Has this issue been resolved? I also saw the fix that recommended adding an image_shape parameter to the inference.py file which I have also tried without success. Thank you for your help!!