diff options
Diffstat (limited to 'tensorflow/python/autograph/converters/builtin_functions.py')
-rw-r--r-- | tensorflow/python/autograph/converters/builtin_functions.py | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/converters/builtin_functions.py b/tensorflow/python/autograph/converters/builtin_functions.py new file mode 100644 index 0000000000..b8b268d8ce --- /dev/null +++ b/tensorflow/python/autograph/converters/builtin_functions.py @@ -0,0 +1,65 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Handles builtins and other special functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.operators import py_builtins +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import templates + + +class BuiltinFunctionTransformer(converter.Base): + """Handles builtin functions. + + This transformer only covers functions that are translated into a + TF equivalent, like `len`. + """ + + def _convert_builtin(self, f, args, as_expression): + template = """ + ag__.func(args) + """ + if as_expression: + return templates.replace_as_expression( + template, func=py_builtins.overload_of(f).__name__, args=args) + else: + return templates.replace( + template, func=py_builtins.overload_of(f).__name__, args=args) + + def visit_Call(self, node): + node = self.generic_visit(node) + if anno.hasanno(node.func, 'live_val'): + live_val = anno.getanno(node.func, 'live_val') + if live_val in py_builtins.SUPPORTED_BUILTINS: + node = self._convert_builtin(live_val, node.args, as_expression=True) + return node + + def visit_Print(self, node): + node = self.generic_visit(node) + args = node.values + # Following is the case when calling print(a, b) + if len(args) == 1 and isinstance(args[0], gast.Tuple): + args = args[0].elts + return self._convert_builtin(print, args, as_expression=False) + + +def transform(node, ctx): + return BuiltinFunctionTransformer(ctx).visit(node) |