TorchLossWrapper

class lightautoml.tasks.losses.torch.TorchLossWrapper(func, flatten=False, log=False, **kwargs)[source]

Bases: torch.nn.Module

Customize PyTorch-based loss.

Parameters
  • func (Callable) – loss to customize. Example: torch.nn.MSELoss.

  • **kwargs – additional parameters.

Returns

callable loss, uses format (y_true, y_pred, sample_weight).

forward(y_true, y_pred, sample_weight=None)[source]

Forward-pass.