aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-07 20:23:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-07 20:27:26 -0700
commitd27db78cd0168f10b308f7508c11dfaa3c6707e9 (patch)
treecec20fd5c5579626ecb9271e39d1ac59b54dac10
parent2b5011625bcab6c50c51b948e68063393711bd30 (diff)
Implement c++ gradients for data_flow operators.
Closes #12856 PiperOrigin-RevId: 167949574
-rw-r--r--tensorflow/cc/BUILD31
-rw-r--r--tensorflow/cc/gradients/data_flow_grad.cc155
-rw-r--r--tensorflow/cc/gradients/data_flow_grad_test.cc69
3 files changed, 255 insertions, 0 deletions
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index c6d5792f49..028de60880 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -91,6 +91,7 @@ cc_library(
name = "grad_ops",
deps = [
":array_grad",
+ ":data_flow_grad",
":math_grad",
":nn_grad",
],
@@ -363,6 +364,36 @@ tf_cc_test(
],
)
+cc_library(
+ name = "data_flow_grad",
+ srcs = ["gradients/data_flow_grad.cc"],
+ deps = [
+ ":cc_ops",
+ ":cc_ops_internal",
+ ":grad_op_registry",
+ ":gradients",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "gradients_data_flow_grad_test",
+ size = "small",
+ srcs = ["gradients/data_flow_grad_test.cc"],
+ deps = [
+ ":cc_ops",
+ ":data_flow_grad",
+ ":grad_op_registry",
+ ":grad_testutil",
+ ":gradient_checker",
+ ":testutil",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_gen_op_wrappers_cc(
name = "cc_ops",
op_lib_names = [
diff --git a/tensorflow/cc/gradients/data_flow_grad.cc b/tensorflow/cc/gradients/data_flow_grad.cc
new file mode 100644
index 0000000000..496254bfc7
--- /dev/null
+++ b/tensorflow/cc/gradients/data_flow_grad.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/data_flow_ops.h"
+#include "tensorflow/cc/ops/data_flow_ops_internal.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+
+#include "tensorflow/cc/framework/grad_op_registry.h"
+#include "tensorflow/cc/framework/gradients.h"
+
+namespace tensorflow {
+namespace ops {
+namespace {
+
+REGISTER_NO_GRADIENT_OP("Queue");
+REGISTER_NO_GRADIENT_OP("QueueEnqueue");
+REGISTER_NO_GRADIENT_OP("QueueEnqueueMany");
+REGISTER_NO_GRADIENT_OP("QueueDequeue");
+REGISTER_NO_GRADIENT_OP("QueueDequeueMany");
+REGISTER_NO_GRADIENT_OP("QueueDequeueUpTo");
+REGISTER_NO_GRADIENT_OP("QueueClose");
+REGISTER_NO_GRADIENT_OP("QueueSize");
+REGISTER_NO_GRADIENT_OP("Stack");
+REGISTER_NO_GRADIENT_OP("StackPush");
+REGISTER_NO_GRADIENT_OP("StackPop");
+REGISTER_NO_GRADIENT_OP("StackClose");
+REGISTER_NO_GRADIENT_OP("GetSessionHandle");
+REGISTER_NO_GRADIENT_OP("GetSessionHandleV2");
+REGISTER_NO_GRADIENT_OP("GetSessionTensor");
+REGISTER_NO_GRADIENT_OP("DeleteSessionTensor");
+
+Status DynamicPartitionGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // DynamicPartition only moves input values into various positions
+ // in the output, so the gradient operation only has to map incoming
+ // gradients into their input source locations.
+ // running example:
+ // data = [10, 20, 30, 40, 50]
+ // partitions = [0, 0, 1, 1, 0]
+ // num_partitions = 2
+ // dynamic_partition(data, partitions, num_partitions) = {
+ // [10, 20, 50],
+ // [30, 40]
+ // }
+ // grads = {
+ // [g1, g2, g3],
+ // [g4, g5]
+ // }
+ // The desired propagation of the gradients back to the data inputs is:
+ // [g1, g2, g4, g5, g3]
+ auto data = op.input(0);
+ auto partitions = op.input(1);
+ int32 num_partitions;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "num_partitions", &num_partitions));
+
+ // Note: the shape of the partitions is a prefix of the data shape.
+ // shape(partitions) = [5]
+ auto partitions_shape = Shape(scope, partitions);
+ // We now create a partitions-shaped tensor with integers from
+ // [0..size(partitions)) This will be dynamic_partitioned with the
+ // input parameters, providing the destination index for a given
+ // source item.
+ // partitions_size = prod([5]) = 5
+ // reshape(range(partitions_size), [5]) = [0, 1, 2, 3, 4]
+ auto zero = Const(scope, 0);
+ auto one = Const(scope, 1);
+ auto original_indices = Reshape(
+ scope, Range(scope, zero, Prod(scope, partitions_shape, zero), one),
+ partitions_shape);
+ // dynamic_partition(
+ // [0, 1, 2, 3, 4],
+ // [0, 0, 1, 1, 0], 2)
+ // = { [0, 1, 4],
+ // [2, 3] }
+ auto partitioned_indices =
+ DynamicPartition(scope, original_indices, partitions, num_partitions);
+
+ // Invert these indices with dynamic_stitch to map the incoming
+ // gradients to their source inputs.
+ // dynamic_stitch(
+ // { [0, 1, 4], [2, 3] },
+ // { [g1, g2, g3], [g4, g5] })
+ // = [g1, g2, g4, g5, g3]
+ auto reconstructed =
+ DynamicStitch(scope, partitioned_indices.outputs, grad_inputs);
+ // reshape back into a data-shaped tensor to propagate gradients for the data
+ // input.
+ grad_outputs->push_back(Reshape(scope, reconstructed, Shape(scope, data)));
+ // Stop propagation along the partitions input
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("DynamicPartition", DynamicPartitionGrad);
+
+Status DynamicStitchGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // Running example:
+ // indices = {2, [1, 0]}
+ // data = {[d_1, d_2], [[d_3, d_4], [d_5, d_6]]}
+ // out = [[d_5, d_6], [d_3, d_4], [d_1, d_2]]
+ // grad = [[g_1, g_2], [g_3, g_4], [g_5, g_6]]
+
+ // indices and data are two equal-sized lists passed
+ // into DynamicStitch.
+ // num_values = 2
+ int32 num_values = op.num_inputs() / 2;
+
+ // Stop propagation along the indices list
+ for (int32 i = 0; i < num_values; i++) {
+ grad_outputs->push_back(NoGradient());
+ }
+
+ // DynamicStitch shuffles its data to the output (using items in
+ // indices) so the gradient propagated to a given data input simply
+ // selects the gradient for its output position.
+ for (int32 i = 0; i < num_values; i++) {
+ // index has the destination positions for the i'th data
+ // element. We cast it into an int32 if necessary, so we can use
+ // it from a Gather op.
+ // i = 0: index = 2
+ // i = 1: index = [1, 0]
+ auto index = op.input(i);
+ if (index.type() != DT_INT32) {
+ index = Cast(scope, index, DT_INT32);
+ }
+ // Gather the index specified locations in the gradient and
+ // propagate it as the gradient for the i'th data item.
+ // i = 0: gather(grad, 2) = [g_5, g_6]
+ // i = 1: gather(grad, [1, 0]) = [[g_3, g_4], [g_1, g_2]]
+ grad_outputs->push_back(Gather(scope, grad_inputs[0], index));
+ }
+
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("DynamicStitch", DynamicStitchGrad);
+REGISTER_GRADIENT_OP("ParallelDynamicStitch", DynamicStitchGrad);
+
+} // anonymous namespace
+} // namespace ops
+} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/data_flow_grad_test.cc b/tensorflow/cc/gradients/data_flow_grad_test.cc
new file mode 100644
index 0000000000..3d027909f0
--- /dev/null
+++ b/tensorflow/cc/gradients/data_flow_grad_test.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/framework/grad_op_registry.h"
+#include "tensorflow/cc/framework/gradient_checker.h"
+#include "tensorflow/cc/framework/testutil.h"
+#include "tensorflow/cc/gradients/grad_testutil.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+using namespace ops; // NOLINT(build/namespaces)
+
+namespace {
+
+class DataFlowGradTest : public ::testing::Test {
+ protected:
+ DataFlowGradTest() : scope_(Scope::NewRootScope()) {}
+
+ void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
+ const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
+ TF_ASSERT_OK(scope_.status());
+ float max_error;
+ TF_ASSERT_OK(
+ ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
+ EXPECT_LT(max_error, 1e-4);
+ }
+
+ Scope scope_;
+};
+
+TEST_F(DataFlowGradTest, DynamicPartitionGrad) {
+ TensorShape data_shape({2, 3, 2});
+ auto data = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(data_shape));
+ auto partitions = Const(scope_, {{2, 1, 0}, {1, 2, 0}});
+ auto y = DynamicPartition(scope_, data, partitions, 3);
+ TensorShape partition_shape({2, 2});
+ RunTest({data}, {data_shape}, y.outputs,
+ {partition_shape, partition_shape, partition_shape});
+}
+
+TEST_F(DataFlowGradTest, DynamicStitchGrad) {
+ TensorShape d1_shape({2});
+ TensorShape d2_shape({2, 2});
+ std::vector<Output> indices = {Const(scope_, 2), Const(scope_, {1, 0})};
+ std::vector<Output> data = {
+ Placeholder(scope_, DT_FLOAT, Placeholder::Shape(d1_shape)),
+ Placeholder(scope_, DT_FLOAT, Placeholder::Shape(d2_shape))};
+ auto y = DynamicStitch(scope_, indices, data);
+ TensorShape y_shape({3, 2});
+ RunTest(data, {d1_shape, d2_shape}, {y}, {y_shape});
+}
+
+} // namespace
+} // namespace tensorflow