Source code for sportslabkit.calibration_model.base

from abc import ABC, abstractmethod

import numpy as np
from PIL import Image

from sportslabkit.utils import read_image


[docs]class BaseCalibrationModel(ABC): """ Base class for detection models. This class implements basic functionality for handling input and output data, and requires subclasses to implement model loading and forward pass functionality. Subclasses should override the 'load' and 'forward' methods. The 'load' method should handle loading the model from the specified repository and checkpoint, and 'forward' should define the forward pass of the model. Then add `ConfigTemplates` for your model to define the available configuration options. The input to the model should be flexible. It accepts numpy.ndarray, torch.Tensor, pathlib Path, string file, PIL Image, or a list of any of these. All inputs will be converted to a list of numpy arrays representing the images. The output of the model is expected to be a list of `Detection` objects, where each `Detection` object represents a detected object in an image. If the model's output does not meet this expectation, `_check_and_fix_outputs` method should convert the output into a compatible format. Example: class CustomDetectionModel(BaseDetectionModel): def load(self): # Load your model here pass def forward(self, x): # Define the forward pass here pass Attributes: model_config (Optional[dict]): The configuration for the model. input_is_batched (bool): Whether the input is batched or not. This is set by the `_check_and_fix_inputs` method. """ def __init__(self): """ Initializes the base detection model. Args: model_config (Optional[dict]): The configuration for the model. This is optional and can be used to pass additional parameters to the model. """ super().__init__() self.input_is_batched = False # initialize the input_is_batched attribute def __call__(self, inputs, **kwargs): inputs = self._check_and_fix_inputs(inputs) results = self.forward(inputs, **kwargs) results = self._check_and_fix_outputs(results, inputs) detections = self._postprocess(results) return detections def _check_and_fix_inputs(self, img): """Check input type and shape. Acceptable input types are numpy.ndarray, torch.Tensor, pathlib Path, string file, PIL Image, or a list of any of these. All inputs will be converted to a list of numpy arrays. """ # if isinstance(inputs, (list, tuple, np.ndarray, torch.Tensor)): # self.input_is_batched = isinstance(inputs, (list, tuple)) or (hasattr(inputs, "ndim") and inputs.ndim == 4) # if not self.input_is_batched: # inputs = [inputs] # else: # inputs = [inputs] # imgs = [] # for img in inputs: # img = self.read_image(img) # imgs.append(img) return self.read_image(img)
[docs] def read_image(self, img): return read_image(img)
def _check_and_fix_outputs(self, outputs, inputs): """ Check output type and convert to list of `Detections` objects. The function expects the raw output from the model to be either a list of `Detection` objects or a list of lists, where each sub-list should contain four elements corresponding to the bounding box of the detected object. See `Detection` and `Detections` class for more details. If the output is not in the correct format, a ValueError is raised. Args: outputs: The raw output from the model. inputs: The corresponding inputs to the model. Returns: A list of `Detections` objects. """ # # The output should be a list or np.ndarray of 3x3 homography matrices # if isinstance(outputs, (list, np.ndarray)): # if len(outputs) == 0: # raise ValueError("Output is empty.") # if isinstance(outputs[0], (list, np.ndarray)): # if len(outputs[0]) != 3 or len(outputs[0][0]) != 3: # raise ValueError("Output should be a list of 3x3 homography matrices.") # else: # raise ValueError("Output should be a list of 3x3 homography matrices.") return outputs def _postprocess(self, outputs): """An empty post-processing method that does nothing. Override in subclasses for additional processing if needed.""" return outputs @abstractmethod
[docs] def forward(self, x): """ Args: x (Tensor): input tensor Returns: Tensor: output tensor """ raise NotImplementedError
[docs] def test(self): import cv2 from sportslabkit.utils.utils import get_git_root # batched inference git_root = get_git_root() im_path = git_root / "data" / "samples" / "ney.jpeg" imgs = [ str(im_path), # filename im_path, # Path "https://ultralytics.com/images/zidane.jpg", # URI cv2.imread(str(im_path))[:, :, ::-1], # OpenCV Image.open(str(im_path)), # PIL np.zeros((320, 640, 3)), # numpy ] results = self(imgs) print(results) for img in imgs: results = self(img) print(results)
if __name__ == "__main__":
[docs] model = BaseCalibrationModel()
model.test()