aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/quantize_training_test.cc
diff options
context:
space:
mode:
authorGravatar Jianmin Chen <goog.jmchen@gmail.com>2016-06-06 15:18:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-06 16:33:17 -0700
commita8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8 (patch)
treec999c38e537a506a674d39f9e546f8dba0c508e5 /tensorflow/core/graph/quantize_training_test.cc
parent17f56a8f11e9a98213ecc1d7d9059ef986b52d89 (diff)
Rewriting training graph to simulate the precision loss for quantized inference.
This finds all the matmul and conv2d ops (with the most precision loss) and convert its inputs according to their types. This rewriting uses the quantize_and_dequantize op to convert tensors with the following types. 1. Const/Variable OP: This is quantized as signed tensors with no given range. 2. Activation OP: Set the range accordingly for different types of activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh} 3. Identity OP: The quantization parameters depend on what its input is. 4. Pooling OPs: various pooling ops. Also depends on its input. 5. Reshape OP: Also depends on the first input to this op. 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the model input. However, if there are >1 unknown ops, then return an error for now to avoid unexpected bahavior. Note: The list above might not be a complete list. Please let us know if you see the CHECK failure so we can include your use case. Change: 124190453
Diffstat (limited to 'tensorflow/core/graph/quantize_training_test.cc')
-rw-r--r--tensorflow/core/graph/quantize_training_test.cc161
1 files changed, 161 insertions, 0 deletions
diff --git a/tensorflow/core/graph/quantize_training_test.cc b/tensorflow/core/graph/quantize_training_test.cc
new file mode 100644
index 0000000000..d6663e0a50
--- /dev/null
+++ b/tensorflow/core/graph/quantize_training_test.cc
@@ -0,0 +1,161 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <map>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/graph/quantize_training.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+class QuantizeTrainingTest : public ::testing::Test {
+ protected:
+ QuantizeTrainingTest() { Reset(); }
+ void Reset() { g_.reset(new Graph(OpRegistry::Global())); }
+
+ template <typename T>
+ Node* Constant(gtl::ArraySlice<T> values, TensorShape shape) {
+ return test::graph::Constant(g_.get(), test::AsTensor(values, shape));
+ }
+
+ std::unique_ptr<Graph> g_;
+};
+
+TEST_F(QuantizeTrainingTest, NormalGraph) {
+ // Construct the following graph
+ /*
+ m1 m2
+ / \ / \
+ Relu Identity c
+ | |
+ a b
+ */
+ Reset();
+ Graph* g = g_.get();
+ Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2});
+ g->AddControlEdge(g->source_node(), a);
+ g->AddControlEdge(g->source_node(), b);
+ g->AddControlEdge(g->source_node(), c);
+ Node* relu = test::graph::Relu(g, a);
+ Node* identity = test::graph::Identity(g, b);
+ Node* m1 = test::graph::Matmul(g, relu, identity, false, false);
+ Node* m2 = test::graph::Matmul(g, identity, c, false, false);
+ g->AddControlEdge(m1, g->sink_node());
+ g->AddControlEdge(m2, g->sink_node());
+
+ // The graph after the rewriting should be:
+ // "Q" is the quantize_and_dequantize op.
+ // Note the Q in the middle is shared by both m1 and m2.
+ /*
+ m1 m2
+ / \ / \
+ Q Q Q
+ | | |
+ Relu Identity c
+ | |
+ a b
+ */
+ int num_bits = 8;
+ // 4 edges to modify
+ TF_ASSERT_OK(DoQuantizeTraining(num_bits, g));
+
+ // There should be 12 nodes in total including the source and sink nodes.
+ EXPECT_EQ(12, g->num_nodes());
+ // Nodes m1 and m2's inputs should be the quantize_and_dequantize op.
+ std::vector<Node*> target_nodes{m1, m2};
+ for (Node* n : target_nodes) {
+ for (Node* in : n->in_nodes()) {
+ EXPECT_EQ("_QuantizeAndDequantize", in->type_string());
+ }
+ }
+
+ // relu, identity, c should now connect to the quantize_and_dequantize nodes.
+ std::vector<Node*> target_inputs{relu, identity, c};
+ for (Node* n : target_inputs) {
+ for (Node* out : n->out_nodes()) {
+ EXPECT_EQ("_QuantizeAndDequantize", out->type_string());
+ }
+ }
+
+ // Quantize_and_dequantize node for identity should have signed_input==true.
+ NodeDef identity_Q = identity->out_nodes().begin()->def();
+ ASSERT_EQ("true",
+ SummarizeAttrValue(identity_Q.attr().find("signed_input")->second));
+ // Quantize_and_dequantize node for relu should have signed_input==false.
+ NodeDef relu_Q = relu->out_nodes().begin()->def();
+ ASSERT_EQ("false",
+ SummarizeAttrValue(relu_Q.attr().find("signed_input")->second));
+}
+
+TEST_F(QuantizeTrainingTest, WithBackwardNodes) {
+ // Construct the same graph plus another backward Matmul.
+ Reset();
+ Graph* g = g_.get();
+ Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2});
+ g->AddControlEdge(g->source_node(), a);
+ g->AddControlEdge(g->source_node(), b);
+ g->AddControlEdge(g->source_node(), c);
+ Node* relu = test::graph::Relu(g, a);
+ Node* identity = test::graph::Identity(g, b);
+ Node* m1 = test::graph::Matmul(g, relu, identity, false, false);
+ Node* m2 = test::graph::Matmul(g, identity, c, false, false);
+ g->AddControlEdge(m1, g->sink_node());
+ g->AddControlEdge(m2, g->sink_node());
+
+ // Add a Matmul node with name starting with "gradients".
+ Node* backward_m;
+ TF_ASSERT_OK(NodeBuilder(g->NewName("gradients/n"), "MatMul")
+ .Input(m1)
+ .Input(m2)
+ .Attr("transpose_a", true)
+ .Attr("transpose_b", false)
+ .Finalize(g, &backward_m));
+ g->AddControlEdge(backward_m, g->sink_node());
+
+ int num_bits = 8;
+ // Still 4 changes since the inputs of backward node will not be converted.
+ TF_ASSERT_OK(DoQuantizeTraining(num_bits, g));
+
+ // Nodes m1 and m2's inputs should now be the quantize_and_dequantize op.
+ EXPECT_EQ(13, g->num_nodes());
+ EXPECT_EQ(2, m2->num_inputs());
+}
+
+#undef SIMPLE_GRAPH
+
+} // namespace
+} // namespace tensorflow