diff options
-rw-r--r-- | tensorflow/core/graph/quantize_training.cc | 229 | ||||
-rw-r--r-- | tensorflow/core/graph/quantize_training.h | 37 | ||||
-rw-r--r-- | tensorflow/core/graph/quantize_training_test.cc | 161 | ||||
-rw-r--r-- | tensorflow/core/graph/testlib.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/graph/testlib.h | 3 |
5 files changed, 439 insertions, 0 deletions
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc new file mode 100644 index 0000000000..23ce7daeff --- /dev/null +++ b/tensorflow/core/graph/quantize_training.cc @@ -0,0 +1,229 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <algorithm> +#include <atomic> +#include <set> +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/graph/quantize_training.h" + +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/memory_types.h" +#include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { +// Node types to rewrite. Insert quantize_and_dequantize op for their inputs. +const std::unordered_set<string, StringPiece::Hasher> nodes_to_rewrite{ + "MatMul", "Conv2D"}; + +// Contains necessary parameters to convert an edge. +struct EdgeToConvert { + // Edge is not owned here. + const Edge* edge; + int32 num_bits; + bool signed_input; + bool range_given; + float input_min; + float input_max; + + EdgeToConvert(const Edge* e, int32 bits, bool sign, bool range, float min, + float max) { + edge = e; + num_bits = bits; + signed_input = sign; + range_given = range; + input_min = min; + input_max = max; + } +}; + +// Decide if a node is in backward pass by checking if its name is led by +// "gradients". +// TODO(jmchen): Make this check more robust as it is not guaranteed that the +// forward node will not be named with a leading "gradients". +inline bool IsGradientNode(const Graph* graph, const Node* node) { + static const string tag = "gradients"; + return (node->name().compare(0, tag.size(), tag) == 0); +} + +// Find the type of the input to set the parameters for the +// quantize_and_dequantize op. +// Returns true if the root tensor op type is known, false otherwise. +bool FindType(const Graph* graph, const Node* node, bool* signed_input, + bool* range_given, float* input_min, float* input_max) { + const string src_op = node->type_string(); + if (src_op == "Const" || src_op == "Variable") { + *signed_input = true; + *range_given = false; + } else if (src_op == "Relu") { + // Range is not given for Relu. + *signed_input = false; + *range_given = false; + } else if (src_op == "Relu6") { + *signed_input = false; + *range_given = true; + *input_min = 0; + *input_max = 6; + } else if (src_op == "Sigmoid") { + *signed_input = false; + *range_given = true; + *input_min = 0; + *input_max = 1; + } else if (src_op == "Tanh") { + *signed_input = true; + *range_given = true; + *input_min = -1; + *input_max = 1; + } else if (src_op == "Reshape") { + // Reshape has 2 inputs and the first one is the tensor. + for (const Edge* edge : node->in_edges()) { + if (edge->src_output() != Graph::kControlSlot && edge->dst_input() == 0) { + FindType(graph, edge->src(), signed_input, range_given, input_min, + input_max); + } + } + } else if (src_op == "Identity" || src_op == "MaxPool" || + src_op == "AvgPool" || src_op == "MaxPool3D" || + src_op == "AvgPool3D") { + // All these Ops only have 1 data input. + for (const Edge* edge : node->in_edges()) { + if (edge->src_output() != Graph::kControlSlot) { + FindType(graph, edge->src(), signed_input, range_given, input_min, + input_max); + } + } + } else { + // Unknown type, could be the model input examples. + // TODO: Set the params for input with user's hint. + *signed_input = true; + *range_given = false; + return false; + } + + return true; +} + +// Insert conversion op, connect it to the graph and remove the old edge. +Status ProcessTargetEdges(Graph* graph, + const std::vector<EdgeToConvert>& target_edges) { + // Remember previous convert ops to avoid duplicated conversion on the same + // input. + std::unordered_map<string, Node*, StringPiece::Hasher> name_index; + for (const EdgeToConvert edge : target_edges) { + Node* convert_node; + string name = + strings::StrCat(edge.edge->src()->name(), "/_QuantizeAndDequantize"); + + auto iter = name_index.find(name); + if (iter == name_index.end()) { + TF_RETURN_IF_ERROR(NodeBuilder(name, "_QuantizeAndDequantize") + .Input(edge.edge->src()) + .Attr("signed_input", edge.signed_input) + .Attr("num_bits", edge.num_bits) + .Attr("range_given", edge.range_given) + .Attr("input_min", edge.input_min) + .Attr("input_max", edge.input_max) + .Finalize(graph, &convert_node)); + + name_index[name] = convert_node; + } else { + convert_node = iter->second; + } + + graph->AddEdge(convert_node, 0, edge.edge->dst(), edge.edge->dst_input()); + graph->RemoveEdge(edge.edge); + } + + return Status::OK(); +} + +} // namespace + +Status DoQuantizeTraining(int32 num_bits, Graph* graph) { + if (graph == nullptr) { + return errors::InvalidArgument("Cannot accept empty graph pointer."); + } + + if (num_bits < 1 || num_bits > 63) { + return errors::OutOfRange("num_bits should be in range [1, 63] but is: ", + num_bits); + } + int potential_input = 0; + std::vector<EdgeToConvert> target_edges; + for (Node* node : graph->nodes()) { + if (nodes_to_rewrite.find(node->type_string()) != nodes_to_rewrite.end() && + !IsGradientNode(graph, node)) { + // Find out which types are the inputs and convert them accordingly. + // 1. Const/Variable OP: This is quantized as signed tensors with no given + // range. + // 2. Activation OP: Set the range accordingly for different types of + // activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh} + // 3. Identity OP: The quantization parameters depend on its input. + // 4. Pooling OPs: various pooling ops. Also depends on its input. + // 5. Reshape OP: Also depends on the first input to this op. + // 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the + // model input. However, if there are >1 unknown ops, then returns an + // error for now to avoid unexpected bahavior. + // Note: The list above might not be a complete list. Please let us + // know if you see the error so we can handle your case. + for (const Edge* edge : node->in_edges()) { + if (edge->src_output() == Graph::kControlSlot) { + // Skip the control dependency input. + continue; + } else { + bool signed_input = false; + bool range_given = false; + float input_min = 0; + float input_max = 0; + bool known_op = FindType(graph, edge->src(), &signed_input, + &range_given, &input_min, &input_max); + if (!known_op) { + // Unknown op is considered as input. + // Only support one input for now. + // TODO: Make this configurable if this is the desirable way to find + // input. + if (potential_input > 0) { + return errors::Unimplemented( + "Find a second unknown op: ", edge->src()->name(), + " with type: ", edge->src()->type_string(), + "; Unknown ops are considered as model input for now and " + "only 1 input is supported currently."); + } + potential_input++; + } + + target_edges.emplace_back(EdgeToConvert( + edge, num_bits, signed_input, range_given, input_min, input_max)); + } + } + } + } + + TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, target_edges)); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/quantize_training.h b/tensorflow/core/graph/quantize_training.h new file mode 100644 index 0000000000..694c491620 --- /dev/null +++ b/tensorflow/core/graph/quantize_training.h @@ -0,0 +1,37 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_ +#define TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +// Rewrites graph for quantized training. +// Rewrites the forward pass to include the precision loss with quantization so +// the model can learn to deal with such loss and achieve better accuracy when +// it is quantized later for inference. +// Note that the num_bits should be in [1, 63] and 'g' must be not null. +// +// On success, returns OK. +// +// On failure, returns the error status. Possible errors include: +// - num_bits out of range. +// - g is null. +// - More than 1 unknown ops encountered. +Status DoQuantizeTraining(int32 num_bits, Graph* g); +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_ diff --git a/tensorflow/core/graph/quantize_training_test.cc b/tensorflow/core/graph/quantize_training_test.cc new file mode 100644 index 0000000000..d6663e0a50 --- /dev/null +++ b/tensorflow/core/graph/quantize_training_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <map> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/graph/quantize_training.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { + +class QuantizeTrainingTest : public ::testing::Test { + protected: + QuantizeTrainingTest() { Reset(); } + void Reset() { g_.reset(new Graph(OpRegistry::Global())); } + + template <typename T> + Node* Constant(gtl::ArraySlice<T> values, TensorShape shape) { + return test::graph::Constant(g_.get(), test::AsTensor(values, shape)); + } + + std::unique_ptr<Graph> g_; +}; + +TEST_F(QuantizeTrainingTest, NormalGraph) { + // Construct the following graph + /* + m1 m2 + / \ / \ + Relu Identity c + | | + a b + */ + Reset(); + Graph* g = g_.get(); + Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); + Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); + Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2}); + g->AddControlEdge(g->source_node(), a); + g->AddControlEdge(g->source_node(), b); + g->AddControlEdge(g->source_node(), c); + Node* relu = test::graph::Relu(g, a); + Node* identity = test::graph::Identity(g, b); + Node* m1 = test::graph::Matmul(g, relu, identity, false, false); + Node* m2 = test::graph::Matmul(g, identity, c, false, false); + g->AddControlEdge(m1, g->sink_node()); + g->AddControlEdge(m2, g->sink_node()); + + // The graph after the rewriting should be: + // "Q" is the quantize_and_dequantize op. + // Note the Q in the middle is shared by both m1 and m2. + /* + m1 m2 + / \ / \ + Q Q Q + | | | + Relu Identity c + | | + a b + */ + int num_bits = 8; + // 4 edges to modify + TF_ASSERT_OK(DoQuantizeTraining(num_bits, g)); + + // There should be 12 nodes in total including the source and sink nodes. + EXPECT_EQ(12, g->num_nodes()); + // Nodes m1 and m2's inputs should be the quantize_and_dequantize op. + std::vector<Node*> target_nodes{m1, m2}; + for (Node* n : target_nodes) { + for (Node* in : n->in_nodes()) { + EXPECT_EQ("_QuantizeAndDequantize", in->type_string()); + } + } + + // relu, identity, c should now connect to the quantize_and_dequantize nodes. + std::vector<Node*> target_inputs{relu, identity, c}; + for (Node* n : target_inputs) { + for (Node* out : n->out_nodes()) { + EXPECT_EQ("_QuantizeAndDequantize", out->type_string()); + } + } + + // Quantize_and_dequantize node for identity should have signed_input==true. + NodeDef identity_Q = identity->out_nodes().begin()->def(); + ASSERT_EQ("true", + SummarizeAttrValue(identity_Q.attr().find("signed_input")->second)); + // Quantize_and_dequantize node for relu should have signed_input==false. + NodeDef relu_Q = relu->out_nodes().begin()->def(); + ASSERT_EQ("false", + SummarizeAttrValue(relu_Q.attr().find("signed_input")->second)); +} + +TEST_F(QuantizeTrainingTest, WithBackwardNodes) { + // Construct the same graph plus another backward Matmul. + Reset(); + Graph* g = g_.get(); + Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); + Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); + Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2}); + g->AddControlEdge(g->source_node(), a); + g->AddControlEdge(g->source_node(), b); + g->AddControlEdge(g->source_node(), c); + Node* relu = test::graph::Relu(g, a); + Node* identity = test::graph::Identity(g, b); + Node* m1 = test::graph::Matmul(g, relu, identity, false, false); + Node* m2 = test::graph::Matmul(g, identity, c, false, false); + g->AddControlEdge(m1, g->sink_node()); + g->AddControlEdge(m2, g->sink_node()); + + // Add a Matmul node with name starting with "gradients". + Node* backward_m; + TF_ASSERT_OK(NodeBuilder(g->NewName("gradients/n"), "MatMul") + .Input(m1) + .Input(m2) + .Attr("transpose_a", true) + .Attr("transpose_b", false) + .Finalize(g, &backward_m)); + g->AddControlEdge(backward_m, g->sink_node()); + + int num_bits = 8; + // Still 4 changes since the inputs of backward node will not be converted. + TF_ASSERT_OK(DoQuantizeTraining(num_bits, g)); + + // Nodes m1 and m2's inputs should now be the quantize_and_dequantize op. + EXPECT_EQ(13, g->num_nodes()); + EXPECT_EQ(2, m2->num_inputs()); +} + +#undef SIMPLE_GRAPH + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index 0d0a84db79..ec878437dc 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -384,6 +384,15 @@ Node* GetSessionTensor(Graph* g, Node* in) { return ret; } +Node* Relu(Graph* g, Node* in) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu") + .Input(in, 0) + .Attr("T", DT_FLOAT) + .Finalize(g, &ret)); + return ret; +} + void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } } // end namespace graph diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h index 511f6b4310..bc4863563f 100644 --- a/tensorflow/core/graph/testlib.h +++ b/tensorflow/core/graph/testlib.h @@ -169,6 +169,9 @@ Node* GetSessionTensor(Graph* g, Node* in); // given in "tensors". Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors); +// Add a Relu node in "g". +Node* Relu(Graph* g, Node* in); + } // end namespace graph } // end namespace test } // end namespace tensorflow |