aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-01 13:09:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-01 13:13:40 -0800
commit8307faacb96808eae1550ed879fa9a85cf76d897 (patch)
tree96e33ecac46d1e197f99f1010ae53c33fe75d401
parent16478853c73d9e6dfab26e73e99d931f4c74043c (diff)
Add support for keyword args for dynamically converted functions.
PiperOrigin-RevId: 187522324
-rw-r--r--tensorflow/contrib/py2tf/converters/call_trees.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/py2tf/converters/call_trees.py
index f18f9f6086..ca8726f916 100644
--- a/tensorflow/contrib/py2tf/converters/call_trees.py
+++ b/tensorflow/contrib/py2tf/converters/call_trees.py
@@ -185,7 +185,7 @@ class CallTreeTransformer(transformer.Base):
"""
return templates.replace(template, func=node.func, original_args=node.args)
- def _converted_call(self, node):
+ def _insert_dynamic_conversion(self, node):
"""Inlines a dynamic conversion for a dynamic function."""
# TODO(mdan): Pass information on the statically compiled functions.
# Having access to the statically compiled functions can help avoid
@@ -208,7 +208,10 @@ class CallTreeTransformer(transformer.Base):
"""
call_expr = templates.replace(
template, func=node.func, original_args=node.args)
- return call_expr[0].value
+ new_call = call_expr[0].value
+ # TODO(mdan): Improve the template mechanism to better support this.
+ new_call.keywords = node.keywords
+ return new_call
# pylint:disable=invalid-name
@@ -251,7 +254,7 @@ class CallTreeTransformer(transformer.Base):
raise NotImplementedError('py_func with return values')
else:
if self.context.recursive:
- node = self._converted_call(node)
+ node = self._insert_dynamic_conversion(node)
else:
# Unresolved functions are allowed in non-recursive mode.
pass