aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/math_grad_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-06 15:05:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 15:10:04 -0700
commitcb29a6b2217a140c248188015da424669fd08e54 (patch)
tree2b9e718b900551937f12132b3c0d81ff5f0f1d2f /tensorflow/core/ops/math_grad_test.cc
parent9eacf865fac500480f3c2708539e6c6893f2a36a (diff)
Add c++ gradient for cast op.
PiperOrigin-RevId: 207615481
Diffstat (limited to 'tensorflow/core/ops/math_grad_test.cc')
-rw-r--r--tensorflow/core/ops/math_grad_test.cc54
1 files changed, 40 insertions, 14 deletions
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index cfa3a64328..2a27ef3ddb 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -38,42 +38,45 @@ std::unique_ptr<Session> NewSession() {
class MathGradTest : public ::testing::Test {
protected:
// Unary
- Status Unary(const string& op, const Tensor& x, Tensor* y) {
- const DataType T = x.dtype();
- auto adef = [T](const string& name) { // E.g., x:float, dy:double
- return strings::StrCat(name, ":", DataTypeString(T));
+ // dst is the output dtype of op_node.
+ Status Unary(const FDH::Node& op_node, const Tensor& x, const DataType dst,
+ Tensor* y) {
+ const DataType src = x.dtype();
+ auto adef = [](const string& name,
+ const DataType type) { // E.g., x:float, dy:double
+ return strings::StrCat(name, ":", DataTypeString(type));
};
// Sum(op(x)), sum all output of op(x).
- auto test = FDH::Define("Test", {adef("x")}, {adef("l")}, {},
+ auto test = FDH::Define("Test", {adef("x", src)}, {adef("l", dst)}, {},
{
- {{"y"}, op, {"x"}, {{"T", T}}},
+ op_node,
FDH::Const("zero", 0),
FDH::Const("one", 1),
- {{"r"}, "Rank", {"x"}, {{"T", T}}},
+ {{"r"}, "Rank", {"x"}, {{"T", src}}},
{{"indices"}, "Range", {"zero", "r", "one"}},
- {{"l"}, "Sum", {"y", "indices"}, {{"T", T}}},
+ {{"l"}, "Sum", {"y", "indices"}, {{"T", dst}}},
});
// TestGrad = Test'(x)
auto grad = FDH::Define(
- "TestGrad", {adef("x")}, {adef("dx")}, {},
+ "TestGrad", {adef("x", src)}, {adef("dx", src)}, {},
{
FDH::Const("one", 1),
- {{"dy"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
+ {{"dy"}, "Cast", {"one"}, {{"DstT", dst}, {"SrcT", DT_INT32}}},
{{"grad"},
"SymbolicGradient",
{"x", "dy"},
{
{"f", FDH::FunctionRef("Test")},
- {"Tin", DataTypeSlice{T, T}},
- {"Tout", DataTypeSlice{T}},
+ {"Tin", DataTypeSlice{src, dst}},
+ {"Tout", DataTypeSlice{src}},
}},
- {{"dx"}, "Identity", {"grad"}, {{"T", T}}},
+ {{"dx"}, "Identity", {"grad"}, {{"T", src}}},
});
// Each test case will feed in "x:0" and expects to get "dx:0".
auto gdef = test::function::GDef(
{
- f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
+ f::NDef("x", "Placeholder", {}, {{"dtype", src}}),
f::NDef("dx", "TestGrad", {"x"}, {}),
},
{test, grad});
@@ -90,6 +93,11 @@ class MathGradTest : public ::testing::Test {
return s;
}
+ Status Unary(const string& op, const Tensor& x, Tensor* y) {
+ const FDH::Node op_node = {{"y"}, op, {"x"}, {{"T", x.dtype()}}};
+ return Unary(op_node, x, x.dtype(), y);
+ }
+
// Unary op expecting OK.
Tensor SymGrad(const string& op, const Tensor& x) {
Tensor ret;
@@ -97,6 +105,14 @@ class MathGradTest : public ::testing::Test {
return ret;
}
+ Tensor SymCastGrad(const Tensor& x, const DataType dst) {
+ Tensor ret;
+ const FDH::Node op_node = {
+ {"y"}, "Cast", {"x"}, {{"SrcT", x.dtype()}, {"DstT", dst}}};
+ TF_CHECK_OK(Unary(op_node, x, dst, &ret));
+ return ret;
+ }
+
// Binary
void SymGrad(const string& op, const Tensor& x, const Tensor& y, Tensor* dx,
Tensor* dy) {
@@ -609,6 +625,16 @@ TEST_F(MathGradTest, Cos) {
test::ExpectClose(ans, dx);
}
+TEST_F(MathGradTest, Cast) {
+ auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
+ TensorShape({2, 3}));
+ auto g = [](float x) { return 1.f; };
+ auto dx = test::AsTensor<float>(
+ {g(-3.f), g(-2.f), g(-1.f), g(1.f), g(2.f), g(3.f)}, TensorShape({2, 3}));
+ Tensor ans = SymCastGrad(x, DT_INT32);
+ test::ExpectClose(ans, dx);
+}
+
// TODO(zhifengc)
// TEST_F(MathGradSComplexTest, Real) {}
// TEST_F(MathGradSComplexTest, Imag) {}