Source code for lightautoml.validation.utils

"""Validation utils."""

from typing import Callable
from typing import Optional
from typing import Union
from typing import cast

from ..dataset.base import LAMLDataset
from ..dataset.np_pd_dataset import CSRSparseDataset
from ..dataset.np_pd_dataset import NumpyDataset
from ..dataset.np_pd_dataset import PandasDataset
from .base import DummyIterator
from .base import HoldoutIterator
from .base import TrainValidIterator
from .np_iterators import get_numpy_iterator


NpDataset = Union[CSRSparseDataset, NumpyDataset, PandasDataset]


[docs]def create_validation_iterator( train: LAMLDataset, valid: Optional[LAMLDataset] = None, n_folds: Optional[int] = None, cv_iter: Optional[Callable] = None, ) -> TrainValidIterator: """Creates train-validation iterator. If train is one of common datasets types (``PandasDataset``, ``NumpyDataset``, ``CSRSparseDataset``) the :func:`~lightautoml.validation.np_iterators.get_numpy_iterator` will be used. Else if validation dataset is defined, the holdout-iterator will be used. Else the dummy iterator will be used. Args: train: Dataset to train. valid: Optional dataset for validate. n_folds: maximum number of folds to iterate. If ``None`` - iterate through all folds. cv_iter: Takes dataset as input and return an iterator of indexes of train/valid for train dataset. Returns: New iterator. """ if type(train) in [PandasDataset, NumpyDataset, CSRSparseDataset]: train = cast(NpDataset, train) valid = cast(NpDataset, valid) iterator = get_numpy_iterator(train, valid, n_folds, cv_iter) else: if valid is not None: iterator = HoldoutIterator(train, valid) else: iterator = DummyIterator(train) return iterator