aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-10-02 17:48:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 17:52:21 -0700
commitb7e9cbab27c893283acc4a6154d7a59dffb23758 (patch)
tree2eed3086deb2390f8bf6b8074ddd654ca59cfec8 /tensorflow/python/eager
parent80821abd6410f47130fc031b15e9ac220de5b1b9 (diff)
Use `defun` instead of `Defun` for `tf.data`, except for `make_one_shot_iterator` which is to be deprecated in future.
PiperOrigin-RevId: 215491729
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py14
-rw-r--r--tensorflow/python/eager/function_test.py9
2 files changed, 18 insertions, 5 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index f261d92d64..aeb1cac3e9 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -663,6 +663,11 @@ class Function(object):
return self._build_call_outputs(outputs)
@property
+ def name(self):
+ """Function name."""
+ return self._inference_function.name
+
+ @property
def graph(self):
"""Returns the graph from which this function was constructed."""
return self._func_graph
@@ -719,6 +724,10 @@ class Function(object):
return nest.map_structure(lambda x: x.dtype if x is not None else None,
self._func_graph.structured_outputs)
+ def add_to_graph(self, g):
+ """Adds this function into the graph g."""
+ return self._inference_function.add_to_graph(g)
+
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
@@ -1122,6 +1131,8 @@ class PolymorphicFunction(object):
*args: inputs to specialize on.
**kwargs: inputs to specialize on.
"""
+ if self._input_signature:
+ args, kwargs = None, None
graph_function, _ = self._maybe_define_function(args, kwargs)
return graph_function
@@ -1304,6 +1315,9 @@ def register(func, *args, **kwargs):
function definition into graph. Register function with different input param
will result into multiple version of functions registered in graph.
+ Also, `args` and `kwargs` are ignored if this `PolymorphicFunction` was
+ created with an `input_signature`.
+
Args:
func: the PolymorphicFunction instance that generated by a @defun
*args: input arguments for the Python function.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 9ce367a837..ac45606eb0 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1750,11 +1750,10 @@ class FunctionTest(test.TestCase):
# pylint: disable=protected-access
self.assertEqual(len(graph._functions), 3)
- # Test input param shape mismatch
- t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
- with self.assertRaisesRegexp(
- ValueError, 'Python inputs incompatible with input_signature'):
- function.register(defun_matmul, t2, t2)
+ # Test register function with cache, note inputs are ignored.
+ function.register(defun_matmul)
+ graph = ops.get_default_graph()
+ self.assertEqual(len(graph._functions), 3)
def testRegisterFunctionWithCache(self):
def matmul(x, y):