Source code for lightautoml.tasks.losses.sklearn

"""Metrics and loss functions for scikit-learn models."""

import logging

from typing import Callable
from typing import Dict
from typing import Optional
from typing import Union

import numpy as np

from .base import Loss


logger = logging.getLogger(__name__)


def fw_rmsle(x, y):
    return np.log1p(x), y


_sk_loss_mapping = {"rmsle": ("mse", fw_rmsle, np.expm1)}

_sk_force_metric = {
    "rmsle": ("mse", None, None),
}


[docs]class SKLoss(Loss): """Loss used for scikit-learn."""
[docs] def __init__( self, loss: str, loss_params: Optional[Dict] = None, fw_func: Optional[Callable] = None, bw_func: Optional[Callable] = None, ): """ Args: loss: One of default loss function. Valid are: 'logloss', 'mse', 'crossentropy', 'rmsle'. loss_params: Addtional loss parameters. fw_func: Forward transformation. Used for transformation of target and item weights. bw_func: backward transformation. Used for predict values transformation. """ assert loss in [ "logloss", "mse", "crossentropy", "rmsle", ], "Not supported in sklearn in general case." self.flg_regressor = loss in ["mse", "rmsle"] if loss in _sk_loss_mapping: self.loss, fw_func, bw_func = _sk_loss_mapping[loss] else: self.loss = loss # set forward and backward transformations if fw_func is not None: self._fw_func = fw_func if bw_func is not None: self._bw_func = bw_func self.loss_params = loss_params
[docs] def set_callback_metric( self, metric: Union[str, Callable], greater_is_better: Optional[bool] = None, metric_params: Optional[Dict] = None, task_name: Optional[str] = None, ): """ Callback metric setter. Uses default callback of parent class `Loss`. Args: metric: Callback metric. greater_is_better: Whether or not higher value is better. metric_params: Additional metric parameters. task_name: Name of task. """ if self.loss in _sk_force_metric: metric, greater_is_better, metric_params = _sk_force_metric[self.loss] logger.info2("For sklearn {0} callback metric switched to {1}".format(self.loss, metric)) super().set_callback_metric(metric, greater_is_better, metric_params, task_name)