aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-05 07:34:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 07:39:13 -0700
commit580a50a4bb30853199de191ba4d98f7390a138db (patch)
tree38ecf4d64c4323b60cf265e456ab5f9da245422b /tensorflow/contrib/autograph
parentffaab58cad72e177ada0e7d1d3724de63032928d (diff)
utils cleanup: move the builtins module under operators.
PiperOrigin-RevId: 211631516
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions.py41
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions_test.py9
-rw-r--r--tensorflow/contrib/autograph/impl/api.py4
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD11
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py5
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py6
-rw-r--r--tensorflow/contrib/autograph/operators/py_builtins.py225
-rw-r--r--tensorflow/contrib/autograph/operators/py_builtins_test.py131
-rw-r--r--tensorflow/contrib/autograph/utils/BUILD23
-rw-r--r--tensorflow/contrib/autograph/utils/__init__.py3
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py143
-rw-r--r--tensorflow/contrib/autograph/utils/builtins_test.py145
-rw-r--r--tensorflow/contrib/autograph/utils/tensors.py41
-rw-r--r--tensorflow/contrib/autograph/utils/tensors_test.py57
14 files changed, 508 insertions, 336 deletions
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py
index b26c52294c..29dce13999 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import templates
@@ -31,41 +33,32 @@ class BuiltinFunctionTransformer(converter.Base):
TF equivalent, like `len`.
"""
- def _convert_builtin(self, node):
+ def _convert_builtin(self, f, args, as_expression):
template = """
- ag__.utils.dynamic_builtin(func, args)
+ ag__.func(args)
"""
- return templates.replace(template, func=node.func, args=node.args)[0].value
-
- def _convert_print(self, node):
- template = """
- ag__.utils.dynamic_print(args)
- """
- return templates.replace(template, args=node.args)[0].value
+ if as_expression:
+ return templates.replace_as_expression(
+ template, func=py_builtins.overload_of(f).__name__, args=args)
+ else:
+ return templates.replace(
+ template, func=py_builtins.overload_of(f).__name__, args=args)
def visit_Call(self, node):
- self.generic_visit(node)
- # TODO(mdan): This won't work if the function was hidden.
- # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead.
- if (isinstance(node.func, gast.Name) and
- node.func.id in ('len', 'range', 'xrange', 'float', 'int')):
- return self._convert_builtin(node)
- # Print needs to be handled separately because it can be read as statement.
- if isinstance(node.func, gast.Name) and node.func.id == 'print':
- return self._convert_print(node)
+ node = self.generic_visit(node)
+ if anno.hasanno(node.func, 'live_val'):
+ live_val = anno.getanno(node.func, 'live_val')
+ if live_val in py_builtins.SUPPORTED_BUILTINS:
+ node = self._convert_builtin(live_val, node.args, as_expression=True)
return node
def visit_Print(self, node):
- self.generic_visit(node)
+ node = self.generic_visit(node)
args = node.values
# Following is the case when calling print(a, b)
if len(args) == 1 and isinstance(args[0], gast.Tuple):
args = args[0].elts
- template = """
- fname(args)
- """
- function_call = templates.replace(template, fname='print', args=args)[0]
- return self.visit(function_call)
+ return self._convert_builtin(print, args, as_expression=False)
def transform(node, ctx):
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
index d0a0cbbeb6..3e3a04f38b 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
@@ -23,6 +23,7 @@ import six
from tensorflow.contrib.autograph.converters import builtin_functions
from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -34,11 +35,11 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
def test_fn(a):
return len(a)
- with self.converted(test_fn, builtin_functions, {'len': len},
- array_ops.shape) as result:
+ with self.converted(test_fn, builtin_functions, {'len': len}) as result:
with self.cached_session() as sess:
- ops = result.test_fn(constant_op.constant([0, 0, 0]))
- self.assertEqual(sess.run(ops), 3)
+ p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+ ops = result.test_fn(p)
+ self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
def test_print(self):
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index 276a387180..8b38d5d080 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -29,9 +29,9 @@ import six
from tensorflow.contrib.autograph.core import config
from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.impl import conversion
+from tensorflow.contrib.autograph.operators import py_builtins
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.utils import builtins
from tensorflow.contrib.autograph.utils import py_func
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
@@ -150,7 +150,7 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
unknown_arg_value = object() # Sentinel for arguments of unknown value
if inspect_utils.isbuiltin(f):
- return builtins.dynamic_builtin(f, *args, **kwargs)
+ return py_builtins.overload_of(f)(*args, **kwargs)
if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
# Regular functions
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
index 332d5dab19..29759bad79 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -22,6 +22,7 @@ py_library(
"__init__.py",
"control_flow.py",
"data_structures.py",
+ "py_builtins.py",
"slices.py",
],
srcs_version = "PY2AND3",
@@ -62,6 +63,16 @@ py_test(
)
py_test(
+ name = "py_builtins_test",
+ srcs = ["py_builtins_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "slices_test",
srcs = ["slices_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index 392cb60bcc..c4fbc260a2 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -45,6 +45,11 @@ from tensorflow.contrib.autograph.operators.data_structures import list_stack
from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts
from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts
from tensorflow.contrib.autograph.operators.data_structures import new_list
+from tensorflow.contrib.autograph.operators.py_builtins import float_
+from tensorflow.contrib.autograph.operators.py_builtins import int_
+from tensorflow.contrib.autograph.operators.py_builtins import len_
+from tensorflow.contrib.autograph.operators.py_builtins import print_
+from tensorflow.contrib.autograph.operators.py_builtins import range_
from tensorflow.contrib.autograph.operators.slices import get_item
from tensorflow.contrib.autograph.operators.slices import GetItemOpts
from tensorflow.contrib.autograph.operators.slices import set_item
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index 9909e52164..9a66a6bb60 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import builtins
+from tensorflow.contrib.autograph.operators import py_builtins
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@@ -82,8 +82,8 @@ def _py_for_stmt(iter_, extra_test, body, init_state):
def _known_len_for_stmt(iter_, extra_test, body, init_state):
- """Overload of for_stmt that iterates over objects that define a length."""
- n = builtins.dynamic_len(iter_)
+ """Overload of for_stmt that iterates over objects that admit a length."""
+ n = py_builtins.len_(iter_)
def while_body(iterate_index, *state):
iterate = iter_[iterate_index]
diff --git a/tensorflow/contrib/autograph/operators/py_builtins.py b/tensorflow/contrib/autograph/operators/py_builtins.py
new file mode 100644
index 0000000000..c5730934e7
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/py_builtins.py
@@ -0,0 +1,225 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators corresponding to Python builtin functions.
+
+List of built-in functions: https://docs.python.org/3/library/functions.html
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import gen_string_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import math_ops
+
+
+UNDEFINED = object()
+
+
+def overload_of(f):
+ if f in SUPPORTED_BUILTINS:
+ return BUILTIN_FUINCTIONS_MAP[f.__name__]
+ return f
+
+
+def abs_(x):
+ if tensor_util.is_tensor(x):
+ return _tf_abs(x)
+ return _py_abs(x)
+
+
+def _tf_abs(x):
+ return math_ops.abs(x)
+
+
+def _py_abs(x):
+ return abs(x)
+
+
+def float_(x=0):
+ if tensor_util.is_tensor(x):
+ return _tf_float(x)
+ return _py_float(x)
+
+
+def _tf_float(x):
+ # TODO(mdan): We shouldn't assume float32.
+ if x.dtype == dtypes.string:
+ return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32)
+ return math_ops.cast(x, dtype=dtypes.float32)
+
+
+def _py_float(x):
+ return float(x)
+
+
+def int_(x=0, base=UNDEFINED):
+ if tensor_util.is_tensor(x):
+ return _tf_int(x, base)
+ return _py_int(x, base)
+
+
+def _tf_int(x, base):
+ if base not in (10, UNDEFINED):
+ raise NotImplementedError('base {} not supported for int'.format(base))
+
+ # TODO(mdan): We shouldn't assume int32.
+ if x.dtype == dtypes.string:
+ return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32)
+ return math_ops.cast(x, dtype=dtypes.int32)
+
+
+def _py_int(x, base):
+ if base is UNDEFINED:
+ return int(x)
+ return int(x, base)
+
+
+def len_(s):
+ if tensors.is_tensor_array(s):
+ return _tf_tensor_array_len(s)
+ elif tensors.is_tensor_list(s):
+ return _tf_tensor_list_len(s)
+ elif tensor_util.is_tensor(s):
+ return _tf_tensor_len(s)
+ return _py_len(s)
+
+
+def _tf_tensor_array_len(s):
+ return s.size()
+
+
+def _tf_tensor_list_len(s):
+ return list_ops.tensor_list_length(s)
+
+
+def _tf_tensor_len(s):
+ """Overload of len_ for Tensor arguments."""
+ # Statically shaped tensors: length is known ahead of time.
+ if s.shape.ndims and s.shape[0].value is not None:
+ return s.shape[0].value
+
+ # Static shape of unknown dimensions: use dynamic shape but statically
+ # chech that it's a scalar.
+ shape = array_ops.shape(s)
+
+ assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
+
+ if shape.shape[0] == 0:
+ raise ValueError(
+ 'len requires a non-scalar tensor, got one of shape {}'.format(shape))
+
+ if shape.shape[0].value is not None:
+ return array_ops.shape(s)[0]
+
+ # Fully dynamic shape: use ops.
+ rank = array_ops.rank(s)
+
+ def raise_zero_rank_error():
+ msg = gen_string_ops.string_join(
+ ['len requires non-zero rank, got ',
+ gen_string_ops.as_string(rank)])
+ with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]):
+ return constant_op.constant(0, dtype=dtypes.int32)
+
+ return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0],
+ raise_zero_rank_error)
+
+
+def _py_len(s):
+ return len(s)
+
+
+def print_(*objects, **kwargs):
+ # Note: Python 2.6 doesn't support explicit keywords after starargs.
+ unknown_kwargs = tuple(
+ set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush')))
+ if unknown_kwargs:
+ raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs))
+
+ # TODO(mdan): use logging_ops.Print when py_func is not supported.
+ return _tf_py_func_print(objects, kwargs)
+
+
+def _tf_py_func_print(objects, kwargs):
+ """Overload of print_ as a py_func implementation."""
+ override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
+ if 'flush' not in override_kwargs:
+ # Defaulting to flushing the console in graph mode, which helps reduce
+ # garbled output in IPython.
+ override_kwargs['flush'] = True
+
+ def print_wrapper(*vals):
+ if six.PY3:
+ # TensorFlow doesn't seem to generate Unicode when passing strings to
+ # py_func. This causes the print to add a "b'" wrapper to the output,
+ # which is probably never what you want.
+ vals = tuple(
+ v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
+ six.print_(*vals, **override_kwargs)
+
+ return py_func.wrap_py_func(
+ print_wrapper, None, objects, use_dummy_return=True)
+
+
+def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED):
+ if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)):
+ return _tf_range(start_or_stop, stop, step)
+ return _py_range(start_or_stop, stop, step)
+
+
+def _tf_range(start_or_stop, stop, step):
+ # TODO(mdan): We should optimize this when a full tensor is not required.
+ if step is not UNDEFINED:
+ return math_ops.range(start_or_stop, stop, step)
+ if stop is not UNDEFINED:
+ return math_ops.range(start_or_stop, stop)
+ return math_ops.range(start_or_stop)
+
+
+def _py_range(start_or_stop, stop, step):
+ if step is not UNDEFINED:
+ return range(start_or_stop, stop, step)
+ if stop is not UNDEFINED:
+ return range(start_or_stop, stop)
+ return range(start_or_stop)
+
+
+SUPPORTED_BUILTINS = set((abs, float, int, len, print, range))
+
+if six.PY2:
+ SUPPORTED_BUILTINS.add(xrange)
+
+BUILTIN_FUINCTIONS_MAP = {
+ 'abs': abs_,
+ 'float': float_,
+ 'int': int_,
+ 'len': len_,
+ 'print': print_,
+ 'range': range_,
+ 'xrange': range_,
+}
diff --git a/tensorflow/contrib/autograph/operators/py_builtins_test.py b/tensorflow/contrib/autograph/operators/py_builtins_test.py
new file mode 100644
index 0000000000..4073c51785
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/py_builtins_test.py
@@ -0,0 +1,131 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for py_builtins module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import six
+
+from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class PyBuiltinsTest(test.TestCase):
+
+ def test_abs(self):
+ self.assertEqual(py_builtins.abs_(-1), 1)
+ with self.test_session() as sess:
+ t = py_builtins.abs_(constant_op.constant(-1))
+ self.assertEqual(sess.run(t), 1)
+ t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
+ self.assertAllEqual(sess.run(t), [1, 2, 3])
+
+ def test_float(self):
+ self.assertEqual(py_builtins.float_(10), 10.0)
+ self.assertEqual(py_builtins.float_('10.0'), 10.0)
+ with self.test_session() as sess:
+ t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
+ self.assertEqual(sess.run(t), 1.0)
+ st = py_builtins.float_(constant_op.constant('1.0'))
+ self.assertEqual(sess.run(st), 1.0)
+
+ def test_int(self):
+ self.assertEqual(py_builtins.int_(10.0), 10)
+ self.assertEqual(py_builtins.int_('11', 2), 3)
+ with self.test_session() as sess:
+ t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
+ self.assertEqual(sess.run(t), 1)
+ st = py_builtins.int_(constant_op.constant('1'))
+ self.assertEqual(sess.run(st), 1)
+ st = py_builtins.int_(constant_op.constant('1'), 10)
+ self.assertEqual(sess.run(st), 1)
+
+ def test_int_unsupported_base(self):
+ t = constant_op.constant(1, dtype=dtypes.float64)
+ with self.assertRaises(NotImplementedError):
+ py_builtins.int_(t, 2)
+
+ def test_len(self):
+ self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
+ with self.test_session() as sess:
+ t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
+ self.assertEqual(t, 3)
+ ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
+ self.assertEqual(sess.run(ta), 5)
+ tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
+ self.assertEqual(sess.run(tl), 3)
+
+ def test_len_scalar(self):
+ with self.assertRaises(ValueError):
+ py_builtins.len_(constant_op.constant(1))
+
+ def test_len_dynamic_shape(self):
+ with self.test_session() as sess:
+ p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+ t = py_builtins.len_(p)
+ self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
+
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ t = py_builtins.len_(p)
+ sess.run(t, {p: 1})
+
+ def test_print_tensors(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
+ self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_print_complex(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(
+ py_builtins.print_(constant_op.constant('test message'), [1, 2]))
+ self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_range(self):
+ self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2])
+ self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2])
+ self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
+
+ def test_range_tensor(self):
+ with self.test_session() as sess:
+ r = py_builtins.range_(constant_op.constant(3))
+ self.assertAllEqual(sess.run(r), [0, 1, 2])
+ r = py_builtins.range_(1, constant_op.constant(3))
+ self.assertAllEqual(sess.run(r), [1, 2])
+ r = py_builtins.range_(2, 0, constant_op.constant(-1))
+ self.assertAllEqual(sess.run(r), [2, 1])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD
index d2b399f19b..4504a5c7a3 100644
--- a/tensorflow/contrib/autograph/utils/BUILD
+++ b/tensorflow/contrib/autograph/utils/BUILD
@@ -20,12 +20,12 @@ py_library(
name = "utils",
srcs = [
"__init__.py",
- "builtins.py",
"context_managers.py",
"misc.py",
"multiple_dispatch.py",
"py_func.py",
"tensor_list.py",
+ "tensors.py",
"testing.py",
"type_check.py",
],
@@ -42,17 +42,6 @@ py_library(
)
py_test(
- name = "builtins_test",
- srcs = ["builtins_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_windows"],
- deps = [
- ":utils",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
name = "context_managers_test",
srcs = ["context_managers_test.py"],
srcs_version = "PY2AND3",
@@ -113,3 +102,13 @@ py_test(
"//tensorflow/python:list_ops",
],
)
+
+py_test(
+ name = "tensors_test",
+ srcs = ["tensors_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py
index 57b5f74741..38e0a0a8f0 100644
--- a/tensorflow/contrib/autograph/utils/__init__.py
+++ b/tensorflow/contrib/autograph/utils/__init__.py
@@ -18,9 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin
-from tensorflow.contrib.autograph.utils.builtins import dynamic_print
-from tensorflow.contrib.autograph.utils.builtins import dynamic_range
from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns
from tensorflow.contrib.autograph.utils.misc import alias_tensors
from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
deleted file mode 100644
index 4dd440ef19..0000000000
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Builtin conversion utilities."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-import six
-
-from tensorflow.contrib.autograph.utils import py_func
-from tensorflow.contrib.autograph.utils import type_check
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import list_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import math_ops
-
-
-def dynamic_builtin(f, *args, **kwargs):
- """Converts a builtin function call inline."""
- if f is len:
- return dynamic_len(*args, **kwargs)
- if six.PY2 and f is xrange:
- return dynamic_range(*args, **kwargs)
- if f is range:
- return dynamic_range(*args, **kwargs)
- if f is int:
- return dynamic_int(*args, **kwargs)
- if f is float:
- return dynamic_float(*args, **kwargs)
- if f is abs:
- return dynamic_abs(*args, **kwargs)
-
- raise NotImplementedError(
- 'The "%s" builtin is not yet supported.' % f.__name__)
-
-
-def dynamic_len(list_or_tensor):
- """Implementation of len using dynamic dispatch."""
- if _is_tensor_list(list_or_tensor):
- return list_ops.tensor_list_length(list_or_tensor)
- elif tensor_util.is_tensor(list_or_tensor):
- shape = list_or_tensor.shape
- if not shape.ndims:
- raise ValueError(
- 'len requires non-zero rank for tensor "%s"' % list_or_tensor)
- return array_ops.shape(list_or_tensor)[0]
- return len(list_or_tensor)
-
-
-def _is_tensor_list(list_or_tensor):
- return (tensor_util.is_tensor(list_or_tensor)
- and list_or_tensor.dtype == dtypes.variant)
-
-
-def dynamic_int(num_or_tensor, **kwargs):
- """Implementation of int() using dynamic dispatch."""
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs)
- return int(num_or_tensor)
-
-
-def dynamic_float(num_or_tensor, **kwargs):
- """Implementation of float() using dynamic dispatch."""
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs)
- return float(num_or_tensor)
-
-
-def dynamic_abs(num_or_tensor, **kwargs):
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.abs(num_or_tensor, **kwargs)
- else:
- return abs(num_or_tensor, **kwargs)
-
-
-def dynamic_range(start_or_stop, stop=None, step=None):
- """Implementation of range using dynamic dispatch."""
- if type_check.is_tensor(start_or_stop, stop, step):
- if step is not None:
- return math_ops.range(start_or_stop, stop, step)
- if stop is not None:
- return math_ops.range(start_or_stop, stop)
- return math_ops.range(start_or_stop)
-
- if step is not None:
- return range(start_or_stop, stop, step)
- elif stop is not None:
- return range(start_or_stop, stop)
- return range(start_or_stop)
-
-
-def is_tf_print_compatible(value):
- # TODO(mdan): Enable once we can reliably test this.
- # This is currently disabled because we can't capture the output of
- # op kernels from Python.
- del value
- return False
-
-
-def dynamic_print(*values):
- """Implementation of print using dynamic dispatch.
-
- The function attempts to use tf.Print if all the values are compatible.
- Otherwise, it will fall back to py_func.
-
- Args:
- *values: values to print
- Returns:
- A dummy value indicating the print completed. If tf.
- """
-
- if all(map(is_tf_print_compatible, values)):
- return logging_ops.Print(1, values)
-
- def print_wrapper(*vals):
- if six.PY3:
- # TensorFlow doesn't seem to generate Unicode when passing strings to
- # py_func. This causes the print to add a "b'" wrapper to the output,
- # which is probably never what you want.
- vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals)
- print(*vals)
- # The flush helps avoid garbled output in IPython.
- sys.stdout.flush()
-
- return py_func.wrap_py_func(
- print_wrapper, None, values, use_dummy_return=True)
diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py
deleted file mode 100644
index b1cd5253bc..0000000000
--- a/tensorflow/contrib/autograph/utils/builtins_test.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for builtins module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-import six
-
-from tensorflow.contrib.autograph.utils import builtins
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.platform import test
-
-
-class BuiltinsTest(test.TestCase):
-
- def test_dynamic_len_tf_scalar(self):
- a = constant_op.constant(1)
-
- with self.assertRaisesRegexp(ValueError,
- 'len requires non-zero rank for tensor.*'):
- with self.test_session() as sess:
- sess.run(builtins.dynamic_builtin(len, a))
-
- def test_dynamic_len_tf_array(self):
- a = constant_op.constant([1, 2, 3])
-
- with self.test_session() as sess:
- self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
-
- def test_dynamic_abs_tf_scalar(self):
- a = constant_op.constant(-1)
-
- with self.test_session() as sess:
- self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a)))
-
- def test_dynamic_abs_tf_array(self):
- a = constant_op.constant([-1, 2, -3])
-
- with self.test_session() as sess:
- self.assertListEqual([1, 2, 3],
- list(sess.run(builtins.dynamic_builtin(abs, a))))
-
- def test_dynamic_abs_py_scalar(self):
- a = -1
- self.assertEqual(1, builtins.dynamic_builtin(abs, a))
-
- def test_dynamic_len_tf_matrix(self):
- a = constant_op.constant([[1, 2], [3, 4]])
-
- with self.test_session() as sess:
- self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
-
- def test_dynamic_len_py_list(self):
- a = [3] * 5
-
- self.assertEqual(5, builtins.dynamic_builtin(len, a))
-
- def test_dynamic_range_all_python(self):
- self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2])
- self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1])
-
- def test_dynamic_range_tf(self):
- with self.test_session() as sess:
- self.assertAllEqual(
- sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))),
- [0, 1, 2])
- self.assertAllEqual(
- sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))),
- [1, 2])
- self.assertAllEqual(
- sess.run(
- builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))),
- [2, 1])
-
- def test_dynamic_range_detection(self):
- def range(x): # pylint:disable=redefined-builtin
- return x
-
- # Functions that just have the names of builtins are rejected.
- with self.assertRaises(NotImplementedError):
- self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
- if six.PY2:
- self.assertListEqual(
- list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
-
- def test_casts(self):
- i = constant_op.constant(2, dtype=dtypes.int32)
- f = constant_op.constant(1.0, dtype=dtypes.float32)
-
- self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32)
- self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32)
- self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32)
- self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32)
-
- self.assertEqual(builtins.dynamic_builtin(int, True), 1)
- self.assertEqual(builtins.dynamic_builtin(int, False), 0)
- self.assertEqual(builtins.dynamic_builtin(float, True), 1.0)
- self.assertEqual(builtins.dynamic_builtin(float, False), 0.0)
-
- def test_dynamic_print_tf(self):
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- with self.test_session() as sess:
- sess.run(builtins.dynamic_print('test message', 1))
- self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
- finally:
- sys.stdout = sys.__stdout__
-
- def test_dynamic_print_complex(self):
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- with self.test_session() as sess:
- sess.run(builtins.dynamic_print('test message', [1, 2]))
- self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
- finally:
- sys.stdout = sys.__stdout__
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/autograph/utils/tensors.py b/tensorflow/contrib/autograph/utils/tensors.py
new file mode 100644
index 0000000000..fa5db81a71
--- /dev/null
+++ b/tensorflow/contrib/autograph/utils/tensors.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This module defines tensor utilities not found in TensorFlow.
+
+The reason these utilities are not defined in TensorFlow is because they may
+not be not fully robust, although they work in the vast majority of cases. So
+we define them here in order for their behavior to be consistently verified.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import tensor_array_ops
+
+
+def is_tensor_array(t):
+ return isinstance(t, tensor_array_ops.TensorArray)
+
+
+def is_tensor_list(t):
+ # TODO(mdan): This is just a heuristic.
+ # With TF lacking support for templated types, this is unfortunately the
+ # closest we can get right now. A dedicated op ought to be possible to
+ # construct.
+ return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and
+ not t.shape.ndims)
diff --git a/tensorflow/contrib/autograph/utils/tensors_test.py b/tensorflow/contrib/autograph/utils/tensors_test.py
new file mode 100644
index 0000000000..e855e0b6cb
--- /dev/null
+++ b/tensorflow/contrib/autograph/utils/tensors_test.py
@@ -0,0 +1,57 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensors module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class TensorsTest(test.TestCase):
+
+ def _simple_tensor_array(self):
+ return tensor_array_ops.TensorArray(dtypes.int32, size=3)
+
+ def _simple_tensor_list(self):
+ return list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([1]), element_dtype=dtypes.int32)
+
+ def _simple_list_of_tensors(self):
+ return [constant_op.constant(1), constant_op.constant(2)]
+
+ def test_is_tensor_array(self):
+ self.assertTrue(tensors.is_tensor_array(self._simple_tensor_array()))
+ self.assertFalse(tensors.is_tensor_array(self._simple_tensor_list()))
+ self.assertFalse(tensors.is_tensor_array(constant_op.constant(1)))
+ self.assertFalse(tensors.is_tensor_array(self._simple_list_of_tensors()))
+ self.assertFalse(tensors.is_tensor_array(None))
+
+ def test_is_tensor_list(self):
+ self.assertFalse(tensors.is_tensor_list(self._simple_tensor_array()))
+ self.assertTrue(tensors.is_tensor_list(self._simple_tensor_list()))
+ self.assertFalse(tensors.is_tensor_list(constant_op.constant(1)))
+ self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors()))
+ self.assertFalse(tensors.is_tensor_list(None))
+
+
+if __name__ == '__main__':
+ test.main()