aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/script_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/script_ops.py')
-rw-r--r--tensorflow/python/ops/script_ops.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 050dcd196b..ce5129edba 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -132,8 +132,8 @@ def py_func(func, inp, Tout, stateful=True, name=None):
Args:
func: A python function.
inp: A list of `Tensor`.
- Tout: A list of tensorflow data types indicating what `func`
- returns.
+ Tout: A list of tensorflow data types or a single tensorflow data type
+ indicating what `func` returns.
stateful: A boolean indicating whether the function should be considered
stateful or stateless. I.e. whether it, given the same input, will
return the same output and at the same time does not change state
@@ -142,7 +142,7 @@ def py_func(func, inp, Tout, stateful=True, name=None):
name: A name for the operation (optional).
Returns:
- A list of `Tensor` which `func` computes.
+ A list of `Tensor` or a single `Tensor` which `func` computes.
"""
token = _py_funcs.insert(func)
# We tie the registered function's life-time with the current
@@ -162,14 +162,20 @@ def py_func(func, inp, Tout, stateful=True, name=None):
# the funcs registry.
g._cleanup_py_funcs_used_in_graph.append(cleanup)
+ if isinstance(Tout, list):
+ is_list = True
+ else:
+ Tout = [Tout]
+ is_list = False
if stateful:
- return gen_script_ops._py_func(input=inp, token=token, Tout=Tout, name=name)
+ result = gen_script_ops._py_func(
+ input=inp, token=token, Tout=Tout, name=name)
# pylint: enable=protected-access
else:
- return gen_script_ops._py_func_stateless(
- input=inp, token=token, Tout=Tout,
- name=name)
+ result = gen_script_ops._py_func_stateless(
+ input=inp, token=token, Tout=Tout, name=name)
# pylint: enable=protected-access
+ return result if is_list else result[0]
ops.RegisterShape("PyFunc")(common_shapes.call_cpp_shape_fn)