diff options
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/py_func_test.py | 43 | ||||
-rw-r--r-- | tensorflow/python/ops/script_ops.py | 97 |
3 files changed, 52 insertions, 90 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index ddb022c09e..8ba98cb88d 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -430,7 +430,7 @@ class Dataset(object): ret_arrays = [] for ret, dtype in zip(flattened_values, flattened_types): try: - ret_arrays.append(script_ops.PythonFunc._convert( # pylint: disable=protected-access + ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access ret, dtype=dtype.as_numpy_dtype)) except (TypeError, ValueError): raise TypeError( diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 0164c70b4c..50154a45a8 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -148,30 +148,6 @@ class PyFuncTest(test.TestCase): script_ops.py_func(list_func, [x], [dtypes.float64] * 2)) self.assertAllClose(y, [0.0, 1.0]) - def testConvertEmptyList(self): - with self.test_session(): - - def empty_list_func(args): - del args - return [], - - x = constant_op.constant(0, dtypes.int64) - y = self.evaluate( - script_ops.py_func(empty_list_func, [x], dtypes.int64)) - self.assertAllClose(y, []) - - def testConvertScalarType(self): - with self.test_session(): - - def one_func(args): - del args - return 1 - - x = constant_op.constant(0, dtypes.int64) - y = self.evaluate( - script_ops.py_func(one_func, [x], dtypes.int64)) - self.assertAllClose(y, 1) - def testTuple(self): # returns a tuple with self.test_session(): @@ -285,6 +261,19 @@ class PyFuncTest(test.TestCase): z, = script_ops.py_func(unicode_string, [], [dtypes.string]) self.assertEqual(z.eval(), correct.encode("utf8")) + def testBadNumpyReturnType(self): + with self.test_session(): + + def bad(): + # Structured numpy arrays aren't supported. + return np.array([], dtype=[("foo", np.float32)]) + + y, = script_ops.py_func(bad, [], [dtypes.float32]) + + with self.assertRaisesRegexp(errors.UnimplementedError, + "Unsupported numpy type"): + y.eval() + def testBadReturnType(self): with self.test_session(): @@ -294,8 +283,8 @@ class PyFuncTest(test.TestCase): z, = script_ops.py_func(bad, [], [dtypes.int64]) - with self.assertRaisesRegexp( - errors.InvalidArgumentError, "argument must.*not 'dict'"): + with self.assertRaisesRegexp(errors.UnimplementedError, + "Unsupported object type"): z.eval() def testReturnInput(self): @@ -463,7 +452,7 @@ class PyFuncTest(test.TestCase): # (see #18292) _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) - + # Call garbage collector to enforce deletion. make_graphs() ops.reset_default_graph() diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 72aac1ddbf..d11e446dbf 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -120,19 +120,30 @@ class EagerFunc(object): return outputs -class PythonFunc(object): - """A wrapper for a function owned by PyFunc.""" +class FuncRegistry(object): + """A helper class to keep track of registered py functions. - def __init__(self, func, Tout): - """Constructs a PythonFunc. + FuncRegistry keeps a map from unique tokens (string) to python + functions, which takes numpy arrays and outputs numpy arrays. + """ - Args: - func: The function to wrap. - Tout: A list of datatypes for the output; an empty list if the output is - None. - """ - self._func = func - self._out_dtypes = Tout + def __init__(self): + self._lock = threading.Lock() + self._unique_id = 0 # GUARDED_BY(self._lock) + # Only store weakrefs to the functions. The strong reference is stored in + # the graph. + self._funcs = weakref.WeakValueDictionary() + + def insert(self, func): + """Registers `func` and returns a unique token for this entry.""" + token = self._next_unique_token() + # Store a weakref to the function + self._funcs[token] = func + return token + + def remove(self, token): + """Removes the registered function corresponding to `token`.""" + self._funcs.pop(token, None) @staticmethod def _convert(value, dtype=None): @@ -160,52 +171,8 @@ class PythonFunc(object): return np.asarray(value, order="C", dtype=object) elif result.dtype.char == "U": return result.astype(np.bytes_) - return result - - def __call__(self, args): - ret = self._func(*args) - # Strings seem to lead to a memory leak here if they're not wrapped in a - # list. - if isinstance(ret, six.binary_type): - ret = [ret] - # Ensures that we return either a single numpy array or a list of numpy - # arrays. - if isinstance(ret, (tuple, list)): - return [ - self._convert(x, dtype=dtype.as_numpy_dtype) - for (x, dtype) in zip(ret, self._out_dtypes) - ] - - dtype = None - if isinstance(self._out_dtypes, (tuple, list)) and self._out_dtypes: - dtype = self._out_dtypes[0].as_numpy_dtype - return self._convert(ret, dtype) - - -class FuncRegistry(object): - """A helper class to keep track of registered py functions. - - FuncRegistry keeps a map from unique tokens (string) to python - functions, which takes numpy arrays and outputs numpy arrays. - """ - - def __init__(self): - self._lock = threading.Lock() - self._unique_id = 0 # GUARDED_BY(self._lock) - # Only store weakrefs to the functions. The strong reference is stored in - # the graph. - self._funcs = weakref.WeakValueDictionary() - - def insert(self, func): - """Registers `func` and returns a unique token for this entry.""" - token = self._next_unique_token() - # Store a weakref to the function - self._funcs[token] = func - return token - - def remove(self, token): - """Removes the registered function corresponding to `token`.""" - self._funcs.pop(token, None) + else: + return result def __call__(self, token, device, args): """Calls the registered function for `token` with args. @@ -235,10 +202,18 @@ class FuncRegistry(object): # # TODO(akshayka): Key the tape cache in a thread-safe way. return func(device, token, args) - elif isinstance(func, PythonFunc): - return func(args) else: - raise ValueError("callback %s type is invalid" % token) + ret = func(*args) + # Strings seem to lead to a memory leak here if they're not wrapped in a + # list. + if isinstance(ret, six.binary_type): + ret = [ret] + # Ensures that we return either a single numpy array or a list of numpy + # arrays. + if isinstance(ret, (tuple, list)): + return [self._convert(x) for x in ret] + else: + return self._convert(ret) def size(self): """Returns how many functions are currently registered.""" @@ -274,8 +249,6 @@ def _internal_py_func(func, if eager: func = EagerFunc(func, Tout, is_grad_func) - else: - func = PythonFunc(func, Tout) token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, |