diff options
author | Derek Murray <mrry@google.com> | 2018-10-02 20:00:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 20:04:04 -0700 |
commit | 2597b883a14749c77fffd7e5f9677107021ff40a (patch) | |
tree | b4a8ea07c72bd3f509edd2bb7d1c90dc3f81ba2d /tensorflow/python/data | |
parent | 4b2d0180ba8c903f098f52eb9a12d26a7626dd34 (diff) |
Automated rollback of commit b7e9cbab27c893283acc4a6154d7a59dffb23758
PiperOrigin-RevId: 215503549
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 60 |
1 files changed, 34 insertions, 26 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index d90da5908d..46ce191f7b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -30,7 +30,6 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse from tensorflow.python.eager import context -from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -38,7 +37,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -1715,8 +1713,7 @@ class _VariantDataset(Dataset): class StructuredFunctionWrapper(object): - """A wrapper for `defun` that supports structured arguments and return values. - + """A wrapper for `Defun` that supports structured arguments and return values. """ def __init__(self, func, transformation_name, dataset=None, @@ -1768,7 +1765,7 @@ class StructuredFunctionWrapper(object): # TODO(b/110122868): Enable this support for all `tf.data` functions. self._nested_dataset_support = experimental_nested_dataset_support - @eager_function.defun(input_signature=self._defun_args()) + @function.Defun(*self._defun_args()) def tf_data_structured_function_wrapper(*args): """Wrapper for passing nested structures to and from tf.data functions.""" flat_args = [] @@ -1853,43 +1850,36 @@ class StructuredFunctionWrapper(object): self._output_shapes = nest.pack_sequence_as(ret, flat_shapes) self._output_types = nest.pack_sequence_as(ret, flat_types) - return flat_ret + _warn_if_collections(transformation_name) - table_initializers_len = len(ops.get_default_graph().get_collection( - ops.GraphKeys.TABLE_INITIALIZERS)) + return flat_ret - self._function = tf_data_structured_function_wrapper.get_concrete_function() + self._function = tf_data_structured_function_wrapper if add_to_graph: self._function.add_to_graph(ops.get_default_graph()) - if len( - self._function.graph.get_collection( - ops.GraphKeys.TABLE_INITIALIZERS)) != table_initializers_len: - warnings.warn( - "Creating lookup tables inside a function passed to %s is not" - " supported. Create each table outside the function, and " - "capture it inside the function to use it." % transformation_name) + else: + # Use the private method that will execute + # `tf_data_structured_function_wrapper` but delay adding it to the graph + # in case (e.g.) we need to rerun the function. + self._function._create_definition_if_needed() # pylint: disable=protected-access def _defun_args(self): - """Returns a list of `tf.TensorSpec` for the input element structure.""" + """Returns a flat list of `tf.DType` for the input element structure.""" ret = [] - for input_type, input_shape, input_class in zip( - nest.flatten(self._input_types), nest.flatten(self._input_shapes), - nest.flatten(self._input_classes)): + for input_type, input_class in zip(nest.flatten(self._input_types), + nest.flatten(self._input_classes)): # TODO(b/110122868): Add a registration mechanism for new component types. if input_class is sparse_tensor_lib.SparseTensor: - ret.append( - tensor_spec.TensorSpec( - tensor_shape.TensorShape(None), dtypes.variant)) + ret.append(dtypes.variant) elif isinstance(input_class, _NestedDatasetComponent): if not self._nested_dataset_support: raise NotImplementedError( "The %s transformation does not currently support nested " "datasets as inputs." % self._transformation_name) - ret.append( - tensor_spec.TensorSpec(tensor_shape.scalar(), dtypes.variant)) + ret.append(dtypes.variant) else: assert isinstance(input_type, dtypes.DType) - ret.append(tensor_spec.TensorSpec(input_shape, input_type)) + ret.append(input_type) return ret @property @@ -2589,6 +2579,24 @@ def _should_unpack_args(args): return type(args) is tuple # pylint: disable=unidiomatic-typecheck +def _warn_if_collections(transformation_name): + """Prints warning message if the current graph uses common graph collections. + + NOTE(mrry): Currently a warning is only generated for lookup tables. Any + variables created will be automatically hoisted out to the outermost scope + using `init_scope()`. Some collections (such as for control-flow contexts) + are benign and should not generate a warning. + + Args: + transformation_name: A human-readable name for the transformation. + """ + if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS): + warnings.warn("Creating lookup tables inside a function passed to %s is not" + " supported. Create each table outside the function, and " + "capture it inside the function to use it." + % transformation_name) + + class MapDataset(UnaryDataset): """A `Dataset` that maps a function over elements in its input.""" |