diff options
author | 2016-09-07 08:31:55 -0800 | |
---|---|---|
committer | 2016-09-07 09:47:28 -0700 | |
commit | a70623971587c96e450b12a8c16c166cfc040ec9 (patch) | |
tree | a501b9ab396f2e087a617388f3f39b40c90e9798 /tensorflow/cc/gradients | |
parent | 649d182d5059a76308c279c9b790359e557e7aa0 (diff) |
C++ Gradients: Adds gradient function for BatchMatMul, cleans up unit tests.
Change: 132445367
Diffstat (limited to 'tensorflow/cc/gradients')
-rw-r--r-- | tensorflow/cc/gradients/math_grad.cc | 57 | ||||
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 152 |
2 files changed, 136 insertions, 73 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 85093015b7..54e938ed07 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -26,16 +26,26 @@ REGISTER_NO_GRADIENT_OP("Const"); // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. -Status MatMulGradHelper(const Scope& scope, const Output& x0, const bool adj_x0, - const Output& x1, const bool adj_x1, const Output& y0, - const bool adj_y0, const Output& y1, const bool adj_y1, +Status MatMulGradHelper(const Scope& scope, const bool is_batch, + const Output& x0, const bool adj_x0, const Output& x1, + const bool adj_x1, const Output& y0, const bool adj_y0, + const Output& y1, const bool adj_y1, std::vector<Output>* grad_outputs) { - auto dx = - MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); - grad_outputs->push_back(dx); - auto dy = - MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); - grad_outputs->push_back(dy); + if (is_batch == false) { + auto dx = + MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); + grad_outputs->push_back(dx); + auto dy = + MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); + grad_outputs->push_back(dy); + } else { + auto dx = + BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); + grad_outputs->push_back(dx); + auto dy = + BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); + grad_outputs->push_back(dy); + } return Status::OK(); } @@ -44,6 +54,7 @@ Status MatMulGradHelper(const Scope& scope, const Output& x0, const bool adj_x0, // combinations. // TODO(andydavis) Re-use this function for BatchMatMulGrad. Status MatMulGradCommon(const Scope& scope, const Operation& op, + const bool is_batch, const std::vector<Output>& grad_inputs, const string& attr_adj_x, const string& attr_adj_y, std::vector<Output>* grad_outputs) { @@ -60,32 +71,40 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb)); if (!ta && !tb) { - return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), true, - op.input(0), true, grad_inputs[0], false, + return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), + true, op.input(0), true, grad_inputs[0], false, grad_outputs); } else if (!ta && tb) { - return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), false, - grad_inputs[0], true, op.input(0), false, + return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), + false, grad_inputs[0], true, op.input(0), false, grad_outputs); } else if (ta && !tb) { - return MatMulGradHelper(scope, op.input(1), false, grad_inputs[0], true, - op.input(0), false, grad_inputs[0], false, + return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0], + true, op.input(0), false, grad_inputs[0], false, grad_outputs); } - return MatMulGradHelper(scope, op.input(1), true, grad_inputs[0], true, - grad_inputs[0], true, op.input(0), true, + return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0], + true, grad_inputs[0], true, op.input(0), true, grad_outputs); } Status MatMulGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { - return MatMulGradCommon(scope, op, grad_inputs, "transpose_a", "transpose_b", - grad_outputs); + return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", + "transpose_b", grad_outputs); } REGISTER_GRADIENT_OP("MatMul", MatMulGrad); +Status BatchMatMulGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", + grad_outputs); +} +REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index d10a96a4ab..1248c0aa32 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -33,11 +33,50 @@ class MathGradTest : public ::testing::Test { protected: MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} - void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y, - const bool t_y, const Output& dz, + void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { + // Generate random test data. + std::vector<Tensor> data; + RandMatMulGradData(is_batch, t_x, t_y, &data); + auto x = Const(root_, data[0]); + auto y = Const(root_, data[1]); + auto dz = Const(root_, data[2]); + + std::vector<Tensor> grad_outputs; + ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs); + + if (!t_x && !t_y) { + test::ExpectClose(grad_outputs[0], + ComputeMatMul(is_batch, dz, false, y, true)); + test::ExpectClose(grad_outputs[1], + ComputeMatMul(is_batch, x, true, dz, false)); + } else if (t_x && !t_y) { + test::ExpectClose(grad_outputs[0], + ComputeMatMul(is_batch, y, false, dz, true)); + test::ExpectClose(grad_outputs[1], + ComputeMatMul(is_batch, x, false, dz, false)); + } else if (!t_x && t_y) { + test::ExpectClose(grad_outputs[0], + ComputeMatMul(is_batch, dz, false, y, false)); + test::ExpectClose(grad_outputs[1], + ComputeMatMul(is_batch, dz, true, x, false)); + } else { + test::ExpectClose(grad_outputs[0], + ComputeMatMul(is_batch, y, true, dz, true)); + test::ExpectClose(grad_outputs[1], + ComputeMatMul(is_batch, dz, true, x, true)); + } + } + + void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x, + const Output& y, const bool t_y, const Output& dz, std::vector<Tensor>* out) { // Compute forward MatMul: z = MatMul(x, y). - auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); + Output z; + if (is_batch) { + z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); + } else { + z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); + } TF_ASSERT_OK(root_.status()); CHECK_NOTNULL(z.node()); std::vector<Output> grad_outputs; @@ -49,31 +88,60 @@ class MathGradTest : public ::testing::Test { test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); } - Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y, - const bool t_y) { - auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); + Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x, + const Output& y, const bool t_y) { + Output z; + if (is_batch) { + z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); + } else { + z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); + } TF_EXPECT_OK(root_.status()); Tensor out; test::GetTensor(root_, z, &out); return out; } - void RandMatMulGradData(const bool tx, const bool ty, + void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty, std::vector<Tensor>* data) { + // Choose a random batch size in [1, 4] + const int b = 1 + (random::New64() % 4); // z = MatMul(x, y) const int m = Rand(); const int k = Rand(); const int n = Rand(); - // x.shape = [m, k] - const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); + + TensorShape x_shape; + if (is_batch) { + // x.shape = [b, m, k] + x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); + } else { + // x.shape = [m, k] + x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); + } data->emplace_back(DT_FLOAT, x_shape); RandTensor(&data->back()); - // y.shape = [k, n] - const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); + + TensorShape y_shape; + if (is_batch) { + // y.shape = [b, k, n] + y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); + } else { + // y.shape = [k, n] + y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); + } data->emplace_back(DT_FLOAT, y_shape); RandTensor(&data->back()); - // z.shape = [m, n] - data->emplace_back(DT_FLOAT, TensorShape({m, n})); + + TensorShape z_shape; + if (is_batch) { + // z.shape = [b, m, n] + z_shape = TensorShape({b, m, n}); + } else { + // z.shape = [m, n] + z_shape = TensorShape({m, n}); + } + data->emplace_back(DT_FLOAT, z_shape); RandTensor(&data->back()); } @@ -88,59 +156,35 @@ class MathGradTest : public ::testing::Test { }; TEST_F(MathGradTest, MatMulGrad_NoTranspose) { - std::vector<Tensor> data; - RandMatMulGradData(false, false, &data); - auto x = Const(root_, data[0]); - auto y = Const(root_, data[1]); - auto dz = Const(root_, data[2]); - - std::vector<Tensor> grad_outputs; - ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs); - - test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true)); - test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false)); + TestMatMulGrad(false, false, false); } TEST_F(MathGradTest, MatMulGrad_TransposeX) { - std::vector<Tensor> data; - RandMatMulGradData(true, false, &data); - auto x = Const(root_, data[0]); - auto y = Const(root_, data[1]); - auto dz = Const(root_, data[2]); - - std::vector<Tensor> grad_outputs; - ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs); - - test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true)); - test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false)); + TestMatMulGrad(false, true, false); } TEST_F(MathGradTest, MatMulGrad_TransposeY) { - std::vector<Tensor> data; - RandMatMulGradData(false, true, &data); - auto x = Const(root_, data[0]); - auto y = Const(root_, data[1]); - auto dz = Const(root_, data[2]); + TestMatMulGrad(false, false, true); +} - std::vector<Tensor> grad_outputs; - ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs); +TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { + TestMatMulGrad(false, true, true); +} - test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false)); - test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false)); +TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { + TestMatMulGrad(true, false, false); } -TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { - std::vector<Tensor> data; - RandMatMulGradData(true, true, &data); - auto x = Const(root_, data[0]); - auto y = Const(root_, data[1]); - auto dz = Const(root_, data[2]); +TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { + TestMatMulGrad(true, true, false); +} - std::vector<Tensor> grad_outputs; - ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs); +TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { + TestMatMulGrad(true, false, true); +} - test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true)); - test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true)); +TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { + TestMatMulGrad(true, true, true); } } // namespace |