aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2017-12-28 16:04:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-28 16:08:58 -0800
commit20765b3e1ae3b718699592c98aa9805cb874b6d1 (patch)
treeb429a74cd0046404644f34cc8fe6ff2cab78bb85 /tensorflow/cc/gradients
parent2e2715baa84720f786b38d1f9cb6887399020d6f (diff)
Merge changes from github.
PiperOrigin-RevId: 180301735
Diffstat (limited to 'tensorflow/cc/gradients')
-rw-r--r--tensorflow/cc/gradients/math_grad.cc212
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc17
2 files changed, 229 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index ebc0c77828..afd92fbf48 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -473,6 +473,41 @@ Status AddNGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("AddN", AddNGrad);
+Status PowGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto x = ConjugateHelper(scope, op.input(0));
+ auto y = ConjugateHelper(scope, op.input(1));
+ auto z = ConjugateHelper(scope, op.output(0));
+ auto grad = grad_inputs[0];
+ // grad * y * pow(x, y - 1)
+ auto one = Cast(scope, Const(scope, 1.0), y.type());
+ auto gx_1 = Mul(scope,
+ Mul(scope, grad, y),
+ Pow(scope, x, Sub(scope, y, one)));
+ // Avoid false singularity at x = 0
+ DataType x_dtype = x.type();
+ auto zero = Cast(scope, Const(scope, 0.0), x_dtype);
+ if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) {
+ // real(x) < 0 is fine for the complex case
+ auto log_x = Where3(scope,
+ NotEqual(scope, x, zero),
+ Log(scope, x),
+ ZerosLike(scope, x));
+ auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
+ return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
+ } else {
+ // There's no sensible real value to return if x < 0, so return 0
+ auto log_x = Where3(scope,
+ Greater(scope, x, zero),
+ Log(scope, x),
+ ZerosLike(scope, x));
+ auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
+ return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
+ }
+}
+REGISTER_GRADIENT_OP("Pow", PowGrad);
+
// MaximumMinimumGradCommon adds shared ops to calculate gradients for
// the binary Maximum and Minimum ops.
Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
@@ -812,6 +847,183 @@ Status MinOrMaxGrad(const Scope& scope, const Operation& op,
REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
+Status ProdGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto zero = Const(scope, 0);
+ auto one = Const(scope, 1);
+
+ // The gradient can be expressed by dividing the product by each entry of
+ // the input tensor. If our input is
+ // [
+ // [3, 4],
+ // [5, 6],
+ // [7, 8]
+ // ]
+ // and we do a Prod operation on the axis 1, we will obtain [[105, 192]].
+ // The gradient will have the same shape as the input
+ // [
+ // [105/3, 192/4],
+ // dz * [105/5, 192/6],
+ // [105/7, 192/6]
+ // ]
+ // If the input contains a zero, the division is impossible but
+ // if we take the calculation that gave the first gradient
+ // (3 * 5 * 6)/3 is equal to 5 * 6
+ // the trick will be to cumprod the elements on the axis without
+ // the element at the current position (3 in the example above).
+ // We will take as example:
+ // [
+ // [
+ // [3.0, 4.0],
+ // [5.0, 6.0],
+ // [7.0, 8.0]
+ // ],
+ // [
+ // [3.0, 5.0],
+ // [0.0, 6.0],
+ // [5.0, 6.0]
+ // ]
+ // ]
+
+ // [2, 3, 2]
+ auto input_shape = Shape(scope, op.input(0));
+
+ // The Reshape with -1 flattens the reduction indices.
+ // [1]
+ auto reduction_indices = Reshape(scope, op.input(1), {-1});
+
+ // [2, 1, 2]
+ auto output_shape_kept_dims =
+ ReducedShapeHelper(scope, input_shape, reduction_indices);
+
+ // [1, 3, 1]
+ auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
+
+ // [[[105, 192]], [[0, 180]]]
+ auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
+
+ // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]]
+ auto grad_tiled = Tile(scope, grad, tile_scaling);
+
+ Scope cpu_scope = scope.WithDevice("/cpu:0");
+
+ // [3]
+ auto rank = Rank(cpu_scope, op.input(0));
+
+
+ // Normalize any negative indices in the reduction_axes to positive values.
+ auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank);
+
+ // [1]
+ auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32);
+
+ // [0, 1, 2]
+ auto idx = Range(cpu_scope, zero, rank, one);
+
+ // [0, 2]
+ auto other = SetDiff1D(cpu_scope, idx, reduced).out;
+
+ // [1, 0, 2]
+ auto perm =
+ Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0);
+
+ // 3 => [3]
+ auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0);
+
+ // 2 * 2 => [2]
+ auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0);
+
+ // [
+ // [
+ // [ 3., 4.],
+ // [ 3., 5.]
+ // ],
+ // [
+ // [ 5., 6.],
+ // [ 0., 6.]
+ // ],
+ // [
+ // [ 7., 8.],
+ // [ 5., 6.]
+ // ]
+ // ]
+ auto permuted = Transpose(scope, op.input(0), perm);
+
+ // [3, 2, 2]
+ auto permuted_shape = Shape(scope, permuted);
+
+ // [
+ // [ 3., 4., 3., 5.],
+ // [ 5., 6., 0., 6.],
+ // [ 7., 8., 5., 6.]
+ // ]
+ auto reshaped = Reshape(
+ scope, permuted,
+ Stack(scope, std::initializer_list<Input>{reduced_num, other_num}));
+
+ // [
+ // [ 1., 1., 1., 1.],
+ // [ 3., 4., 3., 5.],
+ // [ 15., 24., 0., 30.]
+ // ]
+ auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true));
+
+ // [
+ // [ 35., 48., 0., 36.],
+ // [ 7., 8., 5., 6.],
+ // [ 1., 1., 1., 1.]
+ // ]
+ auto right =
+ Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true));
+
+ // left * right =
+ // [
+ // [ 35., 48., 0., 36.],
+ // [ 21., 32., 15., 30.],
+ // [ 15., 24., 0., 30.]
+ // ]
+ // y =
+ // [
+ // [
+ // [ 35., 48.],
+ // [ 0., 36.]
+ // ],
+ // [
+ // [ 21., 32.],
+ // [ 15., 30.]
+ // ],
+ // [
+ // [ 15., 24.],
+ // [ 0., 30.]
+ // ]
+ // ]
+ auto y = Reshape(scope, Mul(scope, left, right), permuted_shape);
+
+ // out =
+ // [
+ // [
+ // [ 35., 48.],
+ // [ 21., 32.],
+ // [ 15., 24.]
+ // ],
+ // [
+ // [ 0., 36.],
+ // [ 15., 30.],
+ // [ 0., 30.]
+ // ]
+ // ]
+ auto out =
+ Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm)));
+
+ grad_outputs->push_back(Reshape(scope, out, input_shape));
+
+ // stop propagation along reduction_indices
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Prod", ProdGrad);
+
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const bool is_batch,
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 29def3c3ea..b94d797711 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -843,6 +843,14 @@ TEST_F(NaryGradTest, SquaredDifference) {
RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
}
+TEST_F(NaryGradTest, Pow) {
+ TensorShape shape({3});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ // fix exponent to avoid overflow
+ auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f}));
+ RunTest({x}, {shape}, {y}, {shape});
+}
+
TEST_F(NaryGradTest, Maximum) {
TensorShape shape({3, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
@@ -865,6 +873,15 @@ TEST_F(NaryGradTest, Minimum) {
RunTest(x, x_init_value, y, shape);
}
+TEST_F(NaryGradTest, Prod) {
+ TensorShape x_shape({2, 3, 2});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+ auto y = Prod(scope_, x, {1});
+ // y's shape is the result of reducing x along axes 1
+ TensorShape y_shape({2, 1, 2});
+ RunTest({x}, {x_shape}, {y}, {y_shape});
+}
+
TEST_F(NaryGradTest, Select) {
TensorShape shape({3, 4});
auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));