aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-13 21:30:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-13 21:34:12 -0700
commitc54693a648730c9223e1aa0a118322f6f0ae8b17 (patch)
tree50ebfacd3d25c0005ead74d76dbe52ae9dc2db1e /tensorflow/cc/gradients
parentde0bc082f153e36f9919c2cac8fc1063fe3c9186 (diff)
Use the numeric gradient checker for matrix gradient tests.
- Port MathGradTest to use ComputeGradientError. - Add complex type support for MatMulGrad and BatchMatMulGrad. PiperOrigin-RevId: 168638169
Diffstat (limited to 'tensorflow/cc/gradients')
-rw-r--r--tensorflow/cc/gradients/math_grad.cc39
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc141
2 files changed, 78 insertions, 102 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index b88332ebc7..aba17cfe0c 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -790,42 +790,37 @@ Status MatMulGradHelper(const Scope& scope, const bool is_batch,
// MatMulGrad common used to read and check node attr state, and determine
// proper MatMul products for gradients based on input matrix transposition
// combinations.
-// TODO(andydavis) Re-use this function for BatchMatMulGrad.
Status MatMulGradCommon(const Scope& scope, const Operation& op,
const bool is_batch,
const std::vector<Output>& grad_inputs,
const string& attr_adj_x, const string& attr_adj_y,
std::vector<Output>* grad_outputs) {
- DataType dtype;
- TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype));
- if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
- return errors::Unimplemented(
- "MatMul gradient for complex data type is not supported yet.");
+ auto a = op.input(0);
+ auto b = op.input(1);
+ // Use conjugate of the inputs for MatMul
+ if (is_batch == false) {
+ a = ConjugateHelper(scope, a);
+ b = ConjugateHelper(scope, b);
}
+ auto product = op.output(0);
bool ta;
bool tb;
- TF_RETURN_IF_ERROR(
- GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta));
- TF_RETURN_IF_ERROR(
- GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb));
+ TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta));
+ TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb));
if (!ta && !tb) {
- return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
- true, op.input(0), true, grad_inputs[0], false,
- grad_outputs);
+ return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a,
+ true, grad_inputs[0], false, grad_outputs);
} else if (!ta && tb) {
- return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
- false, grad_inputs[0], true, op.input(0), false,
- grad_outputs);
+ return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false,
+ grad_inputs[0], true, a, false, grad_outputs);
} else if (ta && !tb) {
- return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0],
- true, op.input(0), false, grad_inputs[0], false,
- grad_outputs);
+ return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a,
+ false, grad_inputs[0], false, grad_outputs);
}
- return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0],
- true, grad_inputs[0], true, op.input(0), true,
- grad_outputs);
+ return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true,
+ grad_inputs[0], true, a, true, grad_outputs);
}
Status MatMulGrad(const Scope& scope, const Operation& op,
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 047243aa6a..3534f16e8f 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -507,77 +507,34 @@ class MathGradTest : public ::testing::Test {
protected:
MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
+ template <typename T>
void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
- // Generate random test data.
- std::vector<Tensor> data;
- RandMatMulGradData(is_batch, t_x, t_y, &data);
- auto x = Const(root_, data[0]);
- auto y = Const(root_, data[1]);
- auto dz = Const(root_, data[2]);
-
- std::vector<Tensor> grad_outputs;
- ComputeMatMulGrad(is_batch, x, t_x, y, t_y, dz, &grad_outputs);
-
- if (!t_x && !t_y) {
- test::ExpectClose(grad_outputs[0],
- ComputeMatMul(is_batch, dz, false, y, true));
- test::ExpectClose(grad_outputs[1],
- ComputeMatMul(is_batch, x, true, dz, false));
- } else if (t_x && !t_y) {
- test::ExpectClose(grad_outputs[0],
- ComputeMatMul(is_batch, y, false, dz, true));
- test::ExpectClose(grad_outputs[1],
- ComputeMatMul(is_batch, x, false, dz, false));
- } else if (!t_x && t_y) {
- test::ExpectClose(grad_outputs[0],
- ComputeMatMul(is_batch, dz, false, y, false));
- test::ExpectClose(grad_outputs[1],
- ComputeMatMul(is_batch, dz, true, x, false));
- } else {
- test::ExpectClose(grad_outputs[0],
- ComputeMatMul(is_batch, y, true, dz, true));
- test::ExpectClose(grad_outputs[1],
- ComputeMatMul(is_batch, dz, true, x, true));
- }
- }
-
- void ComputeMatMulGrad(const bool is_batch, const Output& x, const bool t_x,
- const Output& y, const bool t_y, const Output& dz,
- std::vector<Tensor>* out) {
- // Compute forward MatMul: z = MatMul(x, y).
- Output z;
- if (is_batch) {
- z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
- } else {
- z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
- }
TF_ASSERT_OK(root_.status());
- CHECK_NOTNULL(z.node());
- std::vector<Output> grad_outputs;
- // Call MatMulGrad which populates 'grad_outputs'.
- TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz},
- &grad_outputs));
- ASSERT_EQ(2, grad_outputs.size());
- // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
- test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
- }
-
- Tensor ComputeMatMul(const bool is_batch, const Output& x, const bool t_x,
- const Output& y, const bool t_y) {
+ // Generate random (but compatible) shapes for matrix multiplication.
+ std::vector<TensorShape> shapes;
+ RandMatMulShapes(is_batch, t_x, t_y, &shapes);
+ TensorShape x_shape = shapes[0];
+ TensorShape y_shape = shapes[1];
+ TensorShape z_shape = shapes[2];
+ auto x =
+ Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape));
+ auto y =
+ Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape));
Output z;
if (is_batch) {
z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
} else {
z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
}
- TF_EXPECT_OK(root_.status());
- Tensor out;
- test::GetTensor(root_, z, &out);
- return out;
+
+ float max_error;
+ TF_ASSERT_OK((ComputeGradientError<T, T, float>(
+ root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error)));
+ EXPECT_LT(max_error, 1e-3);
}
- void RandMatMulGradData(const bool is_batch, const bool tx, const bool ty,
- std::vector<Tensor>* data) {
+ void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty,
+ std::vector<TensorShape>* shapes) {
// Choose a random batch size in [1, 4]
const int b = 1 + (random::New64() % 4);
// z = MatMul(x, y)
@@ -593,8 +550,7 @@ class MathGradTest : public ::testing::Test {
// x.shape = [m, k]
x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
}
- data->emplace_back(DT_FLOAT, x_shape);
- RandTensor(&data->back());
+ shapes->push_back(x_shape);
TensorShape y_shape;
if (is_batch) {
@@ -604,8 +560,7 @@ class MathGradTest : public ::testing::Test {
// y.shape = [k, n]
y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
}
- data->emplace_back(DT_FLOAT, y_shape);
- RandTensor(&data->back());
+ shapes->push_back(y_shape);
TensorShape z_shape;
if (is_batch) {
@@ -615,13 +570,7 @@ class MathGradTest : public ::testing::Test {
// z.shape = [m, n]
z_shape = TensorShape({m, n});
}
- data->emplace_back(DT_FLOAT, z_shape);
- RandTensor(&data->back());
- }
-
- void RandTensor(Tensor* t) {
- test::FillFn<float>(
- t, [this](const int i) { return static_cast<float>(Rand()); });
+ shapes->push_back(z_shape);
}
int Rand() { return 1 + (random::New64() % 10); }
@@ -630,35 +579,67 @@ class MathGradTest : public ::testing::Test {
};
TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
- TestMatMulGrad(false, false, false);
+ TestMatMulGrad<float>(false, false, false);
+}
+
+TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) {
+ TestMatMulGrad<complex64>(false, false, false);
}
TEST_F(MathGradTest, MatMulGrad_TransposeX) {
- TestMatMulGrad(false, true, false);
+ TestMatMulGrad<float>(false, true, false);
+}
+
+TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) {
+ TestMatMulGrad<complex64>(false, true, false);
}
TEST_F(MathGradTest, MatMulGrad_TransposeY) {
- TestMatMulGrad(false, false, true);
+ TestMatMulGrad<float>(false, false, true);
+}
+
+TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) {
+ TestMatMulGrad<complex64>(false, false, true);
}
TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
- TestMatMulGrad(false, true, true);
+ TestMatMulGrad<float>(false, true, true);
+}
+
+TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) {
+ TestMatMulGrad<complex64>(false, true, true);
}
TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
- TestMatMulGrad(true, false, false);
+ TestMatMulGrad<float>(true, false, false);
+}
+
+TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) {
+ TestMatMulGrad<complex64>(true, false, false);
}
TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
- TestMatMulGrad(true, true, false);
+ TestMatMulGrad<float>(true, true, false);
+}
+
+TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) {
+ TestMatMulGrad<complex64>(true, true, false);
}
TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
- TestMatMulGrad(true, false, true);
+ TestMatMulGrad<float>(true, false, true);
+}
+
+TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) {
+ TestMatMulGrad<complex64>(true, false, true);
}
TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
- TestMatMulGrad(true, true, true);
+ TestMatMulGrad<float>(true, true, true);
+}
+
+TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) {
+ TestMatMulGrad<complex64>(true, true, true);
}
class NaryGradTest : public ::testing::Test {