aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar feiquan <feiquan@wacai.com>2018-08-13 23:44:38 +0800
committerGravatar feiquan <feiquan@wacai.com>2018-08-13 23:44:38 +0800
commit22ebbbc60e5d94d67cdf6c26b44919f7dbb8f600 (patch)
tree693a0e105633c4ce835b69139c91fabc14a3df56 /tensorflow/contrib/autograph
parent3a99980fcaa8e6df827df121b9b2e15d75f3ace1 (diff)
extends the tensor index operator to support character access
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/operators/slices.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py
index 04fbeb2f6e..d878bddf3c 100644
--- a/tensorflow/contrib/autograph/operators/slices.py
+++ b/tensorflow/contrib/autograph/operators/slices.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import gen_string_ops
# TODO(mdan): Support extended slices.
@@ -57,6 +58,8 @@ def get_item(target, i, opts):
elif tensor_util.is_tensor(target):
if target.dtype == dtypes.variant:
return _tf_tensor_list_get_item(target, i, opts)
+ if target.dtype == dtypes.string:
+ return _tf_tensor_string_get_item(target, i)
else:
return _tf_tensor_get_item(target, i)
else:
@@ -81,6 +84,10 @@ def _tf_tensor_get_item(target, i):
"""Overload of get_item that stages a Tensor (not Tensor list) read."""
return target[i]
+def _tf_tensor_string_get_item(target, i):
+ """Overload of get_item that stages a Tensor string read."""
+ x = gen_string_ops.substr(target, i, 1)
+ return x
def _py_get_item(target, i):
"""Overload of get_item that executes a Python list modification."""