aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/gradients_impl.py1
-rw-r--r--tensorflow/python/ops/manip_grad.py32
-rw-r--r--tensorflow/python/ops/manip_ops.py36
-rw-r--r--tensorflow/python/ops/standard_ops.py4
4 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 314726ede6..230b6c5946 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import image_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import
from tensorflow.python.ops import logging_ops # pylint: disable=unused-import
+from tensorflow.python.ops import manip_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
diff --git a/tensorflow/python/ops/manip_grad.py b/tensorflow/python/ops/manip_grad.py
new file mode 100644
index 0000000000..573e8c0a0d
--- /dev/null
+++ b/tensorflow/python/ops/manip_grad.py
@@ -0,0 +1,32 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Gradients for operators defined in manip_ops.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import manip_ops
+
+
+@ops.RegisterGradient("Roll")
+def _RollGrad(op, grad):
+ # The gradient is just the roll reversed
+ shift = op.inputs[1]
+ axis = op.inputs[2]
+ roll_grad = manip_ops.roll(grad, -shift, axis)
+ return roll_grad, None, None
diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py
new file mode 100644
index 0000000000..c5f39784f4
--- /dev/null
+++ b/tensorflow/python/ops/manip_ops.py
@@ -0,0 +1,36 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators for manipulating tensors.
+
+@@roll
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
+from tensorflow.python.util.all_util import remove_undocumented
+
+# pylint: disable=protected-access
+def roll(input, shift, axis):
+ return _gen_manip_ops.roll(input, shift, axis)
+
+roll.__doc__ = _gen_manip_ops.roll.__doc__
+# pylint: enable=protected-access
+
+_allowed_symbols = ['roll']
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 30bf4e4ef1..737b923415 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -26,6 +26,7 @@ import sys as _sys
from tensorflow.python.ops import array_grad
from tensorflow.python.ops import data_flow_grad
from tensorflow.python.ops import math_grad
+from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import spectral_grad
from tensorflow.python.ops import state_grad
@@ -59,6 +60,7 @@ from tensorflow.python.ops.logging_ops import Print
from tensorflow.python.ops.logging_ops import get_summary_op
from tensorflow.python.ops.lookup_ops import initialize_all_tables
from tensorflow.python.ops.lookup_ops import tables_initializer
+from tensorflow.python.ops.manip_ops import *
from tensorflow.python.ops.math_ops import *
from tensorflow.python.ops.numerics import *
from tensorflow.python.ops.parsing_ops import *
@@ -105,6 +107,7 @@ from tensorflow.python.ops import init_ops as _init_ops
from tensorflow.python.ops import io_ops as _io_ops
from tensorflow.python.ops import linalg_ops as _linalg_ops
from tensorflow.python.ops import logging_ops as _logging_ops
+from tensorflow.python.ops import manip_ops as _manip_ops
from tensorflow.python.ops import math_ops as _math_ops
from tensorflow.python.ops import numerics as _numerics
from tensorflow.python.ops import parsing_ops as _parsing_ops
@@ -280,6 +283,7 @@ remove_undocumented(__name__, _allowed_symbols,
_io_ops,
_linalg_ops,
_logging_ops,
+ _manip_ops,
_math_ops,
_numerics,
_parsing_ops,