Source code for sportslabkit.types.detections

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
from PIL import Image

from sportslabkit.logger import logger
from sportslabkit.types.detection import Detection
from sportslabkit.utils import increment_path, read_image
from sportslabkit.utils.draw import draw_bounding_boxes


[docs]class Detections: """SoccerTrack detections class for inference results.""" def __init__( self, preds: list[dict] | list[list] | list[Detection], im: str | Path | Image.Image | np.ndarray, times: tuple[float, float, float] = (0, 0, 0), names: list[str] | None = None, ): self.im = self._process_im(im) self.preds = self._process_preds(preds) self.names = names self.times = times def _process_im(self, im: str | Path | Image.Image | np.ndarray) -> np.ndarray: return read_image(im) def _process_pred(self, pred: dict | list | Detection) -> np.ndarray: # process predictions if isinstance(pred, dict): if len(pred.keys()) != 6: raise ValueError("The prediction dictionary should contain exactly 6 items") return np.stack( [ pred["bbox_left"], pred["bbox_top"], pred["bbox_width"], pred["bbox_height"], pred["conf"], pred["class"], ], axis=0, ) elif isinstance(pred, list): if len(pred) != 6: raise ValueError("The prediction list should contain exactly 6 items") return np.array(pred) elif isinstance(pred, Detection): return np.array( [ pred.box[0], pred.box[1], pred.box[2], pred.box[3], pred.score, pred.class_id, ] ) elif isinstance(pred, np.ndarray): if pred.shape != (6,): raise ValueError(f"pred should have the shape (6, ), but got {pred.shape}") return pred else: raise TypeError(f"Unsupported prediction type: {type(pred)}") def _process_preds(self, preds: list[Any]) -> np.ndarray: _processed_preds = [] for pred in preds: _processed_preds.append(self._process_pred(pred)) preds = np.array(_processed_preds) if not preds.size: preds = np.zeros((0, 6)) return preds
[docs] def show(self, **kwargs) -> Image.Image: im = self.im boxes = self.preds[:, :4] labels = [f"{int(c)} {conf:.2f}" for conf, c in self.preds[:, 4:]] draw_im = draw_bounding_boxes(im, boxes, labels, **kwargs) return Image.fromarray(draw_im)
[docs] def save_image(self, path: str | Path, **kwargs): image = self.show(**kwargs) image.save(path)
[docs] def save_boxes(self, path: str | Path): with open(path, "w") as f: for box in self.preds[:, :4]: f.write(",".join(map(str, box)) + "\n")
[docs] def crop( self, save: bool = True, save_dir: str | Path = "runs/detect/exp", exist_ok: bool = False ) -> list[Image.Image]: save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None images = [] for box in self.preds[:, :4]: cropped_im = self.im[box[1] : box[1] + box[3], box[0] : box[0] + box[2]] if save_dir is not None: Image.fromarray(cropped_im).save(Path(save_dir) / f"{box}.png") images.append(cropped_im) return images
[docs] def to_df(self): # return detections as pandas DataFrames, i.e. print(results.to_df()) preds = self.preds df = pd.DataFrame( preds, columns=[ "bbox_left", "bbox_top", "bbox_width", "bbox_height", "conf", "class", ], ) return df
[docs] def to_list(self): # return a list of Detection objects, i.e. 'for result in results.tolist():' # check if empty if len(self.preds) == 0: logger.debug("No results to show.") return [] dets = [] for x, y, w, h, conf, class_id in self.preds: det = Detection([x, y, w, h], conf, class_id) dets.append(det) return dets
[docs] def merge(self, other): # merge two Detections objects if isinstance(other, Detections): other = other.preds # check if other is empty if len(other) == 0: return self pred = np.concatenate((self.preds, other), axis=0) return Detections(pred, self.im, self.names, self.times)
def __len__(self): return len(self.preds)