从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)

上一篇文章[《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》](https://cloud.tencent.com/developer/article/1361316)中介绍了如何从`pb`模型文件中提取网络结构图并实现可视化,本文介绍如何从`CKPT`模型文件中提取网络结构图并实现可视化。理论上,既然能从`pb`模型文件中提取网络结构图,`CKPT`模型文件自然也不是问题,但是其中会有一些问题。 # 1 解析CKPT网络结构 解析`CKPT`网络结构的第一步是读取`CKPT`模型中的图文件,得到图的`Graph`对象后即可得到完整的网络结构。读取图文件示例代码如下所示。 ```python saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True) graph = tf.get_default_graph() with tf.Session( graph=graph) as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess,ckpt_path) ``` 调用`graph.get_operations()`后即可得到当前图的所有计算节点,在利用`Operation`对象与`Tensor`对象之间的相互引用关系即可推断网络结构。但是需要注意的是,从`meta`文件中导入的图中获取计算节点存在如下问题。 > 包含反向梯度下降计算的所有节点 > 某些计算节点是按基础计算(加减乘除等)节点拆分成多个计算节点的,如`BatchNorm`,但其实是可以直接合并成一个节点的。 `pb`模型文件可以避免上面第一个问题,将`CKPT`模型转`pb`模型后,可以自动将反向梯度下降相关计算节点移除。对于第二点,`pb`模型文件会自动将基础计算组成一个计算节点,但是对于Tensor操作的函数如Slice等函数是无法合并的。因此,对于第2个问题,将`CKPT`模型转`pb`模型后,可以减少这类问题,但是无法避免。彻底避免的方法只能通过自己针对性地实现。经过以上分析,得出的结论是非常有必要将`CKPT`模型转`pb`模型。 # 2 自动将CKPT转pb,并提取网络图中节点 如果将CKPT自动转pb模型,那么就可以复用上一篇文章[《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》](https://cloud.tencent.com/developer/article/1361316)的代码。示例代码如下所示。 ```python def read_graph_from_ckpt(ckpt_path,input_names,output_name ): saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True) graph = tf.get_default_graph() with tf.Session( graph=graph) as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess,ckpt_path) output_tf =graph.get_tensor_by_name(output_name) pb_graph = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [output_tf.op.name]) with tf.Graph().as_default() as g: tf.import_graph_def(pb_graph, name='') with tf.Session(graph=g) as sess: OPS=get_ops_from_pb(g,input_names,output_name) return OPS ``` 其中函数`get_ops_from_pb`在上一篇文章[《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》](https://cloud.tencent.com/developer/article/1361316)中已经实现。 # 3 测试 以[《MobileNet V1官方预训练模型的使用》](https://cloud.tencent.com/developer/article/1356892)文中介绍的MobileNet V1网络结构为例,下载MobileNet\_v1\_1.0\_192文件并压缩后,得到`mobilenet_v1_1.0_192.ckpt.data-00000-of-00001`、`mobilenet_v1_1.0_192.ckpt.index`、`mobilenet_v1_1.0_192.ckpt.meta`文件。我们还需要知道`mobilenet_v1_1.0_192.ckpt`模型对应的输入和输出`Tensor`对象的名称,官方提供的压缩包文件中并没有告知。一种方法是运行官方代码,把输入Tensor的名称打印出来。但是运行官方代码本身就需要一定的时间和精力,在在上一篇文章[《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》](https://cloud.tencent.com/developer/article/1361316)的代码实现中已经实现了将原始网络结构对应的字符串写入到`ori_network.txt`文件中。因此,可以先随意填写输入名称和输出名称,待生成`ori_network.txt`文件后,从文件中可以直观看到原始网络结构。`ori_network.txt`文件部分内容如下所示。 ![ori_network.txt文件部分内容](https://ask.qcloudimg.com/draft/1163893/oehrvs5fza.jpg) 通过该文件可知,输入`Tensor`的名称为:`batch:0`,输出`Tensor`名称为:`MobilenetV1/Predictions/Reshape_1:0`。有了这些信息后,调用函数`read_graph_from_ckpt`得到静态图的节点列表对象`ops`,调用函数`gen_graph(ops,"save/path/graph.html")`后,在目录`save/path`中得到`graph.html`文件,打开`graph.html`后,显示结果如下。 ![读取并显示CKPT模型的图结构](https://ask.qcloudimg.com/draft/1163893/p47542wf3w.gif) # 4 源码地址 [https://github.com/huachao1001/CNNGraph](https://github.com/huachao1001/CNNGraph)