from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterable
from functools import wraps
from typing import Any
import numpy as np
import optuna
import pandas as pd
from sportslabkit import Tracklet
from sportslabkit.detection_model.dummy import DummyDetectionModel
from sportslabkit.logger import logger, tqdm
from sportslabkit.metrics import hota_score
[docs]def with_callbacks(func):
"""Decorator for wrapping methods that require callback invocations.
Args:
func (callable): The method to wrap.
Returns:
callable: The wrapped method.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
event_name = func.__name__
self._invoke_callbacks(f"on_{event_name}_start")
result = func(self, *args, **kwargs)
self._invoke_callbacks(f"on_{event_name}_end")
return result
return wrapper
[docs]class Callback:
"""Base class for creating new callbacks.
This class defines the basic structure of a callback and allows for dynamic method creation
for handling different events in the Trainer's lifecycle.
Methods:
__getattr__(name: str) -> callable:
Returns a dynamically created method based on the given name.
"""
pass
[docs]class MultiObjectTracker(ABC):
def __init__(self, window_size=1, step_size=None, max_staleness=5, min_length=5, callbacks=None):
self.window_size = window_size
self.step_size = step_size or window_size
self.max_staleness = max_staleness
self.min_length = min_length
self.trial_params = []
self.callbacks = callbacks or []
self.reset()
def _check_callbacks(self, callbacks):
if callbacks:
for callback in callbacks:
if not isinstance(callback, Callback):
raise ValueError("All callbacks must be instances of Callback class.")
def _invoke_callbacks(self, method_name):
"""
Invokes the appropriate methods on all callback objects.
Args:
method_name (str): The name of the method to invoke on the callback objects.
"""
for callback in self.callbacks:
method = getattr(callback, method_name, None)
if method:
method(self)
[docs] def update_tracklet(self, tracklet: Tracklet, states: dict[str, Any]):
self._check_required_observations(states)
tracklet.update_observations(states, self.frame_count)
tracklet.increment_counter()
return tracklet
@abstractmethod
[docs] def update(self, current_frame: Any, trackelts: list[Tracklet]) -> tuple[list[Tracklet], list[dict[str, Any]]]:
pass
[docs] def process_sequence_item(self, sequence: Any):
self.frame_count += 1 # incremenmt first to match steps alive
is_batched = isinstance(sequence, np.ndarray) and len(sequence.shape) == 4
tracklets = self.alive_tracklets
if is_batched:
raise NotImplementedError("Batched tracking is not yet supported")
assigned_tracklets, new_tracklets, unassigned_tracklets = self.update(sequence, tracklets)
# Manage tracklet staleness
assigned_tracklets = self.reset_staleness(assigned_tracklets)
unassigned_tracklets = self.increment_staleness(unassigned_tracklets)
non_stale_tracklets, stale_tracklets = self.separate_stale_tracklets(unassigned_tracklets)
stale_tracklets = self.cleanup_tracklets(stale_tracklets)
# Report tracklet status
logger.debug(
f"assigned: {len(assigned_tracklets)}, new: {len(new_tracklets)}, unassigned: {len(non_stale_tracklets)}, stale: {len(stale_tracklets)}"
)
# Update alive and dead tracklets
self.alive_tracklets = assigned_tracklets + new_tracklets + non_stale_tracklets
self.dead_tracklets += stale_tracklets
[docs] def track(self, sequence: Iterable[Any] | np.ndarray) -> Tracklet:
if not isinstance(sequence, (Iterable, np.ndarray)):
raise ValueError("Input 'sequence' must be an iterable or numpy array of frames/batches")
self.reset()
self.track_sequence(sequence)
self.alive_tracklets = self.cleanup_tracklets(self.alive_tracklets)
bbdf = self.to_bbdf()
return bbdf
@with_callbacks
[docs] def track_sequence(self, sequence):
with tqdm(range(0, len(sequence) - self.window_size + 1, self.step_size), desc="Tracking Progress") as t:
for i in t:
self.process_sequence_item(sequence[i : i + self.window_size].squeeze())
t.set_postfix_str(
f"Active: {len(self.alive_tracklets)}, Dead: {len(self.dead_tracklets)}", refresh=True
)
[docs] def cleanup_tracklets(self, tracklets):
for i, _ in enumerate(tracklets):
tracklets[i].cleanup()
def filter_short_tracklets(tracklet):
return len(tracklet) >= self.min_length
tracklets = list(filter(filter_short_tracklets, tracklets))
return tracklets
[docs] def increment_staleness(self, tracklets):
for i, _ in enumerate(tracklets):
tracklets[i].staleness += 1
return tracklets
[docs] def reset_staleness(self, tracklets):
for i, _ in enumerate(tracklets):
tracklets[i].staleness = 0
return tracklets
[docs] def pre_track(self):
# Hook that subclasses can override
pass
[docs] def post_track(self):
pass
[docs] def reset(self):
# Initialize the single object tracker
logger.debug("Initializing tracker...")
self.alive_tracklets = []
self.dead_tracklets = []
self.frame_count = 0
logger.debug("Tracker initialized.")
def _check_required_observations(self, target: dict[str, Any]):
missing_types = [
required_type for required_type in self.required_observation_types if required_type not in target
]
if missing_types:
required_types_str = ", ".join(self.required_observation_types)
missing_types_str = ", ".join(missing_types)
current_types_str = ", ".join(target.keys())
raise ValueError(
f"Input 'target' is missing the following required types: {missing_types_str}.\n"
f"Required types: {required_types_str}\n"
f"Current types in 'target': {current_types_str}"
)
[docs] def check_updated_state(self, state: dict[str, Any]):
if not isinstance(state, dict):
raise ValueError("The `update` method must return a dictionary.")
missing_types = [
required_type for required_type in self.required_observation_types if required_type not in state
]
if missing_types:
missing_types_str = ", ".join(missing_types)
raise ValueError(
f"The returned state from `update` is missing the following required types: {missing_types_str}."
)
[docs] def create_tracklet(self, state: dict[str, Any]):
tracklet = Tracklet(max_staleness=self.max_staleness)
for required_type in self.required_observation_types:
tracklet.register_observation_type(required_type)
for required_type in self.required_state_types:
tracklet.register_state_type(required_type)
self._check_required_observations(state)
self.update_tracklet(tracklet, state)
return tracklet
[docs] def to_bbdf(self):
"""Create a bounding box dataframe."""
all_tracklets = self.alive_tracklets + self.dead_tracklets
return pd.concat([t.to_bbdf() for t in all_tracklets], axis=1).sort_index()
[docs] def separate_stale_tracklets(self, unassigned_tracklets):
stale_tracklets, non_stale_tracklets = [], []
for tracklet in unassigned_tracklets:
if tracklet.is_stale():
stale_tracklets.append(tracklet)
else:
non_stale_tracklets.append(tracklet)
return non_stale_tracklets, stale_tracklets
@property
def required_observation_types(self):
raise NotImplementedError
@property
def required_state_types(self):
raise NotImplementedError
@property
def hparam_searh_space(self):
return {}
[docs] def create_hparam_dict(self):
hparam_search_space = {}
# Create a dictionary for all hyperparameters
hparams = {"self": self.hparam_search_space} if hasattr(self, "hparam_search_space") else {}
for attribute in vars(self):
value = getattr(self, attribute)
if hasattr(value, "hparam_search_space") and attribute not in hparam_search_space:
hparams[attribute] = {}
search_space = value.hparam_search_space
for param_name, param_space in search_space.items():
hparams[attribute][param_name] = {
"type": param_space["type"],
"values": param_space.get("values"),
"low": param_space.get("low"),
"high": param_space.get("high"),
}
return hparams
[docs] def get_new_hyperparameters(self, hparams, trial):
params = {}
for attribute, param_space in hparams.items():
params[attribute] = {}
for param_name, param_values in param_space.items():
if param_values["type"] == "categorical":
params[attribute][param_name] = trial.suggest_categorical(param_name, param_values["values"])
elif param_values["type"] == "float":
params[attribute][param_name] = trial.suggest_float(
param_name, param_values["low"], param_values["high"]
)
elif param_values["type"] == "logfloat":
params[attribute][param_name] = trial.suggest_float(
param_name,
param_values["low"],
param_values["high"],
log=True,
)
elif param_values["type"] == "int":
params[attribute][param_name] = trial.suggest_int(
param_name, param_values["low"], param_values["high"]
)
else:
raise ValueError(f"Unknown parameter type: {param_values['type']}")
return params
[docs] def apply_hyperparameters(self, params):
# Apply the hyperparameters to the attributes of `self`
for attribute, param_values in params.items():
for param_name, param_value in param_values.items():
if attribute not in self.__dict__ and attribute != "self":
raise AttributeError(f"{attribute=} not found in object") # Raising specific error
if attribute == "self":
logger.debug(f"Setting {param_name} to {param_value} for {self}")
setattr(self, param_name, param_value)
else:
attr_obj = getattr(self, attribute)
if param_name in attr_obj.__dict__:
setattr(attr_obj, param_name, param_value)
logger.debug(f"Setting {param_name} to {param_value} for {attribute}")
else:
__dict__ = attr_obj.__dict__
raise TypeError(
f"Cannot set {param_name=} on {attribute=}, as it is immutable or not in {list(__dict__.keys())}"
)
[docs] def tune_hparams(
self,
frames_list,
bbdf_gt_list,
n_trials=100,
hparam_search_space=None,
verbose=False,
return_study=False,
use_bbdf=False,
reuse_detections=False,
sampler=None,
pruner=None,
):
def objective(trial: optuna.Trial):
params = self.get_new_hyperparameters(hparams, trial)
self.trial_params.append(params)
self.apply_hyperparameters(params)
scores = []
for i, (frames, bbdf_gt) in enumerate(zip(frames_list, bbdf_gt_list)):
self.reset()
if reuse_detections:
self.detection_model = self.detection_models[i]
try:
bbdf_pred = self.track(frames)
except ValueError as e: # Reuturn nan when no tracks are detected
logger.error(e)
return np.nan
score = hota_score(bbdf_pred, bbdf_gt)["HOTA"]
scores.append(score)
trial.report(np.mean(scores), step=len(scores)) # Report intermediate score
if trial.should_prune(): # Check for pruning
raise optuna.TrialPruned()
return np.mean(scores) # return the average score
hparams = hparam_search_space or self.create_hparam_dict()
logger.info("Hyperparameter search space:")
for attribute, param_space in hparams.items():
logger.info(f"{attribute}:")
for param_name, param_values in param_space.items():
logger.info(f"\t{param_name}: {param_values}")
if verbose:
optuna.logging.set_verbosity(optuna.logging.INFO)
else:
optuna.logging.set_verbosity(optuna.logging.WARNING)
if use_bbdf:
raise NotImplementedError
if reuse_detections:
self.detection_models = []
for frames in frames_list:
list_of_detections = []
for frame in tqdm(frames, desc="Detecting frames for reuse"):
list_of_detections.append(self.detection_model(frame)[0])
dummy_detection_model = DummyDetectionModel(list_of_detections)
og_detection_model = self.detection_model
self.detection_models.append(dummy_detection_model)
if sampler is None:
sampler = optuna.samplers.TPESampler(multivariate=True)
if pruner is None:
pruner = optuna.pruners.MedianPruner()
self.trial_params = [] # Used to store the parameters for each trial
study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
study.optimize(objective, n_trials=n_trials)
if reuse_detections:
# reset detection model
self.detection_model = og_detection_model
best_value = study.best_value
self.best_params = self.trial_params[study.best_trial.number]
self.apply_hyperparameters(self.best_params)
if return_study:
return self.best_params, best_value, study
return self.best_params, best_value