diff options
author | 2017-09-12 10:16:26 -0700 | |
---|---|---|
committer | 2017-09-12 10:20:30 -0700 | |
commit | e6b011763a60d239972c8c6c0f36536ab6f885a3 (patch) | |
tree | 8930a1e6f5efa50c860683ea86807335c7470cbf /tensorflow/cc/gradients/data_flow_grad_test.cc | |
parent | f63aa7f49f81a66112bfef6670a18658d5a479e5 (diff) |
Extend c++ gradient_checker to complex types.
PiperOrigin-RevId: 168392949
Diffstat (limited to 'tensorflow/cc/gradients/data_flow_grad_test.cc')
-rw-r--r-- | tensorflow/cc/gradients/data_flow_grad_test.cc | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/cc/gradients/data_flow_grad_test.cc b/tensorflow/cc/gradients/data_flow_grad_test.cc index 3d027909f0..734dfd3af9 100644 --- a/tensorflow/cc/gradients/data_flow_grad_test.cc +++ b/tensorflow/cc/gradients/data_flow_grad_test.cc @@ -35,8 +35,8 @@ class DataFlowGradTest : public ::testing::Test { const OutputList& ys, const std::vector<TensorShape>& y_shapes) { TF_ASSERT_OK(scope_.status()); float max_error; - TF_ASSERT_OK( - ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error)); + TF_ASSERT_OK((ComputeGradientError<float, float, float>( + scope_, xs, x_shapes, ys, y_shapes, &max_error))); EXPECT_LT(max_error, 1e-4); } |