0. 背景
tf2版本中,提供了三种不同的模型构建逻辑。
- sequential API
适用于层叠式的模型层结构,并且每层都只有一个明确的输入tensor和输出tensor
- functional API
函数式API是一种比Sequential API跟加简单和灵活的模型创建方式。
函数式API可以处理具有非线性拓扑的模型、具有共享层的模型,以及具有多个输入或输出的模型
- model subclassing API
model subclassing API具有更大的自由度,可以让开发者控制模型、layer以及训练过程
1. Sequential API
1.1 构建方法
构建Sequential 模型主要通过Sequential constructor对layers的列表进行模型构建
1.2 其他构建方法
- 除了直接通过sequential constructor对layers list进行模型构建外,也可以通过sequential constructor的add属性逐层进行构建
- 注意,通过这种方法进行模型构建的时候可以将一个 name 传入到sequential constructor中
2. Functional API
2.1 构建方法
- 注意:通过Functional API进行模型构建的时候,需要先创建一个输入节点:
inputs = tf.keras.Input(shape=(512, 128))
- 然后可以在
inputs
对象上调用层,在层计算图中创建新的节点, tf.keras.layers.Dense(64, activation='relu')(inputs)
- 每调用一个层,就意味着将上一个层的输出作为新的调用层的输入,并通过该层处理后进行输出结果。
- 最后,可以通过
tf.keras.Model
指定输入和输出来创建模型。
2.2 使用模型进行训练、评估和推断
通过Functional API和Sequential API构建的模型,都可以通过同一种方法进行模型的训练、评估和推断工作。
3. Subclassing API
subclassing种的build_graph不是必须的,但是不具有build_graph的subclass无法进行模型结构的展示。
可以通过如下方式,进行模型结构的展示。