Skip to content

Refine masks using box and points #1181

@maxeyje

Description

@maxeyje

I am attempting to refine masks using the batched_inference function. While it works very well using boxes, the points input throws the same error seen in Issue #383. I followed the recommended steps shown there but cannot seem to get beyond the error "RuntimeError: Tensors must have same number of dimensions: got 3 and 2". My points array has a shape of (163, 1, 2) and my labels array has a shape of (163, 1, 1) with 163 being the number of segmented objects in the image.

Here is how I have implemented the code:

def run_automatic_instance_segmentation(
    image: np.ndarray,
    checkpoint_path: Union[os.PathLike ,str],
    model_type: str = "vit_b_em_organelles",
    device: Optional[Union[str, torch.device]] = None,
    tile_shape: Optional[Tuple[512, 512]] = None, #[1000, 1000],
    halo: Optional[Tuple[10, 10]] = None, #[24, 24],
    points: Optional[np.ndarray] = None,
    point_labels: Optional[np.ndarray] = None,
):

    predictor, segmenter = get_predictor_and_segmenter(
        model_type=model_type,  # choice of the Segment Anything model
        checkpoint=checkpoint_path,  # overwrite to pass your own finetuned model.
        device=device,  # the device to run the model inference.
        is_tiled=(tile_shape is not None),  # whether to run automatic segmentation.
    )

    prediction = batched_inference(
        predictor=predictor,
        image=im,
        points=points,
        point_labels=point_labels,
        batch_size=32,
        return_instance_segmentation=True,
    )

    return prediction

# Uses a previous iteration of segmentation to find centroids of each object as positive points
points, point_labels = find_point_prompts(prediction,img_info["file_name"])
points = np.expand_dims(points, 1)
point_labels = np.expand_dims(point_labels, 1)

prediction = run_automatic_instance_segmentation(
        image=im,
        checkpoint_path=best_checkpoint,
        model_type=model_type,
        device=device,
        points=points,
        point_labels=point_labels,
    )

Here is the error message:

Compute Image Embeddings 2D: 100%|██████████| 1/1 [00:10<00:00, 10.62s/it]
Traceback (most recent call last):
  File "/home/groups/cedar-mmtert-general-scratch/micro_sam/segment_inside_bbox_with_points.py", line 257, in <module>
    new_prediction = run_automatic_instance_segmentation(
        image=im,
    ...<7 lines>...
        box_ind=0,
    )
  File "/home/groups/cedar-mmtert-general-scratch/micro_sam/segment_inside_bbox_with_points.py", line 172, in run_automatic_instance_segmentation
    prediction = batched_inference(
        predictor=predictor,
    ...<4 lines>...
        return_instance_segmentation=True,
    )
  File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/micro_sam/inference.py", line 139, in batched_inference
    batch_masks, batch_ious, batch_logits = predictor.predict_torch(
                                            ~~~~~~~~~~~~~~~~~~~~~~~^
        point_coords=batch_points,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<3 lines>...
        multimask_output=multimasking
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/segment_anything/predictor.py", line 222, in predict_torch
    sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                                          ~~~~~~~~~~~~~~~~~~~~~~~~~^
        points=points,
        ^^^^^^^^^^^^^^
        boxes=boxes,
        ^^^^^^^^^^^^
        masks=mask_input,
        ^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/segment_anything/modeling/prompt_encoder.py", line 155, in forward
    point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
  File "/home/exacloud/gscratch/CEDAR/maxeyje/miniconda3/envs/micro_sam/lib/python3.13/site-packages/segment_anything/modeling/prompt_encoder.py", line 85, in _embed_points
    labels = torch.cat([labels, padding_label], dim=1)
RuntimeError: Tensors must have same number of dimensions: got 3 and 2

Has this issue been resolved? I also saw the fix that recommended adding an image_shape parameter to the inference.py file which I have also tried without success. Thank you for your help!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions