Source code for sportslabkit.sot.base

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any

import numpy as np
import optuna

from sportslabkit import Tracklet
from sportslabkit.dataframe.bboxdataframe import BBoxDataFrame
from sportslabkit.logger import logger
from sportslabkit.metrics.object_detection import iou_scores


[docs]class SingleObjectTracker(ABC): def __init__(self, target, window_size=1, step_size=None, pre_init_args={}, post_init_args={}): self.target = target self.init_target = target self.window_size = window_size self.step_size = step_size or window_size self.pre_init_args = pre_init_args self.post_init_args = post_init_args self.reset()
[docs] def pre_initialize(self, **kwargs): # Hook that subclasses can override pass
[docs] def post_initialize(self, **kwargs): # Hook that subclasses can override pass
[docs] def update_tracklet_observations(self, states: dict[str, Any]): self.check_required_types(states) for required_type in self.required_keys: self.tracklet.update_observation(required_type, states[required_type]) self.tracklet.increment_counter()
@abstractmethod
[docs] def update(self, current_frame: Any) -> dict[str, Any]: pass
[docs] def process_sequence_item(self, sequence: Any): is_batched = isinstance(sequence, np.ndarray) and len(sequence.shape) == 4 if is_batched: updated_states = self.update(sequence) else: updated_states = [self.update(sequence)] for updated_state in updated_states: self.check_updated_state(updated_state) self.update_tracklet_observations(updated_state) self.frame_count += 1
[docs] def track(self, sequence: Iterable[Any] | np.ndarray) -> Tracklet: if not isinstance(sequence, (Iterable, np.ndarray)): raise ValueError("Input 'sequence' must be an iterable or numpy array of frames/batches") self.pre_track() for i in range(0, len(sequence) - self.window_size + 1, self.step_size): logger.debug(f"Processing frames {i} to {i + self.window_size}") self.process_sequence_item(sequence[i : i + self.window_size]) self.post_track() return self.tracklet
[docs] def pre_track(self): # Hook that subclasses can override pass
[docs] def post_track(self): pass
[docs] def reset(self): self.pre_initialize(**self.pre_init_args) # Initialize the single object tracker logger.debug("Initializing tracker...") self.tracklet = Tracklet() for required_type in self.required_keys: self.tracklet.register_observation_type(required_type) self.frame_count = 0 self.update_tracklet_observations(self.init_target) self.post_initialize(**self.post_init_args) logger.debug("Tracker initialized.")
[docs] def check_required_types(self, target: dict[str, Any]): missing_types = [required_type for required_type in self.required_keys if required_type not in target] if missing_types: required_types_str = ", ".join(self.required_keys) missing_types_str = ", ".join(missing_types) current_types_str = ", ".join(target.keys()) raise ValueError( f"Input 'target' is missing the following required types: {missing_types_str}.\n" f"Required types: {required_types_str}\n" f"Current types in 'target': {current_types_str}" )
[docs] def check_updated_state(self, state: dict[str, Any]): if not isinstance(state, dict): raise ValueError("The `update` method must return a dictionary.") missing_types = [required_type for required_type in self.required_keys if required_type not in state] if missing_types: missing_types_str = ", ".join(missing_types) raise ValueError( f"The returned state from `update` is missing the following required types: {missing_types_str}." )
@property def required_keys(self): raise NotImplementedError @property def hparam_searh_space(self): return {}
[docs] def create_hparam_dict(self): # Create a dictionary for all hyperparameters hparams = {"self": self.hparam_search_space} if hasattr(self, "hparam_search_space") else {} for attribute in vars(self): value = getattr(self, attribute) if hasattr(value, "hparam_search_space") and attribute not in self.hparam_search_space: hparams[attribute] = {} search_space = value.hparam_search_space for param_name, param_space in search_space.items(): hparams[attribute][param_name] = { "type": param_space["type"], "values": param_space.get("values"), "low": param_space.get("low"), "high": param_space.get("high"), } return hparams
[docs] def tune_hparams( self, frames, ground_truth_positions, n_trials=100, hparam_search_space=None, metric=iou_scores, verbose=False, return_study=False, ): def objective(trial: optuna.Trial): params = {} for attribute, param_space in hparams.items(): params[attribute] = {} for param_name, param_values in param_space.items(): if param_values["type"] == "categorical": params[attribute][param_name] = trial.suggest_categorical(param_name, param_values["values"]) elif param_values["type"] == "float": params[attribute][param_name] = trial.suggest_float( param_name, param_values["low"], param_values["high"] ) elif param_values["type"] == "logfloat": params[attribute][param_name] = trial.suggest_float( param_name, param_values["low"], param_values["high"], log=True, ) elif param_values["type"] == "int": params[attribute][param_name] = trial.suggest_int( param_name, param_values["low"], param_values["high"] ) else: raise ValueError(f"Unknown parameter type: {param_values['type']}") # Apply the hyperparameters to the attributes of `self` for attribute, param_values in params.items(): for param_name, param_value in param_values.items(): if attribute == "self": setattr(self, param_name, param_value) else: setattr(getattr(self, attribute), param_name, param_value) self.reset() tracklet = self.track(frames) predictions = tracklet.get_observations("box") # Fixme: Should not allow multiple ground truth targets for single object tracking ground_truth_targets = [gt[0] for gt in ground_truth_positions] score = iou_scores(predictions, ground_truth_targets, xywh=True) return score # check that the ground truth positions are in the correct format if isinstance(ground_truth_positions, BBoxDataFrame): ground_truth_positions = np.expand_dims(ground_truth_positions.values, axis=1)[:, :, :4] hparams = self.create_hparam_dict() print("Hyperparameter search space: ") for attribute, param_space in hparams.items(): print(f"{attribute}:") for param_name, param_values in param_space.items(): print(f"\t{param_name}: {param_values}") if verbose: optuna.logging.set_verbosity(optuna.logging.INFO) else: optuna.logging.set_verbosity(optuna.logging.WARNING) study = optuna.create_study(direction="maximize") study.optimize(objective, n_trials=n_trials) best_params = study.best_params best_iou = study.best_value if return_study: return best_params, best_iou, study return best_params, best_iou