Source code for sportslabkit.mot.callbacks

"""Defines the Callback base class and utility decorators for use with the Trainer class.

The Callback class provides a dynamic way to hook into various stages of the Trainer's operations.
It uses Python's __getattr__ method to dynamically handle calls to methods that are not explicitly defined,
allowing it to handle arbitrary `on_<event_name>_start` and `on_<event_name>_end` methods.

Example:
    class MyPrintingCallback(Callback):
        def on_train_start(self, trainer):
            print("Training is starting")
"""


from scipy import stats

from sportslabkit.logger import logger
from sportslabkit.mot.base import Callback, MultiObjectTracker
from sportslabkit.types import Vector
from sportslabkit.vector_model import BaseVectorModel


[docs]class TeamClassificationCallback(Callback): def __init__(self, vector_model: BaseVectorModel): """Initialize TeamClassificationCallback. Args: vector_model (BaseVectorModel): A trained object responsible for classifying teams. This object is generally loaded from a pickle file that contains a trained scikit-learn Pipeline. - The object should have a `predict` method with the following specifications: - predict(input_features: np.ndarray) -> np.ndarray - Input: `input_features` is an ndarray of shape `(num_samples, num_features)`. `num_samples` is the number of samples, and `num_features` is the feature dimension for each sample. - Output: An ndarray of shape `(num_samples,)` containing the predicted team IDs. For a 2-class problem, it will contain integers like 0 or 1. - Example: If you're using an SVM-based classifier saved using pickle, this `predict` method would take a feature vector and output the corresponding team IDs (either 0 or 1 in a 2-class problem). Note: The `vector_model` is expected to be a serialized object (e.g., pickle file) conforming to the above `predict` method specifications. It's commonly generated using scikit-learn and saved for future use. """ super().__init__() self.vector_model = vector_model
[docs] def on_track_sequence_end(self, tracker: MultiObjectTracker) -> None: """Call the `vector_model.predict` method on each tracklet to classify it into a team ID. Method called at the end of a track sequence. During this phase, team classification is performed on each tracklet using the `vector_model.predict`. Args: tracker (MultiObjectTracker): The instance of the tracker. Notes: - Team classification is applied to each tracklet. - An N-dimensional feature vector is extracted for each tracklet using `tracklet.get_observations(“feature”)`. - `vector_model.predict` is used to classify the tracklet into a team ID (0 or 1 in a 2-class problem). """ logger.debug("Applying team classification method...") all_tracklets = tracker.alive_tracklets + tracker.dead_tracklets for tracklet in all_tracklets: tracklet_features: Vector = tracklet.get_observations("feature") # Using forward method for model inference predicted_team_id = self.vector_model(tracklet_features) # Assuming you want the most frequent prediction as the final team ID most_frequent_team_id = stats.mode(predicted_team_id, axis=0, keepdims=False).mode tracklet.team_id = most_frequent_team_id