aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/graph/quantize_training.cc229
-rw-r--r--tensorflow/core/graph/quantize_training.h37
-rw-r--r--tensorflow/core/graph/quantize_training_test.cc161
-rw-r--r--tensorflow/core/graph/testlib.cc9
-rw-r--r--tensorflow/core/graph/testlib.h3
5 files changed, 439 insertions, 0 deletions
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc
new file mode 100644
index 0000000000..23ce7daeff
--- /dev/null
+++ b/tensorflow/core/graph/quantize_training.cc
@@ -0,0 +1,229 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <atomic>
+#include <set>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/graph/quantize_training.h"
+
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/memory_types.h"
+#include "tensorflow/core/framework/log_memory.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+// Node types to rewrite. Insert quantize_and_dequantize op for their inputs.
+const std::unordered_set<string, StringPiece::Hasher> nodes_to_rewrite{
+ "MatMul", "Conv2D"};
+
+// Contains necessary parameters to convert an edge.
+struct EdgeToConvert {
+ // Edge is not owned here.
+ const Edge* edge;
+ int32 num_bits;
+ bool signed_input;
+ bool range_given;
+ float input_min;
+ float input_max;
+
+ EdgeToConvert(const Edge* e, int32 bits, bool sign, bool range, float min,
+ float max) {
+ edge = e;
+ num_bits = bits;
+ signed_input = sign;
+ range_given = range;
+ input_min = min;
+ input_max = max;
+ }
+};
+
+// Decide if a node is in backward pass by checking if its name is led by
+// "gradients".
+// TODO(jmchen): Make this check more robust as it is not guaranteed that the
+// forward node will not be named with a leading "gradients".
+inline bool IsGradientNode(const Graph* graph, const Node* node) {
+ static const string tag = "gradients";
+ return (node->name().compare(0, tag.size(), tag) == 0);
+}
+
+// Find the type of the input to set the parameters for the
+// quantize_and_dequantize op.
+// Returns true if the root tensor op type is known, false otherwise.
+bool FindType(const Graph* graph, const Node* node, bool* signed_input,
+ bool* range_given, float* input_min, float* input_max) {
+ const string src_op = node->type_string();
+ if (src_op == "Const" || src_op == "Variable") {
+ *signed_input = true;
+ *range_given = false;
+ } else if (src_op == "Relu") {
+ // Range is not given for Relu.
+ *signed_input = false;
+ *range_given = false;
+ } else if (src_op == "Relu6") {
+ *signed_input = false;
+ *range_given = true;
+ *input_min = 0;
+ *input_max = 6;
+ } else if (src_op == "Sigmoid") {
+ *signed_input = false;
+ *range_given = true;
+ *input_min = 0;
+ *input_max = 1;
+ } else if (src_op == "Tanh") {
+ *signed_input = true;
+ *range_given = true;
+ *input_min = -1;
+ *input_max = 1;
+ } else if (src_op == "Reshape") {
+ // Reshape has 2 inputs and the first one is the tensor.
+ for (const Edge* edge : node->in_edges()) {
+ if (edge->src_output() != Graph::kControlSlot && edge->dst_input() == 0) {
+ FindType(graph, edge->src(), signed_input, range_given, input_min,
+ input_max);
+ }
+ }
+ } else if (src_op == "Identity" || src_op == "MaxPool" ||
+ src_op == "AvgPool" || src_op == "MaxPool3D" ||
+ src_op == "AvgPool3D") {
+ // All these Ops only have 1 data input.
+ for (const Edge* edge : node->in_edges()) {
+ if (edge->src_output() != Graph::kControlSlot) {
+ FindType(graph, edge->src(), signed_input, range_given, input_min,
+ input_max);
+ }
+ }
+ } else {
+ // Unknown type, could be the model input examples.
+ // TODO: Set the params for input with user's hint.
+ *signed_input = true;
+ *range_given = false;
+ return false;
+ }
+
+ return true;
+}
+
+// Insert conversion op, connect it to the graph and remove the old edge.
+Status ProcessTargetEdges(Graph* graph,
+ const std::vector<EdgeToConvert>& target_edges) {
+ // Remember previous convert ops to avoid duplicated conversion on the same
+ // input.
+ std::unordered_map<string, Node*, StringPiece::Hasher> name_index;
+ for (const EdgeToConvert edge : target_edges) {
+ Node* convert_node;
+ string name =
+ strings::StrCat(edge.edge->src()->name(), "/_QuantizeAndDequantize");
+
+ auto iter = name_index.find(name);
+ if (iter == name_index.end()) {
+ TF_RETURN_IF_ERROR(NodeBuilder(name, "_QuantizeAndDequantize")
+ .Input(edge.edge->src())
+ .Attr("signed_input", edge.signed_input)
+ .Attr("num_bits", edge.num_bits)
+ .Attr("range_given", edge.range_given)
+ .Attr("input_min", edge.input_min)
+ .Attr("input_max", edge.input_max)
+ .Finalize(graph, &convert_node));
+
+ name_index[name] = convert_node;
+ } else {
+ convert_node = iter->second;
+ }
+
+ graph->AddEdge(convert_node, 0, edge.edge->dst(), edge.edge->dst_input());
+ graph->RemoveEdge(edge.edge);
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status DoQuantizeTraining(int32 num_bits, Graph* graph) {
+ if (graph == nullptr) {
+ return errors::InvalidArgument("Cannot accept empty graph pointer.");
+ }
+
+ if (num_bits < 1 || num_bits > 63) {
+ return errors::OutOfRange("num_bits should be in range [1, 63] but is: ",
+ num_bits);
+ }
+ int potential_input = 0;
+ std::vector<EdgeToConvert> target_edges;
+ for (Node* node : graph->nodes()) {
+ if (nodes_to_rewrite.find(node->type_string()) != nodes_to_rewrite.end() &&
+ !IsGradientNode(graph, node)) {
+ // Find out which types are the inputs and convert them accordingly.
+ // 1. Const/Variable OP: This is quantized as signed tensors with no given
+ // range.
+ // 2. Activation OP: Set the range accordingly for different types of
+ // activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh}
+ // 3. Identity OP: The quantization parameters depend on its input.
+ // 4. Pooling OPs: various pooling ops. Also depends on its input.
+ // 5. Reshape OP: Also depends on the first input to this op.
+ // 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the
+ // model input. However, if there are >1 unknown ops, then returns an
+ // error for now to avoid unexpected bahavior.
+ // Note: The list above might not be a complete list. Please let us
+ // know if you see the error so we can handle your case.
+ for (const Edge* edge : node->in_edges()) {
+ if (edge->src_output() == Graph::kControlSlot) {
+ // Skip the control dependency input.
+ continue;
+ } else {
+ bool signed_input = false;
+ bool range_given = false;
+ float input_min = 0;
+ float input_max = 0;
+ bool known_op = FindType(graph, edge->src(), &signed_input,
+ &range_given, &input_min, &input_max);
+ if (!known_op) {
+ // Unknown op is considered as input.
+ // Only support one input for now.
+ // TODO: Make this configurable if this is the desirable way to find
+ // input.
+ if (potential_input > 0) {
+ return errors::Unimplemented(
+ "Find a second unknown op: ", edge->src()->name(),
+ " with type: ", edge->src()->type_string(),
+ "; Unknown ops are considered as model input for now and "
+ "only 1 input is supported currently.");
+ }
+ potential_input++;
+ }
+
+ target_edges.emplace_back(EdgeToConvert(
+ edge, num_bits, signed_input, range_given, input_min, input_max));
+ }
+ }
+ }
+ }
+
+ TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, target_edges));
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/quantize_training.h b/tensorflow/core/graph/quantize_training.h
new file mode 100644
index 0000000000..694c491620
--- /dev/null
+++ b/tensorflow/core/graph/quantize_training.h
@@ -0,0 +1,37 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
+#define TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+// Rewrites graph for quantized training.
+// Rewrites the forward pass to include the precision loss with quantization so
+// the model can learn to deal with such loss and achieve better accuracy when
+// it is quantized later for inference.
+// Note that the num_bits should be in [1, 63] and 'g' must be not null.
+//
+// On success, returns OK.
+//
+// On failure, returns the error status. Possible errors include:
+// - num_bits out of range.
+// - g is null.
+// - More than 1 unknown ops encountered.
+Status DoQuantizeTraining(int32 num_bits, Graph* g);
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
diff --git a/tensorflow/core/graph/quantize_training_test.cc b/tensorflow/core/graph/quantize_training_test.cc
new file mode 100644
index 0000000000..d6663e0a50
--- /dev/null
+++ b/tensorflow/core/graph/quantize_training_test.cc
@@ -0,0 +1,161 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <map>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/graph/quantize_training.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+class QuantizeTrainingTest : public ::testing::Test {
+ protected:
+ QuantizeTrainingTest() { Reset(); }
+ void Reset() { g_.reset(new Graph(OpRegistry::Global())); }
+
+ template <typename T>
+ Node* Constant(gtl::ArraySlice<T> values, TensorShape shape) {
+ return test::graph::Constant(g_.get(), test::AsTensor(values, shape));
+ }
+
+ std::unique_ptr<Graph> g_;
+};
+
+TEST_F(QuantizeTrainingTest, NormalGraph) {
+ // Construct the following graph
+ /*
+ m1 m2
+ / \ / \
+ Relu Identity c
+ | |
+ a b
+ */
+ Reset();
+ Graph* g = g_.get();
+ Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2});
+ g->AddControlEdge(g->source_node(), a);
+ g->AddControlEdge(g->source_node(), b);
+ g->AddControlEdge(g->source_node(), c);
+ Node* relu = test::graph::Relu(g, a);
+ Node* identity = test::graph::Identity(g, b);
+ Node* m1 = test::graph::Matmul(g, relu, identity, false, false);
+ Node* m2 = test::graph::Matmul(g, identity, c, false, false);
+ g->AddControlEdge(m1, g->sink_node());
+ g->AddControlEdge(m2, g->sink_node());
+
+ // The graph after the rewriting should be:
+ // "Q" is the quantize_and_dequantize op.
+ // Note the Q in the middle is shared by both m1 and m2.
+ /*
+ m1 m2
+ / \ / \
+ Q Q Q
+ | | |
+ Relu Identity c
+ | |
+ a b
+ */
+ int num_bits = 8;
+ // 4 edges to modify
+ TF_ASSERT_OK(DoQuantizeTraining(num_bits, g));
+
+ // There should be 12 nodes in total including the source and sink nodes.
+ EXPECT_EQ(12, g->num_nodes());
+ // Nodes m1 and m2's inputs should be the quantize_and_dequantize op.
+ std::vector<Node*> target_nodes{m1, m2};
+ for (Node* n : target_nodes) {
+ for (Node* in : n->in_nodes()) {
+ EXPECT_EQ("_QuantizeAndDequantize", in->type_string());
+ }
+ }
+
+ // relu, identity, c should now connect to the quantize_and_dequantize nodes.
+ std::vector<Node*> target_inputs{relu, identity, c};
+ for (Node* n : target_inputs) {
+ for (Node* out : n->out_nodes()) {
+ EXPECT_EQ("_QuantizeAndDequantize", out->type_string());
+ }
+ }
+
+ // Quantize_and_dequantize node for identity should have signed_input==true.
+ NodeDef identity_Q = identity->out_nodes().begin()->def();
+ ASSERT_EQ("true",
+ SummarizeAttrValue(identity_Q.attr().find("signed_input")->second));
+ // Quantize_and_dequantize node for relu should have signed_input==false.
+ NodeDef relu_Q = relu->out_nodes().begin()->def();
+ ASSERT_EQ("false",
+ SummarizeAttrValue(relu_Q.attr().find("signed_input")->second));
+}
+
+TEST_F(QuantizeTrainingTest, WithBackwardNodes) {
+ // Construct the same graph plus another backward Matmul.
+ Reset();
+ Graph* g = g_.get();
+ Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2});
+ g->AddControlEdge(g->source_node(), a);
+ g->AddControlEdge(g->source_node(), b);
+ g->AddControlEdge(g->source_node(), c);
+ Node* relu = test::graph::Relu(g, a);
+ Node* identity = test::graph::Identity(g, b);
+ Node* m1 = test::graph::Matmul(g, relu, identity, false, false);
+ Node* m2 = test::graph::Matmul(g, identity, c, false, false);
+ g->AddControlEdge(m1, g->sink_node());
+ g->AddControlEdge(m2, g->sink_node());
+
+ // Add a Matmul node with name starting with "gradients".
+ Node* backward_m;
+ TF_ASSERT_OK(NodeBuilder(g->NewName("gradients/n"), "MatMul")
+ .Input(m1)
+ .Input(m2)
+ .Attr("transpose_a", true)
+ .Attr("transpose_b", false)
+ .Finalize(g, &backward_m));
+ g->AddControlEdge(backward_m, g->sink_node());
+
+ int num_bits = 8;
+ // Still 4 changes since the inputs of backward node will not be converted.
+ TF_ASSERT_OK(DoQuantizeTraining(num_bits, g));
+
+ // Nodes m1 and m2's inputs should now be the quantize_and_dequantize op.
+ EXPECT_EQ(13, g->num_nodes());
+ EXPECT_EQ(2, m2->num_inputs());
+}
+
+#undef SIMPLE_GRAPH
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index 0d0a84db79..ec878437dc 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -384,6 +384,15 @@ Node* GetSessionTensor(Graph* g, Node* in) {
return ret;
}
+Node* Relu(Graph* g, Node* in) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu")
+ .Input(in, 0)
+ .Attr("T", DT_FLOAT)
+ .Finalize(g, &ret));
+ return ret;
+}
+
void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
} // end namespace graph
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 511f6b4310..bc4863563f 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -169,6 +169,9 @@ Node* GetSessionTensor(Graph* g, Node* in);
// given in "tensors".
Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors);
+// Add a Relu node in "g".
+Node* Relu(Graph* g, Node* in);
+
} // end namespace graph
} // end namespace test
} // end namespace tensorflow