aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-10-02 17:48:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 17:52:21 -0700
commitb7e9cbab27c893283acc4a6154d7a59dffb23758 (patch)
tree2eed3086deb2390f8bf6b8074ddd654ca59cfec8
parent80821abd6410f47130fc031b15e9ac220de5b1b9 (diff)
Use `defun` instead of `Defun` for `tf.data`, except for `make_one_shot_iterator` which is to be deprecated in future.
PiperOrigin-RevId: 215491729
-rw-r--r--tensorflow/contrib/distribute/python/input_ops.py2
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py60
-rw-r--r--tensorflow/python/eager/function.py14
-rw-r--r--tensorflow/python/eager/function_test.py9
4 files changed, 45 insertions, 40 deletions
diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py
index f07ec8234d..423952c9e2 100644
--- a/tensorflow/contrib/distribute/python/input_ops.py
+++ b/tensorflow/contrib/distribute/python/input_ops.py
@@ -78,7 +78,7 @@ def auto_shard_dataset(dataset, num_shards, index):
elif hasattr(dataset, "_map_func"):
# TODO(priyag): Make this check more robust by enforcing some common
# property on all map/flatmap/interleave datasets.
- map_func_def = dataset._map_func.definition
+ map_func_def = dataset._map_func.function_def
for node in map_func_def.node_def:
if node.op in _READER_DATASET_OPS:
found_reader_op = True
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 46ce191f7b..d90da5908d 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -30,6 +30,7 @@ from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
+from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
@@ -37,6 +38,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -1713,7 +1715,8 @@ class _VariantDataset(Dataset):
class StructuredFunctionWrapper(object):
- """A wrapper for `Defun` that supports structured arguments and return values.
+ """A wrapper for `defun` that supports structured arguments and return values.
+
"""
def __init__(self, func, transformation_name, dataset=None,
@@ -1765,7 +1768,7 @@ class StructuredFunctionWrapper(object):
# TODO(b/110122868): Enable this support for all `tf.data` functions.
self._nested_dataset_support = experimental_nested_dataset_support
- @function.Defun(*self._defun_args())
+ @eager_function.defun(input_signature=self._defun_args())
def tf_data_structured_function_wrapper(*args):
"""Wrapper for passing nested structures to and from tf.data functions."""
flat_args = []
@@ -1850,36 +1853,43 @@ class StructuredFunctionWrapper(object):
self._output_shapes = nest.pack_sequence_as(ret, flat_shapes)
self._output_types = nest.pack_sequence_as(ret, flat_types)
- _warn_if_collections(transformation_name)
-
return flat_ret
- self._function = tf_data_structured_function_wrapper
+ table_initializers_len = len(ops.get_default_graph().get_collection(
+ ops.GraphKeys.TABLE_INITIALIZERS))
+
+ self._function = tf_data_structured_function_wrapper.get_concrete_function()
if add_to_graph:
self._function.add_to_graph(ops.get_default_graph())
- else:
- # Use the private method that will execute
- # `tf_data_structured_function_wrapper` but delay adding it to the graph
- # in case (e.g.) we need to rerun the function.
- self._function._create_definition_if_needed() # pylint: disable=protected-access
+ if len(
+ self._function.graph.get_collection(
+ ops.GraphKeys.TABLE_INITIALIZERS)) != table_initializers_len:
+ warnings.warn(
+ "Creating lookup tables inside a function passed to %s is not"
+ " supported. Create each table outside the function, and "
+ "capture it inside the function to use it." % transformation_name)
def _defun_args(self):
- """Returns a flat list of `tf.DType` for the input element structure."""
+ """Returns a list of `tf.TensorSpec` for the input element structure."""
ret = []
- for input_type, input_class in zip(nest.flatten(self._input_types),
- nest.flatten(self._input_classes)):
+ for input_type, input_shape, input_class in zip(
+ nest.flatten(self._input_types), nest.flatten(self._input_shapes),
+ nest.flatten(self._input_classes)):
# TODO(b/110122868): Add a registration mechanism for new component types.
if input_class is sparse_tensor_lib.SparseTensor:
- ret.append(dtypes.variant)
+ ret.append(
+ tensor_spec.TensorSpec(
+ tensor_shape.TensorShape(None), dtypes.variant))
elif isinstance(input_class, _NestedDatasetComponent):
if not self._nested_dataset_support:
raise NotImplementedError(
"The %s transformation does not currently support nested "
"datasets as inputs." % self._transformation_name)
- ret.append(dtypes.variant)
+ ret.append(
+ tensor_spec.TensorSpec(tensor_shape.scalar(), dtypes.variant))
else:
assert isinstance(input_type, dtypes.DType)
- ret.append(input_type)
+ ret.append(tensor_spec.TensorSpec(input_shape, input_type))
return ret
@property
@@ -2579,24 +2589,6 @@ def _should_unpack_args(args):
return type(args) is tuple # pylint: disable=unidiomatic-typecheck
-def _warn_if_collections(transformation_name):
- """Prints warning message if the current graph uses common graph collections.
-
- NOTE(mrry): Currently a warning is only generated for lookup tables. Any
- variables created will be automatically hoisted out to the outermost scope
- using `init_scope()`. Some collections (such as for control-flow contexts)
- are benign and should not generate a warning.
-
- Args:
- transformation_name: A human-readable name for the transformation.
- """
- if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS):
- warnings.warn("Creating lookup tables inside a function passed to %s is not"
- " supported. Create each table outside the function, and "
- "capture it inside the function to use it."
- % transformation_name)
-
-
class MapDataset(UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index f261d92d64..aeb1cac3e9 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -663,6 +663,11 @@ class Function(object):
return self._build_call_outputs(outputs)
@property
+ def name(self):
+ """Function name."""
+ return self._inference_function.name
+
+ @property
def graph(self):
"""Returns the graph from which this function was constructed."""
return self._func_graph
@@ -719,6 +724,10 @@ class Function(object):
return nest.map_structure(lambda x: x.dtype if x is not None else None,
self._func_graph.structured_outputs)
+ def add_to_graph(self, g):
+ """Adds this function into the graph g."""
+ return self._inference_function.add_to_graph(g)
+
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
@@ -1122,6 +1131,8 @@ class PolymorphicFunction(object):
*args: inputs to specialize on.
**kwargs: inputs to specialize on.
"""
+ if self._input_signature:
+ args, kwargs = None, None
graph_function, _ = self._maybe_define_function(args, kwargs)
return graph_function
@@ -1304,6 +1315,9 @@ def register(func, *args, **kwargs):
function definition into graph. Register function with different input param
will result into multiple version of functions registered in graph.
+ Also, `args` and `kwargs` are ignored if this `PolymorphicFunction` was
+ created with an `input_signature`.
+
Args:
func: the PolymorphicFunction instance that generated by a @defun
*args: input arguments for the Python function.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 9ce367a837..ac45606eb0 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1750,11 +1750,10 @@ class FunctionTest(test.TestCase):
# pylint: disable=protected-access
self.assertEqual(len(graph._functions), 3)
- # Test input param shape mismatch
- t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
- with self.assertRaisesRegexp(
- ValueError, 'Python inputs incompatible with input_signature'):
- function.register(defun_matmul, t2, t2)
+ # Test register function with cache, note inputs are ignored.
+ function.register(defun_matmul)
+ graph = ops.get_default_graph()
+ self.assertEqual(len(graph._functions), 3)
def testRegisterFunctionWithCache(self):
def matmul(x, y):