aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-07 08:31:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-07 09:47:28 -0700
commita70623971587c96e450b12a8c16c166cfc040ec9 (patch)
treea501b9ab396f2e087a617388f3f39b40c90e9798 /tensorflow/cc/gradients
parent649d182d5059a76308c279c9b790359e557e7aa0 (diff)
C++ Gradients: Adds gradient function for BatchMatMul, cleans up unit tests.
Change: 132445367
Diffstat (limited to 'tensorflow/cc/gradients')
-rw-r--r--tensorflow/cc/gradients/math_grad.cc57
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc152
2 files changed, 136 insertions, 73 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 85093015b7..54e938ed07 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -26,16 +26,26 @@ REGISTER_NO_GRADIENT_OP("Const");
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
-Status MatMulGradHelper(const Scope& scope, const Output& x0, const bool adj_x0,
- const Output& x1, const bool adj_x1, const Output& y0,
- const bool adj_y0, const Output& y1, const bool adj_y1,
+Status MatMulGradHelper(const Scope& scope, const bool is_batch,
+ const Output& x0, const bool adj_x0, const Output& x1,
+ const bool adj_x1, const Output& y0, const bool adj_y0,
+ const Output& y1, const bool adj_y1,
std::vector<Output>* grad_outputs) {
- auto dx =
- MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
- grad_outputs->push_back(dx);
- auto dy =
- MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
- grad_outputs->push_back(dy);
+ if (is_batch == false) {
+ auto dx =
+ MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
+ grad_outputs->push_back(dx);
+ auto dy =
+ MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
+ grad_outputs->push_back(dy);
+ } else {
+ auto dx =
+ BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
+ grad_outputs->push_back(dx);
+ auto dy =
+ BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
+ grad_outputs->push_back(dy);
+ }
return Status::OK();
}
@@ -44,6 +54,7 @@ Status MatMulGradHelper(const Scope& scope, const Output& x0, const bool adj_x0,
// combinations.
// TODO(andydavis) Re-use this function for BatchMatMulGrad.
Status MatMulGradCommon(const Scope& scope, const Operation& op,
+ const bool is_batch,
const std::vector<Output>& grad_inputs,
const string& attr_adj_x, const string& attr_adj_y,
std::vector<Output>* grad_outputs) {
@@ -60,32 +71,40 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
if (!ta && !tb) {
- return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), true,
- op.input(0), true, grad_inputs[0], false,
+ return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
+ true, op.input(0), true, grad_inputs[0], false,
grad_outputs);
} else if (!ta && tb) {
- return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), false,
- grad_inputs[0], true, op.input(0), false,
+ return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
+ false, grad_inputs[0], true, op.input(0), false,
grad_outputs);
} else if (ta && !tb) {
- return MatMulGradHelper(scope, op.input(1), false, grad_inputs[0], true,
- op.input(0), false, grad_inputs[0], false,
+ return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0],
+ true, op.input(0), false, grad_inputs[0], false,
grad_outputs);
}
- return MatMulGradHelper(scope, op.input(1), true, grad_inputs[0], true,
- grad_inputs[0], true, op.input(0), true,
+ return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0],
+ true, grad_inputs[0], true, op.input(0), true,
grad_outputs);
}
Status MatMulGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
- return MatMulGradCommon(scope, op, grad_inputs, "transpose_a", "transpose_b",
- grad_outputs);
+ return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
+ "transpose_b", grad_outputs);
}
REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
+Status BatchMatMulGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
+ grad_outputs);
+}
+REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
+
} // anonymous namespace
} // namespace ops
} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index d10a96a4ab..1248c0aa32 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -33,11 +33,50 @@ class MathGradTest : public ::testing::Test {
protected:
MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
- void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y,
- const bool t_y, const Output& dz,
+ void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
+ // Generate random test data.
+ std::vector<Tensor> data;
+ RandMatMulGradData(is_batch, t_x, t_y, &data);
+ auto x = Const(root_, data[0]);
+ auto y = Const(root_, data[1]);
+ auto dz = Const(root_, data[2]);
+
+ std::vector<Tensor> grad_outputs;
+ ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs);
+
+ if (!t_x && !t_y) {
+ test::ExpectClose(grad_outputs[0],
+ ComputeMatMul(is_batch, dz, false, y, true));
+ test::ExpectClose(grad_outputs[1],
+ ComputeMatMul(is_batch, x, true, dz, false));
+ } else if (t_x && !t_y) {
+ test::ExpectClose(grad_outputs[0],
+ ComputeMatMul(is_batch, y, false, dz, true));
+ test::ExpectClose(grad_outputs[1],
+ ComputeMatMul(is_batch, x, false, dz, false));
+ } else if (!t_x && t_y) {
+ test::ExpectClose(grad_outputs[0],
+ ComputeMatMul(is_batch, dz, false, y, false));
+ test::ExpectClose(grad_outputs[1],
+ ComputeMatMul(is_batch, dz, true, x, false));
+ } else {
+ test::ExpectClose(grad_outputs[0],
+ ComputeMatMul(is_batch, y, true, dz, true));
+ test::ExpectClose(grad_outputs[1],
+ ComputeMatMul(is_batch, dz, true, x, true));
+ }
+ }
+
+ void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x,
+ const Output& y, const bool t_y, const Output& dz,
std::vector<Tensor>* out) {
// Compute forward MatMul: z = MatMul(x, y).
- auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
+ Output z;
+ if (is_batch) {
+ z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
+ } else {
+ z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
+ }
TF_ASSERT_OK(root_.status());
CHECK_NOTNULL(z.node());
std::vector<Output> grad_outputs;
@@ -49,31 +88,60 @@ class MathGradTest : public ::testing::Test {
test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
}
- Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y,
- const bool t_y) {
- auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
+ Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x,
+ const Output& y, const bool t_y) {
+ Output z;
+ if (is_batch) {
+ z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
+ } else {
+ z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
+ }
TF_EXPECT_OK(root_.status());
Tensor out;
test::GetTensor(root_, z, &out);
return out;
}
- void RandMatMulGradData(const bool tx, const bool ty,
+ void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty,
std::vector<Tensor>* data) {
+ // Choose a random batch size in [1, 4]
+ const int b = 1 + (random::New64() % 4);
// z = MatMul(x, y)
const int m = Rand();
const int k = Rand();
const int n = Rand();
- // x.shape = [m, k]
- const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
+
+ TensorShape x_shape;
+ if (is_batch) {
+ // x.shape = [b, m, k]
+ x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
+ } else {
+ // x.shape = [m, k]
+ x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
+ }
data->emplace_back(DT_FLOAT, x_shape);
RandTensor(&data->back());
- // y.shape = [k, n]
- const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
+
+ TensorShape y_shape;
+ if (is_batch) {
+ // y.shape = [b, k, n]
+ y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
+ } else {
+ // y.shape = [k, n]
+ y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
+ }
data->emplace_back(DT_FLOAT, y_shape);
RandTensor(&data->back());
- // z.shape = [m, n]
- data->emplace_back(DT_FLOAT, TensorShape({m, n}));
+
+ TensorShape z_shape;
+ if (is_batch) {
+ // z.shape = [b, m, n]
+ z_shape = TensorShape({b, m, n});
+ } else {
+ // z.shape = [m, n]
+ z_shape = TensorShape({m, n});
+ }
+ data->emplace_back(DT_FLOAT, z_shape);
RandTensor(&data->back());
}
@@ -88,59 +156,35 @@ class MathGradTest : public ::testing::Test {
};
TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
- std::vector<Tensor> data;
- RandMatMulGradData(false, false, &data);
- auto x = Const(root_, data[0]);
- auto y = Const(root_, data[1]);
- auto dz = Const(root_, data[2]);
-
- std::vector<Tensor> grad_outputs;
- ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs);
-
- test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true));
- test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false));
+ TestMatMulGrad(false, false, false);
}
TEST_F(MathGradTest, MatMulGrad_TransposeX) {
- std::vector<Tensor> data;
- RandMatMulGradData(true, false, &data);
- auto x = Const(root_, data[0]);
- auto y = Const(root_, data[1]);
- auto dz = Const(root_, data[2]);
-
- std::vector<Tensor> grad_outputs;
- ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs);
-
- test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true));
- test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false));
+ TestMatMulGrad(false, true, false);
}
TEST_F(MathGradTest, MatMulGrad_TransposeY) {
- std::vector<Tensor> data;
- RandMatMulGradData(false, true, &data);
- auto x = Const(root_, data[0]);
- auto y = Const(root_, data[1]);
- auto dz = Const(root_, data[2]);
+ TestMatMulGrad(false, false, true);
+}
- std::vector<Tensor> grad_outputs;
- ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs);
+TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
+ TestMatMulGrad(false, true, true);
+}
- test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false));
- test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false));
+TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
+ TestMatMulGrad(true, false, false);
}
-TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
- std::vector<Tensor> data;
- RandMatMulGradData(true, true, &data);
- auto x = Const(root_, data[0]);
- auto y = Const(root_, data[1]);
- auto dz = Const(root_, data[2]);
+TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
+ TestMatMulGrad(true, true, false);
+}
- std::vector<Tensor> grad_outputs;
- ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs);
+TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
+ TestMatMulGrad(true, false, true);
+}
- test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true));
- test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true));
+TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
+ TestMatMulGrad(true, true, true);
}
} // namespace