aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-10-31 11:25:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-31 11:29:51 -0700
commit8a09bbc4a5de2cb8db20dd41112abec245eaff88 (patch)
tree67783676fe1617a1e29f60866e9d556fe71a80db /tensorflow
parent585432cc21f52ece2c7fd9bd21a45d40b1e63f42 (diff)
Add TFE_Py_TensorShapeSlice function
TFE_Py_TensorShapeSlice takes a list of EagerTensors and returns a list of their i'th dimensions. This utility is fairly niche but it is simple and reduces SPINN training time by over 12%. PiperOrigin-RevId: 174065044
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc68
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h12
-rw-r--r--tensorflow/python/eager/tensor_test.py89
-rw-r--r--tensorflow/python/ops/array_grad.py55
-rw-r--r--tensorflow/python/pywrap_tfe.i1
5 files changed, 205 insertions, 20 deletions
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 3adaea2c79..4cc8f91dbc 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -657,3 +657,71 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
EagerTensorType->tp_dictoffset = 0;
return reinterpret_cast<PyObject*>(EagerTensorType);
}
+
+PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) {
+ if (!PyList_Check(tensor_list)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "tensor_list argument must be a list. Got \"",
+ Py_TYPE(tensor_list)->tp_name, "\"")
+ .c_str());
+ return nullptr;
+ }
+ if (slice_dim < 0) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ tensorflow::strings::StrCat("Slice dimension must be non-negative. "
+ "Got ",
+ slice_dim)
+ .c_str());
+ return nullptr;
+ }
+
+ Py_ssize_t num_tensors = PyList_Size(tensor_list);
+ int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
+ auto tensor = tensorflow::make_safe(TF_AllocateTensor(
+ TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
+ int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
+ for (Py_ssize_t i = 0; i < num_tensors; ++i) {
+ PyObject* tensor_obj = PyList_GET_ITEM(tensor_list, i);
+ if (!EagerTensor_CheckExact(tensor_obj)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expected a list of EagerTensors but "
+ "element ",
+ i, " has type \"", Py_TYPE(tensor_obj)->tp_name, "\"")
+ .c_str());
+ return nullptr;
+ }
+
+ EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
+ TFE_TensorHandle* handle = t->handle;
+ if (slice_dim >= TFE_TensorHandleNumDims(handle)) {
+ PyErr_SetString(PyExc_IndexError,
+ tensorflow::strings::StrCat(
+ "Slice dimension (", slice_dim,
+ ") must be smaller than rank of all "
+ "tensors, but tensor at index ",
+ i, " has rank ", TFE_TensorHandleNumDims(handle))
+ .c_str());
+ return nullptr;
+ }
+ int64_t dim = TFE_TensorHandleDim(handle, slice_dim);
+ data[i] = dim;
+ }
+
+ auto status = tensorflow::make_safe(TF_NewStatus());
+ TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
+ if (TF_GetCode(status.get()) != TF_OK) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
+ TF_Message(status.get()))
+ .c_str());
+ return nullptr;
+ }
+ // handle now owns the tensor. Release it from the smart pointer.
+ tensor.release();
+
+ return EagerTensorFromHandle(handle);
+}
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 9834095c87..1d03df2933 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -105,4 +105,16 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
PyObject* backward_function);
PyObject* TFE_Py_TapeExport(PyObject* tape);
+// Returns an EagerTensor of dimension [len(`tensor_list`)] containing
+// the `slice_dim`'th dimension of each tensor in `tensor_list`. In other words,
+// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
+// `tensor_list`. For example, if `tensor_list` contains tensors of with shapes
+// [1, 2, 3], [4, 5], [6, 7, 8, 9], TFE_Py_TensorShapeSlice called with
+// `slice_dim` equal to 1 will return [2, 5, 7].
+// On error, returns nullptr and sets python exception.
+// REQUIRES: `tensor_list` is a python list of EagerTensors
+// REQUIRES: `slice_dim` is non-negative and smaller than the rank of all
+// tensors in `tensor_list`.
+PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim);
+
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index 2b7b5c727a..3a4b4c2414 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -22,6 +22,7 @@ import copy
import numpy as np
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import test
@@ -216,5 +217,93 @@ class TFETensorTest(test_util.TensorFlowTestCase):
_create_tensor("test string")
+class TFETensorUtilTest(test_util.TensorFlowTestCase):
+
+ def testListOfThree(self):
+ t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
+ t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32)
+ t3 = _create_tensor([[1], [3], [5], [6]], dtype=dtypes.int32)
+
+ r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 0)
+ self.assertAllEqual(np.array([3, 2, 4]), r.numpy())
+
+ r = pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2, t3], 1)
+ self.assertAllEqual(np.array([2, 3, 1]), r.numpy())
+
+ def testEmptyTensorList(self):
+ a = pywrap_tensorflow.TFE_Py_TensorShapeSlice([], 0)
+ self.assertTrue(isinstance(a, ops.EagerTensor))
+ self.assertEqual(0, a.numpy().size)
+
+ def testTensorListContainsNonTensors(self):
+ t1 = _create_tensor([1, 2], dtype=dtypes.int32)
+
+ with self.assertRaisesRegexp(
+ TypeError,
+ r"Expected a list of EagerTensors but element 1 has type \"str\""):
+ pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, "abc"], 0)
+
+ with self.assertRaisesRegexp(
+ TypeError,
+ r"Expected a list of EagerTensors but element 0 has type \"int\""):
+ pywrap_tensorflow.TFE_Py_TensorShapeSlice([2, t1], 0)
+
+ def testTensorListNotList(self):
+ t1 = _create_tensor([1, 2], dtype=dtypes.int32)
+
+ with self.assertRaisesRegexp(
+ TypeError,
+ r"tensor_list argument must be a list. Got \"EagerTensor\""):
+ pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
+
+ with self.assertRaisesRegexp(
+ TypeError,
+ r"tensor_list argument must be a list. Got \"tuple\""):
+ pywrap_tensorflow.TFE_Py_TensorShapeSlice((t1,), -2)
+
+ def testNegativeSliceDim(self):
+ t1 = _create_tensor([1, 2], dtype=dtypes.int32)
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"Slice dimension must be non-negative. Got -2"):
+ pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], -2)
+
+ def testSliceDimOutOfRange(self):
+ t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
+ t2 = _create_tensor([1, 2], dtype=dtypes.int32)
+ t3 = _create_tensor(2, dtype=dtypes.int32)
+
+ with self.assertRaisesRegexp(
+ IndexError,
+ r"Slice dimension \(2\) must be smaller than rank of all tensors, "
+ "but tensor at index 0 has rank 2"):
+ pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1], 2)
+
+ with self.assertRaisesRegexp(
+ IndexError,
+ r"Slice dimension \(1\) must be smaller than rank of all tensors, "
+ "but tensor at index 0 has rank 1"):
+ pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2], 1)
+
+ with self.assertRaisesRegexp(
+ IndexError,
+ r"Slice dimension \(1\) must be smaller than rank of all tensors, "
+ "but tensor at index 1 has rank 1"):
+ pywrap_tensorflow.TFE_Py_TensorShapeSlice([t1, t2], 1)
+
+ with self.assertRaisesRegexp(
+ IndexError,
+ r"Slice dimension \(0\) must be smaller than rank of all tensors, "
+ "but tensor at index 0 has rank 0"):
+ pywrap_tensorflow.TFE_Py_TensorShapeSlice([t3], 0)
+
+ with self.assertRaisesRegexp(
+ IndexError,
+ r"Slice dimension \(0\) must be smaller than rank of all tensors, "
+ "but tensor at index 2 has rank 0"):
+ pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 2ee298ad44..7e632c75e8 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from math import ceil
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
@@ -102,32 +103,46 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
concat_dim = op.inputs[dim_index]
input_values = op.inputs[start_value_index:end_value_index]
- # Using mod here for convenience since concat_dim is already verified
- # in concat implementation to be within the allowed [-rank, rank) range.
- non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
out_grads = []
if isinstance(grad, ops.Tensor):
- # Get the inputs' tensor shapes
- sizes = _ExtractInputShapes(input_values)
- # The magic number of 16 was found through benchmarking a range of sizes
- # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
- # cases when switching implementations at N=16, but it is possible that
- # there will be a small number of performance regressions.
- # pylint: disable=protected-access
- if len(sizes) > 16:
- # extract the size of each input along the concat dimension
- sizes = array_ops.squeeze(
- array_ops.slice(
- array_ops.stack(
- sizes, axis=1), [non_neg_concat_dim, 0], [1, -1]))
+ if context.in_eager_mode():
+ # Using mod here for convenience since concat_dim is already verified
+ # in concat implementation to be within the allowed [-rank, rank) range.
+ non_neg_concat_dim = (
+ concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access
+ # All inputs are guaranteed to be EagerTensors in eager mode
+ sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values,
+ non_neg_concat_dim)
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
else:
- offset = gen_array_ops._concat_offset(non_neg_concat_dim, sizes)
- for (begin, size) in zip(offset, sizes):
- out_grads.append(array_ops.slice(grad, begin, size))
- # pylint: enable=protected-access
+ # Using mod here for convenience since concat_dim is already verified
+ # in concat implementation to be within the allowed [-rank, rank) range.
+ non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
+
+ # Get the inputs' tensor shapes
+ sizes = _ExtractInputShapes(input_values)
+ # The magic number of 16 was found through benchmarking a range of sizes
+ # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
+ # cases when switching implementations at N=16, but it is possible that
+ # there will be a small number of performance regressions.
+ # pylint: disable=protected-access
+ if len(sizes) > 16:
+ # extract the size of each input along the concat dimension
+ sizes = array_ops.squeeze(
+ array_ops.slice(
+ array_ops.stack(
+ sizes, axis=1), [non_neg_concat_dim, 0], [1, -1]))
+ out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
+ else:
+ offset = gen_array_ops._concat_offset(non_neg_concat_dim, sizes)
+ for (begin, size) in zip(offset, sizes):
+ out_grads.append(array_ops.slice(grad, begin, size))
+ # pylint: enable=protected-access
elif isinstance(grad, ops.IndexedSlices):
+ # Using mod here for convenience since concat_dim is already verified
+ # in concat implementation to be within the allowed [-rank, rank) range.
+ non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
concat_dim_static = tensor_util.constant_value(concat_dim)
if concat_dim_static is None:
raise ValueError("Can only compute IndexedSlices gradient with "
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index fa36b77311..637f738fed 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -34,6 +34,7 @@ limitations under the License.
%rename("%s") TFE_ContextOptionsSetConfig;
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
%rename("%s") TFE_DeleteContextOptions;
+%rename("%s") TFE_Py_TensorShapeSlice;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"