Source code for sportslabkit.vector_model.base
from abc import ABC, abstractmethod
from typing import Any
import numpy as np
import torch
from sportslabkit.types import Vector
[docs]class BaseVectorModel(ABC):
"""Abstract Base Class for handling vector-based models.
This class encapsulates model loading, input/output validation, and forward pass operations.
"""
def __init__(self, input_vector_size: int | None = None, output_vector_size: int | None = None) -> None:
"""Initialize the BaseVectorModel.
Args:
input_vector_size (Optional[int]): The size of the input vector. None to bypass validation.
output_vector_size (Optional[int]): The size of the output vector. None to bypass validation.
"""
super().__init__()
self.input_vector_size = input_vector_size
self.output_vector_size = output_vector_size
self.model = None # Placeholder for the actual model
def __call__(self, inputs: Vector, **kwargs: Any) -> Vector:
"""Call the model's forward method after input validation and before output validation.
Args:
inputs (Vector): The input data.
**kwargs (Any): Additional arguments to be passed to the forward method.
Returns:
Vector: The output data.
"""
inputs = self._check_and_fix_inputs(inputs)
outputs = self.forward(inputs, **kwargs)
return self._check_and_fix_outputs(outputs)
def _check_and_fix_inputs(self, inputs: Vector) -> Vector:
"""Validate and optionally fix the inputs before feeding them to the model.
Args:
inputs (Vector): The input data.
Returns:
Vector: The validated and possibly fixed input data.
"""
if self.input_vector_size and len(inputs) != self.input_vector_size:
raise ValueError(f"Input vector size mismatch. Expected {self.input_vector_size}, got {len(inputs)}.")
self.input_batch_size = len(inputs)
return np.array(inputs) if isinstance(inputs, list) else inputs
def _check_and_fix_outputs(self, outputs: Vector) -> Vector:
"""Validate and optionally fix the outputs before returning them.
Args:
outputs (Vector): The output data.
Returns:
Vector: The validated and possibly fixed output data.
"""
if self.output_vector_size and len(outputs) != self.output_vector_size:
raise ValueError(f"Output vector size mismatch. Expected {self.output_vector_size}, got {len(outputs)}.")
self.output_batch_size = len(outputs)
assert self.input_batch_size == self.output_batch_size, f"Input({self.input_batch_size}) and output({self.output_batch_size}) batch sizes do not match."
return np.array(outputs) if isinstance(outputs, torch.Tensor) else outputs
@abstractmethod
[docs] def forward(self, inputs: Vector, **kwargs: Any) -> Vector:
"""Define the forward pass of the model. Must be overridden by subclasses.
Args:
inputs (Vector): The input data.
**kwargs (Any): Additional arguments to be passed to the forward method.
Returns:
Vector: The output data.
"""
raise NotImplementedError("The forward method must be implemented by subclasses.")
[docs] def load(self, path: str) -> None:
"""Load the model from disk.
Args:
path (str): The path to the model file.
"""
self._load_model(path)
self._post_load_check()
@abstractmethod
def _load_model(self, path: str) -> None:
"""User Defined model loading logic. Must be overridden by subclasses.
Args:
path (str): The path to the model file.
"""
raise NotImplementedError("The _load_model method must be implemented by subclasses.")
def _post_load_check(self) -> None:
"""Check whether the model has been loaded correctly.
"""
if self.model is None:
raise ValueError("Model not loaded correctly. Fix your _load_model implementation.")