aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2017-10-20 18:20:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-20 18:26:50 -0700
commit93e8f3c67d82c2d43b8dddd4cb8d7f02259d0e7e (patch)
tree29cbd370d0038f76a9cc4e7c7b245be8afcfc9c5
parent0d6a2e35312c71cf8a145a7c40d69883e254daee (diff)
Adding Python ApiDef overrides.
PiperOrigin-RevId: 172960496
-rw-r--r--tensorflow/core/BUILD5
-rw-r--r--tensorflow/core/api_def/python_api/api_def_B.pbtxt18
-rw-r--r--tensorflow/core/api_def/python_api/api_def_C.pbtxt15
-rw-r--r--tensorflow/core/api_def/python_api/api_def_D.pbtxt54
-rw-r--r--tensorflow/core/api_def/python_api/api_def_E.pbtxt30
-rw-r--r--tensorflow/core/api_def/python_api/api_def_F.pbtxt21
-rw-r--r--tensorflow/core/api_def/python_api/api_def_H.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_I.pbtxt15
-rw-r--r--tensorflow/core/api_def/python_api/api_def_L.pbtxt24
-rw-r--r--tensorflow/core/api_def/python_api/api_def_M.pbtxt78
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Q.pbtxt27
-rw-r--r--tensorflow/core/api_def/python_api/api_def_R.pbtxt36
-rw-r--r--tensorflow/core/api_def/python_api/api_def_S.pbtxt36
-rw-r--r--tensorflow/tools/api/tests/BUILD15
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py177
-rw-r--r--tensorflow/tools/api/tests/convert_from_multiline.cc63
16 files changed, 620 insertions, 0 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index d198a796a7..6ad93a73f4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -3326,6 +3326,11 @@ filegroup(
data = glob(["api_def/base_api/*"]),
)
+filegroup(
+ name = "python_api_def",
+ data = glob(["api_def/python_api/*"]),
+)
+
tf_cc_test(
name = "api_test",
srcs = ["api_def/api_test.cc"],
diff --git a/tensorflow/core/api_def/python_api/api_def_B.pbtxt b/tensorflow/core/api_def/python_api/api_def_B.pbtxt
new file mode 100644
index 0000000000..9b5df58eba
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_B.pbtxt
@@ -0,0 +1,18 @@
+op {
+ graph_op_name: "BitwiseAnd"
+ endpoint {
+ name: "bitwise.bitwise_and"
+ }
+}
+op {
+ graph_op_name: "BitwiseOr"
+ endpoint {
+ name: "bitwise.bitwise_or"
+ }
+}
+op {
+ graph_op_name: "BitwiseXor"
+ endpoint {
+ name: "bitwise.bitwise_xor"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_C.pbtxt b/tensorflow/core/api_def/python_api/api_def_C.pbtxt
new file mode 100644
index 0000000000..cf8d0622be
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_C.pbtxt
@@ -0,0 +1,15 @@
+op {
+ graph_op_name: "Cholesky"
+ endpoint {
+ name: "cholesky"
+ }
+ endpoint {
+ name: "linalg.cholesky"
+ }
+}
+op {
+ graph_op_name: "CropAndResize"
+ endpoint {
+ name: "image.crop_and_resize"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_D.pbtxt b/tensorflow/core/api_def/python_api/api_def_D.pbtxt
new file mode 100644
index 0000000000..12e0dbec1c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_D.pbtxt
@@ -0,0 +1,54 @@
+op {
+ graph_op_name: "DecodeAndCropJpeg"
+ endpoint {
+ name: "image.decode_and_crop_jpeg"
+ }
+}
+op {
+ graph_op_name: "DecodeBmp"
+ endpoint {
+ name: "image.decode_bmp"
+ }
+}
+op {
+ graph_op_name: "DecodeGif"
+ endpoint {
+ name: "image.decode_gif"
+ }
+}
+op {
+ graph_op_name: "DecodeJpeg"
+ endpoint {
+ name: "image.decode_jpeg"
+ }
+}
+op {
+ graph_op_name: "DecodePng"
+ endpoint {
+ name: "image.decode_png"
+ }
+}
+op {
+ graph_op_name: "DepthwiseConv2dNative"
+ endpoint {
+ name: "nn.depthwise_conv2d_native"
+ }
+}
+op {
+ graph_op_name: "DepthwiseConv2dNativeBackpropFilter"
+ endpoint {
+ name: "nn.depthwise_conv2d_native_backprop_filter"
+ }
+}
+op {
+ graph_op_name: "DepthwiseConv2dNativeBackpropInput"
+ endpoint {
+ name: "nn.depthwise_conv2d_native_backprop_input"
+ }
+}
+op {
+ graph_op_name: "DrawBoundingBoxes"
+ endpoint {
+ name: "image.draw_bounding_boxes"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_E.pbtxt b/tensorflow/core/api_def/python_api/api_def_E.pbtxt
new file mode 100644
index 0000000000..f6871f7138
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_E.pbtxt
@@ -0,0 +1,30 @@
+op {
+ graph_op_name: "Elu"
+ endpoint {
+ name: "nn.elu"
+ }
+}
+op {
+ graph_op_name: "EncodeJpeg"
+ endpoint {
+ name: "image.encode_jpeg"
+ }
+}
+op {
+ graph_op_name: "EncodePng"
+ endpoint {
+ name: "image.encode_png"
+ }
+}
+op {
+ graph_op_name: "ExtractGlimpse"
+ endpoint {
+ name: "image.extract_glimpse"
+ }
+}
+op {
+ graph_op_name: "ExtractJpegShape"
+ endpoint {
+ name: "image.extract_jpeg_shape"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_F.pbtxt b/tensorflow/core/api_def/python_api/api_def_F.pbtxt
new file mode 100644
index 0000000000..844a1348a3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_F.pbtxt
@@ -0,0 +1,21 @@
+op {
+ graph_op_name: "FFT"
+ endpoint {
+ name: "fft"
+ }
+ endpoint {
+ name: "spectral.fft"
+ }
+}
+op {
+ graph_op_name: "FractionalAvgPool"
+ endpoint {
+ name: "nn.fractional_avg_pool"
+ }
+}
+op {
+ graph_op_name: "FractionalMaxPool"
+ endpoint {
+ name: "nn.fractional_max_pool"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_H.pbtxt b/tensorflow/core/api_def/python_api/api_def_H.pbtxt
new file mode 100644
index 0000000000..55998189f4
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_H.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "HSVToRGB"
+ endpoint {
+ name: "image.hsv_to_rgb"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_I.pbtxt b/tensorflow/core/api_def/python_api/api_def_I.pbtxt
new file mode 100644
index 0000000000..6c794fab0d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_I.pbtxt
@@ -0,0 +1,15 @@
+op {
+ graph_op_name: "IFFT"
+ endpoint {
+ name: "ifft"
+ }
+ endpoint {
+ name: "spectral.ifft"
+ }
+}
+op {
+ graph_op_name: "Invert"
+ endpoint {
+ name: "bitwise.invert"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_L.pbtxt b/tensorflow/core/api_def/python_api/api_def_L.pbtxt
new file mode 100644
index 0000000000..38ba26a8e8
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_L.pbtxt
@@ -0,0 +1,24 @@
+op {
+ graph_op_name: "L2Loss"
+ endpoint {
+ name: "nn.l2_loss"
+ }
+}
+op {
+ graph_op_name: "LRN"
+ endpoint {
+ name: "nn.local_response_normalization"
+ }
+ endpoint {
+ name: "nn.lrn"
+ }
+}
+op {
+ graph_op_name: "LinSpace"
+ endpoint {
+ name: "lin_space"
+ }
+ endpoint {
+ name: "linspace"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_M.pbtxt b/tensorflow/core/api_def/python_api/api_def_M.pbtxt
new file mode 100644
index 0000000000..154071f6bc
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_M.pbtxt
@@ -0,0 +1,78 @@
+op {
+ graph_op_name: "MatrixBandPart"
+ endpoint {
+ name: "linalg.band_part"
+ }
+ endpoint {
+ name: "matrix_band_part"
+ }
+}
+op {
+ graph_op_name: "MatrixDeterminant"
+ endpoint {
+ name: "linalg.det"
+ }
+ endpoint {
+ name: "matrix_determinant"
+ }
+}
+op {
+ graph_op_name: "MatrixDiag"
+ endpoint {
+ name: "linalg.diag"
+ }
+ endpoint {
+ name: "matrix_diag"
+ }
+}
+op {
+ graph_op_name: "MatrixDiagPart"
+ endpoint {
+ name: "linalg.diag_part"
+ }
+ endpoint {
+ name: "matrix_diag_part"
+ }
+}
+op {
+ graph_op_name: "MatrixInverse"
+ endpoint {
+ name: "linalg.inv"
+ }
+ endpoint {
+ name: "matrix_inverse"
+ }
+}
+op {
+ graph_op_name: "MatrixSetDiag"
+ endpoint {
+ name: "linalg.set_diag"
+ }
+ endpoint {
+ name: "matrix_set_diag"
+ }
+}
+op {
+ graph_op_name: "MatrixSolve"
+ endpoint {
+ name: "linalg.solve"
+ }
+ endpoint {
+ name: "matrix_solve"
+ }
+}
+op {
+ graph_op_name: "MatrixTriangularSolve"
+ endpoint {
+ name: "linalg.triangular_solve"
+ }
+ endpoint {
+ name: "matrix_triangular_solve"
+ }
+}
+op {
+ graph_op_name: "MaxPoolWithArgmax"
+ endpoint {
+ name: "nn.max_pool_with_argmax"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Q.pbtxt b/tensorflow/core/api_def/python_api/api_def_Q.pbtxt
new file mode 100644
index 0000000000..cba032880f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Q.pbtxt
@@ -0,0 +1,27 @@
+op {
+ graph_op_name: "Qr"
+ endpoint {
+ name: "linalg.qr"
+ }
+ endpoint {
+ name: "qr"
+ }
+}
+op {
+ graph_op_name: "QuantizedAvgPool"
+ endpoint {
+ name: "nn.quantized_avg_pool"
+ }
+}
+op {
+ graph_op_name: "QuantizedMaxPool"
+ endpoint {
+ name: "nn.quantized_max_pool"
+ }
+}
+op {
+ graph_op_name: "QuantizedReluX"
+ endpoint {
+ name: "nn.quantized_relu_x"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_R.pbtxt b/tensorflow/core/api_def/python_api/api_def_R.pbtxt
new file mode 100644
index 0000000000..9a57e72be0
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_R.pbtxt
@@ -0,0 +1,36 @@
+op {
+ graph_op_name: "RGBToHSV"
+ endpoint {
+ name: "image.rgb_to_hsv"
+ }
+}
+op {
+ graph_op_name: "Relu"
+ endpoint {
+ name: "nn.relu"
+ }
+}
+op {
+ graph_op_name: "ResizeArea"
+ endpoint {
+ name: "image.resize_area"
+ }
+}
+op {
+ graph_op_name: "ResizeBicubic"
+ endpoint {
+ name: "image.resize_bicubic"
+ }
+}
+op {
+ graph_op_name: "ResizeBilinear"
+ endpoint {
+ name: "image.resize_bilinear"
+ }
+}
+op {
+ graph_op_name: "ResizeNearestNeighbor"
+ endpoint {
+ name: "image.resize_nearest_neighbor"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_S.pbtxt b/tensorflow/core/api_def/python_api/api_def_S.pbtxt
new file mode 100644
index 0000000000..9c7a39038e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_S.pbtxt
@@ -0,0 +1,36 @@
+op {
+ graph_op_name: "SdcaFprint"
+ endpoint {
+ name: "train.sdca_fprint"
+ }
+}
+op {
+ graph_op_name: "SdcaOptimizer"
+ endpoint {
+ name: "train.sdca_optimizer"
+ }
+}
+op {
+ graph_op_name: "SdcaShrinkL1"
+ endpoint {
+ name: "train.sdca_shrink_l1"
+ }
+}
+op {
+ graph_op_name: "Selu"
+ endpoint {
+ name: "nn.selu"
+ }
+}
+op {
+ graph_op_name: "Softplus"
+ endpoint {
+ name: "nn.softplus"
+ }
+}
+op {
+ graph_op_name: "Softsign"
+ endpoint {
+ name: "nn.softsign"
+ }
+}
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index e99cc0572f..a913e35101 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -11,10 +11,15 @@ exports_files([
"API_UPDATE_WARNING.txt",
])
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
+
py_test(
name = "api_compatibility_test",
srcs = ["api_compatibility_test.py"],
data = [
+ ":convert_from_multiline",
+ "//tensorflow/core:base_api_def",
+ "//tensorflow/core:python_api_def",
"//tensorflow/tools/api/golden:api_golden",
"//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt",
"//tensorflow/tools/api/tests:README.txt",
@@ -23,6 +28,7 @@ py_test(
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
"//tensorflow/tools/api/lib:python_object_to_proto_visitor",
@@ -31,6 +37,15 @@ py_test(
],
)
+tf_cc_binary(
+ name = "convert_from_multiline",
+ srcs = ["convert_from_multiline.cc"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:op_gen_lib",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 1ffa8fc26c..f350c12d41 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -28,8 +28,11 @@ from __future__ import division
from __future__ import print_function
import argparse
+from collections import defaultdict
+from operator import attrgetter
import os
import re
+import subprocess
import sys
import unittest
@@ -37,6 +40,7 @@ import tensorflow as tf
from google.protobuf import text_format
+from tensorflow.core.framework import api_def_pb2
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
@@ -64,6 +68,11 @@ _API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden'
_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt'
_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt'
+_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
+_CONVERT_FROM_MULTILINE_SCRIPT = 'tensorflow/tools/api/tests/convert_from_multiline'
+_BASE_API_DIR = 'tensorflow/core/api_def/base_api'
+_PYTHON_API_DIR = 'tensorflow/core/api_def/python_api'
+
def _KeyToFilePath(key):
"""From a given key, construct a filepath."""
@@ -88,6 +97,30 @@ def _FileNameToKey(filename):
return api_object_key
+def _GetSymbol(symbol_id):
+ """Get TensorFlow symbol based on the given identifier.
+
+ Args:
+ symbol_id: Symbol identifier in the form module1.module2. ... .sym.
+
+ Returns:
+ Symbol corresponding to the given id.
+ """
+ # Ignore first module which should be tensorflow
+ symbol_id_split = symbol_id.split('.')[1:]
+ symbol = tf
+ for sym in symbol_id_split:
+ symbol = getattr(symbol, sym)
+ return symbol
+
+
+def _IsGenModule(module_name):
+ if not module_name:
+ return False
+ module_name_split = module_name.split('.')
+ return module_name_split[-1].startswith('gen_')
+
+
class ApiCompatibilityTest(test.TestCase):
def __init__(self, *args, **kwargs):
@@ -229,6 +262,150 @@ class ApiCompatibilityTest(test.TestCase):
update_goldens=FLAGS.update_goldens)
+class ApiDefTest(test.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(ApiDefTest, self).__init__(*args, **kwargs)
+ self._first_cap_pattern = re.compile('(.)([A-Z][a-z]+)')
+ self._all_cap_pattern = re.compile('([a-z0-9])([A-Z])')
+
+ def _GenerateLowerCaseOpName(self, op_name):
+ lower_case_name = self._first_cap_pattern.sub(r'\1_\2', op_name)
+ return self._all_cap_pattern.sub(r'\1_\2', lower_case_name).lower()
+
+ def _CreatePythonApiDef(self, base_api_def, endpoint_names):
+ """Creates Python ApiDef that overrides base_api_def if needed.
+
+ Args:
+ base_api_def: (api_def_pb2.ApiDef) base ApiDef instance.
+ endpoint_names: List of Python endpoint names.
+
+ Returns:
+ api_def_pb2.ApiDef instance with overrides for base_api_def
+ if module.name endpoint is different from any existing
+ endpoints in base_api_def. Otherwise, returns None.
+ """
+ endpoint_names_set = set(endpoint_names)
+ base_endpoint_names_set = {
+ self._GenerateLowerCaseOpName(endpoint.name)
+ for endpoint in base_api_def.endpoint}
+
+ if endpoint_names_set == base_endpoint_names_set:
+ return None # All endpoints are the same
+
+ api_def = api_def_pb2.ApiDef()
+ api_def.graph_op_name = base_api_def.graph_op_name
+
+ for endpoint_name in sorted(endpoint_names):
+ new_endpoint = api_def.endpoint.add()
+ new_endpoint.name = endpoint_name
+
+ return api_def
+
+ def _GetBaseApiMap(self):
+ """Get a map from graph op name to its base ApiDef.
+
+ Returns:
+ Dictionary mapping graph op name to corresponding ApiDef.
+ """
+ # Convert base ApiDef in Multiline format to Proto format.
+ converted_base_api_dir = os.path.join(
+ test.get_temp_dir(), 'temp_base_api_defs')
+ subprocess.check_call(
+ [os.path.join(resource_loader.get_root_dir_with_all_resources(),
+ _CONVERT_FROM_MULTILINE_SCRIPT),
+ _BASE_API_DIR, converted_base_api_dir])
+
+ name_to_base_api_def = {}
+ base_api_files = file_io.get_matching_files(
+ os.path.join(converted_base_api_dir, 'api_def_*.pbtxt'))
+ for base_api_file in base_api_files:
+ if file_io.file_exists(base_api_file):
+ api_defs = api_def_pb2.ApiDefs()
+ text_format.Merge(
+ file_io.read_file_to_string(base_api_file), api_defs)
+ for api_def in api_defs.op:
+ lower_case_name = self._GenerateLowerCaseOpName(api_def.graph_op_name)
+ name_to_base_api_def[lower_case_name] = api_def
+ return name_to_base_api_def
+
+ @unittest.skipUnless(
+ sys.version_info.major == 2 and os.uname()[0] == 'Linux',
+ 'API compabitility test goldens are generated using python2 on Linux.')
+ def testAPIDefCompatibility(self):
+ # Get base ApiDef
+ name_to_base_api_def = self._GetBaseApiMap()
+ # Extract Python API
+ visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
+ public_api_visitor = public_api.PublicAPIVisitor(visitor)
+ public_api_visitor.do_not_descend_map['tf'].append('contrib')
+ traverse.traverse(tf, public_api_visitor)
+ proto_dict = visitor.GetProtos()
+
+ # Map from first character of op name to Python ApiDefs.
+ api_def_map = defaultdict(api_def_pb2.ApiDefs)
+ # We need to override all endpoints even if 1 endpoint differs from base
+ # ApiDef. So, we first create a map from an op to all its endpoints.
+ op_to_endpoint_name = defaultdict(list)
+
+ # Generate map from generated python op to endpoint names.
+ for public_module, value in proto_dict.items():
+ module_obj = _GetSymbol(public_module)
+ for sym in value.tf_module.member_method:
+ obj = getattr(module_obj, sym.name)
+
+ # Check if object is defined in gen_* module. That is,
+ # the object has been generated from OpDef.
+ if hasattr(obj, '__module__') and _IsGenModule(obj.__module__):
+ if obj.__name__ not in name_to_base_api_def:
+ # Symbol might be defined only in Python and not generated from
+ # C++ api.
+ continue
+ relative_public_module = public_module[len('tensorflow.'):]
+ full_name = (relative_public_module + '.' + sym.name
+ if relative_public_module else sym.name)
+ op_to_endpoint_name[obj].append(full_name)
+
+ # Generate Python ApiDef overrides.
+ for op, endpoint_names in op_to_endpoint_name.items():
+ api_def = self._CreatePythonApiDef(
+ name_to_base_api_def[op.__name__], endpoint_names)
+ if api_def:
+ api_defs = api_def_map[op.__name__[0].upper()]
+ api_defs.op.extend([api_def])
+
+ for key in _ALPHABET:
+ # Get new ApiDef for the given key.
+ new_api_defs_str = ''
+ if key in api_def_map:
+ new_api_defs = api_def_map[key]
+ new_api_defs.op.sort(key=attrgetter('graph_op_name'))
+ new_api_defs_str = str(new_api_defs)
+
+ # Get current ApiDef for the given key.
+ api_defs_file_path = os.path.join(
+ _PYTHON_API_DIR, 'api_def_%s.pbtxt' % key)
+ old_api_defs_str = ''
+ if file_io.file_exists(api_defs_file_path):
+ old_api_defs_str = file_io.read_file_to_string(api_defs_file_path)
+
+ if old_api_defs_str == new_api_defs_str:
+ continue
+
+ if FLAGS.update_goldens:
+ if not new_api_defs_str:
+ logging.info('Deleting %s...' % api_defs_file_path)
+ file_io.delete_file(api_defs_file_path)
+ else:
+ logging.info('Updating %s...' % api_defs_file_path)
+ file_io.write_string_to_file(api_defs_file_path, new_api_defs_str)
+ else:
+ self.assertMultiLineEqual(
+ old_api_defs_str, new_api_defs_str,
+ 'To update golden API files, run api_compatibility_test locally '
+ 'with --update_goldens=True flag.')
+
+
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
diff --git a/tensorflow/tools/api/tests/convert_from_multiline.cc b/tensorflow/tools/api/tests/convert_from_multiline.cc
new file mode 100644
index 0000000000..5c5aaa4f06
--- /dev/null
+++ b/tensorflow/tools/api/tests/convert_from_multiline.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Converts all *.pbtxt files in a directory from Multiline to proto format.
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+
+namespace tensorflow {
+
+namespace {
+constexpr char kApiDefFilePattern[] = "*.pbtxt";
+
+Status ConvertFilesFromMultiline(const string& input_dir,
+ const string& output_dir) {
+ Env* env = Env::Default();
+
+ const string file_pattern = io::JoinPath(input_dir, kApiDefFilePattern);
+ std::vector<string> matching_paths;
+ TF_CHECK_OK(env->GetMatchingPaths(file_pattern, &matching_paths));
+
+ if (!env->IsDirectory(output_dir).ok()) {
+ TF_RETURN_IF_ERROR(env->CreateDir(output_dir));
+ }
+
+ for (const auto& path : matching_paths) {
+ string contents;
+ TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, path, &contents));
+ contents = tensorflow::PBTxtFromMultiline(contents);
+ string output_path = io::JoinPath(output_dir, io::Basename(path));
+ // Write contents to output_path
+ TF_RETURN_IF_ERROR(
+ tensorflow::WriteStringToFile(env, output_path, contents));
+ }
+ return Status::OK();
+}
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ const std::string usage =
+ "Usage: convert_from_multiline input_dir output_dir";
+ if (argc != 3) {
+ std::cerr << usage << std::endl;
+ return -1;
+ }
+ TF_CHECK_OK(tensorflow::ConvertFilesFromMultiline(argv[1], argv[2]));
+ return 0;
+}