aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py2
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py43
-rw-r--r--tensorflow/python/ops/script_ops.py97
3 files changed, 52 insertions, 90 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index ddb022c09e..8ba98cb88d 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -430,7 +430,7 @@ class Dataset(object):
ret_arrays = []
for ret, dtype in zip(flattened_values, flattened_types):
try:
- ret_arrays.append(script_ops.PythonFunc._convert( # pylint: disable=protected-access
+ ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access
ret, dtype=dtype.as_numpy_dtype))
except (TypeError, ValueError):
raise TypeError(
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 0164c70b4c..50154a45a8 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -148,30 +148,6 @@ class PyFuncTest(test.TestCase):
script_ops.py_func(list_func, [x], [dtypes.float64] * 2))
self.assertAllClose(y, [0.0, 1.0])
- def testConvertEmptyList(self):
- with self.test_session():
-
- def empty_list_func(args):
- del args
- return [],
-
- x = constant_op.constant(0, dtypes.int64)
- y = self.evaluate(
- script_ops.py_func(empty_list_func, [x], dtypes.int64))
- self.assertAllClose(y, [])
-
- def testConvertScalarType(self):
- with self.test_session():
-
- def one_func(args):
- del args
- return 1
-
- x = constant_op.constant(0, dtypes.int64)
- y = self.evaluate(
- script_ops.py_func(one_func, [x], dtypes.int64))
- self.assertAllClose(y, 1)
-
def testTuple(self):
# returns a tuple
with self.test_session():
@@ -285,6 +261,19 @@ class PyFuncTest(test.TestCase):
z, = script_ops.py_func(unicode_string, [], [dtypes.string])
self.assertEqual(z.eval(), correct.encode("utf8"))
+ def testBadNumpyReturnType(self):
+ with self.test_session():
+
+ def bad():
+ # Structured numpy arrays aren't supported.
+ return np.array([], dtype=[("foo", np.float32)])
+
+ y, = script_ops.py_func(bad, [], [dtypes.float32])
+
+ with self.assertRaisesRegexp(errors.UnimplementedError,
+ "Unsupported numpy type"):
+ y.eval()
+
def testBadReturnType(self):
with self.test_session():
@@ -294,8 +283,8 @@ class PyFuncTest(test.TestCase):
z, = script_ops.py_func(bad, [], [dtypes.int64])
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError, "argument must.*not 'dict'"):
+ with self.assertRaisesRegexp(errors.UnimplementedError,
+ "Unsupported object type"):
z.eval()
def testReturnInput(self):
@@ -463,7 +452,7 @@ class PyFuncTest(test.TestCase):
# (see #18292)
_ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
_ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
-
+
# Call garbage collector to enforce deletion.
make_graphs()
ops.reset_default_graph()
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 72aac1ddbf..d11e446dbf 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -120,19 +120,30 @@ class EagerFunc(object):
return outputs
-class PythonFunc(object):
- """A wrapper for a function owned by PyFunc."""
+class FuncRegistry(object):
+ """A helper class to keep track of registered py functions.
- def __init__(self, func, Tout):
- """Constructs a PythonFunc.
+ FuncRegistry keeps a map from unique tokens (string) to python
+ functions, which takes numpy arrays and outputs numpy arrays.
+ """
- Args:
- func: The function to wrap.
- Tout: A list of datatypes for the output; an empty list if the output is
- None.
- """
- self._func = func
- self._out_dtypes = Tout
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._unique_id = 0 # GUARDED_BY(self._lock)
+ # Only store weakrefs to the functions. The strong reference is stored in
+ # the graph.
+ self._funcs = weakref.WeakValueDictionary()
+
+ def insert(self, func):
+ """Registers `func` and returns a unique token for this entry."""
+ token = self._next_unique_token()
+ # Store a weakref to the function
+ self._funcs[token] = func
+ return token
+
+ def remove(self, token):
+ """Removes the registered function corresponding to `token`."""
+ self._funcs.pop(token, None)
@staticmethod
def _convert(value, dtype=None):
@@ -160,52 +171,8 @@ class PythonFunc(object):
return np.asarray(value, order="C", dtype=object)
elif result.dtype.char == "U":
return result.astype(np.bytes_)
- return result
-
- def __call__(self, args):
- ret = self._func(*args)
- # Strings seem to lead to a memory leak here if they're not wrapped in a
- # list.
- if isinstance(ret, six.binary_type):
- ret = [ret]
- # Ensures that we return either a single numpy array or a list of numpy
- # arrays.
- if isinstance(ret, (tuple, list)):
- return [
- self._convert(x, dtype=dtype.as_numpy_dtype)
- for (x, dtype) in zip(ret, self._out_dtypes)
- ]
-
- dtype = None
- if isinstance(self._out_dtypes, (tuple, list)) and self._out_dtypes:
- dtype = self._out_dtypes[0].as_numpy_dtype
- return self._convert(ret, dtype)
-
-
-class FuncRegistry(object):
- """A helper class to keep track of registered py functions.
-
- FuncRegistry keeps a map from unique tokens (string) to python
- functions, which takes numpy arrays and outputs numpy arrays.
- """
-
- def __init__(self):
- self._lock = threading.Lock()
- self._unique_id = 0 # GUARDED_BY(self._lock)
- # Only store weakrefs to the functions. The strong reference is stored in
- # the graph.
- self._funcs = weakref.WeakValueDictionary()
-
- def insert(self, func):
- """Registers `func` and returns a unique token for this entry."""
- token = self._next_unique_token()
- # Store a weakref to the function
- self._funcs[token] = func
- return token
-
- def remove(self, token):
- """Removes the registered function corresponding to `token`."""
- self._funcs.pop(token, None)
+ else:
+ return result
def __call__(self, token, device, args):
"""Calls the registered function for `token` with args.
@@ -235,10 +202,18 @@ class FuncRegistry(object):
#
# TODO(akshayka): Key the tape cache in a thread-safe way.
return func(device, token, args)
- elif isinstance(func, PythonFunc):
- return func(args)
else:
- raise ValueError("callback %s type is invalid" % token)
+ ret = func(*args)
+ # Strings seem to lead to a memory leak here if they're not wrapped in a
+ # list.
+ if isinstance(ret, six.binary_type):
+ ret = [ret]
+ # Ensures that we return either a single numpy array or a list of numpy
+ # arrays.
+ if isinstance(ret, (tuple, list)):
+ return [self._convert(x) for x in ret]
+ else:
+ return self._convert(ret)
def size(self):
"""Returns how many functions are currently registered."""
@@ -274,8 +249,6 @@ def _internal_py_func(func,
if eager:
func = EagerFunc(func, Tout, is_grad_func)
- else:
- func = PythonFunc(func, Tout)
token = _py_funcs.insert(func)
# We tie the registered function's lifetime with the current default graph,