aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-02 20:00:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 20:04:04 -0700
commit2597b883a14749c77fffd7e5f9677107021ff40a (patch)
treeb4a8ea07c72bd3f509edd2bb7d1c90dc3f81ba2d /tensorflow/python/data
parent4b2d0180ba8c903f098f52eb9a12d26a7626dd34 (diff)
Automated rollback of commit b7e9cbab27c893283acc4a6154d7a59dffb23758
PiperOrigin-RevId: 215503549
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py60
1 files changed, 34 insertions, 26 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index d90da5908d..46ce191f7b 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -30,7 +30,6 @@ from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
-from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
@@ -38,7 +37,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -1715,8 +1713,7 @@ class _VariantDataset(Dataset):
class StructuredFunctionWrapper(object):
- """A wrapper for `defun` that supports structured arguments and return values.
-
+ """A wrapper for `Defun` that supports structured arguments and return values.
"""
def __init__(self, func, transformation_name, dataset=None,
@@ -1768,7 +1765,7 @@ class StructuredFunctionWrapper(object):
# TODO(b/110122868): Enable this support for all `tf.data` functions.
self._nested_dataset_support = experimental_nested_dataset_support
- @eager_function.defun(input_signature=self._defun_args())
+ @function.Defun(*self._defun_args())
def tf_data_structured_function_wrapper(*args):
"""Wrapper for passing nested structures to and from tf.data functions."""
flat_args = []
@@ -1853,43 +1850,36 @@ class StructuredFunctionWrapper(object):
self._output_shapes = nest.pack_sequence_as(ret, flat_shapes)
self._output_types = nest.pack_sequence_as(ret, flat_types)
- return flat_ret
+ _warn_if_collections(transformation_name)
- table_initializers_len = len(ops.get_default_graph().get_collection(
- ops.GraphKeys.TABLE_INITIALIZERS))
+ return flat_ret
- self._function = tf_data_structured_function_wrapper.get_concrete_function()
+ self._function = tf_data_structured_function_wrapper
if add_to_graph:
self._function.add_to_graph(ops.get_default_graph())
- if len(
- self._function.graph.get_collection(
- ops.GraphKeys.TABLE_INITIALIZERS)) != table_initializers_len:
- warnings.warn(
- "Creating lookup tables inside a function passed to %s is not"
- " supported. Create each table outside the function, and "
- "capture it inside the function to use it." % transformation_name)
+ else:
+ # Use the private method that will execute
+ # `tf_data_structured_function_wrapper` but delay adding it to the graph
+ # in case (e.g.) we need to rerun the function.
+ self._function._create_definition_if_needed() # pylint: disable=protected-access
def _defun_args(self):
- """Returns a list of `tf.TensorSpec` for the input element structure."""
+ """Returns a flat list of `tf.DType` for the input element structure."""
ret = []
- for input_type, input_shape, input_class in zip(
- nest.flatten(self._input_types), nest.flatten(self._input_shapes),
- nest.flatten(self._input_classes)):
+ for input_type, input_class in zip(nest.flatten(self._input_types),
+ nest.flatten(self._input_classes)):
# TODO(b/110122868): Add a registration mechanism for new component types.
if input_class is sparse_tensor_lib.SparseTensor:
- ret.append(
- tensor_spec.TensorSpec(
- tensor_shape.TensorShape(None), dtypes.variant))
+ ret.append(dtypes.variant)
elif isinstance(input_class, _NestedDatasetComponent):
if not self._nested_dataset_support:
raise NotImplementedError(
"The %s transformation does not currently support nested "
"datasets as inputs." % self._transformation_name)
- ret.append(
- tensor_spec.TensorSpec(tensor_shape.scalar(), dtypes.variant))
+ ret.append(dtypes.variant)
else:
assert isinstance(input_type, dtypes.DType)
- ret.append(tensor_spec.TensorSpec(input_shape, input_type))
+ ret.append(input_type)
return ret
@property
@@ -2589,6 +2579,24 @@ def _should_unpack_args(args):
return type(args) is tuple # pylint: disable=unidiomatic-typecheck
+def _warn_if_collections(transformation_name):
+ """Prints warning message if the current graph uses common graph collections.
+
+ NOTE(mrry): Currently a warning is only generated for lookup tables. Any
+ variables created will be automatically hoisted out to the outermost scope
+ using `init_scope()`. Some collections (such as for control-flow contexts)
+ are benign and should not generate a warning.
+
+ Args:
+ transformation_name: A human-readable name for the transformation.
+ """
+ if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS):
+ warnings.warn("Creating lookup tables inside a function passed to %s is not"
+ " supported. Create each table outside the function, and "
+ "capture it inside the function to use it."
+ % transformation_name)
+
+
class MapDataset(UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""