diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-10 10:44:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-10 10:44:45 -0700 |
commit | 20622daa9db9d403a66111f5682367566e2bb8db (patch) | |
tree | cc6dc5330a0873f417b0505ee6696ce1a0930c7c | |
parent | 7fedebf6381167dcc8f005ac785a236b3e4895fc (diff) | |
parent | 3c343c19f2414f10aaea6ba3802d883d2303d6fd (diff) |
Merge pull request #21038 from recogni:dev-recogni/bug_20983
PiperOrigin-RevId: 208232199
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/py_func_test.py | 43 | ||||
-rw-r--r-- | tensorflow/python/ops/script_ops.py | 97 |
3 files changed, 90 insertions, 52 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 8ba98cb88d..ddb022c09e 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -430,7 +430,7 @@ class Dataset(object): ret_arrays = [] for ret, dtype in zip(flattened_values, flattened_types): try: - ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access + ret_arrays.append(script_ops.PythonFunc._convert( # pylint: disable=protected-access ret, dtype=dtype.as_numpy_dtype)) except (TypeError, ValueError): raise TypeError( diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 50154a45a8..0164c70b4c 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -148,6 +148,30 @@ class PyFuncTest(test.TestCase): script_ops.py_func(list_func, [x], [dtypes.float64] * 2)) self.assertAllClose(y, [0.0, 1.0]) + def testConvertEmptyList(self): + with self.test_session(): + + def empty_list_func(args): + del args + return [], + + x = constant_op.constant(0, dtypes.int64) + y = self.evaluate( + script_ops.py_func(empty_list_func, [x], dtypes.int64)) + self.assertAllClose(y, []) + + def testConvertScalarType(self): + with self.test_session(): + + def one_func(args): + del args + return 1 + + x = constant_op.constant(0, dtypes.int64) + y = self.evaluate( + script_ops.py_func(one_func, [x], dtypes.int64)) + self.assertAllClose(y, 1) + def testTuple(self): # returns a tuple with self.test_session(): @@ -261,19 +285,6 @@ class PyFuncTest(test.TestCase): z, = script_ops.py_func(unicode_string, [], [dtypes.string]) self.assertEqual(z.eval(), correct.encode("utf8")) - def testBadNumpyReturnType(self): - with self.test_session(): - - def bad(): - # Structured numpy arrays aren't supported. - return np.array([], dtype=[("foo", np.float32)]) - - y, = script_ops.py_func(bad, [], [dtypes.float32]) - - with self.assertRaisesRegexp(errors.UnimplementedError, - "Unsupported numpy type"): - y.eval() - def testBadReturnType(self): with self.test_session(): @@ -283,8 +294,8 @@ class PyFuncTest(test.TestCase): z, = script_ops.py_func(bad, [], [dtypes.int64]) - with self.assertRaisesRegexp(errors.UnimplementedError, - "Unsupported object type"): + with self.assertRaisesRegexp( + errors.InvalidArgumentError, "argument must.*not 'dict'"): z.eval() def testReturnInput(self): @@ -452,7 +463,7 @@ class PyFuncTest(test.TestCase): # (see #18292) _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) - + # Call garbage collector to enforce deletion. make_graphs() ops.reset_default_graph() diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index d11e446dbf..72aac1ddbf 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -120,30 +120,19 @@ class EagerFunc(object): return outputs -class FuncRegistry(object): - """A helper class to keep track of registered py functions. - - FuncRegistry keeps a map from unique tokens (string) to python - functions, which takes numpy arrays and outputs numpy arrays. - """ +class PythonFunc(object): + """A wrapper for a function owned by PyFunc.""" - def __init__(self): - self._lock = threading.Lock() - self._unique_id = 0 # GUARDED_BY(self._lock) - # Only store weakrefs to the functions. The strong reference is stored in - # the graph. - self._funcs = weakref.WeakValueDictionary() - - def insert(self, func): - """Registers `func` and returns a unique token for this entry.""" - token = self._next_unique_token() - # Store a weakref to the function - self._funcs[token] = func - return token + def __init__(self, func, Tout): + """Constructs a PythonFunc. - def remove(self, token): - """Removes the registered function corresponding to `token`.""" - self._funcs.pop(token, None) + Args: + func: The function to wrap. + Tout: A list of datatypes for the output; an empty list if the output is + None. + """ + self._func = func + self._out_dtypes = Tout @staticmethod def _convert(value, dtype=None): @@ -171,8 +160,52 @@ class FuncRegistry(object): return np.asarray(value, order="C", dtype=object) elif result.dtype.char == "U": return result.astype(np.bytes_) - else: - return result + return result + + def __call__(self, args): + ret = self._func(*args) + # Strings seem to lead to a memory leak here if they're not wrapped in a + # list. + if isinstance(ret, six.binary_type): + ret = [ret] + # Ensures that we return either a single numpy array or a list of numpy + # arrays. + if isinstance(ret, (tuple, list)): + return [ + self._convert(x, dtype=dtype.as_numpy_dtype) + for (x, dtype) in zip(ret, self._out_dtypes) + ] + + dtype = None + if isinstance(self._out_dtypes, (tuple, list)) and self._out_dtypes: + dtype = self._out_dtypes[0].as_numpy_dtype + return self._convert(ret, dtype) + + +class FuncRegistry(object): + """A helper class to keep track of registered py functions. + + FuncRegistry keeps a map from unique tokens (string) to python + functions, which takes numpy arrays and outputs numpy arrays. + """ + + def __init__(self): + self._lock = threading.Lock() + self._unique_id = 0 # GUARDED_BY(self._lock) + # Only store weakrefs to the functions. The strong reference is stored in + # the graph. + self._funcs = weakref.WeakValueDictionary() + + def insert(self, func): + """Registers `func` and returns a unique token for this entry.""" + token = self._next_unique_token() + # Store a weakref to the function + self._funcs[token] = func + return token + + def remove(self, token): + """Removes the registered function corresponding to `token`.""" + self._funcs.pop(token, None) def __call__(self, token, device, args): """Calls the registered function for `token` with args. @@ -202,18 +235,10 @@ class FuncRegistry(object): # # TODO(akshayka): Key the tape cache in a thread-safe way. return func(device, token, args) + elif isinstance(func, PythonFunc): + return func(args) else: - ret = func(*args) - # Strings seem to lead to a memory leak here if they're not wrapped in a - # list. - if isinstance(ret, six.binary_type): - ret = [ret] - # Ensures that we return either a single numpy array or a list of numpy - # arrays. - if isinstance(ret, (tuple, list)): - return [self._convert(x) for x in ret] - else: - return self._convert(ret) + raise ValueError("callback %s type is invalid" % token) def size(self): """Returns how many functions are currently registered.""" @@ -249,6 +274,8 @@ def _internal_py_func(func, if eager: func = EagerFunc(func, Tout, is_grad_func) + else: + func = PythonFunc(func, Tout) token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, |