从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)

Tensorflow官方提供的Tensorboard可以可视化神经网络结构图,但是说实话,我几乎从来不用。主要是因为Tensorboard中查看到的图结构太混乱了,包含了网络中所有的计算节点(读取数据节点、网络节点、loss计算节点等等)。更可怕的是,如果一个计算节点是由多个基础计算(如加减乘除等)构成,那么在Tensorboard中会将基础计算节点显示而不是作为一个整体显示(典型的如Squeeze计算节点)。最近为了排查网络结构BUG花费一周时间,因此,狠下心来决定自己写一个工具,将Tensorflow中的图以最简单的方式显示最关键的网络结构。 # 1 Tensor对象与Operation对象 Tensorflow中,Tensor对象主要用于存储数据如常量和变量(训练参数),Operation对象是计算节点,如卷积计算、反卷积计算、ReLU等等。每一个Operation对象均有输入和输出Tensor,同理,每个Tensor对象均有对应生成该Tensor的Operation对象和使用该Tensor对象作为输入的Operation对象。Tensor和Operation对象内均有相关属性和函数来获取其关联的Operation和Tensor对象,相关属性如下所示。 > Tensor对象的op属性指向生成该Tensor的Operation对象。 > Tensor对象的consumers()函数获取使用该Tensor对象作为输入的Operation对象。 > Operation对象的inputs属性指向该计算节点的输入Tensor对象。 > Operation对象的outputs属性执行该计算节点的输出Tensor对象。 如下图所示的网络结构中,调用`Tensor_2`对象的`consumers()`函数,返回的是`[op_1,op_2]`。`Tensor_3`的op属性指向的是`op_1`。`op_1`的inputs属性指向的是`[Tensor_1,Tensor_2]`,`op_1`的output属性指向的是`[Tensor_3]`。 ![Tensor与Operation](https://ask.qcloudimg.com/draft/1163893/r1g3hl72kc.jpg) 有了Tensor与Operation对应在图中的关联关系,就可以将网络结构给画出来。 # 2 提取pb文件中的网络结构图 pb文件是将模型参数固化到图文件中,并合并了一些基础计算和删除了反向传播相关计算得到的protobuf协议文件。如果读者还不懂如何将CKPT模型文件转pb文件,请参考我另一篇文章[《 Tensorflow MobileNet移植到Android》](https://cloud.tencent.com/developer/article/1357573)的第1节部分。有了pb模型文件后,接下来是加载模型,加载pb模型示例代码如下所示。 ```python def read_graph_from_pb(tf_model_path ,input_names,output_name): with open(tf_model_path, 'rb') as f: serialized = f.read() tf.reset_default_graph() gdef = tf.GraphDef() gdef.ParseFromString(serialized) with tf.Graph().as_default() as g: tf.import_graph_def(gdef, name='') with tf.Session(graph=g) as sess: OPS=get_ops_from_pb(g,input_names,output_name) return OPS ``` 其中,倒数第2行调用到的函数`get_ops_from_pb()`用于获取网络结构图中指定输入节点和指定输出节点之间的计算节点。之所以要指定输入和输出,是为了将输入之前的计算节点(如加载数据队列等相关计算节点)和输出之后的计算节点(如计算loss等相关计算节点)去除,免得碍眼。函数`get_ops_from_pb()`实现代码如下。 ```python def get_ops_from_pb(graph,input_names,output_name,save_ori_network=True): if save_ori_network: with open('ori_network.txt','w+') as w: OPS=graph.get_operations() for op in OPS: txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs]) w.write(txt+'\n') inputs_tf = [graph.get_tensor_by_name(input_name) for input_name in input_names] output_tf =graph.get_tensor_by_name(output_name) OPS =get_ops_from_inputs_outputs(graph, inputs_tf,[output_tf] ) with open('network.txt','w+') as w: for op in OPS: txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs]) w.write(txt+'\n') OPS = sort_ops(OPS) OPS = merge_layers(OPS) return OPS ``` 在裁剪网络结构(即只保留input\_names和output\_name之间节点)之前,先将原始的网络结构写入到`ori_network.txt`中,文件中,每一行写入:`输入Tensor---->op---->输出Tensor`。接下来调用函数`get_ops_from_inputs_outputs`获取指定节点之间的节点。并调用`sort_ops`函数对所有的节点排序,以保证被依赖的节点总是出现在相关节点之前。最后调用`merge_layers`函数,将一些可以合并的计算合并成一个独立的节点,例如,`Squeeze`计算相关节点合并成一个单独的Squeeze节点,又如`const-->identity`两个计算节点可以直接忽略(即删除)。 > 注意:篇幅有限,这里不再将函数`get_ops_from_inputs_outputs`、`sort_ops`、`merge_layers`贴出,相关代码请前往文尾提供的源码地址中阅读。 # 3 绘制网络结构 考虑到`SVG`绘制图形的简单易用优点,将排好序的网络计算节点和相关`Tensor`对象数据以`Javascript`字符串的形式写入到`HTML`中,使用``标签绘制箭头,使用``标签绘制矩形,使用``标签绘制椭圆,使用``标签显示文字。绘制类似于如下所示图像 ![绘制网络结构示例](https://ask.qcloudimg.com/draft/1163893/9swuwpgdy6.jpg) > 注意:篇幅有限,这里不再介绍Javascript代码解析模型结构和SVG显示相关的原理,相关代码请前往文尾提供的源码地址中阅读。 # 4 测试模型显示 以[《MobileNet V1官方预训练模型的使用》](https://cloud.tencent.com/developer/article/1356892)文中介绍的MobileNet V1网络结构为例,下载`MobileNet_v1_1.0_192`文件并压缩后,得到`mobilenet_v1_1.0_192_frozen.pb`文件。我们还需要知道`mobilenet_v1_1.0_192_frozen.pb`模型对应的输入和输出`Tensor`对象的名称,好在`MobileNet_v1_1.0_192`压缩包中包含文件`mobilenet_v1_1.0_192_info.txt`。通过该文件可知,输入`Tensor`的名称为:`input:0`,输出Tensor名称为:`MobilenetV1/Predictions/Reshape_1:0`。有了这些信息后,调用函数`read_graph_from_pb`得到静态图的节点列表对象ops,调用函数`gen_graph(ops,"save/path/graph.html")`后,在目录`save/path`中得到`graph.html`文件,打开`graph.html`后,显示结果如下。 > 显示网络结构分两种模式:合并模式和展开模式,分别如下图所示。 ![合并模式网络结构](https://ask.qcloudimg.com/draft/1163893/sk026zzkwy.gif) ![截取的展开模式网络结构](https://ask.qcloudimg.com/draft/1163893/htfii94n90.gif) # 5 源码地址 [https://github.com/huachao1001/CNNGraph](https://github.com/huachao1001/CNNGraph)