diff options
author | 2018-08-06 15:05:50 -0700 | |
---|---|---|
committer | 2018-08-06 15:10:04 -0700 | |
commit | cb29a6b2217a140c248188015da424669fd08e54 (patch) | |
tree | 2b9e718b900551937f12132b3c0d81ff5f0f1d2f /tensorflow/core/ops/math_grad_test.cc | |
parent | 9eacf865fac500480f3c2708539e6c6893f2a36a (diff) |
Add c++ gradient for cast op.
PiperOrigin-RevId: 207615481
Diffstat (limited to 'tensorflow/core/ops/math_grad_test.cc')
-rw-r--r-- | tensorflow/core/ops/math_grad_test.cc | 54 |
1 files changed, 40 insertions, 14 deletions
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc index cfa3a64328..2a27ef3ddb 100644 --- a/tensorflow/core/ops/math_grad_test.cc +++ b/tensorflow/core/ops/math_grad_test.cc @@ -38,42 +38,45 @@ std::unique_ptr<Session> NewSession() { class MathGradTest : public ::testing::Test { protected: // Unary - Status Unary(const string& op, const Tensor& x, Tensor* y) { - const DataType T = x.dtype(); - auto adef = [T](const string& name) { // E.g., x:float, dy:double - return strings::StrCat(name, ":", DataTypeString(T)); + // dst is the output dtype of op_node. + Status Unary(const FDH::Node& op_node, const Tensor& x, const DataType dst, + Tensor* y) { + const DataType src = x.dtype(); + auto adef = [](const string& name, + const DataType type) { // E.g., x:float, dy:double + return strings::StrCat(name, ":", DataTypeString(type)); }; // Sum(op(x)), sum all output of op(x). - auto test = FDH::Define("Test", {adef("x")}, {adef("l")}, {}, + auto test = FDH::Define("Test", {adef("x", src)}, {adef("l", dst)}, {}, { - {{"y"}, op, {"x"}, {{"T", T}}}, + op_node, FDH::Const("zero", 0), FDH::Const("one", 1), - {{"r"}, "Rank", {"x"}, {{"T", T}}}, + {{"r"}, "Rank", {"x"}, {{"T", src}}}, {{"indices"}, "Range", {"zero", "r", "one"}}, - {{"l"}, "Sum", {"y", "indices"}, {{"T", T}}}, + {{"l"}, "Sum", {"y", "indices"}, {{"T", dst}}}, }); // TestGrad = Test'(x) auto grad = FDH::Define( - "TestGrad", {adef("x")}, {adef("dx")}, {}, + "TestGrad", {adef("x", src)}, {adef("dx", src)}, {}, { FDH::Const("one", 1), - {{"dy"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}}, + {{"dy"}, "Cast", {"one"}, {{"DstT", dst}, {"SrcT", DT_INT32}}}, {{"grad"}, "SymbolicGradient", {"x", "dy"}, { {"f", FDH::FunctionRef("Test")}, - {"Tin", DataTypeSlice{T, T}}, - {"Tout", DataTypeSlice{T}}, + {"Tin", DataTypeSlice{src, dst}}, + {"Tout", DataTypeSlice{src}}, }}, - {{"dx"}, "Identity", {"grad"}, {{"T", T}}}, + {{"dx"}, "Identity", {"grad"}, {{"T", src}}}, }); // Each test case will feed in "x:0" and expects to get "dx:0". auto gdef = test::function::GDef( { - f::NDef("x", "Placeholder", {}, {{"dtype", T}}), + f::NDef("x", "Placeholder", {}, {{"dtype", src}}), f::NDef("dx", "TestGrad", {"x"}, {}), }, {test, grad}); @@ -90,6 +93,11 @@ class MathGradTest : public ::testing::Test { return s; } + Status Unary(const string& op, const Tensor& x, Tensor* y) { + const FDH::Node op_node = {{"y"}, op, {"x"}, {{"T", x.dtype()}}}; + return Unary(op_node, x, x.dtype(), y); + } + // Unary op expecting OK. Tensor SymGrad(const string& op, const Tensor& x) { Tensor ret; @@ -97,6 +105,14 @@ class MathGradTest : public ::testing::Test { return ret; } + Tensor SymCastGrad(const Tensor& x, const DataType dst) { + Tensor ret; + const FDH::Node op_node = { + {"y"}, "Cast", {"x"}, {{"SrcT", x.dtype()}, {"DstT", dst}}}; + TF_CHECK_OK(Unary(op_node, x, dst, &ret)); + return ret; + } + // Binary void SymGrad(const string& op, const Tensor& x, const Tensor& y, Tensor* dx, Tensor* dy) { @@ -609,6 +625,16 @@ TEST_F(MathGradTest, Cos) { test::ExpectClose(ans, dx); } +TEST_F(MathGradTest, Cast) { + auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f}, + TensorShape({2, 3})); + auto g = [](float x) { return 1.f; }; + auto dx = test::AsTensor<float>( + {g(-3.f), g(-2.f), g(-1.f), g(1.f), g(2.f), g(3.f)}, TensorShape({2, 3})); + Tensor ans = SymCastGrad(x, DT_INT32); + test::ExpectClose(ans, dx); +} + // TODO(zhifengc) // TEST_F(MathGradSComplexTest, Real) {} // TEST_F(MathGradSComplexTest, Imag) {} |