aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/logging_ops.cc2
-rw-r--r--tensorflow/tools/graph_transforms/BUILD4
-rw-r--r--tensorflow/tools/graph_transforms/README.md168
-rw-r--r--tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc213
-rw-r--r--tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc200
-rw-r--r--tensorflow/tools/graph_transforms/insert_logging.cc153
-rw-r--r--tensorflow/tools/graph_transforms/insert_logging_test.cc203
-rw-r--r--tensorflow/tools/graph_transforms/quantize_nodes.cc10
-rw-r--r--tensorflow/tools/graph_transforms/remove_nodes.cc2
-rw-r--r--tensorflow/tools/graph_transforms/round_weights.cc5
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils.cc27
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils.h22
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils_test.cc91
13 files changed, 1050 insertions, 50 deletions
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index e5807840ce..67d603dd0a 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 3bd2cc9ac9..8bd18bec09 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -60,7 +60,9 @@ cc_library(
"fold_batch_norms.cc",
"fold_constants_lib.cc",
"fold_old_batch_norms.cc",
+ "freeze_requantization_ranges.cc",
"fuse_convolutions.cc",
+ "insert_logging.cc",
"obsfucate_names.cc",
"quantize_nodes.cc",
"quantize_weights.cc",
@@ -101,7 +103,9 @@ tf_cc_test(
"fold_batch_norms_test.cc",
"fold_constants_test.cc",
"fold_old_batch_norms_test.cc",
+ "freeze_requantization_ranges_test.cc",
"fuse_convolutions_test.cc",
+ "insert_logging_test.cc",
"obsfucate_names_test.cc",
"quantize_nodes_test.cc",
"quantize_weights_test.cc",
diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md
index 334c65117e..37545feb4f 100644
--- a/tensorflow/tools/graph_transforms/README.md
+++ b/tensorflow/tools/graph_transforms/README.md
@@ -16,7 +16,9 @@
* [fold_batch_norms](#fold_batch_norms)
* [fold_constants](#fold_constants)
* [fold_old_batch_norms](#fold_old_batch_norms)
+ * [freeze_requantization_ranges](#freeze_requantization_ranges)
* [fuse_convolutions](#fuse_convolutions)
+ * [insert_logging](#insert_logging)
* [merge_duplicate_nodes](#merge_duplicate_nodes)
* [obsfucate_names](#obsfucate_names)
* [quantize_nodes](#quantize_nodes)
@@ -334,7 +336,7 @@ within the saved model, and sets them to the defined default for that attribute.
### fold_batch_norms
-Args: None
+Args: None \
Prerequisites: [fold_constants](#fold_constants)
This transform tries to optimize away the Mul that's introduced after a Conv2D
@@ -347,7 +349,7 @@ produced by training for the Mul input is collapsed down into a simple constant.
### fold_constants
-Args: None\
+Args: None \
Prerequisites: None
Looks for any sub-graphs within the model that always evaluate to constant
@@ -359,7 +361,7 @@ to continue on past transient errors, since this is just an optimization phase.
### fold_old_batch_norms
-Args: None\
+Args: None \
Prerequisites: None
In the early days of TensorFlow, batch normalization was implemented using a
@@ -370,21 +372,143 @@ have a graph that uses the older-style, this transform will recognize and
optimize those ops for inference, in the same way that the
[fold_batch_norms](#fold_batch_norms) transform does for the new approach.
+### freeze_requantization_ranges
+
+Args:
+
+* min_max_log_file: Path to a log file containing ranges for ops.
+* min_percentile: Percentage cutoff to use to calculate an overall min.
+ Defaults to 5.
+* max_percentile: Percentage cutoff to use to calculate an overall max.
+ Defaults to 5.
+
+Quantized operations like convolution or matrix multiplies take their inputs as
+8-bit, but produce 32-bit results. To do further operations on these, they need
+to be converted back down to the lower depth. To make the most of those eight
+bits, you need to scale the thirty-two bits of original data down using a scale
+that matches the range that's actually being used.
+
+Because that range information isn't stored in the original graph, the
+[quantization process](#eight-bit-calculations) inserts RequantizationRange ops
+before each conversion from 32 to 8 bits. This op looks at the 32-bit output and
+calculates the current min and max every time it's run.
+
+This isn't incredibly time-consuming, but it is extra work that's nice to avoid
+if possible. One way of optimizing that away is replacing those
+RequantizationRange ops with a pair of Const nodes holding known min/max values,
+so the scaling down can be done without having to inspect the output every time.
+
+That's what this transform does. It's usually used in conjunction with a copy of
+the graph that's had [insert_logging](#insert_logging) run on it to instrument
+it to record the min/max values to stderr. Why is logging used rather than
+writing to a normal file? As you'll see later, to get best results you want to
+collect data from a lot of runs on real data, and for mobile apps especially
+it's a lot easier to do this by copying log files. As an example, here's how
+you'd add the logging operations for a quantized version of the Inception v3
+graph:
+
+```bash
+bazel build tensorflow/tools/graph_transforms:transform_graph
+bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
+--logtostderr \
+--in_graph=/tmp/quantized_inception.pb \
+--out_graph=/tmp/logged_quantized_inception.pb \
+--inputs=Mul \
+--outputs=softmax \
+--transforms='\
+insert_logging(op=RequantizationRange, show_name=true, message="__requant_min_max:")\
+'
+```
+
+Now, when you run the `/tmp/logged_quantized_inception.pb` graph, it will write
+out log statements that show the value of the min and max calculated by each
+RequantizationRange op. Here's an example of running label_image and saving the
+log:
+
+```bash
+bazel build tensorflow/examples/label_image:label_image
+bazel-bin/tensorflow/examples/label_image/label_image \
+--image=${HOME}/Downloads/grace_hopper.jpg \
+--logtostderr \
+--input_layer=Mul \
+--output_layer=softmax \
+--graph=/tmp/logged_quantized_inception.pb \
+--labels=learning/brain/models/image/inception_v3/imagenet_comp_graph_label_strings.txt \
+--logtostderr \
+2>/tmp/min_max_log_small.txt
+```
+
+If you look in `/tmp/min_max_log_small.txt`, you'll see a lot of lines like
+this:
+
+```
+I0108 21:45:42.261883 1972 logging_ops.cc:79] ;conv/Conv2D/eightbit/requant_range__print__;__requant_min_max:[-20.887871][22.274715]
+```
+
+This is a simple way of serializing the name of the RequantizationRange op and
+its min/max values every time it's run. It's a file like this that you pass into
+the transform as the `min_max_log_file` argument. The transform will attempt to
+extract all of the min/max values associated with ops, ignoring any irrelevant
+lines in the log, and replace the RequantizationRange ops with two Const nodes
+containing the found values.
+
+This isn't the whole story though. The min/max values can vary a lot depending
+on what the particular inputs to the graph are on any given run, which means
+picking ranges based on just one run can lead to clipping of values and a loss
+of accuracy. To get better results, you need to run your network against a range
+of different inputs. In Inception's case, I often use a thousand different
+images from the training set. You can then pass the whole concatenated log from
+all of the runs into the transform, and it will pick ranges based on the
+aggregate of the values found for each RequantizationRange op.
+
+To ensure that outliers don't increase the range too much, and so decrease the
+accuracy by putting too many bits into rare extreme values, the `min_percentile`
+and `max_percentile` arguments control how the overall min and max are chosen.
+At their default values of 5, this means that the lowest 5% of the minimum
+values will be discarded, taking the minimum of the remainder, and the
+equivalent for the maximum.
+
### fuse_convolutions
-Args: None\
+Args: None \
Prerequisites: None
-For graphs that use ResizeBilinear or MirrorPad ops before convolutions (e.g.
-to scale up in the later stages of an image style transfer model),
-it can improve memory usage and latency to combine the spatial
-transformations with the convolution's im2col patch generation. This transform
-looks out for that particular pattern of ops and replaces them with a fused
-version that combines the resizing and padding with the convolution.
+For graphs that use ResizeBilinear or MirrorPad ops before convolutions (e.g. to
+scale up in the later stages of an image style transfer model), it can improve
+memory usage and latency to combine the spatial transformations with the
+convolution's im2col patch generation. This transform looks out for that
+particular pattern of ops and replaces them with a fused version that combines
+the resizing and padding with the convolution.
+
+### insert_logging
+
+Args:
+
+* op: Insert a Print after every occurrence of this op type. Can be repeated
+ to cover multiple types. If not present, all op types will be instrumented.
+* prefix: Insert a Print after every node whose name starts with this value.
+ Can be repeated to cover multiple nodes. If not present, all node names will
+ be matched.
+* show_op: If true, the op type will be prepended to all log messages.
+* show_name: If true, the node's name will be prepended to all log messages.
+* message: Arbitrary text to log before the values.
+* first_n: How many times to print before suppressing. Defaults to -1, which
+ means never stop.
+* summarize: How long numerical results can be before they're truncated.
+ Defaults to 1024.
+
+The Print operator writes strings to stderr when it's run inside a graph, and
+prints out the numerical results of the node that it's reading from. This can be
+very useful when you're debugging and want to follow particular internal values
+while a graph is running. This transform allows you to insert those ops at
+particular points in the graph, and customize the message that's displayed. It's
+also used in conjunction with the
+[freeze_requantization_ranges](#freeze_requantization_ranges) transform to
+output information that it needs.
### merge_duplicate_nodes
-Args: None\
+Args: None \
Prerequisites: None
If there are Const nodes with the same types and contents, or nodes with the
@@ -396,7 +520,7 @@ duplicates of constants that are used in the quantize/dequantize process).
### obsfucate_names
-Args: None\
+Args: None \
Prerequisites: None
Replaces all nodes' names with short generated ids, other than the inputs and
@@ -429,14 +553,14 @@ Prerequisites: [quantize_weights](#quantize_weights)
Replaces any calculation nodes with their eight-bit equivalents (if available),
and adds in conversion layers to allow remaining float operations to
interoperate. This is one of the most complex transforms, and involves multiple
-passes and a lot of rewriting. It's also still an active area of research,
-so results may vary depending on the platform and operations you're using in
-your model. You should run quantize_weights first to ensure your Const ops are
-in eight-bit form.
+passes and a lot of rewriting. It's also still an active area of research, so
+results may vary depending on the platform and operations you're using in your
+model. You should run quantize_weights first to ensure your Const ops are in
+eight-bit form.
### quantize_weights
-Args: None\
+Args: None \
Prerequisites: None
Converts any large (more than 15 element) float Const op into an eight-bit
@@ -461,7 +585,7 @@ special circumstances though.
### remove_device
-Args: None
+Args: None \
Prerequisites: None
All ops can have a hardware device specified. This can be a problem when you're
@@ -548,7 +672,7 @@ device assigned.
### sort_by_execution_order
-Args: None\
+Args: None \
Prerequisites: None
Arranges the nodes in the GraphDef in topological order, so that the inputs of
@@ -741,8 +865,8 @@ transform:
This is looking for QuantizeV2 nodes, with three inputs, the first of which is a
Dequantize, the second is a Min that ultimately pulls from a Dequantize, and the
third is a Max which does the same. Assuming we know the Dequantize ops are
-pulling from the same eight-bit buffer, the end result of this sub-graph is
-a no-op, since it's just turning the eight-bit buffer into float, and then
+pulling from the same eight-bit buffer, the end result of this sub-graph is a
+no-op, since it's just turning the eight-bit buffer into float, and then
immediately converting it back to eight-bits, so if we look for this pattern and
remove it we can optimize the graph without changing the result.
@@ -874,7 +998,7 @@ Here's an example of how [round_weights](#round_weights) reads its `num_steps`
parameter:
```C++
-TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
+TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("num_steps", 256, &num_steps));
```
If the conversion fails or the parameter occurs more than once the helper
diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
new file mode 100644
index 0000000000..8faad4a442
--- /dev/null
+++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
@@ -0,0 +1,213 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+struct MinMaxRecord {
+ string name;
+ float min;
+ float max;
+};
+
+// Try to parse a log file containing loosely-structured lines, some of which
+// are the min/max logs we want.
+Status ExtractMinMaxRecords(const string& log_file_name,
+ std::vector<MinMaxRecord>* records) {
+ string file_data;
+ TF_RETURN_IF_ERROR(
+ ReadFileToString(Env::Default(), log_file_name, &file_data));
+ const string print_suffix("__print__");
+ const string requant_prefix("__requant_min_max:");
+ std::vector<string> file_lines = str_util::Split(file_data, '\n');
+ for (const string& file_line : file_lines) {
+ // We expect to find a line with components separated by semicolons, so to
+ // start make sure that the basic structure is in place/
+ StringPiece line(file_line);
+ if (!line.contains(print_suffix + ";" + requant_prefix)) {
+ continue;
+ }
+ std::vector<string> line_parts = str_util::Split(file_line, ';');
+ if (line_parts.size() < 2) {
+ continue;
+ }
+ // Now we want to figure out which components have the name and min max
+ // values by scanning for the prefix we expect.
+ bool min_max_found = false;
+ int min_max_index;
+ for (int i = 1; i < line_parts.size(); ++i) {
+ StringPiece line_part(line_parts[i]);
+ if (line_part.starts_with(requant_prefix)) {
+ min_max_found = true;
+ min_max_index = i;
+ }
+ }
+ if (!min_max_found) {
+ continue;
+ }
+ // Finally we need to break out the values from the strings, and parse them
+ // into a form we can use.
+ string min_max_string = line_parts[min_max_index];
+ std::vector<string> min_max_parts = str_util::Split(min_max_string, '[');
+ if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) {
+ continue;
+ }
+ string min_string = min_max_parts[1];
+ std::vector<string> min_string_parts = str_util::Split(min_string, ']');
+ if (min_string_parts.size() != 2) {
+ continue;
+ }
+ string min_number_string = min_string_parts[0];
+ float min;
+ if (!strings::safe_strtof(min_number_string.c_str(), &min)) {
+ continue;
+ }
+ string max_string = min_max_parts[2];
+ std::vector<string> max_string_parts = str_util::Split(max_string, ']');
+ if (max_string_parts.size() != 2) {
+ continue;
+ }
+ string max_number_string = max_string_parts[0];
+ float max;
+ if (!strings::safe_strtof(max_number_string.c_str(), &max)) {
+ continue;
+ }
+ StringPiece name_string = line_parts[min_max_index - 1];
+ if (!name_string.ends_with(print_suffix)) {
+ continue;
+ }
+ string name =
+ name_string.substr(0, name_string.size() - print_suffix.size())
+ .ToString();
+ records->push_back({name, min, max});
+ }
+ return Status::OK();
+}
+
+// Uses the observed min/max values for requantization captured in a log file to
+// replace costly RequantizationRange ops with simple Consts.
+Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def) {
+ string min_max_log_file;
+ TF_RETURN_IF_ERROR(
+ context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file));
+ if (min_max_log_file == "") {
+ return errors::InvalidArgument(
+ "You must pass a file name to min_max_log_file");
+ }
+ float min_percentile;
+ TF_RETURN_IF_ERROR(
+ context.GetOneFloatParameter("min_percentile", 5.0f, &min_percentile));
+ float max_percentile;
+ TF_RETURN_IF_ERROR(
+ context.GetOneFloatParameter("max_percentile", 5.0f, &max_percentile));
+
+ std::vector<MinMaxRecord> records;
+ TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records));
+ if (records.empty()) {
+ return errors::InvalidArgument(
+ "No min/max range logs were found in the log file");
+ }
+
+ std::map<string, const NodeDef*> node_map;
+ MapNamesToNodes(input_graph_def, &node_map);
+ bool any_missing_nodes = false;
+ std::map<string, std::vector<MinMaxRecord>> records_by_node;
+ for (const MinMaxRecord& record : records) {
+ records_by_node[record.name].push_back(record);
+ if (!node_map.count(record.name)) {
+ any_missing_nodes = true;
+ LOG(WARNING) << "Node from log not found in graph: " << record.name;
+ }
+ }
+ if (any_missing_nodes) {
+ return errors::InvalidArgument(
+ "Nodes were found in the log file that aren't present in the graph");
+ }
+
+ // Now find out the largest and smallest min/max values for the node.
+ std::map<string, std::pair<float, float>> range_for_nodes;
+ for (const auto& record_info : records_by_node) {
+ const string& name = record_info.first;
+ const std::vector<MinMaxRecord> records = record_info.second;
+ std::vector<float> mins;
+ std::vector<float> maxs;
+ for (const MinMaxRecord& record : records) {
+ mins.push_back(record.min);
+ maxs.push_back(record.max);
+ }
+ std::sort(mins.begin(), mins.end());
+ std::sort(maxs.begin(), maxs.end());
+ int min_index = std::round(mins.size() * (min_percentile / 100.0f));
+ if (min_index < 0) {
+ min_index = 0;
+ }
+ int max_index =
+ std::round(maxs.size() * (1.0f - (max_percentile / 100.0f)));
+ if (max_index > (maxs.size() - 1)) {
+ max_index = maxs.size() - 1;
+ }
+ const float min = mins[min_index];
+ const float max = maxs[max_index];
+ range_for_nodes[name] = {min, max};
+ }
+ std::map<string, string> inputs_to_rename;
+ GraphDef frozen_graph_def;
+ for (const NodeDef& node : input_graph_def.node()) {
+ if (range_for_nodes.count(node.name())) {
+ if (node.op() != "RequantizationRange") {
+ return errors::InvalidArgument(
+ "Node is expected to be a RequantizationRange op: ", node.name(),
+ ", but is: ", node.op());
+ }
+ const float min_value = range_for_nodes.at(node.name()).first;
+ NodeDef* min_node = frozen_graph_def.mutable_node()->Add();
+ min_node->set_op("Const");
+ min_node->set_name(node.name() + "/frozen_min");
+ SetNodeAttr("dtype", DT_FLOAT, min_node);
+ Tensor min_tensor(DT_FLOAT, {});
+ min_tensor.flat<float>()(0) = min_value;
+ SetNodeTensorAttr<float>("value", min_tensor, min_node);
+ inputs_to_rename[node.name() + ":0"] = min_node->name() + ":0";
+
+ const float max_value = range_for_nodes.at(node.name()).second;
+ NodeDef* max_node = frozen_graph_def.mutable_node()->Add();
+ max_node->set_op("Const");
+ max_node->set_name(node.name() + "/frozen_max");
+ SetNodeAttr("dtype", DT_FLOAT, max_node);
+ Tensor max_tensor(DT_FLOAT, {});
+ max_tensor.flat<float>()(0) = max_value;
+ SetNodeTensorAttr<float>("value", max_tensor, max_node);
+ inputs_to_rename[node.name() + ":1"] = max_node->name() + ":0";
+ } else {
+ NodeDef* new_node = frozen_graph_def.mutable_node()->Add();
+ new_node->CopyFrom(node);
+ }
+ }
+ RenameNodeInputs(frozen_graph_def, inputs_to_rename,
+ std::unordered_set<string>(), output_graph_def);
+ return Status::OK();
+}
+
+REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges",
+ FreezeRequantizationRanges);
+
+} // namespace graph_transforms
+} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc
new file mode 100644
index 0000000000..ab6b2ffef0
--- /dev/null
+++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc
@@ -0,0 +1,200 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+// Declare here, so we don't need a public header.
+Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def);
+struct MinMaxRecord {
+ string name;
+ float min;
+ float max;
+};
+Status ExtractMinMaxRecords(const string& log_file_name,
+ std::vector<MinMaxRecord>* records);
+
+class FreezeRequantizationRangesTest : public ::testing::Test {
+ protected:
+ void TestFreezeRequantizationRanges() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor quantized_tensor(DT_QUINT8, TensorShape({1, 6}));
+ test::FillValues<quint8>(&quantized_tensor, {0, 0, 0, 0, 0, 0});
+ Output quantized_op = Const(root.WithOpName("quantized_op"),
+ Input::Initializer(quantized_tensor));
+
+ Tensor quantized_min_tensor(DT_FLOAT, TensorShape({}));
+ test::FillValues<float>(&quantized_min_tensor, {2.0f});
+ Output quantized_min_op = Const(root.WithOpName("quantized_min_op"),
+ Input::Initializer(quantized_min_tensor));
+
+ Tensor quantized_max_tensor(DT_FLOAT, TensorShape({}));
+ test::FillValues<float>(&quantized_max_tensor, {2.0f});
+ Output quantized_max_op = Const(root.WithOpName("quantized_max_op"),
+ Input::Initializer(quantized_min_tensor));
+
+ Tensor offset_tensor(DT_QUINT8, TensorShape({6}));
+ test::FillValues<quint8>(&offset_tensor, {1, 2, 3, 4, 5, 6});
+ Output offset_op =
+ Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
+
+ Tensor offset_min_tensor(DT_FLOAT, TensorShape({}));
+ test::FillValues<float>(&offset_min_tensor, {0.0f});
+ Output offset_min_op = Const(root.WithOpName("offset_min_op"),
+ Input::Initializer(offset_min_tensor));
+
+ Tensor offset_max_tensor(DT_FLOAT, TensorShape({}));
+ test::FillValues<float>(&offset_max_tensor, {255.0f});
+ Output offset_max_op = Const(root.WithOpName("offset_max_op"),
+ Input::Initializer(offset_max_tensor));
+
+ QuantizedBiasAdd quantized_bias_add_op(
+ root.WithOpName("bias_add_op"), quantized_op, offset_op,
+ quantized_min_op, quantized_max_op, offset_min_op, offset_max_op,
+ DT_QINT32);
+
+ RequantizationRange requantization_range_op(
+ root.WithOpName("requantization_range_op"),
+ quantized_bias_add_op.output, quantized_bias_add_op.min_out,
+ quantized_bias_add_op.max_out);
+
+ Requantize requantize_op(
+ root.WithOpName("requantize_op"), quantized_bias_add_op.output,
+ quantized_bias_add_op.min_out, quantized_bias_add_op.max_out,
+ requantization_range_op.output_min, requantization_range_op.output_max,
+ DT_QUINT8);
+
+ Output dequantize_op =
+ Dequantize(root.WithOpName("dequantize_op"), requantize_op.output,
+ requantize_op.output_min, requantize_op.output_max);
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+ const string min_max_log_file_name =
+ io::JoinPath(testing::TmpDir(), "min_max_log_file.txt");
+ {
+ std::unique_ptr<WritableFile> file;
+ TF_ASSERT_OK(
+ Env::Default()->NewWritableFile(min_max_log_file_name, &file));
+ TF_ASSERT_OK(file->Append("Something irrelevant\n"));
+ TF_ASSERT_OK(
+ file->Append("[SomePrefix] "
+ ";requantization_range_op__print__;__requant_min_max:"
+ "[-2.4313571][10.584145]\n"));
+ TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
+ }
+
+ TransformFuncContext context;
+ context.input_names = {};
+ context.output_names = {"dequantize_op"};
+ context.params = {{"min_max_log_file", {min_max_log_file_name}}};
+
+ GraphDef frozen_graph_def;
+ TF_EXPECT_OK(
+ FreezeRequantizationRanges(graph_def, context, &frozen_graph_def));
+
+ std::map<string, const NodeDef*> node_map;
+ MapNamesToNodes(frozen_graph_def, &node_map);
+ EXPECT_EQ(0, node_map.count("requantization_range_op"));
+ EXPECT_EQ(1, node_map.count("requantize_op"));
+ const string& min_input =
+ NodeNameFromInput(node_map.at("requantize_op")->input(3));
+ ASSERT_EQ(1, node_map.count(min_input));
+ EXPECT_EQ("Const", node_map.at(min_input)->op());
+ const string& max_input =
+ NodeNameFromInput(node_map.at("requantize_op")->input(4));
+ ASSERT_EQ(1, node_map.count(max_input));
+ EXPECT_EQ("Const", node_map.at(max_input)->op());
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(
+ original_session->Run({}, {"dequantize_op"}, {}, &original_outputs));
+
+ std::unique_ptr<Session> frozen_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(frozen_session->Create(frozen_graph_def));
+ std::vector<Tensor> frozen_outputs;
+ TF_ASSERT_OK(
+ frozen_session->Run({}, {"dequantize_op"}, {}, &frozen_outputs));
+
+ ASSERT_EQ(original_outputs.size(), frozen_outputs.size());
+ ASSERT_EQ(1, frozen_outputs.size());
+ test::ExpectTensorNear<float>(original_outputs[0], frozen_outputs[0], 0.5);
+ }
+
+ void TestExtractMinMaxRecords() {
+ const string min_max_log_file_name =
+ io::JoinPath(testing::TmpDir(), "min_max_log_file2.txt");
+ {
+ std::unique_ptr<WritableFile> file;
+ TF_ASSERT_OK(
+ Env::Default()->NewWritableFile(min_max_log_file_name, &file));
+ TF_ASSERT_OK(file->Append("Something irrelevant\n"));
+ TF_ASSERT_OK(
+ file->Append("[SomePrefix] "
+ ";requantization_range_op__print__;__requant_min_max:"
+ "[-2.4313571][10.584145]\n"));
+ TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
+ TF_ASSERT_OK(file->Append(
+ "[SomeOtherPrefix] "
+ ";other_requantization_range_op__print__;__requant_min_max:"
+ "[-1.0][2.0]\n"));
+ TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
+ TF_ASSERT_OK(
+ file->Append("[SomePrefix] "
+ ";requantization_range_op__print__;__requant_min_max:"
+ "[-1.bad][2.0]\n"));
+ }
+ std::vector<MinMaxRecord> records;
+ TF_ASSERT_OK(ExtractMinMaxRecords(min_max_log_file_name, &records));
+ ASSERT_EQ(2, records.size());
+ EXPECT_EQ("requantization_range_op", records[0].name);
+ EXPECT_NEAR(-2.4313571f, records[0].min, 1e-5f);
+ EXPECT_NEAR(10.584145f, records[0].max, 1e-5f);
+ EXPECT_EQ("other_requantization_range_op", records[1].name);
+ EXPECT_NEAR(-1.0f, records[1].min, 1e-5f);
+ EXPECT_NEAR(2.0f, records[1].max, 1e-5f);
+ }
+};
+
+TEST_F(FreezeRequantizationRangesTest, TestFreezeRequantizationRanges) {
+ TestFreezeRequantizationRanges();
+}
+
+TEST_F(FreezeRequantizationRangesTest, TestExtractMinMaxRecords) {
+ TestExtractMinMaxRecords();
+}
+
+} // namespace graph_transforms
+} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/insert_logging.cc b/tensorflow/tools/graph_transforms/insert_logging.cc
new file mode 100644
index 0000000000..c9d72f3d7d
--- /dev/null
+++ b/tensorflow/tools/graph_transforms/insert_logging.cc
@@ -0,0 +1,153 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
+
+#include "tensorflow/core/common_runtime/constant_folding.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+// Clears the device field of all ops in the graph.
+Status InsertLogging(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def) {
+ std::unordered_set<string> ops;
+ bool has_ops;
+ if (context.params.count("op")) {
+ has_ops = true;
+ for (const string& op : context.params.at("op")) {
+ ops.insert(op);
+ }
+ } else {
+ has_ops = false;
+ }
+
+ std::unordered_set<string> prefixes;
+ bool has_prefixes;
+ if (context.params.count("prefix")) {
+ has_prefixes = true;
+ for (const string& prefix : context.params.at("prefix")) {
+ prefixes.insert(prefix);
+ }
+ } else {
+ has_prefixes = false;
+ }
+
+ string message;
+ TF_RETURN_IF_ERROR(context.GetOneStringParameter("message", "", &message));
+
+ bool show_name;
+ TF_RETURN_IF_ERROR(
+ context.GetOneBoolParameter("show_name", false, &show_name));
+
+ bool show_op;
+ TF_RETURN_IF_ERROR(context.GetOneBoolParameter("show_op", false, &show_op));
+
+ int32 first_n;
+ TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("first_n", -1, &first_n));
+
+ int32 summarize;
+ TF_RETURN_IF_ERROR(
+ context.GetOneInt32Parameter("summarize", 1024, &summarize));
+
+ std::unordered_map<string, std::set<int>> node_outputs;
+ for (const NodeDef& node : input_graph_def.node()) {
+ for (const string& input : node.input()) {
+ const string canonical_input = CanonicalInputName(input);
+ string prefix;
+ string name;
+ string suffix;
+ NodeNamePartsFromInput(canonical_input, &prefix, &name, &suffix);
+ const string output_index_string = suffix.substr(1, suffix.size() - 1);
+ int32 output_index;
+ if (!strings::safe_strto32(output_index_string, &output_index)) {
+ return errors::InvalidArgument("Couldn't understand output number in ",
+ input);
+ }
+ node_outputs[name].insert(output_index);
+ }
+ }
+
+ std::map<string, string> inputs_to_rename;
+ std::unordered_set<string> ignore_when_renaming;
+ GraphDef logged_graph_def;
+ for (const NodeDef& node : input_graph_def.node()) {
+ NodeDef* new_node = logged_graph_def.mutable_node()->Add();
+ new_node->CopyFrom(node);
+ if (node_outputs[node.name()].empty()) {
+ // There were no outputs found to this node, so skip it.
+ continue;
+ }
+ const bool op_matches = (ops.count(node.op()) > 0);
+ bool prefix_matches = false;
+ for (const string& prefix : prefixes) {
+ if (StringPiece(node.name()).starts_with(prefix)) {
+ prefix_matches = true;
+ }
+ }
+ // If we're not looking for ops, or we found the right op, and if we're not
+ // looking for prefixes or we found the right prefix, then add logging here.
+ if ((!has_ops || op_matches) && (!has_prefixes || prefix_matches)) {
+ const string name_suffix = "__print__";
+ DataTypeVector input_types;
+ DataTypeVector output_types;
+ TF_RETURN_IF_ERROR(GetInOutTypes(node, &input_types, &output_types));
+ NodeDef* print_node = logged_graph_def.mutable_node()->Add();
+ print_node->set_op("Print");
+ print_node->set_name(strings::StrCat(node.name(), name_suffix));
+ string node_message;
+ if (show_op) {
+ node_message += ";" + node.op() + ";";
+ }
+ if (show_name) {
+ node_message += ";" + print_node->name() + ";";
+ }
+ node_message += message;
+ SetNodeAttr("message", node_message, print_node);
+ SetNodeAttr("first_n", first_n, print_node);
+ SetNodeAttr("summarize", summarize, print_node);
+ print_node->add_input(node.name() + ":0");
+ SetNodeAttr("T", output_types[0], print_node);
+ for (int output_index : node_outputs[node.name()]) {
+ print_node->add_input(strings::StrCat(node.name(), ":", output_index));
+ }
+ SetNodeAttr("U", output_types, print_node);
+ ignore_when_renaming.insert(print_node->name());
+ // Rewrite the graph so all references to the first input of the original
+ // op now pull from the print op instead, so it's executed.
+ inputs_to_rename[node.name() + ":0"] =
+ strings::StrCat(node.name(), name_suffix, ":0");
+ }
+ }
+
+ output_graph_def->Clear();
+ RenameNodeInputs(logged_graph_def, inputs_to_rename, ignore_when_renaming,
+ output_graph_def);
+
+ return Status::OK();
+}
+
+REGISTER_GRAPH_TRANSFORM("insert_logging", InsertLogging);
+
+} // namespace graph_transforms
+} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/insert_logging_test.cc b/tensorflow/tools/graph_transforms/insert_logging_test.cc
new file mode 100644
index 0000000000..e1586a46e5
--- /dev/null
+++ b/tensorflow/tools/graph_transforms/insert_logging_test.cc
@@ -0,0 +1,203 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+// Declare here, so we don't need a public header.
+Status InsertLogging(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def);
+
+class InsertLoggingTest : public ::testing::Test {
+ protected:
+ void CheckGraphCanRun(const GraphDef& graph_def,
+ const std::vector<string>& output_names) {
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(session->Run({}, output_names, {}, &outputs));
+ }
+
+ void TestInsertLogging() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Tensor const_tensor(DT_FLOAT, TensorShape({10}));
+ test::FillIota<float>(&const_tensor, 1.0f);
+ Output const_node1 =
+ Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
+ Output const_node2 =
+ Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
+ Output const_node3 =
+ Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
+ Output add_node2 =
+ Add(root.WithOpName("add_node2"), const_node1, const_node2);
+ Output add_node3 =
+ Add(root.WithOpName("add_node3"), const_node1, const_node3);
+ Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
+ Output add_node4 =
+ Add(root.WithOpName("add_node4"), mul_node1, const_node3);
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ CheckGraphCanRun(graph_def, {"add_node4"});
+
+ GraphDef result;
+ TransformFuncContext context;
+ context.input_names = {};
+ context.output_names = {"add_node4"};
+ TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
+
+ CheckGraphCanRun(result, {"add_node4"});
+
+ std::unordered_set<string> print_inputs;
+ for (const NodeDef& node : result.node()) {
+ if (node.op() == "Print") {
+ print_inputs.insert(node.input(0));
+ }
+ }
+
+ EXPECT_EQ(6, print_inputs.size());
+ EXPECT_EQ(1, print_inputs.count("mul_node1:0"));
+ EXPECT_EQ(1, print_inputs.count("add_node2:0"));
+ EXPECT_EQ(1, print_inputs.count("add_node3:0"));
+ EXPECT_EQ(0, print_inputs.count("add_node4:0"));
+ EXPECT_EQ(1, print_inputs.count("const_node1:0"));
+ EXPECT_EQ(1, print_inputs.count("const_node2:0"));
+ EXPECT_EQ(1, print_inputs.count("const_node3:0"));
+ }
+
+ void TestInsertLoggingByOpType() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Tensor const_tensor(DT_FLOAT, TensorShape({10}));
+ test::FillIota<float>(&const_tensor, 1.0f);
+ Output const_node1 =
+ Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
+ Output const_node2 =
+ Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
+ Output const_node3 =
+ Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
+ Output add_node2 =
+ Add(root.WithOpName("add_node2"), const_node1, const_node2);
+ Output add_node3 =
+ Add(root.WithOpName("add_node3"), const_node1, const_node3);
+ Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
+ Output add_node4 =
+ Add(root.WithOpName("add_node4"), mul_node1, const_node3);
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ CheckGraphCanRun(graph_def, {"add_node4"});
+
+ GraphDef result;
+ TransformFuncContext context;
+ context.input_names = {};
+ context.output_names = {"add_node4"};
+ context.params.insert(
+ std::pair<string, std::vector<string>>({"op", {"Mul", "Add"}}));
+ TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
+
+ CheckGraphCanRun(result, {"add_node4"});
+
+ std::unordered_set<string> print_inputs;
+ for (const NodeDef& node : result.node()) {
+ if (node.op() == "Print") {
+ print_inputs.insert(node.input(0));
+ }
+ }
+
+ EXPECT_EQ(3, print_inputs.size());
+ EXPECT_EQ(1, print_inputs.count("mul_node1:0"));
+ EXPECT_EQ(1, print_inputs.count("add_node2:0"));
+ EXPECT_EQ(1, print_inputs.count("add_node3:0"));
+ EXPECT_EQ(0, print_inputs.count("add_node4:0"));
+ EXPECT_EQ(0, print_inputs.count("const_node1:0"));
+ EXPECT_EQ(0, print_inputs.count("const_node2:0"));
+ EXPECT_EQ(0, print_inputs.count("const_node3:0"));
+ }
+
+ void TestInsertLoggingByPrefix() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Tensor const_tensor(DT_FLOAT, TensorShape({10}));
+ test::FillIota<float>(&const_tensor, 1.0f);
+ Output const_node1 =
+ Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
+ Output const_node2 =
+ Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
+ Output const_node3 =
+ Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
+ Output add_node2 =
+ Add(root.WithOpName("add_node2"), const_node1, const_node2);
+ Output add_node3 =
+ Add(root.WithOpName("add_node3"), const_node1, const_node3);
+ Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
+ Output add_node4 =
+ Add(root.WithOpName("add_node4"), mul_node1, const_node3);
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ CheckGraphCanRun(graph_def, {"add_node4"});
+
+ GraphDef result;
+ TransformFuncContext context;
+ context.input_names = {};
+ context.output_names = {"add_node4"};
+ context.params.insert(
+ std::pair<string, std::vector<string>>({"prefix", {"add_node"}}));
+ TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
+
+ CheckGraphCanRun(result, {"add_node4"});
+
+ std::unordered_set<string> print_inputs;
+ for (const NodeDef& node : result.node()) {
+ if (node.op() == "Print") {
+ print_inputs.insert(node.input(0));
+ }
+ }
+
+ EXPECT_EQ(2, print_inputs.size());
+ EXPECT_EQ(0, print_inputs.count("mul_node1:0"));
+ EXPECT_EQ(1, print_inputs.count("add_node2:0"));
+ EXPECT_EQ(1, print_inputs.count("add_node3:0"));
+ EXPECT_EQ(0, print_inputs.count("add_node4:0"));
+ EXPECT_EQ(0, print_inputs.count("const_node1:0"));
+ EXPECT_EQ(0, print_inputs.count("const_node2:0"));
+ EXPECT_EQ(0, print_inputs.count("const_node3:0"));
+ }
+};
+
+TEST_F(InsertLoggingTest, TestInsertLogging) { TestInsertLogging(); }
+
+TEST_F(InsertLoggingTest, TestInsertLoggingByOpType) {
+ TestInsertLoggingByOpType();
+}
+
+TEST_F(InsertLoggingTest, TestInsertLoggingByPrefix) {
+ TestInsertLoggingByPrefix();
+}
+
+} // namespace graph_transforms
+} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc
index fa089b86ef..f460f31d35 100644
--- a/tensorflow/tools/graph_transforms/quantize_nodes.cc
+++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc
@@ -236,7 +236,8 @@ Status MergeDuplicateNodes(const GraphDef& input_graph_def,
}
// Update the graph so that any nodes that referred to removed inputs now
// pull from the remaining duplicate.
- RenameNodeInputs(merged_graph_def, inputs_to_rename, &current_graph_def);
+ RenameNodeInputs(merged_graph_def, inputs_to_rename,
+ std::unordered_set<string>(), &current_graph_def);
} while (any_duplicates_found);
*output_graph_def = current_graph_def;
@@ -295,7 +296,8 @@ Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
},
{true}, &replaced_graph_def));
- RenameNodeInputs(replaced_graph_def, inputs_to_rename, output_graph_def);
+ RenameNodeInputs(replaced_graph_def, inputs_to_rename,
+ std::unordered_set<string>(), output_graph_def);
return Status::OK();
}
@@ -372,9 +374,9 @@ Status QuantizePlaceholders(const GraphDef& input_graph_def,
GraphDef first_pass_graph_def;
RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass,
- &first_pass_graph_def);
+ std::unordered_set<string>(), &first_pass_graph_def);
RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass,
- output_graph_def);
+ std::unordered_set<string>(), output_graph_def);
return Status::OK();
}
diff --git a/tensorflow/tools/graph_transforms/remove_nodes.cc b/tensorflow/tools/graph_transforms/remove_nodes.cc
index 3290e65512..429dbdd0b1 100644
--- a/tensorflow/tools/graph_transforms/remove_nodes.cc
+++ b/tensorflow/tools/graph_transforms/remove_nodes.cc
@@ -80,7 +80,7 @@ Status RemoveNodes(const GraphDef& input_graph_def,
{true}, &replaced_graph_def));
// Make sure all references to removed nodes now point to their inputs.
RenameNodeInputs(replaced_graph_def, inputs_to_rename,
- &current_graph_def);
+ std::unordered_set<string>(), &current_graph_def);
} while (any_nodes_removed);
}
diff --git a/tensorflow/tools/graph_transforms/round_weights.cc b/tensorflow/tools/graph_transforms/round_weights.cc
index 6332876077..72927e439b 100644
--- a/tensorflow/tools/graph_transforms/round_weights.cc
+++ b/tensorflow/tools/graph_transforms/round_weights.cc
@@ -33,8 +33,9 @@ namespace graph_transforms {
Status RoundWeights(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
- int64 num_steps;
- TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
+ int32 num_steps;
+ TF_RETURN_IF_ERROR(
+ context.GetOneInt32Parameter("num_steps", 256, &num_steps));
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, {"Const"},
[num_steps](const NodeMatch& match, const std::set<string>& input_nodes,
diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc
index ecdae72c11..310c331e8f 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils.cc
@@ -469,6 +469,7 @@ Status ReplaceMatchingOpTypes(
Status RenameNodeInputs(const GraphDef& input_graph_def,
const std::map<string, string>& inputs_to_rename,
+ const std::unordered_set<string>& nodes_to_ignore,
GraphDef* output_graph_def) {
std::map<string, std::vector<std::pair<string, string>>>
canonical_inputs_to_rename;
@@ -494,6 +495,9 @@ Status RenameNodeInputs(const GraphDef& input_graph_def,
input_node_name);
}
already_visited.insert(input_node_name);
+ if (nodes_to_ignore.count(node.name())) {
+ break;
+ }
bool any_match_found = false;
for (const std::pair<string, string>& input_to_rename :
canonical_inputs_to_rename.at(input_node_name)) {
@@ -630,9 +634,26 @@ Status TransformFuncContext::GetOneStringParameter(const string& name,
}
}
-Status TransformFuncContext::GetOneIntParameter(const string& name,
- int64 default_value,
- int64* result) const {
+Status TransformFuncContext::GetOneInt32Parameter(const string& name,
+ int32 default_value,
+ int32* result) const {
+ const int params_count = CountParameters(name);
+ if (params_count == 0) {
+ *result = default_value;
+ return Status::OK();
+ }
+ string string_value;
+ TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
+ if (!strings::safe_strto32(StringPiece(string_value), result)) {
+ return errors::InvalidArgument("Couldn't interpret the ", name,
+ " argument as a number:", string_value);
+ }
+ return Status::OK();
+}
+
+Status TransformFuncContext::GetOneInt64Parameter(const string& name,
+ int64 default_value,
+ int64* result) const {
const int params_count = CountParameters(name);
if (params_count == 0) {
*result = default_value;
diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h
index 2c98440907..54808efa9f 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.h
+++ b/tensorflow/tools/graph_transforms/transform_utils.h
@@ -17,8 +17,10 @@ limitations under the License.
#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
#include <set>
+#include <unordered_set>
#include <vector>
+#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -201,9 +203,11 @@ Status ReplaceMatchingOpTypes(
// Returns a list of the unique nodes found in this match.
void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result);
-// Changes all input references to a particular node name.
+// Changes all input references to a particular node name. Any nodes with names
+// listed in nodes_to_ignore will not have their inputs rewritten.
Status RenameNodeInputs(const GraphDef& input_graph_def,
const std::map<string, string>& inputs_to_rename,
+ const std::unordered_set<string>& nodes_to_ignore,
GraphDef* output_graph_def);
// Utility function that copies all the nodes found in a match into the
@@ -225,11 +229,17 @@ struct TransformFuncContext {
Status GetOneStringParameter(const string& name, const string& default_value,
string* result) const;
- // Gets a single occurrence of a parameter as an integer, falling back to a
- // default if it isn't present and returning an error if it isn't convertible
- // to a number.
- Status GetOneIntParameter(const string& name, int64 default_value,
- int64* result) const;
+ // Gets a single occurrence of a parameter as a 32-bit integer, falling back
+ // to a default if it isn't present and returning an error if it isn't
+ // convertible to a number.
+ Status GetOneInt32Parameter(const string& name, int32 default_value,
+ int32* result) const;
+
+ // Gets a single occurrence of a parameter as a 64-bit integer, falling back
+ // to a default if it isn't present and returning an error if it isn't
+ // convertible to a number.
+ Status GetOneInt64Parameter(const string& name, int64 default_value,
+ int64* result) const;
// Gets a single occurrence of a parameter as a floating point number, falling
// back to a default if it isn't present and returning an error if it isn't
diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc
index 9e6ddb46b7..92ebc35834 100644
--- a/tensorflow/tools/graph_transforms/transform_utils_test.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc
@@ -541,7 +541,9 @@ class TransformUtilsTest : public ::testing::Test {
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef renamed_graph_def;
- TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, &renamed_graph_def));
+ TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}},
+ std::unordered_set<string>(),
+ &renamed_graph_def));
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(renamed_graph_def, &node_map);
@@ -579,7 +581,7 @@ class TransformUtilsTest : public ::testing::Test {
GraphDef renamed_graph_def;
TF_ASSERT_OK(RenameNodeInputs(
graph_def, {{"a", "f"}, {"f", "e"}, {"e", "d"}, {"d", "c"}},
- &renamed_graph_def));
+ std::unordered_set<string>(), &renamed_graph_def));
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(renamed_graph_def, &node_map);
@@ -615,8 +617,9 @@ class TransformUtilsTest : public ::testing::Test {
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef renamed_graph_def;
- Status rename_status = RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}},
- &renamed_graph_def);
+ Status rename_status =
+ RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}},
+ std::unordered_set<string>(), &renamed_graph_def);
EXPECT_FALSE(rename_status.ok());
}
@@ -650,6 +653,7 @@ class TransformUtilsTest : public ::testing::Test {
GraphDef renamed_graph_def;
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"quantize_a:*", "quantize_b"}},
+ std::unordered_set<string>(),
&renamed_graph_def));
std::map<string, const NodeDef*> node_map;
@@ -658,6 +662,45 @@ class TransformUtilsTest : public ::testing::Test {
EXPECT_EQ("quantize_b:2", node_map.at("add")->input(1));
}
+ void TestRenameNodeInputsWithIgnores() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ const int width = 10;
+
+ Tensor a_data(DT_FLOAT, TensorShape({width}));
+ test::FillIota<float>(&a_data, 1.0f);
+ Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
+
+ Tensor b_data(DT_FLOAT, TensorShape({width}));
+ test::FillIota<float>(&b_data, 1.0f);
+ Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
+
+ Output add = Add(root.WithOpName("add"), a_const, a_const);
+
+ Output add2 = Add(root.WithOpName("add2"), a_const, a_const);
+
+ Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
+
+ Output mul = Mul(root.WithOpName("mul"), add, placeholder);
+
+ Output mul2 = Mul(root.WithOpName("output"), mul, add2);
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+ GraphDef renamed_graph_def;
+ TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, {"add2"},
+ &renamed_graph_def));
+
+ std::map<string, const NodeDef*> node_map;
+ MapNamesToNodes(renamed_graph_def, &node_map);
+ EXPECT_EQ("b", node_map.at("add")->input(0));
+ EXPECT_EQ("b", node_map.at("add")->input(1));
+ EXPECT_EQ("a", node_map.at("add2")->input(0));
+ EXPECT_EQ("a", node_map.at("add2")->input(1));
+ }
+
void TestFindInvalidInputs() {
GraphDef graph_def;
@@ -943,20 +986,36 @@ class TransformUtilsTest : public ::testing::Test {
EXPECT_EQ("d", value);
}
- void TestGetOneIntParameter() {
+ void TestGetOneInt32Parameter() {
+ TransformFuncContext context;
+ context.params.insert({"foo", {"10", "20"}});
+ context.params.insert({"bar", {"-23"}});
+ context.params.insert({"not_a_number", {"not_numerical"}});
+ context.params.insert({"float", {"-23.232323"}});
+ int32 value;
+ TF_EXPECT_OK(context.GetOneInt32Parameter("bar", 0, &value));
+ EXPECT_EQ(-23, value);
+ EXPECT_FALSE(context.GetOneInt32Parameter("foo", 0, &value).ok());
+ TF_EXPECT_OK(context.GetOneInt32Parameter("not_present", 10, &value));
+ EXPECT_EQ(10, value);
+ EXPECT_FALSE(context.GetOneInt32Parameter("not_a_number", 0, &value).ok());
+ EXPECT_FALSE(context.GetOneInt32Parameter("float", 0, &value).ok());
+ }
+
+ void TestGetOneInt64Parameter() {
TransformFuncContext context;
context.params.insert({"foo", {"10", "20"}});
context.params.insert({"bar", {"-23"}});
context.params.insert({"not_a_number", {"not_numerical"}});
context.params.insert({"float", {"-23.232323"}});
int64 value;
- TF_EXPECT_OK(context.GetOneIntParameter("bar", 0, &value));
+ TF_EXPECT_OK(context.GetOneInt64Parameter("bar", 0, &value));
EXPECT_EQ(-23, value);
- EXPECT_FALSE(context.GetOneIntParameter("foo", 0, &value).ok());
- TF_EXPECT_OK(context.GetOneIntParameter("not_present", 10, &value));
+ EXPECT_FALSE(context.GetOneInt64Parameter("foo", 0, &value).ok());
+ TF_EXPECT_OK(context.GetOneInt64Parameter("not_present", 10, &value));
EXPECT_EQ(10, value);
- EXPECT_FALSE(context.GetOneIntParameter("not_a_number", 0, &value).ok());
- EXPECT_FALSE(context.GetOneIntParameter("float", 0, &value).ok());
+ EXPECT_FALSE(context.GetOneInt64Parameter("not_a_number", 0, &value).ok());
+ EXPECT_FALSE(context.GetOneInt64Parameter("float", 0, &value).ok());
}
void TestGetOneFloatParameter() {
@@ -1111,6 +1170,10 @@ TEST_F(TransformUtilsTest, TestRenameNodeInputsWithWildcard) {
TestRenameNodeInputsWithWildcard();
}
+TEST_F(TransformUtilsTest, TestRenameNodeInputsWithIgnores) {
+ TestRenameNodeInputsWithIgnores();
+}
+
TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); }
TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); }
@@ -1127,7 +1190,13 @@ TEST_F(TransformUtilsTest, TestGetOneStringParameter) {
TestGetOneStringParameter();
}
-TEST_F(TransformUtilsTest, TestGetOneIntParameter) { TestGetOneIntParameter(); }
+TEST_F(TransformUtilsTest, TestGetOneInt32Parameter) {
+ TestGetOneInt32Parameter();
+}
+
+TEST_F(TransformUtilsTest, TestGetOneInt64Parameter) {
+ TestGetOneInt64Parameter();
+}
TEST_F(TransformUtilsTest, TestGetOneFloatParameter) {
TestGetOneFloatParameter();