diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-10-01 03:12:49 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-01 04:17:45 -0700 |
commit | cabaef0f964c3639f6fe20dd218b1d692b7d632e (patch) | |
tree | 22d8ef114a268f92f3bd9fd7d22a4ed8fc07ef67 | |
parent | 023f01093dd61c1b9e86ea48a83d16488e77da15 (diff) |
Allow tf.py_func to have no return values.
Change: 134873490
-rw-r--r-- | tensorflow/core/ops/script_ops.cc | 4 | ||||
-rw-r--r-- | tensorflow/python/framework/op_def_library.py | 19 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/py_func_test.py | 33 | ||||
-rw-r--r-- | tensorflow/python/lib/core/py_func.cc | 25 |
4 files changed, 66 insertions, 15 deletions
diff --git a/tensorflow/core/ops/script_ops.cc b/tensorflow/core/ops/script_ops.cc index 8328229cf7..8197327b56 100644 --- a/tensorflow/core/ops/script_ops.cc +++ b/tensorflow/core/ops/script_ops.cc @@ -23,7 +23,7 @@ REGISTER_OP("PyFunc") .Output("output: Tout") .Attr("token: string") .Attr("Tin: list(type) >= 0") - .Attr("Tout: list(type)") + .Attr("Tout: list(type) >=0") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( @@ -45,7 +45,7 @@ REGISTER_OP("PyFuncStateless") .Output("output: Tout") .Attr("token: string") .Attr("Tin: list(type) >= 0") - .Attr("Tout: list(type)") + .Attr("Tout: list(type) >= 0") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( A stateless version of PyFunc. diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index 9d6bf0c418..9065b2a8e0 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -725,7 +725,7 @@ class OpDefLibrary(object): elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) types = t.list.type - output_structure.append(len(t.list.type)) + output_structure.append(len(types)) else: types = [arg.type] output_structure.append(None) @@ -743,14 +743,15 @@ class OpDefLibrary(object): if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph + op = g.create_op(op_type_name, inputs, output_types, name=scope, + input_types=input_types, attrs=attr_protos, + op_def=op_def) if output_structure: - op = g.create_op(op_type_name, inputs, output_types, name=scope, - input_types=input_types, attrs=attr_protos, - op_def=op_def) outputs = op.outputs - return _Restructure(ops.convert_n_to_tensor(outputs), - output_structure) + res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) + if isinstance(res, list) and not res and op_def.is_stateful: + return op + else: + return res else: - return g.create_op(op_type_name, inputs, output_types, name=scope, - input_types=input_types, attrs=attr_protos, - op_def=op_def) + return op diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 62a6613a73..025dbb71df 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -226,6 +226,39 @@ class PyOpTest(tf.test.TestCase): # This will result in a deadlock if the py_func's don't run in parallel. session.run([x, y]) + def testNoReturnValueStateful(self): + + class State(object): + + def __init__(self): + self._value = np.array([1], np.int64) + + def _increment(self, diff): + self._value += diff + + def increment(self, diff): + return tf.py_func(self._increment, [diff], [], stateful=True) + + @property + def value(self): + return self._value + + with self.test_session() as sess: + s = State() + op = s.increment(tf.constant(2, tf.int64)) + ret = sess.run(op) + self.assertIsNone(ret) + self.assertAllEqual([3], s.value) + + def testNoReturnValueStateless(self): + + def do_nothing(unused_x): + pass + + f = tf.py_func(do_nothing, [tf.constant(3, tf.int64)], [], stateful=False) + with self.test_session() as sess: + self.assertEqual(sess.run(f), []) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 5581079c8c..61ebdecc6c 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -154,6 +154,21 @@ Status NumericNpDTypeToTfDType(const int np, DataType* tf) { return Status::OK(); } +bool IsSingleNone(PyObject* obj) { + if (!PyArray_Check(obj)) { + return false; + } + PyArrayObject* array_obj = reinterpret_cast<PyArrayObject*>(obj); + if (PyArray_NDIM(array_obj) != 0 || PyArray_SIZE(array_obj) != 1) { + return false; + } + npy_intp indices[] = {}; + char* item_ptr = static_cast<char*>(PyArray_GetPtr(array_obj, indices)); + PyObject* item = PyArray_GETITEM(array_obj, item_ptr); + CHECK(item); + return item == Py_None; +} + // Calls the registered py function through the trampoline. Status DoCallPyFunc(PyCall* call) { PyObject* trampoline = GetPyTrampoline(); @@ -193,10 +208,12 @@ Status DoCallPyFunc(PyCall* call) { } } else if (PyArray_Check(result)) { // 'result' is a single ndarray. - Tensor t; - s = ConvertNdarrayToTensor(result, &t); - if (s.ok()) { - call->out.push_back(t); + if (!IsSingleNone(result)) { + Tensor t; + s = ConvertNdarrayToTensor(result, &t); + if (s.ok()) { + call->out.push_back(t); + } } } else { s = errors::Internal("Unexpected pyobject is returned: ", |