diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index 4f79ea7925..96949a1e23 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -1734,6 +1734,9 @@ def __init__(self, manifest: TaskManifest, escrow_address: str, chain_id: int) - GT annotations or samples for successful job launch """ + self.gt_id_attribute = "object_id" + "An additional way to match GT skeletons with input boxes" + # TODO: probably, need to also add an absolute number of minimum GT RoIs per class def _download_input_data(self): @@ -2058,7 +2061,9 @@ def _validate_boxes_annotations(self): # noqa: PLR0912 ) continue - valid_instances.append((bbox, point.wrap(group=bbox.group, id=bbox.id))) + valid_instances.append( + (bbox, point.wrap(group=bbox.group, id=bbox.id, attributes=bbox.attributes)) + ) visited_ids.add(bbox.id) excluded_boxes_info.excluded_count += len(instances) - len(valid_instances) @@ -2138,8 +2143,14 @@ def _find_unambiguous_matches( input_boxes: list[dm.Bbox], gt_skeletons: list[dm.Skeleton], *, + input_points: list[dm.Points], gt_annotations: list[dm.Annotation], ) -> list[tuple[dm.Bbox, dm.Skeleton]]: + bbox_point_mapping: dict[int, dm.Points] = { + bbox.id: next(p for p in input_points if p.group == bbox.group) + for bbox in input_boxes + } + matches = [ [ (input_bbox.label == gt_skeleton.label) @@ -2149,6 +2160,18 @@ def _find_unambiguous_matches( self._get_skeleton_bbox(gt_skeleton, gt_annotations), ) ) + and (input_point := bbox_point_mapping[input_bbox.id]) + and is_point_in_bbox( + input_point.points[0], + input_point.points[1], + self._get_skeleton_bbox(gt_skeleton, gt_annotations), + ) + and ( + # a way to customize matching if the default method is too rough + not (bbox_id := input_bbox.attributes.get(self.gt_id_attribute)) + or not (skeleton_id := gt_skeleton.attributes.get(self.gt_id_attribute)) + or bbox_id == skeleton_id + ) for gt_skeleton in gt_skeletons ] for input_bbox in input_boxes @@ -2239,10 +2262,11 @@ def _find_good_gt_skeletons( input_boxes: list[dm.Bbox], gt_skeletons: list[dm.Skeleton], *, + input_points: list[dm.Points], gt_annotations: list[dm.Annotation], ) -> list[dm.Skeleton]: matches = _find_unambiguous_matches( - input_boxes, gt_skeletons, gt_annotations=gt_annotations + input_boxes, gt_skeletons, input_points=input_points, gt_annotations=gt_annotations ) matched_skeletons = [] @@ -2293,13 +2317,18 @@ def _find_good_gt_skeletons( gt_skeletons = [a for a in gt_sample.annotations if isinstance(a, dm.Skeleton)] input_boxes = [a for a in boxes_sample.annotations if isinstance(a, dm.Bbox)] + input_points = [a for a in boxes_sample.annotations if isinstance(a, dm.Points)] + assert len(input_boxes) == len(input_points) # Samples without boxes are allowed, so we just skip them without an error if not gt_skeletons: continue matched_skeletons = _find_good_gt_skeletons( - input_boxes, gt_skeletons, gt_annotations=gt_sample.annotations + input_boxes, + gt_skeletons, + input_points=input_points, + gt_annotations=gt_sample.annotations, ) if not matched_skeletons: continue