aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-12-06 18:43:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-06 18:47:41 -0800
commitfe8406149feec453250905965a14285465cd2063 (patch)
treebe3cd75d543f3c0f29f368da61d915abbae7fcbf /tensorflow/cc/gradients
parent8ad62af489df718992561710123bc8c037e7d17b (diff)
Merge changes from github.
PiperOrigin-RevId: 178185697
Diffstat (limited to 'tensorflow/cc/gradients')
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc12
-rw-r--r--tensorflow/cc/gradients/nn_grad_test.cc7
2 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index 09fadfcab5..13a3bba5e6 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -196,6 +196,18 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper);
+Status LRNGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs){
+ internal::LRNGrad::Attrs grad_attrs;
+
+ auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0),
+ grad_attrs);
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("LRN", LRNGradHelper);
+
} // anonymous namespace
} // namespace ops
} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index ac66f51cf0..f9063e8365 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -191,5 +191,12 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
RunTest(x, x_init_value, y, y_shape);
}
+TEST_F(NNGradTest, LRN){
+ TensorShape x_shape({1, 1, 2, 1});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+ auto y = LRN(scope_, x);
+ RunTest(x, x_shape, y, x_shape);
+}
+
} // namespace
} // namespace tensorflow