diff options
author | Patrick Nguyen <drpng@google.com> | 2017-12-28 16:04:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-28 16:08:58 -0800 |
commit | 20765b3e1ae3b718699592c98aa9805cb874b6d1 (patch) | |
tree | b429a74cd0046404644f34cc8fe6ff2cab78bb85 /tensorflow/cc/gradients | |
parent | 2e2715baa84720f786b38d1f9cb6887399020d6f (diff) |
Merge changes from github.
PiperOrigin-RevId: 180301735
Diffstat (limited to 'tensorflow/cc/gradients')
-rw-r--r-- | tensorflow/cc/gradients/math_grad.cc | 212 | ||||
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 17 |
2 files changed, 229 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index ebc0c77828..afd92fbf48 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -473,6 +473,41 @@ Status AddNGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("AddN", AddNGrad); +Status PowGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + auto x = ConjugateHelper(scope, op.input(0)); + auto y = ConjugateHelper(scope, op.input(1)); + auto z = ConjugateHelper(scope, op.output(0)); + auto grad = grad_inputs[0]; + // grad * y * pow(x, y - 1) + auto one = Cast(scope, Const(scope, 1.0), y.type()); + auto gx_1 = Mul(scope, + Mul(scope, grad, y), + Pow(scope, x, Sub(scope, y, one))); + // Avoid false singularity at x = 0 + DataType x_dtype = x.type(); + auto zero = Cast(scope, Const(scope, 0.0), x_dtype); + if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) { + // real(x) < 0 is fine for the complex case + auto log_x = Where3(scope, + NotEqual(scope, x, zero), + Log(scope, x), + ZerosLike(scope, x)); + auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); + } else { + // There's no sensible real value to return if x < 0, so return 0 + auto log_x = Where3(scope, + Greater(scope, x, zero), + Log(scope, x), + ZerosLike(scope, x)); + auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); + } +} +REGISTER_GRADIENT_OP("Pow", PowGrad); + // MaximumMinimumGradCommon adds shared ops to calculate gradients for // the binary Maximum and Minimum ops. Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, @@ -812,6 +847,183 @@ Status MinOrMaxGrad(const Scope& scope, const Operation& op, REGISTER_GRADIENT_OP("Min", MinOrMaxGrad); REGISTER_GRADIENT_OP("Max", MinOrMaxGrad); +Status ProdGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + auto zero = Const(scope, 0); + auto one = Const(scope, 1); + + // The gradient can be expressed by dividing the product by each entry of + // the input tensor. If our input is + // [ + // [3, 4], + // [5, 6], + // [7, 8] + // ] + // and we do a Prod operation on the axis 1, we will obtain [[105, 192]]. + // The gradient will have the same shape as the input + // [ + // [105/3, 192/4], + // dz * [105/5, 192/6], + // [105/7, 192/6] + // ] + // If the input contains a zero, the division is impossible but + // if we take the calculation that gave the first gradient + // (3 * 5 * 6)/3 is equal to 5 * 6 + // the trick will be to cumprod the elements on the axis without + // the element at the current position (3 in the example above). + // We will take as example: + // [ + // [ + // [3.0, 4.0], + // [5.0, 6.0], + // [7.0, 8.0] + // ], + // [ + // [3.0, 5.0], + // [0.0, 6.0], + // [5.0, 6.0] + // ] + // ] + + // [2, 3, 2] + auto input_shape = Shape(scope, op.input(0)); + + // The Reshape with -1 flattens the reduction indices. + // [1] + auto reduction_indices = Reshape(scope, op.input(1), {-1}); + + // [2, 1, 2] + auto output_shape_kept_dims = + ReducedShapeHelper(scope, input_shape, reduction_indices); + + // [1, 3, 1] + auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); + + // [[[105, 192]], [[0, 180]]] + auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); + + // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]] + auto grad_tiled = Tile(scope, grad, tile_scaling); + + Scope cpu_scope = scope.WithDevice("/cpu:0"); + + // [3] + auto rank = Rank(cpu_scope, op.input(0)); + + + // Normalize any negative indices in the reduction_axes to positive values. + auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank); + + // [1] + auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32); + + // [0, 1, 2] + auto idx = Range(cpu_scope, zero, rank, one); + + // [0, 2] + auto other = SetDiff1D(cpu_scope, idx, reduced).out; + + // [1, 0, 2] + auto perm = + Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0); + + // 3 => [3] + auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0); + + // 2 * 2 => [2] + auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0); + + // [ + // [ + // [ 3., 4.], + // [ 3., 5.] + // ], + // [ + // [ 5., 6.], + // [ 0., 6.] + // ], + // [ + // [ 7., 8.], + // [ 5., 6.] + // ] + // ] + auto permuted = Transpose(scope, op.input(0), perm); + + // [3, 2, 2] + auto permuted_shape = Shape(scope, permuted); + + // [ + // [ 3., 4., 3., 5.], + // [ 5., 6., 0., 6.], + // [ 7., 8., 5., 6.] + // ] + auto reshaped = Reshape( + scope, permuted, + Stack(scope, std::initializer_list<Input>{reduced_num, other_num})); + + // [ + // [ 1., 1., 1., 1.], + // [ 3., 4., 3., 5.], + // [ 15., 24., 0., 30.] + // ] + auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true)); + + // [ + // [ 35., 48., 0., 36.], + // [ 7., 8., 5., 6.], + // [ 1., 1., 1., 1.] + // ] + auto right = + Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true)); + + // left * right = + // [ + // [ 35., 48., 0., 36.], + // [ 21., 32., 15., 30.], + // [ 15., 24., 0., 30.] + // ] + // y = + // [ + // [ + // [ 35., 48.], + // [ 0., 36.] + // ], + // [ + // [ 21., 32.], + // [ 15., 30.] + // ], + // [ + // [ 15., 24.], + // [ 0., 30.] + // ] + // ] + auto y = Reshape(scope, Mul(scope, left, right), permuted_shape); + + // out = + // [ + // [ + // [ 35., 48.], + // [ 21., 32.], + // [ 15., 24.] + // ], + // [ + // [ 0., 36.], + // [ 15., 30.], + // [ 0., 30.] + // ] + // ] + auto out = + Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm))); + + grad_outputs->push_back(Reshape(scope, out, input_shape)); + + // stop propagation along reduction_indices + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Prod", ProdGrad); + // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 29def3c3ea..b94d797711 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -843,6 +843,14 @@ TEST_F(NaryGradTest, SquaredDifference) { RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); } +TEST_F(NaryGradTest, Pow) { + TensorShape shape({3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // fix exponent to avoid overflow + auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f})); + RunTest({x}, {shape}, {y}, {shape}); +} + TEST_F(NaryGradTest, Maximum) { TensorShape shape({3, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); @@ -865,6 +873,15 @@ TEST_F(NaryGradTest, Minimum) { RunTest(x, x_init_value, y, shape); } +TEST_F(NaryGradTest, Prod) { + TensorShape x_shape({2, 3, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = Prod(scope_, x, {1}); + // y's shape is the result of reducing x along axes 1 + TensorShape y_shape({2, 1, 2}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + TEST_F(NaryGradTest, Select) { TensorShape shape({3, 4}); auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); |