diff options
author | Shivani Agrawal <shivaniagrawal@google.com> | 2018-10-02 17:48:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 17:52:21 -0700 |
commit | b7e9cbab27c893283acc4a6154d7a59dffb23758 (patch) | |
tree | 2eed3086deb2390f8bf6b8074ddd654ca59cfec8 | |
parent | 80821abd6410f47130fc031b15e9ac220de5b1b9 (diff) |
Use `defun` instead of `Defun` for `tf.data`, except for `make_one_shot_iterator` which is to be deprecated in future.
PiperOrigin-RevId: 215491729
-rw-r--r-- | tensorflow/contrib/distribute/python/input_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 60 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 14 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 9 |
4 files changed, 45 insertions, 40 deletions
diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py index f07ec8234d..423952c9e2 100644 --- a/tensorflow/contrib/distribute/python/input_ops.py +++ b/tensorflow/contrib/distribute/python/input_ops.py @@ -78,7 +78,7 @@ def auto_shard_dataset(dataset, num_shards, index): elif hasattr(dataset, "_map_func"): # TODO(priyag): Make this check more robust by enforcing some common # property on all map/flatmap/interleave datasets. - map_func_def = dataset._map_func.definition + map_func_def = dataset._map_func.function_def for node in map_func_def.node_def: if node.op in _READER_DATASET_OPS: found_reader_op = True diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 46ce191f7b..d90da5908d 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -30,6 +30,7 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse from tensorflow.python.eager import context +from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -37,6 +38,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -1713,7 +1715,8 @@ class _VariantDataset(Dataset): class StructuredFunctionWrapper(object): - """A wrapper for `Defun` that supports structured arguments and return values. + """A wrapper for `defun` that supports structured arguments and return values. + """ def __init__(self, func, transformation_name, dataset=None, @@ -1765,7 +1768,7 @@ class StructuredFunctionWrapper(object): # TODO(b/110122868): Enable this support for all `tf.data` functions. self._nested_dataset_support = experimental_nested_dataset_support - @function.Defun(*self._defun_args()) + @eager_function.defun(input_signature=self._defun_args()) def tf_data_structured_function_wrapper(*args): """Wrapper for passing nested structures to and from tf.data functions.""" flat_args = [] @@ -1850,36 +1853,43 @@ class StructuredFunctionWrapper(object): self._output_shapes = nest.pack_sequence_as(ret, flat_shapes) self._output_types = nest.pack_sequence_as(ret, flat_types) - _warn_if_collections(transformation_name) - return flat_ret - self._function = tf_data_structured_function_wrapper + table_initializers_len = len(ops.get_default_graph().get_collection( + ops.GraphKeys.TABLE_INITIALIZERS)) + + self._function = tf_data_structured_function_wrapper.get_concrete_function() if add_to_graph: self._function.add_to_graph(ops.get_default_graph()) - else: - # Use the private method that will execute - # `tf_data_structured_function_wrapper` but delay adding it to the graph - # in case (e.g.) we need to rerun the function. - self._function._create_definition_if_needed() # pylint: disable=protected-access + if len( + self._function.graph.get_collection( + ops.GraphKeys.TABLE_INITIALIZERS)) != table_initializers_len: + warnings.warn( + "Creating lookup tables inside a function passed to %s is not" + " supported. Create each table outside the function, and " + "capture it inside the function to use it." % transformation_name) def _defun_args(self): - """Returns a flat list of `tf.DType` for the input element structure.""" + """Returns a list of `tf.TensorSpec` for the input element structure.""" ret = [] - for input_type, input_class in zip(nest.flatten(self._input_types), - nest.flatten(self._input_classes)): + for input_type, input_shape, input_class in zip( + nest.flatten(self._input_types), nest.flatten(self._input_shapes), + nest.flatten(self._input_classes)): # TODO(b/110122868): Add a registration mechanism for new component types. if input_class is sparse_tensor_lib.SparseTensor: - ret.append(dtypes.variant) + ret.append( + tensor_spec.TensorSpec( + tensor_shape.TensorShape(None), dtypes.variant)) elif isinstance(input_class, _NestedDatasetComponent): if not self._nested_dataset_support: raise NotImplementedError( "The %s transformation does not currently support nested " "datasets as inputs." % self._transformation_name) - ret.append(dtypes.variant) + ret.append( + tensor_spec.TensorSpec(tensor_shape.scalar(), dtypes.variant)) else: assert isinstance(input_type, dtypes.DType) - ret.append(input_type) + ret.append(tensor_spec.TensorSpec(input_shape, input_type)) return ret @property @@ -2579,24 +2589,6 @@ def _should_unpack_args(args): return type(args) is tuple # pylint: disable=unidiomatic-typecheck -def _warn_if_collections(transformation_name): - """Prints warning message if the current graph uses common graph collections. - - NOTE(mrry): Currently a warning is only generated for lookup tables. Any - variables created will be automatically hoisted out to the outermost scope - using `init_scope()`. Some collections (such as for control-flow contexts) - are benign and should not generate a warning. - - Args: - transformation_name: A human-readable name for the transformation. - """ - if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS): - warnings.warn("Creating lookup tables inside a function passed to %s is not" - " supported. Create each table outside the function, and " - "capture it inside the function to use it." - % transformation_name) - - class MapDataset(UnaryDataset): """A `Dataset` that maps a function over elements in its input.""" diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index f261d92d64..aeb1cac3e9 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -663,6 +663,11 @@ class Function(object): return self._build_call_outputs(outputs) @property + def name(self): + """Function name.""" + return self._inference_function.name + + @property def graph(self): """Returns the graph from which this function was constructed.""" return self._func_graph @@ -719,6 +724,10 @@ class Function(object): return nest.map_structure(lambda x: x.dtype if x is not None else None, self._func_graph.structured_outputs) + def add_to_graph(self, g): + """Adds this function into the graph g.""" + return self._inference_function.add_to_graph(g) + def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" backwards_graph = FuncGraph(_backward_name(self._func_graph.name)) @@ -1122,6 +1131,8 @@ class PolymorphicFunction(object): *args: inputs to specialize on. **kwargs: inputs to specialize on. """ + if self._input_signature: + args, kwargs = None, None graph_function, _ = self._maybe_define_function(args, kwargs) return graph_function @@ -1304,6 +1315,9 @@ def register(func, *args, **kwargs): function definition into graph. Register function with different input param will result into multiple version of functions registered in graph. + Also, `args` and `kwargs` are ignored if this `PolymorphicFunction` was + created with an `input_signature`. + Args: func: the PolymorphicFunction instance that generated by a @defun *args: input arguments for the Python function. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 9ce367a837..ac45606eb0 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1750,11 +1750,10 @@ class FunctionTest(test.TestCase): # pylint: disable=protected-access self.assertEqual(len(graph._functions), 3) - # Test input param shape mismatch - t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - with self.assertRaisesRegexp( - ValueError, 'Python inputs incompatible with input_signature'): - function.register(defun_matmul, t2, t2) + # Test register function with cache, note inputs are ignored. + function.register(defun_matmul) + graph = ops.get_default_graph() + self.assertEqual(len(graph._functions), 3) def testRegisterFunctionWithCache(self): def matmul(x, y): |