Source code for sportslabkit.vector_model.sklearn
from typing import Any
from joblib import load
from sklearn.pipeline import Pipeline
from sportslabkit.types import Vector
from sportslabkit.utils import fetch_or_cache_model
from sportslabkit.vector_model.base import BaseVectorModel
[docs]class SklearnVectorModel(BaseVectorModel):
"""
A specialized subclass of BaseVectorModel for scikit-learn pipelines.
This class is designed to facilitate the use of scikit-learn pipelines as vector-based models
within the SportsLabKit ecosystem. It overrides the abstract methods from BaseVectorModel
to provide implementations tailored for scikit-learn pipelines.
Attributes:
model (Pipeline | None): The loaded scikit-learn pipeline model. None if the model is not loaded.
"""
def __init__(
self,
model_path: str = "",
input_vector_size: int | None = None,
output_vector_size: int | None = None) -> None:
super().__init__(input_vector_size, output_vector_size)
self.model_path = model_path
self.load(model_path)
[docs] def forward(self, inputs: Vector, **kwargs: Any) -> Vector:
"""
Implement the forward pass specific to scikit-learn pipelines.
This method takes a vector input and passes it through the scikit-learn pipeline's
`predict` method. Additional keyword arguments can be passed to the `predict` method
via **kwargs.
Args:
inputs (Vector): The input vector, which should match the expected input shape of the pipeline.
**kwargs (Any): Additional keyword arguments to pass to the pipeline's `predict` method.
Returns:
Vector: The output vector from the pipeline's `predict` method.
Raises:
ValueError: If the model attribute is None, indicating that the model has not been loaded.
"""
if self.model is None:
raise ValueError("The model is as empty as a politician's promise. Load it first.")
return self.model.predict(inputs, **kwargs)
def _load_model(self, path: str) -> None:
"""
Load a scikit-learn pipeline model from disk using joblib.
This method uses joblib to load a pre-trained scikit-learn pipeline from the specified file path.
The loaded model is stored in the `model` attribute. A type check is performed to ensure
that the loaded object is a scikit-learn pipeline.
Args:
path (str): The file path to the pre-trained scikit-learn pipeline.
Raises:
TypeError: If the loaded model is not a scikit-learn pipeline.
"""
actual_path = fetch_or_cache_model(path)
self.model = load(actual_path)
if not isinstance(self.model, Pipeline):
raise TypeError(f"Oops, you loaded something that's not a pipeline. Got a {type(self.model)} instead.")