diff options
Diffstat (limited to 'tensorflow/cc/gradients/math_grad_test.cc')
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 152 |
1 files changed, 98 insertions, 54 deletions
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index d10a96a4ab..1248c0aa32 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -33,11 +33,50 @@ class MathGradTest : public ::testing::Test { protected: MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} - void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y, - const bool t_y, const Output& dz, + void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { + // Generate random test data. + std::vector<Tensor> data; + RandMatMulGradData(is_batch, t_x, t_y, &data); + auto x = Const(root_, data[0]); + auto y = Const(root_, data[1]); + auto dz = Const(root_, data[2]); + + std::vector<Tensor> grad_outputs; + ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs); + + if (!t_x && !t_y) { + test::ExpectClose(grad_outputs[0], + ComputeMatMul(is_batch, dz, false, y, true)); + test::ExpectClose(grad_outputs[1], + ComputeMatMul(is_batch, x, true, dz, false)); + } else if (t_x && !t_y) { + test::ExpectClose(grad_outputs[0], + ComputeMatMul(is_batch, y, false, dz, true)); + test::ExpectClose(grad_outputs[1], + ComputeMatMul(is_batch, x, false, dz, false)); + } else if (!t_x && t_y) { + test::ExpectClose(grad_outputs[0], + ComputeMatMul(is_batch, dz, false, y, false)); + test::ExpectClose(grad_outputs[1], + ComputeMatMul(is_batch, dz, true, x, false)); + } else { + test::ExpectClose(grad_outputs[0], + ComputeMatMul(is_batch, y, true, dz, true)); + test::ExpectClose(grad_outputs[1], + ComputeMatMul(is_batch, dz, true, x, true)); + } + } + + void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x, + const Output& y, const bool t_y, const Output& dz, std::vector<Tensor>* out) { // Compute forward MatMul: z = MatMul(x, y). - auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); + Output z; + if (is_batch) { + z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); + } else { + z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); + } TF_ASSERT_OK(root_.status()); CHECK_NOTNULL(z.node()); std::vector<Output> grad_outputs; @@ -49,31 +88,60 @@ class MathGradTest : public ::testing::Test { test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); } - Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y, - const bool t_y) { - auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); + Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x, + const Output& y, const bool t_y) { + Output z; + if (is_batch) { + z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); + } else { + z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); + } TF_EXPECT_OK(root_.status()); Tensor out; test::GetTensor(root_, z, &out); return out; } - void RandMatMulGradData(const bool tx, const bool ty, + void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty, std::vector<Tensor>* data) { + // Choose a random batch size in [1, 4] + const int b = 1 + (random::New64() % 4); // z = MatMul(x, y) const int m = Rand(); const int k = Rand(); const int n = Rand(); - // x.shape = [m, k] - const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); + + TensorShape x_shape; + if (is_batch) { + // x.shape = [b, m, k] + x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); + } else { + // x.shape = [m, k] + x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); + } data->emplace_back(DT_FLOAT, x_shape); RandTensor(&data->back()); - // y.shape = [k, n] - const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); + + TensorShape y_shape; + if (is_batch) { + // y.shape = [b, k, n] + y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); + } else { + // y.shape = [k, n] + y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); + } data->emplace_back(DT_FLOAT, y_shape); RandTensor(&data->back()); - // z.shape = [m, n] - data->emplace_back(DT_FLOAT, TensorShape({m, n})); + + TensorShape z_shape; + if (is_batch) { + // z.shape = [b, m, n] + z_shape = TensorShape({b, m, n}); + } else { + // z.shape = [m, n] + z_shape = TensorShape({m, n}); + } + data->emplace_back(DT_FLOAT, z_shape); RandTensor(&data->back()); } @@ -88,59 +156,35 @@ class MathGradTest : public ::testing::Test { }; TEST_F(MathGradTest, MatMulGrad_NoTranspose) { - std::vector<Tensor> data; - RandMatMulGradData(false, false, &data); - auto x = Const(root_, data[0]); - auto y = Const(root_, data[1]); - auto dz = Const(root_, data[2]); - - std::vector<Tensor> grad_outputs; - ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs); - - test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true)); - test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false)); + TestMatMulGrad(false, false, false); } TEST_F(MathGradTest, MatMulGrad_TransposeX) { - std::vector<Tensor> data; - RandMatMulGradData(true, false, &data); - auto x = Const(root_, data[0]); - auto y = Const(root_, data[1]); - auto dz = Const(root_, data[2]); - - std::vector<Tensor> grad_outputs; - ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs); - - test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true)); - test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false)); + TestMatMulGrad(false, true, false); } TEST_F(MathGradTest, MatMulGrad_TransposeY) { - std::vector<Tensor> data; - RandMatMulGradData(false, true, &data); - auto x = Const(root_, data[0]); - auto y = Const(root_, data[1]); - auto dz = Const(root_, data[2]); + TestMatMulGrad(false, false, true); +} - std::vector<Tensor> grad_outputs; - ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs); +TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { + TestMatMulGrad(false, true, true); +} - test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false)); - test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false)); +TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { + TestMatMulGrad(true, false, false); } -TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { - std::vector<Tensor> data; - RandMatMulGradData(true, true, &data); - auto x = Const(root_, data[0]); - auto y = Const(root_, data[1]); - auto dz = Const(root_, data[2]); +TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { + TestMatMulGrad(true, true, false); +} - std::vector<Tensor> grad_outputs; - ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs); +TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { + TestMatMulGrad(true, false, true); +} - test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true)); - test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true)); +TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { + TestMatMulGrad(true, true, true); } } // namespace |