Models Zoo
- class models.ExtractorPredictor(extractor: FeatureExtractor, predictor: FeaturePredictor, class_encoder=None, classes=None)
Bases:
Model
A class representing an ExtractorPredictor model.
- classmethod load_model(path, extractor: FeatureExtractor, predictor: FeaturePredictor)
Load a pre-trained model from the specified path.
- Parameters:
path (str) – The path to the saved model.
extractor (FeatureExtractor) – The feature extractor used by the model and present in the config.
predictor (FeaturePredictor) – The feature predictor used by the model and present in the config.
- Returns:
An instance of the model class with the loaded feature extractor and predictor.
- Return type:
cls
- predict(dataloader, logger, return_features=False)
Makes predictions using the provided dataloader.
- Parameters:
dataloader (torch.utils.data.DataLoader) – The dataloader containing the data for prediction.
logger (logging.Logger) – The logger object for logging information.
return_features (bool, optional) – Whether to return the extracted features along with predictions. Defaults to False.
- Returns:
An array of predicted labels. true_labels (numpy.ndarray): An array of true labels. features_list (numpy.ndarray, optional): An array of extracted features if return_features is True.
- Return type:
predictions (numpy.ndarray)
- save_model(path)
Save the model to the specified path.
Args: - path (str): The path to save the model.
Returns: - None
- train(dataloader, logger)
Trains the model using the provided dataloader.
- Parameters:
dataloader (torch.utils.data.DataLoader) – The dataloader containing the training data.
logger (logging.Logger) – The logger object for logging training progress.
- class models.Model
Bases:
object
A base class for machine learning models.
- evaluate(dataloader)
- classmethod load_model(path)
Load a saved model from the specified path.
Args: - path (str): The path to the saved model file.
Returns: - model (Model): The loaded model.
Raises: - AssertionError: If the loaded model is not an instance of the current class.
- save_model(path)
Save the model to the specified path.
Args: - path (str): The path to save the model.
Returns: - None
- train(dataloader)