TimmModelEmbedder

class lightautoml.image.image.TimmModelEmbedder(model_name='efficientnet_b0.ra_in1k', weights_path=None, device=torch.device)[source]

Bases: Module

Class to compute TimmModels embeddings.

__init__(model_name='efficientnet_b0.ra_in1k', weights_path=None, device=torch.device)[source]

Pytorch module for image embeddings based on timm models.

Parameters:
  • model_name (str) – Name of effnet model.

  • weights_path (Optional[str]) – Path to saved weights.

  • device – Device to use.

get_shape()

Calculate output embedding shape.

Return type:

int

Returns:

Shape of embedding.

forward(x)[source]

Forward pass.

Return type:

Tensor