diff options
author | Derek Murray <mrry@google.com> | 2018-10-02 20:00:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 20:04:04 -0700 |
commit | 2597b883a14749c77fffd7e5f9677107021ff40a (patch) | |
tree | b4a8ea07c72bd3f509edd2bb7d1c90dc3f81ba2d | |
parent | 4b2d0180ba8c903f098f52eb9a12d26a7626dd34 (diff) |
Automated rollback of commit b7e9cbab27c893283acc4a6154d7a59dffb23758
PiperOrigin-RevId: 215503549
-rw-r--r-- | tensorflow/contrib/distribute/python/input_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 60 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 14 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 9 |
4 files changed, 40 insertions, 45 deletions
diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py index 423952c9e2..f07ec8234d 100644 --- a/tensorflow/contrib/distribute/python/input_ops.py +++ b/tensorflow/contrib/distribute/python/input_ops.py @@ -78,7 +78,7 @@ def auto_shard_dataset(dataset, num_shards, index): elif hasattr(dataset, "_map_func"): # TODO(priyag): Make this check more robust by enforcing some common # property on all map/flatmap/interleave datasets. - map_func_def = dataset._map_func.function_def + map_func_def = dataset._map_func.definition for node in map_func_def.node_def: if node.op in _READER_DATASET_OPS: found_reader_op = True diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index d90da5908d..46ce191f7b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -30,7 +30,6 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse from tensorflow.python.eager import context -from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -38,7 +37,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -1715,8 +1713,7 @@ class _VariantDataset(Dataset): class StructuredFunctionWrapper(object): - """A wrapper for `defun` that supports structured arguments and return values. - + """A wrapper for `Defun` that supports structured arguments and return values. """ def __init__(self, func, transformation_name, dataset=None, @@ -1768,7 +1765,7 @@ class StructuredFunctionWrapper(object): # TODO(b/110122868): Enable this support for all `tf.data` functions. self._nested_dataset_support = experimental_nested_dataset_support - @eager_function.defun(input_signature=self._defun_args()) + @function.Defun(*self._defun_args()) def tf_data_structured_function_wrapper(*args): """Wrapper for passing nested structures to and from tf.data functions.""" flat_args = [] @@ -1853,43 +1850,36 @@ class StructuredFunctionWrapper(object): self._output_shapes = nest.pack_sequence_as(ret, flat_shapes) self._output_types = nest.pack_sequence_as(ret, flat_types) - return flat_ret + _warn_if_collections(transformation_name) - table_initializers_len = len(ops.get_default_graph().get_collection( - ops.GraphKeys.TABLE_INITIALIZERS)) + return flat_ret - self._function = tf_data_structured_function_wrapper.get_concrete_function() + self._function = tf_data_structured_function_wrapper if add_to_graph: self._function.add_to_graph(ops.get_default_graph()) - if len( - self._function.graph.get_collection( - ops.GraphKeys.TABLE_INITIALIZERS)) != table_initializers_len: - warnings.warn( - "Creating lookup tables inside a function passed to %s is not" - " supported. Create each table outside the function, and " - "capture it inside the function to use it." % transformation_name) + else: + # Use the private method that will execute + # `tf_data_structured_function_wrapper` but delay adding it to the graph + # in case (e.g.) we need to rerun the function. + self._function._create_definition_if_needed() # pylint: disable=protected-access def _defun_args(self): - """Returns a list of `tf.TensorSpec` for the input element structure.""" + """Returns a flat list of `tf.DType` for the input element structure.""" ret = [] - for input_type, input_shape, input_class in zip( - nest.flatten(self._input_types), nest.flatten(self._input_shapes), - nest.flatten(self._input_classes)): + for input_type, input_class in zip(nest.flatten(self._input_types), + nest.flatten(self._input_classes)): # TODO(b/110122868): Add a registration mechanism for new component types. if input_class is sparse_tensor_lib.SparseTensor: - ret.append( - tensor_spec.TensorSpec( - tensor_shape.TensorShape(None), dtypes.variant)) + ret.append(dtypes.variant) elif isinstance(input_class, _NestedDatasetComponent): if not self._nested_dataset_support: raise NotImplementedError( "The %s transformation does not currently support nested " "datasets as inputs." % self._transformation_name) - ret.append( - tensor_spec.TensorSpec(tensor_shape.scalar(), dtypes.variant)) + ret.append(dtypes.variant) else: assert isinstance(input_type, dtypes.DType) - ret.append(tensor_spec.TensorSpec(input_shape, input_type)) + ret.append(input_type) return ret @property @@ -2589,6 +2579,24 @@ def _should_unpack_args(args): return type(args) is tuple # pylint: disable=unidiomatic-typecheck +def _warn_if_collections(transformation_name): + """Prints warning message if the current graph uses common graph collections. + + NOTE(mrry): Currently a warning is only generated for lookup tables. Any + variables created will be automatically hoisted out to the outermost scope + using `init_scope()`. Some collections (such as for control-flow contexts) + are benign and should not generate a warning. + + Args: + transformation_name: A human-readable name for the transformation. + """ + if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS): + warnings.warn("Creating lookup tables inside a function passed to %s is not" + " supported. Create each table outside the function, and " + "capture it inside the function to use it." + % transformation_name) + + class MapDataset(UnaryDataset): """A `Dataset` that maps a function over elements in its input.""" diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index aeb1cac3e9..f261d92d64 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -663,11 +663,6 @@ class Function(object): return self._build_call_outputs(outputs) @property - def name(self): - """Function name.""" - return self._inference_function.name - - @property def graph(self): """Returns the graph from which this function was constructed.""" return self._func_graph @@ -724,10 +719,6 @@ class Function(object): return nest.map_structure(lambda x: x.dtype if x is not None else None, self._func_graph.structured_outputs) - def add_to_graph(self, g): - """Adds this function into the graph g.""" - return self._inference_function.add_to_graph(g) - def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" backwards_graph = FuncGraph(_backward_name(self._func_graph.name)) @@ -1131,8 +1122,6 @@ class PolymorphicFunction(object): *args: inputs to specialize on. **kwargs: inputs to specialize on. """ - if self._input_signature: - args, kwargs = None, None graph_function, _ = self._maybe_define_function(args, kwargs) return graph_function @@ -1315,9 +1304,6 @@ def register(func, *args, **kwargs): function definition into graph. Register function with different input param will result into multiple version of functions registered in graph. - Also, `args` and `kwargs` are ignored if this `PolymorphicFunction` was - created with an `input_signature`. - Args: func: the PolymorphicFunction instance that generated by a @defun *args: input arguments for the Python function. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index ac45606eb0..9ce367a837 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1750,10 +1750,11 @@ class FunctionTest(test.TestCase): # pylint: disable=protected-access self.assertEqual(len(graph._functions), 3) - # Test register function with cache, note inputs are ignored. - function.register(defun_matmul) - graph = ops.get_default_graph() - self.assertEqual(len(graph._functions), 3) + # Test input param shape mismatch + t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + with self.assertRaisesRegexp( + ValueError, 'Python inputs incompatible with input_signature'): + function.register(defun_matmul, t2, t2) def testRegisterFunctionWithCache(self): def matmul(x, y): |