diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-13 21:30:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-13 21:34:12 -0700 |
commit | c54693a648730c9223e1aa0a118322f6f0ae8b17 (patch) | |
tree | 50ebfacd3d25c0005ead74d76dbe52ae9dc2db1e /tensorflow/cc/gradients | |
parent | de0bc082f153e36f9919c2cac8fc1063fe3c9186 (diff) |
Use the numeric gradient checker for matrix gradient tests.
- Port MathGradTest to use ComputeGradientError.
- Add complex type support for MatMulGrad and BatchMatMulGrad.
PiperOrigin-RevId: 168638169
Diffstat (limited to 'tensorflow/cc/gradients')
-rw-r--r-- | tensorflow/cc/gradients/math_grad.cc | 39 | ||||
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 141 |
2 files changed, 78 insertions, 102 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index b88332ebc7..aba17cfe0c 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -790,42 +790,37 @@ Status MatMulGradHelper(const Scope& scope, const bool is_batch, // MatMulGrad common used to read and check node attr state, and determine // proper MatMul products for gradients based on input matrix transposition // combinations. -// TODO(andydavis) Re-use this function for BatchMatMulGrad. Status MatMulGradCommon(const Scope& scope, const Operation& op, const bool is_batch, const std::vector<Output>& grad_inputs, const string& attr_adj_x, const string& attr_adj_y, std::vector<Output>* grad_outputs) { - DataType dtype; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype)); - if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { - return errors::Unimplemented( - "MatMul gradient for complex data type is not supported yet."); + auto a = op.input(0); + auto b = op.input(1); + // Use conjugate of the inputs for MatMul + if (is_batch == false) { + a = ConjugateHelper(scope, a); + b = ConjugateHelper(scope, b); } + auto product = op.output(0); bool ta; bool tb; - TF_RETURN_IF_ERROR( - GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta)); - TF_RETURN_IF_ERROR( - GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb)); + TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta)); + TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb)); if (!ta && !tb) { - return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), - true, op.input(0), true, grad_inputs[0], false, - grad_outputs); + return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a, + true, grad_inputs[0], false, grad_outputs); } else if (!ta && tb) { - return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), - false, grad_inputs[0], true, op.input(0), false, - grad_outputs); + return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false, + grad_inputs[0], true, a, false, grad_outputs); } else if (ta && !tb) { - return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0], - true, op.input(0), false, grad_inputs[0], false, - grad_outputs); + return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a, + false, grad_inputs[0], false, grad_outputs); } - return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0], - true, grad_inputs[0], true, op.input(0), true, - grad_outputs); + return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true, + grad_inputs[0], true, a, true, grad_outputs); } Status MatMulGrad(const Scope& scope, const Operation& op, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 047243aa6a..3534f16e8f 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -507,77 +507,34 @@ class MathGradTest : public ::testing::Test { protected: MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} + template <typename T> void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { - // Generate random test data. - std::vector<Tensor> data; - RandMatMulGradData(is_batch, t_x, t_y, &data); - auto x = Const(root_, data[0]); - auto y = Const(root_, data[1]); - auto dz = Const(root_, data[2]); - - std::vector<Tensor> grad_outputs; - ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs); - - if (!t_x && !t_y) { - test::ExpectClose(grad_outputs[0], - ComputeMatMul(is_batch, dz, false, y, true)); - test::ExpectClose(grad_outputs[1], - ComputeMatMul(is_batch, x, true, dz, false)); - } else if (t_x && !t_y) { - test::ExpectClose(grad_outputs[0], - ComputeMatMul(is_batch, y, false, dz, true)); - test::ExpectClose(grad_outputs[1], - ComputeMatMul(is_batch, x, false, dz, false)); - } else if (!t_x && t_y) { - test::ExpectClose(grad_outputs[0], - ComputeMatMul(is_batch, dz, false, y, false)); - test::ExpectClose(grad_outputs[1], - ComputeMatMul(is_batch, dz, true, x, false)); - } else { - test::ExpectClose(grad_outputs[0], - ComputeMatMul(is_batch, y, true, dz, true)); - test::ExpectClose(grad_outputs[1], - ComputeMatMul(is_batch, dz, true, x, true)); - } - } - - void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x, - const Output& y, const bool t_y, const Output& dz, - std::vector<Tensor>* out) { - // Compute forward MatMul: z = MatMul(x, y). - Output z; - if (is_batch) { - z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); - } else { - z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); - } TF_ASSERT_OK(root_.status()); - CHECK_NOTNULL(z.node()); - std::vector<Output> grad_outputs; - // Call MatMulGrad which populates 'grad_outputs'. - TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz}, - &grad_outputs)); - ASSERT_EQ(2, grad_outputs.size()); - // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'. - test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); - } - - Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x, - const Output& y, const bool t_y) { + // Generate random (but compatible) shapes for matrix multiplication. + std::vector<TensorShape> shapes; + RandMatMulShapes(is_batch, t_x, t_y, &shapes); + TensorShape x_shape = shapes[0]; + TensorShape y_shape = shapes[1]; + TensorShape z_shape = shapes[2]; + auto x = + Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape)); + auto y = + Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape)); Output z; if (is_batch) { z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); } else { z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); } - TF_EXPECT_OK(root_.status()); - Tensor out; - test::GetTensor(root_, z, &out); - return out; + + float max_error; + TF_ASSERT_OK((ComputeGradientError<T, T, float>( + root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error))); + EXPECT_LT(max_error, 1e-3); } - void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty, - std::vector<Tensor>* data) { + void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty, + std::vector<TensorShape>* shapes) { // Choose a random batch size in [1, 4] const int b = 1 + (random::New64() % 4); // z = MatMul(x, y) @@ -593,8 +550,7 @@ class MathGradTest : public ::testing::Test { // x.shape = [m, k] x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); } - data->emplace_back(DT_FLOAT, x_shape); - RandTensor(&data->back()); + shapes->push_back(x_shape); TensorShape y_shape; if (is_batch) { @@ -604,8 +560,7 @@ class MathGradTest : public ::testing::Test { // y.shape = [k, n] y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); } - data->emplace_back(DT_FLOAT, y_shape); - RandTensor(&data->back()); + shapes->push_back(y_shape); TensorShape z_shape; if (is_batch) { @@ -615,13 +570,7 @@ class MathGradTest : public ::testing::Test { // z.shape = [m, n] z_shape = TensorShape({m, n}); } - data->emplace_back(DT_FLOAT, z_shape); - RandTensor(&data->back()); - } - - void RandTensor(Tensor* t) { - test::FillFn<float>( - t, [this](const int i) { return static_cast<float>(Rand()); }); + shapes->push_back(z_shape); } int Rand() { return 1 + (random::New64() % 10); } @@ -630,35 +579,67 @@ class MathGradTest : public ::testing::Test { }; TEST_F(MathGradTest, MatMulGrad_NoTranspose) { - TestMatMulGrad(false, false, false); + TestMatMulGrad<float>(false, false, false); +} + +TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) { + TestMatMulGrad<complex64>(false, false, false); } TEST_F(MathGradTest, MatMulGrad_TransposeX) { - TestMatMulGrad(false, true, false); + TestMatMulGrad<float>(false, true, false); +} + +TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) { + TestMatMulGrad<complex64>(false, true, false); } TEST_F(MathGradTest, MatMulGrad_TransposeY) { - TestMatMulGrad(false, false, true); + TestMatMulGrad<float>(false, false, true); +} + +TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) { + TestMatMulGrad<complex64>(false, false, true); } TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { - TestMatMulGrad(false, true, true); + TestMatMulGrad<float>(false, true, true); +} + +TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) { + TestMatMulGrad<complex64>(false, true, true); } TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { - TestMatMulGrad(true, false, false); + TestMatMulGrad<float>(true, false, false); +} + +TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) { + TestMatMulGrad<complex64>(true, false, false); } TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { - TestMatMulGrad(true, true, false); + TestMatMulGrad<float>(true, true, false); +} + +TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) { + TestMatMulGrad<complex64>(true, true, false); } TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { - TestMatMulGrad(true, false, true); + TestMatMulGrad<float>(true, false, true); +} + +TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) { + TestMatMulGrad<complex64>(true, false, true); } TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { - TestMatMulGrad(true, true, true); + TestMatMulGrad<float>(true, true, true); +} + +TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) { + TestMatMulGrad<complex64>(true, true, true); } class NaryGradTest : public ::testing::Test { |