使用protobuf解析tensorflow pb模型

    xiaoxiao2022-07-02  103

    1、安装protobuf,使用protoc把tensorflow模型中的proto文件转换为对应的.h和.cc,如下图所示:

    解析用到的proto位置大概在tensorflow\core\framework和tensorflow\core\protobuf,当然这里面的proto比较多,只要把graph.proto包含的proto文件用到就可以了。

    proto里面有syntax = "proto3",这就意味着protobuf的版本需要在3.0以上。

    2、搭建VS工程

    主要信息获取可以通过以下代码获取:

    // tensorflow_pb.cpp : 定义控制台应用程序的入口点。 // #include "stdafx.h" #include <iostream> #include <fstream> #include "graph.pb.h" int main() { //char* pbname = "test.txt"; char* pbname = "test.pb"; std::fstream fp; fp.open(pbname,std::ios::in | std::ios::binary); if (!fp) { printf("file not found\n"); } tensorflow::GraphDef graph_def; bool suc = graph_def.ParseFromIstream(&fp); fp.close(); int size = graph_def.node_size(); //yes tensorflow::NodeDef node_info; node_info = graph_def.node(8); //对应于python:oplist = get_operation();node_info = oplist[8]; size = node_info.input_size(); //std::string input0_name = node_info.input(0); //std::string input1_name = node_info.input(1); google::protobuf::Map< std::string, tensorflow::AttrValue> attr_map; attr_map = node_info.attr(); //获取该结点的attr的map /*if (attr_map.contains("transpose_a")) { printf("yes"); } tensorflow::AttrValue attr_temp = attr_map["transpose_a"];*/ google::protobuf::Map< std::string, tensorflow::AttrValue>::iterator iter = attr_map.begin(); while (iter != attr_map.end()) { std::cout << iter->first << std::endl; //查看attrmap的key值 iter++; } if (attr_map.contains("value")) { printf("yes"); } tensorflow::AttrValue tensor_content = attr_map["value"]; suc = tensor_content.has_tensor(); tensorflow::TensorProto tensor_info = tensor_content.tensor(); std::string tmp = tensor_info.tensor_content(); float* data = (float*)tmp.data(); //如果该attr包含的tensor有值,需要将该值转换为float来用 tensorflow::TensorShapeProto tens_shape = tensor_info.tensor_shape(); size = tens_shape.dim_size(); tensorflow::TensorShapeProto_Dim tens_dim = tens_shape.dim(0); int64_t size_zz = tens_dim.size(); //获取该Tensor的维度信息,这里获取shape的第一维信息 return 0; }

    当然,在VS工程里需要将使用的protobuf里的src\google\protobuf这部分内容拉到工程里面,为了跑通程序,还需要将上述protobuf的名字包含test的文件去除,这些文件需要另外包含gmock等依赖;

    3、最初打算分析tensorflow源码,但是后来感觉之前python的解析脚本用到的指令应该是tensorflow对上述代码的封装,就拿get_operation(tensor)_by_name为例,tensorflow内部会维护一个op结点(对应上述代码的node)及其name的字典;为了不依赖tensorflow,只好从protobuf解析来做了~

    大致梳理一下:首先需要区分const节点和非const节点,然后将const节点中能计算的信息全部计算出来,便于后续直接获取;

    然后根据数据流将所有节点进行排序;最后根据需求将可合并的节点进行合并,并解析;

     

    最新回复(0)