Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,9 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) -
GT annotations or samples for successful job launch
"""

self.gt_id_attribute = "object_id"
"An additional way to match GT skeletons with input boxes"

# TODO: probably, need to also add an absolute number of minimum GT RoIs per class

def _download_input_data(self):
Expand Down Expand Up @@ -2058,7 +2061,9 @@ def _validate_boxes_annotations(self): # noqa: PLR0912
)
continue

valid_instances.append((bbox, point.wrap(group=bbox.group, id=bbox.id)))
valid_instances.append(
(bbox, point.wrap(group=bbox.group, id=bbox.id, attributes=bbox.attributes))
)
visited_ids.add(bbox.id)

excluded_boxes_info.excluded_count += len(instances) - len(valid_instances)
Expand Down Expand Up @@ -2138,8 +2143,14 @@ def _find_unambiguous_matches(
input_boxes: list[dm.Bbox],
gt_skeletons: list[dm.Skeleton],
*,
input_points: list[dm.Points],
gt_annotations: list[dm.Annotation],
) -> list[tuple[dm.Bbox, dm.Skeleton]]:
bbox_point_mapping: dict[int, dm.Points] = {
bbox.id: next(p for p in input_points if p.group == bbox.group)
for bbox in input_boxes
}

matches = [
[
(input_bbox.label == gt_skeleton.label)
Expand All @@ -2149,6 +2160,18 @@ def _find_unambiguous_matches(
self._get_skeleton_bbox(gt_skeleton, gt_annotations),
)
)
and (input_point := bbox_point_mapping[input_bbox.id])
and is_point_in_bbox(
input_point.points[0],
input_point.points[1],
self._get_skeleton_bbox(gt_skeleton, gt_annotations),
)
and (
# a way to customize matching if the default method is too rough
not (bbox_id := input_bbox.attributes.get(self.gt_id_attribute))
or not (skeleton_id := gt_skeleton.attributes.get(self.gt_id_attribute))
or bbox_id == skeleton_id
)
for gt_skeleton in gt_skeletons
]
for input_bbox in input_boxes
Expand Down Expand Up @@ -2239,10 +2262,11 @@ def _find_good_gt_skeletons(
input_boxes: list[dm.Bbox],
gt_skeletons: list[dm.Skeleton],
*,
input_points: list[dm.Points],
gt_annotations: list[dm.Annotation],
) -> list[dm.Skeleton]:
matches = _find_unambiguous_matches(
input_boxes, gt_skeletons, gt_annotations=gt_annotations
input_boxes, gt_skeletons, input_points=input_points, gt_annotations=gt_annotations
)

matched_skeletons = []
Expand Down Expand Up @@ -2293,13 +2317,18 @@ def _find_good_gt_skeletons(

gt_skeletons = [a for a in gt_sample.annotations if isinstance(a, dm.Skeleton)]
input_boxes = [a for a in boxes_sample.annotations if isinstance(a, dm.Bbox)]
input_points = [a for a in boxes_sample.annotations if isinstance(a, dm.Points)]
assert len(input_boxes) == len(input_points)

# Samples without boxes are allowed, so we just skip them without an error
if not gt_skeletons:
continue

matched_skeletons = _find_good_gt_skeletons(
input_boxes, gt_skeletons, gt_annotations=gt_sample.annotations
input_boxes,
gt_skeletons,
input_points=input_points,
gt_annotations=gt_sample.annotations,
)
if not matched_skeletons:
continue
Expand Down