diff options
-rw-r--r-- | tensorflow/examples/label_image/main.cc | 15 | ||||
-rw-r--r-- | tensorflow/tools/benchmark/benchmark_model.cc | 27 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/README.md | 8 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/quantize_weights.cc | 12 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/quantize_weights_test.cc | 14 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/summarize_graph_main.cc | 12 |
6 files changed, 66 insertions, 22 deletions
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index a98c0817e3..63bc39de6c 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -89,7 +89,6 @@ Status ReadLabelsFile(const string& file_name, std::vector<string>* result, static Status ReadEntireFile(tensorflow::Env* env, const string& filename, Tensor* output) { - tensorflow::uint64 file_size = 0; TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size)); @@ -124,15 +123,15 @@ Status ReadTensorFromImageFile(const string& file_name, const int input_height, // read file_name into a tensor named input Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape()); - TF_RETURN_IF_ERROR(ReadEntireFile(tensorflow::Env::Default(), file_name, - &input)); + TF_RETURN_IF_ERROR( + ReadEntireFile(tensorflow::Env::Default(), file_name, &input)); // use a placeholder to read input data - auto file_reader = Placeholder(root.WithOpName("input"), - tensorflow::DataType::DT_STRING); + auto file_reader = + Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING); std::vector<std::pair<string, tensorflow::Tensor>> inputs = { - {"input", input}, + {"input", input}, }; // Now try to figure out what kind of file it is and decode it. @@ -285,8 +284,8 @@ int main(int argc, char* argv[]) { "tensorflow/examples/label_image/data/imagenet_slim_labels.txt"; int32 input_width = 299; int32 input_height = 299; - int32 input_mean = 0; - int32 input_std = 255; + float input_mean = 0; + float input_std = 255; string input_layer = "input"; string output_layer = "InceptionV3/Predictions/Reshape_1"; bool self_test = false; diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index 81892aef96..38c1bd3fb5 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -138,8 +138,21 @@ Status GetOutputShapes(const std::vector<InputLayerInfo>& inputs, std::vector<std::pair<string, tensorflow::Tensor> > input_tensors; CreateTensorsFromInputInfo(inputs, &input_tensors); std::vector<tensorflow::Tensor> output_tensors; - std::vector<string> output_tensor_names(wanted_shapes.begin(), - wanted_shapes.end()); + std::vector<string> output_tensor_names; + for (const string& wanted_shape : wanted_shapes) { + bool is_input = false; + for (const std::pair<string, tensorflow::Tensor>& input_tensor : + input_tensors) { + if (input_tensor.first == wanted_shape) { + (*node_shapes)[wanted_shape] = input_tensor.second.shape(); + is_input = true; + break; + } + } + if (!is_input) { + output_tensor_names.push_back(wanted_shape); + } + } TF_RETURN_IF_ERROR( session->Run(input_tensors, output_tensor_names, {}, &output_tensors)); CHECK_EQ(output_tensors.size(), output_tensor_names.size()); @@ -156,7 +169,8 @@ Status CalculateFlops(const GraphDef& graph, Session* session, int64* total_flops, std::unordered_map<string, int64>* flops_by_op) { std::unordered_set<string> floppable_ops = { - "Conv2D", "MatMul", "QuantizedConv2D", "QuantizedMatMul"}; + "Conv2D", "MatMul", "QuantizedConv2D", "QuantizedMatMul", + "DepthwiseConv2dNative"}; std::set<string> wanted_shapes; for (const NodeDef& node : graph.node()) { @@ -201,6 +215,13 @@ Status CalculateFlops(const GraphDef& graph, } int64 output_count = output_shape.num_elements(); current_flops = k * output_count * 2; + } else if (node.op() == "DepthwiseConv2dNative") { + const TensorShape& filter_shape = found_shapes[node.input(1)]; + const TensorShape& output_shape = found_shapes[node.name()]; + int64 filter_height = filter_shape.dim_size(0); + int64 filter_width = filter_shape.dim_size(1); + int64 output_count = output_shape.num_elements(); + current_flops = output_count * filter_height * filter_width * 2; } (*flops_by_op)[node.op()] += current_flops; *total_flops += current_flops; diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index b4274e67df..18d1523070 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -578,10 +578,14 @@ eight-bit form. ### quantize_weights -Args: None \ +Args: + +* minimum_size: Tensors with fewer elements than this won't be quantized +(defaults to 1024) + Prerequisites: None -Converts any large (more than 15 element) float Const op into an eight-bit +Converts any large (more than minimum_size) float Const op into an eight-bit equivalent, followed by a float conversion op so that the result is usable by subsequent nodes. This is mostly useful for [shrinking file sizes](#shrinking-file-size), but also helps with the more advanced diff --git a/tensorflow/tools/graph_transforms/quantize_weights.cc b/tensorflow/tools/graph_transforms/quantize_weights.cc index 66d800f0da..cccae8a992 100644 --- a/tensorflow/tools/graph_transforms/quantize_weights.cc +++ b/tensorflow/tools/graph_transforms/quantize_weights.cc @@ -35,11 +35,15 @@ namespace graph_transforms { Status QuantizeWeights(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { + int32 minimum_size; + TF_RETURN_IF_ERROR( + context.GetOneInt32Parameter("minimum_size", 1024, &minimum_size)); TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( input_graph_def, {"Const"}, - [](const NodeMatch& match, const std::set<string>& input_nodes, - const std::set<string>& output_nodes, - std::vector<NodeDef>* new_nodes) { + [minimum_size](const NodeMatch& match, + const std::set<string>& input_nodes, + const std::set<string>& output_nodes, + std::vector<NodeDef>* new_nodes) { const NodeDef& old_const_node = match.node; if (!old_const_node.attr().count("dtype")) { return errors::InvalidArgument("No 'dtype' attribute for Const node ", @@ -58,7 +62,7 @@ Status QuantizeWeights(const GraphDef& input_graph_def, const size_t num_elements = old_tensor.NumElements(); // If this isn't a float constant, or it's too small, then reuse the // same node with no changes. - if ((old_dtype != DT_FLOAT) || (num_elements < 16)) { + if ((old_dtype != DT_FLOAT) || (num_elements < minimum_size)) { new_nodes->push_back(old_const_node); return Status::OK(); } diff --git a/tensorflow/tools/graph_transforms/quantize_weights_test.cc b/tensorflow/tools/graph_transforms/quantize_weights_test.cc index 63c5b5a64d..e1828831db 100644 --- a/tensorflow/tools/graph_transforms/quantize_weights_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_weights_test.cc @@ -70,9 +70,12 @@ class QuantizeWeightsTest : public ::testing::Test { 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}, &original_graph_def); + TransformFuncContext context; + context.output_names = {"output"}; + context.params["minimum_size"] = {"16"}; GraphDef quantized_graph_def; - TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}}, - &quantized_graph_def)); + TF_ASSERT_OK( + QuantizeWeights(original_graph_def, context, &quantized_graph_def)); // Verify the structure of the quantized graph. std::map<string, const NodeDef*> node_lookup; @@ -122,9 +125,12 @@ TEST_F(QuantizeWeightsTest, RangesAlwaysIncludeZero) { 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}, &original_graph_def); + TransformFuncContext context; + context.output_names = {"output"}; + context.params["minimum_size"] = {"16"}; GraphDef quantized_graph_def; - TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}}, - &quantized_graph_def)); + TF_ASSERT_OK( + QuantizeWeights(original_graph_def, context, &quantized_graph_def)); std::map<string, const NodeDef*> node_lookup; MapNamesToNodes(quantized_graph_def, &node_lookup); diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc index e79e7ba121..50353c3654 100644 --- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc +++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc @@ -118,7 +118,17 @@ Status PrintStructure(const GraphDef& graph) { TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph)); for (const NodeDef& node : sorted_graph.node()) { std::cout << node.name() << " (" << node.op() << "): [" - << str_util::Join(node.input(), ", ") << "]" << std::endl; + << str_util::Join(node.input(), ", ") << "]"; + if (node.op() == "Const") { + Tensor tensor; + if (node.attr().count("value") && + tensor.FromProto(node.attr().at("value").tensor())) { + std::cout << ", value=" << tensor.DebugString(); + } else { + LOG(WARNING) << "Decoding Tensor failed for node" << node.name(); + } + } + std::cout << std::endl; } return Status::OK(); } |