Source code for sportslabkit.image_model.torchreid

from dataclasses import field


try:
    from torchreid.utils import FeatureExtractor
except ImportError:
    print(
        "The torchreid module is not installed. Please install it using the following command:\n"
        "pip install git+https://github.com/KaiyangZhou/deep-person-reid.git"
    )

from sportslabkit.constants import CACHE_DIR
from sportslabkit.image_model.base import BaseImageModel
from sportslabkit.logger import logger
from sportslabkit.utils import (
    HiddenPrints,
    download_file_from_google_drive,
)


[docs]model_save_dir = CACHE_DIR / "sportslabkit" / "models" / "torchreid"
[docs]model_dict = { "shufflenet": "https://drive.google.com/file/d/1RFnYcHK1TM-yt3yLsNecaKCoFO4Yb6a-/view?usp=sharing", "mobilenetv2_x1_0": "https://drive.google.com/file/d/1K7_CZE_L_Tf-BRY6_vVm0G-0ZKjVWh3R/view?usp=sharing", "mobilenetv2_x1_4": "https://drive.google.com/file/d/10c0ToIGIVI0QZTx284nJe8QfSJl5bIta/view?usp=sharing", "mlfn": "https://drive.google.com/file/d/1PP8Eygct5OF4YItYRfA3qypYY9xiqHuV/view?usp=sharing", "osnet_x1_0": "https://drive.google.com/file/d/1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY/view?usp=sharing", "osnet_x0_75": "https://drive.google.com/file/d/1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq/view?usp=sharing", "osnet_x0_5": "https://drive.google.com/file/d/16DGLbZukvVYgINws8u8deSaOqjybZ83i/view?usp=sharing", "osnet_x0_25": "https://drive.google.com/file/d/1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs/view?usp=sharing", "osnet_ibn_x1_0": "https://drive.google.com/file/d/1sr90V6irlYYDd4_4ISU2iruoRG8J__6l/view?usp=sharing", "osnet_ain_x1_0": "https://drive.google.com/file/d/1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo/view?usp=sharing", "osnet_ain_x0_75": "https://drive.google.com/file/d/1apy0hpsMypqstfencdH-jKIUEFOW4xoM/view?usp=sharing", "osnet_ain_x0_5": "https://drive.google.com/file/d/1KusKvEYyKGDTUBVRxRiz55G31wkihB6l/view?usp=sharing", "osnet_ain_x0_25": "https://drive.google.com/file/d/1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt/view?usp=sharing", "resnet50_MSMT17": "https://drive.google.com/file/d/1yiBteqgIZoOeywE8AhGmEQl7FTVwrQmf/view?usp=sharing", "osnet_x1_0_MSMT17": "https://drive.google.com/file/d/1IosIFlLiulGIjwW3H8uMRmx3MzPwf86x/view?usp=sharing", "osnet_ain_x1_0_MSMT17": "https://drive.google.com/file/d/1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal/view?usp=sharing", "resnet50_MSMT17x": "https://drive.google.com/file/d/1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj/view?usp=sharing", "resnet50_fc512_MSMT17": "https://drive.google.com/file/d/1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud/view?usp=sharing", }
[docs]def show_torchreid_models(): """Print available models as a list.""" return list(model_dict.keys())
[docs]def download_model(model_name): if model_name not in model_dict: raise ValueError(f"Model {model_name} not available. Available models are: {show_torchreid_models()}") url = model_dict[model_name] filename = model_name + ".pth" file_path = model_save_dir / filename file_path.parent.mkdir(parents=True, exist_ok=True) if file_path.exists(): logger.debug(f"Model {model_name} already exists in {model_save_dir}.") return file_path download_file_from_google_drive(url.split("/")[-2], file_path) logger.debug(f"Model {model_name} successfully downloaded and saved to {model_save_dir}.") return file_path
[docs]class BaseTorchReIDModel(BaseImageModel): def __init__( self, name: str = "osnet_x1_0", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__() self.name = name self.path = path self.device = device self.image_size = image_size self.pixel_mean = pixel_mean self.pixel_std = pixel_std self.pixel_norm = pixel_norm self.verbose = verbose self.model = self.load()
[docs] def load(self): model_name = self.name model_path = self.path device = self.device verbose = self.verbose if (model_name != "") and (model_path == ""): model_path = download_model(model_name) logger.debug(model_path) if model_name.endswith("MSMT17"): model_name = model_name.replace("_MSMT17", "") if verbose: return FeatureExtractor( model_name=model_name, model_path=model_path, device=device, ) with HiddenPrints(): return FeatureExtractor( model_name=model_name, model_path=model_path, device=device, )
[docs] def forward(self, x): return self.model(list(x))
[docs]class ShuffleNet(BaseTorchReIDModel): def __init__( self, name: str = "shufflenet", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class MobileNetV2_x1_0(BaseTorchReIDModel): def __init__( self, name: str = "mobilenetv2_x1_0", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class MobileNetV2_x1_4(BaseTorchReIDModel): def __init__( self, name: str = "mobilenetv2_x1_4", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class MLFN(BaseTorchReIDModel): def __init__( self, name: str = "mfln", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class OSNet_x1_0(BaseTorchReIDModel): def __init__( self, name: str = "osnet_x1_0", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class OSNet_x0_75(BaseTorchReIDModel): def __init__( self, name: str = "osnet_x0_75", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class OSNet_x0_5(BaseTorchReIDModel): def __init__( self, name: str = "osnet_x0_5", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class OSNet_x0_25(BaseTorchReIDModel): def __init__( self, name: str = "osnet_x0_25", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class OSNet_ibn_x1_0(BaseTorchReIDModel): def __init__( self, name: str = "osnet_ibn_x1_0", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class OSNet_ain_x1_0(BaseTorchReIDModel): def __init__( self, name: str = "osnet_ain_x1_0", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class OSNet_ain_x0_75(BaseTorchReIDModel): def __init__( self, name: str = "osnet_ain_x0_75", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class OSNet_ain_x0_5(BaseTorchReIDModel): def __init__( self, name: str = "osnet_ain_x0_5", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class OSNet_ain_x0_25(BaseTorchReIDModel): def __init__( self, name: str = "osnet_ain_x0_25", path: str = "", device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class ResNet50(BaseTorchReIDModel): def __init__( self, name: str = "resnet50", path: str = model_dict["resnet50_MSMT17"], device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)
[docs]class ResNet50_fc512(BaseTorchReIDModel): def __init__( self, name: str = "resnet50_fc512", path: str = model_dict["resnet50_fc512_MSMT17"], device: str = "cpu", image_size: tuple[int, int] = (256, 128), pixel_mean: list[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]), pixel_std: list[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]), pixel_norm: bool = True, verbose: bool = False, ): super().__init__(name, path, device, image_size, pixel_mean, pixel_std, pixel_norm, verbose)