aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/math_grad_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/gradients/math_grad_test.cc')
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc152
1 files changed, 98 insertions, 54 deletions
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index d10a96a4ab..1248c0aa32 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -33,11 +33,50 @@ class MathGradTest : public ::testing::Test {
protected:
MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
- void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y,
- const bool t_y, const Output& dz,
+ void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
+ // Generate random test data.
+ std::vector<Tensor> data;
+ RandMatMulGradData(is_batch, t_x, t_y, &data);
+ auto x = Const(root_, data[0]);
+ auto y = Const(root_, data[1]);
+ auto dz = Const(root_, data[2]);
+
+ std::vector<Tensor> grad_outputs;
+ ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs);
+
+ if (!t_x && !t_y) {
+ test::ExpectClose(grad_outputs[0],
+ ComputeMatMul(is_batch, dz, false, y, true));
+ test::ExpectClose(grad_outputs[1],
+ ComputeMatMul(is_batch, x, true, dz, false));
+ } else if (t_x && !t_y) {
+ test::ExpectClose(grad_outputs[0],
+ ComputeMatMul(is_batch, y, false, dz, true));
+ test::ExpectClose(grad_outputs[1],
+ ComputeMatMul(is_batch, x, false, dz, false));
+ } else if (!t_x && t_y) {
+ test::ExpectClose(grad_outputs[0],
+ ComputeMatMul(is_batch, dz, false, y, false));
+ test::ExpectClose(grad_outputs[1],
+ ComputeMatMul(is_batch, dz, true, x, false));
+ } else {
+ test::ExpectClose(grad_outputs[0],
+ ComputeMatMul(is_batch, y, true, dz, true));
+ test::ExpectClose(grad_outputs[1],
+ ComputeMatMul(is_batch, dz, true, x, true));
+ }
+ }
+
+ void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x,
+ const Output& y, const bool t_y, const Output& dz,
std::vector<Tensor>* out) {
// Compute forward MatMul: z = MatMul(x, y).
- auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
+ Output z;
+ if (is_batch) {
+ z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
+ } else {
+ z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
+ }
TF_ASSERT_OK(root_.status());
CHECK_NOTNULL(z.node());
std::vector<Output> grad_outputs;
@@ -49,31 +88,60 @@ class MathGradTest : public ::testing::Test {
test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
}
- Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y,
- const bool t_y) {
- auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
+ Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x,
+ const Output& y, const bool t_y) {
+ Output z;
+ if (is_batch) {
+ z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
+ } else {
+ z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
+ }
TF_EXPECT_OK(root_.status());
Tensor out;
test::GetTensor(root_, z, &out);
return out;
}
- void RandMatMulGradData(const bool tx, const bool ty,
+ void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty,
std::vector<Tensor>* data) {
+ // Choose a random batch size in [1, 4]
+ const int b = 1 + (random::New64() % 4);
// z = MatMul(x, y)
const int m = Rand();
const int k = Rand();
const int n = Rand();
- // x.shape = [m, k]
- const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
+
+ TensorShape x_shape;
+ if (is_batch) {
+ // x.shape = [b, m, k]
+ x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
+ } else {
+ // x.shape = [m, k]
+ x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
+ }
data->emplace_back(DT_FLOAT, x_shape);
RandTensor(&data->back());
- // y.shape = [k, n]
- const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
+
+ TensorShape y_shape;
+ if (is_batch) {
+ // y.shape = [b, k, n]
+ y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
+ } else {
+ // y.shape = [k, n]
+ y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
+ }
data->emplace_back(DT_FLOAT, y_shape);
RandTensor(&data->back());
- // z.shape = [m, n]
- data->emplace_back(DT_FLOAT, TensorShape({m, n}));
+
+ TensorShape z_shape;
+ if (is_batch) {
+ // z.shape = [b, m, n]
+ z_shape = TensorShape({b, m, n});
+ } else {
+ // z.shape = [m, n]
+ z_shape = TensorShape({m, n});
+ }
+ data->emplace_back(DT_FLOAT, z_shape);
RandTensor(&data->back());
}
@@ -88,59 +156,35 @@ class MathGradTest : public ::testing::Test {
};
TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
- std::vector<Tensor> data;
- RandMatMulGradData(false, false, &data);
- auto x = Const(root_, data[0]);
- auto y = Const(root_, data[1]);
- auto dz = Const(root_, data[2]);
-
- std::vector<Tensor> grad_outputs;
- ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs);
-
- test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true));
- test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false));
+ TestMatMulGrad(false, false, false);
}
TEST_F(MathGradTest, MatMulGrad_TransposeX) {
- std::vector<Tensor> data;
- RandMatMulGradData(true, false, &data);
- auto x = Const(root_, data[0]);
- auto y = Const(root_, data[1]);
- auto dz = Const(root_, data[2]);
-
- std::vector<Tensor> grad_outputs;
- ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs);
-
- test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true));
- test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false));
+ TestMatMulGrad(false, true, false);
}
TEST_F(MathGradTest, MatMulGrad_TransposeY) {
- std::vector<Tensor> data;
- RandMatMulGradData(false, true, &data);
- auto x = Const(root_, data[0]);
- auto y = Const(root_, data[1]);
- auto dz = Const(root_, data[2]);
+ TestMatMulGrad(false, false, true);
+}
- std::vector<Tensor> grad_outputs;
- ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs);
+TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
+ TestMatMulGrad(false, true, true);
+}
- test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false));
- test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false));
+TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
+ TestMatMulGrad(true, false, false);
}
-TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
- std::vector<Tensor> data;
- RandMatMulGradData(true, true, &data);
- auto x = Const(root_, data[0]);
- auto y = Const(root_, data[1]);
- auto dz = Const(root_, data[2]);
+TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
+ TestMatMulGrad(true, true, false);
+}
- std::vector<Tensor> grad_outputs;
- ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs);
+TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
+ TestMatMulGrad(true, false, true);
+}
- test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true));
- test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true));
+TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
+ TestMatMulGrad(true, true, true);
}
} // namespace