aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/data_flow_grad_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-12 10:16:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-12 10:20:30 -0700
commite6b011763a60d239972c8c6c0f36536ab6f885a3 (patch)
tree8930a1e6f5efa50c860683ea86807335c7470cbf /tensorflow/cc/gradients/data_flow_grad_test.cc
parentf63aa7f49f81a66112bfef6670a18658d5a479e5 (diff)
Extend c++ gradient_checker to complex types.
PiperOrigin-RevId: 168392949
Diffstat (limited to 'tensorflow/cc/gradients/data_flow_grad_test.cc')
-rw-r--r--tensorflow/cc/gradients/data_flow_grad_test.cc4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/cc/gradients/data_flow_grad_test.cc b/tensorflow/cc/gradients/data_flow_grad_test.cc
index 3d027909f0..734dfd3af9 100644
--- a/tensorflow/cc/gradients/data_flow_grad_test.cc
+++ b/tensorflow/cc/gradients/data_flow_grad_test.cc
@@ -35,8 +35,8 @@ class DataFlowGradTest : public ::testing::Test {
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
TF_ASSERT_OK(scope_.status());
float max_error;
- TF_ASSERT_OK(
- ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
+ TF_ASSERT_OK((ComputeGradientError<float, float, float>(
+ scope_, xs, x_shapes, ys, y_shapes, &max_error)));
EXPECT_LT(max_error, 1e-4);
}