tf2里导入pretrain-embedding的方法

在深度模型训练中,我们经常会使用到pre-train embedding,比如glove、w2v等模型的输出,尤其是在nlp和推广搜业务中。
面对多种多样的embdding形式,如何方便的利用就成为了一个问题,这里记录一下我用来reload embedding的方法。

1. embedding存储方式

1.1 单层映射

pre-train embedding的存储方式多种多样,最常见的一种是直接存到在文件中,通过<key, vlaue>的方式进行存储。

# file_name
id1\t0.1,0.2,0.3,0.4,0.5
id2\t0.3,0.4,0.5,0.3,0.5
id3\t0.2,0.4,0.2,0.3,0.4

该方式可以将key, value分别存储成个一个list,方便后续查找。

1.2 多层映射

还有多层映射的embedding存储,为了节省存储空间,可能会通过多个文件进行存储

# file_name_1
id1\tattr1,attr2
id2\tattr2,attr3

# file_name_2
attr1\t0.1,0.2,0.3,0.4,0.5
attr2\t0.3,0.4,0.5,0.3,0.5
attr3\t0.2,0.4,0.2,0.3,0.4

这种情况下,需要存储多个list,分别是id的list,id对应的属性的list,属性对应的embedding的list。

如上两种embedding的存储方式,都可以通过tf2中已有方法进行快速的embeding查找。

2. 处理逻辑

主要处理逻辑分成两部分:

  1. 获取需要的embedding在embedding list中的idx
  2. 根据idx获取需要的embdding

2.1 获取idx

class InputIdxLayer(tf.keras.layers.experimental.preprocessing.PreprocessingLayer):
    """
    获取输入特征对应的特征list中的idx信息,方便后续从embeddinglist中获取需要的embdding
    1. 直接映射
        item => idx for item in inputs_list
    2. 间接映射
        item => attr => idx for item's attr in target_list
    """
    def __init__(self, inputs_list: List, mapping_list: List = None,
                 target_list: List = None, **kwargs):
        super(InputIdxLayer, self).__init__(**kwargs)
        self.inputs_list = inputs_list
        self.mapping_list = mapping_list
        self.target_list = target_list
        print(self.inputs_list, self.mapping_list, self.target_list)
        if self.mapping_list is None:
            self.target_table = self.get_table(self.inputs_list,
                                               [i for i in range(len(self.inputs_list))])
        else:
            self.mapping_table = self.get_table(self.inputs_list, self.mapping_list,
                                                default_value='')
            self.target_table = self.get_table(self.target_list,
                                               [i for i in range(len(self.target_list))])

    @staticmethod
    def get_table(keys: List, vals: List, default_value: Any = -1):
        init = tf.lookup.KeyValueTensorInitializer(keys, vals)
        table = tf.lookup.StaticHashTable(init, default_value=default_value)
        return table

    def call(self, inputs):
        if self.mapping_table is None:
            idx_ = self.target_table.lookup(inputs)
            return idx_
        else:
            target_ = self.mapping_table.lookup(inputs)
            idx_ = self.target_table.lookup(target_)
            return idx_

2.2 获取embdding

class LoadEmbeddingLayer(tf.keras.layers.Layer):
    """
    get the pre-trained embedding in embdding List
    inpust: the idx_ output from InputIdxLayer
    """
    def __init__(self, embedding, **kwargs):
        super(LoadEmbeddingLayer).__init__(**kwargs)
        self.embedding = tf.constant(embedding)

    def call(self, inputs):
        return tf.nn.embedding_lookup(self.embedding, inputs)

tf2里导入pretrain-embedding的方法
https://zermzhang.github.io/2022/04/29/tensorflow/tf2里导入pretrain-embedding的方法/
作者
知白
发布于
2022年4月29日
许可协议