aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/adding_an_op
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-03 12:42:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-03 13:05:58 -0800
commit88c9fb09bd667df03cdf7e9f75ff225853ad01e1 (patch)
tree7f352980385adf8cd8e851554c393927917ab9be /tensorflow/examples/adding_an_op
parent20af1b7baa0452256c761a5961966a82b796162a (diff)
Restore the adding_an_op code examples that used to live under
g3doc/, now under examples/. Partial fix of #8029. Change: 149142119
Diffstat (limited to 'tensorflow/examples/adding_an_op')
-rw-r--r--tensorflow/examples/adding_an_op/BUILD157
-rw-r--r--tensorflow/examples/adding_an_op/__init__.py0
-rw-r--r--tensorflow/examples/adding_an_op/attr_examples.cc46
-rw-r--r--tensorflow/examples/adding_an_op/cuda_op.py27
-rw-r--r--tensorflow/examples/adding_an_op/cuda_op_kernel.cc55
-rw-r--r--tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc31
-rw-r--r--tensorflow/examples/adding_an_op/cuda_op_test.py35
-rw-r--r--tensorflow/examples/adding_an_op/fact_test.py32
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_1_test.py42
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_2_test.py59
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_3_test.py52
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_grad_2.py44
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_op_1.py27
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_op_2.py29
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_op_3.py27
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_op_kernel_1.cc62
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc115
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_op_kernel_3.cc72
18 files changed, 912 insertions, 0 deletions
diff --git a/tensorflow/examples/adding_an_op/BUILD b/tensorflow/examples/adding_an_op/BUILD
new file mode 100644
index 0000000000..ffaf9349d2
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/BUILD
@@ -0,0 +1,157 @@
+# Description:
+# Code examples referenced by adding_an_op
+
+package(
+ default_visibility = ["//tensorflow:internal"],
+ features = [
+ "-layering_check",
+ "-parse_headers",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
+load("//tensorflow:tensorflow.bzl", "tf_cuda_tests_tags")
+
+exports_files(["LICENSE"])
+
+tf_custom_op_library(
+ name = "zero_out_op_kernel_1.so",
+ srcs = ["zero_out_op_kernel_1.cc"],
+)
+
+py_library(
+ name = "zero_out_op_1",
+ srcs = ["zero_out_op_1.py"],
+ data = [":zero_out_op_kernel_1.so"],
+ srcs_version = "PY2AND3",
+)
+
+tf_custom_op_library(
+ name = "zero_out_op_kernel_2.so",
+ srcs = ["zero_out_op_kernel_2.cc"],
+)
+
+py_library(
+ name = "zero_out_op_2",
+ srcs = ["zero_out_op_2.py"],
+ data = [":zero_out_op_kernel_2.so"],
+ srcs_version = "PY2AND3",
+)
+
+tf_custom_op_library(
+ name = "zero_out_op_kernel_3.so",
+ srcs = ["zero_out_op_kernel_3.cc"],
+)
+
+py_library(
+ name = "zero_out_op_3",
+ srcs = ["zero_out_op_3.py"],
+ data = [":zero_out_op_kernel_3.so"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "zero_out_grad_2",
+ srcs = ["zero_out_grad_2.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":zero_out_op_2",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:sparse_ops",
+ ],
+)
+
+py_test(
+ name = "zero_out_1_test",
+ size = "small",
+ srcs = ["zero_out_1_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":zero_out_op_1",
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "zero_out_2_test",
+ size = "small",
+ srcs = ["zero_out_2_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":zero_out_grad_2",
+ ":zero_out_op_2",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "zero_out_3_test",
+ size = "small",
+ srcs = ["zero_out_3_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":zero_out_op_3",
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+tf_custom_op_library(
+ name = "cuda_op_kernel.so",
+ srcs = ["cuda_op_kernel.cc"],
+ gpu_srcs = ["cuda_op_kernel.cu.cc"],
+)
+
+py_library(
+ name = "cuda_op",
+ srcs = ["cuda_op.py"],
+ data = [":cuda_op_kernel.so"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "cuda_op_test",
+ size = "small",
+ srcs = ["cuda_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = tf_cuda_tests_tags(),
+ deps = [
+ ":cuda_op",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "fact_test",
+ size = "small",
+ srcs = ["fact_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cc_binary(
+ name = "attr_examples",
+ srcs = ["attr_examples.cc"],
+ deps = [
+ "//tensorflow/core",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/examples/adding_an_op/__init__.py b/tensorflow/examples/adding_an_op/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/__init__.py
diff --git a/tensorflow/examples/adding_an_op/attr_examples.cc b/tensorflow/examples/adding_an_op/attr_examples.cc
new file mode 100644
index 0000000000..4eb35668ce
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/attr_examples.cc
@@ -0,0 +1,46 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdio.h>
+#include "tensorflow/core/framework/op.h"
+
+REGISTER_OP("RestrictedTypeExample").Attr("t: {int32, float, bool}");
+
+REGISTER_OP("NumberType").Attr("t: numbertype");
+
+REGISTER_OP("EnumExample").Attr("e: {'apple', 'orange'}");
+
+REGISTER_OP("MinIntExample").Attr("a: int >= 2");
+
+REGISTER_OP("TypeListExample").Attr("a: list({int32, float}) >= 3");
+
+REGISTER_OP("AttrDefaultExample").Attr("i: int = 0");
+
+REGISTER_OP("AttrDefaultExampleForAllTypes")
+ .Attr("s: string = 'foo'")
+ .Attr("i: int = 0")
+ .Attr("f: float = 1.0")
+ .Attr("b: bool = true")
+ .Attr("ty: type = DT_INT32")
+ .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
+ .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
+ .Attr("l_empty: list(int) = []")
+ .Attr("l_int: list(int) = [2, 3, 5, 7]");
+
+int main(int argc, char* argv[]) {
+ printf("All registered ops:\n%s\n",
+ tensorflow::OpRegistry::Global()->DebugString(false).c_str());
+ return 0;
+}
diff --git a/tensorflow/examples/adding_an_op/cuda_op.py b/tensorflow/examples/adding_an_op/cuda_op.py
new file mode 100644
index 0000000000..dd5428870b
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/cuda_op.py
@@ -0,0 +1,27 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cuda op Python library."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+import tensorflow as tf
+
+if tf.test.is_built_with_cuda():
+ _cuda_op_module = tf.load_op_library(os.path.join(
+ tf.resource_loader.get_data_files_path(), 'cuda_op_kernel.so'))
+ add_one = _cuda_op_module.add_one
diff --git a/tensorflow/examples/adding_an_op/cuda_op_kernel.cc b/tensorflow/examples/adding_an_op/cuda_op_kernel.cc
new file mode 100644
index 0000000000..2f323b43d8
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/cuda_op_kernel.cc
@@ -0,0 +1,55 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+using namespace tensorflow; // NOLINT(build/namespaces)
+
+REGISTER_OP("AddOne")
+ .Input("input: int32")
+ .Output("output: int32")
+ .Doc(R"doc(
+Adds 1 to all elements of the tensor.
+
+output: A Tensor.
+ output = input + 1
+)doc");
+
+void AddOneKernelLauncher(const int* in, const int N, int* out);
+
+class AddOneOp : public OpKernel {
+ public:
+ explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& input_tensor = context->input(0);
+ auto input = input_tensor.flat<int32>();
+
+ // Create an output tensor
+ Tensor* output_tensor = NULL;
+ OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
+ &output_tensor));
+ auto output = output_tensor->template flat<int32>();
+
+ // Set all but the first element of the output tensor to 0.
+ const int N = input.size();
+ // Call the cuda kernel launcher
+ AddOneKernelLauncher(input.data(), N, output.data());
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_GPU), AddOneOp);
diff --git a/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc b/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc
new file mode 100644
index 0000000000..65b50bd3ae
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc
@@ -0,0 +1,31 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+__global__ void AddOneKernel(const int* in, const int N, int* out) {
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
+ i += blockDim.x * gridDim.x) {
+ out[i] = in[i] + 1;
+ }
+}
+
+void AddOneKernelLauncher(const int* in, const int N, int* out) {
+ AddOneKernel<<<32, 256>>>(in, N, out);
+}
+
+#endif
diff --git a/tensorflow/examples/adding_an_op/cuda_op_test.py b/tensorflow/examples/adding_an_op/cuda_op_test.py
new file mode 100644
index 0000000000..07390bc3bf
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/cuda_op_test.py
@@ -0,0 +1,35 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for version 1 of the zero_out op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.examples.adding_an_op import cuda_op
+
+
+class AddOneTest(tf.test.TestCase):
+
+ def test(self):
+ if tf.test.is_built_with_cuda():
+ with self.test_session():
+ result = cuda_op.add_one([5, 4, 3, 2, 1])
+ self.assertAllEqual(result.eval(), [6, 5, 4, 3, 2])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/examples/adding_an_op/fact_test.py b/tensorflow/examples/adding_an_op/fact_test.py
new file mode 100644
index 0000000000..f7f17e5180
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/fact_test.py
@@ -0,0 +1,32 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Test that user ops can be used as expected."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class FactTest(tf.test.TestCase):
+
+ def test(self):
+ with self.test_session():
+ print(tf.user_ops.my_fact().eval())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/examples/adding_an_op/zero_out_1_test.py b/tensorflow/examples/adding_an_op/zero_out_1_test.py
new file mode 100644
index 0000000000..fac486100d
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/zero_out_1_test.py
@@ -0,0 +1,42 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Test for version 1 of the zero_out op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+import tensorflow as tf
+from tensorflow.examples.adding_an_op import zero_out_op_1
+
+
+class ZeroOut1Test(tf.test.TestCase):
+
+ def test(self):
+ with self.test_session():
+ result = zero_out_op_1.zero_out([5, 4, 3, 2, 1])
+ self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
+
+ def testLoadTwice(self):
+ zero_out_loaded_again = tf.load_op_library(os.path.join(
+ tf.resource_loader.get_data_files_path(), 'zero_out_op_kernel_1.so'))
+ self.assertEqual(zero_out_loaded_again, zero_out_op_1._zero_out_module)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/examples/adding_an_op/zero_out_2_test.py b/tensorflow/examples/adding_an_op/zero_out_2_test.py
new file mode 100644
index 0000000000..217bbbcffa
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/zero_out_2_test.py
@@ -0,0 +1,59 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Test for version 2 of the zero_out op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+from tensorflow.examples.adding_an_op import zero_out_grad_2 # pylint: disable=unused-import
+from tensorflow.examples.adding_an_op import zero_out_op_2
+
+
+class ZeroOut2Test(tf.test.TestCase):
+
+ def test(self):
+ with self.test_session():
+ result = zero_out_op_2.zero_out([5, 4, 3, 2, 1])
+ self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
+
+ def test_2d(self):
+ with self.test_session():
+ result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]])
+ self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]])
+
+ def test_grad(self):
+ with self.test_session():
+ shape = (5,)
+ x = tf.constant([5, 4, 3, 2, 1], dtype=tf.float32)
+ y = zero_out_op_2.zero_out(x)
+ err = tf.test.compute_gradient_error(x, shape, y, shape)
+ self.assertLess(err, 1e-4)
+
+ def test_grad_2d(self):
+ with self.test_session():
+ shape = (2, 3)
+ x = tf.constant([[6, 5, 4], [3, 2, 1]], dtype=tf.float32)
+ y = zero_out_op_2.zero_out(x)
+ err = tf.test.compute_gradient_error(x, shape, y, shape)
+ self.assertLess(err, 1e-4)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/examples/adding_an_op/zero_out_3_test.py b/tensorflow/examples/adding_an_op/zero_out_3_test.py
new file mode 100644
index 0000000000..01280caf49
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/zero_out_3_test.py
@@ -0,0 +1,52 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Test for version 3 of the zero_out op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.examples.adding_an_op import zero_out_op_3
+
+
+class ZeroOut3Test(tf.test.TestCase):
+
+ def test(self):
+ with self.test_session():
+ result = zero_out_op_3.zero_out([5, 4, 3, 2, 1])
+ self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
+
+ def testAttr(self):
+ with self.test_session():
+ result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=3)
+ self.assertAllEqual(result.eval(), [0, 0, 0, 2, 0])
+
+ def testNegative(self):
+ with self.test_session():
+ result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=-1)
+ with self.assertRaisesOpError("Need preserve_index >= 0, got -1"):
+ result.eval()
+
+ def testLarge(self):
+ with self.test_session():
+ result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=17)
+ with self.assertRaisesOpError("preserve_index out of range"):
+ result.eval()
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/examples/adding_an_op/zero_out_grad_2.py b/tensorflow/examples/adding_an_op/zero_out_grad_2.py
new file mode 100644
index 0000000000..dc24678e33
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/zero_out_grad_2.py
@@ -0,0 +1,44 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""The gradient of the tutorial zero_out op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import sparse_ops
+
+
+@ops.RegisterGradient("ZeroOut")
+def _zero_out_grad(op, grad):
+ """The gradients for `zero_out`.
+
+ Args:
+ op: The `zero_out` `Operation` that we are differentiating, which we can use
+ to find the inputs and outputs of the original op.
+ grad: Gradient with respect to the output of the `zero_out` op.
+
+ Returns:
+ Gradients with respect to the input of `zero_out`.
+ """
+ to_zero = op.inputs[0]
+ shape = array_ops.shape(to_zero)
+ index = array_ops.zeros_like(shape)
+ first_grad = array_ops.reshape(grad, [-1])[0]
+ to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
+ return [to_zero_grad] # List of one Tensor, since we have one input
diff --git a/tensorflow/examples/adding_an_op/zero_out_op_1.py b/tensorflow/examples/adding_an_op/zero_out_op_1.py
new file mode 100644
index 0000000000..6bd98b1f06
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/zero_out_op_1.py
@@ -0,0 +1,27 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ZeroOut op Python library."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+import tensorflow as tf
+
+_zero_out_module = tf.load_op_library(
+ os.path.join(tf.resource_loader.get_data_files_path(),
+ 'zero_out_op_kernel_1.so'))
+zero_out = _zero_out_module.zero_out
diff --git a/tensorflow/examples/adding_an_op/zero_out_op_2.py b/tensorflow/examples/adding_an_op/zero_out_op_2.py
new file mode 100644
index 0000000000..ba1e8b4d62
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/zero_out_op_2.py
@@ -0,0 +1,29 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ZeroOut ops Python library."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+import tensorflow as tf
+
+_zero_out_module = tf.load_op_library(
+ os.path.join(tf.resource_loader.get_data_files_path(),
+ 'zero_out_op_kernel_2.so'))
+zero_out = _zero_out_module.zero_out
+zero_out2 = _zero_out_module.zero_out2
+zero_out3 = _zero_out_module.zero_out3
diff --git a/tensorflow/examples/adding_an_op/zero_out_op_3.py b/tensorflow/examples/adding_an_op/zero_out_op_3.py
new file mode 100644
index 0000000000..4354ecaf5a
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/zero_out_op_3.py
@@ -0,0 +1,27 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ZeroOut op Python library."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+import tensorflow as tf
+
+_zero_out_module = tf.load_op_library(
+ os.path.join(tf.resource_loader.get_data_files_path(),
+ 'zero_out_op_kernel_3.so'))
+zero_out = _zero_out_module.zero_out
diff --git a/tensorflow/examples/adding_an_op/zero_out_op_kernel_1.cc b/tensorflow/examples/adding_an_op/zero_out_op_kernel_1.cc
new file mode 100644
index 0000000000..cc8c719c1a
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/zero_out_op_kernel_1.cc
@@ -0,0 +1,62 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+using namespace tensorflow; // NOLINT(build/namespaces)
+
+REGISTER_OP("ZeroOut")
+ .Input("to_zero: int32")
+ .Output("zeroed: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Zeros out all but the first value of a Tensor.
+
+zeroed: A Tensor whose first value is identical to `to_zero`, and 0
+ otherwise.
+)doc");
+
+class ZeroOutOp : public OpKernel {
+ public:
+ explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& input_tensor = context->input(0);
+ auto input = input_tensor.flat<int32>();
+
+ // Create an output tensor
+ Tensor* output_tensor = NULL;
+ OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
+ &output_tensor));
+ auto output = output_tensor->template flat<int32>();
+
+ // Set all but the first element of the output tensor to 0.
+ const int N = input.size();
+ for (int i = 1; i < N; i++) {
+ output(i) = 0;
+ }
+
+ // Preserve the first input value.
+ if (N > 0) output(0) = input(0);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
diff --git a/tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc b/tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc
new file mode 100644
index 0000000000..3aa18c7307
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc
@@ -0,0 +1,115 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+using namespace tensorflow; // NOLINT(build/namespaces)
+
+REGISTER_OP("ZeroOut")
+ .Attr("T: realnumbertype")
+ .Input("to_zero: T")
+ .Output("zeroed: T")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Zeros out all but the first value of a Tensor.
+
+zeroed: A Tensor whose first value is identical to `to_zero`, and 0
+ otherwise.
+)doc");
+
+REGISTER_OP("ZeroOut2")
+ .Attr("T: realnumbertype")
+ .Input("to_zero: T")
+ .Output("zeroed: T")
+ .Doc(R"doc(
+Zeros out all but the first value of a Tensor.
+
+zeroed: A Tensor whose first value is identical to `to_zero`, and 0
+ otherwise.
+)doc");
+
+REGISTER_OP("ZeroOut3")
+ .Attr("T: realnumbertype")
+ .Input("to_zero: T")
+ .Output("zeroed: T")
+ .Doc(R"doc(
+Zeros out all but the first value of a Tensor.
+
+zeroed: A Tensor whose first value is identical to `to_zero`, and 0
+ otherwise.
+)doc");
+
+template <typename T>
+class ZeroOutOp : public OpKernel {
+ public:
+ explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& input_tensor = context->input(0);
+ auto input = input_tensor.flat<T>();
+
+ // Create an output tensor
+ Tensor* output = NULL;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_tensor.shape(), &output));
+ auto output_flat = output->template flat<T>();
+
+ // Set all the elements of the output tensor to 0
+ const int N = input.size();
+ for (int i = 0; i < N; i++) {
+ output_flat(i) = T(0);
+ }
+
+ // Preserve the first input value
+ if (N > 0) output_flat(0) = input(0);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ZeroOut")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ ZeroOutOp<float>);
+REGISTER_KERNEL_BUILDER(Name("ZeroOut")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T"),
+ ZeroOutOp<double>);
+REGISTER_KERNEL_BUILDER(Name("ZeroOut")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<int>("T"),
+ ZeroOutOp<int>);
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ZeroOut2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ZeroOutOp<type>)
+
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+REGISTER_KERNEL(int32);
+
+#undef REGISTER_KERNEL
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ZeroOut3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ZeroOutOp<type>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
+
+#undef REGISTER_KERNEL
diff --git a/tensorflow/examples/adding_an_op/zero_out_op_kernel_3.cc b/tensorflow/examples/adding_an_op/zero_out_op_kernel_3.cc
new file mode 100644
index 0000000000..76f6efa334
--- /dev/null
+++ b/tensorflow/examples/adding_an_op/zero_out_op_kernel_3.cc
@@ -0,0 +1,72 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+using namespace tensorflow; // NOLINT(build/namespaces)
+
+REGISTER_OP("ZeroOut")
+ .Attr("preserve_index: int = 0")
+ .Input("to_zero: int32")
+ .Output("zeroed: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ });
+
+class ZeroOutOp : public OpKernel {
+ public:
+ explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
+ // Get the index of the value to preserve
+ OP_REQUIRES_OK(context,
+ context->GetAttr("preserve_index", &preserve_index_));
+ // Check that preserve\_index is positive
+ OP_REQUIRES(context, preserve_index_ >= 0,
+ errors::InvalidArgument("Need preserve_index >= 0, got ",
+ preserve_index_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& input_tensor = context->input(0);
+ auto input = input_tensor.flat<int32>();
+
+ // Check that preserve_index is in range
+ OP_REQUIRES(context, preserve_index_ < input.dimension(0),
+ errors::InvalidArgument("preserve_index out of range"));
+
+ // Create an output tensor
+ Tensor* output_tensor = NULL;
+ OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
+ &output_tensor));
+ auto output = output_tensor->template flat<int32>();
+
+ // Set all the elements of the output tensor to 0
+ const int N = input.size();
+ for (int i = 0; i < N; i++) {
+ output(i) = 0;
+ }
+
+ // Preserve the requested input value
+ output(preserve_index_) = input(preserve_index_);
+ }
+
+ private:
+ int preserve_index_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);