diff options
author | Dan Moldovan <mdan@google.com> | 2018-09-05 07:34:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-05 07:39:13 -0700 |
commit | 580a50a4bb30853199de191ba4d98f7390a138db (patch) | |
tree | 38ecf4d64c4323b60cf265e456ab5f9da245422b /tensorflow/contrib/autograph | |
parent | ffaab58cad72e177ada0e7d1d3724de63032928d (diff) |
utils cleanup: move the builtins module under operators.
PiperOrigin-RevId: 211631516
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r-- | tensorflow/contrib/autograph/converters/builtin_functions.py | 41 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/converters/builtin_functions_test.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/impl/api.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/operators/BUILD | 11 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/operators/__init__.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/operators/control_flow.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/operators/py_builtins.py | 225 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/operators/py_builtins_test.py | 131 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/utils/BUILD | 23 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/utils/__init__.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/utils/builtins.py | 143 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/utils/builtins_test.py | 145 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/utils/tensors.py | 41 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/utils/tensors_test.py | 57 |
14 files changed, 508 insertions, 336 deletions
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index b26c52294c..29dce13999 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -21,6 +21,8 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.operators import py_builtins +from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates @@ -31,41 +33,32 @@ class BuiltinFunctionTransformer(converter.Base): TF equivalent, like `len`. """ - def _convert_builtin(self, node): + def _convert_builtin(self, f, args, as_expression): template = """ - ag__.utils.dynamic_builtin(func, args) + ag__.func(args) """ - return templates.replace(template, func=node.func, args=node.args)[0].value - - def _convert_print(self, node): - template = """ - ag__.utils.dynamic_print(args) - """ - return templates.replace(template, args=node.args)[0].value + if as_expression: + return templates.replace_as_expression( + template, func=py_builtins.overload_of(f).__name__, args=args) + else: + return templates.replace( + template, func=py_builtins.overload_of(f).__name__, args=args) def visit_Call(self, node): - self.generic_visit(node) - # TODO(mdan): This won't work if the function was hidden. - # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. - if (isinstance(node.func, gast.Name) and - node.func.id in ('len', 'range', 'xrange', 'float', 'int')): - return self._convert_builtin(node) - # Print needs to be handled separately because it can be read as statement. - if isinstance(node.func, gast.Name) and node.func.id == 'print': - return self._convert_print(node) + node = self.generic_visit(node) + if anno.hasanno(node.func, 'live_val'): + live_val = anno.getanno(node.func, 'live_val') + if live_val in py_builtins.SUPPORTED_BUILTINS: + node = self._convert_builtin(live_val, node.args, as_expression=True) return node def visit_Print(self, node): - self.generic_visit(node) + node = self.generic_visit(node) args = node.values # Following is the case when calling print(a, b) if len(args) == 1 and isinstance(args[0], gast.Tuple): args = args[0].elts - template = """ - fname(args) - """ - function_call = templates.replace(template, fname='print', args=args)[0] - return self.visit(function_call) + return self._convert_builtin(print, args, as_expression=False) def transform(node, ctx): diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index d0a0cbbeb6..3e3a04f38b 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -23,6 +23,7 @@ import six from tensorflow.contrib.autograph.converters import builtin_functions from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -34,11 +35,11 @@ class BuiltinFunctionsTest(converter_testing.TestCase): def test_fn(a): return len(a) - with self.converted(test_fn, builtin_functions, {'len': len}, - array_ops.shape) as result: + with self.converted(test_fn, builtin_functions, {'len': len}) as result: with self.cached_session() as sess: - ops = result.test_fn(constant_op.constant([0, 0, 0])) - self.assertEqual(sess.run(ops), 3) + p = array_ops.placeholder(dtype=dtypes.int32, shape=None) + ops = result.test_fn(p) + self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3) def test_print(self): diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index 276a387180..8b38d5d080 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -29,9 +29,9 @@ import six from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import conversion +from tensorflow.contrib.autograph.operators import py_builtins from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import inspect_utils -from tensorflow.contrib.autograph.utils import builtins from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_decorator @@ -150,7 +150,7 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, unknown_arg_value = object() # Sentinel for arguments of unknown value if inspect_utils.isbuiltin(f): - return builtins.dynamic_builtin(f, *args, **kwargs) + return py_builtins.overload_of(f)(*args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 332d5dab19..29759bad79 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -22,6 +22,7 @@ py_library( "__init__.py", "control_flow.py", "data_structures.py", + "py_builtins.py", "slices.py", ], srcs_version = "PY2AND3", @@ -62,6 +63,16 @@ py_test( ) py_test( + name = "py_builtins_test", + srcs = ["py_builtins_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) + +py_test( name = "slices_test", srcs = ["slices_test.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index 392cb60bcc..c4fbc260a2 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -45,6 +45,11 @@ from tensorflow.contrib.autograph.operators.data_structures import list_stack from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts from tensorflow.contrib.autograph.operators.data_structures import new_list +from tensorflow.contrib.autograph.operators.py_builtins import float_ +from tensorflow.contrib.autograph.operators.py_builtins import int_ +from tensorflow.contrib.autograph.operators.py_builtins import len_ +from tensorflow.contrib.autograph.operators.py_builtins import print_ +from tensorflow.contrib.autograph.operators.py_builtins import range_ from tensorflow.contrib.autograph.operators.slices import get_item from tensorflow.contrib.autograph.operators.slices import GetItemOpts from tensorflow.contrib.autograph.operators.slices import set_item diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 9909e52164..9a66a6bb60 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.utils import builtins +from tensorflow.contrib.autograph.operators import py_builtins from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -82,8 +82,8 @@ def _py_for_stmt(iter_, extra_test, body, init_state): def _known_len_for_stmt(iter_, extra_test, body, init_state): - """Overload of for_stmt that iterates over objects that define a length.""" - n = builtins.dynamic_len(iter_) + """Overload of for_stmt that iterates over objects that admit a length.""" + n = py_builtins.len_(iter_) def while_body(iterate_index, *state): iterate = iter_[iterate_index] diff --git a/tensorflow/contrib/autograph/operators/py_builtins.py b/tensorflow/contrib/autograph/operators/py_builtins.py new file mode 100644 index 0000000000..c5730934e7 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/py_builtins.py @@ -0,0 +1,225 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators corresponding to Python builtin functions. + +List of built-in functions: https://docs.python.org/3/library/functions.html +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.contrib.autograph.utils import tensors +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import gen_string_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import math_ops + + +UNDEFINED = object() + + +def overload_of(f): + if f in SUPPORTED_BUILTINS: + return BUILTIN_FUINCTIONS_MAP[f.__name__] + return f + + +def abs_(x): + if tensor_util.is_tensor(x): + return _tf_abs(x) + return _py_abs(x) + + +def _tf_abs(x): + return math_ops.abs(x) + + +def _py_abs(x): + return abs(x) + + +def float_(x=0): + if tensor_util.is_tensor(x): + return _tf_float(x) + return _py_float(x) + + +def _tf_float(x): + # TODO(mdan): We shouldn't assume float32. + if x.dtype == dtypes.string: + return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32) + return math_ops.cast(x, dtype=dtypes.float32) + + +def _py_float(x): + return float(x) + + +def int_(x=0, base=UNDEFINED): + if tensor_util.is_tensor(x): + return _tf_int(x, base) + return _py_int(x, base) + + +def _tf_int(x, base): + if base not in (10, UNDEFINED): + raise NotImplementedError('base {} not supported for int'.format(base)) + + # TODO(mdan): We shouldn't assume int32. + if x.dtype == dtypes.string: + return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32) + return math_ops.cast(x, dtype=dtypes.int32) + + +def _py_int(x, base): + if base is UNDEFINED: + return int(x) + return int(x, base) + + +def len_(s): + if tensors.is_tensor_array(s): + return _tf_tensor_array_len(s) + elif tensors.is_tensor_list(s): + return _tf_tensor_list_len(s) + elif tensor_util.is_tensor(s): + return _tf_tensor_len(s) + return _py_len(s) + + +def _tf_tensor_array_len(s): + return s.size() + + +def _tf_tensor_list_len(s): + return list_ops.tensor_list_length(s) + + +def _tf_tensor_len(s): + """Overload of len_ for Tensor arguments.""" + # Statically shaped tensors: length is known ahead of time. + if s.shape.ndims and s.shape[0].value is not None: + return s.shape[0].value + + # Static shape of unknown dimensions: use dynamic shape but statically + # chech that it's a scalar. + shape = array_ops.shape(s) + + assert shape.shape, 'shape tensor of zero size? {}'.format(shape) + + if shape.shape[0] == 0: + raise ValueError( + 'len requires a non-scalar tensor, got one of shape {}'.format(shape)) + + if shape.shape[0].value is not None: + return array_ops.shape(s)[0] + + # Fully dynamic shape: use ops. + rank = array_ops.rank(s) + + def raise_zero_rank_error(): + msg = gen_string_ops.string_join( + ['len requires non-zero rank, got ', + gen_string_ops.as_string(rank)]) + with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]): + return constant_op.constant(0, dtype=dtypes.int32) + + return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0], + raise_zero_rank_error) + + +def _py_len(s): + return len(s) + + +def print_(*objects, **kwargs): + # Note: Python 2.6 doesn't support explicit keywords after starargs. + unknown_kwargs = tuple( + set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush'))) + if unknown_kwargs: + raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs)) + + # TODO(mdan): use logging_ops.Print when py_func is not supported. + return _tf_py_func_print(objects, kwargs) + + +def _tf_py_func_print(objects, kwargs): + """Overload of print_ as a py_func implementation.""" + override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED} + if 'flush' not in override_kwargs: + # Defaulting to flushing the console in graph mode, which helps reduce + # garbled output in IPython. + override_kwargs['flush'] = True + + def print_wrapper(*vals): + if six.PY3: + # TensorFlow doesn't seem to generate Unicode when passing strings to + # py_func. This causes the print to add a "b'" wrapper to the output, + # which is probably never what you want. + vals = tuple( + v.decode('utf-8') if isinstance(v, bytes) else v for v in vals) + six.print_(*vals, **override_kwargs) + + return py_func.wrap_py_func( + print_wrapper, None, objects, use_dummy_return=True) + + +def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED): + if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)): + return _tf_range(start_or_stop, stop, step) + return _py_range(start_or_stop, stop, step) + + +def _tf_range(start_or_stop, stop, step): + # TODO(mdan): We should optimize this when a full tensor is not required. + if step is not UNDEFINED: + return math_ops.range(start_or_stop, stop, step) + if stop is not UNDEFINED: + return math_ops.range(start_or_stop, stop) + return math_ops.range(start_or_stop) + + +def _py_range(start_or_stop, stop, step): + if step is not UNDEFINED: + return range(start_or_stop, stop, step) + if stop is not UNDEFINED: + return range(start_or_stop, stop) + return range(start_or_stop) + + +SUPPORTED_BUILTINS = set((abs, float, int, len, print, range)) + +if six.PY2: + SUPPORTED_BUILTINS.add(xrange) + +BUILTIN_FUINCTIONS_MAP = { + 'abs': abs_, + 'float': float_, + 'int': int_, + 'len': len_, + 'print': print_, + 'range': range_, + 'xrange': range_, +} diff --git a/tensorflow/contrib/autograph/operators/py_builtins_test.py b/tensorflow/contrib/autograph/operators/py_builtins_test.py new file mode 100644 index 0000000000..4073c51785 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/py_builtins_test.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for py_builtins module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import six + +from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.contrib.autograph.operators import py_builtins +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +class PyBuiltinsTest(test.TestCase): + + def test_abs(self): + self.assertEqual(py_builtins.abs_(-1), 1) + with self.test_session() as sess: + t = py_builtins.abs_(constant_op.constant(-1)) + self.assertEqual(sess.run(t), 1) + t = py_builtins.abs_(constant_op.constant([-1, 2, -3])) + self.assertAllEqual(sess.run(t), [1, 2, 3]) + + def test_float(self): + self.assertEqual(py_builtins.float_(10), 10.0) + self.assertEqual(py_builtins.float_('10.0'), 10.0) + with self.test_session() as sess: + t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64)) + self.assertEqual(sess.run(t), 1.0) + st = py_builtins.float_(constant_op.constant('1.0')) + self.assertEqual(sess.run(st), 1.0) + + def test_int(self): + self.assertEqual(py_builtins.int_(10.0), 10) + self.assertEqual(py_builtins.int_('11', 2), 3) + with self.test_session() as sess: + t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64)) + self.assertEqual(sess.run(t), 1) + st = py_builtins.int_(constant_op.constant('1')) + self.assertEqual(sess.run(st), 1) + st = py_builtins.int_(constant_op.constant('1'), 10) + self.assertEqual(sess.run(st), 1) + + def test_int_unsupported_base(self): + t = constant_op.constant(1, dtype=dtypes.float64) + with self.assertRaises(NotImplementedError): + py_builtins.int_(t, 2) + + def test_len(self): + self.assertEqual(py_builtins.len_([1, 2, 3]), 3) + with self.test_session() as sess: + t = py_builtins.len_(constant_op.constant([[1], [2], [3]])) + self.assertEqual(t, 3) + ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5)) + self.assertEqual(sess.run(ta), 5) + tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5])) + self.assertEqual(sess.run(tl), 3) + + def test_len_scalar(self): + with self.assertRaises(ValueError): + py_builtins.len_(constant_op.constant(1)) + + def test_len_dynamic_shape(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtype=dtypes.int32, shape=None) + t = py_builtins.len_(p) + self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3) + + with self.assertRaises(errors_impl.InvalidArgumentError): + t = py_builtins.len_(p) + sess.run(t, {p: 1}) + + def test_print_tensors(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(py_builtins.print_(constant_op.constant('test message'), 1)) + self.assertEqual(out_capturer.getvalue(), 'test message 1\n') + finally: + sys.stdout = sys.__stdout__ + + def test_print_complex(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run( + py_builtins.print_(constant_op.constant('test message'), [1, 2])) + self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') + finally: + sys.stdout = sys.__stdout__ + + def test_range(self): + self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2]) + self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2]) + self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1]) + + def test_range_tensor(self): + with self.test_session() as sess: + r = py_builtins.range_(constant_op.constant(3)) + self.assertAllEqual(sess.run(r), [0, 1, 2]) + r = py_builtins.range_(1, constant_op.constant(3)) + self.assertAllEqual(sess.run(r), [1, 2]) + r = py_builtins.range_(2, 0, constant_op.constant(-1)) + self.assertAllEqual(sess.run(r), [2, 1]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index d2b399f19b..4504a5c7a3 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -20,12 +20,12 @@ py_library( name = "utils", srcs = [ "__init__.py", - "builtins.py", "context_managers.py", "misc.py", "multiple_dispatch.py", "py_func.py", "tensor_list.py", + "tensors.py", "testing.py", "type_check.py", ], @@ -42,17 +42,6 @@ py_library( ) py_test( - name = "builtins_test", - srcs = ["builtins_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], - deps = [ - ":utils", - "//tensorflow/python:client_testlib", - ], -) - -py_test( name = "context_managers_test", srcs = ["context_managers_test.py"], srcs_version = "PY2AND3", @@ -113,3 +102,13 @@ py_test( "//tensorflow/python:list_ops", ], ) + +py_test( + name = "tensors_test", + srcs = ["tensors_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py index 57b5f74741..38e0a0a8f0 100644 --- a/tensorflow/contrib/autograph/utils/__init__.py +++ b/tensorflow/contrib/autograph/utils/__init__.py @@ -18,9 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin -from tensorflow.contrib.autograph.utils.builtins import dynamic_print -from tensorflow.contrib.autograph.utils.builtins import dynamic_range from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns from tensorflow.contrib.autograph.utils.misc import alias_tensors from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py deleted file mode 100644 index 4dd440ef19..0000000000 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Builtin conversion utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - -import six - -from tensorflow.contrib.autograph.utils import py_func -from tensorflow.contrib.autograph.utils import type_check -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import list_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops - - -def dynamic_builtin(f, *args, **kwargs): - """Converts a builtin function call inline.""" - if f is len: - return dynamic_len(*args, **kwargs) - if six.PY2 and f is xrange: - return dynamic_range(*args, **kwargs) - if f is range: - return dynamic_range(*args, **kwargs) - if f is int: - return dynamic_int(*args, **kwargs) - if f is float: - return dynamic_float(*args, **kwargs) - if f is abs: - return dynamic_abs(*args, **kwargs) - - raise NotImplementedError( - 'The "%s" builtin is not yet supported.' % f.__name__) - - -def dynamic_len(list_or_tensor): - """Implementation of len using dynamic dispatch.""" - if _is_tensor_list(list_or_tensor): - return list_ops.tensor_list_length(list_or_tensor) - elif tensor_util.is_tensor(list_or_tensor): - shape = list_or_tensor.shape - if not shape.ndims: - raise ValueError( - 'len requires non-zero rank for tensor "%s"' % list_or_tensor) - return array_ops.shape(list_or_tensor)[0] - return len(list_or_tensor) - - -def _is_tensor_list(list_or_tensor): - return (tensor_util.is_tensor(list_or_tensor) - and list_or_tensor.dtype == dtypes.variant) - - -def dynamic_int(num_or_tensor, **kwargs): - """Implementation of int() using dynamic dispatch.""" - if tensor_util.is_tensor(num_or_tensor): - return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs) - return int(num_or_tensor) - - -def dynamic_float(num_or_tensor, **kwargs): - """Implementation of float() using dynamic dispatch.""" - if tensor_util.is_tensor(num_or_tensor): - return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs) - return float(num_or_tensor) - - -def dynamic_abs(num_or_tensor, **kwargs): - if tensor_util.is_tensor(num_or_tensor): - return math_ops.abs(num_or_tensor, **kwargs) - else: - return abs(num_or_tensor, **kwargs) - - -def dynamic_range(start_or_stop, stop=None, step=None): - """Implementation of range using dynamic dispatch.""" - if type_check.is_tensor(start_or_stop, stop, step): - if step is not None: - return math_ops.range(start_or_stop, stop, step) - if stop is not None: - return math_ops.range(start_or_stop, stop) - return math_ops.range(start_or_stop) - - if step is not None: - return range(start_or_stop, stop, step) - elif stop is not None: - return range(start_or_stop, stop) - return range(start_or_stop) - - -def is_tf_print_compatible(value): - # TODO(mdan): Enable once we can reliably test this. - # This is currently disabled because we can't capture the output of - # op kernels from Python. - del value - return False - - -def dynamic_print(*values): - """Implementation of print using dynamic dispatch. - - The function attempts to use tf.Print if all the values are compatible. - Otherwise, it will fall back to py_func. - - Args: - *values: values to print - Returns: - A dummy value indicating the print completed. If tf. - """ - - if all(map(is_tf_print_compatible, values)): - return logging_ops.Print(1, values) - - def print_wrapper(*vals): - if six.PY3: - # TensorFlow doesn't seem to generate Unicode when passing strings to - # py_func. This causes the print to add a "b'" wrapper to the output, - # which is probably never what you want. - vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals) - print(*vals) - # The flush helps avoid garbled output in IPython. - sys.stdout.flush() - - return py_func.wrap_py_func( - print_wrapper, None, values, use_dummy_return=True) diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py deleted file mode 100644 index b1cd5253bc..0000000000 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for builtins module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - -import six - -from tensorflow.contrib.autograph.utils import builtins -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.platform import test - - -class BuiltinsTest(test.TestCase): - - def test_dynamic_len_tf_scalar(self): - a = constant_op.constant(1) - - with self.assertRaisesRegexp(ValueError, - 'len requires non-zero rank for tensor.*'): - with self.test_session() as sess: - sess.run(builtins.dynamic_builtin(len, a)) - - def test_dynamic_len_tf_array(self): - a = constant_op.constant([1, 2, 3]) - - with self.test_session() as sess: - self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) - - def test_dynamic_abs_tf_scalar(self): - a = constant_op.constant(-1) - - with self.test_session() as sess: - self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a))) - - def test_dynamic_abs_tf_array(self): - a = constant_op.constant([-1, 2, -3]) - - with self.test_session() as sess: - self.assertListEqual([1, 2, 3], - list(sess.run(builtins.dynamic_builtin(abs, a)))) - - def test_dynamic_abs_py_scalar(self): - a = -1 - self.assertEqual(1, builtins.dynamic_builtin(abs, a)) - - def test_dynamic_len_tf_matrix(self): - a = constant_op.constant([[1, 2], [3, 4]]) - - with self.test_session() as sess: - self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a))) - - def test_dynamic_len_py_list(self): - a = [3] * 5 - - self.assertEqual(5, builtins.dynamic_builtin(len, a)) - - def test_dynamic_range_all_python(self): - self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2]) - self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1]) - - def test_dynamic_range_tf(self): - with self.test_session() as sess: - self.assertAllEqual( - sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))), - [0, 1, 2]) - self.assertAllEqual( - sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))), - [1, 2]) - self.assertAllEqual( - sess.run( - builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))), - [2, 1]) - - def test_dynamic_range_detection(self): - def range(x): # pylint:disable=redefined-builtin - return x - - # Functions that just have the names of builtins are rejected. - with self.assertRaises(NotImplementedError): - self.assertEqual(builtins.dynamic_builtin(range, 1), 1) - if six.PY2: - self.assertListEqual( - list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) - - def test_casts(self): - i = constant_op.constant(2, dtype=dtypes.int32) - f = constant_op.constant(1.0, dtype=dtypes.float32) - - self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) - self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) - self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) - self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) - - self.assertEqual(builtins.dynamic_builtin(int, True), 1) - self.assertEqual(builtins.dynamic_builtin(int, False), 0) - self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) - self.assertEqual(builtins.dynamic_builtin(float, False), 0.0) - - def test_dynamic_print_tf(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(builtins.dynamic_print('test message', 1)) - self.assertEqual(out_capturer.getvalue(), 'test message 1\n') - finally: - sys.stdout = sys.__stdout__ - - def test_dynamic_print_complex(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(builtins.dynamic_print('test message', [1, 2])) - self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') - finally: - sys.stdout = sys.__stdout__ - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/autograph/utils/tensors.py b/tensorflow/contrib/autograph/utils/tensors.py new file mode 100644 index 0000000000..fa5db81a71 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/tensors.py @@ -0,0 +1,41 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This module defines tensor utilities not found in TensorFlow. + +The reason these utilities are not defined in TensorFlow is because they may +not be not fully robust, although they work in the vast majority of cases. So +we define them here in order for their behavior to be consistently verified. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import tensor_array_ops + + +def is_tensor_array(t): + return isinstance(t, tensor_array_ops.TensorArray) + + +def is_tensor_list(t): + # TODO(mdan): This is just a heuristic. + # With TF lacking support for templated types, this is unfortunately the + # closest we can get right now. A dedicated op ought to be possible to + # construct. + return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and + not t.shape.ndims) diff --git a/tensorflow/contrib/autograph/utils/tensors_test.py b/tensorflow/contrib/autograph/utils/tensors_test.py new file mode 100644 index 0000000000..e855e0b6cb --- /dev/null +++ b/tensorflow/contrib/autograph/utils/tensors_test.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensors module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils import tensors +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +class TensorsTest(test.TestCase): + + def _simple_tensor_array(self): + return tensor_array_ops.TensorArray(dtypes.int32, size=3) + + def _simple_tensor_list(self): + return list_ops.empty_tensor_list( + element_shape=constant_op.constant([1]), element_dtype=dtypes.int32) + + def _simple_list_of_tensors(self): + return [constant_op.constant(1), constant_op.constant(2)] + + def test_is_tensor_array(self): + self.assertTrue(tensors.is_tensor_array(self._simple_tensor_array())) + self.assertFalse(tensors.is_tensor_array(self._simple_tensor_list())) + self.assertFalse(tensors.is_tensor_array(constant_op.constant(1))) + self.assertFalse(tensors.is_tensor_array(self._simple_list_of_tensors())) + self.assertFalse(tensors.is_tensor_array(None)) + + def test_is_tensor_list(self): + self.assertFalse(tensors.is_tensor_list(self._simple_tensor_array())) + self.assertTrue(tensors.is_tensor_list(self._simple_tensor_list())) + self.assertFalse(tensors.is_tensor_list(constant_op.constant(1))) + self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors())) + self.assertFalse(tensors.is_tensor_list(None)) + + +if __name__ == '__main__': + test.main() |