mlrl.testbed.data_splitting module¶
Author: Michael Rapp (michael.rapp.ml@gmail.com)
Provides classes for training and evaluating multi-label classifiers using either cross validation or separate training and test sets.
- class mlrl.testbed.data_splitting.CrossValidationFold(num_folds: int, fold: int, current_fold: int)¶
Bases:
DataSplit
Provides information about a split of the available data that is used by a single fold of a cross validation.
- get_fold() int | None ¶
Returns the cross validation fold, this split corresponds to.
- Returns:
The cross validation fold, starting at 0, or None, if no cross validation is used
- get_num_folds() int ¶
Returns the total number of cross validation folds.
- Returns:
The total number of cross validation folds or 1, if no cross validation is used
- class mlrl.testbed.data_splitting.CrossValidationOverall(num_folds: int)¶
Bases:
DataSplit
Provides information about the overall splits of a cross validation.
- get_fold() int | None ¶
Returns the cross validation fold, this split corresponds to.
- Returns:
The cross validation fold, starting at 0, or None, if no cross validation is used
- get_num_folds() int ¶
Returns the total number of cross validation folds.
- Returns:
The total number of cross validation folds or 1, if no cross validation is used
- class mlrl.testbed.data_splitting.CrossValidationSplitter(data_set: DataSet, num_folds: int, current_fold: int, random_state: int)¶
Bases:
DataSplitter
Splits the available data into training and test sets corresponding to the individual folds of a cross validation.
- class mlrl.testbed.data_splitting.DataSet(data_dir: str, data_set_name: str, use_one_hot_encoding: bool)¶
Bases:
object
Stores the properties of a data set to be used for training and evaluating multi-label classifiers.
- class mlrl.testbed.data_splitting.DataSplit¶
Bases:
ABC
Provides information about a split of the available data that is used for training and testing.
- abstract get_fold() int | None ¶
Returns the cross validation fold, this split corresponds to.
- Returns:
The cross validation fold, starting at 0, or None, if no cross validation is used
- abstract get_num_folds() int ¶
Returns the total number of cross validation folds.
- Returns:
The total number of cross validation folds or 1, if no cross validation is used
- is_cross_validation_used() bool ¶
Returns whether cross validation is used or not.
- Returns:
True, if cross validation is used, False otherwise
- class mlrl.testbed.data_splitting.DataSplitter¶
Bases:
ABC
An abstract base class for all classes that split a data set into training and test data.
- class Callback¶
Bases:
ABC
An abstract base class for all classes that train and evaluate a model given a predefined split of the available data.
- abstract train_and_evaluate(meta_data: MetaData, data_split: DataSplit, train_x, train_y, test_x, test_y)¶
The function that is invoked to train a model on a training set and evaluate it on a test set.
- Parameters:
meta_data – The meta-data of the training data set
data_split – Information about the split of the available data that should be used for training and evaluating the model
train_x – The feature matrix of the training examples
train_y – The label matrix of the training examples
test_x – The feature matrix of the test examples
test_y – The label matrix of the test examples
- class mlrl.testbed.data_splitting.DataType(value, names=None, *values, module=None, qualname=None, type=None, start=1, boundary=None)¶
Bases:
Enum
Characterizes data as either training or test data.
- TEST = 'test'¶
- TRAINING = 'training'¶
- class mlrl.testbed.data_splitting.NoSplit¶
Bases:
DataSplit
Provides information about data that has not been split into separate training and test data.
- get_fold() int | None ¶
Returns the cross validation fold, this split corresponds to.
- Returns:
The cross validation fold, starting at 0, or None, if no cross validation is used
- get_num_folds() int ¶
Returns the total number of cross validation folds.
- Returns:
The total number of cross validation folds or 1, if no cross validation is used
- class mlrl.testbed.data_splitting.NoSplitter(data_set: DataSet)¶
Bases:
DataSplitter
Does not split the available data into separate train and test sets.
- class mlrl.testbed.data_splitting.TrainTestSplitter(data_set: DataSet, test_size: float, random_state: int)¶
Bases:
DataSplitter
Splits the available data into a single train and test set.
- class mlrl.testbed.data_splitting.TrainingTestSplit¶
Bases:
DataSplit
Provides information about a split of the available data into training and test data.
- get_fold() int | None ¶
Returns the cross validation fold, this split corresponds to.
- Returns:
The cross validation fold, starting at 0, or None, if no cross validation is used
- get_num_folds() int ¶
Returns the total number of cross validation folds.
- Returns:
The total number of cross validation folds or 1, if no cross validation is used
- mlrl.testbed.data_splitting.check_if_files_exist(directory: str, file_names: List[str]) bool ¶
Returns whether all given files exist or not. If some of the files are missing, an IOError is raised.
- Parameters:
directory – The path to the directory where the files should be located
file_names – A list that contains the names of all files to be checked
- Returns:
True, if all files exist, False, if all files are missing