aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2017-11-22 13:42:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-22 13:50:02 -0800
commitb1d8c59e9b014b527fb2fbef9ce9afc14dbc4938 (patch)
treeaf207d5a90f3176bdd3fbffbe1e98258125bf389 /tensorflow/tools/graph_transforms
parente219aeb542779d90a582ffe16f8602cd1b275b22 (diff)
Merge changes from github.
PiperOrigin-RevId: 176695926
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/BUILD2
-rw-r--r--tensorflow/tools/graph_transforms/quantize_nodes.cc2
2 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 1bf7113c9e..9216008600 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -131,6 +131,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
+ "//tensorflow/contrib/rnn:gru_ops_op_lib",
+ "//tensorflow/contrib/rnn:lstm_ops_op_lib",
] + if_not_windows([
"//tensorflow/core/kernels:quantized_ops",
"//tensorflow/core/kernels:remote_fused_graph_rewriter_transform",
diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc
index 2b85e7e83c..97e8f77616 100644
--- a/tensorflow/tools/graph_transforms/quantize_nodes.cc
+++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc
@@ -759,6 +759,7 @@ Status QuantizeNodes(const GraphDef& input_graph_def,
NodeDef reshape_dims;
reshape_dims.set_op("Const");
reshape_dims.set_name(unique_input_name + "/reshape_dims");
+ AddNodeInput("^" + input_name, &reshape_dims);
SetNodeAttr("dtype", DT_INT32, &reshape_dims);
Tensor reshape_dims_tensor(DT_INT32, {1});
reshape_dims_tensor.flat<int32>()(0) = -1;
@@ -768,6 +769,7 @@ Status QuantizeNodes(const GraphDef& input_graph_def,
NodeDef reduction_dims;
reduction_dims.set_op("Const");
reduction_dims.set_name(unique_input_name + "/reduction_dims");
+ AddNodeInput("^" + input_name, &reduction_dims);
SetNodeAttr("dtype", DT_INT32, &reduction_dims);
Tensor reduction_dims_tensor(DT_INT32, {1});
reduction_dims_tensor.flat<int32>()(0) = 0;