BertEmbedder

class lightautoml.text.dl_transformers.BertEmbedder(model_name, pooling='none', **kwargs)[source]

Bases: torch.nn.Module

Class to compute HuggingFace transformers words or sentence embeddings.

__init__(model_name, pooling='none', **kwargs)[source]

Bert sentence or word embeddings.

Parameters
  • model_name (str) – Name of transformers model.

  • pooling (str) – Pooling type.

  • **kwargs – Ignored params.

Note

There are several pooling types:

  • ‘cls’: Use CLS token for sentence embedding from last hidden state.

  • ‘max’: Maximum on seq_len dimension for non masked inputs from last hidden state.

  • ‘mean’: Mean on seq_len dimension for non masked inputs from last hidden state.

  • ‘sum’: Sum on seq_len dimension for non masked inputs from last hidden state.

  • ‘none’: Don’t use pooling (for RandomLSTM pooling strategy).

freeze()[source]

Freeze module parameters.

get_name()[source]

Module name.

Return type

str

Returns

String with module name.

get_out_shape()[source]

Output shape.

Return type

int

Returns

Int with module output shape.