aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-01 03:12:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-01 04:17:45 -0700
commitcabaef0f964c3639f6fe20dd218b1d692b7d632e (patch)
tree22d8ef114a268f92f3bd9fd7d22a4ed8fc07ef67
parent023f01093dd61c1b9e86ea48a83d16488e77da15 (diff)
Allow tf.py_func to have no return values.
Change: 134873490
-rw-r--r--tensorflow/core/ops/script_ops.cc4
-rw-r--r--tensorflow/python/framework/op_def_library.py19
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py33
-rw-r--r--tensorflow/python/lib/core/py_func.cc25
4 files changed, 66 insertions, 15 deletions
diff --git a/tensorflow/core/ops/script_ops.cc b/tensorflow/core/ops/script_ops.cc
index 8328229cf7..8197327b56 100644
--- a/tensorflow/core/ops/script_ops.cc
+++ b/tensorflow/core/ops/script_ops.cc
@@ -23,7 +23,7 @@ REGISTER_OP("PyFunc")
.Output("output: Tout")
.Attr("token: string")
.Attr("Tin: list(type) >= 0")
- .Attr("Tout: list(type)")
+ .Attr("Tout: list(type) >=0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
@@ -45,7 +45,7 @@ REGISTER_OP("PyFuncStateless")
.Output("output: Tout")
.Attr("token: string")
.Attr("Tin: list(type) >= 0")
- .Attr("Tout: list(type)")
+ .Attr("Tout: list(type) >= 0")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
A stateless version of PyFunc.
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index 9d6bf0c418..9065b2a8e0 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -725,7 +725,7 @@ class OpDefLibrary(object):
elif arg.type_list_attr:
t = _AttrValue(attr_protos, arg.type_list_attr)
types = t.list.type
- output_structure.append(len(t.list.type))
+ output_structure.append(len(types))
else:
types = [arg.type]
output_structure.append(None)
@@ -743,14 +743,15 @@ class OpDefLibrary(object):
if arg.is_ref]
with _MaybeColocateWith(must_colocate_inputs):
# Add Op to graph
+ op = g.create_op(op_type_name, inputs, output_types, name=scope,
+ input_types=input_types, attrs=attr_protos,
+ op_def=op_def)
if output_structure:
- op = g.create_op(op_type_name, inputs, output_types, name=scope,
- input_types=input_types, attrs=attr_protos,
- op_def=op_def)
outputs = op.outputs
- return _Restructure(ops.convert_n_to_tensor(outputs),
- output_structure)
+ res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
+ if isinstance(res, list) and not res and op_def.is_stateful:
+ return op
+ else:
+ return res
else:
- return g.create_op(op_type_name, inputs, output_types, name=scope,
- input_types=input_types, attrs=attr_protos,
- op_def=op_def)
+ return op
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 62a6613a73..025dbb71df 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -226,6 +226,39 @@ class PyOpTest(tf.test.TestCase):
# This will result in a deadlock if the py_func's don't run in parallel.
session.run([x, y])
+ def testNoReturnValueStateful(self):
+
+ class State(object):
+
+ def __init__(self):
+ self._value = np.array([1], np.int64)
+
+ def _increment(self, diff):
+ self._value += diff
+
+ def increment(self, diff):
+ return tf.py_func(self._increment, [diff], [], stateful=True)
+
+ @property
+ def value(self):
+ return self._value
+
+ with self.test_session() as sess:
+ s = State()
+ op = s.increment(tf.constant(2, tf.int64))
+ ret = sess.run(op)
+ self.assertIsNone(ret)
+ self.assertAllEqual([3], s.value)
+
+ def testNoReturnValueStateless(self):
+
+ def do_nothing(unused_x):
+ pass
+
+ f = tf.py_func(do_nothing, [tf.constant(3, tf.int64)], [], stateful=False)
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(f), [])
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index 5581079c8c..61ebdecc6c 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -154,6 +154,21 @@ Status NumericNpDTypeToTfDType(const int np, DataType* tf) {
return Status::OK();
}
+bool IsSingleNone(PyObject* obj) {
+ if (!PyArray_Check(obj)) {
+ return false;
+ }
+ PyArrayObject* array_obj = reinterpret_cast<PyArrayObject*>(obj);
+ if (PyArray_NDIM(array_obj) != 0 || PyArray_SIZE(array_obj) != 1) {
+ return false;
+ }
+ npy_intp indices[] = {};
+ char* item_ptr = static_cast<char*>(PyArray_GetPtr(array_obj, indices));
+ PyObject* item = PyArray_GETITEM(array_obj, item_ptr);
+ CHECK(item);
+ return item == Py_None;
+}
+
// Calls the registered py function through the trampoline.
Status DoCallPyFunc(PyCall* call) {
PyObject* trampoline = GetPyTrampoline();
@@ -193,10 +208,12 @@ Status DoCallPyFunc(PyCall* call) {
}
} else if (PyArray_Check(result)) {
// 'result' is a single ndarray.
- Tensor t;
- s = ConvertNdarrayToTensor(result, &t);
- if (s.ok()) {
- call->out.push_back(t);
+ if (!IsSingleNone(result)) {
+ Tensor t;
+ s = ConvertNdarrayToTensor(result, &t);
+ if (s.ok()) {
+ call->out.push_back(t);
+ }
}
} else {
s = errors::Internal("Unexpected pyobject is returned: ",