aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-10-05 11:13:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 11:17:26 -0700
commitb1325838aaf902e52fae4b085c6396848c445062 (patch)
tree4ee30408eec000f133c87f1ff2f617cf9f1d3a21 /tensorflow/core
parentdd8afaad37fdb284dce3518a9be22aca1c25e475 (diff)
Declare that stateless random ops are not differentiable in C++ code.
PiperOrigin-RevId: 215935319
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/ops/stateless_random_grad.cc23
2 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6a3ee3c1cb..900a0e11c4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1242,6 +1242,7 @@ cc_library(
srcs = [
"ops/math_grad.cc",
"ops/random_grad.cc",
+ "ops/stateless_random_grad.cc",
],
linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669
visibility = ["//visibility:public"],
diff --git a/tensorflow/core/ops/stateless_random_grad.cc b/tensorflow/core/ops/stateless_random_grad.cc
new file mode 100644
index 0000000000..331e1d0152
--- /dev/null
+++ b/tensorflow/core/ops/stateless_random_grad.cc
@@ -0,0 +1,23 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/function.h"
+
+namespace tensorflow {
+REGISTER_OP_NO_GRADIENT("StatelessRandomUniform");
+REGISTER_OP_NO_GRADIENT("StatelessRandomNormal");
+REGISTER_OP_NO_GRADIENT("StatelessTruncatedNormal");
+REGISTER_OP_NO_GRADIENT("StatelessMultinomial");
+} // end namespace tensorflow