TorchLossWrapper

class lightautoml.tasks.losses.torch.TorchLossWrapper(func, flatten=False, log=False, **kwargs)[source]

Bases: Module

Customize PyTorch-based loss.

Parameters:
  • func (Callable) – loss to customize. Example: torch.nn.MSELoss.

  • **kwargs (Any) – additional parameters.

Returns:

callable loss, uses format (y_true, y_pred, sample_weight).

forward(y_true, y_pred, sample_weight=None)[source]

Forward-pass.