"""Metrics and loss functions for LightGBM."""
import logging
from functools import partial
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Union
import lightgbm as lgb
import numpy as np
from ..common_metric import _valid_str_multiclass_metric_names
from ..utils import infer_gib
from .base import Loss
from .base import fw_rmsle
from .lgb_custom import lgb_f1_loss_multiclass # , F1Factory
from .lgb_custom import softmax_ax1
logger = logging.getLogger(__name__)
_lgb_binary_metrics_dict = {
"auc": "auc",
"logloss": "binary_logloss",
"accuracy": "binary_error",
}
_lgb_reg_metrics_dict = {
"mse": "mse",
"mae": "mae",
"r2": "mse",
"rmsle": "mse",
"mape": "mape",
"quantile": "quantile",
"huber": "huber",
"fair": "fair",
}
_lgb_multiclass_metrics_dict = {
"auc": _valid_str_multiclass_metric_names["auc"],
"auc_mu": _valid_str_multiclass_metric_names["auc_mu"],
"crossentropy": "multi_logloss",
"accuracy": "multi_error",
"f1_macro": _valid_str_multiclass_metric_names["f1_macro"],
"f1_micro": _valid_str_multiclass_metric_names["f1_micro"],
"f1_weighted": _valid_str_multiclass_metric_names["f1_weighted"],
}
_lgb_metrics_dict = {
"binary": _lgb_binary_metrics_dict,
"reg": _lgb_reg_metrics_dict,
"multiclass": _lgb_multiclass_metrics_dict,
}
_lgb_loss_mapping = {
"logloss": ("binary", None, None),
"mse": ("regression", None, None),
"mae": ("l1", None, None),
"mape": ("mape", None, None),
"crossentropy": ("multiclass", None, None),
"rmsle": ("mse", fw_rmsle, np.expm1),
"quantile": ("quantile", None, None),
"huber": ("huber", None, None),
"fair": ("fair", None, None),
"f1": (lgb_f1_loss_multiclass, None, softmax_ax1),
}
_lgb_loss_params_mapping = {
"quantile": {"q": "alpha"},
"huber": {"a": "alpha"},
"fair_c": {"c": "fair_c"},
}
_lgb_force_metric = {
"rmsle": ("mse", None, None),
}
[docs]class LGBFunc:
"""Wrapper of metric function for LightGBM."""
def __init__(self, metric_func, greater_is_better, bw_func):
self.metric_func = metric_func
self.greater_is_better = greater_is_better
self.bw_func = bw_func
def __call__(self, pred: np.ndarray, dtrain: lgb.Dataset) -> Tuple[str, float, bool]:
"""Calculate metric."""
label = dtrain.get_label()
weights = dtrain.get_weight()
if label.shape[0] != pred.shape[0]:
pred = pred.reshape((label.shape[0], -1), order="F")
label = label.astype(np.int32)
label = self.bw_func(label)
pred = self.bw_func(pred)
# for weighted case
try:
val = self.metric_func(label, pred, sample_weight=weights)
except TypeError:
val = self.metric_func(label, pred)
# TODO: what if grouped case
return "Opt metric", val, self.greater_is_better
[docs]class LGBLoss(Loss):
"""Loss used for LightGBM.
Args:
loss: Objective to optimize.
loss_params: additional loss parameters.
Format like in :mod:`lightautoml.tasks.custom_metrics`.
fw_func: forward transformation.
Used for transformation of target and item weights.
bw_func: backward transformation.
Used for predict values transformation.
Note:
Loss can be one of the types:
- Str: one of default losses
('auc', 'mse', 'mae', 'logloss', 'accuray', 'r2',
'rmsle', 'mape', 'quantile', 'huber', 'fair')
or another lightgbm objective.
- Callable: custom lightgbm style objective.
"""
def __init__(
self,
loss: Union[str, Callable],
loss_params: Optional[Dict] = None,
fw_func: Optional[Callable] = None,
bw_func: Optional[Callable] = None,
):
if loss in _lgb_loss_mapping:
fobj, fw_func, bw_func = _lgb_loss_mapping[loss]
if type(fobj) is str:
self.fobj_name = fobj
self.fobj = None
else:
self.fobj_name = None
self.fobj = fobj
# map param name for known objectives
if self.fobj_name in _lgb_loss_params_mapping:
param_mapping = _lgb_loss_params_mapping[self.fobj_name]
loss_params = {param_mapping[x]: loss_params[x] for x in loss_params}
else:
# set lgb style objective
if type(loss) is str:
self.fobj_name = loss
self.fobj = None
else:
self.fobj_name = None
self.fobj = loss
# set forward and backward transformations
if fw_func is not None:
self._fw_func = fw_func
if bw_func is not None:
self._bw_func = bw_func
self.fobj_params = {}
if loss_params is not None:
self.fobj_params = loss_params
self.metric = None
[docs] def metric_wrapper(
self,
metric_func: Callable,
greater_is_better: Optional[bool],
metric_params: Optional[Dict] = None,
) -> Callable:
"""Customize metric.
Args:
metric_func: Callable metric.
greater_is_better: Whether or not higher value is better.
metric_params: Additional metric parameters.
Returns:
Callable metric, that returns ('Opt metric', value, greater_is_better).
"""
if greater_is_better is None:
greater_is_better = infer_gib(metric_func)
if metric_params is not None:
metric_func = partial(metric_func, **metric_params)
return LGBFunc(metric_func, greater_is_better, self._bw_func)
[docs] def set_callback_metric(
self,
metric: Union[str, Callable],
greater_is_better: Optional[bool] = None,
metric_params: Optional[Dict] = None,
task_name: Optional[str] = None,
):
"""Callback metric setter.
Args:
metric: Callback metric.
greater_is_better: Whether or not higher value is better.
metric_params: Additional metric parameters.
task_name: Name of task.
Note:
Value of ``task_name`` should be one of following options:
- `'binary'`
- `'reg'`
- `'multiclass'`
"""
# force metric if special loss
# what about task_name? in this case?
if self.fobj_name in _lgb_force_metric:
metric, greater_is_better, metric_params = _lgb_force_metric[self.fobj_name]
logger.info2(
"For lgbm {0} callback metric switched to {1}".format(self.fobj_name, metric),
UserWarning,
)
self.metric_params = {}
# set lgb style metric
self.metric = metric
if type(metric) is str:
if metric_params is not None:
self.metric_params = metric_params
_metric_dict = _lgb_metrics_dict[task_name]
_metric = _metric_dict.get(metric)
if type(_metric) is str:
self.metric_name = _metric
self.feval = None
else:
self.metric_name = None
# _metric = CustomWrapper(_metric)
self.feval = self.metric_wrapper(_metric, greater_is_better, {})
else:
self.metric_name = None
# metric = CustomWrapper(metric)
self.feval = self.metric_wrapper(metric, greater_is_better, self.metric_params)