TorchLossWrapper
- class lightautoml.tasks.losses.torch.TorchLossWrapper(func, flatten=False, log=False, **kwargs)[source]
Bases:
torch.nn.Module
Customize PyTorch-based loss.
- Parameters
func (
Callable
) – loss to customize. Example: torch.nn.MSELoss.**kwargs – additional parameters.
- Returns
callable loss, uses format (y_true, y_pred, sample_weight).