mlx.nn.Embedding

目录

mlx.nn.Embedding#

class Embedding(num_embeddings: int, dims: int)#

实现一个简单的查找表,将每个输入整数映射到高维向量。

通常用于嵌入离散标记以供神经网络处理。

参数:
  • num_embeddings (int) – 我们可以嵌入多少个可能的离散标记。通常称为词汇表大小。

  • dims (int) – 嵌入的维度。

方法

as_linear(x)

将嵌入层作为线性层调用。

to_quantized([group_size, bits])

返回一个 QuantizedEmbedding 层,该层近似于此嵌入层。