tf2中subclassing自定义模型
1. tf2中的subclassing模型
tf2中常见的定义模型的方法分成三种:functional、sequenctial、subclassing
其中,subclassing方法就是自己实现一个类来继承tf.keras.Model
,来构建一个Model类的子类,在构建过程中,主要需要实现下面两个方法:
__init__()
这个方法是子类的构造器,主要是用来初始化参数(比如layers之类的)。super
主要用来初始化父构造器。call()
该函数用于执行在__init__
中定义的layers操作
同时,该函数可以使我们定义的模型直接作为函数使用,该功能主要依赖于python的__call__
实现。
2. subclassing模型的例子
3. 使用summary/plot_model观察模型结构
上述的模型定义无法通过summary观察模型结构,主要是因为在定义模型的时候没有指定input_shape,只需要增加一个定义input_shape的函数就可。
增加该函数后,即可以在定义input_shape之后通过summary和plot_model直观的观察模型结构情况。
3.1 summary输出
3.2 plot_model输出
tf2中subclassing自定义模型
https://zermzhang.github.io/2022/04/25/tensorflow/tf2中subclassing自定义模型/