diff options
-rw-r--r-- | tensorflow/core/graph/quantize_training.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/graph/quantize_training.h | 11 | ||||
-rw-r--r-- | tensorflow/core/graph/quantize_training_test.cc | 36 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fake_quantize_training.cc | 50 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fake_quantize_training_test.cc | 63 |
6 files changed, 180 insertions, 16 deletions
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc index 48b6b2a497..b74fa2127e 100644 --- a/tensorflow/core/graph/quantize_training.cc +++ b/tensorflow/core/graph/quantize_training.cc @@ -653,28 +653,38 @@ Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type, return Status::OK(); } -Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph, - int32 num_bits, - const string& quant_op_type, - string* result_graph) { - // First create the graph from the GraphDef. +Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, + int32 num_bits, const string& quant_op_type, + GraphDef* result_graphdef) { Graph graph(OpRegistry::Global()); GraphConstructorOptions opts; - GraphDef input_graphdef; - if (!ParseProtoUnlimited(&input_graphdef, input_graph)) { - return errors::InvalidArgument("Invalid input graph"); - } TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph)); // Call the rewriter on the graph. TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph)); // Convert the result graph back to a GraphDef. + graph.ToGraphDef(result_graphdef); + return Status::OK(); +} + +Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string, + int32 num_bits, + const string& quant_op_type, + string* result_graph_string) { + // First create the graph from the GraphDef. + GraphDef input_graphdef; + if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) { + return errors::InvalidArgument( + "input_graph_string is not a serialized GraphDef protocol buffer"); + } GraphDef output_graphdef; - graph.ToGraphDef(&output_graphdef); + TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef( + input_graphdef, num_bits, quant_op_type, &output_graphdef)); - if (!output_graphdef.SerializeToString(result_graph)) { - return errors::InvalidArgument("Invalid output graph"); + if (!output_graphdef.SerializeToString(result_graph_string)) { + return errors::Internal( + "quantize training transformation resulted in invalid GraphDef"); } return Status::OK(); } diff --git a/tensorflow/core/graph/quantize_training.h b/tensorflow/core/graph/quantize_training.h index 2c1a7e6ae3..2bb4ee1cf0 100644 --- a/tensorflow/core/graph/quantize_training.h +++ b/tensorflow/core/graph/quantize_training.h @@ -38,12 +38,19 @@ namespace tensorflow { Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type, Graph* g); -// Converts a input GraphDef and returns a rewritten GraphDef with the -// quantized training. +// Converts the input serialized GraphDef and returns a rewritten serialized +// GraphDef for quantized training. Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph, int32 num_bits, const string& quant_op_type, string* result_graph); + +// Converts the input GraphDef and returns a rewritten GraphDef for quantized +// training. +Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, + int32 num_bits, const string& quant_op_type, + GraphDef* result_graphdef); + } // namespace tensorflow #endif // TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_ diff --git a/tensorflow/core/graph/quantize_training_test.cc b/tensorflow/core/graph/quantize_training_test.cc index d817d980de..2ad69dbd0c 100644 --- a/tensorflow/core/graph/quantize_training_test.cc +++ b/tensorflow/core/graph/quantize_training_test.cc @@ -282,7 +282,7 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_FakeQuant) { g, strings::StrCat(c->name(), "/FakeQuantWithMinMaxVars"), &found_node)); } -TEST_F(QuantizeTrainingTest, QuantizeGraphDef) { +TEST_F(QuantizeTrainingTest, QuantizeSerializedGraphDef) { // Construct a simple graph with 5 nodes. Reset(); Graph* graph = g_.get(); @@ -310,8 +310,40 @@ TEST_F(QuantizeTrainingTest, QuantizeGraphDef) { GraphDef result_graphdef; EXPECT_TRUE(ParseProtoUnlimited(&result_graphdef, result_string)); + // Ensure that quantizing the serialized graph_def results in a graph with the + // same number of nodes as quantizing the graph. + GraphConstructorOptions opts; + Graph result_graph(OpRegistry::Global()); + TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph)); + TF_ASSERT_OK(DoQuantizeTraining(num_bits, "QuantizeAndDequantizeV2", graph)); + EXPECT_EQ(graph->num_nodes(), result_graph.num_nodes()); +} + +TEST_F(QuantizeTrainingTest, QuantizeGraphDef) { + // Construct a simple graph with 5 nodes. + Reset(); + Graph* graph = g_.get(); + Node* const_a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); + Node* const_b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); + graph->AddControlEdge(graph->source_node(), const_a); + graph->AddControlEdge(graph->source_node(), const_b); + Node* relu = test::graph::Relu(graph, const_a); + Node* identity = test::graph::Identity(graph, const_b); + Node* matmul = test::graph::Matmul(graph, relu, identity, false, false); + graph->AddControlEdge(matmul, graph->sink_node()); + + int num_bits = 8; + + // Convert the graph to the graphdef string. + GraphDef input_graphdef; + graph->ToGraphDef(&input_graphdef); + + GraphDef result_graphdef; + TF_ASSERT_OK(DoQuantizeTrainingOnGraphDef( + input_graphdef, num_bits, "QuantizeAndDequantizeV2", &result_graphdef)); + // Ensure that quantizing the graph_def results in a graph with the same - // number of nodes. + // number of nodes as the graph_def. GraphConstructorOptions opts; Graph result_graph(OpRegistry::Global()); TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph)); diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index d9ec8e8e9b..ec582739b7 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -60,6 +60,7 @@ cc_library( srcs = [ "add_default_attributes.cc", "backports.cc", + "fake_quantize_training.cc", "fold_batch_norms.cc", "fold_constants_lib.cc", "fold_old_batch_norms.cc", @@ -109,6 +110,7 @@ tf_cc_test( srcs = [ "add_default_attributes_test.cc", "backports_test.cc", + "fake_quantize_training_test.cc", "fold_batch_norms_test.cc", "fold_constants_test.cc", "fold_old_batch_norms_test.cc", diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training.cc b/tensorflow/tools/graph_transforms/fake_quantize_training.cc new file mode 100644 index 0000000000..321de47db1 --- /dev/null +++ b/tensorflow/tools/graph_transforms/fake_quantize_training.cc @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/graph/quantize_training.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Rewrites the GraphDef for quantized training. +// Rewrites the forward pass to include the precision loss with quantization so +// the model can learn to deal with such loss and achieve better accuracy when +// it is quantized later for inference. +// Quantization range information is collected in FakeQuantizeWithMinMaxVars +// ops. +// +// TODO(suharshs): Provide instructions on converting the resulting graph for +// inference. +// TODO(suharshs): Implement this using the GTT rather than calling the old +// prototype function. +Status FakeQuantizeTraining(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + // TODO(suharshs): Make num_bits a parameter. + const int32 num_bits = 8; + // TODO(suharshs): Make quantization op a parameter? + const string quant_op_type = "FakeQuantWithMinMaxVars"; + + return DoQuantizeTrainingOnGraphDef(input_graph_def, num_bits, quant_op_type, + output_graph_def); +} + +REGISTER_GRAPH_TRANSFORM("fake_quantize_training", FakeQuantizeTraining); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc new file mode 100644 index 0000000000..3ea7f512c6 --- /dev/null +++ b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status FakeQuantizeTraining(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class FakeQuantizeTrainingTest : public ::testing::Test {}; + +// For now, since the fake_quantize_training transform just calls the +// quantize_training rewrite from tensorflow/core/graph/quantize_training.h, +// we just test that the graph has been changed by the transform. +// TODO(suharshs): Once we implement the fake_quantize_training transform +// using the GTT, write proper tests of the transform here. +TEST_F(FakeQuantizeTrainingTest, TransformOccurred) { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor a_data(DT_FLOAT, TensorShape()); + test::FillIota<float>(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape()); + test::FillIota<float>(&b_data, 1.0f); + Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); + + Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const); + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphDef result; + TransformFuncContext context; + TF_ASSERT_OK(FakeQuantizeTraining(graph_def, context, &result)); + + // Test that the transformation resulted in a graph with more nodes. + EXPECT_GT(result.node_size(), graph_def.node_size()); +} + +} // namespace graph_transforms +} // namespace tensorflow |