aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-26 17:04:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 17:12:17 -0800
commit49d4e9233cebdff001ffcc2e3d703e815ba0a881 (patch)
tree7ae23c98343d1d7e62a0dca9d93fd60cb3cec322
parente37a7ae2277a2a2f7b50ad5ef361e41c30edeb41 (diff)
Consolidate the builtin function overrides into a single module, and use a generic `dynamic_builtin` function to dispatch between implementations. Use the generic dispatcher in the generated code.
PiperOrigin-RevId: 187104685
-rw-r--r--tensorflow/contrib/py2tf/converters/builtin_functions.py13
-rw-r--r--tensorflow/contrib/py2tf/utils/BUILD12
-rw-r--r--tensorflow/contrib/py2tf/utils/__init__.py4
-rw-r--r--tensorflow/contrib/py2tf/utils/builtins.py (renamed from tensorflow/contrib/py2tf/utils/printing.py)32
-rw-r--r--tensorflow/contrib/py2tf/utils/builtins_test.py (renamed from tensorflow/contrib/py2tf/utils/printing_test.py)39
-rw-r--r--tensorflow/contrib/py2tf/utils/misc.py13
-rw-r--r--tensorflow/contrib/py2tf/utils/misc_test.py27
7 files changed, 72 insertions, 68 deletions
diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/py2tf/converters/builtin_functions.py
index e69038aced..b5aa9756da 100644
--- a/tensorflow/contrib/py2tf/converters/builtin_functions.py
+++ b/tensorflow/contrib/py2tf/converters/builtin_functions.py
@@ -36,23 +36,24 @@ class BuiltinFunctionTransformer(transformer.Base):
# pylint:disable=invalid-name
- def _convert_len(self, node):
+ def _convert_builtin(self, node):
template = """
- py2tf_utils.dynamic_len(args)
+ py2tf_utils.dynamic_builtin(func, args)
"""
- return templates.replace(template, args=node.args)[0].value
+ return templates.replace(template, func=node.func, args=node.args)[0].value
def _convert_print(self, node):
template = """
- py2tf_utils.call_print(args)
+ py2tf_utils.dynamic_print(args)
"""
return templates.replace(template, args=node.args)[0].value
def visit_Call(self, node):
self.generic_visit(node)
# TODO(mdan): This won't work if the function was hidden.
- if isinstance(node.func, gast.Name) and node.func.id == 'len':
- return self._convert_len(node)
+ if isinstance(node.func, gast.Name) and node.func.id in ('len',):
+ return self._convert_builtin(node)
+ # Print needs to be handled separately because it can be read as statement.
if isinstance(node.func, gast.Name) and node.func.id == 'print':
return self._convert_print(node)
return node
diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/py2tf/utils/BUILD
index c2fdd40707..2086a9ef60 100644
--- a/tensorflow/contrib/py2tf/utils/BUILD
+++ b/tensorflow/contrib/py2tf/utils/BUILD
@@ -20,10 +20,10 @@ py_library(
name = "utils",
srcs = [
"__init__.py",
+ "builtins.py",
"context_managers.py",
"misc.py",
"multiple_dispatch.py",
- "printing.py",
"py_func.py",
"tensor_list.py",
"type_check.py",
@@ -77,16 +77,6 @@ py_test(
)
py_test(
- name = "printing_test",
- srcs = ["printing_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
name = "type_check_test",
srcs = ["type_check_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py
index d931322bf3..19bf2272bc 100644
--- a/tensorflow/contrib/py2tf/utils/__init__.py
+++ b/tensorflow/contrib/py2tf/utils/__init__.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.py2tf.utils.builtins import dynamic_builtin
+from tensorflow.contrib.py2tf.utils.builtins import dynamic_print
from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns
from tensorflow.contrib.py2tf.utils.misc import alias_tensors
-from tensorflow.contrib.py2tf.utils.misc import dynamic_len
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while
-from tensorflow.contrib.py2tf.utils.printing import call_print
from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func
from tensorflow.contrib.py2tf.utils.type_check import is_tensor
diff --git a/tensorflow/contrib/py2tf/utils/printing.py b/tensorflow/contrib/py2tf/utils/builtins.py
index 95a62bd80b..0a50b80b60 100644
--- a/tensorflow/contrib/py2tf/utils/printing.py
+++ b/tensorflow/contrib/py2tf/utils/builtins.py
@@ -12,14 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""TensorFlow printing support utilities."""
+"""Builtin conversion utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.py2tf.utils import py_func
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
+from tensorflow.python.util import tf_inspect
+
+
+def dynamic_builtin(f, *args, **kwargs):
+ """Converts a builtin function call inline."""
+ if not tf_inspect.isbuiltin(f):
+ return f(*args, **kwargs)
+
+ if f is len:
+ return dynamic_len(*args, **kwargs)
+
+ raise NotImplementedError('The "%s" builtin is not yet supported.' % f)
+
+
+def dynamic_len(list_or_tensor):
+ """Implementation of len using dynamic dispatch."""
+ if tensor_util.is_tensor(list_or_tensor):
+ shape = list_or_tensor.shape
+ if not shape:
+ raise ValueError(
+ 'len requires non-zero rank for tensor "%s"' % list_or_tensor)
+ return array_ops.shape(list_or_tensor)[0]
+
+ return len(list_or_tensor)
def is_tf_print_compatible(value):
@@ -30,8 +56,8 @@ def is_tf_print_compatible(value):
return False
-def call_print(*values):
- """Compiled counterpart of the print builtin.
+def dynamic_print(*values):
+ """Implementartion of print using dynamic dispatch.
The function attempts to use tf.Print if all the values are compatible.
Otherwise, it will fall back to py_func.
diff --git a/tensorflow/contrib/py2tf/utils/printing_test.py b/tensorflow/contrib/py2tf/utils/builtins_test.py
index 2070deb304..19a72c63ec 100644
--- a/tensorflow/contrib/py2tf/utils/printing_test.py
+++ b/tensorflow/contrib/py2tf/utils/builtins_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for printing module."""
+"""Tests for builtins module."""
from __future__ import absolute_import
from __future__ import division
@@ -22,28 +22,53 @@ import sys
import six
-from tensorflow.contrib.py2tf.utils import printing
+from tensorflow.contrib.py2tf.utils import builtins
+from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
-class ContextManagersTest(test.TestCase):
+class BuiltinsTest(test.TestCase):
- def test_call_print_tf(self):
+ def test_dynamic_len_tf_scalar(self):
+ a = constant_op.constant(1)
+
+ with self.assertRaises(ValueError):
+ with self.test_session() as sess:
+ sess.run(builtins.dynamic_builtin(len, a))
+
+ def test_dynamic_len_tf_array(self):
+ a = constant_op.constant([1, 2, 3])
+
+ with self.test_session() as sess:
+ self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
+
+ def test_dynamic_len_tf_matrix(self):
+ a = constant_op.constant([[1, 2], [3, 4]])
+
+ with self.test_session() as sess:
+ self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
+
+ def test_dynamic_len_py_list(self):
+ a = [3] * 5
+
+ self.assertEqual(5, builtins.dynamic_builtin(len, a))
+
+ def test_dynamic_print_tf(self):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
with self.test_session() as sess:
- sess.run(printing.call_print('test message', 1))
+ sess.run(builtins.dynamic_print('test message', 1))
self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
finally:
sys.stdout = sys.__stdout__
- def test_call_print_py_func(self):
+ def test_dynamic_print_complex(self):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
with self.test_session() as sess:
- sess.run(printing.call_print('test message', [1, 2]))
+ sess.run(builtins.dynamic_print('test message', [1, 2]))
self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
finally:
sys.stdout = sys.__stdout__
diff --git a/tensorflow/contrib/py2tf/utils/misc.py b/tensorflow/contrib/py2tf/utils/misc.py
index 7548048388..1b06caf0bd 100644
--- a/tensorflow/contrib/py2tf/utils/misc.py
+++ b/tensorflow/contrib/py2tf/utils/misc.py
@@ -19,22 +19,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
-def dynamic_len(list_or_tensor):
- """Implementation of len using dynamic dispatch."""
- if tensor_util.is_tensor(list_or_tensor):
- shape = list_or_tensor.shape
- if not shape:
- raise ValueError(
- 'len requires non-zero rank for tensor "%s"' % list_or_tensor)
- return array_ops.shape(list_or_tensor)[0]
-
- return len(list_or_tensor)
-
-
def alias_tensors(*args):
"""Wrap any Tensor arguments with an identity op.
diff --git a/tensorflow/contrib/py2tf/utils/misc_test.py b/tensorflow/contrib/py2tf/utils/misc_test.py
index ec88e7cb74..8aedd4cd64 100644
--- a/tensorflow/contrib/py2tf/utils/misc_test.py
+++ b/tensorflow/contrib/py2tf/utils/misc_test.py
@@ -19,37 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.py2tf.utils.misc import alias_tensors
-from tensorflow.contrib.py2tf.utils.misc import dynamic_len
from tensorflow.python.framework.constant_op import constant
from tensorflow.python.ops.variables import Variable
from tensorflow.python.platform import test
-class ContextManagersTest(test.TestCase):
-
- def test_dynamic_len_tf_scalar(self):
- a = constant(1)
-
- with self.assertRaises(ValueError):
- with self.test_session() as sess:
- sess.run(dynamic_len(a))
-
- def test_dynamic_len_tf_array(self):
- a = constant([1, 2, 3])
-
- with self.test_session() as sess:
- self.assertEqual(3, sess.run(dynamic_len(a)))
-
- def test_dynamic_len_tf_matrix(self):
- a = constant([[1, 2], [3, 4]])
-
- with self.test_session() as sess:
- self.assertEqual(2, sess.run(dynamic_len(a)))
-
- def test_dynamic_len_py_list(self):
- a = [3] * 5
-
- self.assertEqual(5, dynamic_len(a))
+class MiscTest(test.TestCase):
def test_alias_single_tensor(self):
a = constant(1)