aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/utils/builtins.py
blob: 998087e056c2cd264399982220d6e0528aab9edb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builtin conversion utilities."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys

import six

from tensorflow.contrib.autograph.utils import py_func
from tensorflow.contrib.autograph.utils import type_check
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops


def dynamic_builtin(f, *args, **kwargs):
  """Converts a builtin function call inline."""
  if f is len:
    return dynamic_len(*args, **kwargs)
  if six.PY2 and f is xrange:
    return dynamic_range(*args, **kwargs)
  if f is range:
    return dynamic_range(*args, **kwargs)
  if f is int:
    return dynamic_int(*args, **kwargs)
  if f is float:
    return dynamic_float(*args, **kwargs)

  raise NotImplementedError(
      'The "%s" builtin is not yet supported.' % f.__name__)


def dynamic_len(list_or_tensor):
  """Implementation of len using dynamic dispatch."""
  if tensor_util.is_tensor(list_or_tensor):
    shape = list_or_tensor.shape
    if not shape:
      raise ValueError(
          'len requires non-zero rank for tensor "%s"' % list_or_tensor)
    return array_ops.shape(list_or_tensor)[0]
  return len(list_or_tensor)


def dynamic_int(num_or_tensor, **kwargs):
  """Implementation of int() using dynamic dispatch."""
  if tensor_util.is_tensor(num_or_tensor):
    return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs)
  return int(num_or_tensor)


def dynamic_float(num_or_tensor, **kwargs):
  """Implementation of float() using dynamic dispatch."""
  if tensor_util.is_tensor(num_or_tensor):
    return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs)
  return float(num_or_tensor)


def dynamic_range(start_or_stop, stop=None, step=None):
  """Implementation of range using dynamic dispatch."""
  if type_check.is_tensor(start_or_stop, stop, step):
    if step is not None:
      return math_ops.range(start_or_stop, stop, step)
    if stop is not None:
      return math_ops.range(start_or_stop, stop)
    return math_ops.range(start_or_stop)

  if step is not None:
    return range(start_or_stop, stop, step)
  elif stop is not None:
    return range(start_or_stop, stop)
  return range(start_or_stop)


def is_tf_print_compatible(value):
  # TODO(mdan): Enable once we can reliably test this.
  # This is currently disabled because we can't capture the output of
  # op kernels from Python.
  del value
  return False


def dynamic_print(*values):
  """Implementation of print using dynamic dispatch.

  The function attempts to use tf.Print if all the values are compatible.
  Otherwise, it will fall back to py_func.

  Args:
    *values: values to print
  Returns:
    A dummy value indicating the print completed. If tf.
  """

  if all(map(is_tf_print_compatible, values)):
    return logging_ops.Print(1, values)

  def print_wrapper(*vals):
    if six.PY3:
      # TensorFlow doesn't seem to generate Unicode when passing strings to
      # py_func. This causes the print to add a "b'" wrapper to the output,
      # which is probably never what you want.
      vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals)
    print(*vals)
    # The flush helps avoid garbled output in IPython.
    sys.stdout.flush()

  return py_func.wrap_py_func(
      print_wrapper, None, values, use_dummy_return=True)