Models Zoo

class models.ExtractorPredictor(extractor: FeatureExtractor, predictor: FeaturePredictor, class_encoder=None, classes=None)

Bases: Model

A class representing an ExtractorPredictor model.

classmethod load_model(path, extractor: FeatureExtractor, predictor: FeaturePredictor)

Load a pre-trained model from the specified path.

Parameters:
  • path (str) – The path to the saved model.

  • extractor (FeatureExtractor) – The feature extractor used by the model and present in the config.

  • predictor (FeaturePredictor) – The feature predictor used by the model and present in the config.

Returns:

An instance of the model class with the loaded feature extractor and predictor.

Return type:

cls

predict(dataloader, logger, return_features=False)

Makes predictions using the provided dataloader.

Parameters:
  • dataloader (torch.utils.data.DataLoader) – The dataloader containing the data for prediction.

  • logger (logging.Logger) – The logger object for logging information.

  • return_features (bool, optional) – Whether to return the extracted features along with predictions. Defaults to False.

Returns:

An array of predicted labels. true_labels (numpy.ndarray): An array of true labels. features_list (numpy.ndarray, optional): An array of extracted features if return_features is True.

Return type:

predictions (numpy.ndarray)

save_model(path)

Save the model to the specified path.

Args: - path (str): The path to save the model.

Returns: - None

train(dataloader, logger)

Trains the model using the provided dataloader.

Parameters:
  • dataloader (torch.utils.data.DataLoader) – The dataloader containing the training data.

  • logger (logging.Logger) – The logger object for logging training progress.

class models.Model

Bases: object

A base class for machine learning models.

evaluate(dataloader)
classmethod load_model(path)

Load a saved model from the specified path.

Args: - path (str): The path to the saved model file.

Returns: - model (Model): The loaded model.

Raises: - AssertionError: If the loaded model is not an instance of the current class.

save_model(path)

Save the model to the specified path.

Args: - path (str): The path to save the model.

Returns: - None

train(dataloader)