Source code for lightautoml.tasks.losses.sklearn

"""Metrics and loss functions for scikit-learn models."""

import logging

from typing import Callable
from typing import Dict
from typing import Optional
from typing import Union

import numpy as np

from .base import Loss
from .base import fw_rmsle


logger = logging.getLogger(__name__)


_sk_loss_mapping = {"rmsle": ("mse", fw_rmsle, np.expm1)}

_sk_force_metric = {
    "rmsle": ("mse", None, None),
}


[docs]class SKLoss(Loss): """Loss used for scikit-learn. Args: loss: One of default loss function. Valid are: 'logloss', 'mse', 'crossentropy', 'rmsle'. loss_params: Addtional loss parameters. fw_func: Forward transformation. Used for transformation of target and item weights. bw_func: backward transformation. Used for predict values transformation. """ def __init__( self, loss: str, loss_params: Optional[Dict] = None, fw_func: Optional[Callable] = None, bw_func: Optional[Callable] = None, ): assert loss in [ "logloss", "mse", "mae", "crossentropy", "rmsle", ], "Not supported in sklearn in general case." self.flg_regressor = loss in ["mse", "rmsle"] if loss in _sk_loss_mapping: self.loss, fw_func, bw_func = _sk_loss_mapping[loss] else: self.loss = loss # set forward and backward transformations if fw_func is not None: self._fw_func = fw_func if bw_func is not None: self._bw_func = bw_func self.loss_params = loss_params
[docs] def set_callback_metric( self, metric: Union[str, Callable], greater_is_better: Optional[bool] = None, metric_params: Optional[Dict] = None, task_name: Optional[str] = None, ): """Callback metric setter. Uses default callback of parent class `Loss`. Args: metric: Callback metric. greater_is_better: Whether or not higher value is better. metric_params: Additional metric parameters. task_name: Name of task. """ if self.loss in _sk_force_metric: metric, greater_is_better, metric_params = _sk_force_metric[self.loss] logger.info2("For sklearn {0} callback metric switched to {1}".format(self.loss, metric)) super().set_callback_metric(metric, greater_is_better, metric_params, task_name)