diff options
author | 2018-10-02 17:48:25 -0700 | |
---|---|---|
committer | 2018-10-02 17:52:21 -0700 | |
commit | b7e9cbab27c893283acc4a6154d7a59dffb23758 (patch) | |
tree | 2eed3086deb2390f8bf6b8074ddd654ca59cfec8 /tensorflow/python/data | |
parent | 80821abd6410f47130fc031b15e9ac220de5b1b9 (diff) |
Use `defun` instead of `Defun` for `tf.data`, except for `make_one_shot_iterator` which is to be deprecated in future.
PiperOrigin-RevId: 215491729
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 60 |
1 files changed, 26 insertions, 34 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 46ce191f7b..d90da5908d 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -30,6 +30,7 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse from tensorflow.python.eager import context +from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -37,6 +38,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -1713,7 +1715,8 @@ class _VariantDataset(Dataset): class StructuredFunctionWrapper(object): - """A wrapper for `Defun` that supports structured arguments and return values. + """A wrapper for `defun` that supports structured arguments and return values. + """ def __init__(self, func, transformation_name, dataset=None, @@ -1765,7 +1768,7 @@ class StructuredFunctionWrapper(object): # TODO(b/110122868): Enable this support for all `tf.data` functions. self._nested_dataset_support = experimental_nested_dataset_support - @function.Defun(*self._defun_args()) + @eager_function.defun(input_signature=self._defun_args()) def tf_data_structured_function_wrapper(*args): """Wrapper for passing nested structures to and from tf.data functions.""" flat_args = [] @@ -1850,36 +1853,43 @@ class StructuredFunctionWrapper(object): self._output_shapes = nest.pack_sequence_as(ret, flat_shapes) self._output_types = nest.pack_sequence_as(ret, flat_types) - _warn_if_collections(transformation_name) - return flat_ret - self._function = tf_data_structured_function_wrapper + table_initializers_len = len(ops.get_default_graph().get_collection( + ops.GraphKeys.TABLE_INITIALIZERS)) + + self._function = tf_data_structured_function_wrapper.get_concrete_function() if add_to_graph: self._function.add_to_graph(ops.get_default_graph()) - else: - # Use the private method that will execute - # `tf_data_structured_function_wrapper` but delay adding it to the graph - # in case (e.g.) we need to rerun the function. - self._function._create_definition_if_needed() # pylint: disable=protected-access + if len( + self._function.graph.get_collection( + ops.GraphKeys.TABLE_INITIALIZERS)) != table_initializers_len: + warnings.warn( + "Creating lookup tables inside a function passed to %s is not" + " supported. Create each table outside the function, and " + "capture it inside the function to use it." % transformation_name) def _defun_args(self): - """Returns a flat list of `tf.DType` for the input element structure.""" + """Returns a list of `tf.TensorSpec` for the input element structure.""" ret = [] - for input_type, input_class in zip(nest.flatten(self._input_types), - nest.flatten(self._input_classes)): + for input_type, input_shape, input_class in zip( + nest.flatten(self._input_types), nest.flatten(self._input_shapes), + nest.flatten(self._input_classes)): # TODO(b/110122868): Add a registration mechanism for new component types. if input_class is sparse_tensor_lib.SparseTensor: - ret.append(dtypes.variant) + ret.append( + tensor_spec.TensorSpec( + tensor_shape.TensorShape(None), dtypes.variant)) elif isinstance(input_class, _NestedDatasetComponent): if not self._nested_dataset_support: raise NotImplementedError( "The %s transformation does not currently support nested " "datasets as inputs." % self._transformation_name) - ret.append(dtypes.variant) + ret.append( + tensor_spec.TensorSpec(tensor_shape.scalar(), dtypes.variant)) else: assert isinstance(input_type, dtypes.DType) - ret.append(input_type) + ret.append(tensor_spec.TensorSpec(input_shape, input_type)) return ret @property @@ -2579,24 +2589,6 @@ def _should_unpack_args(args): return type(args) is tuple # pylint: disable=unidiomatic-typecheck -def _warn_if_collections(transformation_name): - """Prints warning message if the current graph uses common graph collections. - - NOTE(mrry): Currently a warning is only generated for lookup tables. Any - variables created will be automatically hoisted out to the outermost scope - using `init_scope()`. Some collections (such as for control-flow contexts) - are benign and should not generate a warning. - - Args: - transformation_name: A human-readable name for the transformation. - """ - if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS): - warnings.warn("Creating lookup tables inside a function passed to %s is not" - " supported. Create each table outside the function, and " - "capture it inside the function to use it." - % transformation_name) - - class MapDataset(UnaryDataset): """A `Dataset` that maps a function over elements in its input.""" |