BertEmbedder

class lightautoml.text.dl_transformers.BertEmbedder(model_name, pooling='none', **kwargs)[source]

Bases: Module

Class to compute HuggingFace transformers words or sentence embeddings.

Bert sentence or word embeddings.

Parameters:
  • model_name (str) – Name of transformers model.

  • pooling (str) – Pooling type.

  • **kwargs (Any) – Ignored params.

Note

There are several pooling types:

  • ‘cls’: Use CLS token for sentence embedding

    from last hidden state.

  • ‘max’: Maximum on seq_len dimension

    for non masked inputs from last hidden state.

  • ‘mean’: Mean on seq_len dimension for non masked

    inputs from last hidden state.

  • ‘sum’: Sum on seq_len dimension for non masked inputs

    from last hidden state.

  • ‘none’: Don’t use pooling (for RandomLSTM pooling strategy).

forward(inp)[source]

Forward-pass.

Return type:

Tensor

freeze()[source]

Freeze module parameters.

get_name()[source]

Module name.

Return type:

str

Returns:

String with module name.

get_out_shape()[source]

Output shape.

Return type:

int

Returns:

Int with module output shape.