tf2中subclassing自定义模型
1. tf2中的subclassing模型
tf2中常见的定义模型的方法分成三种:functional、sequenctial、subclassing
其中,subclassing方法就是自己实现一个类来继承tf.keras.Model
,来构建一个Model类的子类,在构建过程中,主要需要实现下面两个方法:
__init__()
这个方法是子类的构造器,主要是用来初始化参数(比如layers之类的)。super
主要用来初始化父构造器。call()
该函数用于执行在__init__
中定义的layers操作
同时,该函数可以使我们定义的模型直接作为函数使用,该功能主要依赖于python的__call__
实现。
2. subclassing模型的例子
import tensorflow as tf
class MLP(tf.keras.Model):
def __init__(self):
super(MLP, self).__init__()
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(units=10)
def call(self, inputs):
x = self.flatten(inputs)
x = self.dense1(x)
x = self.dense2(x)
output = tf.nn.softmax(x)
return output
3. 使用summary/plot_model观察模型结构
上述的模型定义无法通过summary观察模型结构,主要是因为在定义模型的时候没有指定input_shape,只需要增加一个定义input_shape的函数就可。
def build_graph(self, input_shape):
input_ = tf.keras.Input(shape=input_shape)
return tf.keras.models.Model(inputs=[input_], outputs=self.call(input_))
增加该函数后,即可以在定义input_shape之后通过summary和plot_model直观的观察模型结构情况。
# 定义模型
test_model = MLP()
3.1 summary输出
# 通过summary观察模型结构
test_model.build_graph(input_shape=(16,)).summary()
3.2 plot_model输出
# 通过plot_model观察模型结构
tf.keras.utils.plot_model(test_model.build_graph(input_shape=(16,)), to_file='./test_model.png', show_shapes=True)
tf2中subclassing自定义模型
https://zermzhang.github.io/2022/04/25/tensorflow/tf2中subclassing自定义模型/