aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 10:44:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 10:44:45 -0700
commit20622daa9db9d403a66111f5682367566e2bb8db (patch)
treecc6dc5330a0873f417b0505ee6696ce1a0930c7c
parent7fedebf6381167dcc8f005ac785a236b3e4895fc (diff)
parent3c343c19f2414f10aaea6ba3802d883d2303d6fd (diff)
Merge pull request #21038 from recogni:dev-recogni/bug_20983
PiperOrigin-RevId: 208232199
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py2
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py43
-rw-r--r--tensorflow/python/ops/script_ops.py97
3 files changed, 90 insertions, 52 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 8ba98cb88d..ddb022c09e 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -430,7 +430,7 @@ class Dataset(object):
ret_arrays = []
for ret, dtype in zip(flattened_values, flattened_types):
try:
- ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access
+ ret_arrays.append(script_ops.PythonFunc._convert( # pylint: disable=protected-access
ret, dtype=dtype.as_numpy_dtype))
except (TypeError, ValueError):
raise TypeError(
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 50154a45a8..0164c70b4c 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -148,6 +148,30 @@ class PyFuncTest(test.TestCase):
script_ops.py_func(list_func, [x], [dtypes.float64] * 2))
self.assertAllClose(y, [0.0, 1.0])
+ def testConvertEmptyList(self):
+ with self.test_session():
+
+ def empty_list_func(args):
+ del args
+ return [],
+
+ x = constant_op.constant(0, dtypes.int64)
+ y = self.evaluate(
+ script_ops.py_func(empty_list_func, [x], dtypes.int64))
+ self.assertAllClose(y, [])
+
+ def testConvertScalarType(self):
+ with self.test_session():
+
+ def one_func(args):
+ del args
+ return 1
+
+ x = constant_op.constant(0, dtypes.int64)
+ y = self.evaluate(
+ script_ops.py_func(one_func, [x], dtypes.int64))
+ self.assertAllClose(y, 1)
+
def testTuple(self):
# returns a tuple
with self.test_session():
@@ -261,19 +285,6 @@ class PyFuncTest(test.TestCase):
z, = script_ops.py_func(unicode_string, [], [dtypes.string])
self.assertEqual(z.eval(), correct.encode("utf8"))
- def testBadNumpyReturnType(self):
- with self.test_session():
-
- def bad():
- # Structured numpy arrays aren't supported.
- return np.array([], dtype=[("foo", np.float32)])
-
- y, = script_ops.py_func(bad, [], [dtypes.float32])
-
- with self.assertRaisesRegexp(errors.UnimplementedError,
- "Unsupported numpy type"):
- y.eval()
-
def testBadReturnType(self):
with self.test_session():
@@ -283,8 +294,8 @@ class PyFuncTest(test.TestCase):
z, = script_ops.py_func(bad, [], [dtypes.int64])
- with self.assertRaisesRegexp(errors.UnimplementedError,
- "Unsupported object type"):
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError, "argument must.*not 'dict'"):
z.eval()
def testReturnInput(self):
@@ -452,7 +463,7 @@ class PyFuncTest(test.TestCase):
# (see #18292)
_ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
_ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
-
+
# Call garbage collector to enforce deletion.
make_graphs()
ops.reset_default_graph()
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index d11e446dbf..72aac1ddbf 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -120,30 +120,19 @@ class EagerFunc(object):
return outputs
-class FuncRegistry(object):
- """A helper class to keep track of registered py functions.
-
- FuncRegistry keeps a map from unique tokens (string) to python
- functions, which takes numpy arrays and outputs numpy arrays.
- """
+class PythonFunc(object):
+ """A wrapper for a function owned by PyFunc."""
- def __init__(self):
- self._lock = threading.Lock()
- self._unique_id = 0 # GUARDED_BY(self._lock)
- # Only store weakrefs to the functions. The strong reference is stored in
- # the graph.
- self._funcs = weakref.WeakValueDictionary()
-
- def insert(self, func):
- """Registers `func` and returns a unique token for this entry."""
- token = self._next_unique_token()
- # Store a weakref to the function
- self._funcs[token] = func
- return token
+ def __init__(self, func, Tout):
+ """Constructs a PythonFunc.
- def remove(self, token):
- """Removes the registered function corresponding to `token`."""
- self._funcs.pop(token, None)
+ Args:
+ func: The function to wrap.
+ Tout: A list of datatypes for the output; an empty list if the output is
+ None.
+ """
+ self._func = func
+ self._out_dtypes = Tout
@staticmethod
def _convert(value, dtype=None):
@@ -171,8 +160,52 @@ class FuncRegistry(object):
return np.asarray(value, order="C", dtype=object)
elif result.dtype.char == "U":
return result.astype(np.bytes_)
- else:
- return result
+ return result
+
+ def __call__(self, args):
+ ret = self._func(*args)
+ # Strings seem to lead to a memory leak here if they're not wrapped in a
+ # list.
+ if isinstance(ret, six.binary_type):
+ ret = [ret]
+ # Ensures that we return either a single numpy array or a list of numpy
+ # arrays.
+ if isinstance(ret, (tuple, list)):
+ return [
+ self._convert(x, dtype=dtype.as_numpy_dtype)
+ for (x, dtype) in zip(ret, self._out_dtypes)
+ ]
+
+ dtype = None
+ if isinstance(self._out_dtypes, (tuple, list)) and self._out_dtypes:
+ dtype = self._out_dtypes[0].as_numpy_dtype
+ return self._convert(ret, dtype)
+
+
+class FuncRegistry(object):
+ """A helper class to keep track of registered py functions.
+
+ FuncRegistry keeps a map from unique tokens (string) to python
+ functions, which takes numpy arrays and outputs numpy arrays.
+ """
+
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._unique_id = 0 # GUARDED_BY(self._lock)
+ # Only store weakrefs to the functions. The strong reference is stored in
+ # the graph.
+ self._funcs = weakref.WeakValueDictionary()
+
+ def insert(self, func):
+ """Registers `func` and returns a unique token for this entry."""
+ token = self._next_unique_token()
+ # Store a weakref to the function
+ self._funcs[token] = func
+ return token
+
+ def remove(self, token):
+ """Removes the registered function corresponding to `token`."""
+ self._funcs.pop(token, None)
def __call__(self, token, device, args):
"""Calls the registered function for `token` with args.
@@ -202,18 +235,10 @@ class FuncRegistry(object):
#
# TODO(akshayka): Key the tape cache in a thread-safe way.
return func(device, token, args)
+ elif isinstance(func, PythonFunc):
+ return func(args)
else:
- ret = func(*args)
- # Strings seem to lead to a memory leak here if they're not wrapped in a
- # list.
- if isinstance(ret, six.binary_type):
- ret = [ret]
- # Ensures that we return either a single numpy array or a list of numpy
- # arrays.
- if isinstance(ret, (tuple, list)):
- return [self._convert(x) for x in ret]
- else:
- return self._convert(ret)
+ raise ValueError("callback %s type is invalid" % token)
def size(self):
"""Returns how many functions are currently registered."""
@@ -249,6 +274,8 @@ def _internal_py_func(func,
if eager:
func = EagerFunc(func, Tout, is_grad_func)
+ else:
+ func = PythonFunc(func, Tout)
token = _py_funcs.insert(func)
# We tie the registered function's lifetime with the current default graph,