aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/graph/quantize_training.cc34
-rw-r--r--tensorflow/core/graph/quantize_training.h11
-rw-r--r--tensorflow/core/graph/quantize_training_test.cc36
-rw-r--r--tensorflow/tools/graph_transforms/BUILD2
-rw-r--r--tensorflow/tools/graph_transforms/fake_quantize_training.cc50
-rw-r--r--tensorflow/tools/graph_transforms/fake_quantize_training_test.cc63
6 files changed, 180 insertions, 16 deletions
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc
index 48b6b2a497..b74fa2127e 100644
--- a/tensorflow/core/graph/quantize_training.cc
+++ b/tensorflow/core/graph/quantize_training.cc
@@ -653,28 +653,38 @@ Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type,
return Status::OK();
}
-Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph,
- int32 num_bits,
- const string& quant_op_type,
- string* result_graph) {
- // First create the graph from the GraphDef.
+Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
+ int32 num_bits, const string& quant_op_type,
+ GraphDef* result_graphdef) {
Graph graph(OpRegistry::Global());
GraphConstructorOptions opts;
- GraphDef input_graphdef;
- if (!ParseProtoUnlimited(&input_graphdef, input_graph)) {
- return errors::InvalidArgument("Invalid input graph");
- }
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph));
// Call the rewriter on the graph.
TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph));
// Convert the result graph back to a GraphDef.
+ graph.ToGraphDef(result_graphdef);
+ return Status::OK();
+}
+
+Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string,
+ int32 num_bits,
+ const string& quant_op_type,
+ string* result_graph_string) {
+ // First create the graph from the GraphDef.
+ GraphDef input_graphdef;
+ if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) {
+ return errors::InvalidArgument(
+ "input_graph_string is not a serialized GraphDef protocol buffer");
+ }
GraphDef output_graphdef;
- graph.ToGraphDef(&output_graphdef);
+ TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef(
+ input_graphdef, num_bits, quant_op_type, &output_graphdef));
- if (!output_graphdef.SerializeToString(result_graph)) {
- return errors::InvalidArgument("Invalid output graph");
+ if (!output_graphdef.SerializeToString(result_graph_string)) {
+ return errors::Internal(
+ "quantize training transformation resulted in invalid GraphDef");
}
return Status::OK();
}
diff --git a/tensorflow/core/graph/quantize_training.h b/tensorflow/core/graph/quantize_training.h
index 2c1a7e6ae3..2bb4ee1cf0 100644
--- a/tensorflow/core/graph/quantize_training.h
+++ b/tensorflow/core/graph/quantize_training.h
@@ -38,12 +38,19 @@ namespace tensorflow {
Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type,
Graph* g);
-// Converts a input GraphDef and returns a rewritten GraphDef with the
-// quantized training.
+// Converts the input serialized GraphDef and returns a rewritten serialized
+// GraphDef for quantized training.
Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph,
int32 num_bits,
const string& quant_op_type,
string* result_graph);
+
+// Converts the input GraphDef and returns a rewritten GraphDef for quantized
+// training.
+Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
+ int32 num_bits, const string& quant_op_type,
+ GraphDef* result_graphdef);
+
} // namespace tensorflow
#endif // TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
diff --git a/tensorflow/core/graph/quantize_training_test.cc b/tensorflow/core/graph/quantize_training_test.cc
index d817d980de..2ad69dbd0c 100644
--- a/tensorflow/core/graph/quantize_training_test.cc
+++ b/tensorflow/core/graph/quantize_training_test.cc
@@ -282,7 +282,7 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_FakeQuant) {
g, strings::StrCat(c->name(), "/FakeQuantWithMinMaxVars"), &found_node));
}
-TEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
+TEST_F(QuantizeTrainingTest, QuantizeSerializedGraphDef) {
// Construct a simple graph with 5 nodes.
Reset();
Graph* graph = g_.get();
@@ -310,8 +310,40 @@ TEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
GraphDef result_graphdef;
EXPECT_TRUE(ParseProtoUnlimited(&result_graphdef, result_string));
+ // Ensure that quantizing the serialized graph_def results in a graph with the
+ // same number of nodes as quantizing the graph.
+ GraphConstructorOptions opts;
+ Graph result_graph(OpRegistry::Global());
+ TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph));
+ TF_ASSERT_OK(DoQuantizeTraining(num_bits, "QuantizeAndDequantizeV2", graph));
+ EXPECT_EQ(graph->num_nodes(), result_graph.num_nodes());
+}
+
+TEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
+ // Construct a simple graph with 5 nodes.
+ Reset();
+ Graph* graph = g_.get();
+ Node* const_a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Node* const_b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
+ graph->AddControlEdge(graph->source_node(), const_a);
+ graph->AddControlEdge(graph->source_node(), const_b);
+ Node* relu = test::graph::Relu(graph, const_a);
+ Node* identity = test::graph::Identity(graph, const_b);
+ Node* matmul = test::graph::Matmul(graph, relu, identity, false, false);
+ graph->AddControlEdge(matmul, graph->sink_node());
+
+ int num_bits = 8;
+
+ // Convert the graph to the graphdef string.
+ GraphDef input_graphdef;
+ graph->ToGraphDef(&input_graphdef);
+
+ GraphDef result_graphdef;
+ TF_ASSERT_OK(DoQuantizeTrainingOnGraphDef(
+ input_graphdef, num_bits, "QuantizeAndDequantizeV2", &result_graphdef));
+
// Ensure that quantizing the graph_def results in a graph with the same
- // number of nodes.
+ // number of nodes as the graph_def.
GraphConstructorOptions opts;
Graph result_graph(OpRegistry::Global());
TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph));
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index d9ec8e8e9b..ec582739b7 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -60,6 +60,7 @@ cc_library(
srcs = [
"add_default_attributes.cc",
"backports.cc",
+ "fake_quantize_training.cc",
"fold_batch_norms.cc",
"fold_constants_lib.cc",
"fold_old_batch_norms.cc",
@@ -109,6 +110,7 @@ tf_cc_test(
srcs = [
"add_default_attributes_test.cc",
"backports_test.cc",
+ "fake_quantize_training_test.cc",
"fold_batch_norms_test.cc",
"fold_constants_test.cc",
"fold_old_batch_norms_test.cc",
diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training.cc b/tensorflow/tools/graph_transforms/fake_quantize_training.cc
new file mode 100644
index 0000000000..321de47db1
--- /dev/null
+++ b/tensorflow/tools/graph_transforms/fake_quantize_training.cc
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/graph/quantize_training.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+// Rewrites the GraphDef for quantized training.
+// Rewrites the forward pass to include the precision loss with quantization so
+// the model can learn to deal with such loss and achieve better accuracy when
+// it is quantized later for inference.
+// Quantization range information is collected in FakeQuantizeWithMinMaxVars
+// ops.
+//
+// TODO(suharshs): Provide instructions on converting the resulting graph for
+// inference.
+// TODO(suharshs): Implement this using the GTT rather than calling the old
+// prototype function.
+Status FakeQuantizeTraining(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def) {
+ // TODO(suharshs): Make num_bits a parameter.
+ const int32 num_bits = 8;
+ // TODO(suharshs): Make quantization op a parameter?
+ const string quant_op_type = "FakeQuantWithMinMaxVars";
+
+ return DoQuantizeTrainingOnGraphDef(input_graph_def, num_bits, quant_op_type,
+ output_graph_def);
+}
+
+REGISTER_GRAPH_TRANSFORM("fake_quantize_training", FakeQuantizeTraining);
+
+} // namespace graph_transforms
+} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
new file mode 100644
index 0000000000..3ea7f512c6
--- /dev/null
+++ b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/math_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+// Declare here, so we don't need a public header.
+Status FakeQuantizeTraining(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def);
+
+class FakeQuantizeTrainingTest : public ::testing::Test {};
+
+// For now, since the fake_quantize_training transform just calls the
+// quantize_training rewrite from tensorflow/core/graph/quantize_training.h,
+// we just test that the graph has been changed by the transform.
+// TODO(suharshs): Once we implement the fake_quantize_training transform
+// using the GTT, write proper tests of the transform here.
+TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor a_data(DT_FLOAT, TensorShape());
+ test::FillIota<float>(&a_data, 1.0f);
+ Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
+
+ Tensor b_data(DT_FLOAT, TensorShape());
+ test::FillIota<float>(&b_data, 1.0f);
+ Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
+
+ Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const);
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+ GraphDef result;
+ TransformFuncContext context;
+ TF_ASSERT_OK(FakeQuantizeTraining(graph_def, context, &result));
+
+ // Test that the transformation resulted in a graph with more nodes.
+ EXPECT_GT(result.node_size(), graph_def.node_size());
+}
+
+} // namespace graph_transforms
+} // namespace tensorflow