diff options
author | 2018-04-05 07:34:25 -0700 | |
---|---|---|
committer | 2018-04-05 07:34:25 -0700 | |
commit | c9c17e3d277fffba647d76f1c3a1cfa4b3001761 (patch) | |
tree | 1073e8354148c398d6abb87817e2d70e7eef582a | |
parent | c1c819b28476d72c1f086fc4e78ff7f013c225ce (diff) | |
parent | 361a13cf0c2b65d26f6e2b5b68875adfcea98dd0 (diff) |
Merge commit for internal changes
486 files changed, 14248 insertions, 3991 deletions
diff --git a/configure.py b/configure.py index 6744082d5d..81d5ad77ee 100644 --- a/configure.py +++ b/configure.py @@ -35,6 +35,7 @@ except ImportError: _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' +_DEFAULT_NCCL_VERSION = '1.3' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' @@ -484,6 +485,8 @@ def set_cc_opt_flags(environ_cp): if is_ppc64le(): # gcc on ppc64le does not support -march, use mcpu instead default_cc_opt_flags = '-mcpu=native' + elif is_windows(): + default_cc_opt_flags = '/arch:AVX' else: default_cc_opt_flags = '-march=native' question = ('Please specify optimization flags to use during compilation when' @@ -494,7 +497,7 @@ def set_cc_opt_flags(environ_cp): for opt in cc_opt_flags.split(): write_to_bazelrc('build:opt --copt=%s' % opt) # It should be safe on the same build host. - if not is_ppc64le(): + if not is_ppc64le() and not is_windows(): write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') # TODO(mikecase): Remove these default defines once we are able to get @@ -1102,6 +1105,81 @@ def set_tf_tensorrt_install_path(environ_cp): write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version) +def set_tf_nccl_install_path(environ_cp): + """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION. + + Args: + environ_cp: copy of the os.environ. + + Raises: + ValueError: if this method was called under non-Linux platform. + UserInputError: if user has provided invalid input multiple times. + """ + if not is_linux(): + raise ValueError('Currently NCCL is only supported on Linux platforms.') + + ask_nccl_version = ( + 'Please specify the NCCL version you want to use. ' + '[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION + + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): + tf_nccl_version = get_from_env_or_user_or_default( + environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION) + tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) + + if tf_nccl_version == '1': + break # No need to get install path, NCCL 1 is a GitHub repo. + + # TODO(csigg): Look with ldconfig first if we can find the library in paths + # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding + # include directory. This is where the NCCL .deb packages install them. + # Then ask the user if we should use that. Instead of a single + # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to + # nccl_configure.bzl + default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') + ask_nccl_path = (r'Please specify the location where NCCL %s library is ' + 'installed. Refer to README.md for more details. [Default ' + 'is %s]:') % (tf_nccl_version, default_nccl_path) + nccl_install_path = get_from_env_or_user_or_default( + environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path)) + if is_windows() or is_cygwin(): + nccl_install_path = cygpath(nccl_install_path) + + if is_windows(): + nccl_lib_path = 'lib/x64/nccl.lib' + elif is_linux(): + nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version + elif is_macos(): + nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version + + nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) + nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') + if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + # Set NCCL_INSTALL_PATH + environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path + write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) + break + + # Reset and Retry + print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' + 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path, + nccl_hdr_path)) + + environ_cp['TF_NCCL_VERSION'] = '' + else: + raise UserInputError('Invalid TF_NCCL setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) + + # Set TF_NCCL_VERSION + environ_cp['TF_NCCL_VERSION'] = tf_nccl_version + write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) + + def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1438,6 +1516,7 @@ def main(): set_tf_cudnn_version(environ_cp) if is_linux(): set_tf_tensorrt_install_path(environ_cp) + set_tf_nccl_install_path(environ_cp) set_tf_cuda_compute_capabilities(environ_cp) if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( 'LD_LIBRARY_PATH') != '1': diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index b32f574628..fe85f8ee0e 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1496,7 +1496,8 @@ TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); // If index is out of bounds, an error code will be set in the status object, // and a null pointer will be returned. TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, - int index, TF_Status*); + int index, + TF_Status* status); // Retrieves the type of the device at the given index. // @@ -1506,14 +1507,15 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, // If index is out of bounds, an error code will be set in the status object, // and a null pointer will be returned. TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, - int index, TF_Status*); + int index, + TF_Status* status); // Retrieve the amount of memory associated with a given device. // // If index is out of bounds, an error code will be set in the status object, // and -1 will be returned. TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( - const TF_DeviceList* list, int index, TF_Status*); + const TF_DeviceList* list, int index, TF_Status* status); // -------------------------------------------------------------------------- // Load plugins containing custom ops and kernels diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 028f146be3..ca80db23ed 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -53,7 +53,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) + EXPECT_TRUE(str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index bb1492fca2..c96a38dec3 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -496,9 +496,11 @@ tensorflow::Status ValidateInputTypeAndPlacement( expected_device->name(), " but is actually on ", actual_device->name(), " (operation running on ", op_device->name(), ")", - " Tensors can be copied explicitly using .gpu() or .cpu()," - " or transparently copied by using tfe.enable_eager_execution(" - "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices" + " Tensors can be copied explicitly using .gpu() or .cpu() " + "methods," + " or transparently copied by using tf.enable_eager_execution(" + "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors " + "between devices" " may slow down your model"); case tensorflow::DEVICE_PLACEMENT_WARN: LOG(WARNING) << "before computing " << op->name << " input #" << i diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 4c64d2cfe3..72b8bc1871 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -133,9 +134,9 @@ TEST_F(LoaderTest, NoTagMatch) { Status st = LoadSavedModel(session_options, run_options, export_dir, {"missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied " - "tags: { missing-tag }")) + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: { missing-tag }")) << st.error_message(); } @@ -149,9 +150,9 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe, "missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE( - StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied tags: ")) + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: ")) << st.error_message(); } @@ -169,7 +170,7 @@ TEST_F(LoaderTest, SessionCreationFailure) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget)) + EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget)) << st.error_message(); } diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc index 3675d72ee3..5dbc4f5f6a 100644 --- a/tensorflow/cc/tutorials/example_trainer.cc +++ b/tensorflow/cc/tutorials/example_trainer.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/graph/default_device.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -166,7 +167,8 @@ namespace { bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, int32* dst) { - if (arg.Consume(flag) && arg.Consume("=")) { + if (tensorflow::str_util::ConsumePrefix(&arg, flag) && + tensorflow::str_util::ConsumePrefix(&arg, "=")) { char extra; return (sscanf(arg.data(), "%d%c", dst, &extra) == 1); } @@ -176,7 +178,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool* dst) { - if (arg.Consume(flag)) { + if (tensorflow::str_util::ConsumePrefix(&arg, flag)) { if (arg.empty()) { *dst = true; return true; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 53ec6c1e60..b04b333141 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -825,6 +825,7 @@ Status Encapsulator::Subgraph::AddHostComputes( builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + builder.Attr("_outside_compilation_subgraph", oc_subgraph_name); Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 56efe98fdb..8599a7038a 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -902,7 +902,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice<DataType>({})}}, + {"shapes", gtl::ArraySlice<DataType>({})}, + {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1046,7 +1047,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, - {"shapes", gtl::ArraySlice<DataType>({})}}, + {"shapes", gtl::ArraySlice<DataType>({})}, + {"_outside_compilation_subgraph", "O2"}}, {"F"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", @@ -1056,7 +1058,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice<DataType>({})}}, + {"shapes", gtl::ArraySlice<DataType>({})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"i_0_retval", "I:o:0"}}); @@ -1193,7 +1196,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice<DataType>({})}}, + {"shapes", gtl::ArraySlice<DataType>({})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); @@ -1214,7 +1218,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}}, + gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -1321,7 +1326,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}}, + gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1403,7 +1409,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}, + gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1482,7 +1489,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"Toutputs", gtl::ArraySlice<DataType>({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}}, + {"shapes", gtl::ArraySlice<TensorShapeProto>({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1561,7 +1569,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"Toutputs", gtl::ArraySlice<DataType>({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}}, + {"shapes", gtl::ArraySlice<TensorShapeProto>({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1725,7 +1734,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice<DataType>({})}}, + {"shapes", gtl::ArraySlice<DataType>({})}, + {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, {{"f_0_retval", "F:o:0"}}); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 381c0205fd..2e362e0a63 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -138,7 +138,7 @@ TEST(XlaCompilationTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } -TEST(XlaCompilationTest, UnsupportedTypes) { +TEST(XlaCompilationTest, Complex128Unsupported) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); GraphDef graphdef; { @@ -158,6 +158,27 @@ TEST(XlaCompilationTest, UnsupportedTypes) { EXPECT_TRUE(clusters.empty()); } +TEST(XlaCompilationTest, HalfSupported) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Tensor t(DT_HALF, TensorShape()); + t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_HALF) + .WithAttr("value", t)); + Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); + ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + EXPECT_FALSE(clusters.empty()); +} + TEST(XlaCompilationTest, ConcatWithConstArg) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); GraphDef graphdef; diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index d2dfdeea68..bc07dbd7bd 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -62,8 +62,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array<DataType, 6> kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array<DataType, 7> kAllXlaCpuTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 5a1db81774..ac60423d95 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -62,8 +62,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array<DataType, 6> kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array<DataType, 8> kAllXlaGpuTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, + DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 204a2a2f90..edabdc218a 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -375,7 +375,6 @@ tf_xla_py_test( name = "momentum_test", size = "small", srcs = ["momentum_test.py"], - tags = ["no_oss"], deps = [ ":xla_test", "//tensorflow/python:array_ops", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index ba7b9bacd2..d1d7379c0a 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -190,19 +190,24 @@ class BinaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) - self._testBinary( - gen_nn_ops.sparse_softmax_cross_entropy_with_logits, - np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], - [0.9, 1.0, 1.1, 1.2]], dtype=dtype), - np.array([2, 1, 7], dtype=np.int32), - expected=[ - np.array([1.342536, 1.442536, np.nan], dtype=dtype), - np.array([[0.213838, 0.236328, -0.738817, 0.288651], - [0.213838, -0.763672, 0.261183, 0.288651], - [np.nan, np.nan, np.nan, np.nan]], - dtype=dtype), - ], - equality_test=self.ListsAreClose) + # TODO(b/68813416): Fails with bfloat16. + if dtype != dtypes.bfloat16.as_numpy_dtype: + self._testBinary( + gen_nn_ops.sparse_softmax_cross_entropy_with_logits, + np.array( + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2]], + dtype=dtype), + np.array([2, 1, 7], dtype=np.int32), + expected=[ + np.array([1.342536, 1.442536, np.nan], dtype=dtype), + np.array( + [[0.213838, 0.236328, -0.738817, 0.288651], [ + 0.213838, -0.763672, 0.261183, 0.288651 + ], [np.nan, np.nan, np.nan, np.nan]], + dtype=dtype), + ], + equality_test=self.ListsAreClose) def testIntOps(self): for dtype in self.int_types: @@ -260,12 +265,6 @@ class BinaryOpsTest(XLATestCase): np.array([[1], [2]], dtype=dtype), dtype(7), expected=np.array([[8], [9]], dtype=dtype)) - self._testBinary( - math_ops.add, - np.array([0xffffffff, 0xfffffffff, 1, 1], dtype=np.int64), - np.array([1, 1, 0xffffffff, 0xfffffffff], dtype=np.int64), - expected=np.array( - [1 << 32, 1 << 36, 1 << 32, 1 << 36], dtype=np.int64)) self._testBinary( math_ops.subtract, @@ -361,6 +360,12 @@ class BinaryOpsTest(XLATestCase): np.array([2, -1], dtype=dtype), expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) + self._testBinary( + math_ops.add, + np.array([0xffffffff, 0xfffffffff, 1, 1], dtype=np.int64), + np.array([1, 1, 0xffffffff, 0xfffffffff], dtype=np.int64), + expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36], dtype=np.int64)) + def testComplexOps(self): for dtype in self.complex_types: ctypes = {np.complex64: np.float32} diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 0528a5415d..a9db1c173d 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -56,7 +56,7 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, elif backend == "gpu": backend_args += [ "--test_device=XLA_GPU", - "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" + "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16" ] backend_tags += ["requires-gpu-sm35"] elif backend in plugins: @@ -89,4 +89,3 @@ def generate_backend_suites(backends=[]): backends = all_backends() for backend in backends: native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend]) - diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 5010fe5e21..1a8989d7c2 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -34,6 +34,13 @@ from tensorflow.python.platform import test class CholeskyOpTest(XLATestCase): + # Cholesky defined for float64, float32, complex64, complex128 + # (https://www.tensorflow.org/api_docs/python/tf/cholesky) + @property + def float_types(self): + return set(super(CholeskyOpTest, self).float_types).intersection( + (np.float64, np.float32, np.complex64, np.complex128)) + def _verifyCholeskyBase(self, sess, placeholder, x, chol, verification, atol): chol_np, verification_np = sess.run([chol, verification], {placeholder: x}) self.assertAllClose(x, verification_np, atol=atol) diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index cccb7f5789..5819b2bf2b 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -37,6 +37,14 @@ def MakePlaceholder(x): class MatrixTriangularSolveOpTest(XLATestCase): + # MatrixTriangularSolve defined for float64, float32, complex64, complex128 + # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve) + @property + def float_types(self): + return set(super(MatrixTriangularSolveOpTest, + self).float_types).intersection( + (np.float64, np.float32, np.complex64, np.complex128)) + def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca, placeholder_b, a, clean_a, b, verification, atol): diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index 92518aadc4..6083981493 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import test @@ -156,6 +157,12 @@ class SpaceToBatchNDTest(XLATestCase): paddings = np.array(paddings).reshape((len(block_shape), 2)) with self.test_session() as sess, self.test_scope(): for dtype in self.float_types: + # TODO(b/68813416): Skip bfloat16's as the input type for direct is + # float32 and results in a mismatch, while making testDirect provide the + # correctly typed input results in 'no fill-function for data-type' + # error. + if dtype == dtypes.bfloat16.as_numpy_dtype: + continue placeholder = array_ops.placeholder(dtype) # outputs = space_to_batch(inputs) x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index a8ab235378..17149aa1c8 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -793,7 +793,10 @@ class UnaryOpsTest(XLATestCase): self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype) self._assertSoftplusMatchesExpected( [[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype) - log_eps = np.log(np.finfo(dtype).eps) + if dtype == dtypes.bfloat16.as_numpy_dtype: + log_eps = np.log(np.finfo(np.float32).eps) + else: + log_eps = np.log(np.finfo(dtype).eps) one = dtype(1) ten = dtype(10) self._assertSoftplusMatchesExpected([ diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index b08d6ab21e..8ecad00f6e 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -230,7 +230,10 @@ class SliceAssignTest(XLATestCase): # shrink shape changes checker[1:2, 1] = [66] checker[1, 1:2] = [66] - checker[1, 1] = 66 + if dtype != dtypes.bfloat16.as_numpy_dtype: + # TODO(b/68813416): valnp call above results in an ndarray and not a + # number for bfloat16s. + checker[1, 1] = 66 # newaxis shape changes checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]] # shrink and newaxis @@ -243,8 +246,11 @@ class SliceAssignTest(XLATestCase): # Assign vector to scalar (rank-0) using newaxis checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype) - checker2[()] = 6 # no indices - checker2[...] = 6 # ellipsis + if dtype != dtypes.bfloat16.as_numpy_dtype: + # TODO(b/68813416): valnp call above results in an ndarray and not a + # number for bfloat16s. + checker2[()] = 6 # no indices + checker2[...] = 6 # ellipsis checker2[None] = [6] # new axis def testUninitialized(self): diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index ff7453194a..e255b01dd7 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -51,13 +51,13 @@ constexpr std::array<DataType, 9> kNumericTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; -constexpr std::array<DataType, 8> kCpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, +constexpr std::array<DataType, 9> kCpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; -constexpr std::array<DataType, 8> kGpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL}}; +constexpr std::array<DataType, 10> kGpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index c4c8894374..3f45167fcb 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -324,8 +324,38 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel( StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice<XlaComputationInstance> computations) { - return Unimplemented( - "ExecuteParallel is not yet implemented for XlaComputation."); + ExecuteGraphParallelRequest request; + + for (const XlaComputationInstance& computation : computations) { + ExecuteGraphRequest single_request; + *single_request.mutable_computation() = computation.computation.proto(); + for (GlobalData* argument : computation.arguments) { + *single_request.add_arguments() = argument->handle(); + } + *single_request.mutable_execution_options() = computation.execution_options; + *request.add_requests() = single_request; + } + + ExecuteParallelResponse response; + VLOG(1) << "making execute-graph-parallel request: " + << request.ShortDebugString(); + tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + std::vector<std::unique_ptr<GlobalData>> outputs; + for (size_t i = 0; i < computations.size(); ++i) { + outputs.push_back( + MakeUnique<GlobalData>(stub_, response.responses(i).output())); + if (computations[i].execution_profile != nullptr) { + *computations[i].execution_profile = response.responses(i).profile(); + } + } + + return std::move(outputs); } StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles( diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index d02972f2c0..f4673a8204 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -24,6 +24,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 24048a1e5a..63df449e0b 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -26,6 +26,7 @@ limitations under the License. namespace xla { namespace { + using InstructionGenerator = ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&, const ComputationDataHandle&); @@ -47,6 +48,27 @@ Computation CreateScalarComputation(const string& name, PrimitiveType type, generator(b.get(), lhs, rhs); return b->BuildAndNoteError(); } + +using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&); + +XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, + XlaBuilder* builder, + XlaOpGenerator generator) { + std::unique_ptr<XlaBuilder> b; + if (type == PRED) { + b = builder->CreateSubBuilder(name); + } else { + b = builder->CreateSubBuilder( + tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); + } + + const Shape scalar = ShapeUtil::MakeShape(type, {}); + auto lhs = b->Parameter(0, scalar, "lhs"); + auto rhs = b->Parameter(1, scalar, "rhs"); + generator(b.get(), lhs, rhs); + return b->BuildAndNoteError(); +} + } // namespace Computation CreateScalarAddComputation(PrimitiveType type, @@ -60,7 +82,7 @@ Computation CreateScalarAddComputation(PrimitiveType type, Computation CreateScalarMultiplyComputation(PrimitiveType type, ComputationBuilder* builder) { return CreateScalarComputation( - "add", type, builder, + "mul", type, builder, [](ComputationBuilder* b, const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); }); } @@ -114,4 +136,75 @@ StatusOr<ComputationDataHandle> Any(const ComputationDataHandle& predicates, return builder->Reduce(predicates, f, logical_or, all_dimensions); } +XlaComputation CreateScalarAddComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "add", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Add(lhs, rhs); + }); +} + +XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "mul", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Mul(lhs, rhs); + }); +} + +XlaComputation CreateScalarGeComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "ge", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Ge(lhs, rhs); + }); +} + +XlaComputation CreateScalarMaxComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "max", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Max(lhs, rhs); + }); +} + +XlaComputation CreateScalarMinComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "min", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Min(lhs, rhs); + }); +} + +XlaComputation CreateScalarAndComputation(XlaBuilder* builder) { + return CreateScalarComputation( + "and", PRED, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->And(lhs, rhs); + }); +} + +XlaComputation CreateScalarOrComputation(XlaBuilder* builder) { + return CreateScalarComputation( + "or", PRED, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Or(lhs, rhs); + }); +} + +StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder) { + auto f = builder->ConstantR0<bool>(false); + XlaComputation logical_or = CreateScalarOrComputation(builder); + TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, + builder->GetShape(predicates)); + std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::iota(all_dimensions.begin(), all_dimensions.end(), 0); + return builder->Reduce(predicates, f, logical_or, all_dimensions); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index ae89784bc2..f4d3fc8015 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -56,6 +58,48 @@ Computation CreateScalarOrComputation(ComputationBuilder* builder); StatusOr<ComputationDataHandle> Any(const ComputationDataHandle& predicates, ComputationBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar add computation and returns it. +XlaComputation CreateScalarAddComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar multiply computation and returns it. +XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar ge computation and returns it. +XlaComputation CreateScalarGeComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar max computation and returns it. +XlaComputation CreateScalarMaxComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar min computation and returns it. +XlaComputation CreateScalarMinComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar logical AND computation and returns it. +XlaComputation CreateScalarAndComputation(XlaBuilder* builder); + +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar logical OR computation and returns it. +XlaComputation CreateScalarOrComputation(XlaBuilder* builder); + +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Returns whether any predicate in "predicates" is set. +// +// Note: if predicates is zero-sized, Any() vacuously returns false. +StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index e51a8b14c0..2d587cc3b9 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include <functional> #include <numeric> #include <string> #include <utility> @@ -44,6 +45,7 @@ int64 GetUniqueId() { bool CanBeRoot(HloOpcode opcode) { switch (opcode) { case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kOutfeed: case HloOpcode::kTrace: return false; @@ -52,20 +54,35 @@ bool CanBeRoot(HloOpcode opcode) { } } +StatusOr<std::vector<Shape>> GetOperandShapes( + tensorflow::gtl::ArraySlice<XlaOp> operands) { + std::vector<Shape> operand_shapes; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape()); + operand_shapes.push_back(shape); + } + return operand_shapes; +} + } // namespace StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const { + TF_RETURN_IF_ERROR(first_error_); + TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op)); return instr->shape(); } StatusOr<Shape> XlaOp::GetShape() const { - TF_RET_CHECK(builder_ != nullptr); + if (builder_ == nullptr) { + return InvalidArgument( + "cannot GetShape for an invalid XlaOp with handle %lld", handle()); + } return builder_->GetShape(*this); } XlaBuilder::XlaBuilder(const string& computation_name) - : name_(computation_name) {} + : name_(computation_name), unique_id_(GetUniqueId()) {} XlaBuilder::~XlaBuilder() {} @@ -81,7 +98,22 @@ void XlaBuilder::NoteError(const Status& error) { } } +XlaOp XlaBuilder::NoteErrorOrReturn( + const std::function<StatusOr<XlaOp>()>& op_creator) { + if (!first_error_.ok()) { + return {}; + } + auto op = op_creator(); + if (!op.ok()) { + NoteError(op.status()); + return {}; + } + return op.ConsumeValueOrDie(); +} + StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) { + TF_RETURN_IF_ERROR(first_error_); + TF_RET_CHECK(root_id != nullptr); ProgramShape program_shape; @@ -148,7 +180,6 @@ StatusOr<XlaComputation> XlaBuilder::Build() { } HloComputationProto entry; - entry.set_name(name_); { int64 root_id; @@ -162,9 +193,9 @@ StatusOr<XlaComputation> XlaBuilder::Build() { entry.add_instructions()->Swap(&instruction); } - const int64 id = GetUniqueId(); - entry.set_id(id); - XlaComputation computation(id); + entry.set_id(unique_id_); + entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. + XlaComputation computation(entry.id()); HloModuleProto* module = computation.mutable_proto(); module->set_name(entry.name()); module->set_id(entry.id()); @@ -187,6 +218,8 @@ StatusOr<XlaComputation> XlaBuilder::Build() { StatusOr<XlaOp> XlaBuilder::InDimBroadcast( const Shape& shape, const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + TF_RETURN_IF_ERROR(first_error_); + HloInstructionProto instr; *instr.mutable_shape() = shape; for (int64 dim : broadcast_dimensions) { @@ -197,6 +230,8 @@ StatusOr<XlaOp> XlaBuilder::InDimBroadcast( StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape, const XlaOp& operand) { + TF_RETURN_IF_ERROR(first_error_); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); CHECK(ShapeUtil::IsScalar(operand_shape) || @@ -240,7 +275,7 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferUnaryOpShape(unop, operand_shape)); return AddInstruction(std::move(instr), unop, {operand}); - }()); + }); } XlaOp XlaBuilder::BinaryOp( @@ -297,7 +332,7 @@ XlaOp XlaBuilder::BinaryOp( } return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs}); - }()); + }); } XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, @@ -335,7 +370,7 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, } return AddInstruction(std::move(instr), triop, {updated_lhs, updated_rhs, updated_ehs}); - }()); + }); } XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, @@ -354,7 +389,7 @@ XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { *instr.mutable_shape() = literal.shape(); *instr.mutable_literal() = literal.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kConstant); - }()); + }); } XlaOp XlaBuilder::Call(const XlaComputation& computation, @@ -362,11 +397,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; std::vector<const Shape*> operand_shape_ptrs; - std::vector<Shape> operand_shapes; - for (const auto& operand : operands) { - TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape()); - operand_shapes.push_back(shape); - } + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, @@ -376,15 +407,10 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, ShapeInference::InferCallShape(operand_shape_ptrs, /*to_apply=*/called_program_shape)); - // Add called computation. - instr.add_called_computation_ids( - computation.proto().entry_computation_id()); - for (const HloComputationProto& e : computation.proto().computations()) { - embedded_.insert({e.id(), e}); - } + AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kCall, operands); - }()); + }); } XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, @@ -400,7 +426,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, instr.set_name(name); *instr.mutable_shape() = shape; return AddInstruction(std::move(instr), HloOpcode::kParameter); - }()); + }); } XlaOp XlaBuilder::Broadcast( @@ -424,10 +450,12 @@ XlaOp XlaBuilder::Broadcast( dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank; } return InDimBroadcast(shape, operand, dimensions); - }()); + }); } StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { + TF_RETURN_IF_ERROR(first_error_); + HloInstructionProto instr; *instr.mutable_shape() = shape; return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand}); @@ -437,7 +465,22 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> start_indices, tensorflow::gtl::ArraySlice<int64> limit_indices, tensorflow::gtl::ArraySlice<int64> strides) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferSliceShape(operand_shape, start_indices, + limit_indices, strides)); + for (int i = 0; i < start_indices.size(); i++) { + auto* slice_config = instr.add_slice_dimensions(); + slice_config->set_start(start_indices[i]); + slice_config->set_limit(limit_indices[i]); + slice_config->set_stride(strides[i]); + } + + return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand}); + }); } XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, @@ -447,17 +490,60 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, tensorflow::gtl::ArraySlice<int64> slice_sizes) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDynamicSliceShape( + operand_shape, start_indices_shape, slice_sizes)); + + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } + + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, + {operand, start_indices}); + }); } XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDynamicUpdateSliceShape( + operand_shape, update_shape, start_indices_shape)); + + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + {operand, update, start_indices}); + }); } XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands, int64 dimension) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + std::vector<const Shape*> operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); + + instr.add_dimensions(dimension); + + return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands); + }); } XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, @@ -477,7 +563,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, ? operand : Transpose(operand, dimensions); return Reshape(shape, transposed); - }()); + }); } XlaOp XlaBuilder::Reshape(const XlaOp& operand, @@ -487,7 +573,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, std::vector<int64> dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); return Reshape(operand, dimensions, new_sizes); - }()); + }); } XlaOp XlaBuilder::Collapse(const XlaOp& operand, @@ -496,7 +582,12 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { - UnimplementedOp(); + NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); + }); } XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, @@ -508,18 +599,14 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) { return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; std::vector<const Shape*> operand_shape_ptrs; - std::vector<Shape> operand_shapes; - for (const XlaOp& e : elements) { - TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(e)); - operand_shapes.push_back(shape); - } + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferVariadicOpShape( HloOpcode::kTuple, operand_shape_ptrs)); return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); - }()); + }); } XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { @@ -538,7 +625,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement, {tuple_data}); - }()); + }); } XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, @@ -572,12 +659,29 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + + DotDimensionNumbers dimension_numbers; + dimension_numbers.add_lhs_contracting_dimensions( + lhs_shape.dimensions_size() == 1 ? 0 : 1); + dimension_numbers.add_rhs_contracting_dimensions(0); + return DotGeneral(lhs, rhs, dimension_numbers); + }); } XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, + dimension_numbers)); + *instr.mutable_dot_dimension_numbers() = dimension_numbers; + return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); + }); } XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, @@ -788,7 +892,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, instr.add_dimensions(dim); } return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand}); - }()); + }); } XlaOp XlaBuilder::Rev(const XlaOp& operand, @@ -812,7 +916,14 @@ XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConvertShape(operand_shape, new_element_type)); + return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand}); + }); } XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, @@ -846,19 +957,64 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands, return UnimplementedOp(); } +XlaOp XlaBuilder::RngOp(RandomDistribution distribution, + tensorflow::gtl::ArraySlice<XlaOp> parameters, + const Shape& shape) { + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + // Check the number of parameters per RNG distribution. + switch (distribution) { + case RandomDistribution::RNG_NORMAL: + case RandomDistribution::RNG_UNIFORM: + if (parameters.size() != 2) { + return InvalidArgument( + "RNG distribution (%s) expects 2 parameters, but got %ld", + RandomDistribution_Name(distribution).c_str(), parameters.size()); + } + break; + default: + LOG(FATAL) << "unhandled distribution " << distribution; + } + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + *instr.mutable_shape() = shape; + + instr.set_distribution(distribution); + + return AddInstruction(std::move(instr), HloOpcode::kRng, parameters); + }); +} + XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape) { - return UnimplementedOp(); + return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); } XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape) { - return UnimplementedOp(); + return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); } XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, const XlaOp& init) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + // Infer shape. + TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const auto& condition_program_shape, + condition.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferWhileShape(condition_program_shape, + body_program_shape, init_shape)); + // Body comes before condition computation in the vector. + AddCalledComputation(body, &instr); + AddCalledComputation(condition, &instr); + return AddInstruction(std::move(instr), HloOpcode::kWhile, {init}); + }); } XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, @@ -878,7 +1034,27 @@ XlaOp XlaBuilder::Reduce( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); + TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferReduceShape( + operand_shape, init_shape, dimensions_to_reduce, + called_program_shape)); + + for (int64 dim : dimensions_to_reduce) { + instr.add_dimensions(dim); + } + + AddCalledComputation(computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kReduce, + {operand, init_value}); + }); } XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, @@ -952,11 +1128,43 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, } void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { - UnimplementedOp(); + NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + // Send instruction produces a tuple of {aliased operand, U32 context}. + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + *instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN( + XlaOp send, + AddInstruction(std::move(instr), HloOpcode::kSend, {operand})); + + HloInstructionProto send_done_instr; + *send_done_instr.mutable_shape() = ShapeUtil::MakeNil(); + send_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, + {send}); + }); } XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + // Recv instruction produces a tuple of {receive buffer, U32 context}. + *instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp recv, + AddInstruction(std::move(instr), HloOpcode::kRecv, {})); + + HloInstructionProto recv_done_instr; + *recv_done_instr.mutable_shape() = shape; + recv_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, + {recv}); + }); } StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand, @@ -1055,20 +1263,27 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { StatusOr<XlaOp> XlaBuilder::AddInstruction( HloInstructionProto&& instr, HloOpcode opcode, tensorflow::gtl::ArraySlice<XlaOp> operands) { + TF_RETURN_IF_ERROR(first_error_); + const int64 handle = instructions_.size(); instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); if (instr.name().empty()) { - instr.set_name(StrCat(instr.opcode(), ".", handle)); + instr.set_name(StrCat(instr.opcode(), ".", unique_id_, ".", handle)); } else { // Append the handle to make sure the name is unique. - instr.set_name(StrCat(instr.name(), ".", handle)); + instr.set_name(StrCat(instr.name(), ".", unique_id_, ".", handle)); } for (const auto& operand : operands) { - TF_RET_CHECK(operand.builder_ != nullptr); - TF_RET_CHECK(operand.builder_ == this) - << "Do not add XlaOp from builder " << operand.builder_->name() - << " to builder " << this->name(); + if (operand.builder_ == nullptr) { + return InvalidArgument("invalid XlaOp with handle %lld", + operand.handle()); + } + if (operand.builder_ != this) { + return InvalidArgument("Do not add XlaOp from builder %s to builder %s", + operand.builder_->name().c_str(), + this->name().c_str()); + } instr.add_operand_ids(operand.handle()); } @@ -1083,8 +1298,22 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction( return op; } +void XlaBuilder::AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr) { + instr->add_called_computation_ids(computation.proto().entry_computation_id()); + for (const HloComputationProto& e : computation.proto().computations()) { + embedded_.insert({e.id(), e}); + } +} + StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction( const XlaOp& op) const { + TF_RETURN_IF_ERROR(first_error_); + + if (op.builder_ != this) { + return InvalidArgument("invalid XlaOp with handle %lld", op.handle()); + } + TF_RET_CHECK(op.builder_ == this); if (op.handle() >= instructions_.size() || op.handle() < 0) { return InvalidArgument("no XlaOp value %lld", op.handle()); diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index f66feb93ce..0673b86646 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -803,19 +803,16 @@ class XlaBuilder { HloInstructionProto&& instr, HloOpcode opcode, tensorflow::gtl::ArraySlice<XlaOp> operands = {}); + void AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr); + // Notes that the error occurred by: // * storing it internally and capturing a backtrace if it's the first error // (this deferred value will be produced on the call to Build()) // * dying if die_immediately_on_error_ is true void NoteError(const Status& error); - XlaOp NoteErrorOrReturn(StatusOr<XlaOp>&& op) { - if (!op.ok()) { - NoteError(op.status()); - return XlaOp(); - } - return op.ConsumeValueOrDie(); - } + XlaOp NoteErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator); // Helper method that creates an empty op and notes error. XlaOp UnimplementedOp(); @@ -835,6 +832,10 @@ class XlaBuilder { XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs); + XlaOp RngOp(RandomDistribution distribution, + tensorflow::gtl::ArraySlice<XlaOp> parameters, + const Shape& shape); + StatusOr<XlaOp> InDimBroadcast( const Shape& shape, const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); @@ -852,7 +853,8 @@ class XlaBuilder { // computation and fills the root_id in the pointer. StatusOr<ProgramShape> GetProgramShape(int64* root_id); - string name_; // Name to use for the built computation. + string name_; // Name to use for the built computation. + int64 unique_id_; // The unique id for the built computation. // The first error encountered while building the computation. // This is OK until the first error is encountered. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index b21ab3044f..2bacc6a914 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -521,6 +521,17 @@ ComputationDataHandle LocalComputationBuilder::Conditional( false_computation.computation()); } +StatusOr<bool> LocalComputationBuilder::IsConstant( + const ComputationDataHandle& operand, int64 num_parameters) { + return builder_.IsConstant(operand, num_parameters); +} + +StatusOr<std::unique_ptr<Literal>> LocalComputationBuilder::ComputeConstant( + const ComputationDataHandle& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice<Literal> parameters) { + return builder_.ComputeConstant(operand, output_layout, parameters); +} + #define _FORWARD(method_name, return_sig, args_sig, args) \ return_sig LocalComputationBuilder::method_name args_sig { \ return builder_.method_name args; \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index a7375c8965..31046e60f1 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -268,6 +268,13 @@ class LocalComputationBuilder { const ComputationDataHandle& false_operand, const LocalComputation& false_computation); + StatusOr<bool> IsConstant(const ComputationDataHandle& operand, + int64 num_parameters); + + StatusOr<std::unique_ptr<Literal> > ComputeConstant( + const ComputationDataHandle& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice<Literal> parameters); + #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 8f231d1a12..ac792e8189 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -182,7 +182,7 @@ tensorflow::ImportNumpy(); %typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { const int64 handle = numpy::PyIntOrPyLongToLong($input); if (handle == -1 && PyErr_Occurred()) { - return NULL; + SWIG_fail; } temp.set_handle(handle); $1 = &temp; @@ -201,7 +201,7 @@ tensorflow::ImportNumpy(); } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -211,7 +211,7 @@ tensorflow::ImportNumpy(); $result = numpy::PyObjectFromXlaLiteral(*value); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -224,7 +224,7 @@ tensorflow::ImportNumpy(); } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -233,7 +233,16 @@ tensorflow::ImportNumpy(); $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; + } +} + +%typemap(out) StatusOr<bool> { + if ($1.ok()) { + $result = PyBool_FromLong($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; } } @@ -241,7 +250,7 @@ tensorflow::ImportNumpy(); if (!$1.ok()) { PyErr_SetString( PyExc_RuntimeError, $1.ToString().c_str()); - return NULL; + SWIG_fail; } Py_INCREF(Py_None); $result = Py_None; @@ -253,7 +262,7 @@ tensorflow::ImportNumpy(); (std::vector<int64> temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.resize(size); @@ -265,13 +274,13 @@ tensorflow::ImportNumpy(); PyExc_TypeError, "Argument sequence element cannot be converted to int"); Py_DECREF(o); - return NULL; + SWIG_fail; } temps[i] = numpy::PyIntOrPyLongToLong(py_int); if (temps[i] == -1 && PyErr_Occurred()) { Py_DECREF(py_int); Py_DECREF(o); - return NULL; + SWIG_fail; } Py_DECREF(py_int); Py_DECREF(o); @@ -285,7 +294,7 @@ tensorflow::ImportNumpy(); (std::vector<ComputationDataHandle> temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.resize(size); @@ -296,13 +305,13 @@ tensorflow::ImportNumpy(); PyErr_SetString( PyExc_TypeError, "Argument sequence element cannot be converted to int"); - return NULL; + SWIG_fail; } const int64 handle = numpy::PyIntOrPyLongToLong(py_int); if (handle == -1 && PyErr_Occurred()) { Py_DECREF(py_int); Py_DECREF(o); - return NULL; + SWIG_fail; } temps[i].set_handle(handle); Py_DECREF(py_int); @@ -317,7 +326,7 @@ tensorflow::ImportNumpy(); (std::vector<LocalShapedBuffer*> temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.reserve(size); @@ -326,7 +335,7 @@ tensorflow::ImportNumpy(); LocalShapedBuffer* lsbp; if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), SWIG_POINTER_EXCEPTION)) == -1) { - return NULL; + SWIG_fail; } temps.push_back(lsbp); Py_DECREF(o); @@ -340,7 +349,7 @@ tensorflow::ImportNumpy(); literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - return NULL; + SWIG_fail; } $1 = literal_status.ValueOrDie().get(); } @@ -352,7 +361,7 @@ tensorflow::ImportNumpy(); %typemap(out) StatusOr< std::unique_ptr<Literal> > { if (!$1.ok()) { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); } @@ -360,7 +369,7 @@ tensorflow::ImportNumpy(); %typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -369,7 +378,7 @@ tensorflow::ImportNumpy(); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); Py_DECREF(o); - return NULL; + SWIG_fail; } temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); Py_DECREF(o); @@ -383,7 +392,7 @@ tensorflow::ImportNumpy(); StatusOr<OpMetadata> statusor = numpy::OpMetadataFromPyObject($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -395,7 +404,7 @@ tensorflow::ImportNumpy(); StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -410,7 +419,7 @@ tensorflow::ImportNumpy(); StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -424,7 +433,7 @@ tensorflow::ImportNumpy(); %typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -433,7 +442,7 @@ tensorflow::ImportNumpy(); Py_DECREF(o); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temps.push_back(statusor.ConsumeValueOrDie()); } @@ -444,7 +453,7 @@ tensorflow::ImportNumpy(); std::vector<tensorflow::gtl::optional<Shape> > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -456,7 +465,7 @@ tensorflow::ImportNumpy(); Py_DECREF(o); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temps.push_back(statusor.ConsumeValueOrDie()); } @@ -470,18 +479,18 @@ tensorflow::ImportNumpy(); PyObject* py_int = numpy::PyNumberToPyInt($input); if (!py_int) { PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); - return NULL; + SWIG_fail; } const long value = numpy::PyIntOrPyLongToLong(py_int); if (value == -1 && PyErr_Occurred()) { Py_DECREF(py_int); - return NULL; + SWIG_fail; } if (!PrimitiveType_IsValid(value)) { PyErr_SetString( PyExc_TypeError, "Argument not valid for PrimitiveType enum"); Py_DECREF(py_int); - return NULL; + SWIG_fail; } $1 = static_cast<PrimitiveType>(value); } @@ -492,19 +501,19 @@ tensorflow::ImportNumpy(); (std::vector<std::pair<int64, int64> > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.reserve(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); if (!o) { - return NULL; + SWIG_fail; } PyObject* first = PyTuple_GetItem(o, 0); if (!first) { Py_DECREF(o); - return NULL; + SWIG_fail; } PyObject* first_pyint = numpy::PyNumberToPyInt(first); if (!first_pyint) { @@ -512,13 +521,13 @@ tensorflow::ImportNumpy(); PyExc_TypeError, "First pair item cannot be converted to int"); Py_DECREF(o); - return NULL; + SWIG_fail; } PyObject* second = PyTuple_GetItem(o, 1); if (!second) { Py_DECREF(o); Py_DECREF(first_pyint); - return NULL; + SWIG_fail; } PyObject* second_pyint = numpy::PyNumberToPyInt(second); if (!second_pyint) { @@ -527,21 +536,21 @@ tensorflow::ImportNumpy(); "Second pair item cannot be converted to int"); Py_DECREF(o); Py_DECREF(first_pyint); - return NULL; + SWIG_fail; } const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); if (first_value == -1 && PyErr_Occurred()) { Py_DECREF(o); Py_DECREF(first_pyint); Py_DECREF(second_pyint); - return NULL; + SWIG_fail; } const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); if (second_value == -1 && PyErr_Occurred()) { Py_DECREF(o); Py_DECREF(first_pyint); Py_DECREF(second_pyint); - return NULL; + SWIG_fail; } temps.push_back(std::make_pair(first_value, second_value)); Py_DECREF(o); @@ -559,26 +568,26 @@ tensorflow::ImportNumpy(); PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( $input, "lhs_contracting_dimensions"); if (!lhs_contracting_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(lhs_contracting_dimensions); if (length == -1) { Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); if (!item) { Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_lhs_contracting_dimensions(dimension); Py_DECREF(item); @@ -589,26 +598,26 @@ tensorflow::ImportNumpy(); PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( $input, "rhs_contracting_dimensions"); if (!lhs_contracting_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(rhs_contracting_dimensions); if (length == -1) { Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); if (!item) { Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_rhs_contracting_dimensions(dimension); Py_DECREF(item); @@ -619,26 +628,26 @@ tensorflow::ImportNumpy(); PyObject* lhs_batch_dimensions = PyObject_GetAttrString( $input, "lhs_batch_dimensions"); if (!lhs_batch_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(lhs_batch_dimensions); if (length == -1) { Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); if (!item) { Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_lhs_batch_dimensions(dimension); Py_DECREF(item); @@ -649,26 +658,26 @@ tensorflow::ImportNumpy(); PyObject* rhs_batch_dimensions = PyObject_GetAttrString( $input, "rhs_batch_dimensions"); if (!rhs_batch_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(rhs_batch_dimensions); if (length == -1) { Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); if (!item) { Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_rhs_batch_dimensions(dimension); Py_DECREF(item); @@ -684,20 +693,20 @@ tensorflow::ImportNumpy(); (PaddingConfig padding_config) { PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); if (!dimensions) { - return NULL; + SWIG_fail; } int length = PySequence_Size(dimensions); if (length == -1) { Py_DECREF(dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(dimensions, i); if (!item) { Py_DECREF(dimensions); - return NULL; + SWIG_fail; } int64 edge_padding_low, edge_padding_high, interior_padding; if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) @@ -705,7 +714,7 @@ tensorflow::ImportNumpy(); || !GetIntAttr(item, "interior_padding", &interior_padding)) { Py_DECREF(item); Py_DECREF(dimensions); - return NULL; + SWIG_fail; } Py_DECREF(item); @@ -727,32 +736,32 @@ tensorflow::ImportNumpy(); int64 value; if (!GetIntAttr($input, "input_batch_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_input_batch_dimension(value); if (!GetIntAttr($input, "input_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_input_feature_dimension(value); if (!GetIntAttr($input, "output_batch_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_output_batch_dimension(value); if (!GetIntAttr($input, "output_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_output_feature_dimension(value); if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_kernel_output_feature_dimension(value); if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_kernel_input_feature_dimension(value); @@ -761,24 +770,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "input_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_input_spatial_dimensions(dimension); Py_DECREF(item); @@ -787,24 +796,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_kernel_spatial_dimensions(dimension); Py_DECREF(item); @@ -813,24 +822,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "output_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_output_spatial_dimensions(dimension); Py_DECREF(item); @@ -865,12 +874,12 @@ tensorflow::ImportNumpy(); PyObject* o = PyObject_GetAttrString($input, "hlo_profile"); if (o == NULL) { - return NULL; + SWIG_fail; } if (o != Py_None) { if (!PyBool_Check(o)) { PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None."); - return NULL; + SWIG_fail; } build_options.set_hlo_profile(o == Py_True); } @@ -885,7 +894,7 @@ tensorflow::ImportNumpy(); if (!statusor.ok()) { PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); Py_DECREF(o); - return NULL; + SWIG_fail; } build_options.set_result_layout(statusor.ValueOrDie()); } @@ -951,6 +960,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::RngBernoulli; %unignore xla::swig::LocalComputationBuilder::While; %unignore xla::swig::LocalComputationBuilder::Conditional; +%unignore xla::swig::LocalComputationBuilder::IsConstant; %unignore xla::swig::LocalComputationBuilder::Eq; %unignore xla::swig::LocalComputationBuilder::Ne; %unignore xla::swig::LocalComputationBuilder::Ge; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index e548d420f4..9c81f6439d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1028,6 +1028,20 @@ class ComputationBuilder(object): _unwrap_data_handle(false_operand), false_computation.c_local_computation)) + def IsConstant(self, operand, num_parameters=0): + """Enqueues an IsConstant operation onto the computation. + + Args: + operand: a ComputationDataHandle to test. + num_parameters: optional int, number of computation parameters to treat as + constant (default 0). + + Returns: bool indicating whether `operand` is a compile-time constant, + meaning its value does not depend on parameters with index greater than or + equal to `num_parameters`. + """ + return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters) + def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 4c16c1f8b0..d97264ea64 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -855,6 +855,17 @@ class SingleOpTest(LocalComputationTest): self.assertTrue(np.all(lo <= result)) self.assertTrue(np.all(result < hi)) + def testIsConstant(self): + c = self._NewComputation() + a = c.ConstantS32Scalar(3) + b = c.ConstantS32Scalar(1) + x = c.ParameterFromNumpy(NumpyArrayS32(0)) + const_expr = c.Sub(b, a) + non_const_expr = c.Mul(const_expr, x) + self.assertTrue(c.IsConstant(const_expr)) + self.assertFalse(c.IsConstant(non_const_expr)) + # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) + class EmbeddedComputationsTest(LocalComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 80c24eaccf..4198260a22 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -87,7 +87,6 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), - execution_session_(string_pool_), symbol_resolver_(llvm::orc::createLegacyLookupResolver( [this](const std::string& name) -> llvm::JITSymbol { return this->ResolveRuntimeSymbol(name); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index aaeff2de87..f4260a95bc 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -102,7 +102,6 @@ class SimpleOrcJIT { std::unique_ptr<llvm::TargetMachine> target_machine_; const Disassembler disassembler_; const llvm::DataLayout data_layout_; - llvm::orc::SymbolStringPool string_pool_; llvm::orc::ExecutionSession execution_session_; std::shared_ptr<llvm::orc::SymbolResolver> symbol_resolver_; ObjLayerT object_layer_; diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index be92b1629a..471d2fd6ce 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -80,6 +80,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> Executable::ExecuteOnStreamWrapper( StatusOr<std::unique_ptr<ShapedBuffer>> return_value = ExecuteOnStream(run_options, arguments, profile_ptr.get()); + TF_RETURN_IF_ERROR(return_value.status()); if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 6f983d0b95..594413e88f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -304,19 +304,15 @@ void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet<HloComputation*>* visited, std::list<HloComputation*>* post_order) { - if (visited->count(computation) > 0) { - return; - } - - for (auto* instruction : computation->instructions()) { - for (HloComputation* called_computation : - instruction->called_computations()) { - ComputeComputationPostOrder(called_computation, visited, post_order); + if (visited->insert(computation).second) { + for (auto* instruction : computation->instructions()) { + for (HloComputation* called_computation : + instruction->called_computations()) { + ComputeComputationPostOrder(called_computation, visited, post_order); + } } + post_order->push_back(computation); } - - visited->insert(computation); - post_order->push_back(computation); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 693004d364..9d7251b6ae 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1520,14 +1520,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { arg_dim_counts[dim] = arg_dimensions[dim]; } - // Create mapping from result index to arg index. - const int64 result_rank = ShapeUtil::Rank(result->shape()); - int64 result_dim = 0; - std::vector<int64> result_to_arg_index(result_rank); + // Map each dimension in the result to a dimension in arg that isn't + // being reduced. + std::vector<int64> result_to_arg_index; for (int64 i = 0; i < arg_dimensions.size(); ++i) { if (arg_dim_steps[i] == 0) { - result_to_arg_index[result_dim] = i; - ++result_dim; + result_to_arg_index.push_back(i); } } @@ -1542,6 +1540,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { base[result_to_arg_index[i]] = multi_index[i]; } + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literal.shape()) && + IsScalarAdd(function)) { + double computed_result = 0; + auto func = [&](ArraySlice<int64> input_index) { + computed_result += arg_literal.Get<float>(input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return static_cast<ReturnT>(computed_result); + } auto func = [&](ArraySlice<int64> input_index) { auto curr_val = arg_literal.Get<ReturnT>(input_index); @@ -1554,19 +1566,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::unique_ptr<Literal> computed_result = embedded_evaluator.Evaluate<const Literal*>(*function, args) .ConsumeValueOrDie(); - // Clear visit states so that the we can use the evaluate again on + // Clear visit states so that we can use the evaluator again on // the same computation. embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. result_val = computed_result->Get<ReturnT>({}); - return true; }; - + // Computes one element of the result, reducing all dimensions that + // contribute to that element. ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func); - return result_val; })); @@ -1574,6 +1584,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + bool IsScalarAdd(HloComputation* computation) { + HloInstruction* instruction = computation->root_instruction(); + if (instruction->opcode() == HloOpcode::kAdd && + computation->num_parameters() == 2) { + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + return lhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(lhs->shape()) && + rhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; + } + return false; + } + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { auto operand = select_and_scatter->operand(0); auto source = select_and_scatter->operand(1); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 685cacd7f7..dd14dd3853 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -1205,6 +1206,80 @@ TEST_P(HloEvaluatorTest, LiteralTestUtil::ExpectEqual(*expected, *result); } +class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; + +// Tests that Reduce doesn't lose precision when adding many numbers (because +// it accumulates its result in a double). +TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { + HloComputation::Builder b(TestName()); + + constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 + std::vector<float> v(kNumElements, 1.0f); + HloInstruction* arg_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1<float>(v))); + HloInstruction* init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + + HloInstruction* reduce_instruction = b.AddInstruction( + HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, + /*dimensions_to_reduce=*/{0}, add_func)); + module().AddEntryComputation(b.Build()); + + HloEvaluator hlo_eval; + std::unique_ptr<Literal> result = + hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal<float>(kNumElements, *result); +} + +// Reducing many numbers should be fast because it doesn't create +// intermediate Literals; the microbenchmark should finish in < 1 msec. +void BM_ReducePrecisely(int num_iters) { + tensorflow::testing::StopTiming(); + HloComputation::Builder b("BM_ReducePrecisely"); + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config); + + constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 + std::vector<float> v(kNumElements, 1.0f); + HloInstruction* arg_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1<float>(v))); + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + + HloInstruction* reduce_instruction = b.AddInstruction( + HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, + /*dimensions_to_reduce=*/{0}, add_func)); + module.AddEntryComputation(b.Build()); + + HloEvaluator hlo_eval; + tensorflow::testing::StartTiming(); + hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + tensorflow::testing::StopTiming(); +} + +BENCHMARK(BM_ReducePrecisely); + TEST_P(HloEvaluatorTest, ReduceAdd) { HloComputation::Builder b(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a2a2c1e615..fcf9ebf5f7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -98,6 +98,13 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } } + if (instruction->opcode() == HloOpcode::kTrace) { + TF_RET_CHECK(instruction->operands().size() == 1) + << "Trace instruction should have 1 operand but sees " + << instruction->operands().size(); + instruction->mutable_operand(0)->set_tracing(instruction.get()); + } + TF_RET_CHECK(!proto.name().empty()); instruction->name_ = proto.name(); @@ -170,6 +177,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); instruction->operands_.push_back(operand); instruction->literal_ = Literal::CreateR1U8(tag); + operand->set_tracing(instruction.get()); return instruction; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a94ba145df..80f8408244 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -928,6 +928,13 @@ class HloInstruction { const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; } + // Returns the sharding unique device, if any. + tensorflow::gtl::optional<int64> sharding_unique_device() const { + if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) { + return tensorflow::gtl::optional<int64>(); + } + return sharding_->UniqueDevice().ValueOrDie(); + } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index fa5dcb0b36..54c34ce116 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -313,6 +313,27 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { if (!ShapeUtil::Compatible(send_shape, recv_shape)) { return FailedPrecondition("send/recv shapes do not match"); } + const HloModule* send_module = channel.send->parent()->parent(); + const HloModule* send_done_module = channel.send_done->parent()->parent(); + if (send_module != send_done_module) { + return FailedPrecondition( + "send and send-done (channel=%lld) must be on the same device: %lld " + "vs. %lld", + channel.id, GetModuleId(send_module), GetModuleId(send_done_module)); + } + const HloModule* recv_module = channel.recv->parent()->parent(); + const HloModule* recv_done_module = channel.recv_done->parent()->parent(); + if (recv_module != recv_done_module) { + return FailedPrecondition( + "recv and recv-done (channel=%lld) must be on the same device: %lld " + "vs. %lld", + channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module)); + } + if (send_module == recv_module) { + return FailedPrecondition( + "send and recv (channel=%lld) must be on different devices: %lld", + channel.id, GetModuleId(send_module)); + } } // Check if channel instructions are used only in allowed computations. diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 18d406f370..06204acbca 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -94,6 +94,10 @@ class HloSharding { // Create a new sharding from a protobuf OpSharding. static StatusOr<HloSharding> FromProto(const OpSharding& proto); + // Checks whether device is a reserved device number. A reserved device number + // has usually a special meaning, with dedicated handling logic. + static bool IsReservedDevice(int64 device) { return device < 0; } + OpSharding ToProto() const; string ToString() const; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 2a282f3be7..ec04239b4f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -762,7 +763,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { fake_argv_storage.push_back(""); for (const auto& it : options) { // Skip options the XLA backend itself consumes. - if (!tensorflow::StringPiece(it.first).starts_with("xla_")) { + if (!tensorflow::str_util::StartsWith(it.first, "xla_")) { if (it.second.empty()) { fake_argv_storage.push_back(it.first); } else { diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index f15117f45c..49ec38eb62 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -53,16 +53,8 @@ bool IsReshapeOrTranspose(const HloInstruction* instruction) { instruction->opcode() == HloOpcode::kTranspose; } -// Returns true if `a` is a broadcast instruction to target shape `shape` and -// its operand is a scalar. -bool IsBroadcastScalarToShape(const HloInstruction* a, const Shape& shape) { - return a->opcode() == HloOpcode::kBroadcast && - ShapeUtil::SameDimensions(a->shape(), shape) && - ShapeUtil::IsScalar(a->operand(0)->shape()); -} - -// Returns true iff `instruction` can change its shape simply by adjusting -// metadata. +// Returns true if `instruction` can change its shape simply by adjusting +// metadata or if `instruction` is a broadcast of a scalar value. bool CanTriviallyChangeShape(const HloInstruction* instruction) { // NOTE: Technically a sequence of reshape(reshape(constant)) is also // trivially reshapable, so we might be tempted to simply recurse if @@ -97,19 +89,30 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) { return true; } + // A broadcase of scalar can trivially change its shape. + if (instruction->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(instruction->operand(0)->shape())) { + return true; + } + return false; } -// Finds the first non-scalar operand of an instruction that is a non-trivial -// reshape or transpose. Returns the operand if it is found or nullptr if not -// found. +// Returns true iff `instruction` is a reshape/transpose instruction for which +// a shape change is nontrivial. +bool IsNontrivialReshape(const HloInstruction* instruction) { + return !ShapeUtil::IsScalar(instruction->shape()) && + IsReshapeOrTranspose(instruction) && + !CanTriviallyChangeShape(instruction->operand(0)); +} + +// Finds the first operand of an instruction that is a non-trivial reshape or +// transpose. Returns such an operand or nullptr if not found. HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( const HloInstruction* hlo) { for (HloInstruction* operand : hlo->operands()) { - if (!ShapeUtil::IsScalar(operand->shape()) && - IsReshapeOrTranspose(operand) && - !CanTriviallyChangeShape(operand->operand(0))) { - VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " + if (IsNontrivialReshape(operand)) { + VLOG(5) << "Found first non-trivial reshape operand of " << hlo->ToString(HloPrintOptions().set_print_metadata(false)) << ":\n\t" << operand->ToString(HloPrintOptions().set_print_metadata(false)); @@ -119,7 +122,7 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( return nullptr; } -// Returns whether `a` and `b` are equivalent for the purposes of this pass. +// Returns whether `a` and `b` are equivalent reshapes/transposes. bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { if (a->opcode() != b->opcode() || !ShapeUtil::SameDimensions(a->shape(), b->shape())) { @@ -136,85 +139,14 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { } } -// Returns true if all operands of `instruction` can easily change shape. -// Operands can easily change shape if they are all reshapes/transposes to and -// from the same shape. Additionally, operands like constant, rng, and any -// scalar change shape with only an adjustment of metadata. -bool AllOperandsHaveEasyShapeChanges( - const HloInstruction* instruction, - const HloInstruction* first_reshape_operand) { - auto print_no_metadata = HloPrintOptions().set_print_metadata(false); - VLOG(3) << "** Checking whether all operands have easy shape changes: " - << instruction->ToString(print_no_metadata); - // Check whether all operands: - // 0. Have the same dimensions as the output -- if not, it may be - // implicitly broadcast, which can confound the movement's - // correctness. - // - // And one of the following: - // 1. Are reshapes or transposes that have the same input and - // output shapes as all other reshaped or transposed operands. - // or - // 2. Are one of kConstant, kRng, and scalars that can change shape - // trivially, - // or - // 3. Are broadcast with a scalar operand. - for (const HloInstruction* operand : instruction->operands()) { - if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { - VLOG(5) << "Operand shape differs from output shape; may be " - "implicitly broadcast, so preventing " - "movement\n\toperand: " - << operand->ToString(print_no_metadata) << "\n\tinstruction: " - << instruction->ToString(print_no_metadata); - return false; - } - - // Skip the rest checks if the current operand is first_reshape_operand - // itself. - if (first_reshape_operand == operand) { - continue; - } - - if (AreEquivalentReshapes(first_reshape_operand, operand)) { - VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToString(print_no_metadata) - << "\n\toperand: " << operand->ToString(print_no_metadata); - continue; - } - - if (CanTriviallyChangeShape(operand)) { - VLOG(5) << "Operand can trivially change shape: " - << operand->ToString(print_no_metadata); - continue; - } - - if (IsBroadcastScalarToShape(operand, first_reshape_operand->shape())) { - VLOG(5) << "Broadcast scalar to shape: " - << operand->ToString(print_no_metadata); - continue; - } - - // TODO(someone): Look into supporting general ops for the operands as - // well. - VLOG(5) << "Operand is neither equalivant to the first Reshape operand" - "nor can trivially change shape: " - << operand->ToString(print_no_metadata); - return false; - } - - VLOG(3) << "All operands have easy shape changes: " - << instruction->ToString(print_no_metadata); - return true; -} - // This function is called once we've decided to sink reshape/transpose operands // across an instruction. It returns an updated `operand` with a shape that // plays nicely with `new_operand_shape`; either it has the same shape (of the // correct type), or it is a scalar that may be implicitly broadcast. -HloInstruction* UpdateOperand(HloComputation* computation, - const HloInstruction* first_reshape_operand, +HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, const Shape& new_operand_shape, HloInstruction* operand) { + HloComputation* computation = operand->parent(); const PrimitiveType element_type = operand->shape().element_type(); const Shape new_shape = ShapeUtil::ChangeElementType(new_operand_shape, element_type); @@ -245,42 +177,24 @@ HloInstruction* UpdateOperand(HloComputation* computation, VLOG(5) << "Using existing operand of kReshape or kTranspose"; return operand->mutable_operand(0); } - case HloOpcode::kBroadcast: - CHECK(IsBroadcastScalarToShape(operand, first_reshape_operand->shape())); - VLOG(5) << "Changing broadcast"; - return computation->AddInstruction( + case HloOpcode::kBroadcast: { + CHECK(ShapeUtil::IsScalar(operand->operand(0)->shape())); + HloInstruction* inst = computation->AddInstruction( operand->CloneWithNewOperands(new_shape, operand->operands())); + VLOG(5) << "Changing broadcast from " << operand->ToString() << " to " + << inst->ToString(); + return inst; + } default: LOG(FATAL) << "Unexpected operand opcode during update: " << operand; } } -// Try to sink any reshape or transpose operands of `instruction` across it. We -// do so if `instruction` is elementwise and all operands are either equivalent -// reshapes/transposes or are trivially reshapable. -StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation, - HloInstruction* instruction) { - // Only perform sinks for live elementwise instructions with operands. - const bool is_dead = instruction->user_count() == 0 && - instruction != computation->root_instruction(); - if (!instruction->IsElementwise() || instruction->operands().empty() || - is_dead) { - return false; - } - - // Only perform sinks if there are any nontrivial reshape/transpose operands. - const HloInstruction* first_reshape_operand = - FirstNonScalarAndNonTrivialReshapeOperand(instruction); - if (!first_reshape_operand) { - return false; - } - - // Only perform sinks if all operands can easily change shape. - if (!AllOperandsHaveEasyShapeChanges(instruction, first_reshape_operand)) { - return false; - } - +// Actually performs the reshape-move transformation -- that is, sinks the +// reshape or transpose operands of `instruction` across it. +StatusOr<bool> PerformSinkReshapeOrTranspose( + HloInstruction* instruction, const HloInstruction* first_reshape_operand) { auto print_no_metadata = HloPrintOptions().set_print_metadata(false); // At this point we've decided to sink reshape/transpose operands. const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); @@ -301,8 +215,8 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation, } VLOG(3) << "Updating operand #" << i << ": " << operands[i]->ToString(print_no_metadata); - operands[i] = UpdateOperand(computation, first_reshape_operand, - new_operand_shape, operands[i]); + operands[i] = + UpdateOperand(first_reshape_operand, new_operand_shape, operands[i]); } if (HloOpcode::kFusion == instruction->opcode()) { // Here we already know `instruction` is elementwise, and no operand is @@ -314,6 +228,7 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation, *shape->mutable_layout() = new_operand_shape.layout(); } } + HloComputation* computation = instruction->parent(); HloInstruction* new_elementwise = computation->AddInstruction(instruction->CloneWithNewOperands( // `instruction` may change the element type, e.g., from @@ -348,6 +263,141 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation, return true; } +// Returns true if the instruction is a reshape-move candidate. +// +// An instruction is a reshape-move candidate if the instruction is elementwise, +// has at least one nontrivial reshape/transpose operand, and its operands are +// either trivially reshapable or are equivalent nontrivial reshapes/transposes. +bool IsReshapeMoveCandidate(HloInstruction* instruction) { + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + VLOG(5) << "** Checking instruction: " + << instruction->ToString(print_no_metadata); + + // Only perform reshape-move for live elementwise instructions with operands. + const bool is_dead = instruction->user_count() == 0 && + instruction != instruction->parent()->root_instruction(); + if (!instruction->IsElementwise() || instruction->operands().empty() || + is_dead) { + return false; + } + + // Check whether all operands: + // 0. Have the same dimensions as the output -- if not, they may be + // implicitly broadcast, which can confound the movement's + // correctness. + // + // And one of the following: + // 1. Are reshapes or transposes that have the same input and + // output shapes as all other reshaped or transposed operands. + // or + // 2. Are one of kConstant, kRng, broadcast of a scalar value, and scalars + // that can change shape trivially. + const HloInstruction* first_reshape_operand = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToString(print_no_metadata) << "\n\tinstruction: " + << instruction->ToString(print_no_metadata); + return false; + } + + if (CanTriviallyChangeShape(operand)) { + VLOG(5) << "Operand can trivially change shape: " + << operand->ToString(print_no_metadata); + continue; + } + + if (!IsNontrivialReshape(operand)) { + VLOG(5) << "Operand can't trivially change shape: " + << operand->ToString(print_no_metadata); + return false; + } + + if (first_reshape_operand == nullptr) { + first_reshape_operand = operand; + VLOG(5) << "First reshape operand " + << operand->ToString(print_no_metadata); + } else if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) + << "Operand is an equivalent reshape of the first reshape operand " + << operand->ToString(print_no_metadata); + } else { + // TODO(someone): Look into supporting general ops for the operands as + // well. + VLOG(5) << "Operand is a reshape but is not equivalent to the first " + "Reshape operand" + << operand->ToString(print_no_metadata); + return false; + } + } + + if (first_reshape_operand) { + VLOG(5) << "All operands have easy shape changes: " + << instruction->ToString(print_no_metadata); + } + + return first_reshape_operand != nullptr; +} + +// Reshape-moves all qualifying instructions in reshape_candidates. Returns +// true if it makes changes. +// +// `reshape_candidates` is a set of HloInstructions with nontrivial reshape +// operands, and a instruction in the set can be reshape-moved iff all the users +// of its nontrivial reshape operands can also be reshaped-moved. +// +// The algorithm here iteratively finds the nontrivial operands with users that +// are outside the set of `reshape_candidates`, and removes their users from +// `reshape_candidates`, until either `reshape_candidates` becomes empty or none +// of the remaining nontrivial operands have users outside `reshape_candidates`. +// In the later case, all the remaining instructions in `reshape_candidates` +// are reshape-moved and the routine returns true. +StatusOr<bool> TryReshapeMoveOnCandidates( + HloInstructionSet* reshape_candidates) { + bool removed = true; + while (!reshape_candidates->empty() && removed) { + if (VLOG_IS_ON(5)) { + for (const HloInstruction* instruction : *reshape_candidates) { + VLOG(5) << "candidate " << instruction->ToString(); + } + } + ConstHloInstructionSet nontrivial_operands; + for (const HloInstruction* instruction : *reshape_candidates) { + for (const auto* operand : instruction->operands()) { + if (IsNontrivialReshape(operand)) { + nontrivial_operands.insert(operand); + } + } + } + + removed = false; + for (auto operand : nontrivial_operands) { + if (c_any_of(operand->users(), [&](HloInstruction* user) { + return !reshape_candidates->count(user); + })) { + for (auto* user : operand->users()) { + removed |= reshape_candidates->erase(user) > 0; + } + } + } + } + + if (reshape_candidates->empty()) { + return false; + } + for (HloInstruction* instruction : *reshape_candidates) { + const HloInstruction* first_reshape_operand = + FirstNonScalarAndNonTrivialReshapeOperand(instruction); + TF_ASSIGN_OR_RETURN( + bool did_change, + PerformSinkReshapeOrTranspose(instruction, first_reshape_operand)); + CHECK(did_change); + } + return true; +} + } // namespace StatusOr<bool> ReshapeMover::Run(HloModule* module) { @@ -355,11 +405,15 @@ StatusOr<bool> ReshapeMover::Run(HloModule* module) { VLOG(2) << "Pre ReshapeMover HLO:"; XLA_VLOG_LINES(2, module->ToString()); for (auto* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool did_change, - TrySinkReshapeOrTranspose(comp, instruction)); - changed |= did_change; + HloInstructionSet reshape_candidates; + for (HloInstruction* instruction : comp->instructions()) { + if (IsReshapeMoveCandidate(instruction)) { + reshape_candidates.insert(instruction); + } } + TF_ASSIGN_OR_RETURN(bool did_change, + TryReshapeMoveOnCandidates(&reshape_candidates)); + changed |= did_change; } VLOG(2) << "Post ReshapeMover HLO:"; XLA_VLOG_LINES(2, module->ToString()); diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 4e0a0a8832..094f7319f4 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -564,15 +564,15 @@ TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { const string hlo_string = R"( HloModule TransposeMulInversedTransposeModule ENTRY TransposeMulInversedTranspose { - src0 = f32[1,20,8,32]{3,2,1,0} parameter(0) - transpose0 = f32[1,8,20,32]{3,2,1,0} transpose(src0), dimensions={0,2,1,3} + src0 = f32[20,8]{1,0} parameter(0) + transpose0 = f32[8,20]{1,0} transpose(src0), dimensions={1,0} src1 = f32[] parameter(1) - broadcast0 = f32[1,8,20,32]{3,2,1,0} broadcast(src1), dimensions={} - ROOT multiply0 = f32[1,8,20,32]{3,2,1,0} multiply(transpose0, broadcast0) + broadcast0 = f32[8,20]{1,0} broadcast(src1), dimensions={} + ROOT multiply0 = f32[8,20]{1,0} multiply(transpose0, broadcast0) } )"; - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); EXPECT_TRUE(changed); @@ -580,5 +580,75 @@ TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { op::Transpose(op::Multiply())); } +TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) { + const string hlo_string = R"( + HloModule ReshapeWithUsersOutsideCandidates + ENTRY ReshapeWithMultipleUsers { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + param1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={} + param2 = f32[20,8]{1,0} parameter(2) + reshape1 = f32[8,20]{1,0} reshape(param2) + param3 = f32[20,8]{1,0} parameter(3) + reshape2 = f32[8,20]{1,0} reshape(param3) + param4 = f32[8,20]{1,0} parameter(4) + add0 = f32[8,20]{1,0} add(reshape0, broadcast0) + add1 = f32[8,20]{1,0} add(reshape0, reshape1) + add2 = f32[8,20]{1,0} add(reshape1, param4) + ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0}, + f32[8,20]{1,0}) tuple(add0, add1, add2) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) { + const string hlo_string = R"( + HloModule ReshapeNoUsersOutsideCandidates1 + ENTRY ReshapeWithMultipleUsers1 { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + param1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={} + param2 = f32[20,8]{1,0} parameter(2) + reshape1 = f32[8,20]{1,0} reshape(param2) + param3 = f32[20,8]{1,0} parameter(3) + reshape2 = f32[8,20]{1,0} reshape(param3) + add0 = f32[8,20]{1,0} add(reshape0, broadcast0) + add1 = f32[8,20]{1,0} add(reshape0, reshape1) + add2 = f32[8,20]{1,0} add(reshape1, reshape2) + ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0}, + f32[8,20]{1,0}) tuple(add0, add1, add2) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Tuple(op::Reshape(), op::Reshape(), op::Reshape())); +} + +TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) { + const string hlo_string = R"( + HloModule ReshapeNoUsersOutsideCandidates2 + ENTRY ReshapeWithMultipleUsers2 { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + ROOT add0 = f32[8,20]{1,0} add(reshape0, reshape0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Reshape(op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index ca8071b7bb..ec883a6cf3 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -409,6 +409,37 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables( return std::move(executables); } +StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables( + const std::vector<const HloModuleProto*>& module_protos, + std::vector<std::unique_ptr<HloModuleConfig>> module_configs, + Backend* backend, + std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors, + DeviceMemoryAllocator* device_allocator) { + VLOG(1) << Printf("BuildExecutable on service %p", this); + + VLOG(1) << "Computations:"; + for (const HloModuleProto* proto : module_protos) { + VLOG(1) << proto->name(); + } + + CHECK_EQ(module_protos.size(), module_configs.size()); + std::vector<std::unique_ptr<HloModule>> modules; + for (int64 i = 0; i < module_protos.size(); ++i) { + const HloModuleProto* proto = module_protos[i]; + const HloModuleConfig& config = *module_configs[i]; + TF_ASSIGN_OR_RETURN(auto module, + HloModule::CreateFromProto(*proto, config)); + modules.push_back(std::move(module)); + } + + TF_ASSIGN_OR_RETURN( + std::vector<std::unique_ptr<Executable>> executables, + backend->compiler()->Compile(std::move(modules), std::move(executors), + device_allocator)); + + return std::move(executables); +} + StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr<HloModuleConfig> module_config, Backend* backend, @@ -703,6 +734,47 @@ tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, return computation->SetReturnValue(arg->operand()); } +StatusOr<std::vector<perftools::gputools::StreamExecutor*>> +Service::GetExecutors(const ExecutionOptions& execution_options, + int64 requests_size, int64 request_index) const { + if (execution_options.device_handles().empty()) { + return FailedPrecondition( + "device handles must be given to execute parallel computations"); + } + if (requests_size > 1 && execution_options.device_handles_size() > 1) { + return InvalidArgument( + "Parallel requests with multiple device handles is not supported. " + "Found %lld parallel requests, with request %lld containing %d device " + "handles.", + requests_size, request_index, execution_options.device_handles_size()); + } + std::vector<perftools::gputools::StreamExecutor*> executors; + for (const auto& device_handle : execution_options.device_handles()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, device_handle)); + se::StreamExecutor* executor = replicas[0]; + CHECK(executor != nullptr); + executors.push_back(executor); + } + return executors; +} + +StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments( + const ExecutionOptions& execution_options, + tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments) { + // Resolve the allocations for the arguments of the computation, and create + // a vector of device memory offsets for the arguments from the allocations. + // In the case of partitioned computations, assume all arguments go on the + // zeroth core. + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, execution_options.device_handles(0))); + TF_ASSIGN_OR_RETURN( + std::vector<std::vector<const ShapedBuffer*>> replicated_arguments, + ResolveAndValidateArguments(arguments, replicas)); + return replicated_arguments; +} + tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); @@ -731,26 +803,10 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // is one of the executors to run the replicated computation. const ExecutionOptions& execution_options = arg->requests(i).execution_options(); - if (execution_options.device_handles().empty()) { - return FailedPrecondition( - "device handles must be given to execute parallel computations"); - } - if (arg->requests_size() > 1 && - execution_options.device_handles_size() > 1) { - return InvalidArgument( - "Parallel requests with multiple device handles is not supported. " - "Found %d parallel requests, with request %lld containing %d device " - "handles.", - arg->requests_size(), i, execution_options.device_handles_size()); - } - std::vector<perftools::gputools::StreamExecutor*> executors; - for (const auto& device_handle : execution_options.device_handles()) { - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, device_handle)); - se::StreamExecutor* executor = replicas[0]; - CHECK(executor != nullptr); - executors.push_back(executor); - } + + // Get the executors. + TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, + arg->requests_size(), i)); // Resolve the UserComputation object associated with the requested // computation and compute the program shape. @@ -767,16 +823,9 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, std::shared_ptr<const ProgramShape> program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); - // Resolve the allocations for the arguments of the computation, and create - // a vector of device memory offsets for the arguments from the allocations. - // In the case of partitioned computations, assume all arguments go on the - // zeroth core. - TF_ASSIGN_OR_RETURN( - auto replicas, - Replicas(*execute_backend_, execution_options.device_handles(0))); - TF_ASSIGN_OR_RETURN( - std::vector<std::vector<const ShapedBuffer*>> replicated_arguments, - ResolveAndValidateArguments(request.arguments(), replicas)); + // Get the replicated arguments. + TF_ASSIGN_OR_RETURN(auto replicated_arguments, + GetArguments(execution_options, request.arguments())); // Create an HloModuleConfig object for the computation, given the shape of // the program and the argument allocations. Here, we care only about the @@ -839,7 +888,103 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, tensorflow::Status Service::ExecuteGraphParallel( const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { - return Unimplemented("execute-graph-parallel is not yet implemented"); + VLOG(1) << "running execute-graph-parallel request"; + + std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments; + std::vector<std::vector<perftools::gputools::StreamExecutor*>> all_executors; + std::vector<const HloModuleProto*> module_protos; + std::vector<std::unique_ptr<HloModuleConfig>> module_configs; + std::vector<string> computation_names; + std::vector<DeviceHandle> device_handles; + + int num_requested_devices = + std::accumulate(arg->requests().begin(), arg->requests().end(), 0, + [](int a, const ExecuteGraphRequest& r) -> int { + return a + r.execution_options().device_handles_size(); + }); + if (num_requested_devices * options_.number_of_replicas() > + execute_backend_->device_count()) { + return FailedPrecondition( + "there are not enough stream executors to execute %d computations", + num_requested_devices); + } + + for (int64 i = 0; i < arg->requests_size(); ++i) { + // Get the stream executor for the i'th computation. This stream executor + // is one of the executors to run the replicated computation. + const ExecutionOptions& execution_options = + arg->requests(i).execution_options(); + const ExecuteGraphRequest& request = arg->requests(i); + TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; + TF_RET_CHECK(request.computation().has_program_shape()) + << "programe shape may not be empty"; + + // Get the executors. + TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, + arg->requests_size(), i)); + + // Get the replicated arguments. + TF_ASSIGN_OR_RETURN(auto replicated_arguments, + GetArguments(execution_options, request.arguments())); + + // Create an HloModuleConfig object for the computation, given the shape of + // the program and the argument allocations. Here, we care only about the + // shapes of the arguments, so, it is sufficient to use the arguments of + // replica 0. + TF_ASSIGN_OR_RETURN( + std::unique_ptr<HloModuleConfig> module_config, + CreateModuleConfig(request.computation().program_shape(), + replicated_arguments.front(), + request.execution_options(), + /*user_computation=*/nullptr)); + VLOG(3) + << "ExecuteGraphParallel created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); + + // Adds to the vectors to build and execute the computations after the loop. + all_arguments.push_back(replicated_arguments); + all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); + module_protos.push_back(&request.computation()); + module_configs.push_back(std::move(module_config)); + computation_names.insert(computation_names.end(), executors.size(), + request.computation().name()); + all_executors.push_back(executors); + device_handles.insert(device_handles.end(), + execution_options.device_handles().begin(), + execution_options.device_handles().end()); + } + + // Build the HloModules and compile to generate the executables. + // + // TODO(jlebar): There's currently no way to pass a device allocator to + // ExecuteGraphParallel, so we have to pass a null device_allocator below. + TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables, + BuildExecutables(module_protos, std::move(module_configs), + execute_backend_.get(), all_executors, + /*device_allocator=*/nullptr)); + std::vector<Executable*> executable_ptrs; + executable_ptrs.reserve(executables.size()); + for (const auto& executable : executables) { + executable_ptrs.push_back(executable.get()); + } + + // Execute the generated executables in parallel and return the device + // handles for each computation's output. + ExecutionProfile profile; + TF_ASSIGN_OR_RETURN( + std::vector<GlobalDataHandle> outputs, + ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, + execute_backend_.get(), device_handles, + computation_names, &profile)); + for (const GlobalDataHandle& output : outputs) { + ExecuteResponse response; + *response.mutable_output() = output; + *response.mutable_profile() = profile; + *result->add_responses() = response; + } + + VLOG(1) << "successfully completed 'execute-graph-parallel' request"; + return tensorflow::Status::OK(); } tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, @@ -872,6 +1017,20 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, *parallel_arg.add_requests() = *arg; ExecuteParallelResponse parallel_result; TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); + return PickParallelResponse(parallel_result, result); +} + +tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + ExecuteGraphParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteGraphParallel(¶llel_arg, ¶llel_result)); + return PickParallelResponse(parallel_result, result); +} + +tensorflow::Status Service::PickParallelResponse( + const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { // The "result device" selection is a bit hacky, but better than assuming it // is device 0. We have b/76035356 for restructuring the client API to clean // up the current asymmetries and support more functionalities. @@ -999,8 +1158,14 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } + if (!arg->computation().has_program_shape()) { + return InvalidArgument("programe shape may not be empty"); + } - // TODO(b/74197823): Handle partitioning. + // If we received multiple device handles, we must partition the module. + if (arg->execution_options().device_handles_size() > 1) { + return ExecuteOneToN(arg, result); + } TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index ebe4a2e043..e09d58bbe7 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -278,6 +278,20 @@ class Service : public ServiceInterface { const ExecutionOptions& execution_options, const UserComputation* user_computation = nullptr); + // Picks a parallel response and fills the result. + Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, + ExecuteResponse* result); + + // Prepare the executors for executing parallel. + StatusOr<std::vector<perftools::gputools::StreamExecutor*>> GetExecutors( + const ExecutionOptions& execution_options, int64 requests_size, + int64 request_index) const; + + // Prepare the arguments for executing parallel. + StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments( + const ExecutionOptions& execution_options, + tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments); + protected: friend class LocalExecutable; @@ -334,6 +348,12 @@ class Service : public ServiceInterface { Backend* backend, std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors, DeviceMemoryAllocator* device_allocator); + StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables( + const std::vector<const HloModuleProto*>& module_protos, + std::vector<std::unique_ptr<HloModuleConfig>> module_configs, + Backend* backend, + std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors, + DeviceMemoryAllocator* device_allocator); // Similar to BuildExecutable, but look in the compilation cache for the // executable first. If the executable is not in the cache, it is built and @@ -378,6 +398,8 @@ class Service : public ServiceInterface { // will be the result of this computation. tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result); + tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index fcdb2e01fb..532f7fd5bf 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -3491,7 +3491,6 @@ void ComputationLowerer::Visit( HloInstruction* operand = lookup_instruction(trace_request.operand()); hlo_instruction = add_instruction( HloInstruction::CreateTrace(trace_request.tag(), operand)); - operand->set_tracing(hlo_instruction); break; } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e337669aeb..6f58c20f34 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -347,10 +347,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -937,8 +937,8 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -977,9 +977,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -1444,9 +1443,9 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1566,6 +1565,8 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 4a9faef1dc..17c6a83c1a 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -601,6 +601,12 @@ ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } +XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, + XlaBuilder* builder) { + return builder->ConstantLiteral( + use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); +} + template void ClientLibraryTestBase::ComputeAndCompareLiteral( ComputationBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index be90f14c8e..52f31b0669 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -312,6 +312,7 @@ class ClientLibraryTestBase : public ::testing::Test { // will be converted to BF16s. ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, ComputationBuilder* builder); + XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); // Creates a constant instruction with the given array. When the use_bfloat16 // flag is set but the array has float elements, the elements will be @@ -322,6 +323,12 @@ class ClientLibraryTestBase : public ::testing::Test { return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); } + template <typename NativeT> + XlaOp CreateConstantFromArray(const Array<NativeT>& array, + XlaBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + } + // Same as CreateConstantFromArray, but for scalars. template <typename NativeT> ComputationDataHandle CreateConstantFromScalar(NativeT value, @@ -330,6 +337,12 @@ class ClientLibraryTestBase : public ::testing::Test { builder); } + template <typename NativeT> + XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value), + builder); + } + // Creates a parameter instruction that wraps a given value and then stores // into "data_handle" the global handle for that parameter. // diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 045148cdd1..32e2f2c084 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -109,14 +111,14 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { XLA_TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { - Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg; + XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> const_arg, client_->TransferToServer(*Literal::CreateR2<int32>({{5, 6}, {7, 8}}))); - ComputationBuilder b(client_, TestName() + ".add"); + XlaBuilder b(TestName() + ".add"); b.Add(b.Parameter(0, shape, "param_0"), b.ConstantR2<int32>({{1, 2}, {3, 4}})); TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); @@ -124,14 +126,14 @@ XLA_TEST_F(ClientTest, // We can't really test parallel execution on CPU since all of the cores in a // CPU are presented as a single device. So for now we test "parallel" // execution on a single device. - std::vector<Client::ComputationInstance> computation_instances; + std::vector<Client::XlaComputationInstance> computation_instances; TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices, client_->GetDeviceHandles(1)); ASSERT_EQ(devices.size(), 1); ExecutionOptions options = execution_options_; *options.add_device_handles() = devices[0]; - computation_instances.push_back(Client::ComputationInstance( + computation_instances.push_back(Client::XlaComputationInstance( add_with_one_arg, {const_arg.get()}, options, nullptr)); TF_ASSERT_OK_AND_ASSIGN(auto results, diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index fb0e9c724a..a4c8a83eb1 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -38,9 +38,9 @@ using ::testing::HasSubstr; // Concatenate expects at least one argument. XLA_TEST_F(ConcatTest, Concat_Nothing) { - ComputationBuilder builder(client_, TestName()); - auto concatenated = builder.ConcatInDim({}, 0); - StatusOr<Computation> computation_status = builder.Build(); + XlaBuilder builder(TestName()); + builder.ConcatInDim({}, 0); + StatusOr<XlaComputation> computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), HasSubstr("Concatenate expects at least one argument")); @@ -48,18 +48,18 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) { // Concatenate with one argument works. XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1<float>({42.0, 64.0}); - auto concatenated = builder.ConcatInDim({a}, 0); + builder.ConcatInDim({a}, 0); std::vector<float> expected = {42, 64}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1<float>({}); - auto concatenated = builder.ConcatInDim({a}, 0); + builder.ConcatInDim({a}, 0); std::vector<float> expected = {}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); @@ -68,51 +68,51 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { // Show that we can't concatenate R0 with R0 because we can't name the dimension // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR0<float>(42.0); auto b = builder.ConstantR0<float>(64.0); - auto concatenated = builder.ConcatInDim({a, b}, 0); - StatusOr<Computation> computation_status = builder.Build(); + builder.ConcatInDim({a, b}, 0); + StatusOr<XlaComputation> computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), HasSubstr("out of bounds: 0")); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1<float>({}); auto b = builder.ConstantR1<float>({}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector<float> expected = {}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1<float>({}); auto b = builder.ConstantR1<float>({256.0}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector<float> expected = {256}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1<float>({42.0, 64.0}); auto b = builder.ConstantR1<float>({}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector<float> expected = {42, 64}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1<float>({42.0, 64.0}); auto b = builder.ConstantR1<float>({256.0}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector<float> expected = {42, 64, 256}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); @@ -129,20 +129,20 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { expected[253 + i] = rhs[i] = 253 + i + 1; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1<float>(lhs); auto b = builder.ConstantR1<float>(rhs); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { for (int dim : {0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(Array2D<float>(0, 0)); auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 0)); - auto concatenated = builder.ConcatInDim({a, b}, dim); + builder.ConcatInDim({a, b}, dim); ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, ErrorSpec(0.0001)); @@ -150,26 +150,27 @@ XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { } XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); Array2D<float> expected({ - {0}, {64}, + {0}, + {64}, }); ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); Array2D<float> expected({ {0, 64}, @@ -178,22 +179,22 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { } XLA_TEST_F(ConcatTest, Concat2x0With2x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(Array2D<float>(2, 0)); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat2x3With2x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(2, 3); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); Array2D<float> expected({ {0, 1, 2, 64, 65, 66, 67, 68}, @@ -203,22 +204,22 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) { } XLA_TEST_F(ConcatTest, Concat3x2With0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 2)); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat3x2With5x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); Array2D<float> expected({ {0, 1}, @@ -234,16 +235,16 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) { } XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 2)); auto b = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 1)); - auto concatenated = builder.ConcatInDim({a, b}, 2); + builder.ConcatInDim({a, b}, 2); ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D<float> a_array({ // 3x1x2 {{0, 1}}, @@ -258,27 +259,29 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { }); auto a = builder.ConstantR3FromArray3D(a_array); auto b = builder.ConstantR3FromArray3D(b_array); - auto concatenated = builder.ConcatInDim({a, b}, 2); + builder.ConcatInDim({a, b}, 2); Array3D<float> expected({ - {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}}, + {{0, 1, 6}}, + {{2, 3, 7}}, + {{4, 5, 8}}, }); ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1<float>({42.0}); auto b = builder.ConstantR1<float>({64.0}); auto c = builder.ConstantR1<float>({256.0}); - auto concatenated = builder.ConcatInDim({a, b, c}, 0); + builder.ConcatInDim({a, b, c}, 0); std::vector<float> expected = {42, 64, 256}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D<float> a_array({ // 3x1x2 {{0, 1}}, @@ -300,35 +303,35 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { auto a = builder.ConstantR3FromArray3D(a_array); auto b = builder.ConstantR3FromArray3D(b_array); auto c = builder.ConstantR3FromArray3D(c_array); - auto concatenated = builder.ConcatInDim({a, b, c}, 2); + builder.ConcatInDim({a, b, c}, 2); Array3D<float> expected({ - {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}}, + {{0, 1, 2, 3}}, + {{4, 5, 6, 7}}, + {{8, 9, 10, 11}}, }); ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1<float>({42.0}); auto b = builder.ConstantR1<float>({64.0}); auto c = builder.ConstantR1<float>({256.0}); // concatenated = (a concat b) concat c - auto concatenated = - builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); + builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); std::vector<float> expected = {42, 64, 256}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1<float>({42.0}); auto b = builder.ConstantR1<float>({64.0}); auto c = builder.ConstantR1<float>({256.0}); // concatenated = a concat (b concat c) - auto concatenated = - builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); + builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); std::vector<float> expected = {42, 64, 256}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); @@ -342,7 +345,7 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) { rhs(0, i) = i + 1024; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D<float>(lhs); auto b = builder.ConstantR2FromArray2D<float>(rhs); builder.ConcatInDim({a, b}, 0); @@ -363,7 +366,7 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) { rhs(0, i) = i + 1024; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D<float>(lhs); auto b = builder.ConstantR2FromArray2D<float>(rhs); builder.ConcatInDim({a, b}, 1); @@ -388,7 +391,7 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D<float>(lhs); auto b = builder.ConstantR2FromArray2D<float>(rhs); builder.ConcatInDim({a, b}, 1); @@ -404,13 +407,13 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { // Show that we can't concatenate with an opaques. XLA_TEST_F(ConcatTest, CannotConcatOpaques) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto opaque_shape = ShapeUtil::MakeOpaqueShape(); auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); auto x = builder.Parameter(0, r1f32, "x"); auto y = builder.Parameter(1, opaque_shape, "y"); - auto concatenated = builder.ConcatInDim({x, y}, 0); - StatusOr<Computation> computation_status = builder.Build(); + builder.ConcatInDim({x, y}, 0); + StatusOr<XlaComputation> computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), @@ -418,23 +421,23 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { } XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto p0 = builder.ConstantR1<bool>({true}); auto p1 = builder.ConstantR1<bool>({false}); auto p2 = builder.ConstantR1<bool>({true}); - auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0); + builder.ConcatInDim({p0, p1, p2}, 0); bool expected[] = {true, false, true}; ComputeAndCompareR1<bool>(&builder, expected, {}); } XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a0 = builder.ConstantR1<int32>({1}); auto a1 = builder.ConstantR1<int32>({2, 3}); auto a2 = builder.ConstantR1<int32>({4, 5, 6}); auto a3 = builder.ConstantR1<int32>({7, 8, 9, 10}); - auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0); + builder.ConcatInDim({a0, a1, a2, a3}, 0); std::vector<int32> expected(10); std::iota(expected.begin(), expected.end(), 1); @@ -442,7 +445,7 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { } XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D<float> arr0(9, 17, 1); arr0.Fill(1); @@ -462,14 +465,14 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { } } - ComputationDataHandle h0; + XlaOp h0; auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0", &builder, &h0); - ComputationDataHandle h1; + XlaOp h1; auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1", &builder, &h1); - auto concatenated = builder.ConcatInDim({h0, h1}, 2); + builder.ConcatInDim({h0, h1}, 2); ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()}); } @@ -495,7 +498,7 @@ TEST_P(ConcatR2BinaryTest, DoIt) { Array2D<int32> rhs(spec.rhs_dim0, spec.rhs_dim1); rhs.FillUnique(1000); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a0 = builder.ConstantR2FromArray2D<int32>(lhs); auto a1 = builder.ConstantR2FromArray2D<int32>(rhs); builder.ConcatInDim({a0, a1}, spec.concat_dimension); @@ -521,7 +524,7 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, f32_scalar, "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto mul = builder.Mul(x, y); @@ -545,7 +548,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, x_literal->shape(), "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto z = builder.Parameter(2, f32_scalar, "z"); @@ -573,7 +576,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, x_literal->shape(), "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto z = builder.Parameter(2, f32_scalar, "y"); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 9a899b7914..0842a8918b 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -230,6 +230,43 @@ XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()}); } +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { + ComputationBuilder builder(client_, TestName()); + // Test cases from compiler_rt library. + std::vector<float> arg{0.0f, + 0.5f, + 0.99f, + 1.0f, + 1.5f, + 1.99f, + 2.0f, + 2.01f, + 2147483648.f, + -0.5f, + -0.99f, + -1.0f, + -1.5f, + -1.99f, + -2.0f, + -2.01f, + 0x1.FFFFFEp+62F, + 0x1.FFFFFCp+62F, + -0x1.FFFFFEp+62F, + -0x1.FFFFFCp+62F}; + std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr<GlobalData> arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, S64); + + std::vector<int64> expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast<int64>(arg[i]); + } + ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()}); +} + XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1<uint8_t>({32, 64}); diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 4f354e6aef..5f00c34002 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -18,9 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" @@ -112,10 +111,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { void TestR3Wrap() { // Slice at dimension boundaries, but with sizes that cause indices to wrap. RunR3<IndexT, DataT>( - {{{1, 2}, {3, 4}, {5, 6}}, - {{7, 8}, {9, 10}, {11, 12}}}, - {0, 2, 1}, {2, 1, 2}, - {{{6, 5}}, {{12, 11}}}); + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, + {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); } template <typename IndexT, typename DataT> @@ -137,9 +134,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -163,9 +160,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -189,9 +186,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -282,6 +279,15 @@ XLA_TEST_F(DynamicSliceTest, Int32R3Pred) { class DynamicUpdateSliceTest : public ClientLibraryTestBase { protected: template <typename IndexT, typename DataT> + void TestR0() { + // Disable algebraic simplifier, otherwise the op will be replaced by a + // constant. + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "algsimp"); + RunR0<IndexT, DataT>(0, 123, {}, 123); + } + + template <typename IndexT, typename DataT> void TestR1() { // Slice at dimension start. RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {0}, @@ -342,6 +348,35 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template <typename IndexT, typename DataT> + void RunR0(int input_value_int, int update_value_int, + const std::vector<IndexT> slice_starts, int expected_value_int) { + Literal input_value = + std::move(*Literal::CreateR0(input_value_int) + ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) + .ValueOrDie()); + Literal update_value = + std::move(*Literal::CreateR0(update_value_int) + ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) + .ValueOrDie()); + Literal expected_value = + std::move(*Literal::CreateR0(expected_value_int) + ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) + .ValueOrDie()); + + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantLiteral(input_value); + auto update = builder.ConstantLiteral(update_value); + builder.DynamicUpdateSlice(input, update, starts); + // Run computation and compare against expected values. + ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()}); + } + + template <typename IndexT, typename DataT> void RunR1(tensorflow::gtl::ArraySlice<int> input_values_int, tensorflow::gtl::ArraySlice<int> update_values_int, const std::vector<IndexT> slice_starts, @@ -359,9 +394,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -390,9 +425,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -421,9 +456,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -474,13 +509,13 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } // Build dynamic slice computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer input parameter. - ComputationDataHandle input; + XlaOp input; std::unique_ptr<GlobalData> input_data = CreateR3Parameter<T>(input_values, 0, "input_values", &builder, &input); // Initialize and transfer update parameter. - ComputationDataHandle update; + XlaOp update; std::unique_ptr<GlobalData> update_data = CreateR3Parameter<T>( update_values, 1, "update_values", &builder, &update); auto starts = builder.ConstantR1<int32>({index, 0, 0}); @@ -500,6 +535,11 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } }; +XLA_TEST_F(DynamicUpdateSliceTest, Int32R0BF16) { TestR0<int32, bfloat16>(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0<int32, float>(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0<int64, float>(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0<uint64, float>(); } + // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) { TestR1<int32, bfloat16>(); @@ -672,7 +712,7 @@ void BM_DynamicSlice(int num_iters) { TransferManager::GetForPlatform(platform).ValueOrDie(); int device_ordinal = client->default_device_ordinal(); - ComputationBuilder builder(client, "DynamicSlice"); + XlaBuilder builder("DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] auto input_literal = Literal::CreateR4( diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 3a097a01ab..d24927d22b 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -57,6 +57,11 @@ limitations under the License. namespace xla { namespace { +using FuncGeneratorForType = Computation (*)(PrimitiveType, + ComputationBuilder*); + +using FuncGenerator = Computation (*)(ComputationBuilder*); + class ReduceTest : public ClientLibraryTestBase { protected: ReduceTest() { @@ -755,53 +760,57 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { } XLA_TEST_F(ReduceTest, VectorizedReduce_Add) { - RunVectorizedReduceTest(CreateScalarAddComputation, - [](float a, float b) { return a + b; }, - [](int32 a, int32 b) { - return static_cast<int32>(static_cast<uint32>(a) + - static_cast<uint32>(b)); - }, - [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0); + RunVectorizedReduceTest( + static_cast<FuncGeneratorForType>(CreateScalarAddComputation), + [](float a, float b) { return a + b; }, + [](int32 a, int32 b) { + return static_cast<int32>(static_cast<uint32>(a) + + static_cast<uint32>(b)); + }, + [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0); } XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) { - RunVectorizedReduceTest(CreateScalarMultiplyComputation, - [](float a, float b) { return a * b; }, - [](int32 a, int32 b) { - return static_cast<int32>(static_cast<uint32>(a) * - static_cast<uint32>(b)); - }, - [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1); + RunVectorizedReduceTest( + static_cast<FuncGeneratorForType>(CreateScalarMultiplyComputation), + [](float a, float b) { return a * b; }, + [](int32 a, int32 b) { + return static_cast<int32>(static_cast<uint32>(a) * + static_cast<uint32>(b)); + }, + [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1); } XLA_TEST_F(ReduceTest, VectorizedReduce_Max) { - RunVectorizedReduceTest(CreateScalarMaxComputation, - [](float a, float b) { return std::max(a, b); }, - [](int32 a, int32 b) { return std::max(a, b); }, - [](uint32 a, uint32 b) { return std::max(a, b); }, - std::numeric_limits<float>::min(), - std::numeric_limits<int32>::min(), - std::numeric_limits<uint32>::min()); + RunVectorizedReduceTest( + static_cast<FuncGeneratorForType>(CreateScalarMaxComputation), + [](float a, float b) { return std::max(a, b); }, + [](int32 a, int32 b) { return std::max(a, b); }, + [](uint32 a, uint32 b) { return std::max(a, b); }, + std::numeric_limits<float>::min(), std::numeric_limits<int32>::min(), + std::numeric_limits<uint32>::min()); } XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { - RunVectorizedReduceTest(CreateScalarMinComputation, - [](float a, float b) { return std::min(a, b); }, - [](int32 a, int32 b) { return std::min(a, b); }, - [](uint32 a, uint32 b) { return std::min(a, b); }, - std::numeric_limits<float>::max(), - std::numeric_limits<int32>::max(), - std::numeric_limits<uint32>::max()); + RunVectorizedReduceTest( + static_cast<FuncGeneratorForType>(CreateScalarMinComputation), + [](float a, float b) { return std::min(a, b); }, + [](int32 a, int32 b) { return std::min(a, b); }, + [](uint32 a, uint32 b) { return std::min(a, b); }, + std::numeric_limits<float>::max(), std::numeric_limits<int32>::max(), + std::numeric_limits<uint32>::max()); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) { RunVectorizedReduceTestForType<bool>( - CreateScalarAndComputation, [](bool a, bool b) { return a && b; }, true); + static_cast<FuncGenerator>(CreateScalarAndComputation), + [](bool a, bool b) { return a && b; }, true); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) { RunVectorizedReduceTestForType<bool>( - CreateScalarOrComputation, [](bool a, bool b) { return a || b; }, false); + static_cast<FuncGenerator>(CreateScalarOrComputation), + [](bool a, bool b) { return a || b; }, false); } class ReduceR3ToR2Test : public ReduceTest, diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 9c317fe579..8dd24f1237 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -252,6 +252,48 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { DefaultErrorSpec()); } +// Tests the super windowing logic w.r.t handling prime number of windows in a +// major dimension with reduction. +TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { + Array4D<float> input_array(15, 15, 4, 128); + input_array.FillRandom(2.f, 4.f); + + int win_len = 3; + int win_stride = 2; + + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); + + Padding padding = Padding::kSame; + // Reduce only along the x and y dimensions, according to the win_len. + ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); +} + +TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { + Array4D<float> input_array(19, 17, 8, 256); + input_array.FillWithMinorDimNum(); + + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); + + Padding padding = Padding::kSame; + ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); +} + // Tests a reduction function that is not a simple add/min/max/etc. XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Array4D<float> input_array(1, 2, 2, 1); diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 574c494c6d..69fbe98bd6 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -41,7 +41,7 @@ TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { Array3D<float> values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D<float>(values); builder.Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); @@ -54,7 +54,7 @@ TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) { Array3D<float> values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D<float>(values); builder.Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); @@ -67,7 +67,7 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { Array3D<float> values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D<float>(values); builder.Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); @@ -77,7 +77,7 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { } XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0)); builder.Slice(original, {0, 0}, {0, 0}, {1, 1}); @@ -85,7 +85,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { } XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 20)); builder.Slice(original, {0, 15}, {0, 20}, {1, 1}); @@ -93,7 +93,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { } XLA_TEST_F(SliceTest, Slice3x0to2x0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0)); builder.Slice(original, {1, 0}, {3, 0}, {1, 1}); @@ -108,7 +108,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D<float>(values); builder.Slice(original, {128, 128}, {256, 256}, {1, 1}); @@ -126,7 +126,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) { Array2D<float> values(1, 4096); std::iota(values.data(), values.data() + 4096, 0.0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D<float>(values); builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1}); @@ -147,7 +147,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { } } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D<float>(values); builder.Slice(original, {0, 0}, {16, 2}, {1, 1}); ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001)); @@ -159,7 +159,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { values.FillRandom(3.14f); auto expected = ReferenceUtil::Slice4D( values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}}); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR4FromArray4D(values); builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); @@ -172,7 +172,7 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { /*strides=*/{{1, 1, 2, 1}}); auto expected_literal = Literal::CreateR4FromArray4DWithLayout( *expected, LayoutUtil::MakeLayout({0, 1, 2, 3})); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR4FromArray4D(values); builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), @@ -198,7 +198,7 @@ class SliceR1Test : public ClientLibraryTestBase, tensorflow::gtl::InlinedVector<NativeT, 1> input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR1<NativeT>(input); builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); @@ -363,7 +363,7 @@ XLA_TEST_P(SliceR2Test, DoIt) { Array2D<int32> input(spec.input_dim0, spec.input_dim1); input.FillUnique(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2DWithLayout<int32>( input, LayoutUtil::MakeLayout(spec.layout)); builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); @@ -453,7 +453,7 @@ class SliceR4Test : public ClientLibraryTestBase, values.FillRandom(3.14f); auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto literal = Literal::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); auto parameter = builder.Parameter(0, literal->shape(), "p0"); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 0bc7df2a65..821432ef7d 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -23,14 +23,14 @@ namespace xla { namespace { -template <typename FloatT> -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { +template <typename FloatT, typename GeneratorT> +void PopulateWithRandomFloatingPointDataImpl(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType<FloatT>()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal // numbers. - std::uniform_real_distribution<FloatT> generator(1.0f, 1.125f); + std::uniform_real_distribution<GeneratorT> generator(1.0f, 1.125f); const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; TF_CHECK_OK(literal->Populate<FloatT>( [&](tensorflow::gtl::ArraySlice<int64> indices) { @@ -52,10 +52,22 @@ void PopulateWithRandomFloatingPointData(Literal* literal, FloatT index_bias = static_cast<FloatT>(index_product % 113 - negative_bias) / static_cast<FloatT>(256.0f); - return (generator(*engine) - 1.0625) + index_bias; + return static_cast<FloatT>(generator(*engine) - 1.0625f) + index_bias; })); } +template <typename FloatT> +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine); +} + +template <> +void PopulateWithRandomFloatingPointData<half>(Literal* literal, + std::minstd_rand0* engine) { + PopulateWithRandomFloatingPointDataImpl<half, float>(literal, engine); +} + // The standard library does not have a case for bfloat16, unsurprisingly, so we // handle that one specially. template <> @@ -100,6 +112,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( case BF16: PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine); break; + case F16: + PopulateWithRandomFloatingPointData<half>(literal.get(), engine); + break; case F32: PopulateWithRandomFloatingPointData<float>(literal.get(), engine); break; diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 33d457c70b..89ce2ce797 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -54,29 +54,28 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0<int32>(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0<int32>(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0<int32>(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0<int32>(&builder, 5, {}); } @@ -91,29 +90,28 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { auto result_shape = ShapeUtil::MakeShape(S64, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0<int64>(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0<int64>(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0<int64>(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0<int64>(&builder, 5, {}); } @@ -123,31 +121,30 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { auto orig_shape = ShapeUtil::MakeShape(S32, {2}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0<int32>(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0<int32>(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1), builder.ConstantR0<int32>(0), CreateScalarAddComputation(S32, &builder), {0}); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0<int32>(&builder, 5, {}); } @@ -156,28 +153,28 @@ TEST_F(WhileTest, WhileWithPredicateResult) { auto result_shape = ShapeUtil::MakeShape(PRED, {}); // Create a computation for the condition: run until condition is true. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Ne(builder.ConstantR0<bool>(true), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: or condition with true. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.Or(prev, builder.ConstantR0<bool>(true)); + builder.Or(prev, builder.ConstantR0<bool>(true)); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.Ne(builder.ConstantR0<bool>(false), builder.ConstantR0<bool>(true)); - auto result = builder.While(condition, body, init); + builder.While(condition, body, init); ComputeAndCompareR0<bool>(&builder, true, {}); } @@ -194,9 +191,9 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -205,33 +202,34 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 15.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum); + builder.Gt(builder.ConstantR0<float>(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1<float>({}); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1<float>({}); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001)); } @@ -247,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -258,33 +256,34 @@ TEST_F(WhileTest, WhileWithVectorResult) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum); + builder.Gt(builder.ConstantR0<float>(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1<float>(8, 0.125f); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1<float>(8, 0.f); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); // Individual elements with increase by 1/8 each time through the loop, so // the sum will increase by 1.0. It will first be >15.5 when the elements @@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -317,34 +316,34 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum); + builder.Gt(builder.ConstantR0<float>(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1<float>(8, 0.125f); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1<float>(8, 0.f); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); builder.Tuple({result}); // Individual elements with increase by 1/8 each time through the loop, so @@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { // Create a computation for the condition. // Repeat for N iterations. const int N = 2; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(N), iteration); @@ -377,28 +376,28 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and permute the weights. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto w1 = builder.GetTupleElement(prev, 1); auto w2 = builder.GetTupleElement(prev, 2); auto w3 = builder.GetTupleElement(prev, 3); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f), builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0<int32>(N); auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f}); @@ -419,9 +418,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // Create a computation for the condition. // Repeat for N iterations. const int N = 2; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(N), iteration); @@ -430,21 +429,21 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // Create a computation for the body. // Add 1 to the iteration variable permute the weights. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto w1 = builder.GetTupleElement(prev, 1); auto w2 = builder.GetTupleElement(prev, 2); auto w3 = builder.GetTupleElement(prev, 3); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f), builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)}); @@ -455,7 +454,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); std::vector<float> expected = {6.f, 6.f, 6.f}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); } @@ -474,9 +473,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(5), iteration); @@ -486,26 +485,27 @@ TEST_F(WhileTest, WhileWithTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1<float>(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0<int32>(5); auto expected_data = Literal::CreateR1<float>( @@ -523,9 +523,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(5), iteration); @@ -534,27 +534,27 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and or the predicate with true - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto pred = builder.GetTupleElement(prev, 1); auto new_pred = builder.Or(pred, builder.ConstantR0<bool>(true)); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple({builder.ConstantR0<int32>(0), builder.Ne(builder.ConstantR0<bool>(false), builder.ConstantR0<bool>(true))}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0<int32>(5); auto expected_predicate = Literal::CreateR0<bool>(true); @@ -570,9 +570,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(5), iteration); @@ -582,25 +582,24 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // Create a computation for the body. // Add 1 to the iteration variable and set the other tuple element to a // constant. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); - auto result = - builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)), - builder.ConstantR0<int32>(7)}); + builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)), + builder.ConstantR0<int32>(7)}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0<int32>(5); auto expected_data = Literal::CreateR0<int32>(7); @@ -631,20 +630,20 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c2)); @@ -654,34 +653,34 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1<float>(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - Computation body2; + XlaComputation body2; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1<float>(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -692,11 +691,11 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector<float> expected(10, sum); ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); @@ -710,20 +709,20 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c2)); @@ -733,21 +732,21 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1<float>(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -758,11 +757,11 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector<float> expected(10, sum); ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); @@ -777,20 +776,20 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c2)); @@ -800,21 +799,21 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1<float>(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -824,11 +823,11 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector<float> expected(10, sum); ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); @@ -844,9 +843,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(5), iteration); @@ -856,9 +855,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); @@ -873,18 +872,18 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); - auto result = builder.Tuple({out0, out1}); + builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0<int32>(5); auto expected_data = Literal::CreateR1<float>( @@ -915,18 +914,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { // Create a computation for the condition: repeat for count iterations. auto build_condition = [this, v6s32](int count) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prev = builder.Reshape( builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0}, - {}); + {}); builder.Gt(builder.ConstantR0<int32>(count), prev); return builder.Build().ConsumeValueOrDie(); }; // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, v6s32, "prev"); auto inc = builder.ConcatInDim( {builder.ConstantR1<int32>({1}), @@ -934,16 +933,15 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { builder.ConstantR0<int32>(100), ShapeUtil::MakeShape(S32, {5}))}, 0); - auto result = builder.Add(inc, prev); + builder.Add(inc, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. auto while_loop = [this, &body, build_condition](int count) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0}); - auto result = builder.While(build_condition(count), body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(build_condition(count), body, init); return builder.Build(); }; @@ -1107,9 +1105,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { auto inner_result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); - Computation inner_condition; + XlaComputation inner_condition; { - ComputationBuilder builder(client_, "inner_condition"); + XlaBuilder builder("inner_condition"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); builder.Lt(i, builder.ConstantR0<int32>(7)); @@ -1118,9 +1116,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // Creates a computation for the outer loop condition: // repeat while result < 30. - Computation outer_condition; + XlaComputation outer_condition; { - ComputationBuilder builder(client_, "outer_condition"); + XlaBuilder builder("outer_condition"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); builder.Lt(prev, builder.ConstantR0<int32>(30)); outer_condition = builder.Build().ConsumeValueOrDie(); @@ -1128,34 +1126,33 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to // `result`. - Computation inner_body; + XlaComputation inner_body; { - ComputationBuilder builder(client_, "inner_body"); + XlaBuilder builder("inner_body"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); auto result = builder.GetTupleElement(params, 1); i = builder.Add(builder.ConstantR0<int32>(1), i); result = builder.Add(builder.ConstantR0<int32>(2), result); - auto output = builder.Tuple({i, result}); + builder.Tuple({i, result}); inner_body = builder.Build().ConsumeValueOrDie(); } // Creates a computation for the outer loop: run the inner loop with i = 0. - Computation outer_body; + XlaComputation outer_body; { - ComputationBuilder builder(client_, "outer_body"); + XlaBuilder builder("outer_body"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev}); auto result = builder.While(inner_condition, inner_body, init); - auto output = builder.GetTupleElement(result, 1); + builder.GetTupleElement(result, 1); outer_body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0<int32>(0); - auto result = builder.While(outer_condition, outer_body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(outer_condition, outer_body, init); ComputeAndCompareR0<int32>(&builder, 42, {}); } @@ -1170,18 +1167,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition_callee; + XlaComputation condition_callee; { - ComputationBuilder builder(client_, "condition_callee"); + XlaBuilder builder("condition_callee"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)}); condition_callee = builder.Build().ConsumeValueOrDie(); } - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto result = builder.Call(condition_callee, {prev}); builder.GetTupleElement(result, 0); @@ -1189,20 +1186,19 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0<int32>(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0<int32>(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0<int32>(&builder, 5, {}); } @@ -1214,28 +1210,28 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {scalar_s32, matrix_shape, matrix_shape, matrix_shape}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto state = builder.Parameter(0, while_shape, "state"); builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto state = builder.Parameter(0, while_shape, "state"); auto indvar = builder.GetTupleElement(state, 0); auto input_0 = builder.GetTupleElement(state, 1); auto input_1 = builder.GetTupleElement(state, 2); auto output = builder.Tanh(builder.Dot(input_0, input_1)); auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1)); - auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output}); + builder.Tuple({indvar_next, input_0, input_1, output}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input}); @@ -1268,9 +1264,9 @@ void BM_WhileLoop(int num_iters) { // Create while condition computation with 'loop_limit'. const int32 loop_limit = 100; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit)); @@ -1278,9 +1274,9 @@ void BM_WhileLoop(int num_iters) { } // Create while body computation with unit loop increment. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); @@ -1294,12 +1290,12 @@ void BM_WhileLoop(int num_iters) { auto starts = builder.ConstantR1<int32>({0, 0, 0}); // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); - auto result = builder.Tuple({out0, out1}); + builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While instruction. - ComputationBuilder builder(client, "while"); + XlaBuilder builder("while"); auto zero = builder.ConstantR0<float>(0.0); auto input = builder.Broadcast(zero, {seq_len, 1024, 1024}); auto init = builder.Tuple({builder.ConstantR0<int32>(0), input}); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 863081d654..adc8b1d620 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <string> #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -894,7 +895,7 @@ class HloParserTest : public ::testing::Test, public ::testing::WithParamInterface<TestData> { protected: static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) + EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 0cebb49afb..bf69144ad8 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -8,6 +8,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +load("//tensorflow:tensorflow.bzl", "if_not_windows") py_library( name = "contrib_py", @@ -40,7 +41,6 @@ py_library( "//tensorflow/contrib/estimator:estimator_py", "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/feature_column:feature_column_py", - "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/fused_conv:fused_conv_py", "//tensorflow/contrib/gan", @@ -63,7 +63,6 @@ py_library( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", - "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", @@ -117,7 +116,10 @@ py_library( "//tensorflow/contrib/kafka", ], "//conditions:default": [], - }), + }) + if_not_windows([ + "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", + "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code + ]), ) cc_library( diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index a8e05df708..1c5b00f92e 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -1,3 +1,4 @@ +# pylint: disable=g-import-not-at-top # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import batching from tensorflow.contrib import bayesflow @@ -84,7 +87,8 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager -from tensorflow.contrib.lite.python import lite +if os.name != "nt": + from tensorflow.contrib.lite.python import lite from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph @@ -94,6 +98,7 @@ from tensorflow.contrib.summary import summary from tensorflow.python.util.lazy_loader import LazyLoader ffmpeg = LazyLoader("ffmpeg", globals(), "tensorflow.contrib.ffmpeg") +del os del LazyLoader del absolute_import diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc index 380a652435..513d519eab 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.cc +++ b/tensorflow/contrib/android/asset_manager_filesystem.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system_helper.h" namespace tensorflow { namespace { @@ -228,9 +229,8 @@ string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) { } string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) { - string output(name); - StringPiece piece(output); - piece.Consume(prefix_); + StringPiece piece(name); + str_util::ConsumePrefix(&piece, prefix_); return piece.ToString(); } @@ -243,6 +243,11 @@ bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) { return AAssetDir_getNextFileName(dir.get()) != NULL; } +Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern, + std::vector<string>* results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + Status AssetManagerFileSystem::NewWritableFile( const string& fname, std::unique_ptr<WritableFile>* result) { return errors::Unimplemented("Asset storage is read only."); diff --git a/tensorflow/contrib/android/asset_manager_filesystem.h b/tensorflow/contrib/android/asset_manager_filesystem.h index 665304b5ee..a87ff42ae2 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.h +++ b/tensorflow/contrib/android/asset_manager_filesystem.h @@ -66,6 +66,9 @@ class AssetManagerFileSystem : public FileSystem { Status DeleteDir(const string& d) override; Status RenameFile(const string& s, const string& t) override; + Status GetMatchingPaths(const string& pattern, + std::vector<string>* results) override; + private: string RemoveAssetPrefix(const string& name); diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index 608bd82722..c5a0dc1095 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -61,6 +61,7 @@ py_test( name = "asserts_test", srcs = ["asserts_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":test_lib", "//tensorflow/python:client_testlib", @@ -81,6 +82,7 @@ py_test( name = "builtin_functions_test", srcs = ["builtin_functions_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":test_lib", "//tensorflow/python:client_testlib", @@ -92,6 +94,7 @@ py_test( size = "large", srcs = ["call_trees_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":test_lib", "//tensorflow/contrib/autograph/impl", diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD index e468176da1..54424e2647 100644 --- a/tensorflow/contrib/autograph/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -26,6 +26,7 @@ py_library( visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib/autograph/converters", + "//tensorflow/contrib/autograph/operators", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/pyct/static_analysis", "//tensorflow/contrib/autograph/utils", @@ -38,6 +39,7 @@ py_test( name = "api_test", srcs = ["api_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":impl", "//tensorflow/contrib/autograph/utils", @@ -50,6 +52,7 @@ py_test( name = "conversion_test", srcs = ["conversion_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":impl", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/autograph/impl/config.py b/tensorflow/contrib/autograph/impl/config.py index 543c1486e6..26326465e2 100644 --- a/tensorflow/contrib/autograph/impl/config.py +++ b/tensorflow/contrib/autograph/impl/config.py @@ -41,10 +41,15 @@ DEFAULT_UNCOMPILED_MODULES = set(( NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',)) -# TODO(mdan): Also allow controlling the generated names (for testability). +# TODO(mdan): Also allow controlling the generated names. +# TODO(mdan); Consolidate all internal imports into a single __ag module. COMPILED_IMPORT_STATEMENTS = ( - 'from __future__ import print_function', 'import tensorflow as tf', - 'from tensorflow.contrib.autograph.impl import api as ' - 'autograph_api', - 'from tensorflow.contrib.autograph import utils as ' - 'autograph_utils') + 'from __future__ import print_function', + 'import tensorflow as tf', + 'from tensorflow.contrib.autograph.impl import api' + ' as autograph_api', + 'from tensorflow.contrib.autograph import utils' + ' as autograph_utils', + 'from tensorflow.contrib.autograph import operators' + ' as __ops', +) diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD new file mode 100644 index 0000000000..7856c253bd --- /dev/null +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "operators", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [], +) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py new file mode 100644 index 0000000000..c3f4cab69e --- /dev/null +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This module implements operators that we overload. + +Note that "operator" is used loosely here, and includes control structures like +conditionals and loops, implemented in functional form, using for example +closures for the body. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index edec5f7712..c483ff68c4 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -66,6 +66,7 @@ py_test( name = "compiler_test", srcs = ["compiler_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":pyct", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index d192bc7aab..83f3bafc42 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -34,6 +34,7 @@ py_test( name = "activity_test", srcs = ["activity_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":static_analysis", "//tensorflow/contrib/autograph/pyct", @@ -46,6 +47,7 @@ py_test( name = "live_values_test", srcs = ["live_values_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":static_analysis", "//tensorflow/contrib/autograph/pyct", diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index b53fbb5c18..d3a1b94688 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -44,6 +44,7 @@ py_test( name = "builtins_test", srcs = ["builtins_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":utils", "//tensorflow/python:client_testlib", @@ -84,6 +85,7 @@ py_test( name = "py_func_test", srcs = ["py_func_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":utils", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h index da5e744851..a3b1b013e3 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h @@ -48,9 +48,9 @@ class BatchFeatures { Status GetFeatureColumnSizes(int64* const num_dense_float_features, int64* const num_sparse_float_features, int64* const num_sparse_int_features) const { - QCHECK_NE(num_dense_float_features, nullptr); - QCHECK_NE(num_sparse_float_features, nullptr); - QCHECK_NE(num_sparse_int_features, nullptr); + QCHECK_NE(num_dense_float_features, static_cast<int64*>(nullptr)); + QCHECK_NE(num_sparse_float_features, static_cast<int64*>(nullptr)); + QCHECK_NE(num_sparse_int_features, static_cast<int64*>(nullptr)); *num_dense_float_features = dense_float_feature_columns_.size(); *num_sparse_float_features = sparse_float_feature_columns_.size(); *num_sparse_int_features = sparse_int_feature_columns_.size(); diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index b776307924..fae45ead5c 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -474,6 +474,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor_bridge.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 35312f06b3..7bb0dc1c0f 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -8,6 +8,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_libs", + "if_not_windows", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", @@ -31,12 +32,17 @@ py_library( ], ) +cc_library( + name = "lib_proto_parsing_for_dataset_ops", + deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]), +) + tf_custom_op_library( name = "_dataset_ops.so", srcs = ["ops/dataset_ops.cc"], deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"] + if_static( - extra_deps = ["//tensorflow/core:lib_proto_parsing"], + extra_deps = [":lib_proto_parsing_for_dataset_ops"], otherwise = [], ), ) diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 676959a900..4b50260670 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -class StagingAreaOpsTest(test.TestCase): +class PrefetchingKernelsOpsTest(test.TestCase): def setUp(self): self._event = threading.Event() @@ -200,6 +200,9 @@ class StagingAreaOpsTest(test.TestCase): sess.run(destroy_op) + +class PrefetchToDeviceTest(test.TestCase): + def testPrefetchToDevice(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( @@ -231,6 +234,37 @@ class StagingAreaOpsTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testPrefetchDictToDevice(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testPrefetchToDeviceGpu(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -248,5 +282,62 @@ class StagingAreaOpsTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testPrefetchToDeviceWithReInit(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_initializable_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceGpuWithReInit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 98651bb568..77e23d0319 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -28,6 +28,7 @@ from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib # TODO(rohanj): Add a python class that constructs resource in the __init__ @@ -67,28 +68,77 @@ def function_buffering_resource_reset(function_buffer_resource, name=None): # pylint: disable=protected-access class _PrefetchToDeviceIterator(object): - """A replacement for @{tf.data.Iterator} that prefetches to another device.""" + """A replacement for @{tf.data.Iterator} that prefetches to another device. - def __init__(self, input_dataset, device, buffer_size): + Args: + input_dataset: The input dataset + one_shot: If true, we make a one shot iterator that's already initialized. + device: A fully specified device string where we want to prefetch to + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be + shared under the given name across multiple sessions that share the + same devices (e.g. when using a remote server). + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + one_shot, + device, + buffer_size, + shared_name=None): self._input_dataset = input_dataset self._get_next_call_count = 0 - input_iterator = input_dataset.make_one_shot_iterator() - input_iterator_handle = input_iterator.string_handle() + self._one_shot = one_shot + if shared_name is None: + shared_name = "" + + if self._one_shot: + self._input_iterator = input_dataset.make_one_shot_iterator() + else: + self._input_iterator = iterator_ops.Iterator.from_structure( + self._input_dataset.output_types, self._input_dataset.output_shapes, + shared_name, self._input_dataset.output_classes) + input_iterator_handle = self._input_iterator.string_handle() @function.Defun(dtypes.string) def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, input_iterator.output_types, input_iterator.output_shapes, - input_iterator.output_classes) - return remote_iterator.get_next() + handle, self._input_iterator.output_types, + self._input_iterator.output_shapes, + self._input_iterator.output_classes) + ret = remote_iterator.get_next() + + # Convert any `SparseTensorValue`s to `SparseTensor`s. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor_lib.SparseTensor.from_value(t) + if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret) + ]) + + # Serialize any sparse tensors and convert result to tensors. + ret = nest.pack_sequence_as(ret, [ + ops.convert_to_tensor(t) + for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) + ]) + return nest.flatten(ret) with ops.device(device): self._buffering_resource = function_buffering_resource( f=_prefetch_fn, target_device=gen_dataset_ops.iterator_get_device( - input_iterator._iterator_resource), + self._input_iterator._iterator_resource), string_arg=input_iterator_handle, - buffer_size=buffer_size) + buffer_size=buffer_size, + shared_name=shared_name) + + if not self._one_shot: + reset_op = function_buffering_resource_reset(self._buffering_resource) + with ops.control_dependencies([reset_op]): + self._initializer = self._input_iterator.make_initializer( + self._input_dataset) def get_next(self, name=None): """See @{tf.data.Iterator.get_next}.""" @@ -113,6 +163,12 @@ class _PrefetchToDeviceIterator(object): return ret @property + def initializer(self): + if self._one_shot: + raise NotImplementedError("Can't initialize a one_shot_iterator") + return self._initializer + + @property def output_classes(self): return self._input_dataset.output_classes @@ -135,13 +191,19 @@ class _PrefetchToDeviceDataset(dataset_ops.Dataset): self._buffer_size = buffer_size if buffer_size is not None else 1 def make_one_shot_iterator(self): - return _PrefetchToDeviceIterator(self._input_dataset, self._device, - self._buffer_size) + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=True, + device=self._device, + buffer_size=self._buffer_size) def make_initializable_iterator(self, shared_name=None): - raise NotImplementedError("`prefetch_to_device()` is not currently " - "compatible with initializable iterators. Use " - "`make_one_shot_iterator()` instead.") + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=False, + device=self._device, + buffer_size=self._buffer_size, + shared_name=shared_name) def _as_variant_tensor(self): # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 4af51bec1a..28483f4c88 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -77,7 +77,7 @@ parameter of `Estimator`. ```python distribution = tf.contrib.distribute.MirroredStrategy() -config = tf.estimator.RunConfig(distribute=distribution) +config = tf.estimator.RunConfig(train_distribute=distribution) classifier = tf.estimator.Estimator(model_fn=model_fn, config=config) classifier.train(input_fn=input_fn) ``` diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index 9be186a724..2b49b8f4ef 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -95,7 +95,7 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, # TODO(isaprykin): Work around the colocate_with error. dnn_optimizer=adagrad.AdagradOptimizer(0.001), linear_optimizer=adagrad.AdagradOptimizer(0.001), - config=run_config.RunConfig(distribute=distribution)) + config=run_config.RunConfig(train_distribute=distribution)) num_steps = 10 estimator.train(train_input_fn, steps=num_steps) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py index 5d6e02b4b9..00c25c7a24 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -59,7 +59,7 @@ def build_model_fn_optimizer(): def main(_): distribution = tf.contrib.distribute.MirroredStrategy( ["/device:GPU:0", "/device:GPU:1"]) - config = tf.estimator.RunConfig(distribute=distribution) + config = tf.estimator.RunConfig(train_distribute=distribution) def input_fn(): features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py index e714255f69..b87224251c 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -41,7 +41,7 @@ def main(args): strategy = tf.contrib.distribute.MirroredStrategy( ['/device:GPU:0', '/device:GPU:1']) - config = tf.estimator.RunConfig(distribute=strategy) + config = tf.estimator.RunConfig(train_distribute=strategy) optimizer = tf.train.GradientDescentOptimizer(0.2) model = tf.keras.Sequential() diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py index fe80bb4df5..7644acedc9 100644 --- a/tensorflow/contrib/distribute/python/monitor.py +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import context +from tensorflow.python.framework import errors from tensorflow.python.ops import variables @@ -55,7 +56,9 @@ class Monitor(object): def run_steps(self, num_steps=None): step = 0 - done = False - while done is not None and (num_steps is None or step < num_steps): - done = self._run_step() - step += 1 + while num_steps is None or step < num_steps: + try: + self._run_step() + step += 1 + except errors.OutOfRangeError: + break diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index de08eb491b..9799901483 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -454,6 +454,7 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) cuda_py_test( @@ -501,12 +502,6 @@ cuda_py_test( "//tensorflow/python:client_testlib", ], shard_count = 4, - tags = [ - "manual", - "noasan", - "noguitar", - "optonly", - ], ) cuda_py_test( @@ -1128,6 +1123,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) cuda_py_test( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py index 4d2f40e27f..c6c8d2cf6e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib +from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops @@ -514,6 +515,42 @@ class _BatchReshapeTest(object): batch_shape=new_batch_shape_ph, validate_args=True).sample().eval() + def test_broadcasting_explicitly_unsupported(self): + old_batch_shape = [4] + new_batch_shape = [1, 4, 1] + rate_ = self.dtype([1, 10, 2, 20]) + + rate = array_ops.placeholder_with_default( + rate_, + shape=old_batch_shape if self.is_static_shape else None) + poisson_4 = poisson_lib.Poisson(rate) + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + poisson_141_reshaped = batch_reshape_lib.BatchReshape( + poisson_4, new_batch_shape_ph, validate_args=True) + + x_4 = self.dtype([2, 12, 3, 23]) + x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4) + + if self.is_static_shape: + with self.assertRaisesRegexp(NotImplementedError, + "too few event dims"): + poisson_141_reshaped.log_prob(x_4) + with self.assertRaisesRegexp(NotImplementedError, + "unexpected batch and event shape"): + poisson_141_reshaped.log_prob(x_114) + return + + with self.assertRaisesOpError("too few event dims"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_4).eval() + + with self.assertRaisesOpError("unexpected batch and event shape"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_114).eval() + class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase): diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index c7ee9b2117..3e6c35e0d6 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -115,7 +115,7 @@ class BatchReshape(distribution_lib.Distribution): self._batch_shape_static = tensor_util.constant_value(self._batch_shape_) if self._batch_shape_static is not None: self._batch_shape_static = np.int32(self._batch_shape_static) - self._runtime_assertions = make_runtime_assertions( + self._runtime_assertions = validate_init_args( self._distribution, self._batch_shape_, validate_args, @@ -229,7 +229,8 @@ class BatchReshape(distribution_lib.Distribution): def _call_reshape_input_output(self, fn, x): """Calls `fn`, appropriately reshaping its input `x` and output.""" - with ops.control_dependencies(self._runtime_assertions): + with ops.control_dependencies( + self._runtime_assertions + self._validate_sample_arg(x)): sample_shape, static_sample_shape = self._sample_shape(x) old_shape = array_ops.concat([ sample_shape, @@ -273,61 +274,142 @@ class BatchReshape(distribution_lib.Distribution): result.set_shape(static_shape) return result - -def make_runtime_assertions( + def _validate_sample_arg(self, x): + """Helper which validates sample arg, e.g., input to `log_prob`.""" + with ops.name_scope(name="validate_sample_arg", values=[x]): + x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims) + event_ndims = (array_ops.size(self.event_shape_tensor()) + if self.event_shape.ndims is None + else self.event_shape.ndims) + batch_ndims = (array_ops.size(self.batch_shape_tensor()) + if self.batch_shape.ndims is None + else self.batch_shape.ndims) + expected_batch_event_ndims = batch_ndims + event_ndims + + if (isinstance(x_ndims, int) and + isinstance(expected_batch_event_ndims, int)): + if x_ndims < expected_batch_event_ndims: + raise NotImplementedError( + "Broadcasting is not supported; too few event dims " + "(expected at least {}, saw {}).".format( + expected_batch_event_ndims, x_ndims)) + ndims_assertion = [] + elif self.validate_args: + ndims_assertion = [ + check_ops.assert_greater_equal( + x_ndims, + expected_batch_event_ndims, + message="Broadcasting is not supported; too few event dims.", + name="assert_batch_and_event_ndims_large_enough"), + ] + + if (self.batch_shape.is_fully_defined() and + self.event_shape.is_fully_defined()): + expected_batch_event_shape = np.int32(self.batch_shape.concatenate( + self.event_shape).as_list()) + else: + expected_batch_event_shape = array_ops.concat([ + self.batch_shape_tensor(), + self.event_shape_tensor(), + ], axis=0) + + sample_ndims = x_ndims - expected_batch_event_ndims + if isinstance(sample_ndims, int): + sample_ndims = max(sample_ndims, 0) + if (isinstance(sample_ndims, int) and + x.shape[sample_ndims:].is_fully_defined()): + actual_batch_event_shape = np.int32(x.shape[sample_ndims:].as_list()) + else: + sample_ndims = math_ops.maximum(sample_ndims, 0) + actual_batch_event_shape = array_ops.shape(x)[sample_ndims:] + + if (isinstance(expected_batch_event_shape, np.ndarray) and + isinstance(actual_batch_event_shape, np.ndarray)): + if any(expected_batch_event_shape != actual_batch_event_shape): + raise NotImplementedError("Broadcasting is not supported; " + "unexpected batch and event shape " + "(expected {}, saw {}).".format( + expected_batch_event_shape, + actual_batch_event_shape)) + # We need to set the final runtime-assertions to `ndims_assertion` since + # its possible this assertion was created. We could add a condition to + # only do so if `self.validate_args == True`, however this is redundant + # as `ndims_assertion` already encodes this information. + runtime_assertions = ndims_assertion + elif self.validate_args: + # We need to make the `ndims_assertion` a control dep because otherwise + # TF itself might raise an exception owing to this assertion being + # ill-defined, ie, one cannot even compare different rank Tensors. + with ops.control_dependencies(ndims_assertion): + shape_assertion = check_ops.assert_equal( + expected_batch_event_shape, + actual_batch_event_shape, + message=("Broadcasting is not supported; " + "unexpected batch and event shape."), + name="assert_batch_and_event_shape_same") + runtime_assertions = [shape_assertion] + else: + runtime_assertions = [] + + return runtime_assertions + + +def validate_init_args( distribution, batch_shape, validate_args, batch_shape_static): """Helper to __init__ which makes or raises assertions.""" - runtime_assertions = [] - - if batch_shape.shape.ndims is not None: - if batch_shape.shape.ndims != 1: - raise ValueError("`batch_shape` must be a vector " - "(saw rank: {}).".format( - batch_shape.shape.ndims)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_rank( - batch_shape, - 1, - message="`batch_shape` must be a vector.", - name="assert_batch_shape_is_vector"), - ] - - batch_size_static = np.prod(batch_shape_static) - dist_batch_size_static = ( - None if not distribution.batch_shape.is_fully_defined() - else np.prod(distribution.batch_shape).value) - - if batch_size_static is not None and dist_batch_size_static is not None: - if batch_size_static != dist_batch_size_static: - raise ValueError("`batch_shape` size ({}) must match " - "`distribution.batch_shape` size ({}).".format( - batch_size_static, - dist_batch_size_static)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_equal( - math_ops.reduce_prod(batch_shape), - math_ops.reduce_prod(distribution.batch_shape_tensor()), - message=("`batch_shape` size must match " - "`distributions.batch_shape` size."), - name="assert_batch_size"), - ] - - if batch_shape_static is not None: - if np.any(batch_shape_static < 1): - raise ValueError("`batch_shape` elements must be positive " - "(i.e., larger than zero).") - elif validate_args: - runtime_assertions += [ - check_ops.assert_positive( - batch_shape, - message=("`batch_shape` elements must be positive " - "(i.e., larger than zero)."), - name="assert_batch_shape_positive") - ] - - return runtime_assertions + with ops.name_scope(name="validate_init_args", + values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access + runtime_assertions = [] + + if batch_shape.shape.ndims is not None: + if batch_shape.shape.ndims != 1: + raise ValueError("`batch_shape` must be a vector " + "(saw rank: {}).".format( + batch_shape.shape.ndims)) + elif validate_args: + runtime_assertions += [ + check_ops.assert_rank( + batch_shape, + 1, + message="`batch_shape` must be a vector.", + name="assert_batch_shape_is_vector"), + ] + + batch_size_static = np.prod(batch_shape_static) + dist_batch_size_static = ( + None if not distribution.batch_shape.is_fully_defined() + else np.prod(distribution.batch_shape).value) + + if batch_size_static is not None and dist_batch_size_static is not None: + if batch_size_static != dist_batch_size_static: + raise ValueError("`batch_shape` size ({}) must match " + "`distribution.batch_shape` size ({}).".format( + batch_size_static, + dist_batch_size_static)) + elif validate_args: + runtime_assertions += [ + check_ops.assert_equal( + math_ops.reduce_prod(batch_shape), + math_ops.reduce_prod(distribution.batch_shape_tensor()), + message=("`batch_shape` size must match " + "`distributions.batch_shape` size."), + name="assert_batch_size"), + ] + + if batch_shape_static is not None: + if np.any(batch_shape_static < 1): + raise ValueError("`batch_shape` elements must be positive " + "(i.e., larger than zero).") + elif validate_args: + runtime_assertions += [ + check_ops.assert_positive( + batch_shape, + message=("`batch_shape` elements must be positive " + "(i.e., larger than zero)."), + name="assert_batch_shape_positive") + ] + + return runtime_assertions diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py index 91a7aded11..34cb8d0e08 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils.py @@ -19,6 +19,7 @@ from __future__ import print_function import abc import collections +import functools import weakref from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 @@ -867,3 +868,115 @@ class Checkpoint(core_checkpointable.Checkpointable): # initialization when executing eagerly. self._maybe_create_save_counter() return status + + +class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): + """Wraps save and restore callbacks as a `SaveableObject`.""" + + def __init__(self, name, dtype, save_callback, restore_callback): + self._restore_callback = restore_callback + spec = saver_lib.BaseSaverBuilder.SaveSpec( + tensor=save_callback, + slice_spec="", + name=name, + dtype=dtype) + super(_CallbackSaveable, self).__init__( + save_callback, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into both variables.""" + tensor, = restored_tensors + return self._restore_callback(tensor) + + +class _SplitDependency(core_checkpointable.CheckpointableBase): + """Looks like a regular variable while synchronizing save/restores.""" + + def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, + fill_save_buffer_fn, consume_restore_buffer_fn): + self._save_buffer = save_buffer + self._restore_buffer = restore_buffer + self._name = name + self._dtype = dtype + self._num_components = num_components + self._fill_save_buffer_fn = fill_save_buffer_fn + self._consume_restore_buffer_fn = consume_restore_buffer_fn + + def _save(self): + """Pull from the shared buffer, populating it if necessary.""" + if self._name not in self._save_buffer: + if self._save_buffer: + raise AssertionError( + ("Split dependency %s (%s) unsynchronized. Split dependencies must " + "be saved together.") % (self._name, self)) + self._fill_save_buffer_fn(self._save_buffer) + return self._save_buffer.pop(self._name) + + def _restore(self, tensor): + """Push into the shared buffer, flushing it if necessary.""" + if self._name in self._restore_buffer: + raise AssertionError( + ("Split dependency %s (%s) unsynchronized. Split dependencies must " + "be restored together.") % (self._name, self)) + self._restore_buffer[self._name] = tensor + if len(self._restore_buffer) == self._num_components: + op = self._consume_restore_buffer_fn(self._restore_buffer) + self._restore_buffer.clear() + return op + else: + return control_flow_ops.no_op() + + def _gather_saveables_for_checkpoint(self): + """Looks to Checkpointable like a regular variable.""" + return { + core_checkpointable.VARIABLE_VALUE_KEY: + functools.partial(_CallbackSaveable, + dtype=self._dtype, + save_callback=self._save, + restore_callback=self._restore) + } + + +def split_dependency(component_names, component_dtypes, + fill_save_buffer_fn, consume_restore_buffer_fn): + """Creates multiple dependencies with a synchronized save/restore. + + Useful when a single op produces `Tensor`s which should each be saved under + different objects, or when `Tensor`s saved with many different objects need to + be restored together as inputs to a single op (i.e. an object which uses a + single fused op may be swapped out for a subgraph of objects, and these two + programs are checkpoint compatible). + + Args: + component_names: A sequence of names for the split + dependencies. `fill_save_buffer_fn` must add these keys to the dictionary + it is passed, and `consume_restore_buffer_fn` will receive a dictionary + with these keys. + component_dtypes: Data types for the `Tensor`s being saved and restored, a + sequence corresponding to `component_names`. + fill_save_buffer_fn: A function which takes an empty dictionary as an + argument and adds `Tensor`s with `component_names` as keys. These + `Tensor`s will be saved as if they were individual variables. + consume_restore_buffer_fn: A function which takes a dictionary with + `component_names` as keys mapping to restored individual `Tensor`s and + returns a restore op (or if executing eagerly, runs the restoration and + may return `None`). + + Returns: + A dictionary mapping from names to Checkpointable objects. If one is + reachable from an object as a dependency, the others should be too; adding + dependencies on some but not all of the objects will result in errors. + """ + save_buffer = {} + restore_buffer = {} + split_dependencies = {} + for name, dtype in zip(component_names, component_dtypes): + split_dependencies[name] = _SplitDependency( + save_buffer=save_buffer, + restore_buffer=restore_buffer, + name=name, + dtype=dtype, + num_components=len(component_names), + fill_save_buffer_fn=fill_save_buffer_fn, + consume_restore_buffer_fn=consume_restore_buffer_fn) + return split_dependencies diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py index 5e1b64728a..891c093a0f 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.keras._impl.keras.engine import sequential from tensorflow.python.keras._impl.keras.engine import training from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops @@ -69,6 +70,87 @@ class MyModel(training.Model): return ret +def _split_variable_closure(variable): + def _fill_save_buffer_fn(save_buffer): + save_buffer["first_half"] = variable[:2] + save_buffer["second_half"] = variable[2:] + return _fill_save_buffer_fn + + +def _combine_variable_closure(variable): + def _consume_restore_buffer_fn(restore_buffer): + return variable.assign( + array_ops.concat([restore_buffer["first_half"], + restore_buffer["second_half"]], + axis=0)) + return _consume_restore_buffer_fn + + +class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): + + def __init__(self): + self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) + split_dependencies = checkpointable_utils.split_dependency( + component_names=("first_half", "second_half"), + component_dtypes=(self.combined.dtype,) * 2, + fill_save_buffer_fn=_split_variable_closure( + self.combined), + consume_restore_buffer_fn=_combine_variable_closure( + self.combined)) + for name, dep in split_dependencies.items(): + self._track_checkpointable(dep, name=name) + + +class HasRegularDeps(checkpointable.Checkpointable): + + def __init__(self): + self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) + self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) + + +class OnlyOneDep(checkpointable.Checkpointable): + + def __init__(self): + self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) + + +class SplitTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testSaveRestoreSplitDep(self): + save_checkpoint = checkpointable_utils.Checkpoint( + dep=SaveTensorSlicesAsDeps()) + self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.])) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = save_checkpoint.save(checkpoint_prefix) + + regular_deps = HasRegularDeps() + regular_restore_checkpoint = checkpointable_utils.Checkpoint( + dep=regular_deps) + regular_restore_checkpoint.restore( + save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half)) + self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half)) + + one_dep = OnlyOneDep() + one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep) + status = one_dep_restore_checkpoint.restore(save_path) + with self.assertRaises(AssertionError): + # Missing the second dependency. + status.assert_consumed() + status.run_restore_ops() + self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half)) + + restore_checkpoint = checkpointable_utils.Checkpoint() + status = restore_checkpoint.restore(save_path) + restore_checkpoint.dep = SaveTensorSlicesAsDeps() + status.assert_consumed().run_restore_ops() + self.assertAllEqual( + [1., 2., 3., 4.], + self.evaluate(restore_checkpoint.dep.combined)) + + class InterfaceTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 60453006f4..99b1e098d5 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -107,16 +107,20 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ - if self._buffer_resource_handle is not None: - with ops.device(self._device): - ret = prefetching_ops.function_buffering_resource_get_next( - function_buffer_resource=self._buffer_resource_handle, - output_types=self._flat_output_types) - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self._output_types, ret), self._output_types, - self._output_shapes, self._output_classes) - else: - return super(Iterator, self)._next_internal() + # This runs in sync mode as iterators use an error status to communicate + # that there is no more data to iterate over. + # TODO(b/77291417): Fix + with context.execution_mode(context.SYNC): + if self._buffer_resource_handle is not None: + with ops.device(self._device): + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) + else: + return super(Iterator, self)._next_internal() # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset # attributes(potential). diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index f86331af6f..2f6cfdf31e 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -22,6 +22,7 @@ cuda_py_test( ":linear_regression", "//tensorflow:tensorflow_py", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) cuda_py_test( diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 2be62c9438..bec0329ebb 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -89,6 +89,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -129,6 +130,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -266,6 +268,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py index b5e4d34dc7..dd009a6753 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py @@ -34,6 +34,7 @@ from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops from tensorflow.python.ops import nn +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -52,7 +53,9 @@ def _dnn_only_estimator_fn( config=None): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), model_dir=model_dir, dnn_feature_columns=feature_columns, dnn_optimizer=optimizer, @@ -100,7 +103,9 @@ def _linear_only_estimator_fn( partitioner=None): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), model_dir=model_dir, linear_feature_columns=feature_columns, linear_optimizer=optimizer, diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py index 71f810acec..75e3107670 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py @@ -32,6 +32,7 @@ from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -41,7 +42,9 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): """Returns a DNNEstimator that uses regression_head.""" return dnn.DNNEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), *args, **kwargs) diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 74da2cbb3f..85ef3291ba 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -178,7 +178,7 @@ def binary_classification_head( def regression_head(weight_column=None, label_dimension=1, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, inverse_link_fn=None, name=None): @@ -218,7 +218,9 @@ def regression_head(weight_column=None, of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. loss_fn: Optional loss function. Defaults to `mean_squared_error`. inverse_link_fn: Optional inverse link function, also known as 'mean function'. Defaults to identity. @@ -243,7 +245,7 @@ def regression_head(weight_column=None, def poisson_regression_head( weight_column=None, label_dimension=1, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, compute_full_loss=True, name=None): """Creates a `_Head` for poisson regression using `tf.nn.log_poisson_loss`. @@ -275,7 +277,9 @@ def poisson_regression_head( of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. compute_full_loss: Whether to include the constant `log(z!)` term in computing the poisson loss. See `tf.nn.log_poisson_loss` for the full documentation. diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 8837dfdc6c..98962ca427 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -1162,8 +1162,8 @@ class PoissonRegressionHead(test.TestCase): # exp(-1) - 2 * (-1) + 2*ln(2) - 2 + 0.5*ln(2*pi*2), # exp(1) - 3 * 1 + 3*ln(3) - 3 + 0.5*ln(2*pi*3)] # = [1.0, 3.020, 1.482] - # sum_loss = 5.502 - expected_loss = 5.502 + # training_loss = (1.0 + 3.020 + 1.482) / 3 + expected_loss = 1.834 atol = 0.001 expected_train_result = b'my_train_op' def _train_op_fn(loss): diff --git a/tensorflow/contrib/estimator/python/estimator/linear_test.py b/tensorflow/contrib/estimator/python/estimator/linear_test.py index c63514eb68..c41996b9c6 100644 --- a/tensorflow/contrib/estimator/python/estimator/linear_test.py +++ b/tensorflow/contrib/estimator/python/estimator/linear_test.py @@ -32,6 +32,7 @@ from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -42,7 +43,9 @@ def _linear_estimator_fn( """Returns a LinearEstimator that uses regression_head.""" return linear.LinearEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), *args, **kwargs) diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 74d3d6d728..d9e5aca295 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -483,14 +483,14 @@ class MultiHeadTest(test.TestCase): [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32), } # Loss for the first head: - # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + - # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2 - # = 28 + # loss1 = ((1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + + # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2) / 8 + # = 3.5 # Loss for the second head: - # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + - # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2 - # = 74 - expected_training_loss = 28. + 74. + # loss2 = ((0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + + # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2) / 12 + # = 6.167 + expected_training_loss = 3.5 + 6.167 training_loss = multi_head.create_loss( features={}, diff --git a/tensorflow/contrib/framework/python/ops/arg_scope_test.py b/tensorflow/contrib/framework/python/ops/arg_scope_test.py index 7ba9d4ffa9..4c3879d4fc 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope_test.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope_test.py @@ -170,6 +170,30 @@ class ArgScopeTest(test.TestCase): self.assertTupleEqual(args, func1_args) self.assertDictEqual(kwargs, func1_kwargs) + def testNestedArgScopeObjectCreatedOutsideScopeOverridesArgScope(self): + + def get_scope_object(): + with arg_scope([func1], a=1, b=None, c=[1]) as sc: + return sc + + scope_object = get_scope_object() + with arg_scope([func1], b=2, d=10): + with arg_scope(scope_object): + args, kwargs = func1(0) + self.assertTupleEqual(args, (0,)) + self.assertDictEqual(kwargs, {'a': 1, 'b': None, 'c': [1]}) + + def testArgScopeObjectCreatedWithinScopeInheritsArgScope(self): + def get_scope_object(): + with arg_scope([func1], a=1, b=None, c=[1]) as sc: + return sc + + with arg_scope([func1], b=2, d=10): + with arg_scope(get_scope_object()): + args, kwargs = func1(0) + self.assertTupleEqual(args, (0,)) + self.assertDictEqual(kwargs, {'a': 1, 'b': None, 'c': [1], 'd': 10}) + def testSharedArgScope(self): func1_args = (0,) func1_kwargs = {'a': 1, 'b': None, 'c': [1]} diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 9e56d3c039..461066bbb4 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -354,6 +354,7 @@ py_test( name = "classifier_metrics_test", srcs = ["python/eval/python/classifier_metrics_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":classifier_metrics", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 0d1afad72d..508f487722 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -31,6 +31,7 @@ __all__ = [ 'add_image_comparison_summaries', 'add_gan_model_summaries', 'add_regularization_loss_summaries', + 'add_cyclegan_image_summaries', ] @@ -51,14 +52,9 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): ValueError: If real and generated data aren't images. """ if isinstance(gan_model, namedtuples.CycleGANModel): - saved_params = locals() - saved_params.pop('gan_model', None) - with ops.name_scope('cyclegan_x2y_image_summaries'): - add_gan_model_image_summaries(gan_model.model_x2y, **saved_params) - with ops.name_scope('cyclegan_y2x_image_summaries'): - add_gan_model_image_summaries(gan_model.model_y2x, **saved_params) - return - + raise ValueError( + '`add_gan_model_image_summaries` does not take CycleGANModels. Please ' + 'use `add_cyclegan_image_summaries` instead.') _assert_is_image(gan_model.real_data) _assert_is_image(gan_model.generated_data) @@ -89,6 +85,49 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): add_gan_model_summaries(gan_model) +def add_cyclegan_image_summaries(cyclegan_model): + """Adds image summaries for CycleGAN. + + There are two summaries, one for each generator. The first image is the + generator input, the second is the generator output, and the third is G(F(x)). + + Args: + cyclegan_model: A CycleGANModel tuple. + + Raises: + ValueError: If `cyclegan_model` isn't a CycleGANModel. + ValueError: If generated data, generator inputs, and reconstructions aren't + images. + ValueError: If the generator input, generated data, and reconstructions + aren't all the same size. + """ + if not isinstance(cyclegan_model, namedtuples.CycleGANModel): + raise ValueError('`cyclegan_model` was not a CycleGANModel. Instead, was ' + '%s' % type(cyclegan_model)) + + _assert_is_image(cyclegan_model.model_x2y.generator_inputs) + _assert_is_image(cyclegan_model.model_x2y.generated_data) + _assert_is_image(cyclegan_model.reconstructed_x) + _assert_is_image(cyclegan_model.model_y2x.generator_inputs) + _assert_is_image(cyclegan_model.model_y2x.generated_data) + _assert_is_image(cyclegan_model.reconstructed_y) + + def _add_comparison_summary(gan_model, reconstructions): + image_list = (array_ops.unstack(gan_model.generator_inputs[:1]) + + array_ops.unstack(gan_model.generated_data[:1]) + + array_ops.unstack(reconstructions[:1])) + summary.image( + 'image_comparison', eval_utils.image_reshaper( + image_list, num_cols=len(image_list)), max_outputs=1) + + with ops.name_scope('x2y_image_comparison_summaries'): + _add_comparison_summary( + cyclegan_model.model_x2y, cyclegan_model.reconstructed_x) + with ops.name_scope('y2x_image_comparison_summaries'): + _add_comparison_summary( + cyclegan_model.model_y2x, cyclegan_model.reconstructed_y) + + def add_image_comparison_summaries(gan_model, num_comparisons=2, display_diffs=False): """Adds image summaries to compare triplets of images. @@ -109,15 +148,6 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2, ValueError: If the generator input, real, and generated data aren't all the same size. """ - if isinstance(gan_model, namedtuples.CycleGANModel): - saved_params = locals() - saved_params.pop('gan_model', None) - with ops.name_scope('cyclegan_x2y_image_comparison_summaries'): - add_image_comparison_summaries(gan_model.model_x2y, **saved_params) - with ops.name_scope('cyclegan_y2x_image_comparison_summaries'): - add_image_comparison_summaries(gan_model.model_y2x, **saved_params) - return - _assert_is_image(gan_model.generator_inputs) _assert_is_image(gan_model.generated_data) _assert_is_image(gan_model.real_data) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 45eb108586..33d51bfc21 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -65,15 +65,14 @@ def get_cyclegan_model(): return namedtuples.CycleGANModel( model_x2y=model_x2y, model_y2x=model_y2x, - reconstructed_x=array_ops.zeros([3, 30, 35, 6]), - reconstructed_y=array_ops.zeros([3, 30, 35, 6])) + reconstructed_x=array_ops.zeros([4, 32, 32, 3]), + reconstructed_y=array_ops.zeros([4, 32, 32, 3])) class SummariesTest(test.TestCase): - def _test_add_gan_model_image_summaries_impl(self, get_model_fn, - expected_num_summary_ops, - model_summaries): + def _test_add_gan_model_image_summaries_impl( + self, get_model_fn, expected_num_summary_ops, model_summaries): summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2, model_summaries=model_summaries) @@ -89,8 +88,9 @@ class SummariesTest(test.TestCase): def test_add_gan_model_image_summaries_no_model(self): self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False) - def test_add_gan_model_image_summaries_for_cyclegan(self): - self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, True) + def test_cyclegan_image_summaries_dont_work(self): + with self.assertRaises(ValueError): + summaries.add_gan_model_image_summaries(get_cyclegan_model()) def _test_add_gan_model_summaries_impl(self, get_model_fn, expected_num_summary_ops): @@ -137,7 +137,11 @@ class SummariesTest(test.TestCase): self._test_add_image_comparison_summaries_impl(get_gan_model, 1) def test_add_image_comparison_summaries_for_cyclegan(self): - self._test_add_image_comparison_summaries_impl(get_cyclegan_model, 2) + summaries.add_cyclegan_image_summaries(get_cyclegan_model()) + + self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + with self.test_session(use_gpu=True): + summary.merge_all().eval() if __name__ == '__main__': diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD index 7dd40c19c5..8186fa1c62 100644 --- a/tensorflow/contrib/kfac/examples/BUILD +++ b/tensorflow/contrib/kfac/examples/BUILD @@ -28,8 +28,28 @@ py_library( ) py_binary( - name = "convnet_mnist_main", - srcs = ["convnet_mnist_main.py"], + name = "convnet_mnist_single_main", + srcs = ["convnet_mnist_single_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_multi_tower_main", + srcs = ["convnet_mnist_multi_tower_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_distributed_main", + srcs = ["convnet_mnist_distributed_main.py"], srcs_version = "PY2AND3", deps = [ ":convnet", diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index 39d80addaa..e8e3353091 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -37,6 +37,8 @@ import tensorflow as tf from tensorflow.contrib.kfac.examples import mlp from tensorflow.contrib.kfac.examples import mnist +from tensorflow.contrib.kfac.python.ops import optimizer as opt + lc = tf.contrib.kfac.layer_collection oq = tf.contrib.kfac.op_queue @@ -48,12 +50,18 @@ __all__ = [ "linear_layer", "build_model", "minimize_loss_single_machine", - "minimize_loss_distributed", + "distributed_grads_only_and_ops_chief_worker", + "distributed_grads_and_ops_dedicated_workers", "train_mnist_single_machine", - "train_mnist_distributed", + "train_mnist_distributed_sync_replicas", + "train_mnist_multitower" ] +# Inverse update ops will be run every _INVERT_EVRY iterations. +_INVERT_EVERY = 10 + + def conv_layer(layer_id, inputs, kernel_size, out_channels): """Builds a convolutional layer with ReLU non-linearity. @@ -161,8 +169,9 @@ def build_model(examples, labels, num_labels, layer_collection): accuracy = tf.reduce_mean( tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) - tf.summary.scalar("loss", loss) - tf.summary.scalar("accuracy", accuracy) + with tf.device("/cpu:0"): + tf.summary.scalar("loss", loss) + tf.summary.scalar("accuracy", accuracy) # Register parameters. K-FAC needs to know about the inputs, outputs, and # parameters of each conv/fully connected layer and the logits powering the @@ -181,41 +190,59 @@ def build_model(examples, labels, num_labels, layer_collection): def minimize_loss_single_machine(loss, accuracy, layer_collection, + device="/gpu:0", session_config=None): """Minimize loss with K-FAC on a single machine. - A single Session is responsible for running all of K-FAC's ops. + A single Session is responsible for running all of K-FAC's ops. The covariance + and inverse update ops are placed on `device`. All model variables are on CPU. Args: loss: 0-D Tensor. Loss to be minimized. accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. + device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and invserse + update ops are run on this device. session_config: None or tf.ConfigProto. Configuration for tf.Session(). Returns: final value for 'accuracy'. """ # Train with K-FAC. - global_step = tf.train.get_or_create_global_step() + g_step = tf.train.get_or_create_global_step() optimizer = opt.KfacOptimizer( learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=[device], + inv_devices=[device], momentum=0.9) - train_op = optimizer.minimize(loss, global_step=global_step) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + with tf.device(device): + train_op = optimizer.minimize(loss, global_step=g_step) + + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): - global_step_, loss_, accuracy_, _, _ = sess.run( - [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) - - if global_step_ % 100 == 0: - sess.run(optimizer.inv_update_op) + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, inverse_op]) - if global_step_ % 100 == 0: + if (global_step_ + 1) % _INVERT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) @@ -250,16 +277,62 @@ def _num_gradient_tasks(num_tasks): return int(np.ceil(0.6 * num_tasks)) -def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, - checkpoint_dir, loss, accuracy, layer_collection): - """Minimize loss with an synchronous implementation of K-FAC. +def _make_distributed_train_op( + task_id, + num_worker_tasks, + num_ps_tasks, + layer_collection +): + """Creates optimizer and distributed training op. - Different tasks are responsible for different parts of K-FAC's Ops. The first - 60% of tasks update weights; the next 20% accumulate covariance statistics; - the last 20% invert the matrices used to precondition gradients. + Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes + the train op. + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC + optimizer. + optimizer: Instance of `opt.KfacOptimizer`. + global_step: `tensor`, Global step. + """ + tf.logging.info("Task id : %d", task_id) + with tf.device(tf.train.replica_device_setter(num_ps_tasks)): + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + momentum=0.9) + sync_optimizer = tf.train.SyncReplicasOptimizer( + opt=optimizer, + replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks), + total_num_replicas=num_worker_tasks) + return sync_optimizer, optimizer, global_step + + +def distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection, invert_every=10): + """Minimize loss with a synchronous implementation of K-FAC. + + All workers perform gradient computation. Chief worker applies gradient after + averaging the gradients obtained from all the workers. All workers block + execution untill the update is applied. Chief worker runs covariance and + inverse update ops. Covariance and inverse matrices are placed on parameter + servers in a round robin manner. For further details on synchronous + distributed optimization check `tf.train.SyncReplicasOptimizer`. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. If 0, parameter servers are not used. @@ -271,6 +344,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, run with each step. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. + invert_every: `int`, Number of steps between update the inverse. Returns: final value for 'accuracy'. @@ -278,19 +352,80 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, Raises: ValueError: if task_id >= num_worker_tasks. """ - with tf.device(tf.train.replica_device_setter(num_ps_tasks)): - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - momentum=0.9) - inv_update_queue = oq.OpQueue(optimizer.inv_update_ops) - sync_optimizer = tf.train.SyncReplicasOptimizer( - opt=optimizer, - replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) - train_op = sync_optimizer.minimize(loss, global_step=global_step) + + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + train_op = sync_optimizer.minimize(loss, global_step=global_step) + + tf.logging.info("Starting training.") + hooks = [sync_optimizer.make_session_run_hook(is_chief)] + + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + if is_chief: + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + update_op = tf.cond( + tf.equal(tf.mod(global_step + 1, invert_every), 0), + lambda: make_update_op(inv_update_thunks), + tf.no_op) + else: + update_op = train_op + + with tf.train.MonitoredTrainingSession( + master=master, + is_chief=is_chief, + checkpoint_dir=checkpoint_dir, + hooks=hooks, + stop_grace_period_secs=0) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, update_op]) + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, + loss_, accuracy_) + return accuracy_ + + +def distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection): + """Minimize loss with a synchronous implementation of K-FAC. + + Different workers are responsible for different parts of K-FAC's Ops. The + first 60% of tasks compute gradients; the next 20% accumulate covariance + statistics; the last 20% invert the matrices used to precondition gradients. + The chief worker applies the gradient . + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + master: string. IP and port of TensorFlow runtime process. Set to empty + string to run locally. + checkpoint_dir: string or None. Path to store checkpoints under. + loss: 0-D Tensor. Loss to be minimized. + accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to + run with each step. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + final value for 'accuracy'. + + Raises: + ValueError: if task_id >= num_worker_tasks. + """ + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars() + train_op = sync_optimizer.minimize(loss, global_step=global_step) + inv_update_queue = oq.OpQueue(inv_update_ops) tf.logging.info("Starting training.") is_chief = (task_id == 0) @@ -306,7 +441,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, if _is_gradient_task(task_id, num_worker_tasks): learning_op = train_op elif _is_cov_update_task(task_id, num_worker_tasks): - learning_op = optimizer.cov_update_op + learning_op = cov_update_op elif _is_inv_update_task(task_id, num_worker_tasks): # TODO(duckworthd): Running this op before cov_update_op has been run a # few times can result in "InvalidArgumentError: Cholesky decomposition @@ -324,13 +459,18 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, return accuracy_ -def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): +def train_mnist_single_machine(data_dir, + num_epochs, + use_fake_data=False, + device="/gpu:0"): """Train a ConvNet on MNIST. Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. + device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and inverse + update ops are run on this device. Returns: accuracy of model on the final minibatch of training data. @@ -350,22 +490,38 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): examples, labels, num_labels=10, layer_collection=layer_collection) # Fit model. - return minimize_loss_single_machine(loss, accuracy, layer_collection) + return minimize_loss_single_machine( + loss, accuracy, layer_collection, device=device) def train_mnist_multitower(data_dir, num_epochs, num_towers, - use_fake_data=True): + use_fake_data=True, devices=None): """Train a ConvNet on MNIST. + Training data is split equally among the towers. Each tower computes loss on + its own batch of data and the loss is aggregated on the CPU. The model + variables are placed on first tower. The covariance and inverse update ops + and variables are placed on GPUs in a round robin manner. + Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. num_towers: int. Number of CPUs to split inference across. use_fake_data: bool. If True, generate a synthetic dataset. + devices: string, Either list of CPU or GPU. The covaraince and inverse + update ops are run on this device. Returns: accuracy of model on the final minibatch of training data. """ + if devices: + device_count = {"GPU": num_towers} + else: + device_count = {"CPU": num_towers} + + devices = devices or [ + "/cpu:{}".format(tower_id) for tower_id in range(num_towers) + ] # Load a dataset. tf.logging.info("Loading MNIST into memory.") tower_batch_size = 128 @@ -388,7 +544,7 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, layer_collection = lc.LayerCollection() tower_results = [] for tower_id in range(num_towers): - with tf.device("/cpu:%d" % tower_id): + with tf.device(devices[tower_id]): with tf.name_scope("tower%d" % tower_id): with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): tf.logging.info("Building tower %d." % tower_id) @@ -402,34 +558,79 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, accuracy = tf.reduce_mean(accuracies) # Fit model. + session_config = tf.ConfigProto( - allow_soft_placement=False, device_count={ - "CPU": num_towers - }) - return minimize_loss_single_machine( - loss, accuracy, layer_collection, session_config=session_config) + allow_soft_placement=False, + device_count=device_count, + ) + + g_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=devices, + inv_devices=devices, + momentum=0.9) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + train_op = optimizer.minimize(loss, global_step=g_step) -def train_mnist_distributed(task_id, - num_worker_tasks, - num_ps_tasks, - master, - data_dir, - num_epochs, - use_fake_data=False): - """Train a ConvNet on MNIST. + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + + tf.logging.info("Starting training.") + with tf.train.MonitoredTrainingSession(config=session_config) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, inverse_op]) + + if (global_step_ + 1) % _INVERT_EVERY == 0: + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", + global_step_, loss_, accuracy_) + + +def train_mnist_distributed_sync_replicas(task_id, + is_chief, + num_worker_tasks, + num_ps_tasks, + master, + data_dir, + num_epochs, + op_strategy, + use_fake_data=False): + """Train a ConvNet on MNIST using Sync replicas optimizer. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. master: string. IP and port of TensorFlow runtime process. data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. + op_strategy: `string`, Strategy to run the covariance and inverse + ops. If op_strategy == `chief_worker` then covaraiance and inverse + update ops are run on chief worker otherwise they are run on dedicated + workers. + use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. + + Raises: + ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"]. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") @@ -448,9 +649,17 @@ def train_mnist_distributed(task_id, # Fit model. checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") - return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, - master, checkpoint_dir, loss, accuracy, - layer_collection) + if op_strategy == "chief_worker": + return distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + elif op_strategy == "dedicated_workers": + return distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + else: + raise ValueError("Only supported op strategies are : {}, {}".format( + "chief_worker", "dedicated_workers")) if __name__ == "__main__": diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py new file mode 100644 index 0000000000..b4c2d4a9e9 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py @@ -0,0 +1,62 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +Distributed training with sync replicas optimizer. See +`convnet.train_mnist_distributed_sync_replicas` for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_integer("task", -1, "Task identifier") +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") +flags.DEFINE_string( + "cov_inv_op_strategy", "chief_worker", + "In dist training mode run the cov, inv ops on chief or dedicated workers." +) +flags.DEFINE_string("master", "local", "Session master.") +flags.DEFINE_integer("ps_tasks", 2, + "Number of tasks in the parameter server job.") +flags.DEFINE_integer("replicas_to_aggregate", 5, + "Number of replicas to aggregate.") +flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.") +flags.DEFINE_integer("num_epochs", None, "Number of epochs.") + + +def _is_chief(): + """Determines whether a job is the chief worker.""" + if "chief_worker" in FLAGS.brain_jobs: + return FLAGS.brain_job_name == "chief_worker" + else: + return FLAGS.task == 0 + + +def main(unused_argv): + _ = unused_argv + convnet.train_mnist_distributed_sync_replicas( + FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks, + FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy) + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py new file mode 100644 index 0000000000..4249bf8a8d --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +Multi tower training mode. See `convnet.train_mnist_multitower` for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir") +flags.DEFINE_integer("num_towers", 2, + "Number of towers for multi tower training.") + + +def main(unused_argv): + _ = unused_argv + assert FLAGS.num_towers > 1 + devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)] + convnet.train_mnist_multitower( + FLAGS.data_dir, + num_epochs=200, + num_towers=FLAGS.num_towers, + devices=devices) + + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py index b0c6fbde19..3aa52aff19 100644 --- a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py @@ -14,44 +14,26 @@ # ============================================================================== r"""Train a ConvNet on MNIST using K-FAC. -See convnet.py for details. +Train on single machine. See `convnet.train_mnist_single_machine` for details. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import sys +from absl import flags import tensorflow as tf from tensorflow.contrib.kfac.examples import convnet -FLAGS = None +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") -def main(argv): - _ = argv - - if FLAGS.num_towers > 1: - convnet.train_mnist_multitower( - FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) - else: - convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200) +def main(unused_argv): + convnet.train_mnist_single_gpu(FLAGS.data_dir, num_epochs=200) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--data_dir", - type=str, - default="/tmp/mnist", - help="Directory to store dataset in.") - parser.add_argument( - "--num_towers", - type=int, - default=1, - help="Number of CPUs to split minibatch across.") - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py index 8d86c2bb51..6de775cc79 100644 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -112,15 +112,16 @@ class ConvNetTest(tf.test.TestCase): def testMinimizeLossSingleMachine(self): with tf.Graph().as_default(): loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy, - layer_collection) - self.assertLess(accuracy_, 1.0) + accuracy_ = convnet.minimize_loss_single_machine( + loss, accuracy, layer_collection, device="/cpu:0") + self.assertLess(accuracy_, 2.0) def testMinimizeLossDistributed(self): with tf.Graph().as_default(): loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_distributed( + accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker( task_id=0, + is_chief=True, num_worker_tasks=1, num_ps_tasks=0, master="", @@ -128,7 +129,7 @@ class ConvNetTest(tf.test.TestCase): loss=loss, accuracy=accuracy, layer_collection=layer_collection) - self.assertLess(accuracy_, 1.0) + self.assertLess(accuracy_, 2.0) def testTrainMnistSingleMachine(self): with tf.Graph().as_default(): @@ -138,7 +139,7 @@ class ConvNetTest(tf.test.TestCase): # but there are too few parameters for the model to effectively memorize # the training set the way an MLP can. convnet.train_mnist_single_machine( - data_dir=None, num_epochs=1, use_fake_data=True) + data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0") def testTrainMnistMultitower(self): with tf.Graph().as_default(): @@ -149,13 +150,15 @@ class ConvNetTest(tf.test.TestCase): def testTrainMnistDistributed(self): with tf.Graph().as_default(): # Ensure model training doesn't crash. - convnet.train_mnist_distributed( + convnet.train_mnist_distributed_sync_replicas( task_id=0, + is_chief=True, num_worker_tasks=1, num_ps_tasks=0, master="", data_dir=None, num_epochs=1, + op_strategy="chief_worker", use_fake_data=True) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index f73c24f8fb..2477d2bfc1 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -114,6 +114,7 @@ py_test( name = "utils_test", srcs = ["utils_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ "//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/contrib/tpu", diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 586a004f88..19608aca47 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -990,9 +990,11 @@ class LayerCollection(object): num_uses=num_uses), reuse=reuse) block.register_additional_tower(inputs, outputs) - - assert len(inputs) == len(outputs) - self._add_uses(params, len(inputs)) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) def register_conv2d_multi(self, params, @@ -1066,9 +1068,11 @@ class LayerCollection(object): reuse=reuse) block.register_additional_tower(inputs, outputs) - - assert len(inputs) == len(outputs) - self._add_uses(params, len(inputs)) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) # TODO(b/74108452): change the loss registration functions names to refer # to "loss functions" instead of distributions. Following naming convention @@ -1088,7 +1092,7 @@ class LayerCollection(object): inputs: A list of Tensors, each of shape [batch_size, input_size] and dtype int32. Indices into embedding matrix. The list indexes each use in the graph (which might correspond to a "time-step" in an RNN). - OR, can be single Tensor, of shape [num_uses, batch_size, input_size], + OR, can be single Tensor, of shape [num_uses*batch_size, input_size], which is a reshaped version of a Tensor of shape [num_uses, batch_size, input_size]. outputs: A list of Tensors, each of shape [batch_size, embedding_size]. @@ -1129,7 +1133,10 @@ class LayerCollection(object): params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) block.register_additional_tower(inputs, outputs) - self._add_uses(params, len(inputs)) + if isinstance(inputs, (tuple, list)): + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) def register_categorical_predictive_distribution(self, logits, diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index 18b265ae80..c8812d4b23 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -70,6 +70,7 @@ py_test( "python/ops/core_test.py", ], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":_typecheck", ":core", diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 4be55468db..d5b3b279a1 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -188,6 +188,7 @@ py_test( size = "small", srcs = ["python/layers/normalization_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":layers_py", "//tensorflow/contrib/framework:framework_py", @@ -353,6 +354,7 @@ py_test( size = "small", srcs = ["python/ops/sparse_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":layers_py", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 337c9e06b8..00f03a111a 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -104,6 +104,7 @@ See the @{$python/contrib.layers} guide. @@infer_real_valued_columns @@sequence_input_from_feature_columns +@@group_norm @@instance_norm """ @@ -122,6 +123,7 @@ _allowed_symbols = ['bias_add', 'conv3d', 'elu', 'feature_column', + 'group_norm', 'instance_norm', 'legacy_fully_connected', 'legacy_linear', diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py index e7d4080ff7..c807ab0f2e 100644 --- a/tensorflow/contrib/layers/python/layers/normalization.py +++ b/tensorflow/contrib/layers/python/layers/normalization.py @@ -24,11 +24,13 @@ from tensorflow.contrib.layers.python.layers import utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope __all__ = [ + 'group_norm', 'instance_norm', ] @@ -158,3 +160,196 @@ def instance_norm(inputs, if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + + +@add_arg_scope +def group_norm(inputs, + groups=32, + channels_axis=-1, + reduction_axes=(-3, -2), + center=True, + scale=True, + epsilon=1e-6, + activation_fn=None, + param_initializers=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + """Functional interface for the group normalization layer. + + Reference: https://arxiv.org/abs/1803.08494. + + "Group Normalization", Yuxin Wu, Kaiming He + + Args: + inputs: A Tensor with at least 2 dimensions one which is channels. All + shape dimensions must be fully defined. + groups: Integer. Divide the channels into this number of groups over which + normalization statistics are computed. This number must be commensurate + with the number of channels in `inputs`. + channels_axis: An integer. Specifies index of channels axis which will be + broken into `groups`, each of which whose statistics will be computed + across. Must be mutually exclusive with `reduction_axes`. Preferred usage + is to specify negative integers to be agnostic as to whether a batch + dimension is included. + reduction_axes: Tuple of integers. Specifies dimensions over which + statistics will be accumulated. Must be mutually exclusive with + `channels_axis`. Statistics will not be accumulated across axes not + specified in `reduction_axes` nor `channel_axis`. Preferred usage is to + specify negative integers to be agnostic to whether a batch dimension is + included. + + Some sample usage cases: + NHWC format: channels_axis=-1, reduction_axes=[-3, -2] + NCHW format: channels_axis=-3, reduction_axes=[-2, -1] + + center: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is + not used. When the next layer is linear (also e.g. `nn.relu`), this can be + disabled since the scaling can be done by the next layer. + epsilon: Small float added to variance to avoid dividing by zero. + activation_fn: Activation function, default set to None to skip it and + maintain a linear activation. + param_initializers: Optional initializers for beta, gamma, moving mean and + moving variance. + reuse: Whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: Optional collections for the variables. + outputs_collections: Collections to add the outputs. + trainable: If `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + scope: Optional scope for `variable_scope`. + + Returns: + A `Tensor` representing the output of the operation. + + Raises: + ValueError: If the rank of `inputs` is undefined. + ValueError: If rank or channels dimension of `inputs` is undefined. + ValueError: If number of groups is not commensurate with number of channels. + ValueError: If reduction_axes or channels_axis are out of bounds. + ValueError: If reduction_axes are not mutually exclusive with channels_axis. + """ + # TODO(shlens): Support partially defined shapes for the inputs. + inputs = ops.convert_to_tensor(inputs) + original_shape = inputs.shape + + if inputs.shape.ndims is None: + raise ValueError('Inputs %s has undefined rank.' % inputs.name) + if channels_axis > (inputs.shape.ndims - 1): + raise ValueError('Axis is out of bounds.') + + # Standardize the channels_axis to be positive and identify # of channels. + if channels_axis < 0: + channels_axis = inputs.shape.ndims + channels_axis + channels = inputs.shape[channels_axis].value + + if channels is None: + raise ValueError('Inputs %s has undefined channel dimension: %d.' % ( + inputs.name, channels_axis)) + + # Standardize the reduction_axes to be positive. + reduction_axes = list(reduction_axes) + for i in range(len(reduction_axes)): + if reduction_axes[i] < 0: + reduction_axes[i] += inputs.shape.ndims + + for a in reduction_axes: + if a > inputs.shape.ndims: + raise ValueError('Axis is out of bounds.') + if inputs.shape[a].value is None: + raise ValueError('Inputs %s has undefined dimensions %d.' % ( + inputs.name, a)) + if channels_axis == a: + raise ValueError('reduction_axis must be mutually exclusive ' + 'with channels_axis') + if groups > channels: + raise ValueError('Invalid groups %d for %d channels.' % (groups, channels)) + if channels % groups != 0: + raise ValueError('%d channels is not commensurate with %d groups.' % + (channels, groups)) + + # Determine axes before channels. Some examples of common image formats: + # 'NCHW': before = [N], after = [HW] + # 'NHWC': before = [NHW], after = [] + axes_before_channels = inputs.shape.as_list()[:channels_axis] + axes_after_channels = inputs.shape.as_list()[channels_axis+1:] + + # Manually broadcast the parameters to conform to the number of groups. + params_shape_broadcast = ([1] * len(axes_before_channels) + + [groups, channels // groups] + + [1] * len(axes_after_channels)) + + # Reshape the input by the group within the channel dimension. + inputs_shape = (axes_before_channels + [groups, channels // groups] + + axes_after_channels) + inputs = array_ops.reshape(inputs, inputs_shape) + + # Determine the dimensions across which moments are calculated. + moments_axes = [channels_axis + 1] + for a in reduction_axes: + if a > channels_axis: + moments_axes.append(a + 1) + else: + moments_axes.append(a) + + with variable_scope.variable_scope( + scope, 'GroupNorm', [inputs], reuse=reuse) as sc: + # Note that the params_shape is the number of channels always. + params_shape = [channels] + + # Allocate parameters for the beta and gamma of the normalization. + beta, gamma = None, None + dtype = inputs.dtype.base_dtype + if param_initializers is None: + param_initializers = {} + if center: + beta_collections = utils.get_variable_collections( + variables_collections, 'beta') + beta_initializer = param_initializers.get( + 'beta', init_ops.zeros_initializer()) + beta = variables.model_variable('beta', + shape=params_shape, + dtype=dtype, + initializer=beta_initializer, + collections=beta_collections, + trainable=trainable) + beta = array_ops.reshape(beta, params_shape_broadcast) + + if scale: + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') + gamma_initializer = param_initializers.get( + 'gamma', init_ops.ones_initializer()) + gamma = variables.model_variable('gamma', + shape=params_shape, + dtype=dtype, + initializer=gamma_initializer, + collections=gamma_collections, + trainable=trainable) + gamma = array_ops.reshape(gamma, params_shape_broadcast) + + # Calculate the moments. + mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) + + # Compute normalization. + # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor + # appropriately so that this operation may be faster. + gain = math_ops.rsqrt(variance + epsilon) + offset = -mean * gain + if gamma is not None: + gain *= gamma + offset *= gamma + if beta is not None: + offset += beta + outputs = inputs * gain + offset + + # Collapse the groups into the channel dimension. + outputs = array_ops.reshape(outputs, original_shape) + + if activation_fn is not None: + outputs = activation_fn(outputs) + return utils.collect_named_outputs(outputs_collections, sc.name, outputs) diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index 5cff1bf0eb..b6e96350db 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -166,5 +166,231 @@ class InstanceNormTest(test.TestCase): def testOutputBigInput5DNCHW(self): self.doOutputTest((1, 100, 100, 1, 1), 'NCHW', tol=1e-3) + +class GroupNormTest(test.TestCase): + + def testInvalidGroupSize(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(5, 2, 10, 10)) + with self.assertRaisesRegexp(ValueError, + 'Invalid groups 10 for 2 channels.'): + normalization.group_norm(inputs, groups=10, + reduction_axes=[-2, -1], channels_axis=-3) + + def testBadCommensurateGroup(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(5, 4, 10, 10)) + with self.assertRaisesRegexp(ValueError, + '4 channels is not commensurate with ' + '3 groups.'): + normalization.group_norm(inputs, groups=3, + reduction_axes=[-2, -1], channels_axis=-3) + + def testAxisIsBad(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 2, 4, 5)) + with self.assertRaisesRegexp(ValueError, + 'Axis is out of bounds.'): + normalization.group_norm(inputs, channels_axis=5) + with self.assertRaisesRegexp(ValueError, + 'Axis is out of bounds.'): + normalization.group_norm(inputs, reduction_axes=[1, 5]) + + def testNotMutuallyExclusiveAxis(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(10, 32, 32, 32)) + # Specify axis with negative values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[-2]) + # Specify axis with positive values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=1, reduction_axes=[1, 3]) + # Specify axis with mixed positive and negative values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[2]) + + def testUnknownShape(self): + inputs = array_ops.placeholder(dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'undefined rank'): + normalization.group_norm(inputs) + + def testParamsShapeNotFullyDefinedReductionAxes(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 32, None, 4)) + with self.assertRaisesRegexp(ValueError, 'undefined dimensions'): + normalization.group_norm(inputs) + + def testParamsShapeNotFullyDefinedChannelsAxis(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 3, 4, None)) + with self.assertRaisesRegexp(ValueError, 'undefined channel dimension'): + normalization.group_norm(inputs, channels_axis=-1, + reduction_axes=[-3, -2]) + + def testCreateOp(self): + height, width, groups = 3, 3, 4 + images = random_ops.random_uniform((5, height, width, 2*groups), seed=1) + output = normalization.group_norm(images, groups=groups, channels_axis=-1, + reduction_axes=[-3, -2]) + print('name: ', output.op.name) + self.assertListEqual([5, height, width, 2*groups], output.shape.as_list()) + + def testCreateOpFloat64(self): + height, width, groups = 3, 3, 5 + images = random_ops.random_uniform( + (5, height, width, 4*groups), dtype=dtypes.float64, seed=1) + output = normalization.group_norm(images, groups=groups) + self.assertEqual(dtypes.float64, output.dtype) + self.assertListEqual([5, height, width, 4*groups], output.shape.as_list()) + + def testCreateOpNoScaleCenter(self): + height, width, groups = 3, 3, 7 + images = random_ops.random_uniform( + (5, height, width, 3*groups), dtype=dtypes.float32, seed=1) + output = normalization.group_norm(images, groups=groups, center=False, + scale=False) + self.assertListEqual([5, height, width, 3*groups], output.shape.as_list()) + self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta'))) + self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma'))) + + def testCreateVariables_NHWC(self): + height, width = 3, 3 + images = random_ops.random_uniform((5, height, width, 8), seed=1) + normalization.group_norm(images, groups=4, + channels_axis=-1, reduction_axes=(-3, -2), + center=True, scale=True) + beta = contrib_variables.get_variables_by_name('beta')[0] + gamma = contrib_variables.get_variables_by_name('gamma')[0] + self.assertEqual('GroupNorm/beta', beta.op.name) + self.assertEqual('GroupNorm/gamma', gamma.op.name) + + def testCreateVariables_NCHW(self): + height, width, groups = 3, 3, 4 + images = random_ops.random_uniform((5, 2*groups, height, width), seed=1) + normalization.group_norm(images, groups=4, + channels_axis=-3, reduction_axes=(-2, -1), + center=True, scale=True) + beta = contrib_variables.get_variables_by_name('beta')[0] + gamma = contrib_variables.get_variables_by_name('gamma')[0] + self.assertEqual('GroupNorm/beta', beta.op.name) + self.assertEqual('GroupNorm/gamma', gamma.op.name) + + def testReuseVariables(self): + height, width = 3, 3 + images = random_ops.random_uniform((5, height, width, 4), seed=1) + normalization.group_norm(images, groups=2, scale=True, scope='IN') + normalization.group_norm(images, groups=2, scale=True, scope='IN', + reuse=True) + beta = contrib_variables.get_variables_by_name('beta') + gamma = contrib_variables.get_variables_by_name('gamma') + self.assertEqual(1, len(beta)) + self.assertEqual(1, len(gamma)) + + def testValueCorrectWithReuseVars(self): + height, width = 3, 3 + image_shape = (10, height, width, 4) + images = random_ops.random_uniform(image_shape, seed=1) + output_train = normalization.group_norm(images, groups=2, scope='IN') + output_eval = normalization.group_norm(images, groups=2, scope='IN', + reuse=True) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + # output_train and output_eval should be the same. + train_np, eval_np = sess.run([output_train, output_eval]) + self.assertAllClose(train_np, eval_np) + + def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None, + groups=2, tol=1e-2): + # Select the axis for the channel and the dimensions along which statistics + # are accumulated. + if channels_axis < 0: + channels_axis += len(input_shape) + reduced_axes = [channels_axis + 1] + for a in reduction_axes: + if a < 0: + a += len(input_shape) + if a < channels_axis: + reduced_axes.append(a) + else: + reduced_axes.append(a+1) + reduced_axes = tuple(reduced_axes) + + # Calculate the final shape for the output Tensor. + axes_before_channels = input_shape[:channels_axis] + axes_after_channels = input_shape[channels_axis+1:] + channels = input_shape[channels_axis] + outputs_shape = (axes_before_channels + [groups, channels // groups] + + axes_after_channels) + + # Calculate the final shape for the output statistics. + reduced_shape = [] + for i, a in enumerate(outputs_shape): + if i not in reduced_axes: + reduced_shape.append(a) + + for mu in (0.0, 1e2): + for sigma in (1.0, 0.1): + # Determine shape of Tensor after normalization. + expected_mean = np.zeros(reduced_shape) + expected_var = np.ones(reduced_shape) + + inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu + output_op = normalization.group_norm( + inputs, groups=groups, center=False, scale=False, + channels_axis=channels_axis, + reduction_axes=reduction_axes) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + outputs = sess.run(output_op) + # Make sure that there are no NaNs + self.assertFalse(np.isnan(outputs).any()) + + outputs = np.reshape(outputs, outputs_shape) + mean = np.mean(outputs, axis=reduced_axes) + var = np.var(outputs, axis=reduced_axes) + # The mean and variance of each example should be close to 0 and 1 + # respectively. + self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol) + self.assertAllClose(expected_var, var, rtol=tol, atol=tol) + + def testOutputSmallInput4D_NHWC(self): + input_shape = [10, 10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + + def testOutputSmallInput3D_NHWC(self): + input_shape = [10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + + def testOutputSmallInput4D_NCHW(self): + input_shape = [10, 10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + + def testOutputSmallInput3D_NCHW(self): + input_shape = [10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + + def testOutputBigInput4D_NHWC(self): + self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], + groups=1) + + def testOutputBigInput4D_NCHW(self): + self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], + groups=4) + + def testOutputSmallInput2D_NC(self): + self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7) + + def testOutputSmallInput5D_NCXXX(self): + self.doOutputTest([10, 10, 20, 40, 5], + channels_axis=1, + reduction_axes=[2, 3, 4], + groups=5) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index ba55365c14..d665fc9335 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -117,6 +117,7 @@ py_test( size = "small", srcs = ["python/learn/learn_io/data_feeder_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/python:client_testlib", @@ -172,6 +173,7 @@ tf_py_test( "//tensorflow/python:variables", "//tensorflow/python/estimator", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) py_test( @@ -190,6 +192,7 @@ py_test( size = "small", srcs = ["python/learn/graph_actions_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/framework:framework_py", @@ -591,6 +594,7 @@ py_test( size = "small", srcs = ["python/learn/learn_io/io_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/learn/python/learn/datasets", @@ -820,6 +824,7 @@ py_test( size = "small", srcs = ["python/learn/utils/saved_model_export_utils_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index f3500bf56f..8c85c431be 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -298,7 +298,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): # core_run_config.RunConfig.__init__(self) # so instead of breaking compatibility with that assumption, we # just manually initialize this field: - self._distribute = None + self._train_distribute = None gpu_options = config_pb2.GPUOptions( per_process_gpu_memory_fraction=gpu_memory_fraction) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index ac269d540a..9c4533079c 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -89,6 +89,7 @@ cc_library( hdrs = [ "builtin_op_data.h", ], + deps = [":context"], ) cc_library( diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index f84b3dad95..e9d0fbc5a9 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -25,7 +25,7 @@ limitations under the License. namespace tflite { -class AllocationInfo; +struct AllocationInfo; // A memory planner that makes all the allocations using arenas. // diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 5fc8954743..2b6c24768c 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -17,6 +17,8 @@ limitations under the License. #include <stdint.h> +#include "tensorflow/contrib/lite/context.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -174,6 +176,11 @@ typedef struct { int block_size; } TfLiteSpaceToDepthParams; +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + typedef enum { kTfLiteCombinerTypeSum = 0, kTfLiteCombinerTypeMean = 1, diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index 14f461f5f9..a33959dca4 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -68,6 +68,19 @@ public final class Interpreter implements AutoCloseable { } /** + * Initializes a {@code Interpreter} and specifies the number of threads used for inference. + * + * @param modelFile: a file of a pre-trained TF Lite model + * @param numThreads: number of threads to use for inference + */ + public Interpreter(@NonNull File modelFile, int numThreads) { + if (modelFile == null) { + return; + } + wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads); + } + + /** * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. * * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index dbf8f8f7cc..fc8187acfe 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -32,9 +32,13 @@ import java.util.Map; final class NativeInterpreterWrapper implements AutoCloseable { NativeInterpreterWrapper(String modelPath) { + this(modelPath, /* numThreads= */ -1); + } + + NativeInterpreterWrapper(String modelPath, int numThreads) { errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModel(modelPath, errorHandle); - interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1); + interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); isMemoryAllocated = true; } @@ -44,11 +48,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { * NativeInterpreterWrapper}. */ NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) { - modelByteBuffer = mappedByteBuffer; - errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); - modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); - interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1); - isMemoryAllocated = true; + this(mappedByteBuffer, /* numThreads= */ -1); } /** diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc index 19942de7bc..17ef2c572e 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -34,6 +34,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(ahentz): these two checks would make the new implementation + // incompatible with some existing models, where params is not specified. It + // is OK not to have them because toco would have set input and output types + // to match the parameters. + // auto* params = reinterpret_cast<TfLiteCastParams*>(node->builtin_data); + // TF_LITE_ENSURE_EQ(context, input->type, params->in_data_type); + // TF_LITE_ENSURE_EQ(context, output->type, params->out_data_type); + return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims)); } diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc index ee8bfe56d9..e67f4e06f3 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -45,10 +45,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, NumDimensions(input) <= 4); - // TODO(ahentz): Our current implementations only support float32. - TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE( + context, output->type == kTfLiteFloat32 || output->type == kTfLiteUInt8); TF_LITE_ENSURE_EQ(context, input->type, output->type); + if (output->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.)); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128); + } + // TODO(ahentz): For some reason our implementations don't support // activations. TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); @@ -75,6 +80,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_L2NORM(optimized_ops); } #undef TF_LITE_L2NORM + } else if (output->type == kTfLiteUInt8) { +#define TF_LITE_L2NORM(type) \ + type::L2Normalization(GetTensorData<uint8>(input), GetTensorDims(input), \ + input->params.zero_point, \ + GetTensorData<uint8>(output), GetTensorDims(output)) + + if (kernel_type == kReference) { + TF_LITE_L2NORM(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_L2NORM(optimized_ops); + } +#undef TF_LITE_L2NORM } else { context->ReportError(context, "Inputs and outputs not all float types."); return kTfLiteError; diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc index 30e103f330..042314ccf5 100644 --- a/tensorflow/contrib/lite/kernels/l2norm_test.cc +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -25,10 +25,22 @@ using ::testing::ElementsAreArray; class L2NormOpModel : public SingleOpModel { public: - L2NormOpModel(std::initializer_list<int> input_shape, - ActivationFunctionType activation_type) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); + L2NormOpModel(const std::initializer_list<int> input_shape, + const TensorType tensor_type, + const ActivationFunctionType activation_type) { + TensorData data = TensorData{tensor_type}; + if (tensor_type != TensorType_FLOAT32) { + data.min = -2.0; + data.max = 2.0; + data.scale = 2.0; + data.zero_point = 128; + } + input_ = AddInput(data); + if (tensor_type != TensorType_FLOAT32) { + data.min = -1.0; + data.max = 127.0 / 128.0; + } + output_ = AddOutput(data); SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions, CreateL2NormOptions(builder_, activation_type).Union()); BuildInterpreter({input_shape}); @@ -38,7 +50,17 @@ class L2NormOpModel : public SingleOpModel { PopulateTensor(input_, data); } - std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + template <typename T> + std::vector<T> GetOutput() { + return ExtractVector<T>(output_); + } + + std::vector<float> GetDequantizedOutput() { + return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + int input() const { return input_; } private: int input_; @@ -46,13 +68,26 @@ class L2NormOpModel : public SingleOpModel { }; TEST(L2NormOpTest, SimpleTest) { - L2NormOpModel m({1, 1, 1, 6}, ActivationFunctionType_NONE); + L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32, + ActivationFunctionType_NONE); m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); } +TEST(L2NormOpTest, SimpleUint8Test) { + L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE); + + m.QuantizeAndPopulate<uint8_t>(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput<uint8_t>(), + ElementsAreArray({58, 166, 173, 205, 83, 134})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/maximum.cc b/tensorflow/contrib/lite/kernels/maximum.cc index 9fdf2b47ea..13c40603ce 100644 --- a/tensorflow/contrib/lite/kernels/maximum.cc +++ b/tensorflow/contrib/lite/kernels/maximum.cc @@ -52,9 +52,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { MaximumContext op_context(context, node); TF_LITE_ENSURE_EQ(context, op_context.input1->type, op_context.input2->type); - TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input2->dims); - op_context.output->type = op_context.input2->type; - return context->ResizeTensor(context, op_context.output, output_dims); + op_context.output->type = op_context.input1->type; + + bool requires_broadcast = + !HaveSameShapes(op_context.input1, op_context.input2); + + TfLiteIntArray* output_size = nullptr; + if (requires_broadcast) { + TF_LITE_ENSURE_OK( + context, CalculateShapeForBroadcast(context, op_context.input1, + op_context.input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(op_context.input1->dims); + } + + return context->ResizeTensor(context, op_context.output, output_size); } template <KernelType kernel_type> diff --git a/tensorflow/contrib/lite/kernels/maximum_test.cc b/tensorflow/contrib/lite/kernels/maximum_test.cc index b3fd7d4e6f..df2bf29c20 100644 --- a/tensorflow/contrib/lite/kernels/maximum_test.cc +++ b/tensorflow/contrib/lite/kernels/maximum_test.cc @@ -71,6 +71,20 @@ TEST(MaximumOpTest, FloatTest) { ElementsAreArray(ArrayFloatNear({1.0, 0.0, 1.0, 12.0, -2.0, -1.43}))); } +TEST(MaximumOpTest, FloatWithBroadcastTest) { + std::initializer_list<float> data1 = {1.0, 0.0, -1.0, -2.0, -1.44, 11.0}; + std::initializer_list<float> data2 = {0.5, 2.0}; + MaximumOpModel m({TensorType_FLOAT32, {3, 1, 2}}, {TensorType_FLOAT32, {2}}, + TensorType_FLOAT32); + m.SetInput1<float>(data1); + m.SetInput2<float>(data2); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT( + m.GetOutput<float>(), + ElementsAreArray(ArrayFloatNear({1.0, 2.0, 0.5, 2.0, 0.5, 11.0}))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index eb374d9031..e6d5c300dc 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -228,6 +228,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_STRIDED_SLICE(reference_ops, int64_t); } break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, uint8_t); + } + break; default: context->ReportError(context, "Type is currently not supported " diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc index 5c98c5f431..22d7b097cb 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -24,6 +24,8 @@ namespace { using ::int32; using ::testing::ElementsAreArray; +template <typename input_type = float, + TensorType tensor_input_type = TensorType_FLOAT32> class StridedSliceOpModel : public SingleOpModel { public: StridedSliceOpModel(std::initializer_list<int> input_shape, @@ -32,11 +34,11 @@ class StridedSliceOpModel : public SingleOpModel { std::initializer_list<int> strides_shape, int begin_mask, int end_mask, int ellipsis_mask, int new_axis_mask, int shrink_axis_mask) { - input_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(tensor_input_type); begin_ = AddInput(TensorType_INT32); end_ = AddInput(TensorType_INT32); strides_ = AddInput(TensorType_INT32); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(tensor_input_type); SetBuiltinOp( BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions, CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, @@ -45,8 +47,8 @@ class StridedSliceOpModel : public SingleOpModel { BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}); } - void SetInput(std::initializer_list<float> data) { - PopulateTensor<float>(input_, data); + void SetInput(std::initializer_list<input_type> data) { + PopulateTensor<input_type>(input_, data); } void SetBegin(std::initializer_list<int32> data) { PopulateTensor<int32>(begin_, data); @@ -58,7 +60,9 @@ class StridedSliceOpModel : public SingleOpModel { PopulateTensor<int32>(strides_, data); } - std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + std::vector<input_type> GetOutput() { + return ExtractVector<input_type>(output_); + } std::vector<int> GetOutputShape() { return GetTensorShape(output_); } private: @@ -71,19 +75,19 @@ class StridedSliceOpModel : public SingleOpModel { TEST(StridedSliceOpTest, UnsupportedInputSize) { EXPECT_DEATH( - StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0), + StridedSliceOpModel<>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0), "StridedSlice op only supports 1D-4D input arrays."); } TEST(StridedSliceOpTest, UnssupportedArgs) { - EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0), + EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0), "ellipsis_mask is not implemented yet."); - EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), + EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), "new_axis_mask is not implemented yet."); } TEST(StridedSliceOpTest, In1D) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -94,7 +98,7 @@ TEST(StridedSliceOpTest, In1D) { } TEST(StridedSliceOpTest, In1D_EmptyOutput) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({10}); m.SetEnd({3}); @@ -104,7 +108,7 @@ TEST(StridedSliceOpTest, In1D_EmptyOutput) { } TEST(StridedSliceOpTest, In1D_NegativeBegin) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-3}); m.SetEnd({3}); @@ -115,7 +119,7 @@ TEST(StridedSliceOpTest, In1D_NegativeBegin) { } TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-5}); m.SetEnd({3}); @@ -126,7 +130,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { } TEST(StridedSliceOpTest, In1D_NegativeEnd) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({-2}); @@ -137,7 +141,7 @@ TEST(StridedSliceOpTest, In1D_NegativeEnd) { } TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-3}); m.SetEnd({5}); @@ -148,7 +152,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { } TEST(StridedSliceOpTest, In1D_BeginMask) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -159,7 +163,7 @@ TEST(StridedSliceOpTest, In1D_BeginMask) { } TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-2}); m.SetEnd({-3}); @@ -170,7 +174,7 @@ TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { } TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({5}); m.SetEnd({2}); @@ -181,7 +185,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { } TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({2}); m.SetEnd({-4}); @@ -192,7 +196,7 @@ TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { } TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-3}); m.SetEnd({-5}); @@ -203,7 +207,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { } TEST(StridedSliceOpTest, In1D_EndMask) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -214,7 +218,7 @@ TEST(StridedSliceOpTest, In1D_EndMask) { } TEST(StridedSliceOpTest, In1D_NegStride) { - StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3}); m.SetBegin({-1}); m.SetEnd({-4}); @@ -225,7 +229,7 @@ TEST(StridedSliceOpTest, In1D_NegStride) { } TEST(StridedSliceOpTest, In1D_EvenLenStride2) { - StridedSliceOpModel m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2}); m.SetBegin({0}); m.SetEnd({2}); @@ -236,7 +240,7 @@ TEST(StridedSliceOpTest, In1D_EvenLenStride2) { } TEST(StridedSliceOpTest, In1D_OddLenStride2) { - StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3}); m.SetBegin({0}); m.SetEnd({3}); @@ -247,7 +251,7 @@ TEST(StridedSliceOpTest, In1D_OddLenStride2) { } TEST(StridedSliceOpTest, In2D_Identity) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -258,7 +262,7 @@ TEST(StridedSliceOpTest, In2D_Identity) { } TEST(StridedSliceOpTest, In2D) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, 0}); m.SetEnd({2, 2}); @@ -269,7 +273,7 @@ TEST(StridedSliceOpTest, In2D) { } TEST(StridedSliceOpTest, In2D_Stride2) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -280,7 +284,7 @@ TEST(StridedSliceOpTest, In2D_Stride2) { } TEST(StridedSliceOpTest, In2D_NegStride) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, -1}); m.SetEnd({2, -4}); @@ -291,7 +295,7 @@ TEST(StridedSliceOpTest, In2D_NegStride) { } TEST(StridedSliceOpTest, In2D_BeginMask) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, 0}); m.SetEnd({2, 2}); @@ -302,7 +306,7 @@ TEST(StridedSliceOpTest, In2D_BeginMask) { } TEST(StridedSliceOpTest, In2D_EndMask) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, 0}); m.SetEnd({2, 2}); @@ -313,7 +317,7 @@ TEST(StridedSliceOpTest, In2D_EndMask) { } TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, -2}); m.SetEnd({2, -4}); @@ -324,7 +328,7 @@ TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { } TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, -2}); m.SetEnd({2, -3}); @@ -335,7 +339,7 @@ TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { } TEST(StridedSliceOpTest, In3D_Identity) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -347,7 +351,7 @@ TEST(StridedSliceOpTest, In3D_Identity) { } TEST(StridedSliceOpTest, In3D_NegStride) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({-1, -1, -1}); m.SetEnd({-3, -4, -3}); @@ -359,7 +363,7 @@ TEST(StridedSliceOpTest, In3D_NegStride) { } TEST(StridedSliceOpTest, In3D_Strided2) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -370,7 +374,7 @@ TEST(StridedSliceOpTest, In3D_Strided2) { } TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -381,7 +385,7 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { } TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); m.SetBegin({2}); m.SetEnd({1}); @@ -392,7 +396,7 @@ TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) { } TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -403,7 +407,7 @@ TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { } TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); m.SetBegin({-2}); m.SetEnd({-3}); @@ -414,7 +418,7 @@ TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) { } TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -425,7 +429,7 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { } TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -436,7 +440,7 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { } TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -447,7 +451,7 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -458,7 +462,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -469,7 +473,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -480,7 +484,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -491,7 +495,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -502,7 +506,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -513,7 +517,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -525,7 +529,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { // This tests catches a very subtle bug that was fixed by cl/188403234. TEST(StridedSliceOpTest, RunTwice) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); auto setup_inputs = [&m]() { m.SetInput({1, 2, 3, 4, 5, 6}); @@ -544,6 +548,17 @@ TEST(StridedSliceOpTest, RunTwice) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); } +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { + StridedSliceOpModel<uint8, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, + 0, 0, 1); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 791d1378f3..606f4a5635 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -32,6 +32,32 @@ namespace tflite { const char* kEmptyTensorName = ""; +TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, + ErrorReporter* error_reporter) { + switch (tensor_type) { + case TensorType_FLOAT32: + *type = kTfLiteFloat32; + break; + case TensorType_INT32: + *type = kTfLiteInt32; + break; + case TensorType_UINT8: + *type = kTfLiteUInt8; + break; + case TensorType_INT64: + *type = kTfLiteInt64; + break; + case TensorType_STRING: + *type = kTfLiteString; + break; + default: + error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", + EnumNameTensorType(tensor_type), tensor_type); + return kTfLiteError; + } + return kTfLiteOk; +} + // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename, @@ -307,10 +333,25 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_EXP: case BuiltinOperator_TOPK_V2: case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_CAST: case BuiltinOperator_DEQUANTIZE: case BuiltinOperator_PRELU: break; + case BuiltinOperator_CAST: { + TfLiteCastParams* params = MallocPOD<TfLiteCastParams>(); + if (auto* schema_params = op->builtin_options_as_CastOptions()) { + auto in_status = + ConvertTensorType(schema_params->in_data_type(), + ¶ms->in_data_type, error_reporter); + auto out_status = + ConvertTensorType(schema_params->out_data_type(), + ¶ms->out_data_type, error_reporter); + if (in_status != kTfLiteOk || out_status != kTfLiteOk) { + break; + } + } + builtin_data = reinterpret_cast<void*>(params); + break; + } case BuiltinOperator_LSH_PROJECTION: { TfLiteLSHProjectionParams* params = MallocPOD<TfLiteLSHProjectionParams>(); @@ -707,29 +748,10 @@ TfLiteStatus InterpreterBuilder::ParseTensors( } TfLiteType type; - switch (tensor->type()) { - case TensorType_FLOAT32: - type = kTfLiteFloat32; - break; - case TensorType_INT32: - type = kTfLiteInt32; - break; - case TensorType_UINT8: - type = kTfLiteUInt8; - break; - case TensorType_INT64: - type = kTfLiteInt64; - break; - case TensorType_STRING: - type = kTfLiteString; - break; - default: - // tensorType = ArrayType::NONE; - error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n", - EnumNameTensorType(tensor->type()), - tensor->type()); - status = kTfLiteError; - continue; + if (ConvertTensorType(tensor->type(), &type, error_reporter_) != + kTfLiteOk) { + status = kTfLiteError; + continue; } auto get_readonly_data = [&](const char** buffer_data, size_t* buffer_size) { diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index decaf9f160..bc13444dc7 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -162,7 +162,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, }; auto duplicate_state_tensor_float32 = - [interpreter, &nn_model, &augmented_inputs, &next_id](int tensor_id) { + [interpreter, &nn_model, &augmented_inputs](int tensor_id) { const TfLiteTensor* tensor = interpreter->tensor(tensor_id); CHECK_NN(ANeuralNetworksModel_setOperandValue( nn_model, tensor_id, tensor->data.raw, tensor->bytes)); diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index e70aa51298..e735062a7f 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -101,6 +101,7 @@ py_test( name = "convert_saved_model_test", srcs = ["convert_saved_model_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], visibility = ["//visibility:public"], deps = [ ":convert_saved_model", diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 7d2e00fe32..c63bfb28cc 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -381,6 +381,8 @@ table LogSoftmaxOptions { } table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; } table DequantizeOptions { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 66a97a1460..0735be5c8f 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -3702,14 +3702,30 @@ flatbuffers::Offset<LogSoftmaxOptions> CreateLogSoftmaxOptions(flatbuffers::Flat struct CastOptionsT : public flatbuffers::NativeTable { typedef CastOptions TableType; - CastOptionsT() { + TensorType in_data_type; + TensorType out_data_type; + CastOptionsT() + : in_data_type(TensorType_FLOAT32), + out_data_type(TensorType_FLOAT32) { } }; struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef CastOptionsT NativeTableType; + enum { + VT_IN_DATA_TYPE = 4, + VT_OUT_DATA_TYPE = 6 + }; + TensorType in_data_type() const { + return static_cast<TensorType>(GetField<int8_t>(VT_IN_DATA_TYPE, 0)); + } + TensorType out_data_type() const { + return static_cast<TensorType>(GetField<int8_t>(VT_OUT_DATA_TYPE, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_IN_DATA_TYPE) && + VerifyField<int8_t>(verifier, VT_OUT_DATA_TYPE) && verifier.EndTable(); } CastOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -3720,6 +3736,12 @@ struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct CastOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; + void add_in_data_type(TensorType in_data_type) { + fbb_.AddElement<int8_t>(CastOptions::VT_IN_DATA_TYPE, static_cast<int8_t>(in_data_type), 0); + } + void add_out_data_type(TensorType out_data_type) { + fbb_.AddElement<int8_t>(CastOptions::VT_OUT_DATA_TYPE, static_cast<int8_t>(out_data_type), 0); + } explicit CastOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -3733,8 +3755,12 @@ struct CastOptionsBuilder { }; inline flatbuffers::Offset<CastOptions> CreateCastOptions( - flatbuffers::FlatBufferBuilder &_fbb) { + flatbuffers::FlatBufferBuilder &_fbb, + TensorType in_data_type = TensorType_FLOAT32, + TensorType out_data_type = TensorType_FLOAT32) { CastOptionsBuilder builder_(_fbb); + builder_.add_out_data_type(out_data_type); + builder_.add_in_data_type(in_data_type); return builder_.Finish(); } @@ -5727,6 +5753,8 @@ inline CastOptionsT *CastOptions::UnPack(const flatbuffers::resolver_function_t inline void CastOptions::UnPackTo(CastOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; + { auto _e = in_data_type(); _o->in_data_type = _e; }; + { auto _e = out_data_type(); _o->out_data_type = _e; }; } inline flatbuffers::Offset<CastOptions> CastOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -5737,8 +5765,12 @@ inline flatbuffers::Offset<CastOptions> CreateCastOptions(flatbuffers::FlatBuffe (void)_rehasher; (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CastOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _in_data_type = _o->in_data_type; + auto _out_data_type = _o->out_data_type; return tflite::CreateCastOptions( - _fbb); + _fbb, + _in_data_type, + _out_data_type); } inline DequantizeOptionsT *DequantizeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index d552de313c..8a35fb9034 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -234,6 +234,7 @@ cc_library( "graph_transformations/identify_relu1.cc", "graph_transformations/lstm_utils.cc", "graph_transformations/make_initial_dequantize_operator.cc", + "graph_transformations/merge_reshape_into_preceding_transpose.cc", "graph_transformations/propagate_activation_function_into_constants.cc", "graph_transformations/propagate_array_data_types.cc", "graph_transformations/propagate_fixed_sizes.cc", @@ -251,7 +252,8 @@ cc_library( "graph_transformations/remove_trivial_reshape.cc", "graph_transformations/remove_trivial_slice.cc", "graph_transformations/remove_unused_op.cc", - "graph_transformations/reorder_activation_functions.cc", + "graph_transformations/reorder_elementwise_unary.cc", + "graph_transformations/reorder_reshape_transpose.cc", "graph_transformations/resolve_batch_normalization.cc", "graph_transformations/resolve_batch_to_space_nd_attributes.cc", "graph_transformations/resolve_constant_binary.cc", @@ -259,6 +261,7 @@ cc_library( "graph_transformations/resolve_constant_fake_quant.cc", "graph_transformations/resolve_constant_fill.cc", "graph_transformations/resolve_constant_gather.cc", + "graph_transformations/resolve_constant_random_uniform.cc", "graph_transformations/resolve_constant_range.cc", "graph_transformations/resolve_constant_shape_or_rank.cc", "graph_transformations/resolve_constant_stack.cc", diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 52c789293c..39e49bc347 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -211,6 +211,7 @@ struct ParsedModelFlags { Arg<bool> allow_nonexistent_arrays = Arg<bool>(false); Arg<bool> allow_nonascii_arrays = Arg<bool>(false); Arg<string> arrays_extra_info_file; + Arg<string> model_flags_file; }; // Flags that describe the operation you would like to do (what conversion diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 22a23357b3..5d51431005 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -357,6 +357,14 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, strides.mutable_list()->add_i(src_op.stride_height); strides.mutable_list()->add_i(src_op.stride_width); strides.mutable_list()->add_i(1); + if ((src_op.dilation_width_factor != 1) || + (src_op.dilation_height_factor != 1)) { + auto& dilations = (*conv2d_op->mutable_attr())["dilations"]; + dilations.mutable_list()->add_i(1); + dilations.mutable_list()->add_i(src_op.dilation_height_factor); + dilations.mutable_list()->add_i(src_op.dilation_width_factor); + dilations.mutable_list()->add_i(1); + } string padding; if (src_op.padding.type == PaddingType::kSame) { padding = "SAME"; @@ -391,84 +399,6 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, } } -void ConvertDilatedConvOperator(const Model& model, const ConvOperator& src_op, - GraphDef* tensorflow_graph) { - CHECK((src_op.dilation_width_factor > 1) || - (src_op.dilation_height_factor > 1)) - << "Conv operator must have height or width dilation factor > 1. " - "Otherwise, use regular conv op."; - CHECK_EQ(src_op.stride_width, 1) - << "Dilated AND strided convolution is unsupported"; - CHECK_EQ(src_op.stride_height, 1) - << "Dilated AND strided convolution is unsupported"; - - // Emulate dilated convolution with a chain of SpaceToBatchND -> Conv -> - // BatchToSpaceND ops. - - // Compute padding - const auto& input_array = model.GetArray(src_op.inputs[0]); - const auto& input_shape = input_array.shape(); - CHECK_EQ(input_shape.dimensions_count(), 4); - int height_mod_dilation = input_shape.dims(1) % src_op.dilation_height_factor; - int pad_height; - if (height_mod_dilation) { - pad_height = src_op.dilation_height_factor - height_mod_dilation; - } else { - pad_height = 0; - } - int pad_width; - int width_mod_dilation = input_shape.dims(2) % src_op.dilation_width_factor; - if (width_mod_dilation) { - pad_width = src_op.dilation_width_factor - width_mod_dilation; - } else { - pad_width = 0; - } - - // SpaceToBatchND op "collapses" the spatially separated elements together - string stb_output = src_op.outputs[0] + "/dilated_conv_SpaceToBatch"; - auto* stb_op = tensorflow_graph->add_node(); - stb_op->set_op("SpaceToBatchND"); - stb_op->set_name(stb_output); - *stb_op->add_input() = src_op.inputs[0]; - (*stb_op->mutable_attr())["T"].set_type(DT_FLOAT); - string block_shape = src_op.outputs[0] + "/dilated_conv_block_shape"; - CreateIntTensorConst( - block_shape, - {src_op.dilation_height_factor, src_op.dilation_width_factor}, {2}, - tensorflow_graph); - *stb_op->add_input() = block_shape; - (*stb_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); - string stb_paddings = src_op.outputs[0] + "/dilated_conv_paddings"; - CreateIntTensorConst(stb_paddings, {0, pad_height, pad_width, 0}, {2, 2}, - tensorflow_graph); - *stb_op->add_input() = stb_paddings; - (*stb_op->mutable_attr())["Tpaddings"].set_type(DT_INT32); - - // Perform a regular conv on the "collapsed" elements - ConvOperator conv_op; - string conv_output = src_op.outputs[0] + "/dilated_conv_Conv2D"; - conv_op.inputs = src_op.inputs; - conv_op.inputs[0] = stb_output; - conv_op.outputs = {conv_output}; - conv_op.padding.type = src_op.padding.type; - conv_op.stride_width = src_op.stride_width; - conv_op.stride_height = src_op.stride_height; - conv_op.dilation_width_factor = 1; - conv_op.dilation_height_factor = 1; - ConvertConvOperator(model, conv_op, tensorflow_graph); - - // BatchToSpaceND op restores elements to their original layout - auto* bts_op = tensorflow_graph->add_node(); - bts_op->set_op("BatchToSpaceND"); - bts_op->set_name(src_op.outputs[0]); - *bts_op->add_input() = conv_output; - (*bts_op->mutable_attr())["T"].set_type(DT_FLOAT); - *bts_op->add_input() = block_shape; - (*bts_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); - *bts_op->add_input() = stb_paddings; - (*bts_op->mutable_attr())["Tcrops"].set_type(DT_INT32); -} - void ConvertDepthwiseConvOperator(const Model& model, const DepthwiseConvOperator& src_op, GraphDef* tensorflow_graph) { @@ -1711,6 +1641,23 @@ void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, (*topk_op->mutable_attr())["sorted"].set_b(true); } +void ConvertRandomUniformOperator(const Model& model, + const RandomUniformOperator& src_op, + GraphDef* tensorflow_graph) { + CHECK(tensorflow_graph != nullptr); + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("RandomUniform"); + CHECK_EQ(src_op.inputs.size(), 1); + new_op->set_name(src_op.outputs[0]); + *new_op->add_input() = src_op.inputs[0]; + const auto shape_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(shape_type); + (*new_op->mutable_attr())["dtype"].set_type( + GetTensorFlowDataType(src_op.dtype)); + (*new_op->mutable_attr())["seed"].set_i(src_op.seed); + (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1719,13 +1666,8 @@ void ConvertOperator(const Model& model, const Operator& src_op, } if (src_op.type == OperatorType::kConv) { - const ConvOperator& conv_op = static_cast<const ConvOperator&>(src_op); - if ((conv_op.dilation_width_factor != 1) || - (conv_op.dilation_height_factor != 1)) { - return ConvertDilatedConvOperator(model, conv_op, tensorflow_graph); - } else { - ConvertConvOperator(model, conv_op, tensorflow_graph); - } + ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kDepthwiseConv) { ConvertDepthwiseConvOperator( model, static_cast<const DepthwiseConvOperator&>(src_op), @@ -1897,6 +1839,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertTransposeConvOperator( model, static_cast<const TransposeConvOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kRandomUniform) { + ConvertRandomUniformOperator( + model, static_cast<const RandomUniformOperator&>(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index d38db85280..0fffab574d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -33,6 +33,11 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { if (conv_op->stride_width != conv_op->stride_height) { return false; } + if ((conv_op->dilation_width_factor != 1) || + (conv_op->dilation_height_factor != 1)) { + // Depthwise conv does not support dilation + return false; + } auto& weights_array = model->GetArray(conv_op->inputs[1]); if (!weights_array.buffer) { // Yield until the weights are resolved as a constant array. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 640afc7c74..27c5044bb3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -128,6 +128,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs) DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs) +DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose) DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu) DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv) @@ -152,7 +153,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator) DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays) DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays) DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax) -DECLARE_GRAPH_TRANSFORMATION(ReorderActivationFunctions) +DECLARE_GRAPH_TRANSFORMATION(ReorderElementwiseUnary) +DECLARE_GRAPH_TRANSFORMATION(ReorderReshapeTranspose) DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) @@ -173,6 +175,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 7c97ef0d31..23c9e3246b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -223,8 +223,11 @@ bool PropagateMinMaxAmongArrays(Model* model, if (array.minmax) { CHECK(*array.minmax == *reference_minmax) << "Both the following arrays have minmax, and they disagree: " - << reference_array_name << " and " << array_name - << ". Expected that either only one of them would have minmax, or at " + << reference_array_name << " (" << reference_minmax->min << "," + << reference_minmax->max << ") and " << array_name << " (" + << array.minmax->min << "," << array.minmax->max + << "). Expected that either only one of them would have minmax, or " + "at " "least that they would agree."; } else { array.GetOrCreateMinMax() = *reference_minmax; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc new file mode 100644 index 0000000000..5065004093 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc @@ -0,0 +1,190 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool OperatorReady(const Model& model, const Operator* op) { + if (!model.HasArray(op->inputs[0]) || !model.HasArray(op->inputs[1]) || + !model.HasArray(op->outputs[0])) { + // Arrays are missing. + return false; + } + + if (!model.GetArray(op->inputs[0]).has_shape() || + !model.GetArray(op->outputs[0]).has_shape()) { + // Input and output needs the shape. + return false; + } + + if (!model.GetArray(op->inputs[1]).buffer) { + // Buffer needs to be a constant. + return false; + } + + return true; +} + +// Returns whether the reshape could be a transpose. +std::vector<int32> ReshapeToTranspose(const Model& model, + const TensorFlowReshapeOperator* op) { + CHECK(!op->shape.empty()); + CHECK(model.HasArray(op->inputs[0])); + CHECK(model.HasArray(op->outputs[0])); + + const auto& input_array = model.GetArray(op->inputs[0]); + const auto& output_array = model.GetArray(op->outputs[0]); + + CHECK(input_array.has_shape()); + CHECK(output_array.has_shape()); + + std::vector<int> in_shape = input_array.shape().dims(); + std::vector<int> out_shape = output_array.shape().dims(); + + std::vector<int> one_indices; + std::vector<int> not_one_indices; + + // Separate into one indices and not one indices. + for (int i = 0; i < in_shape.size(); i++) { + if (in_shape[i] == 1) { + one_indices.push_back(i); + } else { + not_one_indices.push_back(i); + } + } + + // Reorder the vertices. + std::vector<int> perm; + perm.reserve(in_shape.size()); + int one_index = 0; + int not_one_index = 0; + for (const auto val : out_shape) { + if (val == 1) { + perm.push_back(one_indices[one_index]); + one_index++; + } else { + perm.push_back(not_one_indices[not_one_index]); + not_one_index++; + } + } + + return perm; +} + +} // namespace + +// When a transpose is fed into a reshape, it is possible for the two operators +// to be merged if the reshape does not affect memory ordering and does not +// affects the number of dimensions. This only occurs when only unary dimensions +// are shifting position. +bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, + std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>( + it->get(), OperatorType::kTensorFlowReshape); + + if (reshape_op == nullptr) { + return false; + } + + if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) { + return false; + } + + const string intermediate_name = reshape_op->inputs[0]; + const string output_name = reshape_op->outputs[0]; + + // Guarantee the input is only consume by the reshape. + if (CountOpsWithInput(*model, intermediate_name) != 1) { + return false; + } + + // Check for the parent operator. + const auto& transpose_it = FindOpWithOutput(*model, intermediate_name); + if (transpose_it == model->operators.end()) { + return false; + } + + // Find the parent operator and guarantee it is a transpose. + TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>( + transpose_it->get(), OperatorType::kTranspose); + + if (transpose_op == nullptr) { + return false; + } + + if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) { + return false; + } + + if (!ReshapeIsEquivalentToTranspose(*model, reshape_op, + false /*allow_extra_unary_dimensions*/)) { + return false; + } + + // Check that the intermediate is not an output array. + if (!IsDiscardableArray(*model, intermediate_name)) { + AddMessageF( + "Cannot fuse %s and %s as it would invalidate the transpose " + "output array.", + LogName(*transpose_op), LogName(*reshape_op)); + return false; + } + + AddMessageF("Merging operations %s and %s", LogName(*transpose_op), + LogName(*reshape_op)); + + // const auto& intermediate_array = model->GetArray(intermediate_name); + // const auto& output_array = model->GetArray(output_name); + + auto merged_perm = ReshapeToTranspose(*model, reshape_op); + + // Combine the permutations. + const auto& transpose_perm = transpose_op->perm; + for (int i = 0; i < merged_perm.size(); i++) { + merged_perm[i] = transpose_perm[merged_perm[i]]; + } + + // Remove the reshape as passthrough operation. + if (!RemoveTrivialPassthroughOp(this, model, op_index)) { + return false; + } + + // Update transpose_op's constant buffer to contain the new permutation. + model->GetArray(transpose_op->inputs[1]) + .GetMutableBuffer<ArrayDataType::kInt32>() + .data = merged_perm; + transpose_op->perm = merged_perm; + + // transpose_ops's shape will likely has changed. + model->GetArray(transpose_op->outputs[0]).clear_shape(); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 778da39bf1..89ad58f887 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -50,78 +50,108 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { old_output_data_types[output] = model->GetArray(output).data_type; } // Do the actual output data types propagation. - if (op->type == OperatorType::kDequantize || - op->type == OperatorType::kResizeBilinear) { - // These operators unconditionally produce float outputs - SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); - } else if (op->type == OperatorType::kTensorFlowLess || - op->type == OperatorType::kTensorFlowLessEqual || - op->type == OperatorType::kTensorFlowGreater || - op->type == OperatorType::kTensorFlowGreaterEqual) { - // These operators unconditionally produce bool outputs - SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); - } else if (op->type == OperatorType::kRank || - op->type == OperatorType::kTensorFlowShape) { - // These operators only produce int32 outputs. - SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); - } else if (op->type == OperatorType::kTensorFlowSplit || - op->type == OperatorType::kTensorFlowConcat || - op->type == OperatorType::kFill) { - // These operators produce an output with the same type as their 2nd input - CHECK_GE(op->inputs.size(), 2); - const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type; - SetDataTypeForAllOutputs(model, op, data_type); - } else if (op->type == OperatorType::kTransposeConv) { - // These operators produce an output with the same type as their 3rd input - CHECK_GE(op->inputs.size(), 3); - const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type; - SetDataTypeForAllOutputs(model, op, data_type); - } else if (op->type == OperatorType::kCast) { - // Data type of the Cast op is specified. - CHECK_EQ(op->outputs.size(), 1); - auto* cast_op = static_cast<CastOperator*>(op); - model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type; - } else if (op->type == OperatorType::kArgMax) { - // Data type of the ArgMax op is specified. - CHECK_EQ(op->outputs.size(), 1); - auto* argmax_op = static_cast<ArgMaxOperator*>(op); - model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; - } else if (op->type == OperatorType::kRange) { - auto* range_op = static_cast<RangeOperator*>(op); - // Output type of the Range op can be set via an attribute - ArrayDataType data_type; - if (range_op->dtype != ArrayDataType::kNone) { - // Use the type if specified - data_type = range_op->dtype; - } else { - // Otherwise use the first input - CHECK_GE(op->inputs.size(), 1); - data_type = model->GetArray(op->inputs[0]).data_type; + switch (op->type) { + case OperatorType::kDequantize: + case OperatorType::kResizeBilinear: + // These operators unconditionally produce float outputs + SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); + break; + case OperatorType::kTensorFlowLess: + case OperatorType::kTensorFlowLessEqual: + case OperatorType::kTensorFlowGreater: + case OperatorType::kTensorFlowGreaterEqual: + // These operators unconditionally produce bool outputs + SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); + break; + case OperatorType::kRank: + case OperatorType::kTensorFlowShape: + // These operators only produce int32 outputs. + SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); + break; + case OperatorType::kTensorFlowSplit: + case OperatorType::kTensorFlowConcat: + case OperatorType::kFill: { + // These operators produce an output with the same type as their 2nd input + CHECK_GE(op->inputs.size(), 2); + const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; } - CHECK_EQ(op->outputs.size(), 1); - SetDataTypeForAllOutputs(model, op, data_type); - } else if (op->type == OperatorType::kTensorFlowUnsupported) { - auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op); - // Some output tensors from the op could be eliminated by optimization. - // This can make unsupported_op->output_data_types have more elements than - // op->outputs. - if (unsupported_op->output_data_types.size() < op->outputs.size()) { + case OperatorType::kTransposeConv: { + // These operators produce an output with the same type as their 3rd input + CHECK_GE(op->inputs.size(), 3); + const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; + } + case OperatorType::kCast: { + // Data type of the Cast op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* cast_op = static_cast<CastOperator*>(op); + model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type; + break; + } + case OperatorType::kArgMax: { + // Data type of the ArgMax op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* argmax_op = static_cast<ArgMaxOperator*>(op); + model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; + break; + } + case OperatorType::kRange: { + auto* range_op = static_cast<RangeOperator*>(op); + // Output type of the Range op can be set via an attribute + ArrayDataType data_type; + if (range_op->dtype != ArrayDataType::kNone) { + // Use the type if specified + data_type = range_op->dtype; + } else { + // Otherwise use the first input + CHECK_GE(op->inputs.size(), 1); + data_type = model->GetArray(op->inputs[0]).data_type; + } + CHECK_EQ(op->outputs.size(), 1); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } + case OperatorType::kRandomUniform: { + auto* rand_op = static_cast<RandomUniformOperator*>(op); + // The output type of RandomUniform is specified with an attribute + if (rand_op->dtype == ArrayDataType::kNone) { + return false; + } + CHECK_EQ(op->outputs.size(), 1); + SetDataTypeForAllOutputs(model, op, rand_op->dtype); + break; + } + case OperatorType::kTensorFlowUnsupported: { + auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op); + // Some output tensors from the op could be eliminated by optimization. + // This can make unsupported_op->output_data_types have more elements than + // op->outputs. + if (unsupported_op->output_data_types.size() < op->outputs.size()) { + return false; + } + for (int i = 0; i < op->outputs.size(); ++i) { + auto output = op->outputs[i]; + auto data_type = unsupported_op->output_data_types[i]; + model->GetArray(output).data_type = data_type; + } + break; + } + case OperatorType::kExpandDims: { + // Yield on ExpandDim until it is converted to Reshape return false; } - for (int i = 0; i < op->outputs.size(); ++i) { - auto output = op->outputs[i]; - auto data_type = unsupported_op->output_data_types[i]; - model->GetArray(output).data_type = data_type; + default: { + // These operators produce outputs with the same type as their 1st input + CHECK_GT(op->inputs.size(), 0); + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; } - } else if (op->type == OperatorType::kExpandDims) { - // Yield on ExpandDim until it is converted to Reshape - return false; - } else { - // These operators produce outputs with the same type as their 1st input - CHECK_GT(op->inputs.size(), 0); - const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; - SetDataTypeForAllOutputs(model, op, data_type); } + // Return true if any output data type changed, false if none changed. for (const auto& output : op->outputs) { if (old_output_data_types[output] != model->GetArray(output).data_type) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 676736cfc5..68d6f21cf8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -38,6 +38,16 @@ void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth, const int input_height = input_shape.dims(1); const int batch = input_shape.dims(0); + CHECK_GE(input_width, 1); + CHECK_GE(input_height, 1); + CHECK_GE(batch, 1); + CHECK_GE(kwidth, 1); + CHECK_GE(kheight, 1); + CHECK_GE(stride_width, 1); + CHECK_GE(stride_height, 1); + CHECK_GE(dilation_width_factor, 1); + CHECK_GE(dilation_height_factor, 1); + int dilated_kwidth = dilation_width_factor * (kwidth - 1) + 1; int dilated_kheight = dilation_height_factor * (kheight - 1) + 1; @@ -392,8 +402,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) { depth * block_size * block_size})); } -void ProcessFillOperator(Model* model, FillOperator* op) { - CHECK_EQ(op->inputs.size(), 2); +void ProcessOpWithShapeInput(Model* model, Operator* op) { CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { @@ -1529,7 +1538,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { static_cast<SpaceToDepthOperator*>(op)); break; case OperatorType::kFill: - ProcessFillOperator(model, static_cast<FillOperator*>(op)); + CHECK_EQ(op->inputs.size(), 2); + ProcessOpWithShapeInput(model, op); break; case OperatorType::kFullyConnected: ProcessFullyConnectedOperator(model, @@ -1659,6 +1669,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { // transforms that remove them, so we avoid propagating shapes through // them and let things settle once they've been removed. break; + case OperatorType::kRandomUniform: + CHECK_EQ(op->inputs.size(), 1); + ProcessOpWithShapeInput(model, op); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 9fcc95e1fe..7784558b22 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -472,6 +472,44 @@ bool ChooseQuantizationForOperatorOutput( return true; } + +// Fixes array minmax info to match the quantization parameters. +// This is required for when quantization parameters change for an array during +// quantization (such as ChooseQuantizationForOperatorOutput). +void FixMinMaxPostQuantization(ArrayDataType quantized_data_type, + const QuantizationParams& quantization_params, + MinMax* minmax) { + double qmin, qmax; + switch (quantized_data_type) { + case ArrayDataType::kUint8: + qmin = 0; + qmax = 255; + break; + case ArrayDataType::kInt16: + qmin = -32768; + qmax = 32767; + break; + default: + // No update required. + return; + } + + // Compute new minmax values. + double min = + (qmin - quantization_params.zero_point) * quantization_params.scale; + double max = + (qmax - quantization_params.zero_point) * quantization_params.scale; + + // If we are close to the existing minmax values don't bother changing them. + // This prevents propagating small floating point precision errors. + constexpr double kMinMaxThreshold = 1e-5; + const double width = max - min; + if (std::abs(min - minmax->min) > kMinMaxThreshold * width || + std::abs(max - minmax->max) > kMinMaxThreshold * width) { + minmax->min = min; + minmax->max = max; + } +} } // namespace bool Quantize::Run(Model* model, std::size_t op_index) { @@ -618,12 +656,19 @@ bool Quantize::Run(Model* model, std::size_t op_index) { &quantization_params)) { changed = true; const auto& output = op.outputs[output_index]; + auto& output_array = model->GetArray(output); + + // Fix up the min/max information on the output array to match the chosen + // quantization parameters. + auto& output_minmax = output_array.GetMinMax(); + FixMinMaxPostQuantization(quantized_data_type, quantization_params, + &output_minmax); + QuantizeArray(this, model, output, quantized_data_type, quantization_params); + const auto& dequantized_output = AvailableArrayName(*model, output + "_dequantized"); - const auto& output_array = model->GetArray(output); - const auto& output_minmax = output_array.GetMinMax(); auto& dequantized_output_array = model->GetOrCreateArray(dequantized_output); dequantized_output_array.data_type = ArrayDataType::kFloat; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc index 11f8d4b6ee..bdcca5b7ca 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc @@ -72,6 +72,13 @@ bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) { minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0]; minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0]; // We always want [min, max] to contain 0. + if (minmax.min > 0 || minmax.max < 0) { + LOG(ERROR) << "For " << LogName(*fakequant_op) << " the MinMax range " + << "[" << minmax.min << ", " << minmax.max + << "] does not contain 0. " + << "Proceeding by tweaking it to contain 0, which will result " + "in poor accuracy."; + } minmax.min = std::min(minmax.min, 0.); minmax.max = std::max(minmax.max, 0.); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc deleted file mode 100644 index 9852c86c21..0000000000 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include <memory> -#include <string> -#include <unordered_map> -#include <vector> - -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace toco { - -bool ReorderActivationFunctions::Run(Model* model, std::size_t op_index) { - const auto ac_it = model->operators.begin() + op_index; - std::unique_ptr<Operator>& ac_op = *ac_it; - DCHECK(ac_op); - - if (ac_op->type != OperatorType::kRelu6 && - ac_op->type != OperatorType::kRelu1 && - ac_op->type != OperatorType::kRelu) { - return false; - } - - auto exchange_it = FindOpWithOutput(*model, ac_op->inputs[0]); - if (exchange_it == model->operators.end()) return false; - // Find the op producing the array passed to this activation function - std::unique_ptr<Operator>& exchange_op = *exchange_it; - DCHECK(exchange_op); - - // Allow activation functions to move up over any operator that does not - // change the values. - switch (exchange_op->type) { - case OperatorType::kExpandDims: - case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: - case OperatorType::kTranspose: - break; - default: - return false; - } - - DCHECK_EQ(exchange_op->outputs[0], ac_op->inputs[0]); - const auto exchange_op_input = exchange_op->inputs[0]; - const auto intermediate_array = exchange_op->outputs[0]; - const auto ac_op_output = ac_op->outputs[0]; - - int count_ops_consuming_output = - CountOpsWithInput(*model, intermediate_array); - DCHECK_GE(count_ops_consuming_output, 1); - if (count_ops_consuming_output > 1) { - AddMessageF( - "Not exchanging activation function with %s because it is consumed by " - "more than 1 other operator", - LogName(*exchange_op)); - return false; - } - - // If the ac_op was originally producing an output_array we can't trivially - // reorder as otherwise the output array name would change and break - // downstream assumptions. To work around that we perform some renaming below - // in that case at the cost of a bit more confusing array names in this rare - // case. - bool is_ac_op_output = - std::find(model->flags.output_arrays().begin(), - model->flags.output_arrays().end(), - ac_op_output) != model->flags.output_arrays().end(); - if (is_ac_op_output) { - // To preserve the output array name of the activation function we need to - // create a temporary to use to pass between ac->ex. - // - // Original: - // (a) -> EX -> (b) -> AC -> (c) - // Now: - // (a) -> AC -> (c') -> EX -> (c) - AddMessageF( - "Exchanging activation function %s with %s but renaming to preserve " - "output array %s", - LogName(*ac_op), LogName(*exchange_op), ac_op->outputs[0]); - - auto renamed_ac_op_output = - AvailableArrayName(*model, ac_op_output + "_exchange"); - ac_op->inputs[0] = exchange_op_input; - ac_op->outputs[0] = renamed_ac_op_output; - model->EraseArray(exchange_op->outputs[0]); - exchange_op->inputs[0] = renamed_ac_op_output; - exchange_op->outputs[0] = ac_op_output; - } else { - // Simply swap the order and update consumers to use the exchange_op output - // array (b). - // - // Original: - // (a) -> EX -> (b) -> AC -> (c) - // Now: - // (a) -> AC -> (c) -> EX -> (b) - AddMessageF("Exchanging activation function %s with %s", LogName(*ac_op), - LogName(*exchange_op)); - - Operator* consumer = GetFirstOpWithInput(*model, ac_op_output); - while (consumer) { - for (int i = 0; i < consumer->inputs.size(); ++i) { - if (consumer->inputs[i] == ac_op_output) { - consumer->inputs[i] = intermediate_array; - } - } - consumer = GetFirstOpWithInput(*model, ac_op_output); - } - ac_op->inputs[0] = exchange_op_input; - exchange_op->inputs[0] = ac_op_output; - } - - // Clear shapes; this will allow shape propagation to fix the sizes for us. - model->GetOrCreateArray(ac_op->outputs[0]).clear_shape(); - model->GetOrCreateArray(exchange_op->outputs[0]).clear_shape(); - - // Finally, reorder operators. Note that this only works when there are no - // other direct descendents of the exchange_op. - ac_op.swap(exchange_op); - - return true; -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc new file mode 100644 index 0000000000..9f5b7920cb --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iterator> +#include <memory> +#include <string> +#include <unordered_set> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool IsElementwiseOperator(OperatorType optype) { + switch (optype) { + case OperatorType::kCast: + case OperatorType::kExp: + case OperatorType::kFloor: + case OperatorType::kNeg: + case OperatorType::kRelu: + case OperatorType::kRelu1: + case OperatorType::kRelu6: + case OperatorType::kTanh: + case OperatorType::kTensorFlowSqrt: + case OperatorType::kTensorFlowSquare: + return true; + default: + return false; + } +} + +bool IsMoveOperator(OperatorType optype) { + switch (optype) { + case OperatorType::kDepthToSpace: + case OperatorType::kExpandDims: + case OperatorType::kSpaceToDepth: + case OperatorType::kSqueeze: + case OperatorType::kTensorFlowReshape: + case OperatorType::kTranspose: + return true; + default: + return false; + } +} + +} // namespace + +// Swap elementwise operators such that all value operators occur before all +// element move operators, e.g. negation then transpose. +bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) { + const auto element_op_it = model->operators.begin() + op_index; + std::unique_ptr<Operator>& element_op = *element_op_it; + if (!IsElementwiseOperator(element_op->type)) { + return false; + } + + const string intermediate_name = element_op->inputs[0]; + auto it = FindOpWithOutput(*model, intermediate_name); + if (it == model->operators.end()) { + AddMessageF("No preceding operator"); + return false; + } + + std::unique_ptr<Operator>& move_op = *it; + if (!IsMoveOperator(move_op->type)) { + AddMessageF("Preceding operator is not a move operator"); + return false; + } + + if (CountOpsWithInput(*model, intermediate_name) != 1) { + AddMessageF("Input %s used elsewhere", intermediate_name); + return false; + } + + // Check that the intermediate is discardable. + if (!IsDiscardableArray(*model, intermediate_name)) { + AddMessageF( + "Cannot swap elementwise as it would invalidate %s which is " + "an output array.", + intermediate_name); + return false; + } + + // op->inputs may change so we need to keep a value by copy. + const string input_name = move_op->inputs[0]; + const string output_name = element_op->outputs[0]; + + AddMessageF("Swapping around operators with %s and %s", LogName(*element_op), + LogName(*move_op)); + + // If the output array is an exit node for the graph then we need to retain + // the name as an output node. This makes the naming scheme a little confusing + // but is required in this rare case. + if (!IsDiscardableArray(*model, output_name)) { + // The output name of the sequence needs to stay static, so create a new + // array new use for the intermediate. + const auto new_intermediate_name = + AvailableArrayName(*model, element_op->outputs[0] + "_reorder"); + AddMessageF("Adding new array %s to preserve output array name %s", + new_intermediate_name, output_name); + + element_op->inputs[0] = input_name; + element_op->outputs[0] = new_intermediate_name; + model->EraseArray(intermediate_name); + move_op->inputs[0] = new_intermediate_name; + move_op->outputs[0] = output_name; + } else { + // The intermediate array is now the output array. + for (int i = 0; i < model->operators.size(); i++) { + Operator* consumer = model->operators[i].get(); + for (int j = 0; j < consumer->inputs.size(); j++) { + if (consumer->inputs[j] == output_name) { + consumer->inputs[j] = intermediate_name; + } + } + } + + element_op->inputs[0] = input_name; + move_op->inputs[0] = output_name; + } + + // Reset both arrays as shape, type, min/max, etc can all change because of + // the position swap. + model->EraseArray(element_op->outputs[0]); + model->EraseArray(move_op->outputs[0]); + + // Reconstruct. + model->GetOrCreateArray(element_op->outputs[0]); + model->GetOrCreateArray(move_op->outputs[0]); + + // Swap the order of the operators. + element_op.swap(move_op); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc new file mode 100644 index 0000000000..9e7fe1b1cc --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc @@ -0,0 +1,248 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iterator> +#include <memory> +#include <string> +#include <unordered_set> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool OperatorReady(const Model& model, const Operator* op) { + if (!model.HasArray(op->inputs[0]) || !model.HasArray(op->inputs[1]) || + !model.HasArray(op->outputs[0])) { + return false; + } + + if (!model.GetArray(op->inputs[0]).has_shape() || + !model.GetArray(op->outputs[0]).has_shape()) { + // Input and output needs the shape. + return false; + } + + if (!model.GetArray(op->inputs[1]).buffer) { + // Buffer needs to be a constant. + return false; + } + + return true; +} + +// Utility function to filter out a value. +void Filter(std::vector<int>* vec, int value) { + vec->erase(std::remove(vec->begin(), vec->end(), value), vec->end()); +} + +// Computes a new permutation used to swap a reshape-transpose to a +// transpose-reshape. In this case the permutation operates on the intermediate +// shape. +std::vector<int> ComputeNewPerm(std::vector<int> input_dims, + std::vector<int> intermediate_dims, + std::vector<int> perm) { + // These are the major axis of the input. + std::vector<int> input_indices; + for (int i = 0; i < input_dims.size(); i++) { + if (input_dims[i] != 1) { + input_indices.push_back(i); + } + } + + // This maps which indices of the input produced the intermediate indices for + // non-unary dimensions. + std::unordered_map<int, int> intermediate_to_input_indices_map; + for (int i = 0; i < intermediate_dims.size(); i++) { + if (intermediate_dims[i] != 1) { + intermediate_to_input_indices_map[i] = + input_indices[intermediate_to_input_indices_map.size()]; + } + } + + // Translate the transpose permutation to a new permutation starting with the + // major indices. + std::vector<int> new_perm; + new_perm.reserve(input_dims.size()); + for (int i = 0; i < perm.size(); i++) { + if (intermediate_dims[perm[i]] == 1) continue; + + new_perm.push_back(intermediate_to_input_indices_map[perm[i]]); + } + + // Fill the rest of the transpose in with the ones. + for (int index = 0; index < input_dims.size(); index++) { + if (input_dims[index] == 1) { + new_perm.push_back(index); + } + } + + CHECK_EQ(new_perm.size(), input_dims.size()); + return new_perm; +} + +} // namespace + +// Swaps reshape-transpose to transpose-reshape whenever possible. This is +// possible when the reshape does not affect memory ordering. +bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) { + auto transpose_it = model->operators.begin() + op_index; + + TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>( + transpose_it->get(), OperatorType::kTranspose); + + if (transpose_op == nullptr) { + return false; + } + + if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) { + // Wait for values to propagate. + return false; + } + + // Find the operator that produces the transpose op. + auto reshape_it = FindOpWithOutput(*model, transpose_op->inputs[0]); + if (reshape_it == model->operators.end()) { + return false; + } + + TensorFlowReshapeOperator* reshape_op = + ConvertOperator<TensorFlowReshapeOperator*>( + reshape_it->get(), OperatorType::kTensorFlowReshape); + if (reshape_op == nullptr) { + return false; + } + + // Ignore if the reshape is uninitialized. + if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) { + return false; + } + + // Need to copy to keep static if permutated. + const string input_name = reshape_op->inputs[0]; + const string intermediate_name = reshape_op->outputs[0]; + const string output_name = transpose_op->outputs[0]; + + // Intermediate should not be consumed by any other operators. + if (CountOpsWithInput(*model, intermediate_name) != 1) { + AddMessageF("Input %s used elsewhere", intermediate_name); + return false; + } + + // Check that the intermediate is not an output array. + if (!IsDiscardableArray(*model, intermediate_name)) { + AddMessageF( + "Cannot reorder reshape-transpose as it would invalidate %s which is " + "an output array.", + intermediate_name); + return false; + } + + // Get the arrays. + const auto& input_array = model->GetArray(input_name); + const auto& intermediate_array = model->GetArray(intermediate_name); + const auto& output_array = model->GetArray(output_name); + + // Get the shapes of each array. + Shape input_shape = input_array.shape(); + Shape intermediate_shape = intermediate_array.shape(); + Shape output_shape = output_array.shape(); + + // Assign ids to non-unary indices. + std::vector<int> input_dims = input_shape.dims(); + std::vector<int> intermediate_dims = intermediate_shape.dims(); + std::vector<int> output_dims = output_shape.dims(); + + // If the reshape is equivalent to a transpose with fewer/more unary + // dimensions then it can be moved between the transpose. + if (!ReshapeIsEquivalentToTranspose(*model, reshape_op, + true /*allow_extra_unary_dims*/)) { + return false; + } + + if (!IsDiscardableArray(*model, output_name)) { + // The output name of the sequence needs to stay static, so create a new + // array new use for the intermediate. + const auto new_intermediate_name = + AvailableArrayName(*model, transpose_op->outputs[0] + "_exchange"); + AddMessageF("Adding new array %s to preserve output array name %s", + new_intermediate_name, transpose_op->outputs[0]); + transpose_op->inputs[0] = input_name; + transpose_op->outputs[0] = new_intermediate_name; + reshape_op->inputs[0] = new_intermediate_name; + reshape_op->outputs[0] = output_name; + model->EraseArray(intermediate_name); + } else { + // The intermediate array is now the output array. + for (int i = 0; i < model->operators.size(); i++) { + Operator* consumer = model->operators[i].get(); + for (int j = 0; j < consumer->inputs.size(); j++) { + if (consumer->inputs[j] == output_name) { + consumer->inputs[j] = intermediate_name; + } + } + } + + transpose_op->inputs[0] = input_name; + reshape_op->inputs[0] = output_name; + } + + // If transposes constant buffer is used elsewhere, make a new copy. + if (CountOpsWithInput(*model, transpose_op->inputs[1]) != 1) { + transpose_op->inputs[1] = + AvailableArrayName(*model, transpose_op->inputs[1] + "_copy"); + } + + // Make the new transpose permutation. + const std::vector<int> new_perm = + ComputeNewPerm(input_dims, intermediate_dims, transpose_op->perm); + CHECK_EQ(input_dims.size(), new_perm.size()); + + auto& transpose_array = model->GetOrCreateArray(transpose_op->inputs[1]); + transpose_array.GetMutableBuffer<ArrayDataType::kInt32>().data = new_perm; + *(transpose_array.mutable_shape()->mutable_dims()) = { + static_cast<int>(new_perm.size())}; + transpose_op->perm = new_perm; + + // If the reshape's constant buffer is reused, create a new one. + if (CountOpsWithInput(*model, reshape_op->inputs[1]) != 1) { + reshape_op->inputs[1] = + AvailableArrayName(*model, reshape_op->inputs[1] + "_copy"); + } + + // We need to modify the reshape input array to target the new output size. + auto& reshape_array = model->GetOrCreateArray(reshape_op->inputs[1]); + reshape_array.GetMutableBuffer<ArrayDataType::kInt32>().data = output_dims; + *(reshape_array.mutable_shape()->mutable_dims()) = { + static_cast<int>(output_shape.dimensions_count())}; + reshape_op->shape.clear(); + + AddMessageF("Swapping around operators between %s and %s", input_name, + output_name); + + model->GetOrCreateArray(transpose_op->outputs[0]).clear_shape(); + model->GetOrCreateArray(reshape_op->outputs[0]).clear_shape(); + + // Swap the order of the operators. + transpose_it->swap(*reshape_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc new file mode 100644 index 0000000000..88d06d7dc7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace toco { + +template <ArrayDataType Type> +bool ComputeRandomUniformArray(Model* model, RandomUniformOperator* op) { + typedef tensorflow::random::UniformDistribution< + tensorflow::random::PhiloxRandom, DataType<Type>> + Distribution; + + // Allocate output + auto& output_array = model->GetArray(op->outputs[0]); + CHECK(output_array.data_type == Type); + std::vector<DataType<Type>>& data = + output_array.GetMutableBuffer<Type>().data; + data.resize(RequiredBufferSizeForShape(output_array.shape())); + + // We use the same random number generator and distribution as TensorFlow to + // produce the exact same values given the same seeds. See + // tensorflow::functor::FillPhiloxRandomTask<Distribution, false> in + // //third_party/tensorflow/core/kernels/random_op.cc for the implementation. + tensorflow::random::PhiloxRandom generator(op->seed, op->seed2); + Distribution dist; + + // The generator creates Distribution::kResultElementCount samples at a time. + size_t offset = 0; + size_t num_samples = Distribution::kResultElementCount; + while (offset < data.size()) { + const typename Distribution::ResultType samples = dist(&generator); + std::copy(&samples[0], + &samples[0] + std::min(num_samples, data.size() - offset), + &data[0] + offset); + offset += num_samples; + } + + return true; +} + +bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* base_op = it->get(); + if (base_op->type != OperatorType::kRandomUniform) { + return false; + } + auto* op = static_cast<RandomUniformOperator*>(base_op); + + CHECK_EQ(op->inputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return false; + } + + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes + return false; + } + + if ((op->seed == 0) && (op->seed2 == 0)) { + LOG(WARNING) << "RandomUniform op outputting \"" << op->outputs[0] + << "\" is truly random (using /dev/random system entropy). " + "Therefore, cannot resolve as constant. Set \"seed\" or " + "\"seed2\" attr non-zero to fix this"; + return false; + } + + switch (output_array.data_type) { + case ArrayDataType::kFloat: + if (!ComputeRandomUniformArray<ArrayDataType::kFloat>(model, op)) { + return false; + } + break; + // For future support of double or half. + // case ArrayDataType::kDouble... + default: + LOG(FATAL) + << "Unsupported data type given to RandomUniform op with output \"" + << op->outputs[0] << "\""; + break; + } + + // Erase input arrays if no longer used + toco::DeleteArrayIfUsedOnce(op->inputs[0], model); + + // Erase the operator + model->operators.erase(it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index f38203c80f..2a236d3f98 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -60,6 +60,13 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { string input_lhs = matmul_op->inputs[0]; string input_rhs = transpose_op->outputs[0]; + // Construct the new FullyConnectedOperator. + auto* fc_op = new FullyConnectedOperator; + fc_op->outputs = matmul_op->outputs; + + // Insert the newly constructed FullyConnectedOperator. + model->operators.emplace(matmul_it, fc_op) + 1; + // Find the op producing the array passed to this MatMul auto previous_op_it = model->operators.begin(); bool found = false; @@ -76,13 +83,6 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { } Operator* previous_op = (found) ? previous_op_it->get() : nullptr; - // Construct the new FullyConnectedOperator. - auto* fc_op = new FullyConnectedOperator; - fc_op->outputs = matmul_op->outputs; - - // Insert the newly constructed FullyConnectedOperator. - model->operators.emplace(matmul_it, fc_op) + 1; - // Refresh iterator. matmul_it = model->operators.begin(); for (; matmul_it != model->operators.end(); ++matmul_it) { diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index c26e4bddff..876479079b 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -74,7 +74,7 @@ const string& GetStringAttr(const NodeDef& node, const string& attr_name) { return attr.s(); } -int GetIntAttr(const NodeDef& node, const string& attr_name) { +int64 GetIntAttr(const NodeDef& node, const string& attr_name) { CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n" << node.DebugString(); const auto& attr = node.attr().at(attr_name); @@ -569,6 +569,23 @@ void ConvertBiasAddOperator(const NodeDef& node, model->operators.emplace_back(biasadd); } +void ConvertRandomUniform(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "RandomUniform"); + CheckInputsCount(node, tf_import_flags, 1); + + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32); + auto op = absl::make_unique<RandomUniformOperator>(); + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + op->dtype = ConvertDataType(GetDataTypeAttr(node, "dtype")); + op->seed = GetIntAttr(node, "seed"); + op->seed2 = GetIntAttr(node, "seed2"); + CHECK(model != nullptr); + model->operators.emplace_back(std::move(op)); +} + void ConvertReluOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1931,7 +1948,7 @@ void ConvertTopKV2Operator(const NodeDef& node, // K can be encoded as attr (TopK) convert it to a const. if (HasAttr(node, "k")) { string k_array = CreateConstArray<ArrayDataType::kInt32>( - model, node.name() + "k", {GetIntAttr(node, "k")}); + model, node.name() + "k", {static_cast<int32>(GetIntAttr(node, "k"))}); op->inputs.push_back(k_array); } else { CheckInputsCount(node, tf_import_flags, 2); @@ -2168,6 +2185,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef( } else if (node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch") { ConvertDynamicStitchOperator(node, tf_import_flags, model); + } else if (node.op() == "RandomUniform") { + ConvertRandomUniform(node, tf_import_flags, model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 5199e292e1..9bd72e7de1 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -60,6 +60,7 @@ enum class OperatorType { kMaxPool, kFakeQuant, kMul, + kRandomUniform, kRange, kRank, kRelu, @@ -946,6 +947,13 @@ struct FloorModOperator : Operator { FloorModOperator() : Operator(OperatorType::kFloorMod) {} }; +struct RandomUniformOperator : Operator { + RandomUniformOperator() : Operator(OperatorType::kRandomUniform) {} + ArrayDataType dtype = ArrayDataType::kNone; + int64 seed; + int64 seed2; +}; + // Creates a sequence of numbers that begins at start and extends by increments // of delta up to but not including limit. // @@ -1499,7 +1507,14 @@ class Shape { // We still have that one convenience accessor to avoid // the awkward double bracket issue: shape.dims()[i]. - int dims(int i) const { return dims_[i]; } + int dims(int i) const { + // Always check for out-of-bounds accesses, even in optimized builds where + // standard assertions are disabled. Out-of-bounds access here is a common + // occurence. + CHECK_GE(i, 0); + CHECK_GT(dims_.size(), i); + return dims_[i]; + } bool operator==(const Shape& comp) const { return (this->dims_ == comp.dims()); diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 4264f21c76..245eb52444 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -160,6 +160,11 @@ bool ParseModelFlagsFromCommandLineFlags( "Path to an optional file containing a serialized ArraysExtraInfo " "proto allowing to pass extra information about arrays not specified " "in the input model file, such as extra MinMax information."), + Flag("model_flags_file", parsed_flags.model_flags_file.bind(), + parsed_flags.model_flags_file.default_value(), + "Path to an optional file containing a serialized ModelFlags proto. " + "Options specified on the command line will override the values in " + "the proto."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); @@ -182,7 +187,24 @@ void ReadModelFlagsFromCommandLineFlags( const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) { toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet"); -// "batch" flag only exists internally + // Load proto containing the initial model flags. + // Additional flags specified on the command line will overwrite the values. + if (parsed_model_flags.model_flags_file.specified()) { + string model_flags_file_contents; + QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(), + &model_flags_file_contents, + port::file::Defaults()) + .ok()) + << "Specified --model_flags_file=" + << parsed_model_flags.model_flags_file.value() + << " was not found or could not be read"; + QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents, + model_flags)) + << "Specified --model_flags_file=" + << parsed_model_flags.model_flags_file.value() + << " could not be parsed"; + } + #ifdef PLATFORM_GOOGLE CHECK(!((base::SpecifiedOnCommandLine("batch") && parsed_model_flags.variable_batch.specified()))) diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto index 42e0f54826..835dea49eb 100644 --- a/tensorflow/contrib/lite/toco/model_flags.proto +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -98,8 +98,8 @@ message ArraysExtraInfo { message Entry { // Next ID to use: 7. optional string name = 1; - optional float min = 2; - optional float max = 3; + optional double min = 2; + optional double max = 3; optional IODataType data_type = 4; optional InputArrayShape shape = 5; optional float constant_float_value = 6; diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 0cb348bda5..f991529569 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -204,17 +204,22 @@ class BatchToSpaceND TocoOperator* op) const override {} }; -class Cast : public CustomOperator<CastOperator> { +class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions, + ::tflite::BuiltinOptions_CastOptions> { public: - using CustomOperator::CustomOperator; - void WriteOptions(const TocoOperator& op, - flexbuffers::Builder* fbb) const override { - fbb->Int("src_data_type", DataType::Serialize(op.src_data_type)); - fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type)); + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateCastOptions(*builder, + DataType::Serialize(op.src_data_type), + DataType::Serialize(op.dst_data_type)); } - void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { - op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64()); - op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64()); + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->src_data_type = DataType::Deserialize(options.in_data_type()); + op->dst_data_type = DataType::Deserialize(options.out_data_type()); } }; @@ -827,9 +832,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2)); ops.emplace_back( new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell)); + ops.emplace_back( + new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); // Custom Operators. - ops.emplace_back(new Cast("CAST", OperatorType::kCast)); ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index f7a213ecfc..4783843b7f 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -131,7 +131,7 @@ TEST_F(OperatorTest, BuiltinMean) { EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims); } -TEST_F(OperatorTest, CustomCast) { +TEST_F(OperatorTest, BuiltinCast) { CastOperator op; op.src_data_type = ArrayDataType::kFloat; op.dst_data_type = ArrayDataType::kUint8; diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 30dd6fab9e..76e9a27aef 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -74,11 +74,14 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowMatMul); transformations->Add(new FuseBinaryIntoPrecedingAffine); transformations->Add(new FuseBinaryIntoFollowingAffine); - transformations->Add(new ReorderActivationFunctions); + transformations->Add(new MergeReshapeIntoPrecedingTranspose); + transformations->Add(new ReorderElementwiseUnary); + transformations->Add(new ReorderReshapeTranspose); transformations->Add(new ResolveBatchNormalization); transformations->Add(new ResolveConstantBinaryOperator); transformations->Add(new ResolveConstantFill); transformations->Add(new ResolveConstantGather); + transformations->Add(new ResolveConstantRandomUniform); transformations->Add(new ResolveConstantRange); transformations->Add(new ResolveConstantStack); transformations->Add(new ResolveConstantStridedSlice); diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index f3f50487ff..56fa8f4b69 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -297,6 +297,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(L2Pool) HANDLE_OPERATORTYPENAME_CASE(FakeQuant) HANDLE_OPERATORTYPENAME_CASE(Mul) + HANDLE_OPERATORTYPENAME_CASE(RandomUniform) HANDLE_OPERATORTYPENAME_CASE(Relu) HANDLE_OPERATORTYPENAME_CASE(Relu1) HANDLE_OPERATORTYPENAME_CASE(Relu6) @@ -1920,6 +1921,35 @@ bool IsDiscardableArray(const Model& model, const string& array_name) { return true; } +bool ReshapeIsEquivalentToTranspose(const Model& model, + const TensorFlowReshapeOperator* op, + bool allow_extra_unary_dims) { + CHECK(!op->shape.empty()); + CHECK(model.HasArray(op->inputs[0])); + CHECK(model.HasArray(op->outputs[0])); + + const auto& input_array = model.GetArray(op->inputs[0]); + const auto& output_array = model.GetArray(op->outputs[0]); + + CHECK(input_array.has_shape()); + CHECK(output_array.has_shape()); + + std::vector<int> in_shape = input_array.shape().dims(); + std::vector<int> out_shape = output_array.shape().dims(); + + // If the reshape changes the number of dimensions so it cannot be interpreted + // as a transpose. + if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) { + return false; + } + + in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1), + in_shape.end()); + out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), + out_shape.end()); + return in_shape == out_shape; +} + void CheckFinalDataTypesSatisfied(const Model& model) { for (const auto& array_entry : model.GetArrayMap()) { const auto& array = *array_entry.second; @@ -1976,9 +2006,9 @@ void UseArraysExtraInfo(Model* model) { continue; } auto& array = model->GetArray(entry.name()); - auto& minmax = array.GetOrCreateMinMax(); if (entry.has_min() || entry.has_max()) { CHECK_EQ(entry.has_min(), entry.has_max()); + auto& minmax = array.GetOrCreateMinMax(); minmax.min = entry.min(); minmax.max = entry.max(); } @@ -1997,11 +2027,12 @@ void UseArraysExtraInfo(Model* model) { } if (entry.has_constant_float_value()) { CHECK(array.has_shape()); - CHECK(array.data_type == ArrayDataType::kFloat); - auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data; - data.resize(RequiredBufferSizeForShape(array.shape())); - for (float& f : data) { - f = entry.constant_float_value(); + if (array.data_type == ArrayDataType::kFloat) { + auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data; + data.resize(RequiredBufferSizeForShape(array.shape())); + for (float& f : data) { + f = entry.constant_float_value(); + } } } } diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index d3b7224fe3..259ee7fbd0 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -169,10 +169,23 @@ void GetQuantizationParamsFromMinMax(const MinMax& minmax, ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax); } +template <typename T> +T ConvertOperator(Operator* o, OperatorType type) { + if (o != nullptr && o->type == type) { + return static_cast<T>(o); + } + + return nullptr; +} + void CheckIsReadyForQuantization(const Model& model); void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min, double default_ranges_max); +bool ReshapeIsEquivalentToTranspose(const Model& model, + const TensorFlowReshapeOperator* op, + bool allow_extra_unary_dims); + inline int Offset(const Shape& shape, const std::vector<int>& indices) { DCHECK_EQ(shape.dimensions_count(), indices.size()); const int dims_count = shape.dimensions_count(); diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD index 02b4f80252..f616207d46 100644 --- a/tensorflow/contrib/lookup/BUILD +++ b/tensorflow/contrib/lookup/BUILD @@ -46,4 +46,5 @@ tf_py_test( "//tensorflow/python:variables", ], grpc_enabled = True, + tags = ["no_windows"], # TODO: needs investigation on Windows ) diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index 77c936d8c5..76428bc1d4 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -12,6 +12,7 @@ tensorflow/core/platform/posix/env.cc tensorflow/core/platform/posix/load_library.cc tensorflow/core/platform/posix/env_time.cc tensorflow/core/platform/file_system.cc +tensorflow/core/platform/file_system_helper.cc tensorflow/core/platform/env.cc tensorflow/core/platform/env_time.cc tensorflow/core/platform/setround.cc diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index 6cbfd03881..334e70318d 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -31,7 +31,7 @@ tf_custom_op_library( "kernels/nccl_ops.cc", ], deps = if_cuda([ - "@nccl_archive//:nccl", + "@local_config_nccl//:nccl", "//tensorflow/core:gpu_headers_lib", ]), ) @@ -61,7 +61,7 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "@nccl_archive//:nccl", + "@local_config_nccl//:nccl", ], ) @@ -80,7 +80,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:proto_text", "//tensorflow/core:stream_executor", - "@nccl_archive//:nccl", + "@local_config_nccl//:nccl", ], alwayslink = 1, ) diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h index bb219e0edc..6ff8cea84e 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.h +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h @@ -20,7 +20,7 @@ limitations under the License. #include <unordered_map> #include <vector> -#include "src/nccl.h" +#include "third_party/nccl/nccl.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/mutex.h" diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc index 266d4f6f0d..c2b76caef3 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #include <vector> -#include "src/nccl.h" +#include "third_party/nccl/nccl.h" #include "tensorflow/contrib/nccl/kernels/nccl_manager.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc index a4de46a93f..4676e937e5 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/lib/strings/str_util.h" #if GOOGLE_CUDA #include <forward_list> @@ -254,7 +255,7 @@ class NcclReplacePass : public GraphOptimizationPass { // Find reduction and broadcast ops and replace them with Send/Recv ops. for (Node* node : graph->op_nodes()) { StringPiece type = node->type_string(); - if (!type.starts_with("Nccl")) { + if (!str_util::StartsWith(type, "Nccl")) { continue; } if (type == "NcclReduce") { diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 471992fdac..25d19578ea 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -866,7 +866,7 @@ class OptimizerV2(optimizer_v1.Optimizer): raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, v in grads_and_vars],)) return distribute_lib.get_tower_context().merge_call( - self.distributed_apply, filtered, global_step=global_step, name=name) + self._distributed_apply, filtered, global_step=global_step, name=name) def _get_or_create_state(self, var_list=None): """Either looks up or creates `_OptimizerV2State`. @@ -899,7 +899,7 @@ class OptimizerV2(optimizer_v1.Optimizer): self._per_graph_state[graph_key] = per_graph_state return per_graph_state - def distributed_apply(self, distribution, grads_and_vars, global_step, name): + def _distributed_apply(self, distribution, grads_and_vars, global_step, name): """`apply_gradients` for use with a `DistributionStrategy`.""" reduced_grads = distribution.batch_reduce("sum", grads_and_vars) var_list = [v for _, v in grads_and_vars] diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 2889016a84..d53d4d7b10 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -416,7 +416,9 @@ def _InsertQuantOp(context, # name_prefix starts with 'TPUReplicate/loop/'; without dropping it # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which # breaks things later. - name_prefix = common.DropStringPrefix(name_prefix, ops.get_name_scope() + '/') + name_scope = ops.get_name_scope() + if name_scope: + name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/') inputs = producer.outputs[0] # Prevent ops from being quantized multiple times. Bypass ops can sometimes diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 98f05c8bfc..8d057d3710 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -247,6 +247,27 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertTrue(not op.name.startswith('name_scope/name_scope/'), 'Broken op: %s' % op.name) + def testWithNullNameScope(self): + self._RunTestOverParameters(self._TestWithNullNameScope) + + def _TestWithNullNameScope(self, is_training): + graph = ops.Graph() + with graph.as_default(): + with graph.name_scope(None): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + _ = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test') + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Passes if Quantize() does not crash. + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/remote_fused_graph/pylib/BUILD b/tensorflow/contrib/remote_fused_graph/pylib/BUILD index 996b55f9b8..3aa8a14f44 100644 --- a/tensorflow/contrib/remote_fused_graph/pylib/BUILD +++ b/tensorflow/contrib/remote_fused_graph/pylib/BUILD @@ -38,7 +38,6 @@ py_test( size = "small", srcs = ["python/ops/remote_fused_graph_ops_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ ":remote_fused_graph_ops_py", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index faad40d335..e431c464ef 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -53,6 +53,7 @@ py_test( size = "small", srcs = ["python/saved_model/reader_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows visibility = ["//visibility:private"], deps = [ ":saved_model_py", diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 31717305e7..9c08859180 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -151,6 +151,7 @@ py_test( name = "gc_test", srcs = ["gc_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows visibility = ["//visibility:private"], deps = [ ":gc", diff --git a/tensorflow/contrib/slim/python/slim/data/BUILD b/tensorflow/contrib/slim/python/slim/data/BUILD index dc12e67fc6..eef043e832 100644 --- a/tensorflow/contrib/slim/python/slim/data/BUILD +++ b/tensorflow/contrib/slim/python/slim/data/BUILD @@ -61,6 +61,7 @@ py_test( name = "dataset_data_provider_test", srcs = ["dataset_data_provider_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":dataset", ":dataset_data_provider", diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD index d4096751c4..30be14c10c 100644 --- a/tensorflow/contrib/stat_summarizer/BUILD +++ b/tensorflow/contrib/stat_summarizer/BUILD @@ -31,4 +31,5 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:variables", ], + tags = ["no_windows"], ) diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 11a59ec22b..136856c015 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -539,7 +539,6 @@ py_test( srcs = ["client/random_forest_test.py"], srcs_version = "PY2AND3", tags = [ - "no_windows", "nomac", # b/63258195 "notsan", ], diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD index f4efd9717d..2b6a2b2f3c 100644 --- a/tensorflow/contrib/tensorboard/BUILD +++ b/tensorflow/contrib/tensorboard/BUILD @@ -9,6 +9,7 @@ exports_files(["LICENSE"]) # For platform specific build config load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load("//tensorflow:tensorflow.bzl", "py_test") tf_proto_library( name = "protos_all", @@ -81,6 +82,7 @@ py_test( size = "small", srcs = ["plugins/trace/trace_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":trace", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc index c61b465596..cd3f712256 100644 --- a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/event.pb.h" @@ -58,7 +59,7 @@ class SummaryFileWriterTest : public ::testing::Test { TF_CHECK_OK(env_.GetChildren(testing::TmpDir(), &files)); bool found = false; for (const string& f : files) { - if (StringPiece(f).contains(test_name)) { + if (str_util::StrContains(f, test_name)) { if (found) { return errors::Unknown("Found more than one file for ", test_name); } diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 40cf9147b3..32e948a009 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -25,7 +25,10 @@ py_test( srcs = ["predict_test.py"], data = ["data/period_trend.csv"], srcs_version = "PY2AND3", - tags = ["notsan"], # b/67513579 + tags = [ + "no_windows", # TODO: needs investigation on Windows + "notsan", # b/67513579 + ], deps = [ ":predict", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index 9b6c08150c..d2746032a0 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -88,10 +88,14 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/estimator:export", + "//tensorflow/python/feature_column", ], ) @@ -132,7 +136,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_keys", - "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -141,6 +144,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:state_ops", "//tensorflow/python:summary", + "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:estimator_py", @@ -156,23 +160,30 @@ py_test( "head_test.py", ], srcs_version = "PY2AND3", - tags = [ - "no_pip_gpu", # b/63391119 - ], + tags = ["no_pip_gpu"], # b/63391119 deps = [ + ":estimators", ":feature_keys", ":head", + ":input_pipeline", ":model", ":state_management", + "//tensorflow/contrib/timeseries/examples:lstm", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variables", "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/feature_column", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", + "//third_party/py/numpy", + "@six_archive//:six", ], ) @@ -428,6 +439,7 @@ py_test( srcs_version = "PY2AND3", tags = [ "no_pip_gpu", # b/63391119 + "no_windows", # TODO: needs investigation on Windows ], deps = [ ":feature_keys", diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 469cea4fd2..886e1846e2 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -44,7 +44,7 @@ class TimeSeriesRegressor(estimator_lib.Estimator): """An Estimator to fit and evaluate a time series model.""" def __init__(self, model, state_manager=None, optimizer=None, model_dir=None, - config=None): + config=None, head_type=ts_head_lib.TimeSeriesRegressionHead): """Initialize the Estimator. Args: @@ -55,6 +55,8 @@ class TimeSeriesRegressor(estimator_lib.Estimator): from tf.train.Optimizer. Defaults to Adam with step size 0.02. model_dir: See `Estimator`. config: See `Estimator`. + head_type: The kind of head to use for the model (inheriting from + `TimeSeriesRegressionHead`). """ input_statistics_generator = math_utils.InputStatisticsFromMiniBatch( dtype=model.dtype, num_features=model.num_features) @@ -63,8 +65,8 @@ class TimeSeriesRegressor(estimator_lib.Estimator): if optimizer is None: optimizer = train.AdamOptimizer(0.02) self._model = model - ts_regression_head = ts_head_lib.time_series_regression_head( - model, state_manager, optimizer, + ts_regression_head = head_type( + model=model, state_manager=state_manager, optimizer=optimizer, input_statistics_generator=input_statistics_generator) model_fn = ts_regression_head.create_estimator_spec super(TimeSeriesRegressor, self).__init__( diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index 51d0c0ca3f..9f161c1695 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import tempfile import numpy +import six from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import estimators @@ -127,6 +128,12 @@ class TimeSeriesRegressorTest(test.TestCase): session=sess) # Test cold starting + six.assertCountEqual( + self, + [feature_keys.FilteringFeatures.TIMES, + feature_keys.FilteringFeatures.VALUES], + signatures.signature_def[ + feature_keys.SavedModelLabels.COLD_START_FILTER].inputs.keys()) batch_numpy_times = numpy.tile( numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1)) batch_numpy_values = numpy.ones([10, 30, 1]) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index 4cf6bbcfd4..a28a5872b8 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -39,27 +39,18 @@ from tensorflow.python.util import nest from tensorflow.python.summary import summary -def time_series_regression_head(model, - state_manager, - optimizer, - input_statistics_generator=None): - """Creates a `_Head` for time series regression. +class _NoStatePredictOutput(export_lib.PredictOutput): - Args: - model: A model for time series regression. - state_manager: A state manager. - optimizer: An optimizer. - input_statistics_generator: A input statistics generator. - - Returns: - An instance of `_Head` for time series regression. - """ - return _TimeSeriesRegressionHead(model, state_manager, optimizer, - input_statistics_generator) + def as_signature_def(self, receiver_tensors): + no_state_receiver_tensors = { + key: value for key, value in receiver_tensors.items() + if not key.startswith(feature_keys.State.STATE_PREFIX)} + return super(_NoStatePredictOutput, self).as_signature_def( + receiver_tensors=no_state_receiver_tensors) -class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access - """See `time_series_regression_head`.""" +class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access + """Determines input and output signatures for a time series model.""" def __init__(self, model, @@ -67,6 +58,15 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc optimizer, input_statistics_generator=None, name=None): + """Creates a `_Head` for time series regression. + + Args: + model: A model for time series regression. + state_manager: A state manager. + optimizer: An optimizer. + input_statistics_generator: A input statistics generator. + name: An optional name for the model. + """ self.model = model self.state_manager = state_manager self.optimizer = optimizer @@ -167,7 +167,7 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc export_lib.PredictOutput( state_to_dictionary(filtering_outputs.end_state)), feature_keys.SavedModelLabels.COLD_START_FILTER: - export_lib.PredictOutput( + _NoStatePredictOutput( state_to_dictionary(cold_filtering_outputs.end_state)) }, # Likely unused, but it is necessary to return `predictions` to satisfy @@ -255,6 +255,58 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc return self._serving_ops(features) +class OneShotPredictionHead(TimeSeriesRegressionHead): + """A time series head which exports a single stateless serving signature. + + The serving default signature exported by this head expects `times`, `values`, + and any exogenous features, but no state. `values` has shape `[batch_size, + filter_length, num_features]` and `times` has shape `[batch_size, + total_length]`, where `total_length > filter_length`. Any exogenous features + must have their shapes prefixed by the shape of the `times` feature. + + When serving, first performs filtering on the series up to `filter_length` + starting from the default start state for the model, then computes predictions + on the remainder of the series, returning them. + + Model state is neither accepted nor returned, so filtering must be performed + each time predictions are requested when using this head. + """ + + def _serving_ops(self, features): + """Add ops for serving to the graph.""" + with variable_scope.variable_scope("model", use_resource=True): + filtering_features = {} + prediction_features = {} + values_length = array_ops.shape( + features[feature_keys.FilteringFeatures.VALUES])[1] + for key, value in features.items(): + if key == feature_keys.State.STATE_TUPLE: + # Ignore state input. The model's default start state is replicated + # across the batch. + continue + if key == feature_keys.FilteringFeatures.VALUES: + filtering_features[key] = value + else: + filtering_features[key] = value[:, :values_length] + prediction_features[key] = value[:, values_length:] + cold_filtering_outputs = self.model.define_loss( + features=filtering_features, mode=estimator_lib.ModeKeys.EVAL) + prediction_features[feature_keys.State.STATE_TUPLE] = ( + cold_filtering_outputs.end_state) + with variable_scope.variable_scope("model", reuse=True): + prediction_outputs = self.model.predict( + features=prediction_features) + return estimator_lib.EstimatorSpec( + mode=estimator_lib.ModeKeys.PREDICT, + export_outputs={ + feature_keys.SavedModelLabels.PREDICT: + _NoStatePredictOutput(prediction_outputs), + }, + # Likely unused, but it is necessary to return `predictions` to satisfy + # the Estimator's error checking. + predictions={}) + + def _check_feature_shapes_compatible_with(features, compatible_with_name, compatible_with_value, diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index 3415061cfd..c606db76a6 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -18,12 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy +import six + +from tensorflow.contrib.timeseries.examples import lstm as lstm_example +from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators from tensorflow.contrib.timeseries.python.timeseries import feature_keys from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib +from tensorflow.contrib.timeseries.python.timeseries import input_pipeline from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import state_management +from tensorflow.python.client import session as session_lib from tensorflow.python.estimator import estimator_lib +from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -31,6 +39,9 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import adam from tensorflow.python.training import coordinator as coordinator_lib from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import training as train @@ -90,7 +101,7 @@ class EvaluationMetricsTests(test.TestCase): .count_up_to(10), dtype=dtypes.float32), (1, 1, 1)) } - model_fn = ts_head_lib.time_series_regression_head( + model_fn = ts_head_lib.TimeSeriesRegressionHead( model=_TickerModel(), state_manager=state_management.PassthroughStateManager(), optimizer=train.GradientDescentOptimizer(0.001)).create_estimator_spec @@ -127,7 +138,7 @@ class _StubModel(object): def _stub_model_fn(): - return ts_head_lib.time_series_regression_head( + return ts_head_lib.TimeSeriesRegressionHead( model=_StubModel(), state_manager=state_management.PassthroughStateManager(), optimizer=train.AdamOptimizer(0.001)).create_estimator_spec @@ -263,5 +274,76 @@ class PredictFeatureCheckingTests(test.TestCase): mode=estimator_lib.ModeKeys.PREDICT) +class OneShotTests(test.TestCase): + + def test_one_shot_prediction_head_export(self): + model_dir = self.get_temp_dir() + categorical_column = feature_column.categorical_column_with_hash_bucket( + key="categorical_exogenous_feature", hash_bucket_size=16) + exogenous_feature_columns = [ + feature_column.numeric_column( + "2d_exogenous_feature", shape=(2,)), + feature_column.embedding_column( + categorical_column=categorical_column, dimension=10)] + estimator = ts_estimators.TimeSeriesRegressor( + model=lstm_example._LSTMModel( + num_features=5, num_units=128, + exogenous_feature_columns=exogenous_feature_columns), + optimizer=adam.AdamOptimizer(0.001), + config=estimator_lib.RunConfig(tf_random_seed=4), + state_manager=state_management.ChainingStateManager(), + head_type=ts_head_lib.OneShotPredictionHead, + model_dir=model_dir) + train_features = { + feature_keys.TrainEvalFeatures.TIMES: numpy.arange( + 20, dtype=numpy.int64), + feature_keys.TrainEvalFeatures.VALUES: numpy.tile(numpy.arange( + 20, dtype=numpy.float32)[:, None], [1, 5]), + "2d_exogenous_feature": numpy.ones([20, 2]), + "categorical_exogenous_feature": numpy.array( + ["strkey"] * 20)[:, None] + } + train_input_fn = input_pipeline.RandomWindowInputFn( + input_pipeline.NumpyReader(train_features), shuffle_seed=2, + num_threads=1, batch_size=16, window_size=16) + estimator.train(input_fn=train_input_fn, steps=5) + input_receiver_fn = estimator.build_raw_serving_input_receiver_fn() + export_location = estimator.export_savedmodel(self.get_temp_dir(), + input_receiver_fn) + graph = ops.Graph() + with graph.as_default(): + with session_lib.Session() as session: + signatures = loader.load( + session, [tag_constants.SERVING], export_location) + self.assertEqual([feature_keys.SavedModelLabels.PREDICT], + list(signatures.signature_def.keys())) + predict_signature = signatures.signature_def[ + feature_keys.SavedModelLabels.PREDICT] + six.assertCountEqual( + self, + [feature_keys.FilteringFeatures.TIMES, + feature_keys.FilteringFeatures.VALUES, + "2d_exogenous_feature", + "categorical_exogenous_feature"], + predict_signature.inputs.keys()) + features = { + feature_keys.TrainEvalFeatures.TIMES: numpy.tile( + numpy.arange(35, dtype=numpy.int64)[None, :], [2, 1]), + feature_keys.TrainEvalFeatures.VALUES: numpy.tile(numpy.arange( + 20, dtype=numpy.float32)[None, :, None], [2, 1, 5]), + "2d_exogenous_feature": numpy.ones([2, 35, 2]), + "categorical_exogenous_feature": numpy.tile(numpy.array( + ["strkey"] * 35)[None, :, None], [2, 1, 1]) + } + feeds = { + graph.as_graph_element(input_value.name): features[input_key] + for input_key, input_value in predict_signature.inputs.items()} + fetches = {output_key: graph.as_graph_element(output_value.name) + for output_key, output_value + in predict_signature.outputs.items()} + output = session.run(fetches, feed_dict=feeds) + self.assertAllEqual((2, 15, 5), output["mean"].shape) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index ca25ccd2b8..5d33e23a42 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -40,6 +40,7 @@ py_test( timeout = "long", # Moderate but for asan srcs = ["state_space_model_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":state_space_model", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 3e32a7a85c..4de09dd988 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -159,6 +159,7 @@ py_library( name = "tpu_lib", srcs = [ "python/tpu/__init__.py", + "python/tpu/bfloat16.py", "python/tpu/device_assignment.py", "python/tpu/topology.py", "python/tpu/tpu.py", @@ -214,6 +215,7 @@ tf_py_test( ":datasets", ], grpc_enabled = True, + tags = ["no_windows"], ) tf_py_test( @@ -227,6 +229,7 @@ tf_py_test( "//tensorflow/python:framework", "//tensorflow/python:layers", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) tf_py_test( @@ -241,6 +244,17 @@ tf_py_test( ) tf_py_test( + name = "bfloat16_test", + size = "small", + srcs = ["python/tpu/bfloat16_test.py"], + additional_deps = [ + ":tpu", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + ], +) + +tf_py_test( name = "tpu_infeed_test", size = "small", srcs = ["python/tpu/tpu_infeed_test.py"], diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index ea6e874f2d..bb60f3e2d7 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -53,6 +53,7 @@ from __future__ import print_function # pylint: disable=wildcard-import,unused-import from tensorflow.contrib.tpu.python import profiler from tensorflow.contrib.tpu.python.ops.tpu_ops import * +from tensorflow.contrib.tpu.python.tpu.bfloat16 import * from tensorflow.contrib.tpu.python.tpu.device_assignment import * from tensorflow.contrib.tpu.python.tpu.topology import * from tensorflow.contrib.tpu.python.tpu.tpu import * diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto index 590db2c376..2a15875627 100644 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -79,6 +79,10 @@ message StepInfoResult { optional uint64 infeed_duration_ps = 3; // The start time of this step in picoseconds. optional uint64 begin_ps = 4; + // The waiting time within this step in picoseconds. + optional uint64 wait_duration_ps = 5; + // The time spent on cross-replica-sum in picoseconds. + optional uint64 crs_duration_ps = 6; } // Result proto for a sequence of steps. diff --git a/tensorflow/contrib/tpu/python/tpu/bfloat16.py b/tensorflow/contrib/tpu/python/tpu/bfloat16.py new file mode 100644 index 0000000000..5e49af6408 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/bfloat16.py @@ -0,0 +1,77 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Helper context for running models with bfloat16.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import tf_contextlib + + +def _get_custom_getter(): + """Returns a custom getter that this class's methods must be called under. + + All methods of this class must be called under a variable scope that was + passed this custom getter. Example: + + ```python + network = ConvNetBuilder(...) + with tf.variable_scope('cg', custom_getter=network.get_custom_getter()): + network.conv(...) + # Call more methods of network here + ``` + + Currently, this custom getter only does anything if self.use_tf_layers is + True. In that case, it causes variables to be stored as dtype + self.variable_type, then casted to the requested dtype, instead of directly + storing the variable as the requested dtype. + """ + + def inner_custom_getter(getter, *args, **kwargs): + """Custom getter that forces variables to have type self.variable_type.""" + cast_to_bfloat16 = False + requested_dtype = kwargs['dtype'] + if requested_dtype == dtypes.bfloat16: + # Only change the variable dtype if doing so does not decrease variable + # precision. + kwargs['dtype'] = dtypes.float32 + cast_to_bfloat16 = True + var = getter(*args, **kwargs) + # This if statement is needed to guard the cast, because batch norm + # assigns directly to the return value of this custom getter. The cast + # makes the return value not a variable so it cannot be assigned. Batch + # norm variables are always in fp32 so this if statement is never + # triggered for them. + if cast_to_bfloat16: + var = math_ops.cast(var, dtypes.bfloat16) + return var + + return inner_custom_getter + + +@tf_contextlib.contextmanager +def bfloat16_scope(): + """Scope class for bfloat16 variables so that the model uses custom getter. + + This enables variables to be read as bfloat16 type when using get_variable. + """ + with variable_scope.variable_scope( + 'bfloat16', custom_getter=_get_custom_getter()) as varscope: + yield varscope diff --git a/tensorflow/contrib/tpu/python/tpu/bfloat16_test.py b/tensorflow/contrib/tpu/python/tpu/bfloat16_test.py new file mode 100644 index 0000000000..48a01c7308 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/bfloat16_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for bfloat16 helper.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import bfloat16 +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import variable_scope + +from tensorflow.python.platform import test + + +class BFloat16ScopeTest(test.TestCase): + + def testScopeName(self): + """Test if name for the variable scope is propogated correctly. + """ + with bfloat16.bfloat16_scope() as bf: + self.assertEqual(bf.name, "bfloat16") + + def testRequestedDType(self): + """Test if requested dtype is honored in the getter. + """ + with bfloat16.bfloat16_scope() as scope: + v1 = variable_scope.get_variable("v1", []) + self.assertEqual(v1.dtype.base_dtype, dtypes.float32) + v2 = variable_scope.get_variable("v2", [], dtype=dtypes.bfloat16) + self.assertEqual(v2.dtype.base_dtype, dtypes.bfloat16) + self.assertEqual([dtypes.float32, dtypes.float32], + [v.dtype.base_dtype for v in scope.global_variables()]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index fa56708f44..6834600b79 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -2019,7 +2019,8 @@ class TPUEstimator(estimator_lib.Estimator): host_ops, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator)), - ExamplesPerSecondHook(ctx.global_batch_size), + ExamplesPerSecondHook(ctx.global_batch_size, + output_dir=self.model_dir), InstallSignalHandlerHook(), training.LoggingTensorHook( { diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index eea57ed336..3ae350c7bb 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -120,7 +120,8 @@ def _query_tpu_system_metadata(master_address, run_config, logging.info('*** Num TPU Workers: %d', metadata.num_hosts) logging.info('*** Num TPU Cores Per Worker: %d', metadata.num_of_cores_per_host) - logging.info('*** Available Devices: %s', metadata.devices) + for device in metadata.devices: + logging.info('*** Available Device: %s', device) else: logging.info('Failed to find TPU: %s', metadata) return metadata diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py index dbdbb08a82..f305197c19 100644 --- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py +++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -517,6 +518,7 @@ class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest): ops._USE_C_API = self._prev_value +@test_util.with_c_api class PaddingTest(test.TestCase): def testPaddingInvalidLengths(self): diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py index 7223194885..99d486b183 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py @@ -1574,8 +1574,9 @@ def _padding(sequences, num_unroll): if not sequences: return 0, {} - sequences_dict = {} - for key, value in sequences.items(): + # Sort 'sequences_dict' so 'length' will have a predictable value below. + sequences_dict = collections.OrderedDict() + for key, value in sorted(sequences.items()): if not (isinstance(value, sparse_tensor.SparseTensor) or isinstance(value, sparse_tensor.SparseTensorValue)): sequences_dict[key] = ops.convert_to_tensor(value) diff --git a/tensorflow/contrib/util/loader.py b/tensorflow/contrib/util/loader.py index f4283cd9ed..dca01d26f4 100644 --- a/tensorflow/contrib/util/loader.py +++ b/tensorflow/contrib/util/loader.py @@ -42,9 +42,10 @@ def load_op_library(path): plugin. """ if os.name == 'nt': - # To avoid makeing every user_ops aware of windows, re-write - # the file extension from .so to .dll. - path = re.sub(r'\.so$', '.dll', path) + # To avoid making every user_ops aware of windows, re-write + # the file extension from .so to .dll if .so file doesn't exist. + if not os.path.exists(path): + path = re.sub(r'\.so$', '.dll', path) # Currently we have only some user_ops as dlls on windows - don't try # to load them if the dll is not found. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 21f7866abd..7d5ae1c5b5 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -349,6 +349,7 @@ cc_library( "platform/env.h", "platform/env_time.h", "platform/file_system.h", + "platform/file_system_helper.h", "platform/fingerprint.h", "platform/init_main.h", "platform/logging.h", diff --git a/tensorflow/core/api_def/base_api/api_def_For.pbtxt b/tensorflow/core/api_def/base_api/api_def_For.pbtxt new file mode 100644 index 0000000000..a7cd8e1a26 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_For.pbtxt @@ -0,0 +1,29 @@ +op { + graph_op_name: "For" + in_arg { name: "start" description: "The lower bound. An int32" } + in_arg { name: "limit" description: "The upper bound. An int32" } + in_arg { name: "delta" description: "The increment. An int32" } + in_arg { + name: "input" + description: "A list of input tensors whose types are T." + } + out_arg { + name: "output" + description: "A list of output tensors whose types are T." + } + attr { name: "T" description: "A list of dtypes." } + attr { + name: "body" + description: <<END + A function that takes a list of tensors (int32, T) and returns another + list of tensors (T). +END + } + summary: <<END + ```python + output = input; + for i in range(start, limit, delta) + output = body(i, output); + ``` +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_If.pbtxt b/tensorflow/core/api_def/base_api/api_def_If.pbtxt new file mode 100644 index 0000000000..7ba5a3f37e --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_If.pbtxt @@ -0,0 +1,40 @@ +op { + graph_op_name: "If" + in_arg { name: "cond" description: "The predicate." } + in_arg { + name: "cond" + description: <<END + A Tensor. If the tensor is a scalar of non-boolean type, the + scalar is converted to a boolean according to the + following rule: if the scalar is a numerical value, non-zero means + `True` and zero means False; if the scalar is a string, non-empty + means `True` and empty means `False`. If the tensor is not a scalar, + being empty means False and being non-empty means True. +END + } + in_arg { + name: "input" + description: "A list of input tensors." + } + out_arg { + name: "output" + description: "A list of return values." + } + attr { name: "Tin" description: "A list of input types." } + attr { name: "Tout" description: "A list of output types." } + attr { + name: "then_branch" + description: <<END + A function that takes 'inputs' and returns a list of tensors, whose + types are the same as what else_branch returns. +END + } + attr { + name: "else_branch" + description: <<END + A function that takes 'inputs' and returns a list of tensors, whose + types are the same as what then_branch returns. +END + } + summary: "output = cond ? then_branch(input) : else_branch(input)" +} diff --git a/tensorflow/core/api_def/base_api/api_def_While.pbtxt b/tensorflow/core/api_def/base_api/api_def_While.pbtxt new file mode 100644 index 0000000000..95a19c6dff --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_While.pbtxt @@ -0,0 +1,33 @@ +op { + graph_op_name: "While" + in_arg { + name: "input" + description: "A list of input tensors whose types are T." + } + out_arg { + name: "output" + description: "A list of output tensors whose types are T." + } + attr { name: "T" description: "dtype in use." } + attr { + name: "cond" + description: <<END + A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +END + } + attr { + name: "body" + description: <<END + A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified + by T. +END + } + summary: "output = input; While (Cond(output)) { output = Body(output) }" +} diff --git a/tensorflow/core/api_def/python_api/api_def_For.pbtxt b/tensorflow/core/api_def/python_api/api_def_For.pbtxt new file mode 100644 index 0000000000..a58ddf56fe --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_For.pbtxt @@ -0,0 +1 @@ +op { graph_op_name: "For" visibility: HIDDEN } diff --git a/tensorflow/core/api_def/python_api/api_def_If.pbtxt b/tensorflow/core/api_def/python_api/api_def_If.pbtxt new file mode 100644 index 0000000000..a44db5da08 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_If.pbtxt @@ -0,0 +1 @@ +op { graph_op_name: "If" visibility: HIDDEN } diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterAdd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterAdd.pbtxt new file mode 100644 index 0000000000..4f5b6decf6 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ScatterAdd.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ScatterAdd" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_While.pbtxt b/tensorflow/core/api_def/python_api/api_def_While.pbtxt new file mode 100644 index 0000000000..f47a9b0fce --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_While.pbtxt @@ -0,0 +1 @@ +op { graph_op_name: "While" visibility: HIDDEN } diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index ee38960618..f95cecfc66 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" @@ -155,22 +156,22 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) { Status s = session->RunCallable(handle, {}, nullptr, nullptr); EXPECT_TRUE(errors::IsInvalidArgument(s)); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("`fetch_tensors` must be provided")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "`fetch_tensors` must be provided")); TF_ASSERT_OK(session->ReleaseCallable(handle)); std::vector<Tensor> outputs; s = session->RunCallable(handle, {}, &outputs, nullptr); EXPECT_TRUE(errors::IsInvalidArgument(s)); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains("Attempted to run callable after handle was released")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "Attempted to run callable after handle was released")); s = session->RunCallable(handle + 1, {}, &outputs, nullptr); EXPECT_TRUE(errors::IsInvalidArgument(s)); EXPECT_TRUE( - StringPiece(s.error_message()).contains("No such callable handle")); + str_util::StrContains(s.error_message(), "No such callable handle")); } } @@ -567,7 +568,7 @@ TEST(DirectSessionTest, MultipleFeedTest) { {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, &outputs); EXPECT_TRUE(errors::IsInvalidArgument(s)); - EXPECT_TRUE(StringPiece(s.error_message()).contains("fed more than once")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once")); } TEST(DirectSessionTest, MultipleFeedTest_Callable) { @@ -650,7 +651,7 @@ TEST(DirectSessionTest, MultipleFeedTest_Callable) { {first_identity->name() + ":0", second_identity->name() + ":0"}, {}), &handle); EXPECT_TRUE(errors::IsInvalidArgument(s)); - EXPECT_TRUE(StringPiece(s.error_message()).contains("fed more than once")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once")); } TEST(DirectSessionTest, FetchMultipleTimes) { @@ -845,8 +846,8 @@ TEST(DirectSessionTest, PartialRunMissingFeed) { s = session->PRun(handle, {{first_const->name(), value_11}}, {third_identity->name() + ":0"}, &outputs); ASSERT_TRUE(errors::IsInvalidArgument(s)); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("can't be computed from the feeds")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "can't be computed from the feeds")); } TEST(DirectSessionTest, PartialRunMultiOutputFeed) { @@ -875,8 +876,8 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) { // Fetch fourth_identity without feeds. s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs); ASSERT_TRUE(errors::IsInvalidArgument(s)); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("can't be computed from the feeds")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "can't be computed from the feeds")); // Feed switch_node:1 and fetch fourth_identity. s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}}, diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index b06b75d658..0c461a9ee9 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -258,6 +258,13 @@ struct NodeItem { // Return array of per-output allocator attributes. const AllocatorAttributes* output_attrs() const { return output_attr_base(); } + // Return array of expected input index from which each output should + // be forwarded: + // kNeverForward (-2) for DO NOT FORWARD (must allocate). + // kNoReservation (-1) for no expected forwarding. + // 0... for forward from that input. + const int* forward_from() const { return forward_from_base(); } + private: friend class GraphView; @@ -267,6 +274,7 @@ struct NodeItem { // AllocatorAttributes output_attr[num_outputs]; // uint8 input_type[num_inputs]; // uint8 output_type[num_outputs]; + // int forward_from[num_outputs]; // Return pointer to variable length section. char* var() const { @@ -292,6 +300,13 @@ struct NodeItem { sizeof(AllocatorAttributes) * num_outputs + sizeof(uint8) * num_inputs); } + int* forward_from_base() const { + return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(AllocatorAttributes) * num_outputs + + sizeof(uint8) * num_inputs + + sizeof(uint8) * num_outputs); + } + TF_DISALLOW_COPY_AND_ASSIGN(NodeItem); }; @@ -466,7 +481,8 @@ size_t GraphView::NodeItemBytes(const Node* n) { + num_output_edges * sizeof(EdgeInfo) // output_edges[...] + num_outputs * sizeof(AllocatorAttributes) // output_attr[...] + num_inputs * sizeof(uint8) // input_type[num_inputs] - + num_outputs * sizeof(uint8); // output_type[num_outputs] + + num_outputs * sizeof(uint8) // output_type[num_outputs] + + num_outputs * sizeof(int); // forward_from[num_outputs] static constexpr size_t kItemAlignment = sizeof(NodeItem*); static_assert(kItemAlignment % alignof(NodeItem) == 0, "NodeItem must be aligned with kItemAlignment"); @@ -737,8 +753,8 @@ Status InferAllocAttr(const Node* n, const Node* dst, VLOG(2) << "node " << n->name() << " is the sink of an RPC in"; } else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) && parsed_src_name.type != "CPU") { - // Value is going to be the sink of a local DMA from GPU to CPU (or other - // types of accelerators). + // Value is going to be the sink of a local DMA from GPU to CPU (or + // other types of accelerators). attr->set_gpu_compatible(true); VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy"; } else { @@ -1022,7 +1038,8 @@ class ExecutorState { int total_input_tensors = 0; std::vector<const Node*>* nodes = nullptr; - // Lock ordering: ExecutorState.mu_ < mu. + // Lock ordering: ExecutorState.mu_ < mu; + // during structured traversal: parent_frame->mu < mu. mutex mu; void InitializeFrameInfo(const string& enter_name) { @@ -1090,7 +1107,8 @@ class ExecutorState { void ActivateLoopInvs(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); - // Add a new loop invariant and make it available to all active iterations. + // Add a new loop invariant and make it available to all active + // iterations. void AddLoopInv(const NodeItem* item, const Entry& value, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); @@ -1147,8 +1165,8 @@ class ExecutorState { if (front_index_ == ready_.size()) { ready_.clear(); } else { - // Lots of unused entries at beginning of vector: move everything down - // to start of vector. + // Lots of unused entries at beginning of vector: move everything + // down to start of vector. ready_.erase(ready_.begin(), ready_.begin() + front_index_); } front_index_ = 0; @@ -1596,6 +1614,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter); params.is_input_dead = is_input_dead; params.output_attr_array = item.output_attrs(); + params.forward_from_array = nullptr; // later: item.forward_from(); if (item.kernel_is_async) { // Asynchronous computes. @@ -2333,8 +2352,9 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { FrameState* parent_frame = frame->parent_frame; const int64 parent_iter = frame->parent_iter; if (parent_frame != nullptr) { - mutex_lock paranet_frame_lock(parent_frame->mu); + mutex_lock parent_frame_lock(parent_frame->mu); // Propagate all the dead exits to the parent frame. + mutex_lock this_frame_lock(frame->mu); for (const Node* node : frame->dead_exits) { auto parent_iter_state = parent_frame->GetIteration(parent_iter); for (const Edge* e : node->out_edges()) { @@ -2603,7 +2623,7 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { (new ExecutorState(args, this))->RunAsync(std::move(done)); } -} // end namespace +} // namespace Status NewLocalExecutor(const LocalExecutorParams& params, std::unique_ptr<const Graph> graph, @@ -2629,4 +2649,4 @@ Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; } -} // end namespace tensorflow +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index d17ef4d459..61b2f0e60f 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -53,8 +54,8 @@ Status GetOpSig(const string& op, const OpDef** sig) { return OpRegistry::Global()->LookUpOpDef(op, sig); } -void HasError(const Status& s, const string& substr) { - EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) +void HasError(const Status& s, StringPiece substr) { + EXPECT_TRUE(str_util::StrContains(s.ToString(), substr)) << s << ", expected substring " << substr; } @@ -240,7 +241,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { Status status2 = Run(flr, handle, opts, args, std::move(rets)); EXPECT_TRUE(errors::IsInvalidArgument(status2)); EXPECT_TRUE( - StringPiece(status2.error_message()).contains("remote execution.")); + str_util::StrContains(status2.error_message(), "remote execution.")); return status; } @@ -310,7 +311,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { Status status2 = Run(flr, handle, opts, args, std::move(rets)); EXPECT_TRUE(errors::IsInvalidArgument(status2)); EXPECT_TRUE( - StringPiece(status2.error_message()).contains("remote execution.")); + str_util::StrContains(status2.error_message(), "remote execution.")); return status; } diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc index 6223a4e648..2d09e83d01 100644 --- a/tensorflow/core/common_runtime/function_threadpool_test.cc +++ b/tensorflow/core/common_runtime/function_threadpool_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -153,7 +154,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { Status status2 = Run(flr, handle, opts, args, std::move(rets)); EXPECT_TRUE(errors::IsInvalidArgument(status2)); EXPECT_TRUE( - StringPiece(status2.error_message()).contains("remote execution.")); + str_util::StrContains(status2.error_message(), "remote execution.")); return status; } diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index e128b9257f..86851c2c07 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { @@ -151,7 +152,8 @@ class ColocationGraph { if (attr_value != nullptr && attr_value->has_list()) { for (const string& class_spec : attr_value->list().s()) { StringPiece spec(class_spec); - if (spec.Consume(kColocationGroupPrefixStringPiece)) { + if (str_util::ConsumePrefix(&spec, + kColocationGroupPrefixStringPiece)) { found_spec = true; TF_RETURN_IF_ERROR( ColocateNodeToGroup(&colocation_group_root, node, spec)); diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 098024d219..5ad251c892 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -262,9 +263,9 @@ class PlacerTest : public ::testing::Test { ->attributes() \ .device_type()) -#define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \ - EXPECT_TRUE(StringPiece(GetNodeByName((g), (name))->assigned_device_name()) \ - .contains(device_substr)) +#define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \ + EXPECT_TRUE(::tensorflow::str_util::StrContains( \ + GetNodeByName((g), (name))->assigned_device_name(), device_substr)) // Test that a graph with no constraints will successfully assign nodes to the // "best available" device (i.e. prefer GPU over CPU). @@ -488,11 +489,10 @@ TEST_F(PlacerTest, TestAssignedGpuDeviceToCpuDevice) { Status s = Place(&g); EXPECT_EQ(error::INTERNAL, s.code()); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains( - "Assigned device '/job:a/replica:0/task:0/device:fakegpu:0' " - "does not have registered OpKernel support for TestInput")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "Assigned device '/job:a/replica:0/task:0/device:fakegpu:0' " + "does not have registered OpKernel support for TestInput")); } // Test that graphs with reference connections are correctly placed. @@ -541,15 +541,15 @@ TEST_F(PlacerTest, TestReferenceConnection) { { Status s = ReferenceTestHelper("VariableCPU", "AssignGPU", "FakeCPU"); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("no device type supports both of those nodes")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), "no device type supports both of those nodes")); } TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "TestAssign", "FakeGPU")); { Status s = ReferenceTestHelper("VariableGPU", "AssignCPU", "FakeCPU"); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("no device type supports both of those nodes")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), "no device type supports both of those nodes")); } TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", "FakeGPU")); } @@ -760,8 +760,9 @@ TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) { } Status s = Place(&g); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("Cannot colocate nodes 'foo' and 'in' because no " + EXPECT_TRUE( + str_util::StrContains(s.error_message(), + "Cannot colocate nodes 'foo' and 'in' because no " "device type supports both of those nodes and the " "other nodes colocated with them")); } @@ -824,11 +825,11 @@ TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) { } Status s = Place(&g); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains("Cannot colocate nodes 'var3' and 'assign3' because no " - "device type supports both of those nodes and the other " - "nodes colocated with them.")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "Cannot colocate nodes 'var3' and 'assign3' because no " + "device type supports both of those nodes and the other " + "nodes colocated with them.")); } TEST_F(PlacerTest, TestColocationAndReferenceConnections) { @@ -888,7 +889,7 @@ TEST_F(PlacerTest, TestEmptyDeviceSet) { Status s = Place(&g, &empty); EXPECT_TRUE( - StringPiece(s.error_message()).contains("No devices are registered")); + str_util::StrContains(s.error_message(), "No devices are registered")); } // Test that placement fails when the requested device forces an @@ -913,16 +914,17 @@ TEST_F(PlacerTest, TestHeterogeneousDeviceSetFailure) { heterogeneous.AddDevice(cpu.get()); Status s = Place(&g, &heterogeneous); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("colocated with a group of nodes that required " + EXPECT_TRUE( + str_util::StrContains(s.error_message(), + "colocated with a group of nodes that required " "incompatible device")); // The error message should contain information that indicates which // op types have which registered device types. - EXPECT_TRUE(StringPiece(s.error_message()).contains("VariableGPU: FakeGPU")) + EXPECT_TRUE(str_util::StrContains(s.error_message(), "VariableGPU: FakeGPU")) << s; EXPECT_TRUE( - StringPiece(s.error_message()).contains("TestAssign: FakeGPU FakeCPU")) + str_util::StrContains(s.error_message(), "TestAssign: FakeGPU FakeCPU")) << s; } @@ -937,7 +939,7 @@ TEST_F(PlacerTest, TestUnknownDevice) { Status s = Place(&g); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()).contains("/job:foo")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "/job:foo")); } // Test that placement fails when the combination of partial @@ -952,7 +954,7 @@ TEST_F(PlacerTest, TestUnknownMergedDevice) { Status s = Place(&g); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()).contains("/job:foo")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "/job:foo")); } // Test that placement fails when the previously-assigned device for a @@ -969,9 +971,9 @@ TEST_F(PlacerTest, TestUnknownAssignedDevice) { Status s = Place(&g); EXPECT_EQ(error::INTERNAL, s.code()); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains("Assigned device '/job:foo' does not match any device")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "Assigned device '/job:foo' does not match any device")); } // Test that placement fails when an op with no registered kernels is @@ -986,12 +988,11 @@ TEST_F(PlacerTest, TestNoKernelsRegistered) { Status s = Place(&g); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "No OpKernel was registered to support Op 'VariableNoKernels'")); EXPECT_TRUE( - StringPiece(s.error_message()) - .contains( - "No OpKernel was registered to support Op 'VariableNoKernels'")); - EXPECT_TRUE( - StringPiece(s.error_message()).contains("<no registered kernels>")); + str_util::StrContains(s.error_message(), "<no registered kernels>")); } // Test that placement fails when a kernel is registered but no known @@ -1011,10 +1012,10 @@ TEST_F(PlacerTest, TestNoDevicesRegistered) { Status s = Place(&g, &cpu_only); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("No OpKernel was registered to support " - "Op 'VariableGPU'")); - EXPECT_TRUE(StringPiece(s.error_message()).contains("device='FakeGPU'")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "No OpKernel was registered to support Op 'VariableGPU'")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "device='FakeGPU'")); } // Test that placement fails when a requested device is malformed. @@ -1028,8 +1029,8 @@ TEST_F(PlacerTest, TestMalformedDeviceSpecification) { Status s = Place(&g); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("Malformed device specification '/foo:bar'")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), "Malformed device specification '/foo:bar'")); } // Test that placement fails when a previously-assigned device is malformed. @@ -1045,8 +1046,8 @@ TEST_F(PlacerTest, TestMalformedAssignedDevice) { Status s = Place(&g); EXPECT_EQ(error::INTERNAL, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("Malformed assigned device '/foo:bar'")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "Malformed assigned device '/foo:bar'")); } // Test that placement fails when a device was previously assigned to @@ -1063,9 +1064,8 @@ TEST_F(PlacerTest, TestNonUniqueAssignedDevice) { Status s = Place(&g); EXPECT_EQ(error::INTERNAL, s.code()); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains("Assigned device '/job:a' does not match any device")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), "Assigned device '/job:a' does not match any device")); } // Test that ops request to be placed on non-existent devices will be relocated @@ -1099,7 +1099,7 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacement) { SessionOptions options; Status s = Place(&g, &options); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()).contains("/device:fakegpu:11")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "/device:fakegpu:11")); } // Test that placement fails when a node requests an explicit device that is not @@ -1116,10 +1116,10 @@ TEST_F(PlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) { SessionOptions options; Status s = Place(&g, &options); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()).contains("/device:fakecpu:0")); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains("no supported kernel for fakecpu devices is available")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "/device:fakecpu:0")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "no supported kernel for fakecpu devices is available")); } // Test that placement fails when a node requests an explicit device that is not @@ -1137,9 +1137,9 @@ TEST_F(PlacerTest, TestNonExistentDevice) { Status s = Place(&g, &options); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); LOG(WARNING) << s.error_message(); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("was explicitly assigned to /job:foo/replica:17 " - "but available devices")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "was explicitly assigned to /job:foo/replica:17 but available devices")); } TEST_F(PlacerTest, TestUnsupportedDeviceAllowSoftPlacement) { @@ -1205,8 +1205,8 @@ TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { Status s = Place(&g); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("Cannot colocate nodes 'var' and 'assign'")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), "Cannot colocate nodes 'var' and 'assign'")); } // Test that a generator node follows its consumers (where there are several diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index d69e8bc2a0..c7b8259f78 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -155,7 +155,10 @@ class ProcessFunctionLibraryRuntime { string target_device() { return target_device_; } - FunctionLibraryRuntime::LocalHandle local_handle() { return local_handle_; } + FunctionLibraryRuntime::LocalHandle local_handle() { + mutex_lock l(mu_); + return local_handle_; + } // Initializes the FunctionData object by potentially making an Initialize // call to the DistributedFunctionLibraryRuntime. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 2da67b084a..4fbf2abc67 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -132,7 +133,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { }); done2.WaitForNotification(); EXPECT_TRUE(errors::IsNotFound(status)); - EXPECT_TRUE(StringPiece(status.error_message()).contains("not found.")); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "not found.")); return Status::OK(); } diff --git a/tensorflow/core/common_runtime/session_test.cc b/tensorflow/core/common_runtime/session_test.cc index a074154450..feaf29c7bb 100644 --- a/tensorflow/core/common_runtime/session_test.cc +++ b/tensorflow/core/common_runtime/session_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/public/session.h" #include "tensorflow/core/common_runtime/session_factory.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" @@ -31,10 +32,9 @@ TEST(SessionTest, InvalidTargetReturnsNull) { Session* session; Status s = tensorflow::NewSession(options, &session); EXPECT_EQ(s.code(), error::NOT_FOUND); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains( - "No session factory registered for the given session options")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "No session factory registered for the given session options")); } // Register a fake session factory to test error handling paths in @@ -44,7 +44,7 @@ class FakeSessionFactory : public SessionFactory { FakeSessionFactory() {} bool AcceptsOptions(const SessionOptions& options) override { - return StringPiece(options.target).starts_with("fake"); + return str_util::StartsWith(options.target, "fake"); } Session* NewSession(const SessionOptions& options) override { @@ -68,9 +68,9 @@ TEST(SessionTest, MultipleFactoriesForTarget) { Status s = tensorflow::NewSession(options, &session); EXPECT_EQ(s.code(), error::INTERNAL); EXPECT_TRUE( - StringPiece(s.error_message()).contains("Multiple session factories")); - EXPECT_TRUE(StringPiece(s.error_message()).contains("FAKE_SESSION_1")); - EXPECT_TRUE(StringPiece(s.error_message()).contains("FAKE_SESSION_2")); + str_util::StrContains(s.error_message(), "Multiple session factories")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "FAKE_SESSION_1")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "FAKE_SESSION_2")); } } // namespace diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index cef50be3b1..1b7e3138ee 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -351,6 +351,11 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { } } } + if (node_context->requested_input_tensor_as_partial_shape(dst_input)) { + // The input value may have changed. Since we have no way to know if + // that's indeed the case, err on the safe side. + *refined = true; + } // Also propagate handle shape and dtype of edges which are carrying // resource handles. diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index adf5a9afff..f48638afc0 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -143,8 +144,8 @@ TEST_F(ShapeRefinerTest, BadShapes) { // an error. Status s = m.AddNode(mm.node()); ASSERT_FALSE(s.ok()); - ASSERT_TRUE(StringPiece(s.error_message()) - .contains("Dimensions must be equal, but are 1 and 2")); + ASSERT_TRUE(str_util::StrContains( + s.error_message(), "Dimensions must be equal, but are 1 and 2")); } TEST_F(ShapeRefinerTest, SetShape) { @@ -1032,8 +1033,8 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) { TF_ASSERT_OK(m.AddNode(input.node())); } TF_ASSERT_OK(m.AddNode(pack.node())); - EXPECT_TRUE( - StringPiece(m.AddNode(result).error_message()).contains("but is rank 2")); + EXPECT_TRUE(str_util::StrContains(m.AddNode(result).error_message(), + "but is rank 2")); } TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) { diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index 049eec347c..bafd9bfc68 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -144,9 +144,9 @@ BaseRemoteRendezvous::~BaseRemoteRendezvous() { // Returns true if "device_name" is a valid full name of local device // of the "worker". This helper is purely based on the worker name // and device name and does no lookups in the worker->device_mgr. -static bool IsLocalDevice(const string& worker_name, +static bool IsLocalDevice(const StringPiece worker_name, const StringPiece device_name) { - return device_name.starts_with(worker_name); + return str_util::StartsWith(device_name, worker_name); } Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index 120a33f17b..3e79a40683 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/protobuf/master.pb.h" @@ -402,7 +403,7 @@ Status GrpcSession::Reset(const SessionOptions& options, class GrpcSessionFactory : public SessionFactory { public: bool AcceptsOptions(const SessionOptions& options) override { - return StringPiece(options.target).starts_with(kSchemePrefix); + return str_util::StartsWith(options.target, kSchemePrefix); } Session* NewSession(const SessionOptions& options) override { diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index a382b8be95..6182f95f28 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -61,6 +61,26 @@ static bool cpu_allocator_collect_stats = false; // If true, cpu allocator collects full stats. static bool cpu_allocator_collect_full_stats = false; +// Individual allocations large than this amount will trigger a warning. +static const double kLargeAllocationWarningThreshold = 0.1; + +// If cpu_allocator_collect_stats is true, warn when the total allocated memory +// exceeds this threshold. +static const double kTotalAllocationWarningThreshold = 0.5; + +// Cache first invocation to port::AvailableRam, as it can be expensive. +static int64_t LargeAllocationWarningBytes() { + static int64_t value = static_cast<int64>(port::AvailableRam() * + kLargeAllocationWarningThreshold); + return value; +} + +static int64_t TotalAllocationWarningBytes() { + static int64_t value = static_cast<int64>(port::AvailableRam() * + kTotalAllocationWarningThreshold); + return value; +} + void EnableCPUAllocatorStats(bool enable) { cpu_allocator_collect_stats = enable; } @@ -70,7 +90,8 @@ void EnableCPUAllocatorFullStats(bool enable) { class CPUAllocator : public VisitableAllocator { public: - CPUAllocator() : allocation_begun_(false) {} + CPUAllocator() + : total_allocation_warning_triggered_(false), allocation_begun_(false) {} ~CPUAllocator() override {} @@ -81,6 +102,12 @@ class CPUAllocator : public VisitableAllocator { allocation_begun_ = true; } + if (num_bytes > LargeAllocationWarningBytes()) { + LOG(WARNING) << "Allocation of " << num_bytes << " exceeds " + << 100 * kLargeAllocationWarningThreshold + << "% of system memory."; + } + void* p = port::AlignedMalloc(num_bytes, alignment); if (cpu_allocator_collect_stats) { const std::size_t alloc_size = port::MallocExtension_GetAllocatedSize(p); @@ -91,6 +118,14 @@ class CPUAllocator : public VisitableAllocator { std::max<int64>(stats_.max_bytes_in_use, stats_.bytes_in_use); stats_.max_alloc_size = std::max<int64>(stats_.max_alloc_size, alloc_size); + + if (stats_.bytes_in_use > TotalAllocationWarningBytes() && + !total_allocation_warning_triggered_) { + LOG(WARNING) << "Total allocated memory " << stats_.bytes_in_use + << "exceeds " << 100 * kTotalAllocationWarningThreshold + << "% of system memory"; + total_allocation_warning_triggered_ = true; + } } // visit each Visitor in alloc_visitors_ @@ -162,6 +197,7 @@ class CPUAllocator : public VisitableAllocator { private: mutex mu_; AllocatorStats stats_ GUARDED_BY(mu_); + bool total_allocation_warning_triggered_ GUARDED_BY(mu_); // visitor_mutex_ protects write access to alloc_visitors_ and free_visitors_. // While write access is mutually exclusive, reads may happen concurrently. diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index ebb56d525e..87c1ddd15d 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -186,7 +186,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { // check if has_list is false and some other field in attr_value is // set to flag the error. This test can be made more strict once // support for GraphDef versions <= 4 is dropped. - if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) { + if (str_util::StartsWith(type, "list(") && !attr_value.has_list()) { if (num_set) { return errors::InvalidArgument( "AttrValue missing value with expected type '", type, "'"); @@ -197,7 +197,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { } // Okay to have an empty list, but not to be missing a non-list value. - if (num_set == 0 && !StringPiece(type).starts_with("list(")) { + if (num_set == 0 && !str_util::StartsWith(type, "list(")) { return errors::InvalidArgument( "AttrValue missing value with expected type '", type, "'"); } @@ -241,29 +241,29 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { // Parse type. string field_name; - bool is_list = type.Consume("list("); - if (type.Consume("string")) { + bool is_list = str_util::ConsumePrefix(&type, "list("); + if (str_util::ConsumePrefix(&type, "string")) { field_name = "s"; - } else if (type.Consume("int")) { + } else if (str_util::ConsumePrefix(&type, "int")) { field_name = "i"; - } else if (type.Consume("float")) { + } else if (str_util::ConsumePrefix(&type, "float")) { field_name = "f"; - } else if (type.Consume("bool")) { + } else if (str_util::ConsumePrefix(&type, "bool")) { field_name = "b"; - } else if (type.Consume("type")) { + } else if (str_util::ConsumePrefix(&type, "type")) { field_name = "type"; - } else if (type.Consume("shape")) { + } else if (str_util::ConsumePrefix(&type, "shape")) { field_name = "shape"; - } else if (type.Consume("tensor")) { + } else if (str_util::ConsumePrefix(&type, "tensor")) { field_name = "tensor"; - } else if (type.Consume("func")) { + } else if (str_util::ConsumePrefix(&type, "func")) { field_name = "func"; - } else if (type.Consume("placeholder")) { + } else if (str_util::ConsumePrefix(&type, "placeholder")) { field_name = "placeholder"; } else { return false; } - if (is_list && !type.Consume(")")) { + if (is_list && !str_util::ConsumePrefix(&type, ")")) { return false; } diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 2fb17c2b02..72eeda7a43 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -504,8 +504,8 @@ Status Conv3DShape(shape_inference::InferenceContext* c) { input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); stride_planes = strides[2]; - stride_cols = strides[3]; - stride_rows = strides[4]; + stride_rows = strides[3]; + stride_cols = strides[4]; } else { stride_planes = strides[1]; stride_rows = strides[2]; diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 5f3e5ad457..13d429b895 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -140,9 +141,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); - EXPECT_TRUE( - StringPiece(s.ToString()) - .contains("Invalid argument: Shape must be rank 2 but is rank 1")); + EXPECT_TRUE(str_util::StrContains( + s.ToString(), "Invalid argument: Shape must be rank 2 but is rank 1")); } { @@ -161,10 +161,9 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { {S({2, 5}), S({3, 4})}, {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); - EXPECT_TRUE( - StringPiece(s.ToString()) - .contains( - "Invalid argument: Dimensions must be equal, but are 5 and 3")); + EXPECT_TRUE(str_util::StrContains( + s.ToString(), + "Invalid argument: Dimensions must be equal, but are 5 and 3")); } { @@ -173,9 +172,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); - EXPECT_TRUE( - StringPiece(s.ToString()) - .contains("Invalid argument: Shape must be rank 2 but is rank 3")); + EXPECT_TRUE(str_util::StrContains( + s.ToString(), "Invalid argument: Shape must be rank 2 but is rank 3")); } { diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index beaf0adbc5..9e7ffe6c0b 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -201,7 +201,7 @@ class GraphDefBuilderWrapper { // Also looks up the `op_def->name` in the global // `WhitelistedStatefulOpRegistry`. bool IsOpWhitelisted(const OpDef* op_def) const { - return (StringPiece(op_def->name()).ends_with("Dataset") && + return (str_util::EndsWith(op_def->name(), "Dataset") && op_def->output_arg_size() == 1 && op_def->output_arg(0).type() == DT_VARIANT) || dataset::WhitelistedStatefulOpRegistry::Global()->Contains( @@ -474,11 +474,11 @@ class GraphDatasetBase : public DatasetBase { } // Key for storing the Dataset graph in the serialized format. - static const char kDatasetGraphKey[]; + TF_EXPORT static const char kDatasetGraphKey[]; // Key for storing the output node of the Dataset graph in the serialized // format. - static const char kDatasetGraphOutputNodeKey[]; + TF_EXPORT static const char kDatasetGraphOutputNodeKey[]; private: Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 3e7b89d4eb..bdc1af9fda 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -278,7 +279,7 @@ class FunctionInstantiationHelper { auto it = index_.lower_bound(node_name); while (it != index_.end() && it->first <= node_colon_bound) { if (it->first == node_name || - tensorflow::StringPiece(it->first).starts_with(node_colon)) { + tensorflow::str_util::StartsWith(it->first, node_colon)) { nid = it->second.nid; break; } @@ -502,7 +503,7 @@ string Print(const NodeDef& n) { std::vector<StringPiece> dat; std::vector<string> dep; for (StringPiece s : n.input()) { - if (s.Consume("^")) { + if (str_util::ConsumePrefix(&s, "^")) { dep.push_back(s.ToString()); } else { dat.push_back(s); diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 23685e9c53..44e1383719 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -496,7 +496,7 @@ MySelect(x:float) -> (z:float) { } static void HasError(const Status& s, const string& substr) { - EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) + EXPECT_TRUE(str_util::StrContains(s.ToString(), substr)) << ">>" << s << "<<, expected substring >>" << substr << "<<"; } diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index 896cb3cd7f..f7539d37be 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb_text.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -94,7 +95,7 @@ static Status RemoveNewDefaultAttrsFromNodeDef( std::vector<string> to_remove; for (const auto& attr : node_def->attr()) { // If the attr is not in consumer_op_def and doesn't start with '_'... - if (!StringPiece(attr.first).starts_with("_") && + if (!str_util::StartsWith(attr.first, "_") && FindAttr(attr.first, *consumer_op_def) == nullptr) { const OpDef::AttrDef* producer_attr_def = FindAttr(attr.first, *producer_op_def); diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc index e836873f66..cc583df348 100644 --- a/tensorflow/core/framework/node_def_builder_test.cc +++ b/tensorflow/core/framework/node_def_builder_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -82,7 +83,7 @@ class NodeDefBuilderTest : public ::testing::Test { EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); if (status.ok()) return; for (const string& message : messages) { - EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) + EXPECT_TRUE(str_util::StrContains(status.error_message(), message)) << status << ", " << message; } } @@ -103,7 +104,7 @@ class NodeDefBuilderTest : public ::testing::Test { } EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); if (status.ok()) return; - EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) + EXPECT_TRUE(str_util::StrContains(status.error_message(), message)) << "Actual error: " << status.error_message() << "\nDoes not contain: " << message; } diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 95fb386314..bad92ca9b3 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" @@ -131,7 +132,7 @@ Status AttrSlice::Find(StringPiece attr_name, // Skip AttachDef for internal attrs since it is a little bit // expensive and it is common for them to correctly not be included // in a NodeDef. - if (!attr_name.starts_with("_") && ndef_ != nullptr) { + if (!str_util::StartsWith(attr_name, "_") && ndef_ != nullptr) { s = AttachDef(s, *ndef_); } return s; @@ -399,7 +400,7 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { size_t num_inputs = 0; // TODO(josh11b): Unify the input field validation. for (const string& input : node_def.input()) { - if (StringPiece(input).starts_with("^")) { + if (str_util::StartsWith(input, "^")) { seen_control = true; if (input.find(':') != string::npos) { return errors::InvalidArgument( @@ -425,7 +426,7 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { } for (const auto& attr : node_def.attr()) { // Allow internal optional attributes with names starting with "_". - if (StringPiece(attr.first).starts_with("_")) { + if (str_util::StartsWith(attr.first, "_")) { continue; } auto iter = op_attrs.find(attr.first); diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index ae3a93eafe..2a49425dba 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -65,7 +66,7 @@ void ExpectFailure(const NodeDef& bad, const OpDef& op_def, << "; OpDef: " << SummarizeOpDef(op_def); LOG(INFO) << "Message: " << status.error_message(); - EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + EXPECT_TRUE(str_util::StrContains(status.ToString(), message)) << "NodeDef: " << SummarizeNodeDef(bad) << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status << "\nDoes not contain: " << message; @@ -265,7 +266,7 @@ void ExpectInvalidSyntax(const NodeDef& bad, const string& message) { EXPECT_TRUE(errors::IsInvalidArgument(status)) << status << "; NodeDef: " << SummarizeNodeDef(bad); - EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + EXPECT_TRUE(str_util::StrContains(StringPiece(status.ToString()), message)) << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", " << message; } diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index fc5467b3c8..5f68c59fe9 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" @@ -142,7 +143,7 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const { out->Reserve(sorted.size()); for (const auto& item : sorted) { - if (include_internal || !StringPiece(item.first).starts_with("_")) { + if (include_internal || !str_util::StartsWith(item.first, "_")) { *out->Add() = item.second->op_def; } } diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc index b57bdcb841..c782480f1f 100644 --- a/tensorflow/core/framework/op_compatibility_test.cc +++ b/tensorflow/core/framework/op_compatibility_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -96,7 +97,7 @@ class OpCompatibilityTest : public OpsTestBase { ADD_FAILURE() << SummarizeOpDef(old_op_def) << " vs. " << SummarizeOpDef(new_op_def); } else { - EXPECT_TRUE(StringPiece(status.error_message()).contains(error)) + EXPECT_TRUE(str_util::StrContains(status.error_message(), error)) << status << " does not contain " << error; } } @@ -118,7 +119,7 @@ class OpCompatibilityTest : public OpsTestBase { ADD_FAILURE() << SummarizeNodeDef(*node_def()); } else { EXPECT_TRUE( - StringPiece(status.error_message()).contains(validation_error)) + str_util::StrContains(status.error_message(), validation_error)) << status << " does not contain " << validation_error; } @@ -179,7 +180,7 @@ class OpCompatibilityTest : public OpsTestBase { << SummarizeOpDef(*new_op_def); } else { EXPECT_TRUE( - StringPiece(status.error_message()).contains(compatibility_error)) + str_util::StrContains(status.error_message(), compatibility_error)) << status << " does not contain " << compatibility_error; } } diff --git a/tensorflow/core/framework/op_def.proto b/tensorflow/core/framework/op_def.proto index ba545a1994..ca0e5e7133 100644 --- a/tensorflow/core/framework/op_def.proto +++ b/tensorflow/core/framework/op_def.proto @@ -126,6 +126,12 @@ message OpDef { // ------------------------------------------------------------------------- // Optimization constraints. + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // // By default Ops may be moved between devices. Stateful ops should // either not be moved, or should only be moved if that state can also // be moved (e.g. via some sort of save / restore). diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 962bc11ccb..403bd0b5e2 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -112,9 +112,11 @@ bool ConsumeAttrNumber(StringPiece* sp, int64* out) { bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) { auto capture_begin = sp->begin(); - if (sp->Consume("numbertype") || sp->Consume("numerictype") || - sp->Consume("quantizedtype") || sp->Consume("realnumbertype") || - sp->Consume("realnumberictype")) { + if (str_util::ConsumePrefix(sp, "numbertype") || + str_util::ConsumePrefix(sp, "numerictype") || + str_util::ConsumePrefix(sp, "quantizedtype") || + str_util::ConsumePrefix(sp, "realnumbertype") || + str_util::ConsumePrefix(sp, "realnumberictype")) { *out = StringPiece(capture_begin, sp->begin() - capture_begin); return true; } @@ -155,32 +157,32 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, bool is_list = ConsumeListPrefix(&spec); string type; StringPiece type_string; // Used if type == "type" - if (spec.Consume("string")) { + if (str_util::ConsumePrefix(&spec, "string")) { type = "string"; - } else if (spec.Consume("int")) { + } else if (str_util::ConsumePrefix(&spec, "int")) { type = "int"; - } else if (spec.Consume("float")) { + } else if (str_util::ConsumePrefix(&spec, "float")) { type = "float"; - } else if (spec.Consume("bool")) { + } else if (str_util::ConsumePrefix(&spec, "bool")) { type = "bool"; - } else if (spec.Consume("type")) { + } else if (str_util::ConsumePrefix(&spec, "type")) { type = "type"; - } else if (spec.Consume("shape")) { + } else if (str_util::ConsumePrefix(&spec, "shape")) { type = "shape"; - } else if (spec.Consume("tensor")) { + } else if (str_util::ConsumePrefix(&spec, "tensor")) { type = "tensor"; - } else if (spec.Consume("func")) { + } else if (str_util::ConsumePrefix(&spec, "func")) { type = "func"; } else if (ConsumeCompoundAttrType(&spec, &type_string)) { type = "type"; AttrValue* allowed = attr->mutable_allowed_values(); VERIFY(ProcessCompoundType(type_string, allowed), "Expected to see a compound type, saw: ", type_string); - } else if (spec.Consume("{")) { + } else if (str_util::ConsumePrefix(&spec, "{")) { // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }" AttrValue* allowed = attr->mutable_allowed_values(); str_util::RemoveLeadingWhitespace(&spec); - if (spec.starts_with("\"") || spec.starts_with("'")) { + if (str_util::StartsWith(spec, "\"") || str_util::StartsWith(spec, "'")) { type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }" while (true) { StringPiece escaped_string; @@ -193,11 +195,12 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, "Trouble unescaping \"", escaped_string, "\", got error: ", error); allowed->mutable_list()->add_s(unescaped); - if (spec.Consume(",")) { + if (str_util::ConsumePrefix(&spec, ",")) { str_util::RemoveLeadingWhitespace(&spec); - if (spec.Consume("}")) break; // Allow ending with ", }". + if (str_util::ConsumePrefix(&spec, "}")) + break; // Allow ending with ", }". } else { - VERIFY(spec.Consume("}"), + VERIFY(str_util::ConsumePrefix(&spec, "}"), "Expected , or } after strings in list, not: '", spec, "'"); break; } @@ -215,11 +218,12 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, "Unrecognized type string '", type_string, "'"); allowed->mutable_list()->add_type(dt); } - if (spec.Consume(",")) { + if (str_util::ConsumePrefix(&spec, ",")) { str_util::RemoveLeadingWhitespace(&spec); - if (spec.Consume("}")) break; // Allow ending with ", }". + if (str_util::ConsumePrefix(&spec, "}")) + break; // Allow ending with ", }". } else { - VERIFY(spec.Consume("}"), + VERIFY(str_util::ConsumePrefix(&spec, "}"), "Expected , or } after types in list, not: '", spec, "'"); break; } @@ -232,7 +236,8 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, // Write the type into *attr. if (is_list) { - VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'"); + VERIFY(str_util::ConsumePrefix(&spec, ")"), + "Expected ) to close 'list(', not: '", spec, "'"); str_util::RemoveLeadingWhitespace(&spec); attr->set_type(strings::StrCat("list(", type, ")")); } else { @@ -240,7 +245,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, } // Read optional minimum constraint at the end. - if ((is_list || type == "int") && spec.Consume(">=")) { + if ((is_list || type == "int") && str_util::ConsumePrefix(&spec, ">=")) { int64 min_limit = -999; VERIFY(ConsumeAttrNumber(&spec, &min_limit), "Could not parse integer lower limit after '>=', found '", spec, @@ -250,7 +255,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, } // Parse default value, if present. - if (spec.Consume("=")) { + if (str_util::ConsumePrefix(&spec, "=")) { str_util::RemoveLeadingWhitespace(&spec); VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()), "Could not parse default value '", spec, "'"); diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index c80802aad3..9be0dc69d2 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" @@ -239,7 +240,7 @@ static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, Status ValidateOpDef(const OpDef& op_def) { using ::tensorflow::strings::Scanner; - if (!StringPiece(op_def.name()).starts_with("_")) { + if (!str_util::StartsWith(op_def.name(), "_")) { VALIDATE(Scanner(op_def.name()) .One(Scanner::UPPERLETTER) .Any(Scanner::LETTER_DIGIT) @@ -259,11 +260,11 @@ Status ValidateOpDef(const OpDef& op_def) { // Validate type StringPiece type(attr.type()); - bool is_list = type.Consume("list("); + bool is_list = str_util::ConsumePrefix(&type, "list("); bool found = false; for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape", "tensor", "func"}) { - if (type.Consume(valid)) { + if (str_util::ConsumePrefix(&type, valid)) { found = true; break; } @@ -271,8 +272,9 @@ Status ValidateOpDef(const OpDef& op_def) { VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(), "'"); if (is_list) { - VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ", - attr.name(), "'s type ", attr.type()); + VALIDATE(str_util::ConsumePrefix(&type, ")"), + "'list(' is missing ')' in attr ", attr.name(), "'s type ", + attr.type()); } VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ", attr.name(), "'s type ", attr.type()); diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc index 2b9812d4fc..4514d92e38 100644 --- a/tensorflow/core/framework/op_def_util_test.cc +++ b/tensorflow/core/framework/op_def_util_test.cc @@ -57,7 +57,7 @@ class ValidateOpDefTest : public ::testing::Test { EXPECT_FALSE(status.ok()) << "Did not see error with: " << message; if (!status.ok()) { LOG(INFO) << "message: " << status; - EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + EXPECT_TRUE(str_util::StrContains(status.ToString(), message)) << "Actual: " << status << "\nExpected to contain: " << message; } } diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index 5f2eb9d99a..7f23272871 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -50,10 +50,10 @@ string WordWrap(StringPiece prefix, StringPiece str, int width) { StringPiece to_append = str.substr(0, space); str.remove_prefix(space + 1); // Remove spaces at break. - while (to_append.ends_with(" ")) { + while (str_util::EndsWith(to_append, " ")) { to_append.remove_suffix(1); } - while (str.Consume(" ")) { + while (str_util::ConsumePrefix(&str, " ")) { } // Go on to the next line. @@ -65,8 +65,9 @@ string WordWrap(StringPiece prefix, StringPiece str, int width) { } bool ConsumeEquals(StringPiece* description) { - if (description->Consume("=")) { - while (description->Consume(" ")) { // Also remove spaces after "=". + if (str_util::ConsumePrefix(description, "=")) { + while (str_util::ConsumePrefix(description, + " ")) { // Also remove spaces after "=". } return true; } @@ -98,7 +99,7 @@ static bool StartsWithFieldName(StringPiece line, const std::vector<string>& multi_line_fields) { StringPiece up_to_colon; if (!SplitAt(':', &line, &up_to_colon)) return false; - while (up_to_colon.Consume(" ")) + while (str_util::ConsumePrefix(&up_to_colon, " ")) ; // Remove leading spaces. for (const auto& field : multi_line_fields) { if (up_to_colon == field) { @@ -119,9 +120,9 @@ static bool ConvertLine(StringPiece line, StringPiece up_to_colon; StringPiece after_colon = line; SplitAt(':', &after_colon, &up_to_colon); - while (after_colon.Consume(" ")) + while (str_util::ConsumePrefix(&after_colon, " ")) ; // Remove leading spaces. - if (!after_colon.Consume("\"")) { + if (!str_util::ConsumePrefix(&after_colon, "\"")) { // We only convert string fields, so don't convert this line. return false; } @@ -181,9 +182,9 @@ string PBTxtToMultiline(StringPiece pbtxt, static bool FindMultiline(StringPiece line, size_t colon, string* end) { if (colon == StringPiece::npos) return false; line.remove_prefix(colon + 1); - while (line.Consume(" ")) { + while (str_util::ConsumePrefix(&line, " ")) { } - if (line.Consume("<<")) { + if (str_util::ConsumePrefix(&line, "<<")) { *end = line.ToString(); return true; } @@ -228,7 +229,7 @@ string PBTxtFromMultiline(StringPiece multiline_pbtxt) { string suffix; while (!multiline_pbtxt.empty()) { SplitAt('\n', &multiline_pbtxt, &line); - if (line.Consume(end)) break; + if (str_util::ConsumePrefix(&line, end)) break; if (first) { first = false; } else { diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 9ec1c213c3..cfde1e8ea3 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -365,7 +365,7 @@ Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { const Tensor& OpKernelContext::input(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, num_inputs()); + DCHECK_LT(index, num_inputs()) << " name: " << op_kernel().name(); DCHECK(!input_is_ref(index)); const Tensor& tensor = *((*params_->inputs)[index].tensor); record_tensor_reference(tensor); @@ -420,8 +420,8 @@ bool OpKernelContext::forward_input_to_output_with_shape( ? AllocatorAttributes() : output_alloc_attr(output_index); std::unique_ptr<Tensor> new_tensor = forward_input( - input_index, expected_output_dtype(output_index), output_shape, - output_memory_type(output_index), output_attr); + input_index, output_index, expected_output_dtype(output_index), + output_shape, output_memory_type(output_index), output_attr); if (new_tensor != nullptr) { // Transfer ownership to the output slot in OpKernelContext. outputs_[output_index] = TensorValue(new_tensor.release()); @@ -461,35 +461,66 @@ Status OpKernelContext::forward_input_to_output_with_shape( } std::unique_ptr<Tensor> OpKernelContext::forward_input( - int input_index, DataType output_dtype, const TensorShape& output_shape, - MemoryType output_memory_type, const AllocatorAttributes& output_attr) { + int input_index, int output_index, DataType output_dtype, + const TensorShape& output_shape, MemoryType output_memory_type, + const AllocatorAttributes& output_attr) { DCHECK_GE(input_index, 0); DCHECK_LT(input_index, num_inputs()); const TensorValue& input = (*params_->inputs)[input_index]; - // Check that input tensor exists, is not a ref, and has no other consumers. - if (input.tensor == nullptr || input.is_ref() || !input->RefCountIsOne()) { + // Check whether at graph construction time this output was marked + // either for no forwarding or with a reservation for this input. + // If it's reserved for this input we'll skip the refcount and + // AllocatorAttribute checks. + // TODO(tucker): Maybe we should skip all of the checks? + bool never_forward = + (params_->forward_from_array != nullptr && output_index >= 0 && + params_->forward_from_array[output_index] == Params::kNeverForward); + if (never_forward) return nullptr; + bool forward_expected = + (params_->forward_from_array != nullptr && output_index >= 0 && + params_->forward_from_array[output_index] == input_index); + if (!forward_expected && params_->forward_from_array != nullptr) { + // Check for possibly conflicting forward. + for (int i = 0; i < num_outputs(); ++i) { + if (params_->forward_from_array[i] == input_index) { + // This input is reserved for output i. + return nullptr; + } + } + } + // Check that input tensor exists and is not a ref. + if (input.tensor == nullptr || input.is_ref()) { + CHECK(!forward_expected); return nullptr; } // Check that input type matches. if (input_dtype(input_index) != output_dtype) { + CHECK(!forward_expected); return nullptr; } // Check that the input and output sizes are compatible. if (input.tensor->shape().num_elements() != output_shape.num_elements()) { + CHECK(!forward_expected); return nullptr; } // Check that input and output memory types match, i.e. // that they either both live in host or both live in device memory. if (input_memory_type(input_index) != output_memory_type) { + CHECK(!forward_expected); return nullptr; } - // Check that output allocator attributes are not more restrictive than - // input allocator attributes. - const auto input_attr = params_->input_alloc_attrs == nullptr - ? AllocatorAttributes() - : input_alloc_attr(input_index); - if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) { - return nullptr; + if (!forward_expected) { + if (!input->RefCountIsOne()) { + return nullptr; + } + // Check that output allocator attributes are not more restrictive than + // input allocator attributes. + const auto input_attr = params_->input_alloc_attrs == nullptr + ? AllocatorAttributes() + : input_alloc_attr(input_index); + if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) { + return nullptr; + } } // TODO(rmlarsen): Use MakeUnique here. There is already a copy in // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of @@ -505,7 +536,8 @@ Status OpKernelContext::forward_input_or_allocate_temp( Tensor* out_temp) { for (int input_index : candidate_input_indices) { std::unique_ptr<Tensor> new_tensor = - forward_input(input_index, type, shape, DEVICE_MEMORY, allocator_attr); + forward_input(input_index, Params::kNoReservation /*output_index*/, + type, shape, DEVICE_MEMORY, allocator_attr); if (new_tensor != nullptr) { *out_temp = std::move(*new_tensor); return Status::OK(); @@ -595,6 +627,14 @@ Status OpKernelContext::allocate_output(int index, const TensorShape& shape, Tensor** output) { DCHECK_GE(index, 0); DCHECK_LT(index, num_outputs()); + bool forward_expected = + (params_->forward_from_array != nullptr && index >= 0 && + params_->forward_from_array[index] >= 0); + if (forward_expected) { + return errors::Internal( + "Explicit allocate_output call where input forwarding required. Try " + "turning off the ScopedAllocator optimizer."); + } AllocatorAttributes attr = output_alloc_attr(index); return allocate_output(index, shape, output, attr); } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 2d97160830..67943377b9 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -64,10 +64,11 @@ class AsyncOpKernel; class CallFrameInterface; class FunctionLibraryRuntime; class OpKernelConstruction; // declared below -class OpKernelContext; // declared below +class OpKernelContext; // declared below, class OpRegistryInterface; class ResourceMgr; class ScopedStepContainer; +class CollectiveExecutor; class StepStatsCollector; class OpKernel { @@ -532,6 +533,10 @@ class OpKernelContext { // computations running on other devices. Rendezvous* rendezvous = nullptr; + // Mechanism for executing a collective op that needs to coordinate + // with parallel instances runing on other devices. + CollectiveExecutor* collective_executor = nullptr; + // The session state for this op. SessionState* session_state = nullptr; @@ -565,6 +570,12 @@ class OpKernelContext { // TensorSliceReaderCache support. checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; + + // Support for forwarding reservations (used by ScopedAllocator). + static const int kNeverForward = -2; + static const int kNoReservation = -1; + // Values in [0,...) represent reservations for the indexed output. + const int* forward_from_array = nullptr; }; // params must outlive the OpKernelContext. @@ -707,14 +718,31 @@ class OpKernelContext { // input[input_index] are compatible with those given in dtype, shape, // memory_type, and attr, // * refcount on the underlying buffer is one. + // * Either there is no forwarding reservation for either input_index + // or output_index or the specified input is reserved for the specified + // output. More precisely: + // + // These cases mean neither input nor output has a reservation: + // forward_from_array = nullptr + // OR (input_index is not in forward_from_array AND + // (output_index == kNoReservation OR + // forward_from_array[output_index] == kNoReservation)) + // + // This case means that input_index is reserved for output_index: + // forward_from_array[output_index] == input_index + // + // This case means the output is reserved to always be allocated, + // never assigned a forwarded input: + // forward_from_array[output_index] == kNeverForward + // // Otherwise returns nullptr. // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic, // forwarding is only safe if there are no reads via __ldg() after writes // to the same address. std::unique_ptr<Tensor> forward_input( - int input_index, DataType dtype, const TensorShape& shape, - MemoryType memory_type, - const AllocatorAttributes& attr) TF_MUST_USE_RESULT; + int input_index, int output_index, DataType output_dtype, + const TensorShape& output_shape, MemoryType output_memory_type, + const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT; // Tries to forward one of the inputs given in input_indices to // output[output_index]. If none of the given inputs can be forwarded, calls @@ -934,6 +962,10 @@ class OpKernelContext { // Rendezvous Send() and Recv(). Rendezvous* rendezvous() const { return params_->rendezvous; } + CollectiveExecutor* collective_executor() const { + return params_->collective_executor; + } + // An op kernel can access the session state it belongs to. SessionState* session_state() const { return params_->session_state; } @@ -1102,7 +1134,7 @@ class OpKernelContext { Status status_; friend class CollectiveExecutor; // for access to params_ - Params* params_; // not owned + Params* params_; // not owned mutable mutex mu_; // mutable so const accessors can acquire the lock gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_); gtl::InlinedVector<TensorValue, 4> outputs_; diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index b53b877f28..bcd409e5c5 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -546,9 +546,9 @@ TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) { {"T|list(type)|[DT_FLOAT]"})); ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); - EXPECT_TRUE( - StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {})) - .contains("Invalid argument: ")); + EXPECT_TRUE(str_util::StrContains( + GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}), + "Invalid argument: ")); ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"}, error::INVALID_ARGUMENT); @@ -565,8 +565,8 @@ TEST_F(OpKernelBuilderTest, DuplicateKernel) { DeviceTypeVector devs; Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(StringPiece(status.error_message()) - .contains("Multiple OpKernel registrations match NodeDef")); + EXPECT_TRUE(str_util::StrContains( + status.error_message(), "Multiple OpKernel registrations match NodeDef")); ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT); } @@ -585,8 +585,8 @@ TEST_F(OpKernelBuilderTest, DuplicateKernelForT) { DeviceTypeVector devs; Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(StringPiece(status.error_message()) - .contains("Multiple OpKernel registrations match NodeDef")); + EXPECT_TRUE(str_util::StrContains( + status.error_message(), "Multiple OpKernel registrations match NodeDef")); ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"}, error::INVALID_ARGUMENT); @@ -606,8 +606,9 @@ TEST_F(OpKernelBuilderTest, BadConstraint) { DeviceTypeVector devs; Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(StringPiece(status.error_message()) - .contains("OpKernel 'BadConstraint' has constraint on attr " + EXPECT_TRUE( + str_util::StrContains(status.error_message(), + "OpKernel 'BadConstraint' has constraint on attr " "'T' not in NodeDef")); ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"}, diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index 07272e2374..798220d4c3 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -71,7 +72,7 @@ string LookupOrCreate(ResourceMgr* rm, const string& container, } static void HasError(const Status& s, const string& substr) { - EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) + EXPECT_TRUE(str_util::StrContains(s.ToString(), substr)) << s << ", expected substring " << substr; } diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index f48a7b9c47..da103bfec9 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -152,10 +153,9 @@ TEST_F(ShapeInferenceTest, Run) { }; Status s = c.Run(fn); // Extra error message is attached when Run fails. - EXPECT_TRUE(StringPiece(s.ToString()) - .contains("Shape must be at most rank 0 but " - "is rank 1 for 'foo' (op: " - "'foo_op')")) + EXPECT_TRUE(str_util::StrContains( + s.ToString(), + "Shape must be at most rank 0 but is rank 1 for 'foo' (op: 'foo_op')")) << s; } } @@ -367,10 +367,9 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) { // WithRankAtMost on shape with known dimensionality. s1 = in1; - EXPECT_TRUE( - StringPiece(c.WithRankAtMost(in1, 2, &s1).ToString()) - .contains( - "Invalid argument: Shape must be at most rank 2 but is rank 3")); + EXPECT_TRUE(str_util::StrContains( + c.WithRankAtMost(in1, 2, &s1).ToString(), + "Invalid argument: Shape must be at most rank 2 but is rank 3")); EXPECT_FALSE(IsSet(s1)); EXPECT_TRUE(c.WithRankAtMost(in1, 3, &s1).ok()); @@ -406,10 +405,9 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) { // WithRankAtLeast on shape with known dimensionality. s1 = in1; - EXPECT_TRUE( - StringPiece(c.WithRankAtLeast(in1, 4, &s1).ToString()) - .contains( - "Invalid argument: Shape must be at least rank 4 but is rank 3")); + EXPECT_TRUE(str_util::StrContains( + c.WithRankAtLeast(in1, 4, &s1).ToString(), + "Invalid argument: Shape must be at least rank 4 but is rank 3")); EXPECT_FALSE(IsSet(s1)); EXPECT_TRUE(c.WithRankAtLeast(in1, 3, &s1).ok()); @@ -449,12 +447,14 @@ TEST_F(ShapeInferenceTest, WithValue) { // WithValue on dimension with known size. out1 = d0; - EXPECT_TRUE(StringPiece(c.WithValue(d0, 0, &out1).ToString()) - .contains("Invalid argument: Dimension must be 0 but is 1")); + EXPECT_TRUE( + str_util::StrContains(c.WithValue(d0, 0, &out1).ToString(), + "Invalid argument: Dimension must be 0 but is 1")); EXPECT_FALSE(IsSet(out1)); out1 = d0; - EXPECT_TRUE(StringPiece(c.WithValue(d0, 2, &out1).ToString()) - .contains("Invalid argument: Dimension must be 2 but is 1")); + EXPECT_TRUE( + str_util::StrContains(c.WithValue(d0, 2, &out1).ToString(), + "Invalid argument: Dimension must be 2 but is 1")); EXPECT_FALSE(IsSet(out1)); EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok()); @@ -513,16 +513,14 @@ TEST_F(ShapeInferenceTest, MergeDim) { EXPECT_EQ(3, merged_dims.size()); // Merging unequal values is an error. - EXPECT_TRUE( - StringPiece(c.Merge(d2, d1, &out).ToString()) - .contains( - "Invalid argument: Dimensions must be equal, but are 2 and 1")); + EXPECT_TRUE(str_util::StrContains( + c.Merge(d2, d1, &out).ToString(), + "Invalid argument: Dimensions must be equal, but are 2 and 1")); EXPECT_FALSE(IsSet(out)); - EXPECT_TRUE( - StringPiece(c.Merge(d1, d2, &out).ToString()) - .contains( - "Invalid argument: Dimensions must be equal, but are 1 and 2")); + EXPECT_TRUE(str_util::StrContains( + c.Merge(d1, d2, &out).ToString(), + "Invalid argument: Dimensions must be equal, but are 1 and 2")); EXPECT_FALSE(IsSet(out)); @@ -729,26 +727,23 @@ TEST_F(ShapeInferenceTest, MergeShape) { // Incompatible merges give errors and set out to nullptr. out = s_unknown; - EXPECT_TRUE( - StringPiece(c.Merge(s_u_2, s_1_3, &out).ToString()) - .contains( - "Invalid argument: Dimension 1 in both shapes must be equal, but " - "are 2 and 3")); + EXPECT_TRUE(str_util::StrContains( + c.Merge(s_u_2, s_1_3, &out).ToString(), + "Invalid argument: Dimension 1 in both shapes must be equal, but " + "are 2 and 3")); EXPECT_FALSE(IsSet(out)); out = s_unknown; - EXPECT_TRUE( - StringPiece(c.Merge(s_1_3, s_u_2, &out).ToString()) - .contains( - "Invalid argument: Dimension 1 in both shapes must be equal, but " - "are 3 and 2")); + EXPECT_TRUE(str_util::StrContains( + c.Merge(s_1_3, s_u_2, &out).ToString(), + "Invalid argument: Dimension 1 in both shapes must be equal, but " + "are 3 and 2")); EXPECT_FALSE(IsSet(out)); out = s_unknown; - EXPECT_TRUE( - StringPiece(c.Merge(s_1, s_1_2, &out).ToString()) - .contains( - "Invalid argument: Shapes must be equal rank, but are 1 and 2")); + EXPECT_TRUE(str_util::StrContains( + c.Merge(s_1, s_1_2, &out).ToString(), + "Invalid argument: Shapes must be equal rank, but are 1 and 2")); EXPECT_FALSE(IsSet(out)); @@ -795,22 +790,18 @@ TEST_F(ShapeInferenceTest, MergePrefix) { // Incompatible merges give errors and set outs to nullptr. s_out = s_unknown; s_prefix_out = s_unknown; - EXPECT_TRUE( - StringPiece( - c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString()) - .contains( - "Invalid argument: Dimensions must be equal, but are 1 and 2")); + EXPECT_TRUE(str_util::StrContains( + c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString(), + "Invalid argument: Dimensions must be equal, but are 1 and 2")); EXPECT_FALSE(IsSet(s_out)); EXPECT_FALSE(IsSet(s_prefix_out)); s_out = s_unknown; s_prefix_out = s_unknown; - EXPECT_TRUE( - StringPiece( - c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString()) - .contains( - "Invalid argument: Shape must be at least rank 3 but is rank 2")); + EXPECT_TRUE(str_util::StrContains( + c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString(), + "Invalid argument: Shape must be at least rank 3 but is rank 2")); EXPECT_FALSE(IsSet(s_out)); EXPECT_FALSE(IsSet(s_prefix_out)); } @@ -868,24 +859,21 @@ TEST_F(ShapeInferenceTest, Subshape) { // Errors. out = unknown; - EXPECT_TRUE(StringPiece(c.Subshape(in0, 6, -3, &out).ToString()) - .contains("Invalid argument: Subshape must have computed " - "start <= end, but is 5 " - "and 2 (computed from start 6 and end -3 over " - "shape with rank 5)")); + EXPECT_TRUE(str_util::StrContains( + c.Subshape(in0, 6, -3, &out).ToString(), + "Invalid argument: Subshape must have computed start <= end, but is 5 " + "and 2 (computed from start 6 and end -3 over shape with rank 5)")); EXPECT_FALSE(IsSet(out)); out = unknown; - EXPECT_TRUE(StringPiece(c.Subshape(in0, -50, 100, &out).ToString()) - .contains("Invalid argument: Subshape start out of " - "bounds: -50, for shape with " - "rank 5")); + EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, -50, 100, &out).ToString(), + "Invalid argument: Subshape start out of " + "bounds: -50, for shape with rank 5")); EXPECT_FALSE(IsSet(out)); out = unknown; - EXPECT_TRUE(StringPiece(c.Subshape(in0, 0, -50, &out).ToString()) - .contains("Invalid argument: Subshape end out of bounds: " - "-50, for shape with rank " - "5")); + EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, 0, -50, &out).ToString(), + "Invalid argument: Subshape end out of " + "bounds: -50, for shape with rank 5")); EXPECT_FALSE(IsSet(out)); } @@ -1094,27 +1082,26 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { EXPECT_EQ("[]", create(&t)); t = ::tensorflow::test::AsTensor<float>({1, 2, 3}); - EXPECT_TRUE( - StringPiece(create(&t)) - .contains("Input tensor must be int32 or int64, but was float")); + EXPECT_TRUE(str_util::StrContains( + create(&t), "Input tensor must be int32 or int64, but was float")); t = ::tensorflow::test::AsScalar<int32>(1); - EXPECT_TRUE(StringPiece(create(&t)) - .contains("Input tensor must be rank 1, but was rank 0")); + EXPECT_TRUE(str_util::StrContains( + create(&t), "Input tensor must be rank 1, but was rank 0")); t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1}); - EXPECT_TRUE(StringPiece(create(&t)) - .contains("Input tensor must be rank 1, but was rank 2")); + EXPECT_TRUE(str_util::StrContains( + create(&t), "Input tensor must be rank 1, but was rank 2")); // Test negative values for the dims. t = ::tensorflow::test::AsTensor<int64>({3, -2, 1}); - EXPECT_TRUE(StringPiece(create(&t)) - .contains("Invalid value in tensor used for shape: -2")); + EXPECT_TRUE(str_util::StrContains( + create(&t), "Invalid value in tensor used for shape: -2")); // Test negative values for the dims. t = ::tensorflow::test::AsTensor<int32>({3, -2, 1}); - EXPECT_TRUE(StringPiece(create(&t)) - .contains("Invalid value in tensor used for shape: -2")); + EXPECT_TRUE(str_util::StrContains( + create(&t), "Invalid value in tensor used for shape: -2")); // Test when the input shape is wrong. { @@ -1172,9 +1159,9 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok()); EXPECT_EQ("?", c.DebugString(out)); proto.add_dim()->set_size(0); - EXPECT_TRUE( - StringPiece(c.MakeShapeFromShapeProto(proto, &out).error_message()) - .contains("An unknown shape must not have any dimensions set.")); + EXPECT_TRUE(str_util::StrContains( + c.MakeShapeFromShapeProto(proto, &out).error_message(), + "An unknown shape must not have any dimensions set.")); EXPECT_FALSE(IsSet(out)); // With known rank. @@ -1188,10 +1175,10 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { // With invalid dimension value. proto.add_dim()->set_size(-2); - EXPECT_TRUE( - StringPiece(c.MakeShapeFromShapeProto(proto, &out).error_message()) - .contains("Shape [0,?,1000,-2] has dimensions with values below -1 " - "(where -1 means unknown)")); + EXPECT_TRUE(str_util::StrContains( + c.MakeShapeFromShapeProto(proto, &out).error_message(), + "Shape [0,?,1000,-2] has dimensions with values below -1 " + "(where -1 means unknown)")); EXPECT_FALSE(IsSet(out)); } @@ -1257,9 +1244,10 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) { EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); EXPECT_EQ("20", c.DebugString(d)); - EXPECT_TRUE(StringPiece(c.MakeDimForScalarInput(1, &d).error_message()) - .contains("Dimension size, given by scalar input 1, must " - "be non-negative but is -1")); + EXPECT_TRUE( + str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(), + "Dimension size, given by scalar input 1, must be " + "non-negative but is -1")); // Same tests, with int64 values. t1 = tensorflow::test::AsScalar<int64>(20); @@ -1267,9 +1255,10 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) { EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); EXPECT_EQ("20", c.DebugString(d)); - EXPECT_TRUE(StringPiece(c.MakeDimForScalarInput(1, &d).error_message()) - .contains("Dimension size, given by scalar input 1, must " - "be non-negative but is -1")); + EXPECT_TRUE( + str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(), + "Dimension size, given by scalar input 1, must be " + "non-negative but is -1")); } TEST_F(ShapeInferenceTest, GetAttr) { @@ -1322,33 +1311,33 @@ TEST_F(ShapeInferenceTest, Divide) { EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok()); EXPECT_EQ("3", c.DebugString(out)); - EXPECT_TRUE( - StringPiece(c.Divide(d_6, 5, evenly_divisible, &out).error_message()) - .contains("Dimension size must be evenly divisible by 5 but is 6")); + EXPECT_TRUE(str_util::StrContains( + c.Divide(d_6, 5, evenly_divisible, &out).error_message(), + "Dimension size must be evenly divisible by 5 but is 6")); - EXPECT_TRUE( - StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message()) - .contains("Divisor must be positive but is 0")); - EXPECT_TRUE( - StringPiece(c.Divide(d_6, d_0, evenly_divisible, &out).error_message()) - .contains("Divisor must be positive but is 0")); + EXPECT_TRUE(str_util::StrContains( + c.Divide(d_6, 0, evenly_divisible, &out).error_message(), + "Divisor must be positive but is 0")); + EXPECT_TRUE(str_util::StrContains( + c.Divide(d_6, d_0, evenly_divisible, &out).error_message(), + "Divisor must be positive but is 0")); - EXPECT_TRUE( - StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message()) - .contains("Divisor must be positive but is -1")); + EXPECT_TRUE(str_util::StrContains( + c.Divide(d_6, -1, evenly_divisible, &out).error_message(), + "Divisor must be positive but is -1")); // Repeat error cases above with evenly_divisible=false. evenly_divisible = false; EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok()); EXPECT_EQ("1", c.DebugString(out)); - EXPECT_TRUE( - StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message()) - .contains("Divisor must be positive but is 0")); + EXPECT_TRUE(str_util::StrContains( + c.Divide(d_6, 0, evenly_divisible, &out).error_message(), + "Divisor must be positive but is 0")); - EXPECT_TRUE( - StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message()) - .contains("Divisor must be positive but is -1")); + EXPECT_TRUE(str_util::StrContains( + c.Divide(d_6, -1, evenly_divisible, &out).error_message(), + "Divisor must be positive but is -1")); } TEST_F(ShapeInferenceTest, Add) { @@ -1396,11 +1385,9 @@ TEST_F(ShapeInferenceTest, Add) { EXPECT_TRUE(c.Add(d_0, d_6, &out).ok()); EXPECT_TRUE(SameHandle(out, d_6)); - EXPECT_TRUE( - StringPiece(c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out) - .error_message()) - .contains( - "Dimension size overflow from adding 6 and 9223372036854775802")); + EXPECT_TRUE(str_util::StrContains( + c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message(), + "Dimension size overflow from adding 6 and 9223372036854775802")); } TEST_F(ShapeInferenceTest, Subtract) { @@ -1448,9 +1435,9 @@ TEST_F(ShapeInferenceTest, Subtract) { EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok()); EXPECT_TRUE(SameHandle(out, d_6)); - EXPECT_TRUE( - StringPiece(c.Subtract(d_5, d_6, &out).error_message()) - .contains("Negative dimension size caused by subtracting 6 from 5")); + EXPECT_TRUE(str_util::StrContains( + c.Subtract(d_5, d_6, &out).error_message(), + "Negative dimension size caused by subtracting 6 from 5")); } TEST_F(ShapeInferenceTest, Multiply) { diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc index b4765ab0b2..b54dd220ab 100644 --- a/tensorflow/core/framework/shape_inference_testutil.cc +++ b/tensorflow/core/framework/shape_inference_testutil.cc @@ -100,7 +100,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, } } - if (expected.starts_with("in")) { + if (str_util::StartsWith(expected, "in")) { if (in_index == -1) { return Unknown(err_prefix, " should have matched an input shape by " @@ -135,7 +135,9 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, } // Verify the dimensions. - CHECK(expected.starts_with("[") && expected.ends_with("]")) << expected; + CHECK(str_util::StartsWith(expected, "[") && + str_util::EndsWith(expected, "]")) + << expected; expected.remove_prefix(1); expected.remove_suffix(1); @@ -176,7 +178,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, return Unknown(err_prefix, " expected to be unknown but was ", c.Value(out_dim), err_suffix); } - } else if (expected_dim.starts_with("d")) { + } else if (str_util::StartsWith(expected_dim, "d")) { // Compare the dimension values. auto v = str_util::Split(expected_dim, '|'); if (in_dim_idx.first == -1) { diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h index 7977841482..2a99af7659 100644 --- a/tensorflow/core/framework/shape_inference_testutil.h +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/version.h" @@ -83,17 +84,17 @@ class ShapeInferenceTestutil { "", ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ op, i, o) \ .error_message()) -#define INFER_ERROR(error_substring, op, i) \ - { \ - string error_message = \ - ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ - op, i, "e") \ - .error_message(); \ - const string& substring = error_substring; \ - EXPECT_NE("", error_message); \ - EXPECT_TRUE(StringPiece(error_message).contains(substring)) \ - << "Expected to see '" << substring << "' in '" << error_message \ - << "'"; \ +#define INFER_ERROR(error_substring, op, i) \ + { \ + string error_message = \ + ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ + op, i, "e") \ + .error_message(); \ + const string& substring = error_substring; \ + EXPECT_NE("", error_message); \ + EXPECT_TRUE(::tensorflow::str_util::StrContains(error_message, substring)) \ + << "Expected to see '" << substring << "' in '" << error_message \ + << "'"; \ } } // namespace tensorflow diff --git a/tensorflow/core/framework/shape_inference_testutil_test.cc b/tensorflow/core/framework/shape_inference_testutil_test.cc index 20a6807064..a4405b502c 100644 --- a/tensorflow/core/framework/shape_inference_testutil_test.cc +++ b/tensorflow/core/framework/shape_inference_testutil_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -25,10 +26,11 @@ namespace shape_inference { namespace { -#define EXPECT_CONTAINS(str, substr) \ - do { \ - string s = (str); \ - EXPECT_TRUE(StringPiece(s).contains(substr)) << "String: " << s; \ +#define EXPECT_CONTAINS(str, substr) \ + do { \ + string s = (str); \ + EXPECT_TRUE(::tensorflow::str_util::StrContains(s, substr)) \ + << "String: " << s; \ } while (false) static OpShapeInferenceFn* global_fn_ptr = nullptr; @@ -97,8 +99,8 @@ TEST(ShapeInferenceTestutilTest, Failures) { auto error_message = ShapeInferenceTestutil::InferShapes( ShapeInferenceTestOp("NoSuchOp"), "", "") .error_message(); - EXPECT_TRUE(StringPiece(error_message) - .starts_with("Op type not registered 'NoSuchOp'")); + EXPECT_TRUE( + str_util::StartsWith(error_message, "Op type not registered 'NoSuchOp'")); // Wrong shape error messages. EXPECT_CONTAINS(RunInferShapes(op, "[1];[2];[1]", "?", fn_copy_input_0), diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc index adf4e1bae3..2280114de5 100644 --- a/tensorflow/core/framework/types.cc +++ b/tensorflow/core/framework/types.cc @@ -114,7 +114,7 @@ string DataTypeString(DataType dtype) { } bool DataTypeFromString(StringPiece sp, DataType* dt) { - if (sp.ends_with("_ref")) { + if (str_util::EndsWith(sp, "_ref")) { sp.remove_suffix(4); DataType non_ref; if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) { diff --git a/tensorflow/core/framework/types_test.cc b/tensorflow/core/framework/types_test.cc index 60f2b4135a..16b069c70a 100644 --- a/tensorflow/core/framework/types_test.cc +++ b/tensorflow/core/framework/types_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -140,9 +141,8 @@ TEST(TypesTest, ComplexTypes) { TEST(TypesTest, IntegerTypes) { for (auto dt : AllTypes()) { const string name = DataTypeString(dt); - const StringPiece n = name; - EXPECT_EQ(DataTypeIsInteger(dt), - n.starts_with("int") || n.starts_with("uint")) + EXPECT_EQ(DataTypeIsInteger(dt), str_util::StartsWith(name, "int") || + str_util::StartsWith(name, "uint")) << "DataTypeInteger failed for " << name; } } diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc index 85e014f804..60fa7bd559 100644 --- a/tensorflow/core/framework/variant_op_copy_test.cc +++ b/tensorflow/core/framework/variant_op_copy_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/port.h" @@ -259,8 +260,8 @@ TEST(VariantOpCopyTest, CreateConstOnGPUFailsGracefully) { ClientSession session(root); std::vector<Tensor> outputs; Status s = session.Run({create_const}, &outputs); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("GPU copy from non-DMA string tensor")) + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "GPU copy from non-DMA string tensor")) << s.ToString(); } @@ -365,8 +366,9 @@ TEST(VariantOpCopyTest, CreateCopyCPUToGPUStringFailsSafely) { std::vector<Tensor> outputs; Status err = session.Run({create_op, identity}, &outputs); EXPECT_EQ(err.code(), errors::Code::INVALID_ARGUMENT); - EXPECT_TRUE(StringPiece(err.error_message()) - .contains("During Variant Host->Device Copy: non-DMA-copy " + EXPECT_TRUE( + str_util::StrContains(err.error_message(), + "During Variant Host->Device Copy: non-DMA-copy " "attempted of tensor type: string")) << err.error_message(); } diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc index 06ca211c76..7055e62c0e 100644 --- a/tensorflow/core/framework/variant_op_registry_test.cc +++ b/tensorflow/core/framework/variant_op_registry_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include <memory> +#include "tensorflow/core/lib/strings/str_util.h" #define EIGEN_USE_THREADS @@ -130,7 +131,7 @@ TEST(VariantOpShapeRegistryTest, TestBasic) { Variant v = vv_early_exit; Status s0 = (*shape_fn)(v, &shape); EXPECT_FALSE(s0.ok()); - EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit!")); + EXPECT_TRUE(str_util::StrContains(s0.error_message(), "early exit!")); VariantValue vv_ok{false /* early_exit */}; v = vv_ok; @@ -229,7 +230,7 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out); EXPECT_FALSE(s0.ok()); EXPECT_TRUE( - StringPiece(s0.error_message()).contains("early exit zeros_like")); + str_util::StrContains(s0.error_message(), "early exit zeros_like")); VariantValue vv_ok{false /* early_exit */, 0 /* value */}; v = vv_ok; @@ -254,7 +255,7 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out); EXPECT_FALSE(s0.ok()); EXPECT_TRUE( - StringPiece(s0.error_message()).contains("early exit zeros_like")); + str_util::StrContains(s0.error_message(), "early exit zeros_like")); VariantValue vv_ok{false /* early_exit */, 0 /* value */}; v = vv_ok; @@ -299,7 +300,7 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) { Status s0 = BinaryOpVariants<CPUDevice>( null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out); EXPECT_FALSE(s0.ok()); - EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add")); + EXPECT_TRUE(str_util::StrContains(s0.error_message(), "early exit add")); VariantValue vv_ok{false /* early_exit */, 3 /* value */}; v_a = vv_ok; @@ -325,7 +326,7 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) { Status s0 = BinaryOpVariants<GPUDevice>( null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out); EXPECT_FALSE(s0.ok()); - EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add")); + EXPECT_TRUE(str_util::StrContains(s0.error_message(), "early exit add")); VariantValue vv_ok{false /* early_exit */, 3 /* value */}; v_a = vv_ok; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index a7af5e2312..fb8a6c39e6 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -567,6 +567,11 @@ void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const { inputs[edge->dst_input()] = edge; } } + // Sort the control inputs for more predictable serialization. + std::sort(inputs.begin() + node->num_inputs(), inputs.end(), + [](const Edge* a, const Edge* b) -> bool { + return a->src()->name() < b->src()->name(); + }); node_def->clear_input(); node_def->mutable_input()->Reserve(inputs.size()); diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 76ee88e684..f15e2ce9fa 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/version.h" @@ -73,7 +74,7 @@ class GraphConstructor { Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit) : allow_internal_ops(false), expect_device_spec(false), - prefix(in.prefix.empty() || StringPiece(in.prefix).ends_with("/") + prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/") ? in.prefix : in.prefix + "/"), uniquify_names(in.uniquify_names), @@ -436,7 +437,7 @@ Status GraphConstructor::BuildNodeIndex() { bool in_control_dependence = false; for (int i = 0; i < node_def.input_size(); ++i) { StringPiece input_name = node_def.input(i); - if (!input_name.empty() && input_name.starts_with("^")) { + if (!input_name.empty() && str_util::StartsWith(input_name, "^")) { in_control_dependence = true; } else if (in_control_dependence) { return errors::InvalidArgument( @@ -484,7 +485,7 @@ Status GraphConstructor::InitFromEdges() { bool has_loop_back_edge = false; for (int i = 0; i < node_def.input_size(); ++i) { StringPiece input_name(node_def.input(i)); - if (input_name.starts_with("^")) { + if (str_util::StartsWith(input_name, "^")) { num_control_edges++; } else { TensorId id(ParseTensorName(input_name)); @@ -534,7 +535,7 @@ Status GraphConstructor::ValidateColocationConstraints( if (iter == node_def.attr().end()) return Status::OK(); for (const string& c : iter->second.list().s()) { StringPiece s(c); - if (s.Consume(kColocationGroupPrefix) && + if (str_util::ConsumePrefix(&s, kColocationGroupPrefix) && gdef_nodes_.find(s) == gdef_nodes_.end()) { return errors::InvalidArgument( "Node '", node_def.name(), @@ -764,7 +765,7 @@ void GraphConstructor::AddPrefixToNodeDef( // Skip remapped inputs (which already exist in g_ and are not being // imported). if (input_already_exists[i]) continue; - if (input.Consume("^")) { + if (str_util::ConsumePrefix(&input, "^")) { node_def->set_input(i, strings::StrCat("^", prefix_, input)); } else { node_def->set_input(i, strings::StrCat(prefix_, input)); @@ -776,7 +777,7 @@ void GraphConstructor::AddPrefixToNodeDef( node_def->mutable_attr()->at(kColocationAttrName).mutable_list(); for (int i = 0; i < list->s_size(); ++i) { StringPiece v(list->s(i)); - if (v.Consume(kColocationGroupPrefix)) { + if (str_util::ConsumePrefix(&v, kColocationGroupPrefix)) { list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v)); } } @@ -819,7 +820,7 @@ void GraphConstructor::UpdateUniquifiedColocationNames() { bool updated = false; for (int i = 0; i < coloc_values.size(); ++i) { StringPiece val(coloc_values[i]); - if (val.Consume(kColocationGroupPrefix)) { + if (str_util::ConsumePrefix(&val, kColocationGroupPrefix)) { const auto& name_pair = uniquified_names_.find(val.ToString()); if (name_pair == uniquified_names_.end()) continue; updated = true; diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 963c1dc024..c18ccf6ce4 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -156,7 +156,9 @@ class GraphConstructorTest : public ::testing::Test { return ""; } StringPiece loc(value[0]); - return loc.Consume(kColocationGroupPrefix) ? loc.ToString() : ""; + return str_util::ConsumePrefix(&loc, kColocationGroupPrefix) + ? loc.ToString() + : ""; } string GraphDebugString() const { diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 17a174101b..877e4f1b44 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/device_name_utils.h" @@ -372,7 +373,7 @@ string ControlLoopName(const string& name) { bool IsControlLoop(const Node* node) { const string& name = node->name(); - return StringPiece(name).starts_with("_cloop"); + return str_util::StartsWith(name, "_cloop"); } // An enter node for control flow. diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index 6841f29149..83b24cafe2 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -120,7 +121,7 @@ void CheckLoopConstruction(const GraphDef& graph_def) { if (ndef.op() == "_Recv") { bool has_control = false; for (const string& input_name : ndef.input()) { - if (StringPiece(input_name).starts_with("^")) { + if (str_util::StartsWith(input_name, "^")) { has_control = true; break; } @@ -128,7 +129,7 @@ void CheckLoopConstruction(const GraphDef& graph_def) { EXPECT_TRUE(has_control); } // Must have a control loop - if (StringPiece(ndef.name()).starts_with("_cloop")) { + if (str_util::StartsWith(ndef.name(), "_cloop")) { if (ndef.op() == "Enter") { has_control_enter = true; } diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index e2ce0ba046..c8c2b225fe 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -408,7 +409,7 @@ TEST_F(GraphTest, NewName) { EXPECT_NE(a1, a2); EXPECT_NE(a1, b1); EXPECT_NE(a2, b1); - EXPECT_TRUE(StringPiece(a1).starts_with("A")) << a1; + EXPECT_TRUE(str_util::StartsWith(a1, "A")) << a1; } TEST_F(GraphTest, IsValidNode) { diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc index cb0fc8a154..3b6e8cc233 100644 --- a/tensorflow/core/graph/quantize_training.cc +++ b/tensorflow/core/graph/quantize_training.cc @@ -259,8 +259,14 @@ Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2"); const string assign_op_name = strings::StrCat(name_prefix, "/Assign"); for (Node* var : variables) { - string new_restore_op_name = graph->NewName(restore_op_name); - string new_assign_op_name = graph->NewName(assign_op_name); + // Add an extra prefix after calling graph->NewName because the "unique" + // name may conflict with names generated for Send nodes. + // TODO(b/77547936): fix this more generally and get rid of the extra prefix + // here. + string new_restore_op_name = + strings::StrCat(graph->NewName(restore_op_name), "_qt"); + string new_assign_op_name = + strings::StrCat(graph->NewName(assign_op_name), "_qt"); string tensor_names_op_name = strings::StrCat(new_restore_op_name, "/tensor_names"); string shape_and_slices_op_name = diff --git a/tensorflow/core/graph/quantize_training_test.cc b/tensorflow/core/graph/quantize_training_test.cc index 2ad69dbd0c..e46f92bc24 100644 --- a/tensorflow/core/graph/quantize_training_test.cc +++ b/tensorflow/core/graph/quantize_training_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" @@ -215,7 +216,7 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_QuantizeAndDequantize) { Node* found_node; Status s = FindNode(g, strings::StrCat(d->name(), "/QuantizeAndDequantizeV2"), &found_node); - EXPECT_TRUE(StringPiece(s.ToString()).contains("not found")) << s; + EXPECT_TRUE(str_util::StrContains(s.ToString(), "not found")) << s; // Ensure that m1 and m2's inputs were quantized. TF_ASSERT_OK( @@ -269,7 +270,7 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_FakeQuant) { Node* found_node; Status s = FindNode(g, strings::StrCat(d->name(), "/FakeQuantWithMinMaxVars"), &found_node); - EXPECT_TRUE(StringPiece(s.ToString()).contains("not found")) << s; + EXPECT_TRUE(str_util::StrContains(s.ToString(), "not found")) << s; // Ensure that m1 and m2's inputs were quantized. TF_ASSERT_OK( diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index 7219d9812f..6c014a8d44 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -312,8 +312,8 @@ TEST_F(SubgraphTest, ChainOfFools) { EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0)); } -static bool HasSubstr(const string& base, const string& substr) { - bool ok = StringPiece(base).contains(substr); +static bool HasSubstr(StringPiece base, StringPiece substr) { + bool ok = str_util::StrContains(base, substr); EXPECT_TRUE(ok) << base << ", expected substring " << substr; return ok; } diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc index 089ea5e527..8af1936d64 100644 --- a/tensorflow/core/graph/tensor_id.cc +++ b/tensorflow/core/graph/tensor_id.cc @@ -18,6 +18,7 @@ limitations under the License. #include <string> #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { @@ -45,7 +46,7 @@ TensorId ParseTensorName(StringPiece name) { if (p > base && *p == ':' && mul > 1) { id.first = StringPiece(base, p - base); id.second = index; - } else if (name.starts_with("^")) { + } else if (str_util::StartsWith(name, "^")) { // Control edge id.first = StringPiece(base + 1); id.second = Graph::kControlSlot; diff --git a/tensorflow/core/graph/validate_test.cc b/tensorflow/core/graph/validate_test.cc index cb6d107cad..d58cdc3c5b 100644 --- a/tensorflow/core/graph/validate_test.cc +++ b/tensorflow/core/graph/validate_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -60,7 +61,7 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedDefaultAttr) { CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str; Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global()); EXPECT_FALSE(s.ok()); - EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr")); + EXPECT_TRUE(str_util::StrContains(s.ToString(), "NodeDef missing attr")); // Add the defaults. TF_ASSERT_OK(AddDefaultAttrsToGraphDef(&graph_def, *OpRegistry::Global(), 0)); @@ -83,7 +84,7 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) { CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str; Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global()); EXPECT_FALSE(s.ok()); - EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr")); + EXPECT_TRUE(str_util::StrContains(s.ToString(), "NodeDef missing attr")); // Add the defaults. TF_ASSERT_OK(AddDefaultAttrsToGraphDef(&graph_def, *OpRegistry::Global(), 0)); @@ -91,7 +92,7 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) { // Validation should still fail. s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global()); EXPECT_FALSE(s.ok()); - EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr")); + EXPECT_TRUE(str_util::StrContains(s.ToString(), "NodeDef missing attr")); } TEST(ValidateGraphDefAgainstOpListTest, GraphWithOpOnlyInOpList) { diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index 39bfca244e..8d8c6084ec 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -62,6 +62,10 @@ void Cluster::DisableOptimizer(bool disable) { options_.config.mutable_graph_options()->mutable_rewrite_options(); rewriter_config->set_layout_optimizer(RewriterConfig::OFF); rewriter_config->set_disable_model_pruning(true); + rewriter_config->set_function_optimization(RewriterConfig::OFF); + rewriter_config->set_arithmetic_optimization(RewriterConfig::OFF); + rewriter_config->set_loop_optimization(RewriterConfig::OFF); + rewriter_config->set_dependency_optimization(RewriterConfig::OFF); rewriter_config->set_constant_folding(RewriterConfig::OFF); rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT); rewriter_config->mutable_auto_parallel()->set_enable(false); diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc index b54b34959a..50d6e6468f 100644 --- a/tensorflow/core/grappler/clusters/utils.cc +++ b/tensorflow/core/grappler/clusters/utils.cc @@ -54,7 +54,7 @@ DeviceProperties GetLocalCPUInfo() { int64 free_mem = port::AvailableRam(); if (free_mem < INT64_MAX) { - device.set_memory_size(free_mem * 1024); + device.set_memory_size(free_mem); } (*device.mutable_environment())["cpu_instruction_set"] = diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index ae70c98608..abfa7bc48e 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -66,6 +66,7 @@ Status VirtualCluster::Run(const GraphDef& graph, } Costs node_costs; + int node_id = 0; do { OpContext op_context = scheduler.GetCurrNode(); node_costs = node_estimator_->PredictCosts(op_context); @@ -73,6 +74,7 @@ Status VirtualCluster::Run(const GraphDef& graph, CostGraphDef::Node* cost_node = metadata->mutable_cost_graph()->add_node(); const string& op_name = op_context.name; + cost_node->set_id(node_id++); cost_node->set_name(op_name); cost_node->set_device(op_context.device_name); cost_node->set_compute_cost( diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 5103098f27..8fe154dbf3 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -1011,6 +1011,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { } // Skip any information that comes from fed nodes. if (fed_ports.find(node->name()) != fed_ports.end()) { + VLOG(2) << "Skipping feed node shape: " << node->name(); continue; } for (const auto& merged_shapes : node_ctx->MergedShapes()) { diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 0f6307cfdf..14e46ecdd9 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -202,12 +202,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, - // TODO(76227186): re-enable with output size check & test - /* {kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, {kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, {kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, - */ {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)}, {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)}, @@ -817,6 +814,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( } if (!shape_found) { // Set the minimum filter size that's feasible. + input_shape.Clear(); for (int i = 0; i < 4; ++i) { input_shape.add_dim()->set_size(1); } @@ -859,6 +857,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations( } if (!shape_found) { // Set the minimum filter size that's feasible. + filter_shape.Clear(); for (int i = 0; i < 4; ++i) { filter_shape.add_dim()->set_size(1); } @@ -1056,6 +1055,13 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice( // part of it. For these op the size of the output determines the memory cost. const auto& op_info = op_context.op_info; + const int inputs_needed = op_info.op() == "Slice" ? 3 : 2; + if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) { + Costs costs = Costs::ZeroCosts(); + costs.inaccurate = true; + return costs; + } + bool unknown_shapes = false; // Each output element is a copy of some element from input. @@ -1242,10 +1248,31 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad( const OpContext& op_context) const { bool found_unknown_shapes = false; const auto& op_info = op_context.op_info; - // x: op_info.inputs(0) + // x's shape: op_info.inputs(0) // y_grad: op_info.inputs(1) - ConvolutionDimensions dims = OpDimensionsFromInputs( - op_info.inputs(0).shape(), op_info, &found_unknown_shapes); + + // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0). + bool shape_found = false; + TensorShapeProto x_shape; + if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) { + const TensorProto& value = op_info.inputs(0).value(); + shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape); + } + if (!shape_found && op_info.outputs_size() > 0) { + x_shape = op_info.outputs(0).shape(); + shape_found = true; + } + if (!shape_found) { + // Set the minimum shape that's feasible. + x_shape.Clear(); + for (int i = 0; i < 4; ++i) { + x_shape.add_dim()->set_size(1); + } + found_unknown_shapes = true; + } + + ConvolutionDimensions dims = + OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes); int64 ops = 0; if (dims.kx <= dims.sx && dims.ky <= dims.sy) { diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index 56915ed821..d797a8a8c1 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -217,6 +217,39 @@ std::vector<int> GetPoolingOutputSize(const std::vector<int>& input, return output; } +// Helper functions for testing GetTensorShapeProtoFromTensorProto(). +void GetTensorProto(const DataType dtype, const std::vector<int64>& shape, + const std::vector<int64> values, const bool tensor_content, + TensorProto* tensor_proto) { + tensor_proto->Clear(); + TensorProto temp_tensor_proto; + temp_tensor_proto.set_dtype(dtype); + for (const auto& x : shape) { + temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x); + } + for (const auto& x : values) { + if (dtype == DT_INT64) { + temp_tensor_proto.add_int64_val(x); + } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 || + dtype == DT_UINT8) { + temp_tensor_proto.add_int_val(x); + } else if (dtype == DT_UINT32) { + temp_tensor_proto.add_uint32_val(x); + } else if (dtype == DT_UINT64) { + temp_tensor_proto.add_uint64_val(x); + } else { + CHECK(false) << "Unsupported dtype: " << dtype; + } + } + Tensor tensor(dtype); + CHECK(tensor.FromProto(temp_tensor_proto)); + if (tensor_content) { + tensor.AsProtoTensorContent(tensor_proto); + } else { + tensor.AsProtoField(tensor_proto); + } +} + OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x, const std::vector<int>& ksize, const std::vector<int>& strides, @@ -233,8 +266,11 @@ OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x, DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs()); DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_outputs()); } else if (op_name == "AvgPoolGrad") { - // input: x, y_grad, output: x_grad. - DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs()); + // input: x's shape, y_grad, output: x_grad. + DescribeArbitraryRankInput({4}, DT_INT32, &op_info); + auto* tensor_proto = op_info.mutable_inputs(0)->mutable_value(); + GetTensorProto(DT_INT32, {4}, {x[0], x[1], x[2], x[3]}, + /*tensor_content=*/false, tensor_proto); DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs()); DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs()); } else if (op_name == "MaxPoolGrad") { @@ -365,43 +401,56 @@ class OpLevelCostEstimatorTest : public ::testing::Test { OpLevelCostEstimator estimator_; }; -// TODO(76227186): re-enable with output size check & test -/* TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) { -OpContext op_context; -SetCpuDevice(&op_context.op_info); -op_context.op_info.set_op("Gather"); + OpContext op_context; + SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op("Gather"); -// Huge first input shouldn't affect Gather execution and memory costs. -DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info); -DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info); -DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info); + // Huge first input shouldn't affect Gather execution and memory costs. + DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info); + DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info); + DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info); -auto cost = estimator_.PredictCosts(op_context); -EXPECT_EQ(Costs::Duration(130), cost.memory_time); -EXPECT_EQ(Costs::Duration(16), cost.compute_time); -EXPECT_EQ(Costs::Duration(146), cost.execution_time); -EXPECT_FALSE(cost.inaccurate); + auto cost = estimator_.PredictCosts(op_context); + EXPECT_EQ(Costs::Duration(130), cost.memory_time); + EXPECT_EQ(Costs::Duration(16), cost.compute_time); + EXPECT_EQ(Costs::Duration(146), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); } -TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) { -OpContext op_context; -SetCpuDevice(&op_context.op_info); -op_context.op_info.set_op("Slice"); +TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op("Gather"); -// Huge first input shouldn't affect Slice execution and memory costs. -DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info); -DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info); -DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info); -DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info); + // Huge first input shouldn't affect Gather execution and memory costs. + DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info); + DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info); + + auto cost = estimator_.PredictCosts(op_context); + EXPECT_EQ(Costs::Duration(0), cost.memory_time); + EXPECT_EQ(Costs::Duration(0), cost.compute_time); + EXPECT_EQ(Costs::Duration(0), cost.execution_time); + EXPECT_TRUE(cost.inaccurate); +} -auto cost = estimator_.PredictCosts(op_context); -EXPECT_EQ(Costs::Duration(81), cost.memory_time); -EXPECT_EQ(Costs::Duration(10), cost.compute_time); -EXPECT_EQ(Costs::Duration(91), cost.execution_time); -EXPECT_FALSE(cost.inaccurate); +TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op("Slice"); + + // Huge first input shouldn't affect Slice execution and memory costs. + DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info); + DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info); + DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info); + DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info); + + auto cost = estimator_.PredictCosts(op_context); + EXPECT_EQ(Costs::Duration(81), cost.memory_time); + EXPECT_EQ(Costs::Duration(10), cost.compute_time); + EXPECT_EQ(Costs::Duration(91), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); } -*/ TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) { auto cost = PredictCosts(DescribeBiasAdd(1000, 10)); @@ -510,39 +559,6 @@ TEST_F(OpLevelCostEstimatorTest, BatchMatMul) { EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate); } -// Helper functions for testing GetTensorShapeProtoFromTensorProto(). -void GetTensorProto(const DataType dtype, const std::vector<int64>& shape, - const std::vector<int64> values, const bool tensor_content, - TensorProto* tensor_proto) { - tensor_proto->Clear(); - TensorProto temp_tensor_proto; - temp_tensor_proto.set_dtype(dtype); - for (const auto& x : shape) { - temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x); - } - for (const auto& x : values) { - if (dtype == DT_INT64) { - temp_tensor_proto.add_int64_val(x); - } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 || - dtype == DT_UINT8) { - temp_tensor_proto.add_int_val(x); - } else if (dtype == DT_UINT32) { - temp_tensor_proto.add_uint32_val(x); - } else if (dtype == DT_UINT64) { - temp_tensor_proto.add_uint64_val(x); - } else { - CHECK(false) << "Unsupported dtype: " << dtype; - } - } - Tensor tensor(dtype); - CHECK(tensor.FromProto(temp_tensor_proto)); - if (tensor_content) { - tensor.AsProtoTensorContent(tensor_proto); - } else { - tensor.AsProtoField(tensor_proto); - } -} - void ExpectTensorShape(const std::vector<int64>& expected, const TensorShapeProto& tensor_shape_proto) { TensorShape tensor_shape_expected(expected); @@ -746,25 +762,25 @@ TEST_F(OpLevelCostEstimatorTest, PredictAvgPoolGrad) { { // Typical 3xz3 window with 2x2 stride. auto costs = predict_avg_pool_grad(10, 20, 384, 3, 2, "SAME"); - EXPECT_EQ(Costs::Duration(1920000), costs.execution_time); + EXPECT_EQ(Costs::Duration(1305602), costs.execution_time); EXPECT_EQ(Costs::Duration(537600), costs.compute_time); - EXPECT_EQ(Costs::Duration(1382400), costs.memory_time); + EXPECT_EQ(Costs::Duration(768002), costs.memory_time); EXPECT_FALSE(costs.inaccurate); } { // 1x1 window with 2x2 stride: used for shortcut in resnet-50. auto costs = predict_avg_pool_grad(10, 20, 384, 1, 2, "SAME"); - EXPECT_EQ(Costs::Duration(1574400), costs.execution_time); + EXPECT_EQ(Costs::Duration(960002), costs.execution_time); EXPECT_EQ(Costs::Duration(192000), costs.compute_time); - EXPECT_EQ(Costs::Duration(1382400), costs.memory_time); + EXPECT_EQ(Costs::Duration(768002), costs.memory_time); EXPECT_FALSE(costs.inaccurate); } { // 2x2 window with 3x3 stride. auto costs = predict_avg_pool_grad(10, 20, 384, 2, 3, "VALID"); - EXPECT_EQ(Costs::Duration(1476480), costs.execution_time); + EXPECT_EQ(Costs::Duration(862082), costs.execution_time); EXPECT_EQ(Costs::Duration(172416), costs.compute_time); - EXPECT_EQ(Costs::Duration(1304064), costs.memory_time); + EXPECT_EQ(Costs::Duration(689666), costs.memory_time); EXPECT_FALSE(costs.inaccurate); } } diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 3ac3ae0f8f..0e5c654acf 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -44,6 +44,8 @@ Costs CombineCosts(const Costs& left, const Costs& right) { Costs result = left; result.execution_time += right.execution_time; + result.compute_time += right.compute_time; + result.memory_time += right.memory_time; if (right.inaccurate) { result.inaccurate = true; } @@ -841,6 +843,8 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { Costs VirtualScheduler::Summary() const { // Print out basic execution summary. VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count(); + VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count(); + VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count(); VLOG(1) << "Expected max memory: " << graph_costs_.max_memory; VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers; VLOG(1) << "Expected max per-op streaming buffers: " diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index c31ac9b59c..a24d2dbd9f 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace grappler { @@ -68,6 +69,10 @@ bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; } bool IsCast(const NodeDef& node) { return node.op() == "Cast"; } +bool IsCheckNumerics(const NodeDef& node) { + return node.op() == "CheckNumerics"; +} + bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; } bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; } @@ -360,6 +365,8 @@ bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; } bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; } +bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; } + bool IsVariable(const NodeDef& node) { const auto& op = node.op(); return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" || @@ -404,8 +411,18 @@ bool IsFreeOfSideEffect(const NodeDef& node) { bool ModifiesInputsInPlace(const NodeDef& node) { // Some nodes do in-place updates on regular tensor inputs. string op_name = node.op(); + + // Ops that modify resource variables effectively modify one of their inputs. + if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" || + op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" || + op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" || + op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" || + op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") { + return false; + } + std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower); - if (StringPiece(op_name).contains("inplace")) { + if (str_util::StrContains(op_name, "inplace")) { return true; } return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace"); diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 39affcbc24..8667f72c7e 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -37,6 +37,7 @@ bool IsBiasAdd(const NodeDef& node); bool IsBiasAddGrad(const NodeDef& node); bool IsBitcast(const NodeDef& node); bool IsCast(const NodeDef& node); +bool IsCheckNumerics(const NodeDef& node); bool IsComplex(const NodeDef& node); bool IsComplexAbs(const NodeDef& node); bool IsConj(const NodeDef& node); @@ -139,6 +140,7 @@ bool IsTile(const NodeDef& node); bool IsTranspose(const NodeDef& node); bool IsTruncateDiv(const NodeDef& node); bool IsTruncateMod(const NodeDef& node); +bool IsUnpack(const NodeDef& node); bool IsVariable(const NodeDef& node); bool IsZeta(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 2c365c467c..122fd48584 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -251,6 +251,7 @@ cc_library( ":constant_folding", ":graph_optimizer", ":graph_optimizer_stage", + ":symbolic_shapes", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -260,6 +261,7 @@ cc_library( "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:frame", + "//tensorflow/core/grappler/utils:topological_sort", ], ) @@ -272,6 +274,11 @@ tf_cuda_cc_test( ":constant_folding", ":model_pruner", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -501,6 +508,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/utils:colocation", "//tensorflow/core/grappler/utils:topological_sort", ], ) @@ -630,6 +638,7 @@ cc_library( tf_cuda_cc_test( name = "debug_stripper_test", + size = "small", srcs = ["debug_stripper_test.cc"], deps = [ ":debug_stripper", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d155e0b289..59a5695af0 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" #include <algorithm> +#include <deque> #include <limits> #include <unordered_map> #include <unordered_set> @@ -31,8 +32,9 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h" +#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" #include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/grappler/utils/frame.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -197,39 +199,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) { bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } -// Shape is symbolically defined if it has a known rank, and each dimension is -// defined, or is an unknown symbol (dim.size <= -2). -bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape) { - return !shape.unknown_rank() && - std::all_of( - shape.dim().begin(), shape.dim().end(), - [](const TensorShapeProto::Dim& dim) { return dim.size() != -1; }); -} - -bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties) { - return ShapeIsSymbolicallyDefined(properties.shape()); -} - -bool ShapesSymbolicallyEqual(const TensorShapeProto& left, - const TensorShapeProto& right) { - if (left.unknown_rank() || right.unknown_rank() || - left.dim_size() != right.dim_size()) { - return false; - } - for (int i = 0; i < left.dim_size(); ++i) { - if (left.dim(i).size() == -1 || right.dim(i).size() == -1 || - left.dim(i).size() != right.dim(i).size()) { - return false; - } - } - return true; -} - -bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left, - const OpInfo::TensorProperties& right) { - return ShapesSymbolicallyEqual(left.shape(), right.shape()); -} - // Returns whether `reshape` is an identity op. The tensor that `reshape` // reshapes is the `output_pos`-th output of node `input`. bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input, @@ -320,21 +289,16 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all // optimizations will be migrated to stages - void AddFrameControlDeps(const NodeDef* old_node, - const std::vector<NodeDef*>& new_nodes, - const string& source_for_ctrl_dep, - const std::vector<NodeDef*>& sinks_for_control_dep) { - const auto frame_it = ctx_.frame_map->find(old_node); - if (frame_it != ctx_.frame_map->end()) { - for (auto node : new_nodes) { - ctx_.frame_map->emplace(node, frame_it->second); - } - if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) { - const string ctrl_dep = ConstantFolding::AddControlDependency( - source_for_ctrl_dep, ctx_.optimized_graph, ctx_.node_map); - for (auto node : sinks_for_control_dep) { - MaybeAddControlInput(ctrl_dep, node, ctx_.optimized_graph, - ctx_.node_map); + void ForwardControlDependencies( + NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) { + for (const auto& src : src_nodes) { + for (int i = src->input_size() - 1; i >= 0; --i) { + if (IsControlInput(src->input(i))) { + *target_node->add_input() = src->input(i); + ctx_.node_map->AddOutput(NodeName(src->input(i)), + target_node->name()); + } else { + break; } } } @@ -348,17 +312,30 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the // original inputs of absorbed nodes. // -// All nodes in a Add/AddN subgraph must have symbolically equal shape. All -// nodes must have the same device placement. +// 1) All nodes must have the same device placement. +// +// 2) If All nodes in a Add/AddN subgraph have symbolically equal shape, tree is +// optimized to a single AddN node. // -// Example: // AddN_1 // / | \ -// Add_1 z Add_2 -> AddN(z, y, z, w, q, e) +// Add_1 z Add_2 -> AddN(x, y, z, w, q, e) // / \ / \ // x y w Add_3 // / \ // q e +// +// 3) If some nodes have different shape (it needs to be broadcastable to the +// shape of a "root), tree is optimized to AddNs for symbolically equal +// shapes, and a tree of Add ops, that minimize broadcasts. +// +// AddN_1 Add +// / | \ / \ +// Add_1 z Add_2 -> Add w +// / \ / \ / \ +// x y w Add_3 AddN(x, y, q, e) z +// / \ +// q e class AddOpsRewriteStage : public ArithmeticOptimizerStage { public: explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx, @@ -379,7 +356,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { OpInfo::TensorProperties properties; Status has_properties = GetTensorProperties(node->name(), &properties); return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) && - HasAllInputsOfSymbolicallyEqualShape(*node, properties); + HasAllInputsOfBroadcastableShape(*node, properties); } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { @@ -387,7 +364,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { AddOpsGroup group; TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group)); - if (!group.absorbed_nodes.empty() && !IsRewritten(group)) { + if (!group.absorbed_nodes.empty()) { *simplified_node_name = RewriteAddOpsGroup(group); } @@ -395,6 +372,14 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { } private: + // Input name with a statically inferred shape from GraphProperties + struct InputAndShape { + InputAndShape(const string& input, const TensorShapeProto& shape) + : input(input), shape(shape) {} + string input; + TensorShapeProto shape; + }; + // Holds together an add ops subgraph that we want to rewrite together. // // For the graph above the AddOpsGroup will be: @@ -406,12 +391,12 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { TensorShapeProto root_shape; // Add/AddN operations below the root level that were absorbed by this group std::vector<NodeDef*> absorbed_nodes; - // Inputs of absorbed nodes that will be forwarded to rewritten AddN node - std::vector<string> inputs; + // Inputs of absorbed nodes that will be forwarded to optimized AddN ops + std::vector<InputAndShape> inputs; }; - // Check if all inputs have symbolically equal shapes - bool HasAllInputsOfSymbolicallyEqualShape( + // Check if all inputs can be broadcasted to the same shape + bool HasAllInputsOfBroadcastableShape( const NodeDef& node, const OpInfo::TensorProperties& properties) const { const AddOpsRewriteStage* self = this; return std::all_of( @@ -421,7 +406,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { Status has_input_properties = self->GetTensorProperties(input, &input_properties); return has_input_properties.ok() && - ShapesSymbolicallyEqual(properties, input_properties); + ShapesBroadcastable(properties, input_properties); }); } @@ -467,11 +452,11 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { if (node->device() != group.root_node->device()) { return false; } - // All input shapes must be symbolically defined and equal to the node shape + // All input shapes must be broadcastable to the node shape OpInfo::TensorProperties properties; Status has_properties = GetTensorProperties(name, &properties); return has_properties.ok() && - HasAllInputsOfSymbolicallyEqualShape(*node, properties); + HasAllInputsOfBroadcastableShape(*node, properties); } // Node requirements both for a root node and an absorbed node @@ -490,18 +475,16 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { if (rewritten_nodes_.find(node->name()) != rewritten_nodes_.end()) { return false; } + // it must not be created by this stage at any of previous optimization runs + if (StringPiece(node->name()).contains(stage_name_)) { + return false; + } // should not drive or be driven by control dependency // TODO(ezhulenev): relax this condition for root node return !(IsDrivenByControlDependency(*node) || DrivesControlDependency(*node)); } - // Check that optimized group node name doesn't exists. It might happen if - // graph optimized multiple times without pruning between invocations. - bool IsRewritten(const AddOpsGroup& group) const { - return ctx_.node_map->NodeExists(AddOpsGroupName(group)); - } - // Create an AddOpsGroup with a root in a given node Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) { OpInfo::TensorProperties root_node_output_properties; @@ -513,7 +496,10 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { group->absorbed_nodes.reserve(root_node->input_size()); for (int i = 0; i < root_node->input_size(); ++i) { - TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(root_node->input(i), group)); + const string& input_i = root_node->input(i); + if (!IsControlInput(input_i)) { + TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(input_i, group)); + } } return Status::OK(); @@ -526,71 +512,159 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { if (IsAbsorbableByAddOpsGroup(input, *group)) { group->absorbed_nodes.push_back(node); for (int i = 0; i < node->input_size(); ++i) { - TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(node->input(i), group)); + const string& input_i = node->input(i); + if (!IsControlInput(input)) { + TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(input_i, group)); + } } } else { // If node can't be absorbed, add it to AddOpsGroup input - group->inputs.push_back(input); + OpInfo::TensorProperties properties; + TF_RETURN_IF_ERROR(GetTensorProperties(input, &properties)); + group->inputs.emplace_back(input, properties.shape()); } return Status::OK(); } - // New node for AddOpsGroup is added to the same scope as a root_node. All - // absorbed nodes are stripped of their scope, and only names are used in a - // new node name. - // - // Example: AddOpsGroup(root="a/b/c/Add_2", absorbed=["d/Add_1", "e/Add"]) - // node_name="a/b/c/AddOpsGroup_Add_2_Add_1_Add - string AddOpsGroupName(const AddOpsGroup& group) const { - CHECK_NOTNULL(group.root_node); - - auto root = ParseNodeScopeAndName(group.root_node->name()); + // Rewrite an add ops group into a single AddN if all input shapes are + // symbolically equal. If not, create AddN for equal shapes first, and then + // build an Add tree, minimizing the cost of broadcasts. + string RewriteAddOpsGroup(const AddOpsGroup& group) { + // all new nodes will be placed under the scope of a root node + auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name()); + + auto shape_sig = [](const TensorShapeProto& shape) { + string name = strings::StrCat("r:", shape.dim_size(), ":d"); + for (int i = 0; i < shape.dim_size(); ++i) + strings::StrAppend(&name, ":", shape.dim(i).size()); + return name; + }; + + // Find what shapes are present in the inputs of absorbed nodes + std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs; + for (const auto& input : group.inputs) { + shape_sig_to_inputs[shape_sig(input.shape)].push_back(input); + } - std::vector<string> absorbed_node_names(group.absorbed_nodes.size()); - std::transform(group.absorbed_nodes.begin(), group.absorbed_nodes.end(), - absorbed_node_names.begin(), - [](const NodeDef* node) { return node->name(); }); + // Collect all the shapes from representative elements + std::vector<TensorShapeProto> shapes; + shapes.reserve(shape_sig_to_inputs.size()); + for (const auto& el : shape_sig_to_inputs) + shapes.push_back(el.second[0].shape); + + // If all inputs have the same shape, rewrite whole group with a single AddN + if (shapes.size() == 1) { + string node_name = OptimizedNodeName(root_scope_and_name); + AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name, + group.inputs); + // keep track of nodes that were created or absorbed as a part of rewrite + rewritten_nodes_.insert(node_name); + return node_name; + } - return OptimizedNodeName(root, absorbed_node_names); - } + // For inputs of different shapes: + // 1. Rewrite inputs of the same shape using AddN (leaf nodes) + // 2. Build a tree of Add nodes, minimizing cost of broadcast + std::sort(shapes.begin(), shapes.end(), + [](const TensorShapeProto& left, const TensorShapeProto& right) { + return CompareSymbolicallyShapedTensorSizes(left, right); + }); + + // optimized name for leaf AddN nodes + auto leaf_node_name = [&root_scope_and_name, this](int i) { + return OptimizedNodeName(root_scope_and_name, + strings::StrCat("Leaf_", i)); + }; + // optimized name for internal nodes of a tree built up from AddN leaves + auto internal_node_name = [&root_scope_and_name, this](int i) { + return OptimizedNodeName(root_scope_and_name, + strings::StrCat("Internal_", i)); + }; + + // Add/AddN nodes that must be added to the tree + std::deque<InputAndShape> add_ops; + + // Prepare leaf AddN nodes for inputs of equal shape + for (int i = 0; i < shapes.size(); ++i) { + const auto node_name = leaf_node_name(i); + const auto& inputs = shape_sig_to_inputs[shape_sig(shapes[i])]; + add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node, + node_name, inputs)); + } - // Create a new node for a AddOpsGroup and return it's name. - string RewriteAddOpsGroup(const AddOpsGroup& group) { - CHECK_GT(group.absorbed_nodes.size(), 0) - << "AddOpsGroup must have non empty absorbed nodes"; + // Build up a tree of Add ops + int internal_nodes = 0; + do { + const InputAndShape lhs = add_ops.front(); + add_ops.pop_front(); + const InputAndShape rhs = add_ops.front(); + add_ops.pop_front(); + string name = add_ops.empty() ? OptimizedNodeName(root_scope_and_name) + : internal_node_name(internal_nodes++); + InputAndShape add = AddAggregatedInputs(*group.root_node, name, lhs, rhs); + add_ops.push_front(add); + } while (add_ops.size() > 1); + + InputAndShape optimized_root_node = add_ops.front(); + return optimized_root_node.input; + } + + // Add 'AddN' node to aggregate inputs of symbolically equal shape + InputAndShape AddInputsOfSymbolicallyEqualShape( + const NodeDef& root_node, const string& node_name, + const std::vector<InputAndShape>& inputs) { + CHECK(!inputs.empty()) << "Inputs must be non-empty"; + + // Do not create redundant AddN nodes + if (inputs.size() == 1) { + return inputs[0]; + } - // name for a new node constructed from AddOpsGroup - string node_name = AddOpsGroupName(group); + // get shape from representative element + auto shape = inputs[0].shape; // copy attributes from a root node - DataType dtype = group.root_node->attr().at("T").type(); + DataType dtype = root_node.attr().at("T").type(); // add new AddN node - NodeDef* added_node = AddEmptyNode(node_name); - added_node->set_op("AddN"); - added_node->set_device(group.root_node->device()); - (*added_node->mutable_attr())["T"].set_type(dtype); - (*added_node->mutable_attr())["N"].set_i(group.inputs.size()); - - // all inputs of absorbed nodes are added to the new node - for (const string& input : group.inputs) { - ctx_.node_map->AddOutput(input, node_name); - added_node->add_input(input); + NodeDef* node = AddEmptyNode(node_name); + node->set_op("AddN"); + node->set_device(root_node.device()); + (*node->mutable_attr())["T"].set_type(dtype); + (*node->mutable_attr())["N"].set_i(inputs.size()); + + for (const auto& inputAndShape : inputs) { + ctx_.node_map->AddOutput(inputAndShape.input, node_name); + node->add_input(inputAndShape.input); } - // Add frame dependencies that the original node might have had. - AddFrameControlDeps(group.root_node, {added_node}, "", {}); + rewritten_nodes_.insert(node_name); + return InputAndShape(node_name, shape); + } + + // Add a single 'Add' node to sum two inputs + InputAndShape AddAggregatedInputs(const NodeDef& root_node, + const string& node_name, + const InputAndShape& left, + const InputAndShape& right) { + // copy attributes from a root node + DataType dtype = root_node.attr().at("T").type(); - VLOG(1) << "Absorbed " << group.absorbed_nodes.size() - << " Add/AddN nodes from the graph"; + // add new Add node + NodeDef* node = AddEmptyNode(node_name); + node->set_op("Add"); + node->set_device(root_node.device()); + (*node->mutable_attr())["T"].set_type(dtype); - // keep track of nodes that were created or absorbed as a part of rewrite - rewritten_nodes_.insert(node_name); - for (const NodeDef* absorbed : group.absorbed_nodes) { - rewritten_nodes_.insert(absorbed->name()); - } + ctx_.node_map->AddOutput(left.input, node_name); + ctx_.node_map->AddOutput(right.input, node_name); + + node->add_input(left.input); + node->add_input(right.input); - return node_name; + rewritten_nodes_.insert(node_name); + return InputAndShape( + node_name, TensorShapeProto()); // shape is not important at this point } // keep nodes that were added or absorbed as a part of AddOpsGroup rewrite @@ -623,7 +697,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { CHECK(IsSupported(node)); std::set<string> common_factors; - TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors)); + std::vector<string> ctrl_deps; + TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors, &ctrl_deps)); if (common_factors.size() == 1) { const string& common_factor = *common_factors.begin(); @@ -655,9 +730,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { new_add_node->set_input(i, unique_factors[i]); } - // Add frame dependencies that the original node might have had. - AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor, - {new_add_node}); + // Add control deps on add node + for (const string& ctrl_dep : ctrl_deps) { + *new_add_node->add_input() = ctrl_dep; + ctx_.node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name()); + } // optimize new inner aggregation node AddToOptimizationQueue(new_add_node); @@ -683,14 +760,16 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { } // Determine the set of common factors if the input nodes are all Mul nodes. - Status GetCommonFactors(const NodeDef* node, - std::set<string>* common_factors) const { + Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors, + std::vector<string>* ctrl_deps) const { CHECK(common_factors->empty()); for (int i = 0; i < node->input_size(); ++i) { if (i > 0 && common_factors->empty()) break; - if (IsControlInput(node->input(i))) break; - + if (IsControlInput(node->input(i))) { + ctrl_deps->push_back(node->input(i)); + continue; + } NodeDef* input; TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input)); @@ -710,6 +789,9 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { std::inserter(intersection, intersection.begin())); std::swap(*common_factors, intersection); } + for (int i = 2; i < input->input_size(); ++i) { + ctrl_deps->push_back(input->input(i)); + } } return Status::OK(); } @@ -1195,20 +1277,15 @@ void ArithmeticOptimizer::DedupComputations() { } } -void ArithmeticOptimizer::AddFrameControlDeps( - const NodeDef* old_node, const std::vector<NodeDef*>& new_nodes, - const string& source_for_ctrl_dep, - const std::vector<NodeDef*>& sinks_for_control_dep) { - const auto frame_it = frame_map_.find(old_node); - if (frame_it != frame_map_.end()) { - for (auto node : new_nodes) { - frame_map_.emplace(node, frame_it->second); - } - if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) { - const string ctrl_dep = ConstantFolding::AddControlDependency( - source_for_ctrl_dep, optimized_graph_, node_map_.get()); - for (auto node : sinks_for_control_dep) { - MaybeAddControlInput(ctrl_dep, node, optimized_graph_, node_map_.get()); +void ArithmeticOptimizer::ForwardControlDependencies( + NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) { + for (const auto& src : src_nodes) { + for (int i = src->input_size() - 1; i >= 0; --i) { + if (IsControlInput(src->input(i))) { + *target_node->add_input() = src->input(i); + node_map_->AddOutput(NodeName(src->input(i)), target_node->name()); + } else { + break; } } } @@ -1264,19 +1341,18 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( int output_pos = 0; string input_node_name = ParseNodeName(node->input(0), &output_pos); const NodeDef* input = node_map_->GetNode(input_node_name); - if (input->op() == "Reshape") { + if (input->op() == "Reshape" && !HasControlInputs(*input)) { reshape->set_input(0, input->input(0)); node_map_->UpdateInput(reshape->name(), input->name(), input->input(0)); nodes_to_simplify->PushBack(reshape); return reshape->name(); } - // If the reshape is a no-op, forward its input to its consumers. This is - // considered aggressive, because users may state that the placeholder - // outputs tensors of shape [M, N] while feeding it with tensors of shape - // [M*N] (or worse). The reshape nodes are then necessary to update the - // tensor metadata to the required shape. - if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_)) { + // If the reshape is a no-op, forward its input to its consumers, unless it + // anchors a control dependency since we want to make sure that control + // dependency is triggered. + if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) && + !HasControlInputs(*reshape)) { return reshape->input(0); } } @@ -1329,10 +1405,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( node_map_->AddOutput(new_transpose->name(), new_cast->name()); nodes_to_simplify->PushBack(new_transpose); - // Add frame dependencies that the original node might have had. - AddFrameControlDeps(node, {new_transpose, new_cast}, - new_transpose->input(0), {new_transpose}); - + ForwardControlDependencies(new_transpose, {cast, node}); return new_cast->name(); } } @@ -1406,7 +1479,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( node_map_->AddOutput(weights->name(), scaled_weights->name()); scaled_weights->add_input(mul->input(1)); node_map_->AddOutput(scale->name(), scaled_weights->name()); - AddFrameControlDeps(node, {scaled_weights}, "", {}); + ForwardControlDependencies(scaled_weights, {source}); // Update `conv`'s weights to `scaled_weights`. conv->set_input(1, scaled_weights->name()); @@ -1442,7 +1515,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) { - // Discard aggregate nodes with a single input. + // Discard aggregate nodes with a single input and no control dependencies. if (node->input_size() == 1) { return node->input(0); } @@ -1488,6 +1561,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( return ""; } new_const_node->set_device(node->device()); + MaybeAddControlInput(NodeName(node->input(0)), new_const_node, + optimized_graph_, node_map_.get()); nodes_to_simplify->PushBack(new_const_node); // 2. Replace the aggregate node with Mul(Const(N), x). @@ -1500,9 +1575,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( new_mul_node->add_input(node->input(0)); node_map_->AddOutput(node->input(0), new_mul_node->name()); - CopyControlInputs(*node, new_mul_node, optimized_graph_, node_map_.get()); - AddFrameControlDeps(node, {new_const_node, new_mul_node}, node->input(0), - {new_const_node}); + ForwardControlDependencies(new_mul_node, {node}); return new_mul_node->name(); } } @@ -1535,7 +1608,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( FlipBooleanAttr(attr_a, new_op); new_op->set_input(0, a->input(0)); node_map_->UpdateInput(new_op->name(), a->name(), a->input(0)); - AddFrameControlDeps(node, {new_op}, a->input(0), {new_op}); } if (b_is_foldable) { const string attr_b = @@ -1543,10 +1615,15 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( FlipBooleanAttr(attr_b, new_op); new_op->set_input(1, b->input(0)); node_map_->UpdateInput(new_op->name(), b->name(), b->input(0)); - if (!a_is_foldable) { - AddFrameControlDeps(node, {new_op}, b->input(0), {new_op}); - } } + std::vector<const NodeDef*> deps_to_forward({node}); + if (a_is_foldable) { + deps_to_forward.push_back(a); + } + if (b_is_foldable) { + deps_to_forward.push_back(b); + } + ForwardControlDependencies(new_op, deps_to_forward); } } @@ -1568,7 +1645,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( : "Transpose"); new_op->set_input(0, input->input(0)); node_map_->UpdateInput(new_op->name(), node->name(), input->input(0)); - AddFrameControlDeps(node, {new_op}, "", {}); + ForwardControlDependencies(new_op, {node, input}); return new_op->name(); } } @@ -1584,38 +1661,27 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { } const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_, - graph_properties_.get(), node_map_.get(), - &frame_map_); + graph_properties_.get(), node_map_.get()); const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify); - std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages; - - if (options_.combine_add_to_addn) { - stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>( - new AddOpsRewriteStage(ctx, ctx_ext))); - } - if (options_.hoist_common_factor_out_of_aggregation) { - stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>( - new HoistCommonFactorOutOfAggregation(ctx, ctx_ext))); - } - if (options_.remove_identity_transpose) { - stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>( - new RemoveIdentityTranspose(ctx, ctx_ext))); - } - if (options_.remove_redundant_bitcast) { - stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>( - new RemoveRedundantBitcastStage(ctx, ctx_ext))); - } - if (options_.remove_redundant_cast) { - stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>( - new RemoveRedundantCastStage(ctx, ctx_ext))); - } - if (options_.remove_negation) { - stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>( - new RemoveNegationStage(ctx, ctx_ext))); - } - - VLOG(1) << "Simplify arithmetic ops using " << stages.size() + // Stop pipeline after first stage returning non-empty simplified tensor name. + const auto stop = [](const string& result) { return !result.empty(); }; + GraphOptimizerStagePipeline<string> pipeline(stop); + + if (options_.combine_add_to_addn) + pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext); + if (options_.hoist_common_factor_out_of_aggregation) + pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext); + if (options_.remove_identity_transpose) + pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext); + if (options_.remove_redundant_bitcast) + pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext); + if (options_.remove_redundant_cast) + pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext); + if (options_.remove_negation) + pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext); + + VLOG(1) << "Simplify arithmetic ops using " << pipeline.NumStages() << " arithmetic optimization stages"; while (!nodes_to_simplify.Empty()) { @@ -1628,22 +1694,13 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { } // if it was not simplified try to run it through all configured stages - if (simplified_tensor.empty()) { - for (auto& stage : stages) { - if (stage->IsSupported(node)) { - TF_RETURN_IF_ERROR(stage->TrySimplify(node, &simplified_tensor)); - if (!simplified_tensor.empty()) { - break; - } - } + if (!stop(simplified_tensor)) { + bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor); + if (!optimized) { + continue; } } - // if it's still empty go to the next Node - if (simplified_tensor.empty()) { - continue; - } - // re-wire consumers of an old node to the new one if (NodeName(simplified_tensor) != node->name()) { // Always consider simplified_tensor for further optimizations. @@ -1686,24 +1743,28 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, GraphDef* optimized_graph) { - optimized_graph_ = optimized_graph; - *optimized_graph_ = item.graph; + GrapplerItem optimized_item(item); + optimized_graph_ = &optimized_item.graph; // Set up helper data structures. nodes_to_preserve_ = item.NodesToPreserve(); fetch_nodes_known_ = !item.fetch.empty(); node_map_.reset(new NodeMap(optimized_graph_)); - int num_frames; - TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_, - &frame_map_, &num_frames)); + + DedupComputations(); + + // Perform topological sort on the graph in order to help AddOpsRewrite to + // optimize larger subgraphs starting from the roots with more inputs. + TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_)); + // Shapes are only needed in aggressive mode. graph_properties_.reset(new GraphProperties(item)); TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false)); // Perform the optimizations. - DedupComputations(); TF_RETURN_IF_ERROR(SimplifyArithmeticOps()); + optimized_graph->Swap(optimized_graph_); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 965f0e9ea2..7e81ed0a1f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/grappler/utils/frame.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { @@ -69,7 +68,13 @@ class ArithmeticOptimizer : public GraphOptimizer { // optimization level by default. static ArithmeticOptimizerOptions Default( RewriterConfig::Toggle opt_level) { - return ArithmeticOptimizerOptions(); + ArithmeticOptimizerOptions options; + // TODO(ezhulenev): enable combine_add_to_addn by default after 1.8 + // release cut + if (opt_level == RewriterConfig::AGGRESSIVE) { + options.combine_add_to_addn = true; + } + return options; } }; @@ -94,13 +99,9 @@ class ArithmeticOptimizer : public GraphOptimizer { // Dedup redundant nodes in the graph. void DedupComputations(); - // Fix frame dependencies by adding control dependencies from old_input to - // nodes in new_nodes_for_control_dep, and update frame_map for all nodes in - // new_nodes. - void AddFrameControlDeps(const NodeDef* old_node, - const std::vector<NodeDef*>& new_nodes, - const string& source_for_ctrl_dep, - const std::vector<NodeDef*>& sinks_for_control_dep); + // Forward the control dependencies anchored on src_nodes to the target_nodes. + void ForwardControlDependencies(NodeDef* target_node, + const std::vector<const NodeDef*>& src_nodes); // Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse // transposes. @@ -129,7 +130,6 @@ class ArithmeticOptimizer : public GraphOptimizer { bool fetch_nodes_known_ = false; std::unordered_set<string> nodes_to_preserve_; std::unique_ptr<NodeMap> node_map_; - FrameMap frame_map_; std::unique_ptr<GraphProperties> graph_properties_; GraphDef* optimized_graph_ = nullptr; // Not owned. }; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index ad3edc144a..e117341ba3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -156,27 +156,24 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch = {"div"}; - ArithmeticOptimizer optimizer; - GraphDef output; - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {}); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); EXPECT_EQ(1, tensors_expected.size()); - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + ArithmeticOptimizer optimizer; + GraphDef output; + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); EXPECT_EQ(2, output.node_size()); - const NodeDef& new_c1 = output.node(0); - EXPECT_EQ("c1", new_c1.name()); - const NodeDef& new_div = output.node(1); - EXPECT_EQ("div", new_div.name()); - EXPECT_EQ(2, new_div.input_size()); - EXPECT_EQ("c1", new_div.input(0)); - EXPECT_EQ("c1", new_div.input(1)); - - auto tensors = EvaluateNodes(output, item.fetch, {}); + const NodeDef* new_c1 = node_map.GetNode("c1"); + ASSERT_NE(new_c1, nullptr); + + const NodeDef* new_div = node_map.GetNode("div"); + ASSERT_NE(new_div, nullptr); + EXPECT_EQ(2, new_div->input_size()); + EXPECT_EQ("c1", new_div->input(0)); + EXPECT_EQ("c1", new_div->input(1)); + + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6); } @@ -195,23 +192,30 @@ TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch = {"div"}; + Tensor bool_t(DT_BOOL, TensorShape({})); + bool_t.scalar<bool>().setConstant(true); + auto tensors_expected = + EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}}); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); EXPECT_EQ(5, output.node_size()); - const NodeDef& new_div = output.node(3); - EXPECT_EQ(4, new_div.input_size()); - EXPECT_EQ("check1", new_div.input(0)); - EXPECT_EQ("check1", new_div.input(1)); - EXPECT_EQ("^assert1", new_div.input(2)); - EXPECT_EQ("^assert1", new_div.input(3)); + const NodeDef* new_div = node_map.GetNode("div"); + ASSERT_NE(new_div, nullptr); + EXPECT_EQ(4, new_div->input_size()); + EXPECT_EQ("check1", new_div->input(0)); + EXPECT_EQ("check1", new_div->input(1)); + EXPECT_EQ("^assert1", new_div->input(2)); + EXPECT_EQ("^assert1", new_div->input(3)); + + auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}}); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { @@ -223,32 +227,34 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"div"}; + item.fetch = {"div1"}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); EXPECT_EQ(4, output.node_size()); - const NodeDef& new_c1 = output.node(0); - EXPECT_EQ("c1", new_c1.name()); - const NodeDef& new_c2 = output.node(1); - EXPECT_EQ("c2", new_c2.name()); - const NodeDef& new_mul1 = output.node(2); - EXPECT_EQ("mul1", new_mul1.name()); - EXPECT_EQ(2, new_mul1.input_size()); - EXPECT_EQ("c1", new_mul1.input(0)); - EXPECT_EQ("c2", new_mul1.input(1)); - const NodeDef& new_div1 = output.node(3); - EXPECT_EQ("div1", new_div1.name()); - EXPECT_EQ(2, new_div1.input_size()); - EXPECT_EQ("mul1", new_div1.input(0)); - EXPECT_EQ("mul1", new_div1.input(1)); + const NodeDef* new_c1 = node_map.GetNode("c1"); + ASSERT_NE(new_c1, nullptr); + const NodeDef* new_c2 = node_map.GetNode("c2"); + ASSERT_NE(new_c2, nullptr); + const NodeDef* new_mul1 = node_map.GetNode("mul1"); + ASSERT_NE(new_mul1, nullptr); + EXPECT_EQ(2, new_mul1->input_size()); + EXPECT_EQ("c1", new_mul1->input(0)); + EXPECT_EQ("c2", new_mul1->input(1)); + const NodeDef* new_div1 = node_map.GetNode("div1"); + ASSERT_NE(new_div1, nullptr); + EXPECT_EQ(2, new_div1->input_size()); + EXPECT_EQ("mul1", new_div1->input(0)); + EXPECT_EQ("mul1", new_div1->input(1)); + + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, MulToSquare) { @@ -259,6 +265,9 @@ TEST_F(ArithmeticOptimizerTest, MulToSquare) { Output id = ops::Identity(s.WithOpName("id"), mul); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + std::vector<string> fetch = {"id"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; GraphDef output; @@ -273,6 +282,10 @@ TEST_F(ArithmeticOptimizerTest, MulToSquare) { EXPECT_EQ(2, output.node(4).input_size()); EXPECT_EQ("c", output.node(4).input(0)); EXPECT_EQ("^d", output.node(4).input(1)); + + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) { @@ -285,6 +298,9 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) { Output id = ops::Identity(s.WithOpName("id"), recip2); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + std::vector<string> fetch = {"id"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; GraphDef output; @@ -295,6 +311,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) { EXPECT_EQ("c", output.node(1).input(0)); EXPECT_EQ("c", output.node(3).input(0)); EXPECT_EQ("c", output.node(5).input(0)); + + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) { @@ -307,6 +327,9 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) { Output id2 = ops::Identity(s.WithOpName("id2"), recip2); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + std::vector<string> fetch = {"id2"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; GraphDef output; @@ -320,6 +343,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) { EXPECT_EQ(6, output.node_size()); EXPECT_EQ("squeeze", output.node(5).input(0)); EXPECT_EQ("c", output.node(2).input(0)); + + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) { @@ -334,6 +361,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + std::vector<string> fetch = {"id2"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); + ArithmeticOptimizer optimizer; GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); @@ -351,6 +382,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) { EXPECT_EQ(original.input(j), optimized.input(j)); } } + + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { @@ -362,28 +397,35 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + std::vector<string> fetch = {"id"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); + ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); EXPECT_EQ(5, output.node_size()); - const NodeDef& new_const = output.node(3); - EXPECT_EQ(OptimizedName("add_const"), new_const.name()); - EXPECT_EQ("^x", new_const.input(0)); + + const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const")); + ASSERT_NE(new_const, nullptr); + EXPECT_EQ("^x", new_const->input(0)); EXPECT_EQ(std::string("\0\0\0@", 4), - new_const.attr().at("value").tensor().tensor_content()); - const NodeDef& new_mul = output.node(4); - EXPECT_EQ(OptimizedName("add_mul"), new_mul.name()); - EXPECT_EQ(OptimizedName("add_const"), new_mul.input(0)); - EXPECT_EQ("x", new_mul.input(1)); - const NodeDef& new_id = output.node(2); - EXPECT_EQ("id", new_id.name()); - EXPECT_EQ(OptimizedName("add_mul"), new_id.input(0)); + new_const->attr().at("value").tensor().tensor_content()); + + const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul")); + ASSERT_NE(new_mul, nullptr); + EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0)); + EXPECT_EQ("x", new_mul->input(1)); + + const NodeDef* new_id = node_map.GetNode("id"); + ASSERT_NE(new_id, nullptr); + EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0)); + + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { @@ -396,29 +438,36 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + std::vector<string> fetch = {"id"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); + ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); EXPECT_EQ(6, output.node_size()); - const NodeDef& new_const = output.node(4); - EXPECT_EQ(OptimizedName("add_const"), new_const.name()); - EXPECT_EQ("^x", new_const.input(0)); + + const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const")); + ASSERT_NE(new_const, nullptr); + EXPECT_EQ("^x", new_const->input(0)); EXPECT_EQ(std::string("\0\0\0@", 4), - new_const.attr().at("value").tensor().tensor_content()); - const NodeDef& new_mul = output.node(5); - EXPECT_EQ(OptimizedName("add_mul"), new_mul.name()); - EXPECT_EQ(OptimizedName("add_const"), new_mul.input(0)); - EXPECT_EQ("x", new_mul.input(1)); - EXPECT_EQ("^y", new_mul.input(2)); - const NodeDef& new_id = output.node(3); - EXPECT_EQ("id", new_id.name()); - EXPECT_EQ(OptimizedName("add_mul"), new_id.input(0)); + new_const->attr().at("value").tensor().tensor_content()); + + const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul")); + ASSERT_NE(new_mul, nullptr); + EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0)); + EXPECT_EQ("x", new_mul->input(1)); + EXPECT_EQ("^y", new_mul->input(2)); + + const NodeDef* new_id = node_map.GetNode("id"); + ASSERT_NE(new_id, nullptr); + EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0)); + + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { @@ -434,6 +483,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + const std::vector<string> devices{ "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1", "/device:CPU:0", "/device:CPU:0", "/device:CPU:0", @@ -458,48 +508,45 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { EXPECT_EQ(17, output.node_size()); const NodeDef* id_node = node_map.GetNode("id"); - ASSERT_TRUE(id_node != nullptr); + ASSERT_NE(id_node, nullptr); EXPECT_EQ(1, id_node->input_size()); EXPECT_EQ(HoistMulName("Add_6"), id_node->input(0)); const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6")); - ASSERT_TRUE(mul_node != nullptr); + ASSERT_NE(mul_node, nullptr); EXPECT_EQ(2, mul_node->input_size()); EXPECT_EQ("Placeholder", mul_node->input(0)); EXPECT_EQ(HoistAddName("Add_6"), mul_node->input(1)); const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6")); - ASSERT_TRUE(add_6_node != nullptr); - EXPECT_EQ(3, add_6_node->input_size()); + ASSERT_NE(add_6_node, nullptr); + EXPECT_EQ(2, add_6_node->input_size()); EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0)); EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1)); - EXPECT_EQ("^Placeholder", add_6_node->input(2)); const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4")); - ASSERT_TRUE(add_4_node != nullptr); + ASSERT_NE(add_4_node, nullptr); EXPECT_EQ("Add", add_4_node->op()); - EXPECT_EQ(3, add_4_node->input_size()); + EXPECT_EQ(2, add_4_node->input_size()); EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0)); EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1)); - EXPECT_EQ("^Placeholder", add_4_node->input(2)); const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5")); - ASSERT_TRUE(add_5_node != nullptr); + ASSERT_NE(add_5_node, nullptr); EXPECT_EQ("Add", add_5_node->op()); - EXPECT_EQ(3, add_5_node->input_size()); + EXPECT_EQ(2, add_5_node->input_size()); EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0)); EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1)); - EXPECT_EQ("^Placeholder", add_5_node->input(2)); const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const")); - ASSERT_TRUE(add_const_node != nullptr); + ASSERT_NE(add_const_node, nullptr); EXPECT_EQ("Const", add_const_node->op()); EXPECT_EQ(1, add_const_node->input_size()); EXPECT_EQ("^Placeholder", add_const_node->input(0)); const NodeDef* add_1_const_node = node_map.GetNode(OptimizedName("Add_1_const")); - ASSERT_TRUE(add_1_const_node != nullptr); + ASSERT_NE(add_1_const_node, nullptr); EXPECT_EQ("Const", add_1_const_node->op()); EXPECT_EQ(1, add_1_const_node->input_size()); EXPECT_EQ("^Placeholder", add_1_const_node->input(0)); @@ -525,7 +572,8 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) { GrapplerItem item; item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; EnableOnlyHoistCommonFactor(&optimizer); @@ -550,55 +598,63 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) { EXPECT_EQ(9, output.node_size()); const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add")); - ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found"; + ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found"; EXPECT_EQ("y1", new_add_node->input(0)); EXPECT_EQ("y2", new_add_node->input(1)); const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add")); - ASSERT_TRUE(new_mul_node != nullptr) << "Hoisted Mul node not found"; + ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found"; EXPECT_EQ("x", new_mul_node->input(0)); EXPECT_EQ(new_add_node->name(), new_mul_node->input(1)); const NodeDef* id_node = node_map.GetNode("id"); - ASSERT_TRUE(id_node != nullptr) << "Id node not found"; + ASSERT_NE(id_node, nullptr) << "Id node not found"; EXPECT_EQ("id", id_node->name()); EXPECT_EQ(HoistMulName("add"), id_node->input(0)); } + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } } } TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2}); - Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2}); + Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2}); Output z = ops::Complex(s.WithOpName("z"), re, im); Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); Output conj = ops::Conj(s.WithOpName("conj"), z); Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - + std::vector<string> fetch = {"trans"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); EXPECT_EQ(7, output.node_size()); - EXPECT_EQ(OptimizedName("trans_fused"), output.node(6).name()); - EXPECT_EQ("ConjugateTranspose", output.node(6).op()); - EXPECT_EQ("z", output.node(6).input(0)); - EXPECT_EQ("perm", output.node(6).input(1)); + + const NodeDef* trans_fused_node = + node_map.GetNode(OptimizedName("trans_fused")); + ASSERT_NE(trans_fused_node, nullptr); + EXPECT_EQ("ConjugateTranspose", trans_fused_node->op()); + EXPECT_EQ("z", trans_fused_node->input(0)); + EXPECT_EQ("perm", trans_fused_node->input(1)); + + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]); } TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2}); - Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2}); + Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2}); Output z = ops::Complex(s.WithOpName("z"), re, im); Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); Output conj = ops::Conj(s.WithOpName("conj"), z); @@ -606,44 +662,56 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) { ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + std::vector<string> fetch = {"conjugate_trans"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); EXPECT_EQ(7, output.node_size()); - EXPECT_EQ(OptimizedName("conjugate_trans_fused"), output.node(6).name()); - EXPECT_EQ("Transpose", output.node(6).op()); - EXPECT_EQ("z", output.node(6).input(0)); - EXPECT_EQ("perm", output.node(6).input(1)); + + const NodeDef* conjugate_trans_fused_node = + node_map.GetNode(OptimizedName("conjugate_trans_fused")); + EXPECT_EQ("Transpose", conjugate_trans_fused_node->op()); + EXPECT_EQ("z", conjugate_trans_fused_node->input(0)); + EXPECT_EQ("perm", conjugate_trans_fused_node->input(1)); + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]); } TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2}); - Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2}); + Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2}); Output z = ops::Complex(s.WithOpName("z"), re, im); Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); Output trans = ops::Transpose(s.WithOpName("trans"), z, perm); Output conj = ops::Conj(s.WithOpName("conj"), trans); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + std::vector<string> fetch = {"conj"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); EXPECT_EQ(7, output.node_size()); - EXPECT_EQ(OptimizedName("conj_fused"), output.node(6).name()); - EXPECT_EQ("ConjugateTranspose", output.node(6).op()); - EXPECT_EQ("z", output.node(6).input(0)); - EXPECT_EQ("perm", output.node(6).input(1)); + + const NodeDef* conj_fused_node = + node_map.GetNode(OptimizedName("conj_fused")); + EXPECT_EQ("ConjugateTranspose", conj_fused_node->op()); + EXPECT_EQ("z", conj_fused_node->input(0)); + EXPECT_EQ("perm", conj_fused_node->input(1)); + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]); } TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) { @@ -665,27 +733,32 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) { } GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + std::vector<string> fetch = {"matmul"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); EXPECT_EQ(7, output.node_size()); - EXPECT_EQ(OptimizedName("matmul_fused"), output.node(6).name()); - EXPECT_EQ("a", output.node(6).input(0)); - EXPECT_EQ("b", output.node(6).input(1)); + + const NodeDef* matmul_fused_node = + node_map.GetNode(OptimizedName("matmul_fused")); + ASSERT_NE(matmul_fused_node, nullptr); + EXPECT_EQ("a", matmul_fused_node->input(0)); + EXPECT_EQ("b", matmul_fused_node->input(1)); if (matmul_type == "BatchMatMul") { - EXPECT_TRUE(output.node(6).attr().at("adj_x").b()); - EXPECT_TRUE(output.node(6).attr().at("adj_y").b()); + EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b()); + EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b()); } else { - EXPECT_TRUE(output.node(6).attr().at("transpose_a").b()); - EXPECT_TRUE(output.node(6).attr().at("transpose_b").b()); + EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b()); + EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b()); } + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } } @@ -707,6 +780,9 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) { Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + std::vector<string> fetch = {"matmul"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; GraphDef output; @@ -719,6 +795,9 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) { EXPECT_EQ("b", output.node(10).input(1)); EXPECT_TRUE(output.node(10).attr().at("adj_x").b()); EXPECT_TRUE(output.node(10).attr().at("adj_y").b()); + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<complex64>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, IdentityReshape) { @@ -739,7 +818,10 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { GrapplerItem item; item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28})); + auto tensors_expected = + EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}}); + EXPECT_EQ(1, tensors_expected.size()); GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); @@ -747,6 +829,9 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(0, CountOpNodes(output, "Reshape")); + auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}}); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { @@ -761,7 +846,10 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { GrapplerItem item; item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28})); + item.feed = {{"Placeholder", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); @@ -769,6 +857,9 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { @@ -781,7 +872,6 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { GrapplerItem item; item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); @@ -812,7 +902,10 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) { GrapplerItem item; item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - + auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({8, 3, 28, 28, 4})); + item.feed = {{"nchw_vect_c", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); GraphDef output; TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); @@ -820,6 +913,9 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) { TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]); } TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) { @@ -1322,8 +1418,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) { // check add tree was replaced with AddN const NodeDef* collapsed_add = - node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab"); - ASSERT_TRUE(collapsed_add != nullptr); + node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc"); + ASSERT_NE(collapsed_add, nullptr); EXPECT_EQ("AddN", collapsed_add->op()); EXPECT_EQ(3, collapsed_add->input_size()); @@ -1333,7 +1429,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) { // check output was re-wired to new node const NodeDef* updated_outputs = node_map.GetNode("outputs"); - ASSERT_TRUE(updated_outputs != nullptr); + ASSERT_NE(updated_outputs, nullptr); EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0)); } @@ -1381,8 +1477,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) { // check left Add subtree replaced with AddN const NodeDef* collapsed_left = - node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab"); - ASSERT_TRUE(collapsed_left != nullptr); + node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc"); + ASSERT_NE(collapsed_left, nullptr); EXPECT_EQ("AddN", collapsed_left->op()); EXPECT_EQ(3, collapsed_left->input_size()); @@ -1392,8 +1488,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) { // check right Add subtree replaced with AddN const NodeDef* collapsed_right = - node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz_Add_xy"); - ASSERT_TRUE(collapsed_right != nullptr); + node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz"); + ASSERT_NE(collapsed_right, nullptr); EXPECT_EQ("AddN", collapsed_right->op()); EXPECT_EQ(3, collapsed_right->input_size()); @@ -1403,7 +1499,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) { // check that Mul inputs re-wired to new Nodes const NodeDef* updated_mul = node_map.GetNode("Mul"); - ASSERT_TRUE(updated_mul != nullptr); + ASSERT_NE(updated_mul, nullptr); EXPECT_EQ("Mul", updated_mul->op()); EXPECT_EQ(2, updated_mul->input_size()); @@ -1444,9 +1540,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) { NodeMap node_map(&output); // check Add tree replaced with AddN - const NodeDef* collapsed_add = node_map.GetNode( - "ArithmeticOptimizer/AddOpsRewrite_Add_all_Add_ab_Add_bc"); - ASSERT_TRUE(collapsed_add != nullptr); + const NodeDef* collapsed_add = + node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all"); + ASSERT_NE(collapsed_add, nullptr); EXPECT_EQ("AddN", collapsed_add->op()); EXPECT_EQ(4, collapsed_add->input_size()); @@ -1496,8 +1592,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) { // check add tree was replaced with AddN const NodeDef* collapsed_add = - node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab"); - ASSERT_TRUE(collapsed_add != nullptr); + node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc"); + ASSERT_NE(collapsed_add, nullptr); EXPECT_EQ("AddN", collapsed_add->op()); EXPECT_EQ(3, collapsed_add->input_size()); EXPECT_EQ("a", collapsed_add->input(0)); @@ -1506,10 +1602,173 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) { // check output was re-wired to new node const NodeDef* updated_outputs = node_map.GetNode("outputs"); - ASSERT_TRUE(updated_outputs != nullptr); + ASSERT_NE(updated_outputs, nullptr); EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0)); } +TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCast) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT); + auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT); + auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT); + auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b); + auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c); + + auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT); + auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT); + auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT); + auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y); + auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z); + + auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz); + auto outputs = ops::Identity(s.WithOpName("outputs"), add_all); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyAddToAddNCombining(&optimizer); + + OptimizeAndPrune(&optimizer, &item, &output); + + // We expect the following rewrite(s) to occur: + // 1) [a, x], [b, y], [c, z] - aggregate same shapes first + // 2) Build an aggregation tree minimizing cost of broadcast + // + // + + + // / \ / \ + // + + + AddN(c, z) + // / \ / \ / \ + // + c x + --> AddN(a, x) AddN(b, y) + // / \ / \ + // a b y z + EXPECT_EQ(12, output.node_size()); + NodeMap node_map(&output); + + // expected names of outer and inner nodes + string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll"; + string outer_0_add_name = + "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll"; + string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll"; + string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll"; + string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll"; + + // Add [a, x] first + const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name); + ASSERT_NE(add_ax_node, nullptr); + EXPECT_EQ("AddN", add_ax_node->op()); + EXPECT_EQ(2, add_ax_node->input_size()); + EXPECT_EQ("a", add_ax_node->input(0)); + EXPECT_EQ("x", add_ax_node->input(1)); + + // Then add [b, y] + const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name); + ASSERT_NE(add_by_node, nullptr); + EXPECT_EQ("AddN", add_by_node->op()); + EXPECT_EQ(2, add_by_node->input_size()); + EXPECT_EQ("b", add_by_node->input(0)); + EXPECT_EQ("y", add_by_node->input(1)); + + // Then add [c, z] + const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name); + ASSERT_NE(add_cz_node, nullptr); + EXPECT_EQ("AddN", add_cz_node->op()); + EXPECT_EQ(2, add_cz_node->input_size()); + EXPECT_EQ("c", add_cz_node->input(0)); + EXPECT_EQ("z", add_cz_node->input(1)); + + // Then add results together starting from smaller shapes [a, x] + [b, y] + const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name); + ASSERT_NE(outer_0_node, nullptr); + EXPECT_EQ("Add", outer_0_node->op()); + EXPECT_EQ(2, outer_0_node->input_size()); + EXPECT_EQ(inner_0_add_name, outer_0_node->input(0)); + EXPECT_EQ(inner_1_add_name, outer_0_node->input(1)); + + // And finally top level Add node + const NodeDef* outer_node = node_map.GetNode(outer_add_name); + ASSERT_NE(outer_node, nullptr); + EXPECT_EQ("Add", outer_node->op()); + EXPECT_EQ(2, outer_node->input_size()); + EXPECT_EQ(outer_0_add_name, outer_node->input(0)); + EXPECT_EQ(inner_2_add_name, outer_node->input(1)); + + // And outputs reading new top level Add node + const NodeDef* updated_outputs = node_map.GetNode("outputs"); + ASSERT_NE(updated_outputs, nullptr); + EXPECT_EQ(outer_add_name, updated_outputs->input(0)); +} + +TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCastWithSymbolicShapes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + // We have a small input with one unknown dimension + auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_FLOAT); + + // And second input which is larger, but has the same unknown dimension + // device spec prevents this node from rewriting + auto d = "/job:do_not_rewrite_me"; + auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_FLOAT); + auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v); + + // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32} + auto a = ops::Sqrt(s.WithOpName("a"), small); + auto b = ops::Square(s.WithOpName("b"), large); + auto c = ops::Round(s.WithOpName("c"), small); + + // [add_ab, add_abc] shape must be inferred from inputs + auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b); + auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c); + + auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyAddToAddNCombining(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + + // We expect the following rewrite(s) to occur: it's much cheaper to add small + // tensors, and do the broadcast just once + // + // + + + // / \ / \ + // + c --> + b + // / \ / \ + // a b a c + EXPECT_EQ(9, output.node_size()); + NodeMap node_map(&output); + + // expected names of outer and inner nodes + string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc"; + string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc"; + + // outer Add node + const NodeDef* outer_add = node_map.GetNode(outer_add_name); + ASSERT_NE(outer_add, nullptr); + EXPECT_EQ("Add", outer_add->op()); + EXPECT_EQ(inner_add_name, outer_add->input(0)); + EXPECT_EQ("b", outer_add->input(1)); + + // inner AddN node + const NodeDef* inner_add = node_map.GetNode(inner_add_name); + ASSERT_NE(inner_add, nullptr); + EXPECT_EQ(2, inner_add->input_size()); + EXPECT_EQ("a", inner_add->input(0)); + EXPECT_EQ("c", inner_add->input(1)); + + // check output was re-wired to new node + const NodeDef* updated_outputs = node_map.GetNode("outputs"); + ASSERT_NE(updated_outputs, nullptr); + EXPECT_EQ(outer_add_name, updated_outputs->input(0)); +} + TEST_F(ArithmeticOptimizerTest, RemoveNegation) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 7de544de52..d941a0b3f9 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -747,10 +747,6 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { if (op.find("Quantized") != string::npos || op.find("Sparse") == 0) { return false; } - if (node.attr().count("_XlaCompile") > 0 && - node.attr().at("_XlaCompile").b()) { - return false; - } const OpDef* op_def = nullptr; Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); @@ -777,7 +773,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { // the case of a merge node that propagate the first inputs that becomes // available, and therefore only requires a single constant input to be // foldable. - bool has_constant_input = false; + bool merge_has_constant_input = false; const bool is_merge = IsMerge(node); for (const auto& input : node.input()) { if (IsControlInput(input)) { @@ -788,21 +784,20 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } bool is_const = IsReallyConstant(*input_node); - if (!is_const && !is_merge) { - return false; - } - // Don't fold strings constants for now since this causes problems with - // checkpointing. - if (is_const && input_node->attr().at("dtype").type() == DT_STRING) { + if (is_const) { + // Don't fold strings constants for now since this causes problems with + // checkpointing. + if (input_node->attr().at("dtype").type() == DT_STRING) { + return false; + } + // Special case: If a Merge node has at least one constant input that + // does not depend on a control input, we can fold it. + merge_has_constant_input |= !HasControlInputs(*input_node); + } else if (!is_merge) { return false; } - has_constant_input |= is_const; - } - if (is_merge) { - return has_constant_input; } - - return true; + return !is_merge || merge_has_constant_input; } namespace { @@ -1542,6 +1537,16 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, for (int i = 0; i < optimized_graph->node_size(); ++i) { NodeDef* node = optimized_graph->mutable_node(i); + if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { + ReplaceOperationWithIdentity(1, node, optimized_graph); + continue; + } + + if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) { + ReplaceOperationWithIdentity(0, node, optimized_graph); + continue; + } + // Remove Shuffle or Reverse op over scalar values. if (use_shape_info && !properties->GetInputProperties(node->name()).empty() && @@ -1708,9 +1713,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, } // Move constants past Enter. - // TODO(rmlarsen): Reenable when we fix the root cause of b/76008022 - if (opt_level_ == RewriterConfig::AGGRESSIVE && IsEnter(*node) && - node->input_size() > 0) { + if (IsEnter(*node) && node->input_size() > 0) { + if (node->attr().count("is_constant") == 0 || + !node->attr().at("is_constant").b()) { + continue; + } const string& node_name = node->name(); const NodeDef* input = node_map_->GetNode(node->input(0)); if (input != nullptr && IsReallyConstant(*input) && @@ -1739,7 +1746,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, node_map_->AddOutput(node_name, new_node->name()); for (NodeDef* consumer : consumers) { for (int i = 0; i < consumer->input_size(); ++i) { - if (consumer->input(i) == node_name) { + if (NodeName(consumer->input(i)) == node_name) { node_map_->UpdateInput(consumer->name(), node_name, new_node->name()); consumer->set_input(i, new_node->name()); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 1db4fb9de7..71ee81dfde 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -83,14 +83,6 @@ class ConstantFoldingTest : public GrapplerTest { } }; -template <DataType DTYPE> -Tensor GetRandomTensor(const TensorShape& shape) { - typedef typename EnumToDataType<DTYPE>::Type T; - Tensor tensor(DTYPE, shape); - tensor.flat<T>() = tensor.flat<T>().random(); - return tensor; -} - TEST_F(ConstantFoldingTest, SimpleFolding) { // Build a simple graph with a few trivially prunable ops. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -380,11 +372,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ(2, t.tensor_shape().dim(1).size()); } } - auto a_t = GetRandomTensor<DT_FLOAT>(TensorShape({3, 2})); - auto b_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 3})); - auto x_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 2})); - auto y_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 2})); - auto bias_t = GetRandomTensor<DT_FLOAT>(TensorShape({2})); + auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 2})); + auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3})); + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2})); + auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2})); + auto bias_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2})); auto tensors_expected = EvaluateNodes( item.graph, item.fetch, @@ -1264,6 +1256,10 @@ TEST_F(ConstantFoldingTest, MergeNodes) { ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2}); ops::Merge m2(scope.WithOpName("m2"), {const1, const3}); ops::Merge m3(scope.WithOpName("m3"), {x, y}); + // m4 is not foldable because the only constant input + // has a control input, so we cannot know if it will be + // triggered. + ops::Merge m4(scope.WithOpName("m4"), {x, const1}); ops::Identity out1(scope.WithOpName("out1"), m1.output); ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index); @@ -1271,9 +1267,11 @@ TEST_F(ConstantFoldingTest, MergeNodes) { ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index); ops::Identity out3(scope.WithOpName("out3"), m3.output); ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index); + ops::Identity out4(scope.WithOpName("out4"), m4.output); + ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index); GrapplerItem item; - item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"}; + item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"}; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); ConstantFolding optimizer(nullptr /* cpu_device */); @@ -1281,6 +1279,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) { Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + EXPECT_EQ(19, output.node_size()); int found_nodes = 0; for (const auto& node : output.node()) { if (node.name() == "out1") { @@ -1317,10 +1316,18 @@ TEST_F(ConstantFoldingTest, MergeNodes) { EXPECT_EQ(1, node.input_size()); EXPECT_EQ("m3:1", node.input(0)); ++found_nodes; + } else if (node.name() == "out4") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("m4", node.input(0)); + ++found_nodes; + } else if (node.name() == "idx4") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("m4:1", node.input(0)); + ++found_nodes; } } // Make sure the graph contains all the nodes we're expecting. - EXPECT_EQ(6, found_nodes); + EXPECT_EQ(8, found_nodes); std::vector<string> fetch = {"out1", "idx1"}; auto tensors = EvaluateNodes(output, fetch); @@ -1335,6 +1342,82 @@ TEST_F(ConstantFoldingTest, MergeNodes) { EXPECT_EQ(2, out_idx.flat<int32>()(0)); } +TEST_F(ConstantFoldingTest, SplitRemoval) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output in1 = + ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT); + Output in2 = + ops::Variable(scope.WithOpName("in2"), TensorShape({4}), DT_FLOAT); + auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {}); + ops::Split s1(scope.WithOpName("s1"), split_dim, in1, 1); + ops::Split s2(scope.WithOpName("s2"), split_dim, in2, 2); + + ops::Add out(scope.WithOpName("out"), s1[0], s2[0]); + + GrapplerItem item; + item.fetch = {"out"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef got; + Status status = optimizer.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("in1", "VariableV2", {}, {}, &want); + AddNode("in2", "VariableV2", {}, {}, &want); + AddNode("split_dim", "Const", {}, {}, &want); + AddNode("s1", "Identity", {"in1", AsControlDependency("split_dim")}, {}, + &want); + AddNode("s2", "Split", {"in2", "split_dim"}, {}, &want); + AddNode("out", "Add", {"s1", "s2"}, {}, &want); + + CompareGraphs(want, got); +} + +TEST_F(ConstantFoldingTest, SplitVRemoval) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output in1 = + ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT); + Output in2 = + ops::Variable(scope.WithOpName("in2"), TensorShape({5}), DT_FLOAT); + auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {}); + auto size_splits1 = ops::Const(scope.WithOpName("size_splits1"), {2}, {1}); + auto size_splits2 = ops::Const(scope.WithOpName("size_splits2"), {2, 3}, {2}); + ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1); + ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2); + + LOG(INFO) << s1.output.size(); + LOG(INFO) << s2.output.size(); + ops::Add out(scope.WithOpName("out"), s1[0], s2[0]); + + GrapplerItem item; + item.fetch = {"out"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef got; + Status status = optimizer.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("in1", "VariableV2", {}, {}, &want); + AddNode("in2", "VariableV2", {}, {}, &want); + AddNode("split_dim", "Const", {}, {}, &want); + AddNode("size_splits1", "Const", {}, {}, &want); + AddNode("size_splits2", "Const", {}, {}, &want); + AddNode("s1", "Identity", + {"in1", AsControlDependency("size_splits1"), + AsControlDependency("split_dim")}, + {}, &want); + AddNode("s2", "SplitV", {"in2", "size_splits2", "split_dim"}, {}, &want); + AddNode("out", "Add", {"s1", "s2"}, {}, &want); + + CompareGraphs(want, got); +} + TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); @@ -2252,6 +2335,10 @@ TEST_F(ConstantFoldingTest, Enter) { GrapplerItem item; AttrValue frame_name; frame_name.set_s("foo"); + AttrValue is_constant_true; + is_constant_true.set_b(true); + AttrValue is_constant_false; + is_constant_false.set_b(false); AttrValue type; type.set_type(DT_FLOAT); AttrValue value; @@ -2262,19 +2349,31 @@ TEST_F(ConstantFoldingTest, Enter) { GraphDef& graph = item.graph; AddNode("x", "Placeholder", {}, {{"T", type}}, &graph); AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph); - AddNode("enter1", "Enter", {"x"}, {{"T", type}, {"frame_name", frame_name}}, + AddNode("enter1", "Enter", {"x"}, + {{"T", type}, + {"frame_name", frame_name}, + {"is_constant", is_constant_true}}, &graph); - AddNode("enter2", "Enter", {"c1"}, {{"T", type}, {"frame_name", frame_name}}, + AddNode("enter2", "Enter", {"c1"}, + {{"T", type}, + {"frame_name", frame_name}, + {"is_constant", is_constant_true}}, + &graph); + AddNode("enter3", "Enter", {"c1"}, + {{"T", type}, + {"frame_name", frame_name}, + {"is_constant", is_constant_false}}, &graph); AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph); AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph); AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph); + AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph); item.fetch.push_back("id1"); item.fetch.push_back("id2"); item.fetch.push_back("id3"); + item.fetch.push_back("id4"); - ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, - nullptr /* cpu_device */); + ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -2283,7 +2382,7 @@ TEST_F(ConstantFoldingTest, Enter) { status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(7, output.node_size()); + EXPECT_EQ(9, output.node_size()); for (const NodeDef& node : output.node()) { if (node.name() == "id1") { EXPECT_EQ("Identity", node.op()); @@ -2295,6 +2394,11 @@ TEST_F(ConstantFoldingTest, Enter) { EXPECT_EQ(1, node.input_size()); EXPECT_EQ("^enter2", node.input(0)); } + if (node.name() == "id4") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("enter3", node.input(0)); + } } } diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc index 0e058e3435..8bd10171f1 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper.cc +++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/debug_stripper.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -39,6 +40,10 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item, inp = AsControlDependency(inp); } } + } else if (IsCheckNumerics(node)) { + // Replace with Identity op which will be pruned later. + node.set_op("Identity"); + node.mutable_attr()->erase("message"); } } return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc index aacd55f136..3f11febc64 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc +++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/debug_stripper.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -29,14 +29,13 @@ namespace { class DebugStripperTest : public GrapplerTest {}; TEST_F(DebugStripperTest, OutputEqualToInput) { - constexpr char device[] = "/device:CPU:0"; + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({})); + Output y = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({})); + Output add = ops::Add(s, x, y); + Output result = ops::Identity(s, add); GrapplerItem item; - item.graph = test::function::GDef( - {test::function::NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, - device), - test::function::NDef("y", "XTimesTwo", {"x"}, {}, device), - test::function::NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, device)}, - {}); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); DebugStripper optimizer; GraphDef output; @@ -45,19 +44,17 @@ TEST_F(DebugStripperTest, OutputEqualToInput) { } TEST_F(DebugStripperTest, StripAssertFromGraph) { - constexpr char device[] = "/device:CPU:0"; + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape({})); + Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, + ops::Placeholder::Shape({})); + auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y); + auto assert = ops::Assert(s.WithOpName("Assert"), greaterequal, {x, y}); + Output add = ops::Add( + s.WithOpName("z").WithControlDependencies({assert.operation}), x, y); GrapplerItem item; - item.graph = test::function::GDef( - {test::function::NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, - device), - test::function::NDef("y", "Placeholder", {}, {{"dtype", DT_FLOAT}}, - device), - test::function::NDef("GreaterEqual", "GreaterEqual", {"x", "y"}, - {{"T", DT_FLOAT}}, device), - test::function::NDef("Assert", "Assert", {"GreaterEqual"}, - {{"T", DT_FLOAT}}, device), - test::function::NDef("z", "Add", {"x", "y", "^Assert"}, {}, device)}, - {}); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); DebugStripper optimizer; GraphDef output; @@ -68,31 +65,27 @@ TEST_F(DebugStripperTest, StripAssertFromGraph) { if (node.name() == "x") { count++; EXPECT_EQ("Placeholder", node.op()); - EXPECT_EQ(device, node.device()); EXPECT_EQ(0, node.input_size()); } else if (node.name() == "y") { count++; EXPECT_EQ("Placeholder", node.op()); - EXPECT_EQ(device, node.device()); EXPECT_EQ(0, node.input_size()); } else if (node.name() == "GreaterEqual") { count++; EXPECT_EQ("GreaterEqual", node.op()); - EXPECT_EQ(device, node.device()); EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); } else if (node.name() == "Assert") { count++; EXPECT_EQ("NoOp", node.op()); - EXPECT_EQ(device, node.device()); - EXPECT_EQ(1, node.input_size()); + EXPECT_EQ(3, node.input_size()); EXPECT_EQ("^GreaterEqual", node.input(0)); - EXPECT_EQ(0, node.attr_size()); + EXPECT_EQ("^x", node.input(1)); + EXPECT_EQ("^y", node.input(2)); } else if (node.name() == "z") { count++; EXPECT_EQ("Add", node.op()); - EXPECT_EQ(device, node.device()); EXPECT_EQ(3, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); @@ -100,6 +93,75 @@ TEST_F(DebugStripperTest, StripAssertFromGraph) { } } EXPECT_EQ(5, count); + + Tensor x_t(DT_FLOAT, TensorShape({})); + Tensor y_t(DT_FLOAT, TensorShape({})); + x_t.flat<float>()(0) = 1.0f; + y_t.flat<float>()(0) = 0.5f; + std::vector<Tensor> expected = + EvaluateNodes(item.graph, {"z"}, {{"x", x_t}, {"y", y_t}}); + std::vector<Tensor> optimized = + EvaluateNodes(output, {"z"}, {{"x", x_t}, {"y", y_t}}); + test::ExpectTensorEqual<float>(expected[0], optimized[0]); +} + +TEST_F(DebugStripperTest, StripCheckNumericsFromGraph) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape({})); + Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, + ops::Placeholder::Shape({})); + auto check1 = ops::CheckNumerics(s.WithOpName("CheckNumerics1"), x, "foo"); + auto check2 = ops::CheckNumerics(s.WithOpName("CheckNumerics2"), y, "foo"); + Output add = ops::Add(s.WithOpName("z"), check1, check2); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + DebugStripper optimizer; + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + int count = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "x") { + count++; + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "y") { + count++; + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "CheckNumerics1") { + count++; + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ(1, node.attr_size()); + } else if (node.name() == "CheckNumerics2") { + count++; + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ(1, node.attr_size()); + } else if (node.name() == "z") { + count++; + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("CheckNumerics1", node.input(0)); + EXPECT_EQ("CheckNumerics2", node.input(1)); + } + } + EXPECT_EQ(5, count); + + Tensor x_t(DT_FLOAT, TensorShape({})); + Tensor y_t(DT_FLOAT, TensorShape({})); + x_t.flat<float>()(0) = 1.0f; + y_t.flat<float>()(0) = 0.5f; + std::vector<Tensor> expected = + EvaluateNodes(item.graph, {"z"}, {{"x", x_t}, {"y", y_t}}); + std::vector<Tensor> optimized = + EvaluateNodes(output, {"z"}, {{"x", x_t}, {"y", y_t}}); + test::ExpectTensorEqual<float>(expected[0], optimized[0]); } } // namespace diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index 57b3118245..6a297da52d 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -678,6 +678,50 @@ TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) { } } +TEST_F(DependencyOptimizerTest, RemoveGreaterEqualWithNoOp) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape({})); + Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, + ops::Placeholder::Shape({})); + auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y); + auto noop = + ops::NoOp(s.WithOpName("NoOp").WithControlDependencies(greaterequal)); + Output add = ops::Add( + s.WithOpName("z").WithControlDependencies({noop.operation}), x, y); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + DependencyOptimizer optimizer; + GraphDef output; + item.fetch.push_back("z"); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + int count = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "x") { + count++; + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "y") { + count++; + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "GreaterEqual") { + count++; + } else if (node.name() == "NoOp") { + count++; + } else if (node.name() == "z") { + count++; + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("y", node.input(1)); + } + } + EXPECT_EQ(3, count); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 2a6b8a325f..f1da469a6c 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -32,16 +32,129 @@ limitations under the License. namespace tensorflow { namespace grappler { +namespace { + +class FunctionInliningContext { + public: + explicit FunctionInliningContext(const GrapplerItem& item) + : library_(&item.graph.library()), functions_(InliningCandidates(item)) {} + + const FunctionDefLibrary& Library() const { return *library_; } + + bool HasInlinedFunctions() const { return !functions_.empty(); } + + // Find inlining candidate by name. Return nullptr if not found. + const FunctionDef* FindInlinedFunction(const string& name) const { + auto it = functions_.find(name); + if (it != functions_.end()) { + return it->second; + } else { + return nullptr; + } + } + + private: + std::unordered_map<string, const FunctionDef*> InliningCandidates( + const GrapplerItem& item) const { + std::unordered_map<string, const FunctionDef*> functions; + for (const FunctionDef& func : item.graph.library().function()) { + // Don't inline functions marked as noinline + if (func.attr().count("_noinline") != 0) { + continue; + } + // Don't touch anything marked XLA to prevent XLA failures further down + // the road. + if (func.attr().count("_XlaCompile") > 0 && + func.attr().at("_XlaCompile").b()) { + continue; + } + // Can't create IdentityN nodes with no input or output: skip these + // functions for now. + if (func.signature().input_arg_size() == 0 || + func.signature().output_arg_size() == 0) { + continue; + } + functions[func.signature().name()] = &func; + } + return functions; + } + + const FunctionDefLibrary* library_; + std::unordered_map<string, const FunctionDef*> functions_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext); +}; + +// Copy input/output argument type to the type_list. Return error if argument +// type is not explicitly defined, and not specified in function attributes. +Status CopyArgType(const NodeDef& func_node, + const std::unordered_map<string, AttrValue>& func_attr, + const string& arg_kind, const OpDef::ArgDef& arg, + AttrValue::ListValue* type_list) { + if (arg.type() != DT_INVALID) { + type_list->add_type(arg.type()); + } else { + auto it = func_attr.find(arg.type_attr()); + if (it == func_attr.end() || it->second.type() == DT_INVALID) { + return errors::InvalidArgument( + "Invalid ", arg_kind, " argument ", arg.name(), " for function ", + func_node.op(), " instantiated by ", func_node.name()); + } + type_list->add_type(it->second.type()); + } + return Status::OK(); +} + +// Add an IdentityN op to hook the function inputs to: this ensures that +// they're all evaluated before the evaluation of the function body starts. +Status HookInlinedFunctionInputs( + const NodeDef& func_node, const FunctionDef& func, + const std::unordered_map<string, AttrValue>& func_attr, NodeDef* inputs) { + inputs->set_name(strings::StrCat(func_node.name(), "/", "inlined_inputs")); + inputs->set_op("IdentityN"); + inputs->set_device(func_node.device()); + *inputs->mutable_input() = func_node.input(); + AttrValue::ListValue* type_list = + (*inputs->mutable_attr())["T"].mutable_list(); + for (const OpDef::ArgDef& arg : func.signature().input_arg()) { + TF_RETURN_IF_ERROR( + CopyArgType(func_node, func_attr, "input", arg, type_list)); + } + return Status::OK(); +} + +// Add an IdentityN op to hook the function outputs to: this ensures that the +// function body is fully evaluated before its fanout gets scheduled. +Status HookInlinedFunctionOutputs( + const NodeDef& func_node, const FunctionDef& func, + const std::unordered_map<string, AttrValue>& func_attr, + const gtl::ArraySlice<string> fetch, NodeDef* outputs) { + outputs->set_name(func_node.name()); + outputs->set_op("IdentityN"); + outputs->set_device(func_node.device()); + AttrValue::ListValue* type_list = + (*outputs->mutable_attr())["T"].mutable_list(); + for (int i = 0; i < func.signature().output_arg_size(); ++i) { + const OpDef::ArgDef& arg = func.signature().output_arg(i); + TF_RETURN_IF_ERROR( + CopyArgType(func_node, func_attr, "output", arg, type_list)); + // Use the fetch names since they take into account the output mapping. + outputs->add_input(strings::StrCat(func_node.name(), "/", fetch[i])); + } + return Status::OK(); +} + +Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, + const FunctionInliningContext& ctx, + GraphDef* optimized_graph) { + const std::unordered_map<string, AttrValue> func_attr( + func_node.attr().begin(), func_node.attr().end()); -Status InlineFunction(const NodeDef& node, const FunctionDef& func, - const FunctionDefLibrary& library, GraphDef* graph) { - const std::unordered_map<string, AttrValue> attr(node.attr().begin(), - node.attr().end()); std::unique_ptr<GrapplerItem> item = - GrapplerItemFromFunctionDef(func, attr, library); + GrapplerItemFromFunctionDef(func, func_attr, ctx.Library()); if (!item) { - return errors::InvalidArgument("Failed to inline function ", node.op(), - " instantiated by ", node.name()); + return errors::InvalidArgument("Failed to inline function ", func_node.op(), + " instantiated by ", func_node.name()); } std::unordered_map<string, int> input_nodes; @@ -50,43 +163,25 @@ Status InlineFunction(const NodeDef& node, const FunctionDef& func, input_nodes[arg.name()] = i; } - // Add an IdentityN op to hook the function inputs to: this ensures that - // they're all evaluated before the evaluation of the function body starts. - NodeDef* func_inputs = graph->add_node(); - func_inputs->set_name(strings::StrCat(node.name(), "/", "inlined_inputs")); - func_inputs->set_op("IdentityN"); - func_inputs->set_device(node.device()); - *func_inputs->mutable_input() = node.input(); - AttrValue::ListValue* type_list = - (*func_inputs->mutable_attr())["T"].mutable_list(); - for (const OpDef::ArgDef& arg : func.signature().input_arg()) { - if (arg.type() != DT_INVALID) { - type_list->add_type(arg.type()); - } else { - auto it = attr.find(arg.type_attr()); - if (it == attr.end()) { - return errors::InvalidArgument("Invalid input argument ", arg.name(), - " for function ", node.op(), - " instantiated by ", node.name()); - } - type_list->add_type(it->second.type()); - } - } + // Hook inlined function inputs to IdentityN node + NodeDef* func_inputs = optimized_graph->add_node(); + TF_RETURN_IF_ERROR( + HookInlinedFunctionInputs(func_node, func, func_attr, func_inputs)); for (NodeDef& func_body_node : *item->graph.mutable_node()) { if (input_nodes.find(func_body_node.name()) != input_nodes.end()) { + CHECK_EQ(0, func_body_node.input_size()); // Turn input placeholders into identity nodes if (IsPlaceholder(func_body_node)) { func_body_node.set_op("Identity"); } - CHECK_EQ(0, func_body_node.input_size()); int input_id = input_nodes[func_body_node.name()]; func_body_node.add_input( strings::StrCat(func_inputs->name(), ":", input_id)); } else { // Update the input names if any. for (string& input : *func_body_node.mutable_input()) { - input = AddPrefixToNodeName(input, node.name()); + input = AddPrefixToNodeName(input, /*prefix=*/func_node.name()); } // If the node has no input, make hook it up to the func_inputs node to // ensure it runs in the same frame as the other nodes of the function @@ -98,39 +193,29 @@ Status InlineFunction(const NodeDef& node, const FunctionDef& func, // Add the node name as a prefix to avoid collisions after inlining func_body_node.set_name( - strings::StrCat(node.name(), "/", func_body_node.name())); + strings::StrCat(func_node.name(), "/", func_body_node.name())); // Make sure the node is placed - func_body_node.set_device(node.device()); - - // Move the node to the main graph - graph->add_node()->Swap(&func_body_node); - } - - // Add an IdentityN op to hook the function outputs to: this ensures that the - // function body is fully evaluated before its fanout gets scheduled. - NodeDef* func_outputs = graph->add_node(); - func_outputs->set_name(node.name()); - func_outputs->set_op("IdentityN"); - func_outputs->set_device(node.device()); - type_list = (*func_outputs->mutable_attr())["T"].mutable_list(); - for (int i = 0; i < func.signature().output_arg_size(); ++i) { - const OpDef::ArgDef& arg = func.signature().output_arg(i); - if (arg.type() != DT_INVALID) { - type_list->add_type(arg.type()); + func_body_node.set_device(func_node.device()); + + // Check if a body node is itself a function + const FunctionDef* func_body_node_func = + ctx.FindInlinedFunction(func_body_node.op()); + if (func_body_node_func != nullptr) { + // Recursively inline function calls + TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func, + ctx, optimized_graph)); } else { - auto it = attr.find(arg.type_attr()); - if (it == attr.end()) { - return errors::InvalidArgument("Invalid output argument ", arg.name(), - " for function ", node.op(), - " instantiated by ", node.name()); - } - type_list->add_type(it->second.type()); + // Move the node to the main graph + optimized_graph->add_node()->Swap(&func_body_node); } - // Use the fetch names since they take into account the output mapping. - func_outputs->add_input(strings::StrCat(node.name(), "/", item->fetch[i])); } + // Hook inlined function outputs to IdentityN node + NodeDef* func_outputs = optimized_graph->add_node(); + TF_RETURN_IF_ERROR(HookInlinedFunctionOutputs(func_node, func, func_attr, + item->fetch, func_outputs)); + return Status::OK(); } @@ -278,31 +363,14 @@ Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env, return Status::OK(); } +} // namespace + Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { - std::unordered_map<string, const FunctionDef*> functions; - for (const FunctionDef& func : item.graph.library().function()) { - // Don't inline functions marked as noinline - if (func.attr().count("_noinline") != 0) { - continue; - } - // Don't touch anything marked XLA to prevent XLA failures further down the - // road. - if (func.attr().count("_XlaCompile") > 0 && - func.attr().at("_XlaCompile").b()) { - continue; - } - // Can't create IdentityN nodes with no input or output: skip these - // functions for now. - if (func.signature().input_arg_size() == 0 || - func.signature().output_arg_size() == 0) { - continue; - } - functions[func.signature().name()] = &func; - } + FunctionInliningContext function_inlining_ctx(item); - // Nothing to do. - if (functions.empty()) { + // Nothing to do here. + if (!function_inlining_ctx.HasInlinedFunctions()) { *optimized_graph = item.graph; return Status::OK(); } @@ -315,12 +383,14 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(InlineSymbolicGradient(node, &env, optimized_graph)); continue; } - auto it = functions.find(node.op()); - if (it == functions.end()) { - *optimized_graph->add_node() = node; + + const FunctionDef* func = + function_inlining_ctx.FindInlinedFunction(node.op()); + if (func != nullptr) { + TF_RETURN_IF_ERROR( + InlineFunction(node, *func, function_inlining_ctx, optimized_graph)); } else { - TF_RETURN_IF_ERROR(InlineFunction(node, *it->second, item.graph.library(), - optimized_graph)); + *optimized_graph->add_node() = node; } } diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc index deb2fabded..c804d75756 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc @@ -26,7 +26,22 @@ namespace tensorflow { namespace grappler { namespace { -class FunctionOptimizerTest : public GrapplerTest {}; +constexpr char kDevice[] = "/device:CPU:0"; + +class FunctionOptimizerTest : public GrapplerTest { + protected: + Tensor MakeScalarTensor(float value) { + Tensor tensor(DT_FLOAT, {}); + tensor.scalar<float>()() = value; + return tensor; + } + + Tensor MakeScalarTensor(int value) { + Tensor tensor(DT_INT32, {}); + tensor.scalar<int>()() = value; + return tensor; + } +}; TEST_F(FunctionOptimizerTest, SimpleFunction) { // Build a graph to compute y = XTimesTwo(x) @@ -94,9 +109,8 @@ TEST_F(FunctionOptimizerTest, SimpleFunction) { } EXPECT_EQ(7, count); + Tensor pi = MakeScalarTensor(3.14f); item.fetch = {"z"}; - Tensor pi(DT_FLOAT, {}); - pi.flat<float>()(0) = 3.14f; item.feed.emplace_back("x", pi); auto tensors_expected = EvaluateFetchNodes(item); GrapplerItem optimized(item, std::move(output)); @@ -183,9 +197,8 @@ TEST_F(FunctionOptimizerTest, FixedTypeFunction) { } EXPECT_EQ(6, count); + Tensor pi = MakeScalarTensor(3.14f); item.fetch = {"z"}; - Tensor pi(DT_FLOAT, {}); - pi.flat<float>()(0) = 3.14f; item.feed.emplace_back("x", pi); auto tensors_expected = EvaluateFetchNodes(item); GrapplerItem optimized(item, std::move(output)); @@ -268,9 +281,8 @@ TEST_F(FunctionOptimizerTest, FunctionWithOutputMapping) { } EXPECT_EQ(6, count); + Tensor pi = MakeScalarTensor(3.14f); item.fetch = {"z"}; - Tensor pi(DT_FLOAT, {}); - pi.flat<float>()(0) = 3.14f; item.feed.emplace_back("x", pi); auto tensors_expected = EvaluateFetchNodes(item); GrapplerItem optimized(item, std::move(output)); @@ -325,18 +337,11 @@ TEST_F(FunctionOptimizerTest, FunctionWithInputForwarding) { TF_EXPECT_OK(status); item.fetch = {"z0", "z1", "z2"}; - Tensor in(DT_FLOAT, {}); - in.flat<float>()(0) = 3.14f; - item.feed.emplace_back("x0", in); - in.flat<float>()(0) = 2.7f; - item.feed.emplace_back("x1", in); - in.flat<float>()(0) = 1.0f; - item.feed.emplace_back("x2", in); - in.flat<float>()(0) = -1.0f; - item.feed.emplace_back("x4", in); - Tensor in_int(DT_INT32, {}); - in_int.flat<int>()(0) = 1234; - item.feed.emplace_back("x3", in_int); + item.feed.emplace_back("x0", MakeScalarTensor(3.14f)); + item.feed.emplace_back("x1", MakeScalarTensor(2.7f)); + item.feed.emplace_back("x2", MakeScalarTensor(1.0f)); + item.feed.emplace_back("x4", MakeScalarTensor(-1.0f)); + item.feed.emplace_back("x3", MakeScalarTensor(1234)); auto tensors_expected = EvaluateFetchNodes(item); GrapplerItem optimized(item, std::move(output)); auto tensors = EvaluateFetchNodes(optimized); @@ -379,6 +384,100 @@ TEST_F(FunctionOptimizerTest, FunctionWithoutInput) { EXPECT_EQ(item.graph.DebugString(), output.DebugString()); } +TEST_F(FunctionOptimizerTest, InlineFunctionWithNestedFunctionCall) { + // Define square via function library: + // MySquare(x) = MyMul(x, x) + + FunctionDef mul_func = FunctionDefHelper::Create( + "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"}, + {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "output:z:0"}}); + + FunctionDef square_func = FunctionDefHelper::Create( + "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"}, + {{{"output"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "output:z:0"}}); + + GrapplerItem item; + item.graph = test::function::GDef( + {test::function::NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, + kDevice), + test::function::NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, + kDevice), + test::function::NDef("outputs", "Identity", {"square:0"}, + {{"T", DT_FLOAT}}, kDevice)}, + // FunctionLib + {mul_func, square_func}); + + GraphDef output; + FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + int count = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "square/inlined_inputs" && count++) { + EXPECT_EQ("IdentityN", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("a", node.input(0)); + } else if (node.name() == "square/x" && count++) { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("square/inlined_inputs:0", node.input(0)); + } else if (node.name() == "square/output/inlined_inputs" && count++) { + EXPECT_EQ("IdentityN", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("square/x", node.input(0)); + EXPECT_EQ("square/x", node.input(1)); + } else if (node.name() == "square/output/x" && count++) { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("square/output/inlined_inputs:0", node.input(0)); + } else if (node.name() == "square/output/y" && count++) { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("square/output/inlined_inputs:1", node.input(0)); + } else if (node.name() == "square/output/output" && count++) { + EXPECT_EQ("Mul", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("square/output/x", node.input(0)); + EXPECT_EQ("square/output/y", node.input(1)); + } else if (node.name() == "square/output" && count++) { + EXPECT_EQ("IdentityN", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("square/output/output:0", node.input(0)); + } else if (node.name() == "square" && count++) { + EXPECT_EQ("IdentityN", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("square/output:0", node.input(0)); + } else if (node.name() == "outputs" && count++) { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("square:0", node.input(0)); + } + } + EXPECT_EQ(9, count); + + item.fetch = {"outputs"}; + item.feed.emplace_back("a", MakeScalarTensor(2.0f)); + auto tensors_expected = EvaluateFetchNodes(item); + + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + + test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]); +} + TEST_F(FunctionOptimizerTest, SymbolicGradients) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h index be95c00d2d..7ed0474861 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/grappler/utils/frame.h" namespace tensorflow { namespace grappler { @@ -45,21 +44,16 @@ const NodeScopeAndName ParseNodeScopeAndName(const string& node_name); struct GraphOptimizerContext { GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve, GraphDef* optimized_graph, - GraphProperties* graph_properties, NodeMap* node_map, - FrameMap* frame_map) + GraphProperties* graph_properties, NodeMap* node_map) : nodes_to_preserve(nodes_to_preserve), optimized_graph(optimized_graph), graph_properties(graph_properties), - node_map(node_map), - frame_map(frame_map) {} + node_map(node_map) {} const std::unordered_set<string>* nodes_to_preserve; GraphDef* optimized_graph; GraphProperties* graph_properties; NodeMap* node_map; - // TODO(ezhulenev): it seems that frame_map is only relevant for loop - // optimizer? Move it to loop-optimizer specific context extension. - FrameMap* frame_map; }; Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, @@ -117,6 +111,9 @@ class GraphOptimizerStage { : optimizer_name_(optimizer_name), stage_name_(stage_name), ctx_(ctx) {} virtual ~GraphOptimizerStage() = default; + const string& stage_name() const { return stage_name_; } + const string& optimizer_name() const { return optimizer_name_; } + // Check if we should try to simplify node. Returning true doesn't // guarantee that node will be simplified. // @@ -179,6 +176,64 @@ class GraphOptimizerStage { const GraphOptimizerContext ctx_; }; +template <typename Result> +class GraphOptimizerStagePipeline { + public: + // Break predicate specifies if a pipeline should stop early, and not pass + // a node to the next registered optimizer stage, typically that should be the + // case when a stage successfully optimized a node, and it wants to yield + // control to the optimizer. + explicit GraphOptimizerStagePipeline( + const std::function<bool(const Result&)> break_predicate) + : break_predicate_(break_predicate) {} + + // Add a stage to the pipeline. It should be called with the arguments for the + // stage constructor: + // + // pipeline.AddStage<FooStage>(constructor_arg1, constructor_arg2); + // + // Returns a reference to the added stage. + template <typename T, typename... Args> + T& AddStage(Args&&... args) { + auto stage = new T(std::forward<Args>(args)...); + stages_.push_back(std::unique_ptr<T>(stage)); + return *stage; + } + + // Pass a node through all registered optimizer stages, until break predicate + // is true. + // + // Return true, if pipeline exited after a break predicate was evaluated as + // 'true', which typically means that a node was optimized by one of the + // registered stages. + // + // Return false, if node was not optimized by any of registered stages. + bool PassThroughAllStages(NodeDef* node, Result* result) { + for (auto& stage : stages_) { + if (stage->IsSupported(node)) { + const Status stage_status = stage->TrySimplify(node, result); + // Each stage must be "error safe" (just like exception safe). In + // case of any error it must leave optimized graph unmodified. + if (!stage_status.ok()) { + LOG(WARNING) << "Failed to run optimizer " << stage->optimizer_name() + << ", stage " << stage->stage_name() + << ". Error: " << stage_status.error_message(); + } + if (break_predicate_(*result)) return true; + } + } + return false; + } + + std::size_t NumStages() { return stages_.size(); } + + private: + std::vector<std::unique_ptr<GraphOptimizerStage<Result>>> stages_; + std::function<bool(const Result&)> break_predicate_; + + TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizerStagePipeline); +}; + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc index 416327e622..3f5ab87a5a 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc @@ -58,8 +58,8 @@ TEST_F(GraphOptimizerStageTest, ParseNodeNameAndScope_InScope) { TEST_F(GraphOptimizerStageTest, OptimizedNodeName) { GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr, /*optimized_graph*/ nullptr, - /*graph_properties*/ nullptr, /*node_name*/ nullptr, - /*frame_map*/ nullptr); + /*graph_properties*/ nullptr, + /*node_name*/ nullptr); FakeOptimizerStage stage("my_opt", "my_stg", ctx); const auto node = ParseNodeScopeAndName("a/b/c/Add"); @@ -94,8 +94,7 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) { GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr, /*optimized_graph*/ &item.graph, /*graph_properties*/ &properties, - /*node_name*/ &node_map, - /*frame_map*/ nullptr); + /*node_name*/ &node_map); FakeOptimizerStage stage("my_opt", "my_stg", ctx); NodeDef* add_node; @@ -134,8 +133,7 @@ TEST_F(GraphOptimizerStageTest, AddNodes) { GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr, /*optimized_graph*/ &item.graph, /*graph_properties*/ &properties, - /*node_name*/ &node_map, - /*frame_map*/ nullptr); + /*node_name*/ &node_map); FakeOptimizerStage stage("my_opt", "my_stg", ctx); NodeDef* add_node; @@ -165,4 +163,4 @@ TEST_F(GraphOptimizerStageTest, AddNodes) { } // namespace } // end namespace grappler -} // end namespace tensorflow
\ No newline at end of file +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 254c1edf7b..308eecd420 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -2119,6 +2119,10 @@ Status LayoutOptimizer::Tune(const GrapplerItem& item, Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { + if (cluster == nullptr) { + return errors::InvalidArgument("cluster == nullptr"); + } + if (GetNumGPUs(*cluster) < 1) { // LayoutOptimizer is currently only tuned for GPU. *output = item.graph; diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index ad655db727..5723e397ab 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/loop_optimizer.h" #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/utils/colocation.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status.h" @@ -44,16 +45,15 @@ int64 NumEdges(const GraphDef& graph) { } string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) { - return strings::StrCat("Graph size before: ", before.node_size(), " nodes, ", - NumEdges(before), - " edges. Graph size after: ", after.node_size(), - " nodes, ", NumEdges(after), " edges."); + return strings::StrCat("Graph size after: ", after.node_size(), " nodes (", + after.node_size() - before.node_size(), "), ", + NumEdges(after), " edges (", + NumEdges(after) - NumEdges(before), ")"); } } // namespace std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer( const string& optimizer) { - VLOG(1) << "Adding graph optimization pass: " << optimizer; std::unique_ptr<GraphOptimizer> graph_optimizer; if (optimizer == "pruning") { graph_optimizer.reset(new ModelPruner()); @@ -171,46 +171,58 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, return Status::OK(); } + // Some optimizers should be run only once. + const std::set<string> run_once_optimizers = {"layout"}; bool already_optimized = false; - for (const auto& optimizer : optimizers) { - if (!already_optimized) { - Status status = optimizer->Optimize(cluster, item, optimized_graph); - string result; - if (!status.ok()) { - VLOG(1) << "Not able to apply optimizer " << optimizer->name() - << ". Return status: " << status.ToString(); - result = status.ToString(); - } else { - already_optimized = true; - result = strings::StrCat( - "OK. ", PrintSizesBeforeAfter(item.graph, *optimized_graph)); + const int num_iterations = + cfg_.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS + ? 1 + : cfg_.meta_optimizer_iterations(); + for (int iteration = 0; iteration < num_iterations; ++iteration) { + VLOG(1) << "Starting optimization iteration " << iteration + 1; + for (const auto& optimizer : optimizers) { + if (iteration > 0 && run_once_optimizers.count(optimizer->name())) { + continue; } - result_.push_back(std::make_pair(optimizer->name(), result)); - VLOG(1) << "Optimizer " << optimizer->name() - << " return status: " << result; - } else { - GrapplerItem optimized_item(item, std::move(*optimized_graph)); - Status status = - optimizer->Optimize(cluster, optimized_item, optimized_graph); - string result; - if (!status.ok()) { - VLOG(1) << "Not able to apply optimizer " << optimizer->name() - << ". Return status: " << status.ToString(); - optimized_graph->Swap(&optimized_item.graph); - result = status.ToString(); + if (!already_optimized) { + Status status = optimizer->Optimize(cluster, item, optimized_graph); + string result; + if (!status.ok()) { + VLOG(1) << "Not able to apply optimizer " << optimizer->name() + << ". Return status: " << status.ToString(); + result = status.ToString(); + } else { + already_optimized = true; + result = strings::StrCat( + "OK. ", PrintSizesBeforeAfter(item.graph, *optimized_graph)); + } + result_.push_back(std::make_pair(optimizer->name(), result)); + VLOG(1) << "Optimizer " << optimizer->name() + << " return status: " << result; } else { - result = strings::StrCat( - "OK. ", - PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph)); + GrapplerItem optimized_item(item, std::move(*optimized_graph)); + Status status = + optimizer->Optimize(cluster, optimized_item, optimized_graph); + string result; + if (!status.ok()) { + VLOG(1) << "Not able to apply optimizer " << optimizer->name() << ": " + << status.ToString(); + optimized_graph->Swap(&optimized_item.graph); + result = status.ToString(); + } else { + result = strings::StrCat( + optimizer->name(), ": ", + PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph)); + } + result_.push_back(std::make_pair(optimizer->name(), result)); + VLOG(1) << result; } - result_.push_back(std::make_pair(optimizer->name(), result)); - VLOG(1) << "Optimizer " << optimizer->name() - << " return status: " << result; } } if (already_optimized) { TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph)); + ReassignColocation(optimized_graph); // Make sure that the optimizers preserved the graph version and library. DCHECK_GE(optimized_graph->library().function_size(), item.graph.library().function_size()); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 536347d834..d9a386b9be 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -72,6 +72,20 @@ TEST(MetaOptimizerTest, RunsCustomOptimizer) { EXPECT_TRUE(TestOptimizer::IsOptimized()); } +TEST(MetaOptimizerTest, RunOptimizersTwice) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + RewriterConfig rewriter_config; + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/optimizers/symbolic_shapes.h index a9dcf44e23..eb79bab314 100644 --- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h +++ b/tensorflow/core/grappler/optimizers/symbolic_shapes.h @@ -31,8 +31,8 @@ bool IsUnknown(const TensorShapeProto::Dim& dim); bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape); bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties); -// Shapes are symbolically equal, if they have the same rank, they are -// they are known or symbolically defined, and have matching dimensions. +// Shapes are symbolically equal, if they have the same rank, they are known or +// symbolically defined, and have matching dimensions. bool ShapesSymbolicallyEqual(const TensorShapeProto& left, const TensorShapeProto& right); bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left, diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 86a6d5000d..5893f286ed 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -255,6 +255,14 @@ int NumOutputs(const NodeDef& node, GraphDef* graph) { return num_outputs; } +bool HasControlInputs(const NodeDef& node) { + int num_inputs = node.input_size(); + if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) { + return true; + } + return false; +} + int NumNonControlInputs(const NodeDef& node) { int num_inputs = node.input_size(); for (const string& input : node.input()) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 7aa31939f5..11555d712a 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -138,6 +138,9 @@ string AsControlDependency(const string& node); // some of the outputs may be unconnected. int NumOutputs(const NodeDef& node, GraphDef* graph); +// Returns true iff the node has at least one control input. +bool HasControlInputs(const NodeDef& node); + // Number of connected non-control inputs. int NumNonControlInputs(const NodeDef& node); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index baf24c2505..7419c26dff 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -181,3 +181,28 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) + +cc_library( + name = "colocation", + srcs = ["colocation.cc"], + hdrs = ["colocation.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:utils", + ], +) + +tf_cc_test( + name = "colocation_test", + size = "small", + srcs = ["colocation_test.cc"], + deps = [ + ":colocation", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) diff --git a/tensorflow/core/grappler/utils/colocation.cc b/tensorflow/core/grappler/utils/colocation.cc new file mode 100644 index 0000000000..0573e0a830 --- /dev/null +++ b/tensorflow/core/grappler/utils/colocation.cc @@ -0,0 +1,122 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/colocation.h" + +#include <cstring> +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { + +namespace { + +// Find root node of the colocation group. +// The map is mapping from one node name to its parent. node_name is the +// starting node to search. By iteratively following the path from child to +// parent, we can find the root node for the colocation group that node_name +// belongs to. +string GetColocationGroupRoot(std::unordered_map<string, string>* map, + const string& node_name) { + if (map->find(node_name) == map->end()) { + // If node_name is not in the map, we create a new root node which points + // to itself. + map->insert({node_name, node_name}); + return node_name; + } + string cur = node_name; + while ((*map)[cur] != cur) { + // Backtracing the map until we reach the root node. + cur = (*map)[cur]; + } + return cur; +} + +// Merge two colocation groups into one. +// left and right is the root node of two colocation groups respectively. +void MergeColocationGroup(std::unordered_map<string, string>* map, + const string& left, const string& right) { + // Do nothing if left or right node is not in the map. + if (map->find(left) == map->end() || map->find(right) == map->end()) { + return; + } + if (left != right) { + // Make the right node a child of the left node, which merges the two + // groups. + map->at(right) = left; + } +} +} // namespace + +// Use of disjoint set algorithm to build the colocation groups from the input +// graph. The core data structure in use is a hash map from one node to its +// parent node. Whenever we see two nodes colocate with each other, we merge +// their colocation groups together. After we traverse all colocation pairs +// in the graph, we will have several disjoint sets. Then we pick the root node +// of each disjoint set as the representative node, and let all other nodes in +// the group colocate with the representative node. +void ReassignColocation(GraphDef* graph) { + constexpr char kClassAttr[] = "_class"; + constexpr char kColocPrefix[] = "loc:@"; + + // A hashmap that maps from a node name to its parent node name. + std::unordered_map<string, string> coloc_groups; + NodeMap node_map(graph); + for (const auto& node : graph->node()) { + auto iter = node.attr().find(kClassAttr); + if (iter != node.attr().end() && iter->second.has_list()) { + for (const auto& str : iter->second.list().s()) { + size_t pos = str.find(kColocPrefix); + if (pos == 0) { + // After we find a colocation, update the colocation groups. + string colocate_node = str.substr(pos + strlen(kColocPrefix)); + MergeColocationGroup( + &coloc_groups, GetColocationGroupRoot(&coloc_groups, node.name()), + GetColocationGroupRoot(&coloc_groups, colocate_node)); + } + } + } + } + + // We use the root node of each colocation groups as its representative + // node. For each node in one group, colocate with the representative node + // if the node is in the graph. + for (const auto& pair : coloc_groups) { + if (pair.first != pair.second) { + // This is a child node. + NodeDef* node = node_map.GetNode(pair.first); + if (node) { + // Colocate this node with the root node. + AttrValue new_value; + new_value.mutable_list()->add_s( + kColocPrefix + GetColocationGroupRoot(&coloc_groups, pair.first)); + node->mutable_attr()->erase(kClassAttr); + node->mutable_attr()->insert({kClassAttr, new_value}); + } + } else { + // This is a root node. Clear the _class attribute. + NodeDef* node = node_map.GetNode(pair.first); + if (node) { // root node should always exist in the graph as guaranteed + // by order of merging. Just put check here to ensure safety. + node->mutable_attr()->erase(kClassAttr); + } + } + } +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/colocation.h b/tensorflow/core/grappler/utils/colocation.h new file mode 100644 index 0000000000..6062db6102 --- /dev/null +++ b/tensorflow/core/grappler/utils/colocation.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_ + +#include <unordered_map> +#include "tensorflow/core/framework/graph.pb.h" + +namespace tensorflow { +namespace grappler { + +// Evaluates the colocation relation in the graph and rewrites the new +// colocation relation in the graph. We scan the graph nodes sequentially, and +// builds a disjoint-sets of nodes (within each disjoint-set the nodes are +// colocated with each other). We then select the root node of each set as a +// representative node, and then colocate each node within the set (should also +// exist in graph) with the representative node. +// Note that there is current one situation this function can't handle: +// Node A colocates with X, node B colocates with Y, X colocates with Y but +// X, Y are removed from graph. In this case we can't know A colocates with B. +void ReassignColocation(GraphDef* graph); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_ diff --git a/tensorflow/core/grappler/utils/colocation_test.cc b/tensorflow/core/grappler/utils/colocation_test.cc new file mode 100644 index 0000000000..6638364240 --- /dev/null +++ b/tensorflow/core/grappler/utils/colocation_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/colocation.h" + +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +class ColocationTest : public ::testing::Test {}; + +bool VerifyNodeHasColocation(const NodeDef& ndef, const string& coloc) { + if (ndef.attr().empty()) { + return false; + } + if (ndef.attr().find("_class") == ndef.attr().end()) { + return false; + } + return ndef.attr().at("_class").list().s(0) == coloc; +} + +TEST(ColocationTest, ReassignColocation_SingleNode) { + // Node A colocates with B, but node B is not in the graph. + // A + // | + // | + // [B] + + NodeDef ndef; + const Status status = + NodeDefBuilder("A", "Const").Attr("_class", {"loc:@B"}).Finalize(&ndef); + TF_EXPECT_OK(status); + GraphDef gdef = test::function::GDef({ndef}); + + EXPECT_EQ(1, gdef.node_size()); + EXPECT_EQ(1, gdef.node(0).attr_size()); + + ReassignColocation(&gdef); + + // Validates that node A's colocation info is cleared. + EXPECT_EQ(1, gdef.node_size()); + EXPECT_EQ(0, gdef.node(0).attr_size()); +} + +TEST(ColocationTest, ReassignColocation_MultiNode_SingleGroup) { + // Node A, B, C colocate with X. D colocates with C. E colocates with D. + // Node X is not in the graph. + // A B C---D---E + // | | | + // | | | + // +--[X]--+ + // After re-assign of colocation, A, B, C, D should colocate with E. + // A B C D + // | | | | + // | | | | + // +---+-E-+---+ + + NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e; + Status status = + NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a); + TF_EXPECT_OK(status); + status = + NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b); + TF_EXPECT_OK(status); + status = + NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c); + TF_EXPECT_OK(status); + status = + NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d); + TF_EXPECT_OK(status); + status = + NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e); + TF_EXPECT_OK(status); + GraphDef gdef = + test::function::GDef({ndef_a, ndef_b, ndef_c, ndef_d, ndef_e}); + + EXPECT_EQ(5, gdef.node_size()); + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X")); // A + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X")); // B + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X")); // C + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C")); // D + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D")); // E + + ReassignColocation(&gdef); + + EXPECT_EQ(5, gdef.node_size()); + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E")); // A + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E")); // B + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E")); // C + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E")); // D + EXPECT_EQ(0, gdef.node(4).attr_size()); // E +} + +TEST(ColocationTest, ReassignColocation_MultiNode_MultiGroup) { + // Before re-assign: + // Node A, B, C colocate with X. D colocates with C. E colocates with D. + // Node U, V colocates with W. Node X, W are not in the graph: + // A B C---D---E + // | | | + // | | | + // +--[X]--+ + // + // U V + // | | + // | | + // +--[W]--+ + // + // After re-assign: + // A, B, C, D should colocate with E. U should colocate with V. + // A B C D + // | | | | + // | | | | + // +---+-E-+---+ + // + // U + // | + // | + // V + + NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v; + Status status = + NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a); + TF_EXPECT_OK(status); + status = + NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b); + TF_EXPECT_OK(status); + status = + NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c); + TF_EXPECT_OK(status); + status = + NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d); + TF_EXPECT_OK(status); + status = + NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e); + TF_EXPECT_OK(status); + status = + NodeDefBuilder("U", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_u); + TF_EXPECT_OK(status); + status = + NodeDefBuilder("V", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_v); + TF_EXPECT_OK(status); + GraphDef gdef = test::function::GDef( + {ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v}); + + EXPECT_EQ(7, gdef.node_size()); + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X")); // A + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X")); // B + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X")); // C + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C")); // D + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D")); // E + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@W")); // U + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(6), "loc:@W")); // V + + ReassignColocation(&gdef); + + EXPECT_EQ(7, gdef.node_size()); + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E")); // A + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E")); // B + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E")); // C + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E")); // D + EXPECT_EQ(0, gdef.node(4).attr_size()); // E + EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@V")); // U + EXPECT_EQ(0, gdef.node(6).attr_size()); // V +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h index 3bc7bea454..e1394b9c35 100644 --- a/tensorflow/core/grappler/utils/grappler_test.h +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -57,6 +57,15 @@ class GrapplerTest : public ::testing::Test { // Count nodes of the given op-type in a graph. int CountOpNodes(const GraphDef& graph, const string& op); + // Get a random tansor with given shape. + template <DataType DTYPE> + Tensor GenerateRandomTensor(const TensorShape& shape) const { + typedef typename EnumToDataType<DTYPE>::Type T; + Tensor tensor(DTYPE, shape); + tensor.flat<T>() = tensor.flat<T>().random(); + return tensor; + } + private: SessionOptions options_; }; diff --git a/tensorflow/core/kernels/assign_op.h b/tensorflow/core/kernels/assign_op.h index a312e8e8a4..2ed1628bf1 100644 --- a/tensorflow/core/kernels/assign_op.h +++ b/tensorflow/core/kernels/assign_op.h @@ -77,7 +77,8 @@ class AssignOp : public OpKernel { // 1. Try to reuse the rhs. std::unique_ptr<Tensor> input_alias = context->forward_input( - 1, old_lhs.dtype(), old_lhs.shape(), DEVICE_MEMORY, attr); + 1, OpKernelContext::Params::kNoReservation /*output_index*/, + old_lhs.dtype(), old_lhs.shape(), DEVICE_MEMORY, attr); if (input_alias != nullptr) { // Transfer ownership to the ref. context->replace_ref_input(0, *input_alias.release(), diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc index 39ef8ee3ac..4485152e96 100644 --- a/tensorflow/core/kernels/data_format_ops.cc +++ b/tensorflow/core/kernels/data_format_ops.cc @@ -37,25 +37,37 @@ class DataFormatDimMapOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); string dst_format; OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); + OP_REQUIRES(context, src_format.size() == 4, + errors::InvalidArgument(strings::StrCat( + "Source format must of length 4, received src_format = ", + src_format))); OP_REQUIRES( - context, src_format == "NHWC", + context, dst_format.size() == 4, errors::InvalidArgument(strings::StrCat( - "Current implementation doesn't support source data format ", - src_format))); - OP_REQUIRES(context, dst_format == "NCHW", - errors::InvalidArgument(strings::StrCat( - "Current implementation doesn't support dst data format ", - dst_format))); + "Destination format must of length 4, received dst_format = ", + dst_format))); + dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())}); + for (int i = 0; i < src_format.size(); ++i) { + for (int j = 0; j < dst_format.size(); ++j) { + if (dst_format[j] == src_format[i]) { + dst_idx_.vec<int>()(i) = j; + break; + } + } + } } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); - Tensor* output = nullptr; + Tensor* output; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(), - input.flat<T>(), output->flat<T>()); + input.flat<T>(), output->flat<T>(), + dst_idx_.vec<int>()); } + + Tensor dst_idx_; }; template <typename Device, typename T> @@ -147,11 +159,11 @@ TF_CALL_int64(REGISTER_KERNEL); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void DataFormatDimMap<GPUDevice, T>::operator()( \ - const GPUDevice& d, typename TTypes<T>::ConstFlat x, \ - typename TTypes<T>::Flat y); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void DataFormatDimMap<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstFlat x, \ + typename TTypes<T>::Flat y, const TTypes<int>::Vec dst); \ extern template struct DataFormatDimMap<GPUDevice, T>; #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); TF_CALL_int32(DECLARE_GPU_SPECS); diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h index 2ccc919586..1ca144cb40 100644 --- a/tensorflow/core/kernels/data_format_ops.h +++ b/tensorflow/core/kernels/data_format_ops.h @@ -27,15 +27,25 @@ namespace functor { template <typename Device, typename T> struct DataFormatDimMap { void operator()(const Device& d, typename TTypes<T>::ConstFlat x, - typename TTypes<T>::Flat y) { + typename TTypes<T>::Flat y, const TTypes<int>::Vec dst) { auto zero = x.constant(0); auto one = x.constant(1); - auto three = x.constant(3); + auto two = x.constant(2); + + auto f_zero = x.constant(dst(0)); + auto f_one = x.constant(dst(1)); + auto f_two = x.constant(dst(2)); + auto f_three = x.constant(dst(3)); + auto four = x.constant(4); auto x_mod = (x + four) % 4; + auto is_zero = (x_mod == zero); - auto is_three = (x_mod == three); - y.device(d) = is_zero.select(zero, is_three.select(one, x_mod + one)); + auto is_one = (x_mod == one); + auto is_two = (x_mod == two); + + y.device(d) = is_zero.select( + f_zero, is_one.select(f_one, is_two.select(f_two, f_three))); } }; diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index b687088db1..911aa3a78f 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -21,10 +20,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/mutex.h" -namespace tensorflow { +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/stream.h" +#endif // GOOGLE_CUDA +namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice; typedef FunctionLibraryRuntime::Handle FHandle; @@ -106,11 +107,9 @@ void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, opts->runner = ctx->runner(); } -} // end namespace - -class FunctionalIf : public AsyncOpKernel { +class IfOp : public AsyncOpKernel { public: - explicit FunctionalIf(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { auto lib = ctx->function_library(); OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library")); const NameAttrList* func; @@ -120,7 +119,7 @@ class FunctionalIf : public AsyncOpKernel { OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_)); } - ~FunctionalIf() override {} + ~IfOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { bool cond; @@ -134,8 +133,7 @@ class FunctionalIf : public AsyncOpKernel { class State { public: - State(FunctionalIf* kernel, OpKernelContext* ctx, bool cond, - DoneCallback done) + State(IfOp* kernel, OpKernelContext* ctx, bool cond, DoneCallback done) : kernel_(kernel), ctx_(ctx), cond_(cond), @@ -168,7 +166,7 @@ class FunctionalIf : public AsyncOpKernel { } private: - FunctionalIf* const kernel_; + IfOp* const kernel_; OpKernelContext* const ctx_; const bool cond_; const DoneCallback done_; @@ -179,18 +177,22 @@ class FunctionalIf : public AsyncOpKernel { }; }; -REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), FunctionalIf); +// TODO(drpng): remove this. +REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), IfOp); REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"), - FunctionalIf); + IfOp); + +REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp); +REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp); -class FunctionalWhile : public AsyncOpKernel { +class WhileOp : public AsyncOpKernel { public: - explicit FunctionalWhile(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_)); } - ~FunctionalWhile() override {} + ~WhileOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { auto lib = ctx->function_library(); @@ -234,7 +236,7 @@ class FunctionalWhile : public AsyncOpKernel { class State { public: - State(FunctionalWhile* kernel, OpKernelContext* ctx, FHandle cond_handle, + State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, FHandle body_handle, DoneCallback done) : kernel_(kernel), ctx_(ctx), @@ -253,7 +255,7 @@ class FunctionalWhile : public AsyncOpKernel { void Start() { EvalCond(); } private: - FunctionalWhile* const kernel_; + WhileOp* const kernel_; OpKernelContext* const ctx_; const FHandle cond_handle_; const FHandle body_handle_; @@ -316,7 +318,152 @@ class FunctionalWhile : public AsyncOpKernel { } }; }; -REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), FunctionalWhile); -REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), FunctionalWhile); +// TODO(drpng): remove these. +REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp); +REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp); + +REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp); +REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp); + +Status GetScalar(OpKernelContext* ctx, int index, int32* value, + const char* label) { + Tensor t = ctx->input(index); + if (!TensorShapeUtils::IsScalar(t.shape())) { + return errors::InvalidArgument(label, " must be a scalar, but ", + t.shape().DebugString()); + } + *value = t.scalar<int32>()(); + return Status::OK(); +} + +class ForOp : public AsyncOpKernel { + public: + explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + auto lib = ctx->function_library(); + OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library")); + const NameAttrList* func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &func)); + OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_)); + } + + ~ForOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + (new State(this, ctx, done))->Start(); + } + + private: + FHandle body_handle_; + + class State { + public: + State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done) + : kernel_(kernel), + ctx_(ctx), + done_(std::move(done)), + lib_(CHECK_NOTNULL(ctx_->function_library())), + args_(1 + ctx_->num_inputs() - 3) { + args_[0] = Tensor(DT_INT32, {}); + iter_ = &args_[0].scalar<int32>()(); + + const int32 num_loop_inputs = ctx_->num_inputs() - 3; + rets_.reserve(num_loop_inputs); + for (int i = 0; i < num_loop_inputs; ++i) { + rets_.push_back(ctx_->input(3 + i)); + } + } + + ~State() {} + + void Start() { + Status s = StartLoop(); + if (!s.ok()) Finish(s); + } + + private: + ForOp* const kernel_; + OpKernelContext* const ctx_; + const DoneCallback done_; + FunctionLibraryRuntime* const lib_; + FunctionLibraryRuntime::Options opts_; + TensorVec args_; + TensorVec rets_; + + int32* iter_; // points to args_[0]. + int32 limit_; + int32 delta_; + + // If an error e is returned, caller must call Finish(e). + // If OK is returned, the async loop execution has been started. + Status StartLoop() { + SetRunOptions(ctx_, &opts_, false /* always_collect_stats */); + + TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start")); + TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit")); + TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta")); + + if ((delta_ > 0 && *iter_ <= limit_) || + (delta_ < 0 && *iter_ >= limit_) || + (delta_ == 0 && *iter_ == limit_)) { + RunNext(); + return Status::OK(); + } else { + return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_, + " ", limit_, " ", delta_); + } + } + + void RunNext() { + bool done_loop; + if (delta_ > 0) { + done_loop = *iter_ >= limit_; + } else { + done_loop = *iter_ <= limit_; + } + if (done_loop) { + Finish(Status::OK()); + return; + } + + if (rets_.size() >= args_.size()) { + Finish(errors::InvalidArgument( + "For loop body returned ", rets_.size(), + " arguments. Expected: ", args_.size() - 1)); + return; + } + for (int i = 0; i < rets_.size(); ++i) { + args_[1 + i] = std::move(rets_[i]); + } + rets_.clear(); + lib_->Run(opts_, kernel_->body_handle_, args_, &rets_, + [this](const Status& s) { + if (s.ok()) { + *iter_ += delta_; + RunNext(); + } else { + Finish(s); + } + }); + } + + void Finish(Status s) { + if (s.ok()) { + s = SetOutputs(kernel_, ctx_, rets_); + } + ctx_->SetStatus(s); + done_(); + delete this; + } + }; +}; + +REGISTER_KERNEL_BUILDER(Name("For").Device(DEVICE_CPU), ForOp); +REGISTER_KERNEL_BUILDER(Name("For") + .Device(DEVICE_GPU) + .HostMemory("start") + .HostMemory("limit") + .HostMemory("delta"), + ForOp); +} // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index e3872fee0e..57b7798ba0 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/kernels/initializable_lookup_table.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/hash/hash.h" @@ -62,8 +63,7 @@ class MutableHashTableOfScalars final : public LookupInterface { mutex_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( - table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)), - default_val); + table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); } return Status::OK(); @@ -78,9 +78,8 @@ class MutableHashTableOfScalars final : public LookupInterface { table_.clear(); } for (int64 i = 0; i < key_values.size(); ++i) { - gtl::InsertOrUpdate(&table_, - SubtleMustCopyUnlessStringOrFloat(key_values(i)), - SubtleMustCopyUnlessStringOrFloat(value_values(i))); + gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)), + SubtleMustCopyIfIntegral(value_values(i))); } return Status::OK(); } @@ -172,8 +171,8 @@ class MutableHashTableOfTensors final : public LookupInterface { mutex_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { - ValueArray* value_vec = gtl::FindOrNull( - table_, SubtleMustCopyUnlessStringOrFloat(key_values(i))); + ValueArray* value_vec = + gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i))); if (value_vec != nullptr) { for (int64 j = 0; j < value_dim; j++) { value_values(i, j) = value_vec->at(j); @@ -203,8 +202,8 @@ class MutableHashTableOfTensors final : public LookupInterface { V value = value_values(i, j); value_vec.push_back(value); } - gtl::InsertOrUpdate( - &table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)), value_vec); + gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)), + value_vec); } return Status::OK(); } @@ -379,15 +378,14 @@ class MutableDenseHashTable final : public LookupInterface { for (int64 j = 0; j < value_size; ++j) { // TODO(andreasst): check if we can get rid of SubtleMustCopy // here and elsewhere in this file. - value_matrix(i, j) = SubtleMustCopyUnlessStringOrFloat( - value_buckets_matrix(bucket_index, j)); + value_matrix(i, j) = + SubtleMustCopyIfIntegral(value_buckets_matrix(bucket_index, j)); } break; } if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_matrix, 0)) { for (int64 j = 0; j < value_size; ++j) { - value_matrix(i, j) = - SubtleMustCopyUnlessStringOrFloat(default_flat(j)); + value_matrix(i, j) = SubtleMustCopyIfIntegral(default_flat(j)); } break; } @@ -531,7 +529,7 @@ class MutableDenseHashTable final : public LookupInterface { if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) { for (int64 j = 0; j < value_size; ++j) { value_buckets_matrix(bucket_index, j) = - SubtleMustCopyUnlessStringOrFloat(value_matrix(i, j)); + SubtleMustCopyIfIntegral(value_matrix(i, j)); } break; } @@ -539,11 +537,11 @@ class MutableDenseHashTable final : public LookupInterface { ++num_entries_; for (int64 j = 0; j < key_size; ++j) { key_buckets_matrix(bucket_index, j) = - SubtleMustCopyUnlessStringOrFloat(key_matrix(i, j)); + SubtleMustCopyIfIntegral(key_matrix(i, j)); } for (int64 j = 0; j < value_size; ++j) { value_buckets_matrix(bucket_index, j) = - SubtleMustCopyUnlessStringOrFloat(value_matrix(i, j)); + SubtleMustCopyIfIntegral(value_matrix(i, j)); } break; } @@ -849,6 +847,7 @@ REGISTER_KERNEL(string, int64); REGISTER_KERNEL(int64, string); REGISTER_KERNEL(string, bool); REGISTER_KERNEL(int64, float); +REGISTER_KERNEL(int64, Variant); #undef REGISTER_KERNEL @@ -899,6 +898,7 @@ REGISTER_KERNEL(int64, double); REGISTER_KERNEL(string, float); REGISTER_KERNEL(string, bool); REGISTER_KERNEL(int64, bool); +REGISTER_KERNEL(int64, Variant); #undef REGISTER_KERNEL diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index 3657fd5b6a..29a0cc91fe 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -125,19 +125,21 @@ namespace lookup { // integral types. However non-integer variables are not allowed and therefore // the local copy is unnecessary. template <typename T> -T SubtleMustCopyUnlessStringOrFloat(const T& value) { +T SubtleMustCopyIfIntegral(const T& value) { return internal::SubtleMustCopy(value); } -inline const string& SubtleMustCopyUnlessStringOrFloat(const string& value) { +inline const string& SubtleMustCopyIfIntegral(const string& value) { return value; } -inline const float SubtleMustCopyUnlessStringOrFloat(const float value) { +inline const float SubtleMustCopyIfIntegral(const float value) { return value; } + +inline const double SubtleMustCopyIfIntegral(const double value) { return value; } -inline const double SubtleMustCopyUnlessStringOrFloat(const double value) { +inline const Variant& SubtleMustCopyIfIntegral(const Variant& value) { return value; } @@ -204,8 +206,8 @@ class HashTable : public InitializableLookupTable { const auto key_values = keys.flat<K>(); const auto value_values = values.flat<V>(); for (int64 i = 0; i < key_values.size(); ++i) { - const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i)); - const V value = SubtleMustCopyUnlessStringOrFloat(value_values(i)); + const K key = SubtleMustCopyIfIntegral(key_values(i)); + const V value = SubtleMustCopyIfIntegral(value_values(i)); const V& previous_value = gtl::LookupOrInsert(table_.get(), key, value); if (previous_value != value) { return errors::FailedPrecondition( @@ -224,8 +226,7 @@ class HashTable : public InitializableLookupTable { for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( - *table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)), - default_val); + *table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); } return Status::OK(); } diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h index ad606803ee..6c19f9841c 100644 --- a/tensorflow/core/kernels/queue_op.h +++ b/tensorflow/core/kernels/queue_op.h @@ -43,6 +43,7 @@ class QueueOp : public ResourceOpKernel<QueueInterface> { void Compute(OpKernelContext* context) override { ResourceOpKernel<QueueInterface>::Compute(context); + mutex_lock l(mu_); if (resource_ && context->track_allocations()) { context->record_persistent_memory_allocation(resource_->MemoryUsed()); } diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index d1675f27dd..f49a05c70a 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -250,8 +250,9 @@ class AssignVariableOp : public OpKernel { // Copying is unnecessary if we are the last user of the value // tensor, we can just adopt the input tensor's buffer instead. - std::unique_ptr<Tensor> input_alias = - context->forward_input(1, dtype_, value.shape(), DEVICE_MEMORY, attr); + std::unique_ptr<Tensor> input_alias = context->forward_input( + 1, OpKernelContext::Params::kNoReservation /*output_index*/, dtype_, + value.shape(), DEVICE_MEMORY, attr); mutex_lock ml(*variable->mu()); variable->is_initialized = true; if (input_alias) { @@ -363,9 +364,36 @@ class AssignVariableOp<Device, Variant> : public OpKernel { DataTypeString(variable->tensor()->dtype()), " got ", DataTypeString(DT_VARIANT))); + AllocatorAttributes attr; + attr.set_on_host(true); + + // Copying is unnecessary if we are the last user of the value + // tensor, we can just adopt the input tensor's buffer instead. + // Note that Variant objects themselves always reside on host. + std::unique_ptr<Tensor> input_alias = context->forward_input( + 1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT, + value.shape(), HOST_MEMORY, attr); + mutex_lock ml(*variable->mu()); variable->is_initialized = true; *variable->tensor() = Tensor(DT_VARIANT, value.shape()); + + if (input_alias) { + *variable->tensor() = *input_alias; + return; + } + + // Need to copy, but maybe we can re-use variable's buffer? + if (!variable->tensor()->RefCountIsOne() || + !variable->tensor()->shape().IsSameSize(value.shape())) { + PersistentTensor unused; + Tensor* tmp; + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_VARIANT, value.shape(), + &unused, &tmp, attr)); + *variable->tensor() = *tmp; + } + const auto elements_in = value.flat<Variant>(); auto elements_out = variable->tensor()->flat<Variant>(); auto copy_fn = std::bind(&VariantCopyFn<Device>, context, @@ -577,7 +605,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU); #if GOOGLE_CUDA #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type) -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GATHER_GPU); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU); #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index 4b5df7aff0..4ebb7fbcc7 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -419,7 +419,7 @@ class SparseCrossOp : public OpKernel { context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()), errors::InvalidArgument( "Dense inputs should be a matrix but received shape ", - indices_list_in[i].shape().DebugString(), " at position ", i)); + dense_list_in[i].shape().DebugString(), " at position ", i)); OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size, errors::InvalidArgument("Expected batch size ", batch_size, " got ", dense_list_in[i].dim_size(0), diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc index 5bd79778a6..0b006fa2b4 100644 --- a/tensorflow/core/lib/core/stringpiece.cc +++ b/tensorflow/core/lib/core/stringpiece.cc @@ -55,6 +55,4 @@ StringPiece StringPiece::substr(size_t pos, size_t n) const { return StringPiece(data_ + pos, n); } -const StringPiece::size_type StringPiece::npos = size_type(-1); - } // namespace tensorflow diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h index 79409cce4b..835b938cbf 100644 --- a/tensorflow/core/lib/core/stringpiece.h +++ b/tensorflow/core/lib/core/stringpiece.h @@ -65,7 +65,7 @@ class StringPiece { iterator begin() const { return data_; } iterator end() const { return data_ + size_; } - static const size_t npos; + static const size_t npos = size_type(-1); // Return the ith byte in the referenced data. // REQUIRES: n < size() diff --git a/tensorflow/core/lib/io/format.cc b/tensorflow/core/lib/io/format.cc index 64852943ad..0c24c660a2 100644 --- a/tensorflow/core/lib/io/format.cc +++ b/tensorflow/core/lib/io/format.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <limits> + #include "tensorflow/core/lib/io/format.h" #include "tensorflow/core/lib/core/coding.h" @@ -84,6 +86,11 @@ Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle, // Read the block contents as well as the type/crc footer. // See table_builder.cc for the code that built this structure. size_t n = static_cast<size_t>(handle.size()); + + if (kBlockTrailerSize > std::numeric_limits<size_t>::max() - n) { + return errors::DataLoss("handle.size() too big"); + } + char* buf = new char[n + kBlockTrailerSize]; StringPiece contents; Status s = file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf); diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc index 516decc3c0..8f34baa7de 100644 --- a/tensorflow/core/lib/strings/numbers.cc +++ b/tensorflow/core/lib/strings/numbers.cc @@ -23,6 +23,7 @@ limitations under the License. #include <locale> #include <unordered_map> +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -203,7 +204,7 @@ bool safe_strto64(StringPiece str, int64* value) { int64 vlimit = kint64max; int sign = 1; - if (str.Consume("-")) { + if (str_util::ConsumePrefix(&str, "-")) { sign = -1; // Different limit for positive and negative integers. vlimit = kint64min; @@ -265,7 +266,7 @@ bool safe_strto32(StringPiece str, int32* value) { int64 vmax = kint32max; int sign = 1; - if (str.Consume("-")) { + if (str_util::ConsumePrefix(&str, "-")) { sign = -1; // Different max for positive and negative integers. ++vmax; diff --git a/tensorflow/core/lib/strings/ordered_code_test.cc b/tensorflow/core/lib/strings/ordered_code_test.cc index fee8a6f93e..ede9f4d390 100644 --- a/tensorflow/core/lib/strings/ordered_code_test.cc +++ b/tensorflow/core/lib/strings/ordered_code_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -128,7 +129,7 @@ void TestWriteAppends(T first, U second) { string encoded_first_only = encoded; OCWriteToString<U>(&encoded, second); EXPECT_NE(encoded, encoded_first_only); - EXPECT_TRUE(StringPiece(encoded).starts_with(encoded_first_only)); + EXPECT_TRUE(str_util::StartsWith(encoded, encoded_first_only)); } template <typename T> diff --git a/tensorflow/core/lib/strings/scanner.h b/tensorflow/core/lib/strings/scanner.h index d3b63357ee..c82e771368 100644 --- a/tensorflow/core/lib/strings/scanner.h +++ b/tensorflow/core/lib/strings/scanner.h @@ -18,6 +18,7 @@ limitations under the License. #include <string> #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -75,14 +76,14 @@ class Scanner { // Consume the next s.size() characters of the input, if they match <s>. If // they don't match <s>, this is a no-op. Scanner& ZeroOrOneLiteral(StringPiece s) { - cur_.Consume(s); + str_util::ConsumePrefix(&cur_, s); return *this; } // Consume the next s.size() characters of the input, if they match <s>. If // they don't match <s>, then GetResult will ultimately return false. Scanner& OneLiteral(StringPiece s) { - if (!cur_.Consume(s)) { + if (!str_util::ConsumePrefix(&cur_, s)) { error_ = true; } return *this; diff --git a/tensorflow/core/lib/wav/wav_io_test.cc b/tensorflow/core/lib/wav/wav_io_test.cc index d8a83fc464..9e41da6a20 100644 --- a/tensorflow/core/lib/wav/wav_io_test.cc +++ b/tensorflow/core/lib/wav/wav_io_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -203,7 +204,7 @@ TEST(WavIO, ChunkSizeOverflow) { wav_data_string, &decoded_audio, &decoded_sample_count, &decoded_channel_count, &decoded_sample_rate); EXPECT_FALSE(decode_status.ok()); - EXPECT_TRUE(StringPiece(decode_status.error_message()).contains("too large")) + EXPECT_TRUE(str_util::StrContains(decode_status.error_message(), "too large")) << decode_status.error_message(); } diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 7cdf36f423..10b24c2d34 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -20672,6 +20672,38 @@ op { is_stateful: true } op { + name: "For" + input_arg { + name: "start" + type: DT_INT32 + } + input_arg { + name: "limit" + type: DT_INT32 + } + input_arg { + name: "delta" + type: DT_INT32 + } + input_arg { + name: "input" + type_list_attr: "T" + } + output_arg { + name: "output" + type_list_attr: "T" + } + attr { + name: "T" + type: "list(type)" + has_minimum: true + } + attr { + name: "body" + type: "func" + } +} +op { name: "FractionalAvgPool" input_arg { name: "value" @@ -22755,6 +22787,45 @@ op { is_stateful: true } op { + name: "If" + input_arg { + name: "cond" + type_attr: "Tcond" + } + input_arg { + name: "input" + type_list_attr: "Tin" + } + output_arg { + name: "output" + type_list_attr: "Tout" + } + attr { + name: "Tcond" + type: "type" + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "then_branch" + type: "func" + } + attr { + name: "else_branch" + type: "func" + } +} +op { name: "Igamma" input_arg { name: "a" @@ -68076,6 +68147,31 @@ op { } } op { + name: "While" + input_arg { + name: "input" + type_list_attr: "T" + } + output_arg { + name: "output" + type_list_attr: "T" + } + attr { + name: "T" + type: "list(type)" + has_minimum: true + } + attr { + name: "cond" + type: "func" + } + attr { + name: "body" + type: "func" + } + is_stateful: true +} +op { name: "WholeFileReader" output_arg { name: "reader_handle" diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index 4b21fac80a..792686cae1 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -50,6 +50,7 @@ REGISTER_OP("RemoteCall") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape); +// TODO(drpng): remove this. REGISTER_OP("_If") .Input("cond: Tcond") .Input("input: Tin") @@ -76,8 +77,18 @@ else_branch: A function that takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. )doc"); -// TODO(b/37549631) setting the While Op to always be stateful is too -// conservative. +REGISTER_OP("If") + .Input("cond: Tcond") + .Input("input: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("Tin: list(type)") + .Attr("Tout: list(type)") + .Attr("then_branch: func") + .Attr("else_branch: func") + .SetShapeFn(shape_inference::UnknownShape); + +// TODO(drpng): remove this. REGISTER_OP("_While") .Input("input: T") .Output("output: T") @@ -108,4 +119,30 @@ body: A function that takes a list of tensors and returns another by T. )doc"); +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("While") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(i)); + } + return Status::OK(); + }); + +REGISTER_OP("For") + .Input("start: int32") + .Input("limit: int32") + .Input("delta: int32") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("body: func") + .SetShapeFn(shape_inference::UnknownShape); + } // end namespace tensorflow diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc index 8dcd3e815f..da38a6bc24 100644 --- a/tensorflow/core/ops/math_grad_test.cc +++ b/tensorflow/core/ops/math_grad_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" @@ -362,7 +363,7 @@ class MathGradTest : public ::testing::Test { }; void HasError(const Status& s, const string& substr) { - EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) + EXPECT_TRUE(str_util::StrContains(s.ToString(), substr)) << s << ", expected substring " << substr; } diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index ca3772e6f8..8f974d5367 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -239,20 +240,21 @@ TEST(MathOpsTest, Select_ShapeFn) { // Expect an error when the shapes can't be merged. handle_data[2]->at(0).first = shape_proto({2, 2}); - EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message()) - .contains("must be equal, but are 1 and 2")); + EXPECT_TRUE(str_util::StrContains(run_inference_for_handles().error_message(), + "must be equal, but are 1 and 2")); handle_data[2]->at(0).first = i1; // restore to valid // Expect an error when the types can't be merged. handle_data[2]->at(1).second = DT_INT64; - EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message()) - .contains("pointing to different dtypes")); + EXPECT_TRUE(str_util::StrContains(run_inference_for_handles().error_message(), + "pointing to different dtypes")); handle_data[2]->at(1).second = DT_INT32; // restore to valid // Expect an error when different numbers of tensors are merged. handle_data[2]->push_back({i1, DT_FLOAT}); - EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message()) - .contains("pointing to different numbers of tensors")); + EXPECT_TRUE( + str_util::StrContains(run_inference_for_handles().error_message(), + "pointing to different numbers of tensors")); handle_data[2]->pop_back(); // restore to valid. } diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 42a68cb712..5764976aee 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -9780,6 +9780,38 @@ op { is_stateful: true } op { + name: "For" + input_arg { + name: "start" + type: DT_INT32 + } + input_arg { + name: "limit" + type: DT_INT32 + } + input_arg { + name: "delta" + type: DT_INT32 + } + input_arg { + name: "input" + type_list_attr: "T" + } + output_arg { + name: "output" + type_list_attr: "T" + } + attr { + name: "T" + type: "list(type)" + has_minimum: true + } + attr { + name: "body" + type: "func" + } +} +op { name: "FractionalAvgPool" input_arg { name: "value" @@ -11184,6 +11216,45 @@ op { is_stateful: true } op { + name: "If" + input_arg { + name: "cond" + type_attr: "Tcond" + } + input_arg { + name: "input" + type_list_attr: "Tin" + } + output_arg { + name: "output" + type_list_attr: "Tout" + } + attr { + name: "Tcond" + type: "type" + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "then_branch" + type: "func" + } + attr { + name: "else_branch" + type: "func" + } +} +op { name: "Igamma" input_arg { name: "a" @@ -32937,6 +33008,31 @@ op { } } op { + name: "While" + input_arg { + name: "input" + type_list_attr: "T" + } + output_arg { + name: "output" + type_list_attr: "T" + } + attr { + name: "T" + type: "list(type)" + has_minimum: true + } + attr { + name: "cond" + type: "func" + } + attr { + name: "body" + type: "func" + } + is_stateful: true +} +op { name: "WholeFileReader" output_arg { name: "reader_handle" diff --git a/tensorflow/core/platform/abi.cc b/tensorflow/core/platform/abi.cc index 4df62734e9..e597a490d6 100644 --- a/tensorflow/core/platform/abi.cc +++ b/tensorflow/core/platform/abi.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/core/platform/abi.h" -#if defined(PLATFORM_WINDOWS) +#if defined(_MSC_VER) #include <windows.h> #include <cstring> #else @@ -26,19 +26,19 @@ limitations under the License. #include <memory> #include <string> -#if defined(PLATFORM_WINDOWS) +#if defined(_MSC_VER) extern "C" char* __unDName(char* output_string, const char* name, int max_string_length, void* (*p_alloc)(std::size_t), void (*p_free)(void*), unsigned short disable_flags); -#endif // defined(PLATFORM_WINDOWS) +#endif // defined(_MSC_VER) namespace tensorflow { namespace port { std::string MaybeAbiDemangle(const char* name) { -#if defined(PLATFORM_WINDOWS) +#if defined(_MSC_VER) std::unique_ptr<char> demangled{__unDName(nullptr, name, 0, std::malloc, std::free, static_cast<unsigned short>(0))}; diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 3ee7be3c4e..be84316c48 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -85,6 +85,7 @@ cc_library( ":retrying_utils", ":time_util", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@jsoncpp_git//:jsoncpp", ], @@ -263,6 +264,7 @@ tf_cc_test( deps = [ ":gcs_file_system", ":http_request_fake", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 1691826483..3c0dc13d75 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -172,7 +172,7 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket, return errors::InvalidArgument("GCS path doesn't contain a bucket name: ", fname); } - objectp.Consume("/"); + str_util::ConsumePrefix(&objectp, "/"); *object = objectp.ToString(); if (!empty_object_ok && object->empty()) { return errors::InvalidArgument("GCS path doesn't contain an object name: ", @@ -535,7 +535,8 @@ class GcsWritableFile : public WritableFile { *uploaded = 0; } else { StringPiece range_piece(received_range); - range_piece.Consume("bytes="); // May or may not be present. + str_util::ConsumePrefix(&range_piece, + "bytes="); // May or may not be present. std::vector<int64> range_parts; if (!str_util::SplitAndParseAsInts(range_piece, '-', &range_parts) || range_parts.size() != 2) { @@ -1172,7 +1173,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, // 'object_prefix', which is part of 'dirname', should be removed from // the beginning of 'name'. StringPiece relative_path(name); - if (!relative_path.Consume(object_prefix)) { + if (!str_util::ConsumePrefix(&relative_path, object_prefix)) { return errors::Internal(strings::StrCat( "Unexpected response: the returned file name ", name, " doesn't match the prefix ", object_prefix)); @@ -1201,7 +1202,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, } const string& prefix_str = prefix.asString(); StringPiece relative_path(prefix_str); - if (!relative_path.Consume(object_prefix)) { + if (!str_util::ConsumePrefix(&relative_path, object_prefix)) { return errors::Internal( "Unexpected response: the returned folder name ", prefix_str, " doesn't match the prefix ", object_prefix); diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 8516421614..2fbde9b6a7 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/gcs_file_system.h" #include <fstream> #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/cloud/http_request_fake.h" #include "tensorflow/core/platform/test.h" @@ -584,8 +585,9 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) { TF_EXPECT_OK(file->Append("content2")); const auto& status = file->Close(); EXPECT_EQ(errors::Code::ABORTED, status.code()); - EXPECT_TRUE(StringPiece(status.error_message()) - .contains("All 10 retry attempts failed. The last failure: " + EXPECT_TRUE( + str_util::StrContains(status.error_message(), + "All 10 retry attempts failed. The last failure: " "Unavailable: important HTTP error 503")) << status; } @@ -641,13 +643,12 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) { const auto& status = file->Close(); EXPECT_EQ(errors::Code::UNAVAILABLE, status.code()); EXPECT_TRUE( - StringPiece(status.error_message()) - .contains( - "Upload to gs://bucket/path/writeable.txt failed, caused by: " - "Not found: important HTTP error 410")) + str_util::StrContains(status.error_message(), + "Upload to gs://bucket/path/writeable.txt failed, " + "caused by: Not found: important HTTP error 410")) << status; - EXPECT_TRUE(StringPiece(status.error_message()) - .contains("when uploading gs://bucket/path/writeable.txt")) + EXPECT_TRUE(str_util::StrContains( + status.error_message(), "when uploading gs://bucket/path/writeable.txt")) << status; } diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc index d3f763bb3c..ee6886fef7 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/retrying_file_system.h" #include <fstream> #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -245,7 +246,7 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_AllRetriesFailed) { char scratch[10]; const auto& status = random_access_file->Read(0, 10, &result, scratch); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -399,7 +400,7 @@ TEST(RetryingFileSystemTest, NewWritableFile_AllRetriesFailed) { // Use it and check the results. const auto& status = writable_file->Sync(); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -428,7 +429,7 @@ TEST(RetryingFileSystemTest, NewReadOnlyMemoryRegionFromFile_AllRetriesFailed) { const auto& status = fs.NewReadOnlyMemoryRegionFromFile("filename.txt", &result); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -454,7 +455,7 @@ TEST(RetryingFileSystemTest, GetChildren_AllRetriesFailed) { std::vector<string> result; const auto& status = fs.GetChildren("gs://path", &result); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -481,7 +482,7 @@ TEST(RetryingFileSystemTest, GetMatchingPaths_AllRetriesFailed) { std::vector<string> result; const auto& status = fs.GetMatchingPaths("gs://path/dir", &result); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -506,7 +507,7 @@ TEST(RetryingFileSystemTest, DeleteFile_AllRetriesFailed) { std::vector<string> result; const auto& status = fs.DeleteFile("gs://path/file.txt"); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -531,7 +532,7 @@ TEST(RetryingFileSystemTest, CreateDir_AllRetriesFailed) { std::vector<string> result; const auto& status = fs.CreateDir("gs://path/newdir"); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -556,7 +557,7 @@ TEST(RetryingFileSystemTest, DeleteDir_AllRetriesFailed) { std::vector<string> result; const auto& status = fs.DeleteDir("gs://path/dir"); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -582,7 +583,7 @@ TEST(RetryingFileSystemTest, GetFileSize_AllRetriesFailed) { uint64 size; const auto& status = fs.GetFileSize("gs://path/file.txt", &size); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -605,7 +606,7 @@ TEST(RetryingFileSystemTest, RenameFile_AllRetriesFailed) { const auto& status = fs.RenameFile("old_name", "new_name"); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -630,7 +631,7 @@ TEST(RetryingFileSystemTest, Stat_AllRetriesFailed) { FileStatistics stat; const auto& status = fs.Stat("file_name", &stat); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -642,7 +643,7 @@ TEST(RetryingFileSystemTest, FileExists_AllRetriesFailed) { const auto& status = fs.FileExists("file_name"); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -677,7 +678,7 @@ TEST(RetryingFileSystemTest, IsDirectory_AllRetriesFailed) { const auto& status = fs.IsDirectory("gs://path/dir"); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } @@ -706,7 +707,7 @@ TEST(RetryingFileSystemTest, DeleteRecursively_AllRetriesFailed) { const auto& status = fs.DeleteRecursively("gs://path/dir", &undeleted_files, &undeleted_dirs); EXPECT_TRUE( - StringPiece(status.error_message()).contains("Retriable error #10")) + str_util::StrContains(status.error_message(), "Retriable error #10")) << status; } diff --git a/tensorflow/core/platform/cloud/retrying_utils_test.cc b/tensorflow/core/platform/cloud/retrying_utils_test.cc index 6eb340e094..1b6527618a 100644 --- a/tensorflow/core/platform/cloud/retrying_utils_test.cc +++ b/tensorflow/core/platform/cloud/retrying_utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/retrying_utils.h" #include <fstream> #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -31,10 +32,9 @@ TEST(RetryingUtilsTest, CallWithRetries_RetryDelays) { const auto& status = RetryingUtils::CallWithRetries(f, 500000L, sleep); EXPECT_EQ(errors::Code::ABORTED, status.code()); - EXPECT_TRUE(StringPiece(status.error_message()) - .contains("All 10 retry attempts " - "failed. The last failure: " - "Unavailable: Failed.")) + EXPECT_TRUE(str_util::StrContains( + status.error_message(), + "All 10 retry attempts failed. The last failure: Unavailable: Failed.")) << status; EXPECT_EQ(10, requested_delays.size()); diff --git a/tensorflow/core/platform/default/tracing_impl.h b/tensorflow/core/platform/default/tracing_impl.h index e813e4a17a..7834548896 100644 --- a/tensorflow/core/platform/default/tracing_impl.h +++ b/tensorflow/core/platform/default/tracing_impl.h @@ -22,7 +22,6 @@ limitations under the License. // IWYU pragma: friend third_party/tensorflow/core/platform/tracing.h #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/tracing.h" diff --git a/tensorflow/core/platform/denormal.cc b/tensorflow/core/platform/denormal.cc index 3631d9ddf9..82cbc43b4f 100644 --- a/tensorflow/core/platform/denormal.cc +++ b/tensorflow/core/platform/denormal.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <tuple> + #include "tensorflow/core/platform/denormal.h" -#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/platform.h" diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc index a2f42f44ac..b55e94d552 100644 --- a/tensorflow/core/platform/file_system.cc +++ b/tensorflow/core/platform/file_system.cc @@ -18,7 +18,6 @@ limitations under the License. #include <deque> #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -28,28 +27,6 @@ limitations under the License. namespace tensorflow { -namespace { - -constexpr int kNumThreads = 8; - -// Run a function in parallel using a ThreadPool, but skip the ThreadPool -// on the iOS platform due to its problems with more than a few threads. -void ForEach(int first, int last, const std::function<void(int)>& f) { -#if TARGET_OS_IPHONE - for (int i = first; i < last; i++) { - f(i); - } -#else - int num_threads = std::min(kNumThreads, last - first); - thread::ThreadPool threads(Env::Default(), "ForEach", num_threads); - for (int i = first; i < last; i++) { - threads.Schedule([f, i] { f(i); }); - } -#endif -} - -} // anonymous namespace - FileSystem::~FileSystem() {} string FileSystem::TranslateName(const string& name) const { @@ -94,76 +71,6 @@ bool FileSystem::FilesExist(const std::vector<string>& files, return result; } -Status FileSystem::GetMatchingPaths(const string& pattern, - std::vector<string>* results) { - results->clear(); - // Find the fixed prefix by looking for the first wildcard. - string fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\")); - string eval_pattern = pattern; - std::vector<string> all_files; - string dir = io::Dirname(fixed_prefix).ToString(); - // If dir is empty then we need to fix up fixed_prefix and eval_pattern to - // include . as the top level directory. - if (dir.empty()) { - dir = "."; - fixed_prefix = io::JoinPath(dir, fixed_prefix); - eval_pattern = io::JoinPath(dir, pattern); - } - - // Setup a BFS to explore everything under dir. - std::deque<string> dir_q; - dir_q.push_back(dir); - Status ret; // Status to return. - // children_dir_status holds is_dir status for children. It can have three - // possible values: OK for true; FAILED_PRECONDITION for false; CANCELLED - // if we don't calculate IsDirectory (we might do that because there isn't - // any point in exploring that child path). - std::vector<Status> children_dir_status; - while (!dir_q.empty()) { - string current_dir = dir_q.front(); - dir_q.pop_front(); - std::vector<string> children; - Status s = GetChildren(current_dir, &children); - ret.Update(s); - if (children.empty()) continue; - // This IsDirectory call can be expensive for some FS. Parallelizing it. - children_dir_status.resize(children.size()); - ForEach(0, children.size(), - [this, ¤t_dir, &children, &fixed_prefix, - &children_dir_status](int i) { - const string child_path = io::JoinPath(current_dir, children[i]); - // In case the child_path doesn't start with the fixed_prefix then - // we don't need to explore this path. - if (!str_util::StartsWith(child_path, fixed_prefix)) { - children_dir_status[i] = Status(tensorflow::error::CANCELLED, - "Operation not needed"); - } else { - children_dir_status[i] = IsDirectory(child_path); - } - }); - for (int i = 0; i < children.size(); ++i) { - const string child_path = io::JoinPath(current_dir, children[i]); - // If the IsDirectory call was cancelled we bail. - if (children_dir_status[i].code() == tensorflow::error::CANCELLED) { - continue; - } - // If the child is a directory add it to the queue. - if (children_dir_status[i].ok()) { - dir_q.push_back(child_path); - } - all_files.push_back(child_path); - } - } - - // Match all obtained files to the input pattern. - for (const auto& f : all_files) { - if (Env::Default()->MatchPath(f, eval_pattern)) { - results->push_back(f); - } - } - return ret; -} - Status FileSystem::DeleteRecursively(const string& dirname, int64* undeleted_files, int64* undeleted_dirs) { diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h index 8f99766e15..077b1d79cf 100644 --- a/tensorflow/core/platform/file_system.h +++ b/tensorflow/core/platform/file_system.h @@ -138,10 +138,8 @@ class FileSystem { /// * OK - no errors /// * UNIMPLEMENTED - Some underlying functions (like GetChildren) are not /// implemented - /// The default implementation uses a combination of GetChildren, MatchPath - /// and IsDirectory. virtual Status GetMatchingPaths(const string& pattern, - std::vector<string>* results); + std::vector<string>* results) = 0; /// \brief Obtains statistics for the given path. virtual Status Stat(const string& fname, FileStatistics* stat) = 0; diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc new file mode 100644 index 0000000000..22c5057281 --- /dev/null +++ b/tensorflow/core/platform/file_system_helper.cc @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/file_system_helper.h" + +#include <deque> +#include <string> +#include <vector> + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/platform.h" + +namespace tensorflow { +namespace internal { + +namespace { + +constexpr int kNumThreads = 8; + +// Run a function in parallel using a ThreadPool, but skip the ThreadPool +// on the iOS platform due to its problems with more than a few threads. +void ForEach(int first, int last, const std::function<void(int)>& f) { +#if TARGET_OS_IPHONE + for (int i = first; i < last; i++) { + f(i); + } +#else + int num_threads = std::min(kNumThreads, last - first); + thread::ThreadPool threads(Env::Default(), "ForEach", num_threads); + for (int i = first; i < last; i++) { + threads.Schedule([f, i] { f(i); }); + } +#endif +} + +} // namespace + +Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern, + std::vector<string>* results) { + results->clear(); + // Find the fixed prefix by looking for the first wildcard. + string fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\")); + string eval_pattern = pattern; + std::vector<string> all_files; + string dir = io::Dirname(fixed_prefix).ToString(); + // If dir is empty then we need to fix up fixed_prefix and eval_pattern to + // include . as the top level directory. + if (dir.empty()) { + dir = "."; + fixed_prefix = io::JoinPath(dir, fixed_prefix); + eval_pattern = io::JoinPath(dir, pattern); + } + + // Setup a BFS to explore everything under dir. + std::deque<string> dir_q; + dir_q.push_back(dir); + Status ret; // Status to return. + // children_dir_status holds is_dir status for children. It can have three + // possible values: OK for true; FAILED_PRECONDITION for false; CANCELLED + // if we don't calculate IsDirectory (we might do that because there isn't + // any point in exploring that child path). + std::vector<Status> children_dir_status; + while (!dir_q.empty()) { + string current_dir = dir_q.front(); + dir_q.pop_front(); + std::vector<string> children; + Status s = fs->GetChildren(current_dir, &children); + ret.Update(s); + if (children.empty()) continue; + // This IsDirectory call can be expensive for some FS. Parallelizing it. + children_dir_status.resize(children.size()); + ForEach(0, children.size(), + [fs, ¤t_dir, &children, &fixed_prefix, + &children_dir_status](int i) { + const string child_path = io::JoinPath(current_dir, children[i]); + // In case the child_path doesn't start with the fixed_prefix then + // we don't need to explore this path. + if (!str_util::StartsWith(child_path, fixed_prefix)) { + children_dir_status[i] = Status(tensorflow::error::CANCELLED, + "Operation not needed"); + } else { + children_dir_status[i] = fs->IsDirectory(child_path); + } + }); + for (int i = 0; i < children.size(); ++i) { + const string child_path = io::JoinPath(current_dir, children[i]); + // If the IsDirectory call was cancelled we bail. + if (children_dir_status[i].code() == tensorflow::error::CANCELLED) { + continue; + } + // If the child is a directory add it to the queue. + if (children_dir_status[i].ok()) { + dir_q.push_back(child_path); + } + all_files.push_back(child_path); + } + } + + // Match all obtained files to the input pattern. + for (const auto& f : all_files) { + if (env->MatchPath(f, eval_pattern)) { + results->push_back(f); + } + } + return ret; +} + +} // namespace internal +} // namespace tensorflow diff --git a/tensorflow/core/platform/file_system_helper.h b/tensorflow/core/platform/file_system_helper.h new file mode 100644 index 0000000000..8d812b0e38 --- /dev/null +++ b/tensorflow/core/platform/file_system_helper.h @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_HELPER_H_ +#define TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_HELPER_H_ + +#include <string> +#include <vector> + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class FileSystem; +class Env; + +namespace internal { + +// Given a pattern, stores in 'results' the set of paths (in the given file +// system) that match that pattern. +// +// This helper may be used by implementations of FileSystem::GetMatchingPaths() +// in order to provide parallel scanning of subdirectories (except on iOS). +// +// Arguments: +// fs: may not be null and will be used to identify directories and list +// their contents. +// env: may not be null and will be used to check if a match has been found. +// pattern: see FileSystem::GetMatchingPaths() for details. +// results: will be cleared and may not be null. +// +// Returns an error status if any call to 'fs' failed. +Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern, + std::vector<string>* results); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_HELPER_H_ diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 74863293a3..9a71fbe2b7 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/posix/error.h" @@ -396,6 +397,11 @@ Status HadoopFileSystem::GetChildren(const string& dir, return Status::OK(); } +Status HadoopFileSystem::GetMatchingPaths(const string& pattern, + std::vector<string>* results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + Status HadoopFileSystem::DeleteFile(const string& fname) { hdfsFS fs = nullptr; TF_RETURN_IF_ERROR(Connect(fname, &fs)); diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.h b/tensorflow/core/platform/hadoop/hadoop_file_system.h index 5f2b222622..6af7a698ff 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.h +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.h @@ -49,6 +49,9 @@ class HadoopFileSystem : public FileSystem { Status GetChildren(const string& dir, std::vector<string>* result) override; + Status GetMatchingPaths(const string& pattern, + std::vector<string>* results) override; + Status DeleteFile(const string& fname) override; Status CreateDir(const string& name) override; diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc index 6ba2f04d0f..b207d34749 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/test.h" @@ -197,7 +198,7 @@ TEST_F(HadoopFileSystemTest, WriteWhileReading) { // Skip the test if we're not testing on HDFS. Hadoop's local filesystem // implementation makes no guarantees that writable files are readable while // being written. - if (!StringPiece(fname).starts_with("hdfs://")) { + if (!str_util::StartsWith(fname, "hdfs://")) { return; } diff --git a/tensorflow/core/platform/mem.h b/tensorflow/core/platform/mem.h index 7bb9fc264f..fca3a2332d 100644 --- a/tensorflow/core/platform/mem.h +++ b/tensorflow/core/platform/mem.h @@ -59,7 +59,7 @@ void MallocExtension_ReleaseToSystem(std::size_t num_bytes); // routine, this routine returns 0. std::size_t MallocExtension_GetAllocatedSize(const void* p); -// Returns the amount of RAM available in kB, or INT64_MAX if unknown. +// Returns the amount of RAM available in bytes, or INT64_MAX if unknown. int64 AvailableRam(); } // namespace port diff --git a/tensorflow/core/platform/null_file_system.h b/tensorflow/core/platform/null_file_system.h index 008e6d54d0..420abc1ada 100644 --- a/tensorflow/core/platform/null_file_system.h +++ b/tensorflow/core/platform/null_file_system.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/file_system_helper.h" namespace tensorflow { @@ -65,6 +66,11 @@ class NullFileSystem : public FileSystem { return errors::Unimplemented("GetChildren unimplemented"); } + Status GetMatchingPaths(const string& pattern, + std::vector<string>* results) override { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); + } + Status DeleteFile(const string& fname) override { return errors::Unimplemented("DeleteFile unimplemented"); } diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc index 494acde803..8e316472fe 100644 --- a/tensorflow/core/platform/posix/port.cc +++ b/tensorflow/core/platform/posix/port.cc @@ -177,7 +177,7 @@ int64 AvailableRam() { struct sysinfo info; int err = sysinfo(&info); if (err == 0) { - return info.freeram / 1024; + return info.freeram; } #endif return INT64_MAX; diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc index 9a8021565c..47bfa020ce 100644 --- a/tensorflow/core/platform/posix/posix_file_system.cc +++ b/tensorflow/core/platform/posix/posix_file_system.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/posix/error.h" #include "tensorflow/core/platform/posix/posix_file_system.h" @@ -225,6 +226,11 @@ Status PosixFileSystem::GetChildren(const string& dir, return Status::OK(); } +Status PosixFileSystem::GetMatchingPaths(const string& pattern, + std::vector<string>* results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + Status PosixFileSystem::DeleteFile(const string& fname) { Status result; if (unlink(TranslateName(fname).c_str()) != 0) { diff --git a/tensorflow/core/platform/posix/posix_file_system.h b/tensorflow/core/platform/posix/posix_file_system.h index 98ffa43b8a..e8898d0a97 100644 --- a/tensorflow/core/platform/posix/posix_file_system.h +++ b/tensorflow/core/platform/posix/posix_file_system.h @@ -47,6 +47,9 @@ class PosixFileSystem : public FileSystem { Status Stat(const string& fname, FileStatistics* stats) override; + Status GetMatchingPaths(const string& pattern, + std::vector<string>* results) override; + Status DeleteFile(const string& fname) override; Status CreateDir(const string& name) override; diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc index 301fcb9dbf..ee423699b2 100644 --- a/tensorflow/core/platform/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/platform/s3/s3_file_system.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/s3/aws_logging.h" #include "tensorflow/core/platform/s3/s3_crypto.h" @@ -497,6 +498,11 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) { return Status::OK(); } +Status S3FileSystem::GetMatchingPaths(const string& pattern, + std::vector<string>* results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + Status S3FileSystem::DeleteFile(const string& fname) { string bucket, object; TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object)); diff --git a/tensorflow/core/platform/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h index 31264be621..5d0565b378 100644 --- a/tensorflow/core/platform/s3/s3_file_system.h +++ b/tensorflow/core/platform/s3/s3_file_system.h @@ -46,6 +46,9 @@ class S3FileSystem : public FileSystem { Status Stat(const string& fname, FileStatistics* stat) override; + Status GetMatchingPaths(const string& pattern, + std::vector<string>* results) override; + Status DeleteFile(const string& fname) override; Status CreateDir(const string& name) override; diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h index 8f7bff1bb0..3c6e7b0db5 100644 --- a/tensorflow/core/platform/tracing.h +++ b/tensorflow/core/platform/tracing.h @@ -103,7 +103,9 @@ class Tracing { friend class ScopedAnnotation; friend class TraceMe; - static std::atomic<Tracing::Engine*> tracing_engine_; + // TODO: TF_EXPORT is for building //tensorflow/contrib/data:_dataset_ops.so + // on Windows. Figure out a way to remove TF_EXPORT here. + TF_EXPORT static std::atomic<Tracing::Engine*> tracing_engine_; static Tracing::Engine* engine() { return tracing_engine_.load(std::memory_order_acquire); } diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index f3b27ea394..174f41a993 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -166,7 +166,7 @@ int64 AvailableRam() { MEMORYSTATUSEX statex; statex.dwLength = sizeof(statex); if (GlobalMemoryStatusEx(&statex)) { - return statex.ullAvailPhys / 1024; + return statex.ullAvailPhys; } return INT64_MAX; } diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc index 682e46e0fc..dc2efbeaf5 100644 --- a/tensorflow/core/platform/windows/windows_file_system.cc +++ b/tensorflow/core/platform/windows/windows_file_system.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/posix/error.h" #include "tensorflow/core/platform/windows/error.h" @@ -494,7 +495,8 @@ Status WindowsFileSystem::GetMatchingPaths(const string& pattern, // but no code appears to rely on this behavior. string converted_pattern(pattern); std::replace(converted_pattern.begin(), converted_pattern.end(), '\\', '/'); - TF_RETURN_IF_ERROR(FileSystem::GetMatchingPaths(converted_pattern, results)); + TF_RETURN_IF_ERROR(internal::GetMatchingPaths(this, Env::Default(), + converted_pattern, results)); for (string& result : *results) { std::replace(result.begin(), result.end(), '/', '\\'); } diff --git a/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc b/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc index e968b9c97e..96b6cc30bd 100644 --- a/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc +++ b/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/profiler/internal/advisor/tfprof_advisor.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -82,8 +83,8 @@ TEST_F(TFProfAdvisorTest, OperationChecker) { (*options.mutable_checkers())[kCheckers[1]]; AdviceProto advice = advisor_->Advise(options); EXPECT_EQ(advice.checkers().at(kCheckers[1]).reports_size(), 1); - EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[1]).reports(0)) - .contains("NCHW")); + EXPECT_TRUE(str_util::StrContains( + advice.checkers().at(kCheckers[1]).reports(0), "NCHW")); } TEST_F(TFProfAdvisorTest, UtilizationChecker) { @@ -91,16 +92,17 @@ TEST_F(TFProfAdvisorTest, UtilizationChecker) { (*options.mutable_checkers())[kCheckers[0]]; AdviceProto advice = advisor_->Advise(options); EXPECT_EQ(advice.checkers().at(kCheckers[0]).reports_size(), 1); - EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[0]).reports(0)) - .contains("low utilization")); + EXPECT_TRUE(str_util::StrContains( + advice.checkers().at(kCheckers[0]).reports(0), "low utilization")); } TEST_F(TFProfAdvisorTest, ExpensiveOperationChecker) { AdvisorOptionsProto options; (*options.mutable_checkers())[kCheckers[2]]; AdviceProto advice = advisor_->Advise(options); - EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[2]).reports(0)) - .contains("top 1 operation type: Conv2D")); + EXPECT_TRUE( + str_util::StrContains(advice.checkers().at(kCheckers[2]).reports(0), + "top 1 operation type: Conv2D")); } } // namespace tfprof diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index bb772460b0..9b6202e7b4 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -29,6 +29,14 @@ message RewriterConfig { AGGRESSIVE = 3; } + // Enum controling the number of times to run optimizers. The default is to + // run them once. + enum NumIterationsType { + DEFAULT_NUM_ITERS = 0; + ONE = 1; + TWO = 2; + } + // Optimize tensor layouts (default is ON) // e.g. This will try to use NCHW layout on GPU which is faster. Toggle layout_optimizer = 1; @@ -51,6 +59,10 @@ message RewriterConfig { // If true, don't remove unnecessary ops from the graph bool disable_model_pruning = 2; + // Controls how many times we run the optimizers in meta optimizer (default + // is once). + NumIterationsType meta_optimizer_iterations = 12; + enum MemOptType { // The default setting (SCHEDULING and SWAPPING HEURISTICS only) DEFAULT_MEM_OPT = 0; diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 3efc703faf..480ce94fca 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -17,6 +17,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" @@ -28,7 +29,9 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, const std::function<bool(string)>& hook, bool* value_parsing_ok) { *value_parsing_ok = true; - if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { + if (str_util::ConsumePrefix(&arg, "--") && + str_util::ConsumePrefix(&arg, flag) && + str_util::ConsumePrefix(&arg, "=")) { *value_parsing_ok = hook(arg.ToString()); return true; } @@ -40,7 +43,9 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, const std::function<bool(int32)>& hook, bool* value_parsing_ok) { *value_parsing_ok = true; - if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { + if (str_util::ConsumePrefix(&arg, "--") && + str_util::ConsumePrefix(&arg, flag) && + str_util::ConsumePrefix(&arg, "=")) { char extra; int32 parsed_int32; if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) { @@ -60,7 +65,9 @@ bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, const std::function<bool(int64)>& hook, bool* value_parsing_ok) { *value_parsing_ok = true; - if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { + if (str_util::ConsumePrefix(&arg, "--") && + str_util::ConsumePrefix(&arg, flag) && + str_util::ConsumePrefix(&arg, "=")) { char extra; int64 parsed_int64; if (sscanf(arg.data(), "%lld%c", &parsed_int64, &extra) != 1) { @@ -80,7 +87,8 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, const std::function<bool(bool)>& hook, bool* value_parsing_ok) { *value_parsing_ok = true; - if (arg.Consume("--") && arg.Consume(flag)) { + if (str_util::ConsumePrefix(&arg, "--") && + str_util::ConsumePrefix(&arg, flag)) { if (arg.empty()) { *value_parsing_ok = hook(true); return true; @@ -107,7 +115,9 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, const std::function<bool(float)>& hook, bool* value_parsing_ok) { *value_parsing_ok = true; - if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { + if (str_util::ConsumePrefix(&arg, "--") && + str_util::ConsumePrefix(&arg, flag) && + str_util::ConsumePrefix(&arg, "=")) { char extra; float parsed_float; if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) { diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc index c1bc0f3378..ff9c108f10 100644 --- a/tensorflow/core/util/device_name_utils_test.cc +++ b/tensorflow/core/util/device_name_utils_test.cc @@ -408,7 +408,7 @@ static void MergeDevNamesError(const string& name_a, const string& name_b, DeviceNameUtils::ParsedName target_a = Name(name_a); Status s = DeviceNameUtils::MergeDevNames(&target_a, Name(name_b)); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); - EXPECT_TRUE(StringPiece(s.error_message()).contains(expected_error_substr)) + EXPECT_TRUE(str_util::StrContains(s.error_message(), expected_error_substr)) << s; } diff --git a/tensorflow/core/util/equal_graph_def.cc b/tensorflow/core/util/equal_graph_def.cc index f1ec497a67..b87dce0dff 100644 --- a/tensorflow/core/util/equal_graph_def.cc +++ b/tensorflow/core/util/equal_graph_def.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" @@ -144,7 +145,7 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff, int first_control_input = actual.input_size(); for (int i = 0; i < actual.input_size(); ++i) { - if (StringPiece(actual.input(i)).starts_with("^")) { + if (str_util::StartsWith(actual.input(i), "^")) { first_control_input = i; break; } @@ -240,7 +241,7 @@ uint64 NodeDefHash(const NodeDef& ndef, const EqualGraphDefOptions& options) { // Normal inputs. Order important. int first_control_input = ndef.input_size(); for (int i = 0; i < ndef.input_size(); ++i) { - if (StringPiece(ndef.input(i)).starts_with("^")) { + if (str_util::StartsWith(ndef.input(i), "^")) { first_control_input = i; break; } diff --git a/tensorflow/core/util/memmapped_file_system.cc b/tensorflow/core/util/memmapped_file_system.cc index a0f43d2d4a..1fa6b8bec0 100644 --- a/tensorflow/core/util/memmapped_file_system.cc +++ b/tensorflow/core/util/memmapped_file_system.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/util/memmapped_file_system.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/util/memmapped_file_system.pb.h" @@ -157,6 +158,12 @@ Status MemmappedFileSystem::GetChildren(const string& filename, return errors::Unimplemented("memmapped format doesn't support GetChildren"); } +Status MemmappedFileSystem::GetMatchingPaths(const string& pattern, + std::vector<string>* results) { + return errors::Unimplemented( + "memmapped format doesn't support GetMatchingPaths"); +} + Status MemmappedFileSystem::DeleteFile(const string& filename) { return errors::Unimplemented("memmapped format doesn't support DeleteFile"); } @@ -236,7 +243,7 @@ Status MemmappedFileSystem::InitializeFromFile(Env* env, } bool MemmappedFileSystem::IsMemmappedPackageFilename(const string& filename) { - return StringPiece(filename).starts_with(kMemmappedPackagePrefix); + return str_util::StartsWith(filename, kMemmappedPackagePrefix); } namespace { diff --git a/tensorflow/core/util/memmapped_file_system.h b/tensorflow/core/util/memmapped_file_system.h index 541587aeab..76cc4911f5 100644 --- a/tensorflow/core/util/memmapped_file_system.h +++ b/tensorflow/core/util/memmapped_file_system.h @@ -85,6 +85,8 @@ class MemmappedFileSystem : public FileSystem { Status NewAppendableFile(const string& fname, std::unique_ptr<WritableFile>* result) override; Status GetChildren(const string& dir, std::vector<string>* r) override; + Status GetMatchingPaths(const string& pattern, + std::vector<string>* results) override; Status DeleteFile(const string& f) override; Status CreateDir(const string& d) override; Status DeleteDir(const string& d) override; diff --git a/tensorflow/core/util/reporter_test.cc b/tensorflow/core/util/reporter_test.cc index 575c27d4ef..90ea09876e 100644 --- a/tensorflow/core/util/reporter_test.cc +++ b/tensorflow/core/util/reporter_test.cc @@ -29,7 +29,7 @@ namespace { // Tests of all the error paths in log_reader.cc follow: static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) + EXPECT_TRUE(str_util::StrContains(s, expected)) << s << " does not contain " << expected; } diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index 08f1aa7125..7f166f0ec0 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -293,7 +294,7 @@ void VersionTest(const VersionDef& version, StringPiece expected_error) { BundleReader reader(Env::Default(), path); EXPECT_TRUE(errors::IsInvalidArgument(reader.status())); EXPECT_TRUE( - StringPiece(reader.status().error_message()).starts_with(expected_error)); + str_util::StartsWith(reader.status().error_message(), expected_error)); } } // namespace @@ -588,7 +589,7 @@ TEST(TensorBundleTest, Error) { TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f))); EXPECT_FALSE(writer.Add("foo", Constant_2x3(2.f)).ok()); EXPECT_TRUE( - StringPiece(writer.status().ToString()).contains("duplicate key")); + str_util::StrContains(writer.status().ToString(), "duplicate key")); EXPECT_FALSE(writer.Finish().ok()); } { // Double finish @@ -598,7 +599,7 @@ TEST(TensorBundleTest, Error) { } { // Not found. BundleReader reader(Env::Default(), Prefix("nonexist")); - EXPECT_TRUE(StringPiece(reader.status().ToString()).contains("Not found")); + EXPECT_TRUE(str_util::StrContains(reader.status().ToString(), "Not found")); } } @@ -629,7 +630,7 @@ TEST(TensorBundleTest, Checksum) { BundleReader reader(Env::Default(), Prefix(prefix)); Status status = reader.Lookup(key, &val); EXPECT_TRUE(errors::IsDataLoss(status)); - EXPECT_TRUE(StringPiece(status.ToString()).contains(expected_msg)); + EXPECT_TRUE(str_util::StrContains(status.ToString(), expected_msg)); }; // Corrupts a float tensor. @@ -680,8 +681,8 @@ TEST(TensorBundleTest, Endianness) { BundleReader reader(Env::Default(), Prefix("end")); EXPECT_TRUE(errors::IsUnimplemented(reader.status())); - EXPECT_TRUE(StringPiece(reader.status().ToString()) - .contains("different endianness from the reader")); + EXPECT_TRUE(str_util::StrContains(reader.status().ToString(), + "different endianness from the reader")); } TEST(TensorBundleTest, TruncatedTensorContents) { diff --git a/tensorflow/core/util/tensor_slice_reader_test.cc b/tensorflow/core/util/tensor_slice_reader_test.cc index 010cc36823..3c9590e488 100644 --- a/tensorflow/core/util/tensor_slice_reader_test.cc +++ b/tensorflow/core/util/tensor_slice_reader_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -422,7 +423,7 @@ static void VersionTest(const VersionDef& versions, const string& error) { // Read it back in and verify that we get the expected error TensorSliceReader reader(path, OpenTableTensorSliceReader); EXPECT_TRUE(reader.status().code() == error::INVALID_ARGUMENT && - StringPiece(reader.status().error_message()).starts_with(error)) + str_util::StartsWith(reader.status().error_message(), error)) << "Expected error starting with '" << errors::InvalidArgument(error) << "', got '" << reader.status() << "'"; } diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc index ff5bfd65ae..31397f11b6 100644 --- a/tensorflow/core/util/tensor_slice_writer_test.cc +++ b/tensorflow/core/util/tensor_slice_writer_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -333,8 +334,8 @@ TEST(TensorSliceWriteTest, SizeErrors) { const std::vector<int8> data(300000000, -1); Status s = writer.Add("test1", shape, slice, data.data()); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("Tensor slice is too large to serialize")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), "Tensor slice is too large to serialize")); } // Add a large string tensor slice, which will fail. @@ -344,8 +345,8 @@ TEST(TensorSliceWriteTest, SizeErrors) { const std::vector<string> data(256 * 1024, std::string(8192, 'f')); Status s = writer.Add("test2", shape, slice, data.data()); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("Tensor slice is too large to serialize")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), "Tensor slice is too large to serialize")); } } diff --git a/tensorflow/docs_src/extend/index.md b/tensorflow/docs_src/extend/index.md index bdff60b39e..1ab0340ad9 100644 --- a/tensorflow/docs_src/extend/index.md +++ b/tensorflow/docs_src/extend/index.md @@ -16,9 +16,10 @@ TensorFlow: for your own file and record formats. Python is currently the only language supported by TensorFlow's API stability -promises. However, TensorFlow also provides functionality in C++, Java, and Go, +promises. However, TensorFlow also provides functionality in C++, Go, Java and +[JavaScript](https://js.tensorflow.org), plus community support for [Haskell](https://github.com/tensorflow/haskell) and -[Rust](https://github.com/tensorflow/rust). If you'd like to create or +[Rust](https://github.com/tensorflow/rust). If you'd like to create or develop TensorFlow features in a language other than these languages, read the following guide: diff --git a/tensorflow/docs_src/mobile/tflite/devguide.md b/tensorflow/docs_src/mobile/tflite/devguide.md index 5b521dca7b..96392a3c9b 100644 --- a/tensorflow/docs_src/mobile/tflite/devguide.md +++ b/tensorflow/docs_src/mobile/tflite/devguide.md @@ -88,7 +88,7 @@ Tensorflow Lite format. This process uses several model formats: extracted from a `SavedModel`. * *TensorFlow Lite model* (.tflite) —A serialized [FlatBuffer](https://google.github.io/flatbuffers/) that contains TensorFlow - Lite operators and tensors for the TensorFlow Lite interpreter, similiar to a + Lite operators and tensors for the TensorFlow Lite interpreter, similar to a `FrozenGraphDef`. ### Freeze Graph diff --git a/tensorflow/docs_src/programmers_guide/eager.md b/tensorflow/docs_src/programmers_guide/eager.md index 8db65737dc..dc5b403428 100644 --- a/tensorflow/docs_src/programmers_guide/eager.md +++ b/tensorflow/docs_src/programmers_guide/eager.md @@ -1,35 +1,34 @@ # Eager Execution TensorFlow's eager execution is an imperative programming environment that -evaluates operations immediately, without an extra graph-building step. -Operations return concrete values instead of constructing a computational graph -to run later. This makes it easy to get started with TensorFlow, debug models, -reduce boilerplate code, and is fun! To follow along with this guide, run the -code samples below in an interactive `python` interpreter. - -Eager execution supports most TensorFlow operations and GPU acceleration. -Automatic differentiation uses a dynamically-constructed tape instead of a static -graph to compute gradients. Eager execution is a flexible machine learning -platform for research and experimentation that provides: - -* *An intuitive interface* —Structure your code naturally and use Python data +evaluates operations immediately, without building graphs: operations return +concrete values instead of constructing a computational graph to run later. This +makes it easy to get started with TensorFlow and debug models, and it +reduces boilerplate as well. To follow along with this guide, run the code +samples below in an interactive `python` interpreter. + +Eager execution is a flexible machine learning platform for research and +experimentation, providing: + +* *An intuitive interface*—Structure your code naturally and use Python data structures. Quickly iterate on small models and small data. -* *Easier debugging* —Call ops directly to inspect running models and test +* *Easier debugging*—Call ops directly to inspect running models and test changes. Use standard Python debugging tools for immediate error reporting. -* *Natural control flow* —Use Python control flow instead of graph control flow, - including support for dynamic models. +* *Natural control flow*—Use Python control flow instead of graph control + flow, simplifying the specification of dynamic models. -For a collection of examples running in eager execution, see: +Eager execution supports most TensorFlow operations and GPU acceleration. For a +collection of examples running in eager execution, see: [tensorflow/contrib/eager/python/examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples). -Note: Some models may experience increased overhead with eager execution enabled. -Performance improvements are ongoing, but please +Note: Some models may experience increased overhead with eager execution +enabled. Performance improvements are ongoing, but please [file a bug](https://github.com/tensorflow/tensorflow/issues) if you find a problem and share your benchmarks. ## Setup and basic usage -Upgrade to TensorFlow 1.7 to include updates for eager execution: +Upgrade to the latest version of TensorFlow: ``` $ pip install --upgrade tensorflow @@ -110,9 +109,106 @@ environments and is useful for writing code to [work with graphs](#work_with_gra import tensorflow.contrib.eager as tfe ``` +## Dynamic control flow + +A major benefit of eager execution is that all the functionality of the host +language is available while your model is executing. So, for example, +it is easy to write [fizzbuzz](https://en.wikipedia.org/wiki/Fizz_buzz): + +```py +def fizzbuzz(max_num): + counter = tf.constant(0) + for num in range(max_num): + num = tf.constant(num) + if num % 3 == 0 and num % 5 == 0: + print('FizzBuzz') + elif num % 3 == 0: + print('Fizz') + elif num % 5 == 0: + print('Buzz') + else: + print(num) + counter += 1 + return counter +``` + +This has conditionals that depend on tensor values and it prints these values +at runtime. + +## Build a model + +Many machine learning models are represented by composing layers. When +using TensorFlow with eager execution you can either write your own layers or +use a layer provided in the `tf.keras.layers` package. + +While you can use any Python object to represent a layer, +TensorFlow has `tf.keras.layers.Layer` as a convenient base class. Inherit from +it to implement your own layer: + +```py +class MySimpleLayer(tf.keras.layers.Layer): + def __init__(self, output_units): + self.output_units = output_units + + def build(self, input): + # The build method gets called the first time your layer is used. + # Creating variables on build() allows you to make their shape depend + # on the input shape and hence remove the need for the user to specify + # full shapes. It is possible to create variables during __init__() if + # you already know their full shapes. + self.kernel = self.add_variable( + "kernel", [input.shape[-1], self.output_units]) + + def call(self, input): + # Override call() instead of __call__ so we can perform some bookkeeping. + return tf.matmul(input, self.kernel) +``` + +Use `tf.keras.layers.Dense` layer instead of `MySimpleLayer` above as it has +a superset of its functionality (it can also add a bias). + +When composing layers into models you can use `tf.keras.Sequential` to represent +models which are a linear stack of layers. It is easy to use for basic models: + +```py +model = tf.keras.Sequential([ + tf.keras.layers.Dense(10, input_shape=(784,)), # must declare input shape + tf.keras.layers.Dense(10) +]) +``` + +Alternatively, organize models in classes by inheriting from `tf.keras.Model`. +This is a container for layers that is a layer itself, allowing `tf.keras.Model` +objects to contain other `tf.keras.Model` objects. + +```py +class MNISTModel(tf.keras.Model): + def __init__(self): + super(MNISTModel, self).__init__() + self.dense1 = tf.keras.layers.Dense(units=10) + self.dense2 = tf.keras.layers.Dense(units=10) + + def call(self, input): + """Run the model.""" + result = self.dense1(input) + result = self.dense2(result) + result = self.dense2(result) # reuse variables from dense2 layer + return result + +model = MNISTModel() +``` + +It's not required to set an input shape for the `tf.keras.Model` class since +the parameters are set the first time input is passed to the layer. + +`tf.keras.layers` classes create and contain their own model variables that +are tied to the lifetime of their layer objects. To share layer variables, share +their objects. + + ## Eager training -### Automatic differentiation +### Computing gradients [Automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) is useful for implementing machine learning algorithms such as @@ -124,7 +220,7 @@ operations for computing gradients later. not tracing. Since different operations can occur during each call, all forward-pass operations get recorded to a "tape". To compute the gradient, play the tape backwards and then discard. A particular `tfe.GradientTape` can only -be computed once, subsequent calls throw a runtime error. +compute one gradient; subsequent calls throw a runtime error. ```py w = tfe.Variable([[1.0]]) @@ -216,189 +312,12 @@ for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): global_step=tf.train.get_or_create_global_step()) ``` -#### Dynamic models - -`tfe.GradientTape` can also be used in dynamic models. This example for a -[backtracking line search](https://wikipedia.org/wiki/Backtracking_line_search) -algorithm looks like normal NumPy code, except there are gradients and is -differentiable, despite the complex control flow: - -```py -def line_search_step(fn, init_x, rate=1.0): - with tfe.GradientTape() as tape: - # Variables are automatically recorded, but manually watch a tensor - tape.watch(init_x) - value = fn(init_x) - grad, = tape.gradient(value, [init_x]) - grad_norm = tf.reduce_sum(grad * grad) - init_value = value - while value > init_value - rate * grad_norm: - x = init_x - rate * grad - value = fn(x) - rate /= 2.0 - return x, value -``` - -#### Additional functions to compute gradients - -`tfe.GradientTape` is a powerful interface for computing gradients, but there -is another [Autograd](https://github.com/HIPS/autograd)-style API available for -automatic differentiation. These functions are useful if writing math code with -only tensors and gradient functions, and without `tfe.Variables`: - -* `tfe.gradients_function` —Returns a function that computes the derivatives - of its input function parameter with respect to its arguments. The input - function parameter must return a scalar value. When the returned function is - invoked, it returns a list of `tf.Tensor` objects: one element for each - argument of the input function. Since anything of interest must be passed as a - function parameter, this becomes unwieldy if there's a dependency on many - trainable parameters. -* `tfe.value_and_gradients_function` —Similar to - `tfe.gradients_function`, but when the returned function is invoked, it - returns the value from the input function in addition to the list of - derivatives of the input function with respect to its arguments. - -In the following example, `tfe.gradients_function` takes the `square` -function as an argument and returns a function that computes the partial -derivatives of `square` with respect to its inputs. To calculate the derivative -of `square` at `3`, `grad(3.0)` returns `6`. - -```py -def square(x): - return tf.multiply(x, x) - -grad = tfe.gradients_function(square) - -square(3.) # => 9.0 -grad(3.) # => [6.0] - -# The second-order derivative of square: -gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) -gradgrad(3.) # => [2.0] - -# The third-order derivative is None: -gradgradgrad = tfe.gradients_function(lambda x: gradgrad(x)[0]) -gradgradgrad(3.) # => [None] - - -# With flow control: -def abs(x): - return x if x > 0. else -x - -grad = tfe.gradients_function(abs) - -grad(3.) # => [1.0] -grad(-3.) # => [-1.0] -``` - -### Custom gradients - -Custom gradients are an easy way to override gradients in eager and graph -execution. Within the forward function, define the gradient with respect to the -inputs, outputs, or intermediate results. For example, here's an easy way to clip -the norm of the gradients in the backward pass: - -```py -@tf.custom_gradient -def clip_gradient_by_norm(x, norm): - y = tf.identity(x) - def grad_fn(dresult): - return [tf.clip_by_norm(dresult, norm), None] - return y, grad_fn -``` - -Custom gradients are commonly used to provide a numerically stable gradient for a -sequence of operations: - -```py -def log1pexp(x): - return tf.log(1 + tf.exp(x)) -grad_log1pexp = tfe.gradients_function(log1pexp) - -# The gradient computation works fine at x = 0. -grad_log1pexp(0.) # => [0.5] - -# However, x = 100 fails because of numerical instability. -grad_log1pexp(100.) # => [nan] -``` - -Here, the `log1pexp` function can be analytically simplified with a custom -gradient. The implementation below reuses the value for `tf.exp(x)` that is -computed during the forward pass—making it more efficient by eliminating -redundant calculations: - -```py -@tf.custom_gradient -def log1pexp(x): - e = tf.exp(x) - def grad(dy): - return dy * (1 - 1 / (1 + e)) - return tf.log(1 + e), grad - -grad_log1pexp = tfe.gradients_function(log1pexp) - -# As before, the gradient computation works fine at x = 0. -grad_log1pexp(0.) # => [0.5] - -# And the gradient computation also works at x = 100. -grad_log1pexp(100.) # => [1.0] -``` - - -## Build and train models - -There are many parameters to optimize when calculating derivatives. TensorFlow -code is easier to read when structured into reusable classes and objects instead -of a single top-level function. Eager execution encourages the use of the -Keras-style layer classes in the `tf.keras.layers` module. Additionally, the -`tf.train.Optimizer` classes provide sophisticated techniques to calculate -parameter updates. The following example creates a multi-layer model that classifies the standard [MNIST handwritten digits](https://www.tensorflow.org/tutorials/layers). It demonstrates the optimizer and layer APIs to build trainable graphs in an eager execution environment. -### Build a model - -The `tf.keras.Sequential` model is a linear stack of layers. It is easy to -use for basic models: - -```py -model = tf.keras.Sequential([ - tf.keras.layers.Dense(10, input_shape=(784,)), # must declare input shape - tf.keras.layers.Dense(10) -]) -``` - -Alternatively, organize models in classes by inheriting from `tf.keras.Model`. -This is a container for layers that is a layer itself, allowing `tf.keras.Model` -objects to contain other `tf.keras.Model` objects. - -```py -class MNISTModel(tf.keras.Model): - def __init__(self): - super(MNISTModel, self).__init__() - self.dense1 = tf.keras.layers.Dense(units=10) - self.dense2 = tf.keras.layers.Dense(units=10) - - def call(self, input): - """Run the model.""" - result = self.dense1(input) - result = self.dense2(result) - result = self.dense2(result) # reuse variables from dense2 layer - return result - -model = MNISTModel() -``` - -It's not required to set an input shape for the `tf.keras.Model` class since -the parameters are set the first time input is passed to the layer. - -`tf.keras.layers` classes create and contain their own model variables that -are tied to the lifetime of their layer objects. To share layer variables, share -their objects. - ### Train a model Even without training, call the model and inspect the output in eager execution: @@ -415,7 +334,7 @@ result = model(batch) This example uses the [dataset.py module](https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py) from the -[TensorFlow MNIST example](https://github.com/tensorflow/models/tree/master/official/mnist), +[TensorFlow MNIST example](https://github.com/tensorflow/models/tree/master/official/mnist); download this file to your local directory. Run the following to download the MNIST data files to your working directory and prepare a `tf.data.Dataset` for training: @@ -662,11 +581,141 @@ for _ in range(iterations): ... ``` +## Advanced automatic differentiation topics + +### Dynamic models + +`tfe.GradientTape` can also be used in dynamic models. This example for a +[backtracking line search](https://wikipedia.org/wiki/Backtracking_line_search) +algorithm looks like normal NumPy code, except there are gradients and is +differentiable, despite the complex control flow: + +```py +def line_search_step(fn, init_x, rate=1.0): + with tfe.GradientTape() as tape: + # Variables are automatically recorded, but manually watch a tensor + tape.watch(init_x) + value = fn(init_x) + grad, = tape.gradient(value, [init_x]) + grad_norm = tf.reduce_sum(grad * grad) + init_value = value + while value > init_value - rate * grad_norm: + x = init_x - rate * grad + value = fn(x) + rate /= 2.0 + return x, value +``` + +### Additional functions to compute gradients + +`tfe.GradientTape` is a powerful interface for computing gradients, but there +is another [Autograd](https://github.com/HIPS/autograd)-style API available for +automatic differentiation. These functions are useful if writing math code with +only tensors and gradient functions, and without `tfe.Variables`: + +* `tfe.gradients_function` —Returns a function that computes the derivatives + of its input function parameter with respect to its arguments. The input + function parameter must return a scalar value. When the returned function is + invoked, it returns a list of `tf.Tensor` objects: one element for each + argument of the input function. Since anything of interest must be passed as a + function parameter, this becomes unwieldy if there's a dependency on many + trainable parameters. +* `tfe.value_and_gradients_function` —Similar to + `tfe.gradients_function`, but when the returned function is invoked, it + returns the value from the input function in addition to the list of + derivatives of the input function with respect to its arguments. + +In the following example, `tfe.gradients_function` takes the `square` +function as an argument and returns a function that computes the partial +derivatives of `square` with respect to its inputs. To calculate the derivative +of `square` at `3`, `grad(3.0)` returns `6`. + +```py +def square(x): + return tf.multiply(x, x) + +grad = tfe.gradients_function(square) + +square(3.) # => 9.0 +grad(3.) # => [6.0] + +# The second-order derivative of square: +gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) +gradgrad(3.) # => [2.0] + +# The third-order derivative is None: +gradgradgrad = tfe.gradients_function(lambda x: gradgrad(x)[0]) +gradgradgrad(3.) # => [None] + + +# With flow control: +def abs(x): + return x if x > 0. else -x + +grad = tfe.gradients_function(abs) + +grad(3.) # => [1.0] +grad(-3.) # => [-1.0] +``` + +### Custom gradients + +Custom gradients are an easy way to override gradients in eager and graph +execution. Within the forward function, define the gradient with respect to the +inputs, outputs, or intermediate results. For example, here's an easy way to clip +the norm of the gradients in the backward pass: + +```py +@tf.custom_gradient +def clip_gradient_by_norm(x, norm): + y = tf.identity(x) + def grad_fn(dresult): + return [tf.clip_by_norm(dresult, norm), None] + return y, grad_fn +``` + +Custom gradients are commonly used to provide a numerically stable gradient for a +sequence of operations: + +```py +def log1pexp(x): + return tf.log(1 + tf.exp(x)) +grad_log1pexp = tfe.gradients_function(log1pexp) + +# The gradient computation works fine at x = 0. +grad_log1pexp(0.) # => [0.5] + +# However, x = 100 fails because of numerical instability. +grad_log1pexp(100.) # => [nan] +``` + +Here, the `log1pexp` function can be analytically simplified with a custom +gradient. The implementation below reuses the value for `tf.exp(x)` that is +computed during the forward pass—making it more efficient by eliminating +redundant calculations: + +```py +@tf.custom_gradient +def log1pexp(x): + e = tf.exp(x) + def grad(dy): + return dy * (1 - 1 / (1 + e)) + return tf.log(1 + e), grad + +grad_log1pexp = tfe.gradients_function(log1pexp) + +# As before, the gradient computation works fine at x = 0. +grad_log1pexp(0.) # => [0.5] + +# And the gradient computation also works at x = 100. +grad_log1pexp(100.) # => [1.0] +``` + ## Performance -Computation is not automatically offloaded to GPUs during eager execution. To -explicitly direct a computation to a GPU, enclose it in a -`tf.device('/gpu:0')` block: +Computation is automatically offloaded to GPUs during eager execution. If you +want control over where a computation runs you can enclose it in a +`tf.device('/gpu:0')` block (or the CPU equivalent): ```py import time diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md index e8c2fa6990..017db0e8cb 100644 --- a/tensorflow/docs_src/programmers_guide/index.md +++ b/tensorflow/docs_src/programmers_guide/index.md @@ -5,6 +5,7 @@ works. The units are as follows: ## High Level APIs + * @{$programmers_guide/eager}, which is the easiest way to use tensorflow. * @{$programmers_guide/estimators}, which introduces a high-level TensorFlow API that greatly simplifies ML programming. * @{$programmers_guide/datasets}, which explains how to diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index 63bc39de6c..baa65d3243 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" @@ -137,15 +138,15 @@ Status ReadTensorFromImageFile(const string& file_name, const int input_height, // Now try to figure out what kind of file it is and decode it. const int wanted_channels = 3; tensorflow::Output image_reader; - if (tensorflow::StringPiece(file_name).ends_with(".png")) { + if (tensorflow::str_util::EndsWith(file_name, ".png")) { image_reader = DecodePng(root.WithOpName("png_reader"), file_reader, DecodePng::Channels(wanted_channels)); - } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) { + } else if (tensorflow::str_util::EndsWith(file_name, ".gif")) { // gif decoder returns 4-D tensor, remove the first dim image_reader = Squeeze(root.WithOpName("squeeze_first_dim"), DecodeGif(root.WithOpName("gif_reader"), file_reader)); - } else if (tensorflow::StringPiece(file_name).ends_with(".bmp")) { + } else if (tensorflow::str_util::EndsWith(file_name, ".bmp")) { image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader); } else { // Assume if it's neither a PNG nor a GIF then it must be a JPEG. diff --git a/tensorflow/examples/multibox_detector/main.cc b/tensorflow/examples/multibox_detector/main.cc index e38704fd98..96ea525a4e 100644 --- a/tensorflow/examples/multibox_detector/main.cc +++ b/tensorflow/examples/multibox_detector/main.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -84,10 +85,10 @@ Status ReadTensorFromImageFile(const string& file_name, const int input_height, // Now try to figure out what kind of file it is and decode it. const int wanted_channels = 3; tensorflow::Output image_reader; - if (tensorflow::StringPiece(file_name).ends_with(".png")) { + if (tensorflow::str_util::EndsWith(file_name, ".png")) { image_reader = DecodePng(root.WithOpName("png_reader"), file_reader, DecodePng::Channels(wanted_channels)); - } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) { + } else if (tensorflow::str_util::EndsWith(file_name, ".gif")) { image_reader = DecodeGif(root.WithOpName("gif_reader"), file_reader); } else { // Assume if it's neither a PNG nor a GIF then it must be a JPEG. @@ -131,7 +132,7 @@ Status ReadTensorFromImageFile(const string& file_name, const int input_height, Status SaveImage(const Tensor& tensor, const string& file_path) { LOG(INFO) << "Saving image to " << file_path; - CHECK(tensorflow::StringPiece(file_path).ends_with(".png")) + CHECK(tensorflow::str_util::EndsWith(file_path, ".png")) << "Only saving of png files is supported."; auto root = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index a33703ad6f..0fd2177df7 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -1720,6 +1720,131 @@ func Size(scope *Scope, input tf.Output, optional ...SizeAttr) (output tf.Output return op.Output(0) } +// Returns the rank of a tensor. +// +// This operation returns an integer representing the rank of `input`. +// +// For example: +// +// ``` +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// # shape of tensor 't' is [2, 2, 3] +// rank(t) ==> 3 +// ``` +// +// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank +// of a tensor is the number of indices required to uniquely select each element +// of the tensor. Rank is also known as "order", "degree", or "ndims." +func Rank(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Rank", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ReverseSequenceAttr is an optional argument to ReverseSequence. +type ReverseSequenceAttr func(optionalAttr) + +// ReverseSequenceBatchDim sets the optional batch_dim attribute to value. +// +// value: The dimension along which reversal is performed. +// If not specified, defaults to 0 +func ReverseSequenceBatchDim(value int64) ReverseSequenceAttr { + return func(m optionalAttr) { + m["batch_dim"] = value + } +} + +// Reverses variable length slices. +// +// This op first slices `input` along the dimension `batch_dim`, and for each +// slice `i`, reverses the first `seq_lengths[i]` elements along +// the dimension `seq_dim`. +// +// The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`, +// and `seq_lengths` must be a vector of length `input.dims[batch_dim]`. +// +// The output slice `i` along dimension `batch_dim` is then given by input +// slice `i`, with the first `seq_lengths[i]` slices along dimension +// `seq_dim` reversed. +// +// For example: +// +// ``` +// # Given this: +// batch_dim = 0 +// seq_dim = 1 +// input.dims = (4, 8, ...) +// seq_lengths = [7, 2, 3, 5] +// +// # then slices of input are reversed on seq_dim, but only up to seq_lengths: +// output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] +// output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] +// output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] +// output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] +// +// # while entries past seq_lens are copied through: +// output[0, 7:, :, ...] = input[0, 7:, :, ...] +// output[1, 2:, :, ...] = input[1, 2:, :, ...] +// output[2, 3:, :, ...] = input[2, 3:, :, ...] +// output[3, 2:, :, ...] = input[3, 2:, :, ...] +// ``` +// +// In contrast, if: +// +// ``` +// # Given this: +// batch_dim = 2 +// seq_dim = 0 +// input.dims = (8, ?, 4, ...) +// seq_lengths = [7, 2, 3, 5] +// +// # then slices of input are reversed on seq_dim, but only up to seq_lengths: +// output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] +// output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] +// output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] +// output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] +// +// # while entries past seq_lens are copied through: +// output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] +// output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] +// output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] +// output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...] +// ``` +// +// Arguments: +// input: The input to reverse. +// seq_lengths: 1-D with length `input.dims(batch_dim)` and +// `max(seq_lengths) <= input.dims(seq_dim)` +// seq_dim: The dimension which is partially reversed. +// +// Returns The partially reversed input. It has the same shape as `input`. +func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_dim int64, optional ...ReverseSequenceAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"seq_dim": seq_dim} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ReverseSequence", + Input: []tf.Input{ + input, seq_lengths, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Returns the complex conjugate of a complex number. // // Given a tensor `input` of complex numbers, this operation returns a tensor of @@ -5128,102 +5253,6 @@ func RsqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { return op.Output(0) } -// ReverseSequenceAttr is an optional argument to ReverseSequence. -type ReverseSequenceAttr func(optionalAttr) - -// ReverseSequenceBatchDim sets the optional batch_dim attribute to value. -// -// value: The dimension along which reversal is performed. -// If not specified, defaults to 0 -func ReverseSequenceBatchDim(value int64) ReverseSequenceAttr { - return func(m optionalAttr) { - m["batch_dim"] = value - } -} - -// Reverses variable length slices. -// -// This op first slices `input` along the dimension `batch_dim`, and for each -// slice `i`, reverses the first `seq_lengths[i]` elements along -// the dimension `seq_dim`. -// -// The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`, -// and `seq_lengths` must be a vector of length `input.dims[batch_dim]`. -// -// The output slice `i` along dimension `batch_dim` is then given by input -// slice `i`, with the first `seq_lengths[i]` slices along dimension -// `seq_dim` reversed. -// -// For example: -// -// ``` -// # Given this: -// batch_dim = 0 -// seq_dim = 1 -// input.dims = (4, 8, ...) -// seq_lengths = [7, 2, 3, 5] -// -// # then slices of input are reversed on seq_dim, but only up to seq_lengths: -// output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] -// output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] -// output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] -// output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] -// -// # while entries past seq_lens are copied through: -// output[0, 7:, :, ...] = input[0, 7:, :, ...] -// output[1, 2:, :, ...] = input[1, 2:, :, ...] -// output[2, 3:, :, ...] = input[2, 3:, :, ...] -// output[3, 2:, :, ...] = input[3, 2:, :, ...] -// ``` -// -// In contrast, if: -// -// ``` -// # Given this: -// batch_dim = 2 -// seq_dim = 0 -// input.dims = (8, ?, 4, ...) -// seq_lengths = [7, 2, 3, 5] -// -// # then slices of input are reversed on seq_dim, but only up to seq_lengths: -// output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] -// output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] -// output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] -// output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] -// -// # while entries past seq_lens are copied through: -// output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] -// output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] -// output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] -// output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...] -// ``` -// -// Arguments: -// input: The input to reverse. -// seq_lengths: 1-D with length `input.dims(batch_dim)` and -// `max(seq_lengths) <= input.dims(seq_dim)` -// seq_dim: The dimension which is partially reversed. -// -// Returns The partially reversed input. It has the same shape as `input`. -func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_dim int64, optional ...ReverseSequenceAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"seq_dim": seq_dim} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ReverseSequence", - Input: []tf.Input{ - input, seq_lengths, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. type DepthwiseConv2dNativeAttr func(optionalAttr) @@ -5808,35 +5837,6 @@ func FusedBatchNormV2(scope *Scope, x tf.Output, scale tf.Output, offset tf.Outp return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// Returns the rank of a tensor. -// -// This operation returns an integer representing the rank of `input`. -// -// For example: -// -// ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// # shape of tensor 't' is [2, 2, 3] -// rank(t) ==> 3 -// ``` -// -// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank -// of a tensor is the number of indices required to uniquely select each element -// of the tensor. Rank is also known as "order", "degree", or "ndims." -func Rank(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Rank", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Transforms a Tensor into a serialized TensorProto proto. // // Arguments: diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index 0b69a8cbe5..c99d04869a 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.7.0-rc1</version> + <version>1.7.0</version> <relativePath>../</relativePath> </parent> <artifactId>libtensorflow</artifactId> diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index 541876f7f5..4561c2c8ad 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.7.0-rc1</version> + <version>1.7.0</version> <relativePath>../</relativePath> </parent> <artifactId>libtensorflow_jni</artifactId> diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index d8933e5238..82a2b8e769 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.7.0-rc1</version> + <version>1.7.0</version> <relativePath>../</relativePath> </parent> <artifactId>libtensorflow_jni_gpu</artifactId> diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index 6286fd73df..4c1ec0cc80 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ <modelVersion>4.0.0</modelVersion> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.7.0-rc1</version> + <version>1.7.0</version> <packaging>pom</packaging> <url>https://www.tensorflow.org</url> diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index 4e881f5a63..fcd8236bad 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.7.0-rc1</version> + <version>1.7.0</version> <relativePath>../</relativePath> </parent> <artifactId>proto</artifactId> diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index d512a7eda9..241581713a 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.7.0-rc1</version> + <version>1.7.0</version> <relativePath>../</relativePath> </parent> <artifactId>tensorflow</artifactId> diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a3e79d46d8..6ec8a1cdab 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -28,6 +28,8 @@ load("//tensorflow:tensorflow.bzl", "py_tests") load("//tensorflow:tensorflow.bzl", "tf_py_build_info_genrule") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") +load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library") @@ -58,9 +60,10 @@ py_library( "//tensorflow/tools/api/generator:__pkg__", "//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed ], - deps = [":no_contrib"] + if_not_windows([ + deps = [ + ":no_contrib", "//tensorflow/contrib:contrib_py", - ]), + ], ) py_library( @@ -284,6 +287,17 @@ cc_library( ) cc_library( + name = "py_exception_registry", + srcs = ["lib/core/py_exception_registry.cc"], + hdrs = ["lib/core/py_exception_registry.h"], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/core:lib", + "//util/python:python_headers", + ], +) + +cc_library( name = "kernel_registry", srcs = ["util/kernel_registry.cc"], hdrs = ["util/kernel_registry.h"], @@ -413,6 +427,7 @@ tf_cc_shared_object( "-lm", ], "//tensorflow:darwin": [], + "//tensorflow:windows": [], }), deps = [ "//tensorflow/core:framework_headers_lib", @@ -960,7 +975,6 @@ py_test( srcs = ["framework/contrib_test.py"], main = "framework/contrib_test.py", srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ "//tensorflow:tensorflow_py", "//tensorflow/python:client_testlib", @@ -1330,7 +1344,6 @@ py_test( srcs = ["framework/dtypes_test.py"], main = "framework/dtypes_test.py", srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -1706,7 +1719,6 @@ py_test( size = "small", srcs = ["ops/clip_ops_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ ":client_testlib", ":clip_ops", @@ -2775,7 +2787,6 @@ cuda_py_test( ], data = ["//tensorflow/core:image_testdata"], shard_count = 5, - tags = ["no_windows"], ) cuda_py_test( @@ -3313,6 +3324,7 @@ tf_py_wrap_cc( "grappler/model_analyzer.i", "grappler/tf_optimizer.i", "lib/core/bfloat16.i", + "lib/core/py_exception_registry.i", "lib/core/py_func.i", "lib/core/strings.i", "lib/io/file_io.i", @@ -3340,6 +3352,7 @@ tf_py_wrap_cc( ":kernel_registry", ":numpy_lib", ":safe_ptr", + ":py_exception_registry", ":py_func_lib", ":py_record_reader_lib", ":py_record_writer_lib", @@ -3374,6 +3387,65 @@ tf_py_wrap_cc( tf_additional_gdr_deps()), ) +# ** Targets for Windows build (start) ** +# We need the following targets to expose symbols from _pywrap_tensorflow.dll + +# Build a cc_binary from tf_custom_op_library_additional_deps_impl, +# it contains all object code from its dependencies. +tf_native_cc_binary( + name = "tf_custom_op_library_additional_deps.so", + linkshared = 1, + linkstatic = 1, + deps = tf_custom_op_library_additional_deps_impl(), +) + +# Get a DEF file generated by parsing all object files +# of tf_custom_op_library_additional_deps.so +filegroup( + name = "pywrap_tensorflow_def_file", + srcs = [":tf_custom_op_library_additional_deps.so"], + output_group = "def_file", +) + +# Filter the DEF file to reduce the number of symbols to 64K or less. +# Note that we also write the name of the pyd file into DEF file so that +# the dynamic libraries of custom ops can find it at runtime. +genrule( + name = "pywrap_tensorflow_filtered_def_file", + srcs = [":pywrap_tensorflow_def_file"], + outs = ["pywrap_tensorflow_filtered_def_file.def"], + cmd = select({ + "//tensorflow:windows": """ + $(location @local_config_def_file_filter//:def_file_filter) \\ + --input $(location :pywrap_tensorflow_def_file) \\ + --output $@ \\ + --target _pywrap_tensorflow_internal.pyd + """, + "//conditions:default": "touch $@", # Just a placeholder for Unix platforms + }), + tools = ["@local_config_def_file_filter//:def_file_filter"], +) + +# Get the import library of _pywrap_tensorflow_internal.dll +filegroup( + name = "pywrap_tensorflow_import_lib_file", + srcs = [":_pywrap_tensorflow_internal.so"], + output_group = "interface_library", +) + +# Create a cc_import rule for the import library of _pywrap_tensorflow_internal.dll +# so that custom ops' dynamic libraries can link against it. +cc_import( + name = "pywrap_tensorflow_import_lib", + interface_library = select({ + "//tensorflow:windows": ":pywrap_tensorflow_import_lib_file", + "//conditions:default": "not_exsiting_on_unix.lib", # Just a placeholder for Unix platforms + }), + system_provided = 1, +) + +# ** Targets for Windows build (end) ** + py_library( name = "lib", srcs = [ @@ -3707,6 +3779,7 @@ cuda_py_test( ":math_ops", "//tensorflow/core:protos_all_py", ], + tags = ["no_windows"], ) cuda_py_test( @@ -3746,7 +3819,6 @@ py_test( size = "small", srcs = ["lib/core/bfloat16_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ ":client_testlib", ":lib", @@ -3939,6 +4011,7 @@ py_test( srcs = ["training/saver_large_partitioned_variable_test.py"], srcs_version = "PY2AND3", tags = [ + "no_windows", "noasan", # http://b/30782289 "notsan", # http://b/30782289 ], @@ -4054,7 +4127,6 @@ py_test( size = "small", srcs = ["training/checkpoint_ops_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ ":checkpoint_ops_gen", ":client", @@ -4095,10 +4167,7 @@ py_test( size = "medium", srcs = ["training/monitored_session_test.py"], srcs_version = "PY2AND3", - tags = [ - "no_windows", - "notsan", # b/67945581 - ], + tags = ["notsan"], # b/67945581 deps = [ ":array_ops", ":client_testlib", @@ -4710,6 +4779,7 @@ py_test( ":client_testlib", ":framework_for_generated_wrappers", ":math_ops", + ":tf_item", ":tf_optimizer", "//tensorflow/core:protos_all_py", "//third_party/py/numpy", @@ -4771,6 +4841,29 @@ py_test( ) cuda_py_test( + name = "constant_folding_test", + size = "medium", + srcs = [ + "grappler/constant_folding_test.py", + ], + additional_deps = [ + ":client_testlib", + ":framework_for_generated_wrappers", + ":array_ops", + ":control_flow_ops", + ":dtypes", + ":functional_ops", + ":math_ops", + ":ops", + "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", + ], + tags = [ + "grappler", + ], +) + +cuda_py_test( name = "layout_optimizer_test", size = "medium", srcs = [ diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 5c9ed9ccaf..4c84d78f2e 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -27,7 +27,6 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow as tf_session -from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -629,14 +628,12 @@ class BaseSession(SessionInterface): self._session = None opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) try: - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - # pylint: disable=protected-access - self._session = tf_session.TF_NewSession(self._graph._c_graph, opts, - status) - # pylint: enable=protected-access - else: - self._session = tf_session.TF_NewDeprecatedSession(opts, status) + if self._created_with_new_api: + # pylint: disable=protected-access + self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) + # pylint: enable=protected-access + else: + self._session = tf_session.TF_NewDeprecatedSession(opts) finally: tf_session.TF_DeleteSessionOptions(opts) @@ -663,22 +660,20 @@ class BaseSession(SessionInterface): Returns: A list of devices in the session. """ - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - raw_device_list = tf_session.TF_SessionListDevices( - self._session, status) - else: - raw_device_list = tf_session.TF_DeprecatedSessionListDevices( - self._session, status) - device_list = [] - size = tf_session.TF_DeviceListCount(raw_device_list) - for i in range(size): - name = tf_session.TF_DeviceListName(raw_device_list, i, status) - device_type = tf_session.TF_DeviceListType(raw_device_list, i, status) - memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i, status) - device_list.append(_DeviceAttributes(name, device_type, memory)) - tf_session.TF_DeleteDeviceList(raw_device_list) - return device_list + if self._created_with_new_api: + raw_device_list = tf_session.TF_SessionListDevices(self._session) + else: + raw_device_list = tf_session.TF_DeprecatedSessionListDevices( + self._session) + device_list = [] + size = tf_session.TF_DeviceListCount(raw_device_list) + for i in range(size): + name = tf_session.TF_DeviceListName(raw_device_list, i) + device_type = tf_session.TF_DeviceListType(raw_device_list, i) + memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i) + device_list.append(_DeviceAttributes(name, device_type, memory)) + tf_session.TF_DeleteDeviceList(raw_device_list) + return device_list def close(self): """Closes this session. @@ -692,15 +687,13 @@ class BaseSession(SessionInterface): if self._created_with_new_api: if self._session and not self._closed: self._closed = True - with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_CloseSession(self._session, status) + tf_session.TF_CloseSession(self._session) else: with self._extend_lock: if self._opened and not self._closed: self._closed = True - with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_CloseDeprecatedSession(self._session, status) + tf_session.TF_CloseDeprecatedSession(self._session) def __del__(self): # cleanly ignore all exceptions @@ -710,11 +703,10 @@ class BaseSession(SessionInterface): pass if self._session is not None: try: - status = c_api_util.ScopedTFStatus() if self._created_with_new_api: - tf_session.TF_DeleteSession(self._session, status) + tf_session.TF_DeleteSession(self._session) else: - tf_session.TF_DeleteDeprecatedSession(self._session, status) + tf_session.TF_DeleteDeprecatedSession(self._session) except AttributeError: # At shutdown, `c_api_util` or `tf_session` may have been garbage # collected, causing the above method calls to fail. In this case, @@ -1031,11 +1023,11 @@ class BaseSession(SessionInterface): # Set up a graph with feeds and fetches for partial run. def _setup_fn(session, feed_list, fetch_list, target_list): self._extend_graph() - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - return tf_session.TF_SessionPRunSetup_wrapper( - session, feed_list, fetch_list, target_list, status) - else: + if self._created_with_new_api: + return tf_session.TF_SessionPRunSetup_wrapper( + session, feed_list, fetch_list, target_list) + else: + with errors.raise_exception_on_not_ok_status() as status: return tf_session.TF_PRunSetup(session, feed_list, fetch_list, target_list, status) @@ -1345,8 +1337,7 @@ class BaseSession(SessionInterface): def _extend_graph(self): if self._created_with_new_api: with self._graph._lock: # pylint: disable=protected-access - with errors.raise_exception_on_not_ok_status() as status: - tf_session.ExtendSession(self._session, status) + tf_session.ExtendSession(self._session) else: # Ensure any changes to the graph are reflected in the runtime. with self._extend_lock: @@ -1412,22 +1403,22 @@ class BaseSession(SessionInterface): def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata): - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - return tf_session.TF_SessionRun_wrapper( - self._session, options, feed_dict, fetch_list, target_list, - run_metadata, status) - else: + if self._created_with_new_api: + return tf_session.TF_SessionRun_wrapper( + self._session, options, feed_dict, fetch_list, target_list, + run_metadata) + else: + with errors.raise_exception_on_not_ok_status() as status: return tf_session.TF_Run( self._session, options, feed_dict, fetch_list, target_list, status, run_metadata) def _call_tf_sessionprun(self, handle, feed_dict, fetch_list): - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - return tf_session.TF_SessionPRun_wrapper( - self._session, handle, feed_dict, fetch_list, status) - else: + if self._created_with_new_api: + return tf_session.TF_SessionPRun_wrapper( + self._session, handle, feed_dict, fetch_list) + else: + with errors.raise_exception_on_not_ok_status() as status: return tf_session.TF_PRun( self._session, handle, feed_dict, fetch_list, status) diff --git a/tensorflow/python/client/session_list_devices_test.py b/tensorflow/python/client/session_list_devices_test.py index 5a7413c12e..38a3acb2dc 100644 --- a/tensorflow/python/client/session_list_devices_test.py +++ b/tensorflow/python/client/session_list_devices_test.py @@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow as tf_session from tensorflow.python.client import session -from tensorflow.python.framework import c_api_util from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -42,21 +41,13 @@ class SessionListDevicesTestMethods(object): def testInvalidDeviceNumber(self): opts = tf_session.TF_NewSessionOptions() - with errors.raise_exception_on_not_ok_status() as status: - c_session = tf_session.TF_NewSession( - ops.get_default_graph()._c_graph, opts, status) - raw_device_list = tf_session.TF_SessionListDevices( - c_session, status) + c_session = tf_session.TF_NewSession(ops.get_default_graph()._c_graph, opts) + raw_device_list = tf_session.TF_SessionListDevices(c_session) size = tf_session.TF_DeviceListCount(raw_device_list) - # Test that invalid device numbers return -1 rather than a Swig-wrapped - # pointer. - status_no_exception = c_api_util.ScopedTFStatus() - memory = tf_session.TF_DeviceListMemoryBytes( - raw_device_list, size, status_no_exception) - self.assertEqual(memory, -1) + with self.assertRaises(errors.InvalidArgumentError): + tf_session.TF_DeviceListMemoryBytes(raw_device_list, size) tf_session.TF_DeleteDeviceList(raw_device_list) - with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_CloseSession(c_session, status) + tf_session.TF_CloseSession(c_session) def testListDevicesGrpcSession(self): server = server_lib.Server.create_local_server() diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 77ce9195ee..b82182d5d3 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -18,11 +18,11 @@ limitations under the License. %{ #include "tensorflow/c/python_api.h" -#include "tensorflow/python/client/tf_session_helper.h" #include "tensorflow/core/framework/session_state.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/python/client/tf_session_helper.h" // Helper function to convert a Python list of Tensors to a C++ vector of // TF_Outputs. @@ -72,7 +72,7 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) { int size = PySequence_Fast_GET_SIZE(py_int_seq); for (int i = 0; i < size; ++i) { PyObject* item = PySequence_Fast_GET_ITEM(py_int_seq, i); - vec->push_back(PyInt_AsLong(item)); + vec->push_back(PyLong_AsLongLong(item)); } } @@ -157,6 +157,25 @@ tensorflow::ImportNumpy(); } } +// We use TF_OperationGetControlOutputs_wrapper instead of +// TF_OperationGetControlOutputs +%ignore TF_OperationGetControlOutputs; +%unignore TF_OperationGetControlOutputs_wrapper; +// See comment for "%noexception TF_SessionRun_wrapper;" +%noexception TF_OperationGetControlOutputs_wrapper; + +// Build a Python list of TF_Operation* and return it. +%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlOutputs_wrapper { + $result = PyList_New($1.size()); + if (!$result) { + SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); + } + + for (size_t i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i])); + } +} + %ignore TF_OperationOutputConsumers; %unignore TF_OperationOutputConsumers_wrapper; // See comment for "%noexception TF_SessionRun_wrapper;" @@ -438,6 +457,11 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ $1 = PyLong_AsLongLong($input); } +// Override default py3 behavior of attempting to encode into Unicode. +%typemap(out) std::string tensorflow::ResourceHandleShapeAndType { + $result = PyBytes_FromStringAndSize($1.data(), $1.size()); +} + // TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams, // skip for now %ignore TF_WhileParams; @@ -499,9 +523,8 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ _TF_SetTarget(opts, target) if config is not None: from tensorflow.python.framework import errors - with errors.raise_exception_on_not_ok_status() as status: - config_str = config.SerializeToString() - _TF_SetConfig(opts, config_str, status) + config_str = config.SerializeToString() + _TF_SetConfig(opts, config_str) return opts %} diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index ca57abd712..b48d758e4a 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -550,6 +550,15 @@ std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper( return control_inputs; } +std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper( + TF_Operation* oper) { + std::vector<TF_Operation*> control_outputs( + TF_OperationNumControlOutputs(oper)); + TF_OperationGetControlOutputs(oper, control_outputs.data(), + control_outputs.size()); + return control_outputs; +} + std::vector<const char*> TF_OperationOutputConsumers_wrapper( TF_Output oper_out) { int num_consumers = TF_OperationOutputNumConsumers(oper_out); diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 603d03e315..d2b4abc476 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -136,8 +136,7 @@ string EqualAttrValueWrapper(const string& actual, const string& expected); // // If shape is unknown, sets unknown_shape to true. tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper( - TF_Graph* graph, TF_Output output, TF_Status* out_status, - bool* unknown_shape); + TF_Graph* graph, TF_Output output, TF_Status* status, bool* unknown_shape); // Runs the graph associated with the session starting with the supplied inputs. // On success, `py_outputs` is populated with a numpy ndarray for each output @@ -149,7 +148,7 @@ void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options, const std::vector<PyObject*>& input_ndarrays, const std::vector<TF_Output>& outputs, const std::vector<TF_Operation*>& targets, - TF_Buffer* run_metadata, TF_Status* out_status, + TF_Buffer* run_metadata, TF_Status* status, std::vector<PyObject*>* py_outputs); // Set up the graph with the intended feeds (inputs) and fetches (output) for @@ -165,8 +164,7 @@ void TF_SessionPRunSetup_wrapper(TF_Session* session, const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs, const std::vector<TF_Operation*>& targets, - const char** out_handle, - TF_Status* out_status); + const char** out_handle, TF_Status* status); // Continue to run the graph with additional feeds and fetches. The // execution state is uniquely identified by the handle. @@ -182,7 +180,7 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, const std::vector<TF_Output>& inputs, const std::vector<PyObject*>& input_ndarrays, const std::vector<TF_Output>& outputs, - TF_Status* out_status, + TF_Status* status, std::vector<PyObject*>* py_outputs); // Retrieves the inputs of this operation. @@ -192,6 +190,10 @@ std::vector<TF_Output> GetOperationInputs(TF_Operation* oper); std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper( TF_Operation* oper); +// Retrieves the control outputs of this operation. +std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper( + TF_Operation* oper); + // Retrieves the op names of the consumers of `oper_out`. The returned strings // have the lifetime of the underlying TF_Graph. std::vector<const char*> TF_OperationOutputConsumers_wrapper( @@ -204,7 +206,7 @@ TF_Function* TF_GraphToFunction_wrapper( const std::vector<TF_Operation*>* opers, const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs, const NameVector& output_names, const TF_FunctionOptions* opts, - const char* description, TF_Status* out_status); + const char* description, TF_Status* status); // Set the shapes and types for the output's handle. // diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py index 4a14a915bd..0af282a024 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py @@ -28,6 +28,7 @@ from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -717,6 +718,14 @@ class IteratorTest(test.TestCase): self.assertTrue( iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE in str(warning.message)) + def testEagerIteratorAsync(self): + with context.eager_mode(), context.execution_mode(context.ASYNC): + val = 0 + dataset = dataset_ops.Dataset.range(10) + for foo in dataset: + self.assertEqual(val, foo.numpy()) + val += 1 + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index d79b9d6011..0c76afd29d 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -488,23 +488,27 @@ class EagerIterator(object): def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ - with ops.device(self._device): - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` - # because in eager mode this code will run synchronously on the calling - # thread. Therefore we do not need to make a defensive context switch - # to a background thread, and can achieve a small constant performance - # boost by invoking the iterator synchronously. - ret = gen_dataset_ops.iterator_get_next_sync( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self._output_types, ret), self._output_types, - self._output_shapes, self._output_classes) + # This runs in sync mode as iterators use an error status to communicate + # that there is no more data to iterate over. + # TODO(b/77291417): Fix + with context.execution_mode(context.SYNC): + with ops.device(self._device): + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` + # because in eager mode this code will run synchronously on the calling + # thread. Therefore we do not need to make a defensive context switch + # to a background thread, and can achieve a small constant performance + # boost by invoking the iterator synchronously. + ret = gen_dataset_ops.iterator_get_next_sync( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) def next(self): """Returns a nested structure of `tf.Tensor`s containing the next element. diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 4195586313..250b4b1b6a 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -913,6 +913,7 @@ cuda_py_test( "//tensorflow/python:util", "//tensorflow/python:variables", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) py_test( @@ -920,6 +921,7 @@ py_test( size = "small", srcs = ["cli/profile_analyzer_cli_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":debugger_cli_common", ":profile_analyzer_cli", diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 209b012621..92774d4d50 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -31,7 +31,6 @@ from tensorflow.python.eager import imperative_grad from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -50,12 +49,10 @@ def op_attr_type(op_type, attr_name): try: return _op_attr_type_cache[(op_type, attr_name)] except KeyError: - with errors.raise_exception_on_not_ok_status() as status: - h = context.context()._handle # pylint: disable=protected-access - attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType( - h, op_type, attr_name, status) - _op_attr_type_cache[(op_type, attr_name)] = attr_type - return attr_type + h = context.context()._handle # pylint: disable=protected-access + attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(h, op_type, attr_name) + _op_attr_type_cache[(op_type, attr_name)] = attr_type + return attr_type def make_attr(attr_type, value): diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 9ca5041c38..7ad37058fd 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -201,6 +201,9 @@ class MicroBenchmarks(test.Benchmark): m = self._m_2 self._run(lambda: gen_array_ops.identity(m), 30000) + def benchmark_slowpath_tf_identity(self): + self._run(lambda: gen_array_ops.identity(1), 30000) + def benchmark_tfe_py_execute_identity(self): m = self._m_2 ctx_handle = context.context()._handle diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 8c1bb06bc3..9e146f021e 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -28,7 +28,6 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev -from tensorflow.python.framework import errors from tensorflow.python.util import compat from tensorflow.python.util import is_in_graph_mode from tensorflow.python.util import tf_contextlib @@ -86,6 +85,7 @@ class _EagerContext(threading.local): self.device_spec = pydev.DeviceSpec.from_string("") self.device_name = self.device_spec.to_string() self.mode = _default_mode + self.is_eager = _default_mode == EAGER_MODE self.scope_name = "" self.recording_summaries = False self.summary_writer_resource = None @@ -223,34 +223,27 @@ class Context(object): assert self._context_devices is None opts = pywrap_tensorflow.TFE_NewContextOptions() try: - with errors.raise_exception_on_not_ok_status() as status: - if self._config is not None: - config_str = self._config.SerializeToString() - pywrap_tensorflow.TFE_ContextOptionsSetConfig( - opts, config_str, len(config_str), status) - if self._device_policy is not None: - pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy( - opts, self._device_policy) - if self._execution_mode == ASYNC: - pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) - self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status) + if self._config is not None: + config_str = self._config.SerializeToString() + pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str) + if self._device_policy is not None: + pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy( + opts, self._device_policy) + if self._execution_mode == ASYNC: + pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) + self._context_handle = pywrap_tensorflow.TFE_NewContext(opts) finally: pywrap_tensorflow.TFE_DeleteContextOptions(opts) # Store list of devices self._context_devices = [] - with errors.raise_exception_on_not_ok_status() as status: - device_list = pywrap_tensorflow.TFE_ContextListDevices( - self._context_handle, status) + device_list = pywrap_tensorflow.TFE_ContextListDevices( + self._context_handle) try: self._num_gpus = 0 for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): - with errors.raise_exception_on_not_ok_status() as status: - dev_name = pywrap_tensorflow.TF_DeviceListName( - device_list, i, status) + dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i) self._context_devices.append(pydev.canonical_name(dev_name)) - with errors.raise_exception_on_not_ok_status() as status: - dev_type = pywrap_tensorflow.TF_DeviceListType( - device_list, i, status) + dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i) if dev_type == "GPU": self._num_gpus += 1 @@ -287,9 +280,12 @@ class Context(object): @tf_contextlib.contextmanager def _mode(self, mode): + """A context manager to allow setting the mode to EAGER/GRAPH.""" ctx = self._eager_context old_mode = ctx.mode + old_is_eager = ctx.is_eager ctx.mode = mode + ctx.is_eager = mode == EAGER_MODE if mode == EAGER_MODE: # Entering graph mode does not provide us with sufficient information to # record a context switch; graph-based context switches are only logged @@ -298,13 +294,14 @@ class Context(object): try: yield finally: + ctx.is_eager = old_is_eager ctx.mode = old_mode if mode == EAGER_MODE: self.context_switches.pop() def executing_eagerly(self): """Returns True if current thread has eager executing enabled.""" - return self._eager_context.mode == EAGER_MODE + return self._eager_context.is_eager def scalar_cache(self): """Per-device cache for scalars.""" @@ -411,9 +408,7 @@ class Context(object): if mode is None: mode = SYNC self._eager_context.execution_mode = mode - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle, - mode == ASYNC, status) + pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle, mode == ASYNC) @tf_contextlib.contextmanager def execution_mode(self, mode): @@ -427,8 +422,7 @@ class Context(object): def async_wait(self): """Waits for ops dispatched in ASYNC mode to finish.""" - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextAsyncWait(self._handle, status) + pywrap_tensorflow.TFE_ContextAsyncWait(self._handle) def async_clear_error(self): """Clears errors raised during ASYNC execution.""" @@ -448,11 +442,9 @@ class Context(object): Args: fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). """ - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextAddFunction( - self._handle, # pylint: disable=protected-access - fn, - status) + pywrap_tensorflow.TFE_ContextAddFunction( + self._handle, # pylint: disable=protected-access + fn) def add_function_def(self, fdef): """Add a function definition to the context. @@ -464,12 +456,10 @@ class Context(object): fdef: A FunctionDef protocol buffer message. """ fdef_string = fdef.SerializeToString() - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextAddFunctionDef( - self._handle, # pylint: disable=protected-access - fdef_string, - len(fdef_string), - status) + pywrap_tensorflow.TFE_ContextAddFunctionDef( + self._handle, # pylint: disable=protected-access + fdef_string, + len(fdef_string)) def add_post_execution_callback(self, callback): """Add a post-execution callback to the context. @@ -512,23 +502,19 @@ class Context(object): To retrieve the accumulated metadata call context.export_run_metadata() and to stop tracing call context.disable_run_metadata(). """ - if not self._context_handle: - self._initialize_handle_and_devices() - pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle) + pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle) @tf_contextlib.contextmanager def device_policy(self, policy): - if not self._context_handle: - self._initialize_handle_and_devices() - old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( - self._context_handle) + handle = self._handle + old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(handle) pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy( - self._handle, policy) + handle, policy) try: yield finally: pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy( - self._handle, old) + handle, old) def disable_run_metadata(self): """Disables tracing of op execution via RunMetadata.""" @@ -548,9 +534,8 @@ class Context(object): if not self._context_handle: return None with c_api_util.tf_buffer() as buffer_: - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextExportRunMetadata( - self._context_handle, buffer_, status) + pywrap_tensorflow.TFE_ContextExportRunMetadata( + self._context_handle, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) run_metadata = config_pb2.RunMetadata() run_metadata.ParseFromString(compat.as_bytes(proto_data)) @@ -579,6 +564,10 @@ def context(): return _context +def context_safe(): + return _context + + # TODO(agarwal): remove this. def get_default_context(): """Same as context.""" diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 343012e552..711eddcec1 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -34,7 +34,6 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.framework import c_api_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -79,14 +78,10 @@ def capture_value(tensor_map, value, dtype, name): ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - captured_value._op._graph._c_graph, # pylint: disable=protected-access - captured_value._as_tf_output(), # pylint: disable=protected-access - shapes, - ranks, - types, - status) + pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( + captured_value._op._graph._c_graph, # pylint: disable=protected-access + captured_value._as_tf_output(), # pylint: disable=protected-access + shapes, ranks, types) tensor_map[ops.tensor_id(value)] = (value, captured_value) else: @@ -275,23 +270,20 @@ class _EagerDefinedFunction(object): inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function """ - with errors.raise_exception_on_not_ok_status() as status: - fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( - graph._c_graph, # pylint: disable=protected-access - compat.as_str(name), - False, - [o._c_op for o in operations], # pylint: disable=protected-access - [t._as_tf_output() for t in inputs], # pylint: disable=protected-access - [t._as_tf_output() for t in outputs], # pylint: disable=protected-access - [], - None, - compat.as_str(""), - status) + fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( + graph._c_graph, # pylint: disable=protected-access + compat.as_str(name), + False, + [o._c_op for o in operations], # pylint: disable=protected-access + [t._as_tf_output() for t in inputs], # pylint: disable=protected-access + [t._as_tf_output() for t in outputs], # pylint: disable=protected-access + [], + None, + compat.as_str("")) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) + pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index 837cad974a..000152855d 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -21,7 +21,6 @@ from __future__ import print_function import collections from tensorflow.python import pywrap_tensorflow -from tensorflow.python.framework import errors VSpace = collections.namedtuple( @@ -60,6 +59,5 @@ def imperative_grad( or if only non-differentiable functions of the source were used in the computation of target. """ - with errors.raise_exception_on_not_ok_status() as status: - return pywrap_tensorflow.TFE_Py_TapeGradient( - tape._tape, vspace, target, sources, output_gradients, status) # pylint: disable=protected-access + return pywrap_tensorflow.TFE_Py_TapeGradient( + tape._tape, vspace, target, sources, output_gradients) # pylint: disable=protected-access diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index c2ce8efd7f..9afab0077b 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -117,7 +117,7 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { const string& function_name) : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) { op_name_ = function_name_; - op_name_.Consume("_"); + str_util::ConsumePrefix(&op_name_, "_"); } ~GenEagerPythonOp() override {} @@ -366,8 +366,8 @@ string GenEagerPythonOp::Code() { void GenEagerPythonOp::HandleGraphMode(const string& function_setup) { // Handle graph-mode case strings::StrAppend(&result_, - " _ctx = _context.context()\n" - " if not _ctx.executing_eagerly():\n", + " _ctx = _context._context\n" + " if _ctx is None or not _ctx._eager_context.is_eager:\n", function_setup, " _, _, _op = _op_def_lib._apply_op_helper(\n"); AddBodyNoReturn(" "); @@ -492,7 +492,7 @@ bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation, strings::StrAppend(function_setup, indentation, " ", attr_api_name, " = ", default_value, "\n"); } - if (attr_type.starts_with("list(")) { + if (str_util::StartsWith(attr_type, "list(")) { ExpectListArg(indentation, attr_api_name, function_setup); } @@ -683,13 +683,14 @@ bool GenEagerPythonOp::AddEagerFallbackCode( return true; } - AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix), parameters); + AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix), + strings::StrCat(parameters, ", ctx=None")); strings::StrAppend( &result_, " r\"\"\"This is the slowpath function for Eager mode.\n"); strings::StrAppend(&result_, " This is for function ", function_name_, "\n \"\"\"\n"); - strings::StrAppend(&result_, " _ctx = _context.context()\n"); + strings::StrAppend(&result_, " _ctx = ctx if ctx else _context.context()\n"); string function_setup; if (!GetEagerFunctionSetup(" ", &function_setup)) { @@ -712,9 +713,9 @@ bool GenEagerPythonOp::AddEagerFallbackCode( } void GenEagerPythonOp::AddEagerFastPathExecute() { - string fastpath_execute_params = - strings::StrCat("_ctx._handle, _ctx.device_name, \"", op_def_.name(), - "\", ", "name, _ctx._post_execution_callbacks"); + string fastpath_execute_params = strings::StrCat( + "_ctx._context_handle, _ctx._eager_context.device_name, \"", + op_def_.name(), "\", ", "name, _ctx._post_execution_callbacks"); string fallback_params; for (int i = 0; i < api_def_.in_arg_size(); i++) { @@ -755,6 +756,8 @@ void GenEagerPythonOp::AddEagerFastPathExecute() { strings::StrAppend(&result_, " ", "return _result\n"); // Handle fallback. + if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); + strings::StrAppend(&fallback_params, "ctx=_ctx"); strings::StrAppend(&result_, " ", "except _core._FallbackException:\n"); strings::StrAppend( &result_, " ", "return ", function_name_, kEagerFallbackSuffix, diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 8a398f6447..d99bd0b0ff 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1844,6 +1844,15 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { op_exec_info.ctx = reinterpret_cast<TFE_Context*>( PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr)); + + if (op_exec_info.ctx == nullptr) { + // The context hasn't been initialized. It will be in the slow path. + RaiseFallbackException( + "This function does not handle the case of the path where " + "all inputs are not already EagerTensors."); + return nullptr; + } + op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1)); op_exec_info.op_name = PyTuple_GET_ITEM(args, 2); op_exec_info.op_def = GetOpDef(op_exec_info.op_name); diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index f93bc221cc..5d8b19223f 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -966,5 +966,6 @@ cuda_py_test( tags = [ "multi_gpu", "noasan", # flaky time outs + "notsan", # flaky ], ) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index ab69a093a2..4d3eff71ad 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -188,7 +188,7 @@ class Estimator(object): self._config = config # The distribute field contains an instance of DistributionStrategy. - self._distribution = self._config.distribute + self._distribution = self._config.train_distribute # Model directory. model_dir = compat_internal.path_to_str(model_dir) diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index 41415b89e9..f62c9cece6 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -44,7 +44,7 @@ _DEFAULT_REPLACEABLE_LIST = [ 'keep_checkpoint_max', 'keep_checkpoint_every_n_hours', 'log_step_count_steps', - 'distribute' + 'train_distribute' ] _SAVE_CKPT_ERR = ( @@ -302,7 +302,7 @@ class RunConfig(object): keep_checkpoint_max=5, keep_checkpoint_every_n_hours=10000, log_step_count_steps=100, - distribute=None): + train_distribute=None): """Constructs a RunConfig. All distributed training related properties `cluster_spec`, `is_chief`, @@ -426,10 +426,10 @@ class RunConfig(object): the feature. log_step_count_steps: The frequency, in number of global steps, that the global step/sec and the loss will be logged during training. - distribute: an optional instance of + train_distribute: an optional instance of `tf.contrib.distribute.DistributionStrategy`. If specified, - then Estimator will distribute the user's model according to the policy - specified by that strategy. + then Estimator will distribute the user's model during training, + according to the policy specified by that strategy. Raises: ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs` @@ -466,7 +466,7 @@ class RunConfig(object): keep_checkpoint_max=keep_checkpoint_max, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, log_step_count_steps=log_step_count_steps, - distribute=distribute) + train_distribute=train_distribute) self._init_distributed_setting_from_environment_var(tf_config) @@ -678,10 +678,10 @@ class RunConfig(object): return self._service @property - def distribute(self): + def train_distribute(self): """Returns the optional `tf.contrib.distribute.DistributionStrategy` object. """ - return self._distribute + return self._train_distribute def replace(self, **kwargs): """Returns a new instance of `RunConfig` replacing specified properties. @@ -697,7 +697,7 @@ class RunConfig(object): - `keep_checkpoint_max`, - `keep_checkpoint_every_n_hours`, - `log_step_count_steps`, - - `distribute`. + - `train_distribute`. In addition, either `save_checkpoints_steps` or `save_checkpoints_secs` can be set (should not be both). diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index 219105d386..295d4ca094 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -43,6 +43,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/keras", "@six_archive//:six", ], ) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 92c6ff21c4..3a315e5c2e 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -139,6 +139,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -460,6 +462,154 @@ def linear_model(features, return predictions +class _FCLinearWrapper(base.Layer): + """Wraps a _FeatureColumn in a layer for use in a linear model. + + See `linear_model` above. + """ + + def __init__(self, + feature_column, + units=1, + sparse_combiner='sum', + weight_collections=None, + trainable=True, + name=None, + **kwargs): + super(_FCLinearWrapper, self).__init__( + trainable=trainable, name=name, **kwargs) + self._feature_column = feature_column + self._units = units + self._sparse_combiner = sparse_combiner + self._weight_collections = weight_collections + self._state = {} + + def build(self, _): + self._state = self._feature_column._create_state( # pylint: disable=protected-access + self._weight_collections, self.add_variable) + + if isinstance(self._feature_column, _CategoricalColumn): + weight = self.add_variable( + name='weights', + shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access + initializer=init_ops.zeros_initializer(), + trainable=self.trainable) + else: + num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access + weight = self.add_variable( + name='weights', + shape=[num_elements, self._units], + initializer=init_ops.zeros_initializer(), + trainable=self.trainable) + ops.add_to_collections(self._weight_collections, weight) + self._weight_var = weight + self.built = True + + def call(self, builder): + weighted_sum = _create_weighted_sum( + column=self._feature_column, + builder=builder, + units=self._units, + sparse_combiner=self._sparse_combiner, + weight_collections=self._weight_collections, + trainable=self.trainable, + weight_var=self._weight_var, + state=self._state) + return weighted_sum + + +class _BiasLayer(base.Layer): + """A layer for the bias term. + """ + + def __init__(self, + units=1, + trainable=True, + weight_collections=None, + name=None, + **kwargs): + super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs) + self._units = units + self._weight_collections = weight_collections + + def build(self, _): + self._bias_variable = self.add_variable( + 'bias_weights', + shape=[self._units], + initializer=init_ops.zeros_initializer(), + trainable=self.trainable) + ops.add_to_collections(self._weight_collections, self._bias_variable) + self.built = True + + def call(self, _): + return self._bias_variable + + +class _LinearModel(training.Model): + """Creates a linear model using feature columns. + """ + + def __init__(self, + feature_columns, + units=1, + sparse_combiner='sum', + weight_collections=None, + trainable=True, + name=None, + **kwargs): + super(_LinearModel, self).__init__(name=name, **kwargs) + self._feature_columns = _clean_feature_columns(feature_columns) + self._weight_collections = list(weight_collections or []) + if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections: + self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES) + + column_layers = {} + for column in sorted(self._feature_columns, key=lambda x: x.name): + with variable_scope.variable_scope( + None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access + column_name = vs.name + column_layer = _FCLinearWrapper(column, units, sparse_combiner, + self._weight_collections, trainable, + column_name, **kwargs) + column_layers[column_name] = column_layer + self._column_layers = self._add_layers(column_layers) + self._bias_layer = _BiasLayer( + units=units, + trainable=trainable, + weight_collections=self._weight_collections, + name='bias_layer', + **kwargs) + + def call(self, features): + for column in self._feature_columns: + if not isinstance(column, (_DenseColumn, _CategoricalColumn)): + raise ValueError( + 'Items of feature_columns must be either a ' + '_DenseColumn or _CategoricalColumn. Given: {}'.format(column)) + weighted_sums = [] + ordered_columns = [] + builder = _LazyBuilder(features) + for layer in sorted(self._column_layers.values(), key=lambda x: x.name): + ordered_columns.append(layer._feature_column) # pylint: disable=protected-access + weighted_sum = layer(builder) + weighted_sums.append(weighted_sum) + + _verify_static_batch_size_equality(weighted_sums, ordered_columns) + predictions_no_bias = math_ops.add_n( + weighted_sums, name='weighted_sum_no_bias') + predictions = nn_ops.bias_add( + predictions_no_bias, self._bias_layer(builder), name='weighted_sum') # pylint: disable=not-callable + return predictions + + def _add_layers(self, layers): + # "Magic" required for keras.Model classes to track all the variables in + # a list of layers.Layer objects. + # TODO(ashankar): Figure out API so user code doesn't have to do this. + for name, layer in layers.items(): + setattr(self, 'layer-%s' % name, layer) + return layers + + def _transform_features(features, feature_columns): """Returns transformed features based on features columns passed in. @@ -1643,6 +1793,19 @@ class _FeatureColumn(object): """ pass + def _create_state(self, weight_collections=None, creator=None): + """Returns an object that captures the state of the column. + + Args: + weight_collections: Collections to add the variable to + creator: Variable creator method called, if provided. + + Returns: + An object that encapsulates the state of the column. Can return None. + """ + del weight_collections, creator # Unused + return None + class _DenseColumn(_FeatureColumn): """Represents a column which can be represented as `Tensor`. @@ -1662,7 +1825,11 @@ class _DenseColumn(_FeatureColumn): pass @abc.abstractmethod - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + def _get_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None, + state=None): """Returns a `Tensor`. The output of this function will be used by model-builder-functions. For @@ -1680,6 +1847,9 @@ class _DenseColumn(_FeatureColumn): will be created) are added. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see @{tf.Variable}). + state: An object encapsulating the state of the column. Columns that + create state using the _create_state method would have that state + passed in to this method. Returns: `Tensor` of shape [batch_size] + `_variable_shape`. @@ -1687,13 +1857,14 @@ class _DenseColumn(_FeatureColumn): pass -def _create_weighted_sum( - column, - builder, - units, - sparse_combiner, - weight_collections, - trainable): +def _create_weighted_sum(column, + builder, + units, + sparse_combiner, + weight_collections, + trainable, + weight_var=None, + state=None): """Creates a weighted sum for a dense or sparse column for linear_model.""" if isinstance(column, _CategoricalColumn): return _create_categorical_column_weighted_sum( @@ -1702,32 +1873,50 @@ def _create_weighted_sum( units=units, sparse_combiner=sparse_combiner, weight_collections=weight_collections, - trainable=trainable) + trainable=trainable, + weight_var=weight_var) else: return _create_dense_column_weighted_sum( column=column, builder=builder, units=units, weight_collections=weight_collections, - trainable=trainable) + trainable=trainable, + weight_var=weight_var, + state=state) -def _create_dense_column_weighted_sum( - column, builder, units, weight_collections, trainable): +def _create_dense_column_weighted_sum(column, + builder, + units, + weight_collections, + trainable, + weight_var=None, + state=None): """Create a weighted sum of a dense column for linear_model.""" - tensor = column._get_dense_tensor( # pylint: disable=protected-access - builder, - weight_collections=weight_collections, - trainable=trainable) + if state is not None: + tensor = column._get_dense_tensor( # pylint: disable=protected-access + builder, + weight_collections=weight_collections, + trainable=trainable, + state=state) + else: + tensor = column._get_dense_tensor( # pylint: disable=protected-access + builder, + weight_collections=weight_collections, + trainable=trainable) num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access batch_size = array_ops.shape(tensor)[0] tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements)) - weight = variable_scope.get_variable( - name='weights', - shape=[num_elements, units], - initializer=init_ops.zeros_initializer(), - trainable=trainable, - collections=weight_collections) + if weight_var is not None: + weight = weight_var + else: + weight = variable_scope.get_variable( + name='weights', + shape=[num_elements, units], + initializer=init_ops.zeros_initializer(), + trainable=trainable, + collections=weight_collections) return math_ops.matmul(tensor, weight, name='weighted_sum') @@ -1777,8 +1966,13 @@ class _CategoricalColumn(_FeatureColumn): pass -def _create_categorical_column_weighted_sum( - column, builder, units, sparse_combiner, weight_collections, trainable): +def _create_categorical_column_weighted_sum(column, + builder, + units, + sparse_combiner, + weight_collections, + trainable, + weight_var=None): """Create a weighted sum of a categorical column for linear_model.""" sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access builder, @@ -1792,12 +1986,15 @@ def _create_categorical_column_weighted_sum( weight_tensor = sparse_ops.sparse_reshape( weight_tensor, [array_ops.shape(weight_tensor)[0], -1]) - weight = variable_scope.get_variable( - name='weights', - shape=(column._num_buckets, units), # pylint: disable=protected-access - initializer=init_ops.zeros_initializer(), - trainable=trainable, - collections=weight_collections) + if weight_var is not None: + weight = weight_var + else: + weight = variable_scope.get_variable( + name='weights', + shape=(column._num_buckets, units), # pylint: disable=protected-access + initializer=init_ops.zeros_initializer(), + trainable=trainable, + collections=weight_collections) return _safe_embedding_lookup_sparse( weight, id_tensor, @@ -2195,8 +2392,33 @@ class _EmbeddingColumn( self._shape = tensor_shape.vector(self.dimension) return self._shape - def _get_dense_tensor_internal( - self, inputs, weight_collections=None, trainable=None): + def _create_state(self, weight_collections=None, creator=None): + variables_map = {} + embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access + if creator is not None: + embedding_weights = creator( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable) + ops.add_to_collections(weight_collections, embedding_weights) + else: + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable, + collections=weight_collections) + variables_map['embedding_weights'] = embedding_weights + return variables_map + + def _get_dense_tensor_internal(self, + inputs, + weight_collections=None, + trainable=None, + state=None): """Private method that follows the signature of _get_dense_tensor.""" # Get sparse IDs and weights. sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access @@ -2204,14 +2426,10 @@ class _EmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access - embedding_weights = variable_scope.get_variable( - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=self.initializer, - trainable=self.trainable and trainable, - collections=weight_collections) + if state is None: + state = self._create_state(weight_collections) + embedding_weights = state['embedding_weights'] + if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): @@ -2229,7 +2447,11 @@ class _EmbeddingColumn( name='%s_weights' % self.name, max_norm=self.max_norm) - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + def _get_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None, + state=None): if isinstance(self.categorical_column, _SequenceCategoricalColumn): raise ValueError( 'In embedding_column: {}. ' @@ -2242,8 +2464,10 @@ class _EmbeddingColumn( self.name, type(self.categorical_column), self.categorical_column)) return self._get_dense_tensor_internal( - inputs=inputs, weight_collections=weight_collections, - trainable=trainable) + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable, + state=state) def _get_sequence_dense_tensor( self, inputs, weight_collections=None, trainable=None): @@ -2299,7 +2523,39 @@ class _SharedEmbeddingColumn( self._shape = tensor_shape.vector(self.dimension) return self._shape - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + def _create_state(self, weight_collections=None, creator=None): + variables_map = {} + shared_embedding_collection = ops.get_collection( + self.shared_embedding_collection_name) + if not shared_embedding_collection: + embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access + if creator is not None: + embedding_weights = creator( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable) + ops.add_to_collections(weight_collections, embedding_weights) + else: + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable, + collections=weight_collections) + ops.add_to_collection(self.shared_embedding_collection_name, + embedding_weights) + variables_map['embedding_weights'] = embedding_weights + + return variables_map + + def _get_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None, + state=None): # This method is called from a variable_scope with name _var_scope_name, # which is shared among all shared embeddings. Open a name_scope here, so # that the ops for different columns have distinct names. diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 6f366e7722..07588af37e 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -34,6 +34,7 @@ from tensorflow.python.feature_column.feature_column import _CategoricalColumn from tensorflow.python.feature_column.feature_column import _DenseColumn from tensorflow.python.feature_column.feature_column import _FeatureColumn from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.feature_column.feature_column import _LinearModel from tensorflow.python.feature_column.feature_column import _transform_features from tensorflow.python.feature_column.feature_column import InputLayer from tensorflow.python.framework import constant_op @@ -339,6 +340,20 @@ class NumericColumnTest(test.TestCase): sess.run(price_var.assign([[10.]])) self.assertAllClose([[10.], [50.]], predictions.eval()) + def test_keras_linear_model(self): + price = fc.numeric_column('price') + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + predictions = get_keras_linear_model_predictions(features, [price]) + bias = get_keras_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.]], price_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run(price_var.assign([[10.]])) + self.assertAllClose([[10.], [50.]], predictions.eval()) + class BucketizedColumnTest(test.TestCase): @@ -561,6 +576,62 @@ class BucketizedColumnTest(test.TestCase): sess.run(bias.assign([1.])) self.assertAllClose([[81.], [141.]], predictions.eval()) + def test_keras_linear_model_one_input_value(self): + """Tests _LinearModel for input with shape=[1].""" + price = fc.numeric_column('price', shape=[1]) + bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + features = {'price': [[-1.], [1.], [5.], [6.]]} + predictions = get_keras_linear_model_predictions(features, + [bucketized_price]) + bias = get_keras_linear_model_bias() + bucketized_price_var = get_linear_model_column_var(bucketized_price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + # One weight variable per bucket, all initialized to zero. + self.assertAllClose([[0.], [0.], [0.], [0.], [0.]], + bucketized_price_var.eval()) + self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval()) + sess.run( + bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]])) + # price -1. is in the 0th bucket, whose weight is 10. + # price 1. is in the 1st bucket, whose weight is 20. + # price 5. is in the 3rd bucket, whose weight is 40. + # price 6. is in the 4th bucket, whose weight is 50. + self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval()) + sess.run(bias.assign([1.])) + self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval()) + + def test_keras_linear_model_two_input_values(self): + """Tests _LinearModel for input with shape=[2].""" + price = fc.numeric_column('price', shape=[2]) + bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + features = {'price': [[-1., 1.], [5., 6.]]} + predictions = get_keras_linear_model_predictions(features, + [bucketized_price]) + bias = get_keras_linear_model_bias() + bucketized_price_var = get_linear_model_column_var(bucketized_price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + # One weight per bucket per input column, all initialized to zero. + self.assertAllClose( + [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], + bucketized_price_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run( + bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.], + [60.], [70.], [80.], [90.], [100.]])) + # 1st example: + # price -1. is in the 0th bucket, whose weight is 10. + # price 1. is in the 6th bucket, whose weight is 70. + # 2nd example: + # price 5. is in the 3rd bucket, whose weight is 40. + # price 6. is in the 9th bucket, whose weight is 100. + self.assertAllClose([[80.], [140.]], predictions.eval()) + sess.run(bias.assign([1.])) + self.assertAllClose([[81.], [141.]], predictions.eval()) + class HashedCategoricalColumnTest(test.TestCase): @@ -767,6 +838,28 @@ class HashedCategoricalColumnTest(test.TestCase): # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6 self.assertAllClose(((4.,), (6.,)), predictions.eval()) + def test_keras_linear_model(self): + wire_column = fc.categorical_column_with_hash_bucket('wire', 4) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + wire_column.name: + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_keras_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 3: wire_var[3] = 4 + # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6 + self.assertAllClose(((4.,), (6.,)), predictions.eval()) + class CrossedColumnTest(test.TestCase): @@ -1060,6 +1153,96 @@ class CrossedColumnTest(test.TestCase): dense_shape=(2, 2)), }, (crossed,)) + def test_keras_linear_model(self): + """Tests _LinearModel. + + Uses data from test_get_sparse_tesnsors_simple. + """ + a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,)) + b = fc.bucketized_column(a, boundaries=(0, 1)) + crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5) + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + 'a': + constant_op.constant(((-1., .5), (.5, 1.))), + 'c': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=['cA', 'cB', 'cC'], + dense_shape=(2, 2)), + }, (crossed,)) + bias = get_keras_linear_model_bias() + crossed_var = get_linear_model_column_var(crossed) + with _initialized_session() as sess: + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)), + crossed_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,)))) + # Expected ids after cross = (1, 0, 1, 3, 4, 2) + self.assertAllClose(((3.,), (14.,)), predictions.eval()) + sess.run(bias.assign((.1,))) + self.assertAllClose(((3.1,), (14.1,)), predictions.eval()) + + def test_keras_linear_model_with_weights(self): + + class _TestColumnWithWeights(_CategoricalColumn): + """Produces sparse IDs and sparse weights.""" + + @property + def name(self): + return 'test_column' + + @property + def _parse_example_spec(self): + return { + self.name: + parsing_ops.VarLenFeature(dtypes.int32), + '{}_weights'.format(self.name): + parsing_ops.VarLenFeature(dtypes.float32), + } + + @property + def _num_buckets(self): + return 5 + + def _transform_feature(self, inputs): + return (inputs.get(self.name), + inputs.get('{}_weights'.format(self.name))) + + def _get_sparse_tensors(self, + inputs, + weight_collections=None, + trainable=None): + """Populates both id_tensor and weight_tensor.""" + ids_and_weights = inputs.get(self) + return _CategoricalColumn.IdWeightPair( + id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1]) + + t = _TestColumnWithWeights() + crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5) + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, + 'crossed_column does not support weight_tensor.*{}'.format(t.name)): + get_keras_linear_model_predictions({ + t.name: + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=[0, 1, 2], + dense_shape=(2, 2)), + '{}_weights'.format(t.name): + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=[1., 10., 2.], + dense_shape=(2, 2)), + 'c': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=['cA', 'cB', 'cC'], + dense_shape=(2, 2)), + }, (crossed,)) + def get_linear_model_bias(): with variable_scope.variable_scope('linear_model', reuse=True): @@ -1071,6 +1254,28 @@ def get_linear_model_column_var(column): 'linear_model/' + column.name)[0] +def get_keras_linear_model_bias(): + with variable_scope.variable_scope('linear_model', reuse=True): + with variable_scope.variable_scope('bias_layer', reuse=True): + return variable_scope.get_variable('bias_weights') + + +def get_keras_linear_model_predictions(features, + feature_columns, + units=1, + sparse_combiner='sum', + weight_collections=None, + trainable=True): + keras_linear_model = _LinearModel( + feature_columns, + units, + sparse_combiner, + weight_collections, + trainable, + name='linear_model') + return keras_linear_model(features) # pylint: disable=not-callable + + @test_util.with_c_api class LinearModelTest(test.TestCase): @@ -1698,6 +1903,629 @@ class LinearModelTest(test.TestCase): sess.run(net, feed_dict={features['price']: np.array(1)}) +@test_util.with_c_api +class _LinearModelTest(test.TestCase): + + def test_raises_if_empty_feature_columns(self): + with self.assertRaisesRegexp(ValueError, + 'feature_columns must not be empty'): + get_keras_linear_model_predictions(features={}, feature_columns=[]) + + def test_should_be_feature_column(self): + with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'): + get_keras_linear_model_predictions( + features={'a': [[0]]}, feature_columns='NotSupported') + + def test_should_be_dense_or_categorical_column(self): + + class NotSupportedColumn(_FeatureColumn): + + @property + def name(self): + return 'NotSupportedColumn' + + def _transform_feature(self, cache): + pass + + @property + def _parse_example_spec(self): + pass + + with self.assertRaisesRegexp( + ValueError, 'must be either a _DenseColumn or _CategoricalColumn'): + get_keras_linear_model_predictions( + features={'a': [[0]]}, feature_columns=[NotSupportedColumn()]) + + def test_does_not_support_dict_columns(self): + with self.assertRaisesRegexp( + ValueError, 'Expected feature_columns to be iterable, found dict.'): + fc.linear_model( + features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')}) + + def test_raises_if_duplicate_name(self): + with self.assertRaisesRegexp( + ValueError, 'Duplicate feature column name found for columns'): + get_keras_linear_model_predictions( + features={'a': [[0]]}, + feature_columns=[fc.numeric_column('a'), + fc.numeric_column('a')]) + + def test_dense_bias(self): + price = fc.numeric_column('price') + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + predictions = get_keras_linear_model_predictions(features, [price]) + bias = get_keras_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + sess.run(price_var.assign([[10.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[15.], [55.]], predictions.eval()) + + def test_sparse_bias(self): + wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor} + predictions = get_keras_linear_model_predictions(features, [wire_cast]) + bias = get_keras_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval()) + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[1005.], [10015.]], predictions.eval()) + + def test_dense_and_sparse_bias(self): + wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) + price = fc.numeric_column('price') + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]} + predictions = get_keras_linear_model_predictions(features, + [wire_cast, price]) + bias = get_keras_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + sess.run(price_var.assign([[10.]])) + self.assertAllClose([[1015.], [10065.]], predictions.eval()) + + def test_dense_and_sparse_column(self): + """When the column is both dense and sparse, uses sparse tensors.""" + + class _DenseAndSparseColumn(_DenseColumn, _CategoricalColumn): + + @property + def name(self): + return 'dense_and_sparse_column' + + @property + def _parse_example_spec(self): + return {self.name: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + return inputs.get(self.name) + + @property + def _variable_shape(self): + raise ValueError('Should not use this method.') + + def _get_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None): + raise ValueError('Should not use this method.') + + @property + def _num_buckets(self): + return 4 + + def _get_sparse_tensors(self, + inputs, + weight_collections=None, + trainable=None): + sp_tensor = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=[2, 0, 3], + dense_shape=[2, 2]) + return _CategoricalColumn.IdWeightPair(sp_tensor, None) + + dense_and_sparse_column = _DenseAndSparseColumn() + with ops.Graph().as_default(): + sp_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {dense_and_sparse_column.name: sp_tensor} + predictions = get_keras_linear_model_predictions( + features, [dense_and_sparse_column]) + bias = get_keras_linear_model_bias() + dense_and_sparse_column_var = get_linear_model_column_var( + dense_and_sparse_column) + with _initialized_session() as sess: + sess.run( + dense_and_sparse_column_var.assign([[10.], [100.], [1000.], + [10000.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[1005.], [10015.]], predictions.eval()) + + def test_dense_multi_output(self): + price = fc.numeric_column('price') + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + predictions = get_keras_linear_model_predictions( + features, [price], units=3) + bias = get_keras_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((1, 3)), price_var.eval()) + sess.run(price_var.assign([[10., 100., 1000.]])) + sess.run(bias.assign([5., 6., 7.])) + self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]], + predictions.eval()) + + def test_sparse_multi_output(self): + wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor} + predictions = get_keras_linear_model_predictions( + features, [wire_cast], units=3) + bias = get_keras_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval()) + sess.run( + wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.], + [1000., 1100., + 1200.], [10000., 11000., 12000.]])) + sess.run(bias.assign([5., 6., 7.])) + self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]], + predictions.eval()) + + def test_dense_multi_dimension(self): + price = fc.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1., 2.], [5., 6.]]} + predictions = get_keras_linear_model_predictions(features, [price]) + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([[0.], [0.]], price_var.eval()) + sess.run(price_var.assign([[10.], [100.]])) + self.assertAllClose([[210.], [650.]], predictions.eval()) + + def test_sparse_multi_rank(self): + wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = array_ops.sparse_placeholder(dtypes.string) + wire_value = sparse_tensor.SparseTensorValue( + values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2] + indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]], + dense_shape=[2, 2, 2]) + features = {'wire_cast': wire_tensor} + predictions = get_keras_linear_model_predictions(features, [wire_cast]) + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval()) + self.assertAllClose( + np.zeros((2, 1)), + predictions.eval(feed_dict={wire_tensor: wire_value})) + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + self.assertAllClose( + [[1010.], [11000.]], + predictions.eval(feed_dict={wire_tensor: wire_value})) + + def test_sparse_combiner(self): + wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor} + predictions = get_keras_linear_model_predictions( + features, [wire_cast], sparse_combiner='mean') + bias = get_keras_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[1005.], [5010.]], predictions.eval()) + + def test_dense_multi_dimension_multi_output(self): + price = fc.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1., 2.], [5., 6.]]} + predictions = get_keras_linear_model_predictions( + features, [price], units=3) + bias = get_keras_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((2, 3)), price_var.eval()) + sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]])) + sess.run(bias.assign([2., 3., 4.])) + self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]], + predictions.eval()) + + def test_raises_if_shape_mismatch(self): + price = fc.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + if ops._USE_C_API: + with self.assertRaisesRegexp( + Exception, + r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'): + predictions = get_keras_linear_model_predictions(features, [price]) + else: + predictions = get_keras_linear_model_predictions(features, [price]) + with _initialized_session(): + with self.assertRaisesRegexp(Exception, 'requested shape has 4'): + predictions.eval() + + def test_dense_reshaping(self): + price = fc.numeric_column('price', shape=[1, 2]) + with ops.Graph().as_default(): + features = {'price': [[[1., 2.]], [[5., 6.]]]} + predictions = get_keras_linear_model_predictions(features, [price]) + bias = get_keras_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.], [0.]], price_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run(price_var.assign([[10.], [100.]])) + self.assertAllClose([[210.], [650.]], predictions.eval()) + + def test_dense_multi_column(self): + price1 = fc.numeric_column('price1', shape=2) + price2 = fc.numeric_column('price2') + with ops.Graph().as_default(): + features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]} + predictions = get_keras_linear_model_predictions(features, + [price1, price2]) + bias = get_keras_linear_model_bias() + price1_var = get_linear_model_column_var(price1) + price2_var = get_linear_model_column_var(price2) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.], [0.]], price1_var.eval()) + self.assertAllClose([[0.]], price2_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run(price1_var.assign([[10.], [100.]])) + sess.run(price2_var.assign([[1000.]])) + sess.run(bias.assign([7.])) + self.assertAllClose([[3217.], [4657.]], predictions.eval()) + + def test_dense_collection(self): + price = fc.numeric_column('price') + with ops.Graph().as_default() as g: + features = {'price': [[1.], [5.]]} + get_keras_linear_model_predictions( + features, [price], weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + bias = get_keras_linear_model_bias() + price_var = get_linear_model_column_var(price) + self.assertIn(bias, my_vars) + self.assertIn(price_var, my_vars) + + def test_sparse_collection(self): + wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + wire_tensor = sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + features = {'wire_cast': wire_tensor} + get_keras_linear_model_predictions( + features, [wire_cast], weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + bias = get_keras_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + self.assertIn(bias, my_vars) + self.assertIn(wire_cast_var, my_vars) + + def test_dense_trainable_default(self): + price = fc.numeric_column('price') + with ops.Graph().as_default() as g: + features = {'price': [[1.], [5.]]} + get_keras_linear_model_predictions(features, [price]) + bias = get_keras_linear_model_bias() + price_var = get_linear_model_column_var(price) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertIn(bias, trainable_vars) + self.assertIn(price_var, trainable_vars) + + def test_sparse_trainable_default(self): + wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + wire_tensor = sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + features = {'wire_cast': wire_tensor} + get_keras_linear_model_predictions(features, [wire_cast]) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + bias = get_keras_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + self.assertIn(bias, trainable_vars) + self.assertIn(wire_cast_var, trainable_vars) + + def test_dense_trainable_false(self): + price = fc.numeric_column('price') + with ops.Graph().as_default() as g: + features = {'price': [[1.], [5.]]} + get_keras_linear_model_predictions(features, [price], trainable=False) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertEqual([], trainable_vars) + + def test_sparse_trainable_false(self): + wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + wire_tensor = sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + features = {'wire_cast': wire_tensor} + get_keras_linear_model_predictions(features, [wire_cast], trainable=False) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertEqual([], trainable_vars) + + def test_column_order(self): + price_a = fc.numeric_column('price_a') + price_b = fc.numeric_column('price_b') + wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + features = { + 'price_a': [[1.]], + 'price_b': [[3.]], + 'wire_cast': + sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + } + get_keras_linear_model_predictions( + features, [price_a, wire_cast, price_b], + weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + self.assertIn('price_a', my_vars[0].name) + self.assertIn('price_b', my_vars[1].name) + self.assertIn('wire_cast', my_vars[2].name) + + with ops.Graph().as_default() as g: + features = { + 'price_a': [[1.]], + 'price_b': [[3.]], + 'wire_cast': + sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + } + get_keras_linear_model_predictions( + features, [wire_cast, price_b, price_a], + weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + self.assertIn('price_a', my_vars[0].name) + self.assertIn('price_b', my_vars[1].name) + self.assertIn('wire_cast', my_vars[2].name) + + def test_static_batch_size_mismatch(self): + price1 = fc.numeric_column('price1') + price2 = fc.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': [[1.], [5.], [7.]], # batchsize = 3 + 'price2': [[3.], [4.]] # batchsize = 2 + } + with self.assertRaisesRegexp( + ValueError, + 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string + get_keras_linear_model_predictions(features, [price1, price2]) + + def test_subset_of_static_batch_size_mismatch(self): + price1 = fc.numeric_column('price1') + price2 = fc.numeric_column('price2') + price3 = fc.numeric_column('price3') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 + 'price2': [[3.], [4.]], # batchsize = 2 + 'price3': [[3.], [4.], [5.]] # batchsize = 3 + } + with self.assertRaisesRegexp( + ValueError, + 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string + get_keras_linear_model_predictions(features, [price1, price2, price3]) + + def test_runtime_batch_size_mismatch(self): + price1 = fc.numeric_column('price1') + price2 = fc.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 + 'price2': [[3.], [4.]] # batchsize = 2 + } + predictions = get_keras_linear_model_predictions(features, + [price1, price2]) + with _initialized_session() as sess: + with self.assertRaisesRegexp(errors.OpError, + 'must have the same size and shape'): + sess.run( + predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]}) + + def test_runtime_batch_size_matches(self): + price1 = fc.numeric_column('price1') + price2 = fc.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 + 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 + } + predictions = get_keras_linear_model_predictions(features, + [price1, price2]) + with _initialized_session() as sess: + sess.run( + predictions, + feed_dict={ + features['price1']: [[1.], [5.]], + features['price2']: [[1.], [5.]], + }) + + def test_with_numpy_input_fn(self): + price = fc.numeric_column('price') + price_buckets = fc.bucketized_column( + price, boundaries=[ + 0., + 10., + 100., + ]) + body_style = fc.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'price': np.array([-1., 2., 13., 104.]), + 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']), + }, + batch_size=2, + shuffle=False) + features = input_fn() + net = get_keras_linear_model_predictions(features, + [price_buckets, body_style]) + # self.assertEqual(1 + 3 + 5, net.shape[1]) + with _initialized_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess, coord=coord) + + bias = get_keras_linear_model_bias() + price_buckets_var = get_linear_model_column_var(price_buckets) + body_style_var = get_linear_model_column_var(body_style) + + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net)) + + coord.request_stop() + coord.join(threads) + + def test_with_1d_sparse_tensor(self): + price = fc.numeric_column('price') + price_buckets = fc.bucketized_column( + price, boundaries=[ + 0., + 10., + 100., + ]) + body_style = fc.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + + # Provides 1-dim tensor and dense tensor. + features = { + 'price': + constant_op.constant([ + -1., + 12., + ]), + 'body-style': + sparse_tensor.SparseTensor( + indices=((0,), (1,)), + values=('sedan', 'hardtop'), + dense_shape=(2,)), + } + self.assertEqual(1, features['price'].shape.ndims) + self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0]) + + net = get_keras_linear_model_predictions(features, + [price_buckets, body_style]) + with _initialized_session() as sess: + bias = get_keras_linear_model_bias() + price_buckets_var = get_linear_model_column_var(price_buckets) + body_style_var = get_linear_model_column_var(body_style) + + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net)) + + def test_with_1d_unknown_shape_sparse_tensor(self): + price = fc.numeric_column('price') + price_buckets = fc.bucketized_column( + price, boundaries=[ + 0., + 10., + 100., + ]) + body_style = fc.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + country = fc.categorical_column_with_vocabulary_list( + 'country', vocabulary_list=['US', 'JP', 'CA']) + + # Provides 1-dim tensor and dense tensor. + features = { + 'price': array_ops.placeholder(dtypes.float32), + 'body-style': array_ops.sparse_placeholder(dtypes.string), + 'country': array_ops.placeholder(dtypes.string), + } + self.assertIsNone(features['price'].shape.ndims) + self.assertIsNone(features['body-style'].get_shape().ndims) + + price_data = np.array([-1., 12.]) + body_style_data = sparse_tensor.SparseTensorValue( + indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,)) + country_data = np.array(['US', 'CA']) + + net = get_keras_linear_model_predictions( + features, [price_buckets, body_style, country]) + bias = get_keras_linear_model_bias() + price_buckets_var = get_linear_model_column_var(price_buckets) + body_style_var = get_linear_model_column_var(body_style) + with _initialized_session() as sess: + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], + sess.run( + net, + feed_dict={ + features['price']: price_data, + features['body-style']: body_style_data, + features['country']: country_data + })) + + def test_with_rank_0_feature(self): + price = fc.numeric_column('price') + features = { + 'price': constant_op.constant(0), + } + self.assertEqual(0, features['price'].shape.ndims) + + # Static rank 0 should fail + with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'): + get_keras_linear_model_predictions(features, [price]) + + # Dynamic rank 0 should fail + features = { + 'price': array_ops.placeholder(dtypes.float32), + } + net = get_keras_linear_model_predictions(features, [price]) + self.assertEqual(1, net.shape[1]) + with _initialized_session() as sess: + with self.assertRaisesOpError('Feature .* cannot have rank 0'): + sess.run(net, feed_dict={features['price']: np.array(1)}) + + class InputLayerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() @@ -2715,6 +3543,32 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 self.assertAllClose(((3.,), (5.,)), predictions.eval()) + def test_keras_linear_model(self): + wire_column = fc.categorical_column_with_vocabulary_file( + key='wire', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=1) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + wire_column.name: + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_keras_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 + self.assertAllClose(((3.,), (5.,)), predictions.eval()) + class VocabularyListCategoricalColumnTest(test.TestCase): @@ -3082,6 +3936,31 @@ class VocabularyListCategoricalColumnTest(test.TestCase): # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 self.assertAllClose(((3.,), (5.,)), predictions.eval()) + def test_keras_linear_model(self): + wire_column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo'), + num_oov_buckets=1) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + wire_column.name: + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_keras_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 + self.assertAllClose(((3.,), (5.,)), predictions.eval()) + class IdentityCategoricalColumnTest(test.TestCase): @@ -3306,6 +4185,28 @@ class IdentityCategoricalColumnTest(test.TestCase): # weight_var[2] + weight_var[1] = 3+2 = 5 self.assertAllClose(((1.,), (5.,)), predictions.eval()) + def test_keras_linear_model(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + self.assertEqual(3, column._num_buckets) + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + column.name: + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)) + }, (column,)) + bias = get_keras_linear_model_bias() + weight_var = get_linear_model_column_var(column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + weight_var.assign(((1.,), (2.,), (3.,))).eval() + # weight_var[0] = 1 + # weight_var[2] + weight_var[1] = 3+2 = 5 + self.assertAllClose(((1.,), (5.,)), predictions.eval()) + class TransformFeaturesTest(test.TestCase): @@ -3537,6 +4438,25 @@ class IndicatorColumnTest(test.TestCase): weight_var.assign([[1.], [2.], [3.], [4.]]).eval() self.assertAllClose([[2. + 3.]], predictions.eval()) + def test_keras_linear_model(self): + animal = fc.indicator_column( + fc.categorical_column_with_identity('animal', num_buckets=4)) + with ops.Graph().as_default(): + features = { + 'animal': + sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) + } + + predictions = get_keras_linear_model_predictions(features, [animal]) + weight_var = get_linear_model_column_var(animal) + with _initialized_session(): + # All should be zero-initialized. + self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval()) + self.assertAllClose([[0.]], predictions.eval()) + weight_var.assign([[1.], [2.], [3.], [4.]]).eval() + self.assertAllClose([[2. + 3.]], predictions.eval()) + def test_input_layer(self): animal = fc.indicator_column( fc.categorical_column_with_identity('animal', num_buckets=4)) @@ -3727,6 +4647,72 @@ class EmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval()) + + def test_get_dense_tensor_with_state(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer) + + # Create embedding_weights variable. + weight_collections = [ + ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES + ] + state = embedding_column._create_state(weight_collections) + + # Provide sparse input and get dense result. + embedding_lookup = embedding_column._get_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input + }), state=state) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertItemsEqual( ('embedding_weights:0',), tuple([v.name for v in global_vars])) with _initialized_session(): @@ -4023,6 +5009,82 @@ class EmbeddingColumnTest(test.TestCase): # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42] self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval()) + def test_keras_linear_model(self): + # Inputs. + batch_size = 4 + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(batch_size, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_shape = (vocabulary_size, embedding_dimension) + zeros_embedding_values = np.zeros(embedding_shape) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual(embedding_shape, shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return zeros_embedding_values + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer) + + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + categorical_column.name: sparse_input + }, (embedding_column,)) + expected_var_names = ( + 'linear_model/bias_layer/bias_weights:0', + 'linear_model/aaa_embedding/weights:0', + 'linear_model/aaa_embedding/embedding_weights:0', + ) + self.assertItemsEqual( + expected_var_names, + [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + trainable_vars = { + v.name: v + for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + } + self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + bias = trainable_vars['linear_model/bias_layer/bias_weights:0'] + embedding_weights = trainable_vars[ + 'linear_model/aaa_embedding/embedding_weights:0'] + linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0'] + with _initialized_session(): + # Predictions with all zero weights. + self.assertAllClose(np.zeros((1,)), bias.eval()) + self.assertAllClose(zeros_embedding_values, embedding_weights.eval()) + self.assertAllClose( + np.zeros((embedding_dimension, 1)), linear_weights.eval()) + self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval()) + + # Predictions with all non-zero weights. + embedding_weights.assign(( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + )).eval() + linear_weights.assign(((4.,), (6.,))).eval() + # example 0, ids [2], embedding[0] = [7, 11] + # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5] + # example 2, ids [], embedding[2] = [0, 0] + # example 3, ids [1], embedding[3] = [3, 5] + # sum(embeddings * linear_weights) + # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42] + self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval()) + def test_input_layer(self): # Inputs. vocabulary_size = 3 @@ -4445,6 +5507,80 @@ class SharedEmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + embedding_var = global_vars[0] + with _initialized_session(): + self.assertAllEqual(embedding_values, embedding_var.eval()) + self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval()) + self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval()) + + def test_get_dense_tensor_with_state(self): + # Inputs. + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array([ + [2, -1, -1], # example 0, ids [2] + [0, 1, -1] + ]) # example 1, ids [0, 1] + input_b = np.array([ + [0, -1, -1], # example 0, ids [0] + [-1, -1, -1] + ]) # example 1, ids [] + input_features = {'aaa': input_a, 'bbb': input_b} + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups_a = ( + # example 0: + (7., 11.), # ids [2], embedding = [7, 11] + # example 1: + (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + ) + expected_lookups_b = ( + # example 0: + (1., 2.), # ids [0], embedding = [1, 2] + # example 1: + (0., 0.), # ids [], embedding = [0, 0] + ) + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + # Create state. + weight_collections = [ + ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES + ] + state = embedding_column_a._create_state(weight_collections) + + # Provide sparse input and get dense result. + embedding_lookup_a = embedding_column_a._get_dense_tensor( + _LazyBuilder(input_features), state=state) + embedding_lookup_b = embedding_column_b._get_dense_tensor( + _LazyBuilder(input_features), state=state) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertItemsEqual( ('embedding_weights:0',), tuple([v.name for v in global_vars])) embedding_var = global_vars[0] @@ -4595,6 +5731,97 @@ class SharedEmbeddingColumnTest(test.TestCase): # = [3*1 + 5*2, 3*0 +5*0] = [13, 0] self.assertAllClose([[94. + 13.], [29.]], predictions.eval()) + def test_keras_linear_model(self): + # Inputs. + batch_size = 2 + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array([ + [2, -1, -1], # example 0, ids [2] + [0, 1, -1] + ]) # example 1, ids [0, 1] + input_b = np.array([ + [0, -1, -1], # example 0, ids [0] + [-1, -1, -1] + ]) # example 1, ids [] + + # Embedding variable. + embedding_dimension = 2 + embedding_shape = (vocabulary_size, embedding_dimension) + zeros_embedding_values = np.zeros(embedding_shape) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual(embedding_shape, shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return zeros_embedding_values + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + categorical_column_a.name: input_a, + categorical_column_b.name: input_b, + }, (embedding_column_a, embedding_column_b)) + # Linear weights do not follow the column name. But this is a rare use + # case, and fixing it would add too much complexity to the code. + expected_var_names = ( + 'linear_model/bias_layer/bias_weights:0', + 'linear_model/aaa_bbb_shared_embedding/weights:0', + 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0', + 'linear_model/aaa_bbb_shared_embedding_1/weights:0', + ) + self.assertItemsEqual( + expected_var_names, + [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + trainable_vars = { + v.name: v + for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + } + self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + bias = trainable_vars['linear_model/bias_layer/bias_weights:0'] + embedding_weights = trainable_vars[ + 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0'] + linear_weights_a = trainable_vars[ + 'linear_model/aaa_bbb_shared_embedding/weights:0'] + linear_weights_b = trainable_vars[ + 'linear_model/aaa_bbb_shared_embedding_1/weights:0'] + with _initialized_session(): + # Predictions with all zero weights. + self.assertAllClose(np.zeros((1,)), bias.eval()) + self.assertAllClose(zeros_embedding_values, embedding_weights.eval()) + self.assertAllClose( + np.zeros((embedding_dimension, 1)), linear_weights_a.eval()) + self.assertAllClose( + np.zeros((embedding_dimension, 1)), linear_weights_b.eval()) + self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval()) + + # Predictions with all non-zero weights. + embedding_weights.assign(( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + )).eval() + linear_weights_a.assign(((4.,), (6.,))).eval() + # example 0, ids [2], embedding[0] = [7, 11] + # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5] + # sum(embeddings * linear_weights) + # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29] + linear_weights_b.assign(((3.,), (5.,))).eval() + # example 0, ids [0], embedding[0] = [1, 2] + # example 1, ids [], embedding[1] = 0, 0] + # sum(embeddings * linear_weights) + # = [3*1 + 5*2, 3*0 +5*0] = [13, 0] + self.assertAllClose([[94. + 13.], [29.]], predictions.eval()) + def _test_input_layer(self, trainable=True): # Inputs. vocabulary_size = 3 @@ -4880,6 +6107,101 @@ class WeightedCategoricalColumnTest(test.TestCase): dense_shape=(2, 2)), weight_tensor.eval()) + def test_keras_linear_model(self): + column = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + 'ids': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(.5, 1., .1), + dense_shape=(2, 2)) + }, (column,)) + bias = get_keras_linear_model_bias() + weight_var = get_linear_model_column_var(column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + weight_var.assign(((1.,), (2.,), (3.,))).eval() + # weight_var[0] * weights[0, 0] = 1 * .5 = .5 + # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1] + # = 3*1 + 2*.1 = 3+.2 = 3.2 + self.assertAllClose(((.5,), (3.2,)), predictions.eval()) + + def test_keras_linear_model_mismatched_shape(self): + column = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + with self.assertRaisesRegexp(ValueError, + r'Dimensions.*are not compatible'): + get_keras_linear_model_predictions({ + 'ids': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (0, 1), (1, 0), (1, 1)), + values=(.5, 11., 1., .1), + dense_shape=(2, 2)) + }, (column,)) + + def test_keras_linear_model_mismatched_dense_values(self): + column = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + 'ids': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': ((.5,), (1.,)) + }, (column,)) + with _initialized_session(): + with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): + predictions.eval() + + def test_keras_linear_model_mismatched_dense_shape(self): + column = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + 'ids': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': ((.5,), (1.,), (.1,)) + }, (column,)) + bias = get_keras_linear_model_bias() + weight_var = get_linear_model_column_var(column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + weight_var.assign(((1.,), (2.,), (3.,))).eval() + # weight_var[0] * weights[0, 0] = 1 * .5 = .5 + # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1] + # = 3*1 + 2*.1 = 3+.2 = 3.2 + self.assertAllClose(((.5,), (3.2,)), predictions.eval()) + def test_linear_model(self): column = fc.weighted_categorical_column( categorical_column=fc.categorical_column_with_identity( diff --git a/tensorflow/python/framework/c_api_util.py b/tensorflow/python/framework/c_api_util.py index 6c522de452..4356a534b4 100644 --- a/tensorflow/python/framework/c_api_util.py +++ b/tensorflow/python/framework/c_api_util.py @@ -33,7 +33,7 @@ class ScopedTFStatus(object): def __del__(self): # Note: when we're destructing the global context (i.e when the process is # terminating) we can have already deleted other modules. - if c_api.TF_DeleteStatus is not None: + if c_api is not None and c_api.TF_DeleteStatus is not None: c_api.TF_DeleteStatus(self.status) @@ -46,7 +46,7 @@ class ScopedTFGraph(object): def __del__(self): # Note: when we're destructing the global context (i.e when the process is # terminating) we can have already deleted other modules. - if c_api.TF_DeleteGraph is not None: + if c_api is not None and c_api.TF_DeleteGraph is not None: c_api.TF_DeleteGraph(self.graph) @@ -59,7 +59,7 @@ class ScopedTFImportGraphDefOptions(object): def __del__(self): # Note: when we're destructing the global context (i.e when the process is # terminating) we can have already deleted other modules. - if c_api.TF_DeleteImportGraphDefOptions is not None: + if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None: c_api.TF_DeleteImportGraphDefOptions(self.options) diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py index 2a40316d51..84106c32c6 100644 --- a/tensorflow/python/framework/errors_impl.py +++ b/tensorflow/python/framework/errors_impl.py @@ -473,6 +473,8 @@ _CODE_TO_EXCEPTION_CLASS = { DATA_LOSS: DataLossError, } +c_api.PyExceptionRegistry_Init(_CODE_TO_EXCEPTION_CLASS) + _EXCEPTION_CLASS_TO_CODE = dict(( (class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items())) @@ -499,6 +501,7 @@ def _make_specific_exception(node_def, op, message, error_code): # Named like a function for backwards compatibility with the # @tf_contextlib.contextmanager version, which was switched to a class to avoid # some object creation overhead. +# TODO(b/77295559): expand use of TF_Status* SWIG typemap and deprecate this. @tf_export("errors.raise_exception_on_not_ok_status") # pylint: disable=invalid-name class raise_exception_on_not_ok_status(object): """Context manager to check for C API status.""" diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 82dd2a3356..c5caf9ebc0 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -30,7 +30,6 @@ from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.eager import context from tensorflow.python.framework import c_api_util from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -275,8 +274,7 @@ class _DefinedFunction(object): self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_FunctionToFunctionDef(self._c_func, buf, status) + c_api.TF_FunctionToFunctionDef(self._c_func, buf) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) @@ -399,18 +397,16 @@ class _DefinedFunction(object): if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access - with errors.raise_exception_on_not_ok_status() as status: - self._c_func = c_api.TF_GraphToFunction_wrapper( - temp_graph._c_graph, - base_func_name, - self._func_name is None, # append_hash_to_fn_name - None, # opers - [t._as_tf_output() for t in inputs], - [t._as_tf_output() for t in outputs], - output_names, - None, # opts - description, - status) + self._c_func = c_api.TF_GraphToFunction_wrapper( + temp_graph._c_graph, + base_func_name, + self._func_name is None, # append_hash_to_fn_name + None, # opers + [t._as_tf_output() for t in inputs], + [t._as_tf_output() for t in outputs], + output_names, + None, # opts + description) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) @@ -433,9 +429,8 @@ class _DefinedFunction(object): serialized = attr_value.SerializeToString() # TODO(skyewm): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use the same status. - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name), - serialized, status) + c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name), + serialized) def _create_hash_str(self, input_arg, output_arg, node_def): """Creates an 8-character string unique to this input. @@ -830,8 +825,7 @@ def _from_definition(fdef, grad_func=None): # pylint: disable=protected-access if ops._USE_C_API: serialized = fdef.SerializeToString() - with errors.raise_exception_on_not_ok_status() as status: - result._c_func = c_api.TF_FunctionImportFunctionDef(serialized, status) + result._c_func = c_api.TF_FunctionImportFunctionDef(serialized) result._extra_inputs = [] else: result._definition = fdef diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 4ea34d7bb2..23f529b988 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -485,9 +485,8 @@ def import_graph_def(graph_def, with graph._lock: # pylint: disable=protected-access with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: - with errors.raise_exception_on_not_ok_status() as status: - results = c_api.TF_GraphImportGraphDefWithResults( - graph._c_graph, serialized, options, status) # pylint: disable=protected-access + results = c_api.TF_GraphImportGraphDefWithResults( + graph._c_graph, serialized, options) # pylint: disable=protected-access except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index 369669c2e6..2c913d1e02 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -219,6 +219,23 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(outer_inner.name, "outer/inner_1") self.assertEqual(outer_inner_c.name, "outer/inner/c_1") + def testEmptyNameScope(self): + with ops.Graph().as_default(): + # Create name scope but don't create any ops with it + with ops.name_scope("foo"): + pass + + # Import graph def that uses name scope name + op, = importer.import_graph_def( + self._MakeGraphDef("node { name: 'foo' op: 'IntOutput' }"), + return_elements=["foo"], + name="") + + if ops._USE_C_API: + self.assertEqual(op.name, "foo") + else: + self.assertEqual(op.name, "foo_1") + def testInputMap(self): with ops.Graph().as_default(): feed_a_0 = constant_op.constant(0, dtype=dtypes.int32) diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py index 1f2aa264c1..535c6017f5 100644 --- a/tensorflow/python/framework/load_library.py +++ b/tensorflow/python/framework/load_library.py @@ -26,7 +26,6 @@ import threading # pylint: disable=unused-import from tensorflow.core.framework import op_def_pb2 from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import from tensorflow.python import pywrap_tensorflow as py_tf -from tensorflow.python.framework import errors_impl from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -54,8 +53,7 @@ def load_op_library(library_filename): Raises: RuntimeError: when unable to load the library or get the python wrappers. """ - with errors_impl.raise_exception_on_not_ok_status() as status: - lib_handle = py_tf.TF_LoadLibrary(library_filename, status) + lib_handle = py_tf.TF_LoadLibrary(library_filename) op_list_str = py_tf.TF_GetOpList(lib_handle) op_list = op_def_pb2.OpList() @@ -99,5 +97,4 @@ def load_file_system_library(library_filename): Raises: RuntimeError: when unable to load the library. """ - with errors_impl.raise_exception_on_not_ok_status() as status: - lib_handle = py_tf.TF_LoadLibrary(library_filename, status) + py_tf.TF_LoadLibrary(library_filename) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 6930737a0c..2d55f98a1c 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -63,7 +63,7 @@ from tensorflow.python.util.tf_export import tf_export # calls to the C API. Currently disabled by default but can be manually enabled # in code or via the environment variable. This will be removed once all # functionality is supported and there's no performance penalty with it enabled. -_USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "0") is not "0" +_USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "1") is not "0" _USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "0") is not "0" @@ -373,15 +373,12 @@ class Tensor(_TensorLike): """ graph = self._op._graph._c_graph # pylint: disable=protected-access if graph and _USE_C_SHAPES: - with errors.raise_exception_on_not_ok_status() as status: - num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(), - status) + num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output()) if num_dims == -1: dim_list = None else: - with errors.raise_exception_on_not_ok_status() as status: - dim_list = c_api.TF_GraphGetTensorShape_wrapper( - graph, self._as_tf_output(), num_dims, status) + dim_list = c_api.TF_GraphGetTensorShape_wrapper( + graph, self._as_tf_output(), num_dims) dim_list = [None if i == -1 else i for i in dim_list] return tensor_shape.TensorShape(dim_list) return self._shape_val @@ -489,13 +486,11 @@ class Tensor(_TensorLike): else: dim_list.append(dim.value) try: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_GraphSetTensorShape_wrapper( - self._op._graph._c_graph, # pylint: disable=protected-access - self._as_tf_output(), - dim_list, - unknown_shape, - status) + c_api.TF_GraphSetTensorShape_wrapper( + self._op._graph._c_graph, # pylint: disable=protected-access + self._as_tf_output(), + dim_list, + unknown_shape) except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) @@ -1514,13 +1509,10 @@ def _create_c_op(graph, node_def, inputs, control_inputs): serialized = attr_value.SerializeToString() # TODO(skyewm): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use the same status. - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_SetAttrValueProto(op_desc, - compat.as_str(name), serialized, status) + c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized) try: - with errors.raise_exception_on_not_ok_status() as status: - c_op = c_api.TF_FinishOperation(op_desc, status) + c_op = c_api.TF_FinishOperation(op_desc) except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) @@ -1943,12 +1935,10 @@ class Operation(object): if self._c_op: # Reset cached inputs. self._inputs_val = None - with errors.raise_exception_on_not_ok_status() as status: - c_api.UpdateEdge( - self._graph._c_graph, # pylint: disable=protected-access - tensor._as_tf_output(), # pylint: disable=protected-access - self._tf_input(index), - status) + c_api.UpdateEdge( + self._graph._c_graph, # pylint: disable=protected-access + tensor._as_tf_output(), # pylint: disable=protected-access + self._tf_input(index)) else: self._inputs_val[index].consumers().remove(self) self._inputs_val[index] = tensor @@ -2124,6 +2114,30 @@ class Operation(object): return self._control_inputs_val @property + def _control_outputs(self): + """The `Operation` objects which have a control dependency on this op. + + Before any of the ops in self._control_outputs can execute tensorflow will + ensure self has finished executing. + + Returns: + A list of `Operation` objects. + + """ + if self._c_op: + control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op) + # pylint: disable=protected-access + return [ + self.graph._get_operation_by_name_unsafe( + c_api.TF_OperationName(c_op)) for c_op in control_c_ops + ] + # pylint: enable=protected-access + else: + # TODO(apassos) this should be less inefficient. + return [o for o in self._graph.get_operations() + if self in o.control_inputs] + + @property def _control_inputs(self): logging.warning("Operation._control_inputs is private, use " "Operation.control_inputs instead. " @@ -2169,8 +2183,7 @@ class Operation(object): # pylint: enable=line-too-long if self._c_op: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_OperationToNodeDef(self._c_op, buf, status) + c_api.TF_OperationToNodeDef(self._c_op, buf) data = c_api.TF_GetBuffer(buf) node_def = node_def_pb2.NodeDef() node_def.ParseFromString(compat.as_bytes(data)) @@ -2228,11 +2241,9 @@ class Operation(object): buf = c_api.TF_NewBufferFromString( compat.as_bytes(attr_value.SerializeToString())) try: - with errors.raise_exception_on_not_ok_status() as status: - # pylint: disable=protected-access - c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, - status) - # pylint: enable=protected-access + # pylint: disable=protected-access + c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf) + # pylint: enable=protected-access finally: c_api.TF_DeleteBuffer(buf) else: @@ -2254,8 +2265,7 @@ class Operation(object): if self._c_op: try: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status) + c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf) data = c_api.TF_GetBuffer(buf) except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. @@ -2469,11 +2479,10 @@ def _set_shapes_for_outputs_c_api(op): # The C API computes the shapes when the TF_Operation is created. Fetch the # output shapes from the C object. for output in op.outputs: - with errors.raise_exception_on_not_ok_status() as status: - # pylint: disable=protected-access - shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper( - op._graph._c_graph, output._as_tf_output(), status) - # pylint: enable=protected-access + # pylint: disable=protected-access + shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper( + op._graph._c_graph, output._as_tf_output()) + # pylint: enable=protected-access if unknown_shape: output.set_shape(tensor_shape.unknown_shape()) elif not shape_vector: @@ -2994,8 +3003,7 @@ class Graph(object): # pylint: enable=line-too-long if self._c_graph: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_GraphVersions(self._c_graph, buf, status) + c_api.TF_GraphVersions(self._c_graph, buf) data = c_api.TF_GetBuffer(buf) version_def = versions_pb2.VersionDef() version_def.ParseFromString(compat.as_bytes(data)) @@ -3098,8 +3106,7 @@ class Graph(object): if self._c_graph: with self._lock: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_GraphToGraphDef(self._c_graph, buf, status) + c_api.TF_GraphToGraphDef(self._c_graph, buf) data = c_api.TF_GetBuffer(buf) graph = graph_pb2.GraphDef() graph.ParseFromString(compat.as_bytes(data)) @@ -3208,14 +3215,10 @@ class Graph(object): # remove this when all functions are generated using the C API by default # as this will be unnecessary. if not function._c_func: - with errors.raise_exception_on_not_ok_status() as status: - serialized = function.definition.SerializeToString() - function._c_func = c_api.TF_FunctionImportFunctionDef( - serialized, status) - with errors.raise_exception_on_not_ok_status() as status: - gradient = function._grad_func._c_func if function._grad_func else None - c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient, - status) + serialized = function.definition.SerializeToString() + function._c_func = c_api.TF_FunctionImportFunctionDef(serialized) + gradient = function._grad_func._c_func if function._grad_func else None + c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient) else: # If there is already a function with the same name, raise an error # if bodies are different. Else, do nothing. The C API version above @@ -3365,8 +3368,12 @@ class Graph(object): """ self._check_not_finalized() ret = Operation(c_op, self) - assert ret.name not in self._names_in_use - self._names_in_use[ret.name] = 1 + # If a name_scope was created with ret.name but no nodes were created in it, + # the name will still appear in _names_in_use even though the name hasn't + # been used. This is ok, just leave _names_in_use as-is in this case. + # TODO(skyewm): make the C API guarantee no name conflicts. + if ret.name not in self._names_in_use: + self._names_in_use[ret.name] = 1 self._create_op_helper(ret, compute_device=compute_device) return ret @@ -3732,11 +3739,9 @@ class Graph(object): """Returns the `OpDef` proto for `type`. `type` is a string.""" if self._c_graph: with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - # pylint: disable=protected-access - c_api.TF_GraphGetOpDef(self._c_graph, - compat.as_bytes(type), buf, status) - # pylint: enable=protected-access + # pylint: disable=protected-access + c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf) + # pylint: enable=protected-access data = c_api.TF_GetBuffer(buf) op_def = op_def_pb2.OpDef() op_def.ParseFromString(compat.as_bytes(data)) @@ -4512,6 +4517,22 @@ class Graph(object): return tf.matmul(tensor, tensor) ``` + Also note that though execution of ops created under this scope will trigger + execution of the dependencies, the ops created under this scope might still + be pruned from a normal tensorflow graph. For example, in the following + snippet of code the dependencies are never executed: + + ```python + loss = model.loss() + with tf.control_dependencies(dependencies): + loss = loss + tf.constant(1) # note: dependencies ignored in the + # backward pass + return tf.gradients(loss, model.variables) + ``` + + This is because evaluating the gradient graph does not require evaluating + the constant(1) op created in the forward pass. + Args: control_inputs: A list of `Operation` or `Tensor` objects which must be executed or computed before running the operations @@ -5350,6 +5371,10 @@ def enable_eager_execution(config=None, device_policy=None, raise ValueError( "tf.enable_eager_execution must be called at program startup.") + # Monkey patch to get rid of an unnecessary conditional since the context is + # now initialized. + context.context = context.context_safe + def eager_run(main=None, argv=None): """Runs the program with an optional main function and argv list. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index aa51391871..58bead91ed 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -473,6 +473,7 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertEqual(z.control_inputs, [x, x]) z._add_control_inputs([x, y, y]) # pylint: disable=protected-access self.assertEqual(z.control_inputs, [x, x, x, y, y]) + self.assertEqual(x._control_outputs, [z]) def testAddControlInputC(self): # The C API dedups redundant control edges, pure Python does not @@ -487,6 +488,7 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertEqual(z.control_inputs, [x]) z._add_control_inputs([x, y, y]) # pylint: disable=protected-access self.assertEqual(z.control_inputs, [x, y]) + self.assertEqual(x._control_outputs, [z]) def testRemoveAllControlInputs(self): a = constant_op.constant(1) diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py index c7ff23e4ff..48a834392b 100644 --- a/tensorflow/python/framework/smart_cond.py +++ b/tensorflow/python/framework/smart_cond.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow as c_api -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops @@ -83,9 +82,8 @@ def smart_constant_value(pred): # wanted to limit the change hidden behind _USE_C_API). # pylint: disable=protected-access if pred_value is None and ops._USE_C_API: - with errors.raise_exception_on_not_ok_status() as status: - pred_value = c_api.TF_TryEvaluateConstant_wrapper( - pred.graph._c_graph, pred._as_tf_output(), status) + pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph, + pred._as_tf_output()) # pylint: enable=protected-access else: diff --git a/tensorflow/python/framework/versions.py b/tensorflow/python/framework/versions.py index 06955b8858..d08b4bf48a 100644 --- a/tensorflow/python/framework/versions.py +++ b/tensorflow/python/framework/versions.py @@ -29,7 +29,7 @@ __cxx11_abi_flag__ = pywrap_tensorflow.__cxx11_abi_flag__ __monolithic_build__ = pywrap_tensorflow.__monolithic_build__ VERSION = __version__ -tf_export("VERSION").export_constant(__name__, "VERSION") +tf_export("VERSION", "__version__").export_constant(__name__, "VERSION") GIT_VERSION = __git_version__ tf_export("GIT_VERSION").export_constant(__name__, "GIT_VERSION") COMPILER_VERSION = __compiler_version__ diff --git a/tensorflow/python/grappler/constant_folding_test.py b/tensorflow/python/grappler/constant_folding_test.py new file mode 100644 index 0000000000..ab1d0ed25b --- /dev/null +++ b/tensorflow/python/grappler/constant_folding_test.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Grappler Constant Folding.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ConstantFoldingTest(test.TestCase): + + # See b/76008022. + def testScanInsideWhile(self): + + def loop_cond(idx_step, *unused_args): + return idx_step < 1 + + def loop_body(idx_step, y): + x = array_ops.zeros([10, 20, 30], dtype=dtypes.float32) + x = functional_ops.scan( + math_ops.add, + x, + initializer=array_ops.zeros([20, 30], dtype=dtypes.float32), + back_prop=False, + parallel_iterations=1) + + with ops.device('/cpu:0'): + y = array_ops.identity(x) + + return idx_step + 1, y + + if test.is_gpu_available(cuda_only=True): + init_y = array_ops.zeros([10, 20, 30], dtype=dtypes.float32) + _, y = control_flow_ops.while_loop( + loop_cond, + loop_body, + loop_vars=[0, init_y], + back_prop=False, + parallel_iterations=1) + with session.Session() as sess: + y_v = sess.run(y) + self.assertAllEqual(np.zeros([10, 20, 30]), y_v) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/grappler/item.py b/tensorflow/python/grappler/item.py index 4a083849bd..1748efdd13 100644 --- a/tensorflow/python/grappler/item.py +++ b/tensorflow/python/grappler/item.py @@ -51,9 +51,7 @@ class Item(object): self._BuildTFItem() def IdentifyImportantOps(self, sort_topologically=False): - with errors.raise_exception_on_not_ok_status() as status: - return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically, - status) + return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically) def GetOpProperties(self): ret_from_swig = tf_item.TF_GetOpProperties(self.tf_item) diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py index 3ee4d7807e..1c0f072dd3 100644 --- a/tensorflow/python/grappler/tf_optimizer_test.py +++ b/tensorflow/python/grappler/tf_optimizer_test.py @@ -17,12 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.grappler import item as gitem from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -74,6 +78,47 @@ class PyWrapOptimizeGraphTest(test.TestCase): self.assertEqual(a2.op.name, optimized_graph.node[3].name) self.assertEqual('Variable/Assign', optimized_graph.node[4].name) + def testLoops(self): + g = ops.Graph() + with g.as_default(): + + def _Cond(_, counter): + return counter < end + + def _Body(buf, counter): + buf = array_ops.concat([buf, [counter]], 0) + counter += 1 + return [buf, counter] + + start = array_ops.placeholder(shape=[], dtype=dtypes.int32) + end = array_ops.placeholder(shape=[], dtype=dtypes.int32) + init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32) + loop_vars = [init_buf, start] + shape_inv = [ + tensor_shape.TensorShape([None]), + tensor_shape.TensorShape([]) + ] + buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars, shape_inv) + + f = -array_ops.ones_like(buf, optimize=False) + buf_shape = array_ops.shape(buf) + f_shape = array_ops.shape(f) + ops.add_to_collection('train_op', buf_shape) + ops.add_to_collection('train_op', f_shape) + + # Optimize the graph. + mg = meta_graph.create_meta_graph_def(graph=g) + rewriter_config = rewriter_config_pb2.RewriterConfig() + optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) + mg.graph_def.CopyFrom(optimized_graph) + + # Check that the nodes referenced in various collections have been preserved + item = gitem.Item(mg) + props = item.GetOpProperties() + buf_prop = props[buf.op.name] + f_prop = props[f.op.name] + self.assertEqual(buf_prop, f_prop) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 2a06907f49..57f5097639 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -637,7 +637,10 @@ py_test( size = "small", srcs = ["_impl/keras/utils/io_utils_test.py"], srcs_version = "PY2AND3", - tags = ["notsan"], + tags = [ + "no_windows", # TODO: needs investigation on Windows + "notsan", + ], deps = [ ":keras", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/_impl/keras/activations.py b/tensorflow/python/keras/_impl/keras/activations.py index 236e17653e..b518898ad8 100644 --- a/tensorflow/python/keras/_impl/keras/activations.py +++ b/tensorflow/python/keras/_impl/keras/activations.py @@ -23,6 +23,8 @@ import six from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.layers.base import Layer +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -43,10 +45,10 @@ def softmax(x, axis=-1): """ ndim = K.ndim(x) if ndim == 2: - return K.softmax(x) + return nn.softmax(x) elif ndim > 2: - e = K.exp(x - K.max(x, axis=axis, keepdims=True)) - s = K.sum(e, axis=axis, keepdims=True) + e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True)) + s = math_ops.reduce_sum(e, axis=axis, keepdims=True) return e / s else: raise ValueError('Cannot apply softmax to a tensor that is 1D') @@ -79,12 +81,12 @@ def selu(x): @tf_export('keras.activations.softplus') def softplus(x): - return K.softplus(x) + return nn.softplus(x) @tf_export('keras.activations.softsign') def softsign(x): - return K.softsign(x) + return nn.softsign(x) @tf_export('keras.activations.relu') @@ -94,12 +96,12 @@ def relu(x, alpha=0., max_value=None): @tf_export('keras.activations.tanh') def tanh(x): - return K.tanh(x) + return nn.tanh(x) @tf_export('keras.activations.sigmoid') def sigmoid(x): - return K.sigmoid(x) + return nn.sigmoid(x) @tf_export('keras.activations.hard_sigmoid') diff --git a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py index c26a28ed40..d928a7afdc 100644 --- a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py +++ b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py @@ -22,8 +22,10 @@ import json import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -151,11 +153,11 @@ def _preprocess_symbolic_input(x, data_format, mode): std = None if _IMAGENET_MEAN is None: - _IMAGENET_MEAN = K.constant(-np.array(mean)) + _IMAGENET_MEAN = constant_op.constant(-np.array(mean), dtype=K.floatx()) # Zero-center by mean pixel if K.dtype(x) != K.dtype(_IMAGENET_MEAN): - x = K.bias_add(x, K.cast(_IMAGENET_MEAN, K.dtype(x)), data_format) + x = K.bias_add(x, math_ops.cast(_IMAGENET_MEAN, K.dtype(x)), data_format) else: x = K.bias_add(x, _IMAGENET_MEAN, data_format) if std is not None: diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index 7baf27642a..3aac6a9065 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_util from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -2795,6 +2796,8 @@ class Function(object): else: feed_dict = {} + session = get_session() + data_tensors_to_feed = [] for tensor, value in zip(self.inputs, inputs): if value is None: continue @@ -2803,9 +2806,20 @@ class Function(object): indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(sparse_coo.col, 1)), 1) value = (indices, sparse_coo.data, sparse_coo.shape) - feed_dict[tensor] = value + elif tensor_util.is_tensor(value): + data_tensors_to_feed.append((tensor, value)) + else: + feed_dict[tensor] = value + + if data_tensors_to_feed: + # This is a *temporary* workaround (i.e. hack) to feed a symbolic tensor + # to `feed_dict`. It is very inefficient. It will be removed as soon + # as it becomes possible to pass symbolic tensors to `feed_dict`. + data_tensor_values = session.run([x[1] for x in data_tensors_to_feed]) + for i, v in enumerate(data_tensor_values): + feed_dict[data_tensors_to_feed[i][0]] = v + fetches = self.outputs + [self.updates_op] + self.fetches - session = get_session() updated = session.run( fetches=fetches, feed_dict=feed_dict, **self.session_kwargs) return updated[:len(self.outputs)] diff --git a/tensorflow/python/keras/_impl/keras/constraints.py b/tensorflow/python/keras/_impl/keras/constraints.py index 271fbbb63d..abe95d8e0c 100644 --- a/tensorflow/python/keras/_impl/keras/constraints.py +++ b/tensorflow/python/keras/_impl/keras/constraints.py @@ -24,6 +24,7 @@ import six from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -65,7 +66,8 @@ class MaxNorm(Constraint): self.axis = axis def __call__(self, w): - norms = K.sqrt(K.sum(K.square(w), axis=self.axis, keepdims=True)) + norms = K.sqrt( + math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) desired = K.clip(norms, 0, self.max_value) return w * (desired / (K.epsilon() + norms)) @@ -79,7 +81,7 @@ class NonNeg(Constraint): """ def __call__(self, w): - return w * K.cast(K.greater_equal(w, 0.), K.floatx()) + return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx()) @tf_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm') @@ -105,7 +107,9 @@ class UnitNorm(Constraint): def __call__(self, w): return w / ( - K.epsilon() + K.sqrt(K.sum(K.square(w), axis=self.axis, keepdims=True))) + K.epsilon() + K.sqrt( + math_ops.reduce_sum( + math_ops.square(w), axis=self.axis, keepdims=True))) def get_config(self): return {'axis': self.axis} @@ -148,7 +152,8 @@ class MinMaxNorm(Constraint): self.axis = axis def __call__(self, w): - norms = K.sqrt(K.sum(K.square(w), axis=self.axis, keepdims=True)) + norms = K.sqrt( + math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) desired = ( self.rate * K.clip(norms, self.min_value, self.max_value) + (1 - self.rate) * norms) diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py index b50277c8ff..9ab4b6fdcf 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py @@ -783,7 +783,7 @@ class TopologyConstructionTest(test.TestCase): def test_activity_regularization_with_model_composition(self): def reg(x): - return keras.backend.sum(x) + return math_ops.reduce_sum(x) net_a_input = keras.Input((2,)) net_a = net_a_input diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index 971245c162..71de657da8 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -1181,6 +1181,9 @@ class Model(Network): batch_size=batch_size) elif validation_split and 0. < validation_split < 1.: + if training_utils.has_symbolic_tensors(x): + raise ValueError('If your data is in the form of symbolic tensors, ' + 'you cannot use `validation_split`.') if hasattr(x[0], 'shape'): split_at = int(x[0].shape[0] * (1. - validation_split)) else: diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py index 67858a578c..4cdb5f108a 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py @@ -31,9 +31,8 @@ from tensorflow.python.keras._impl.keras import callbacks as cbks from tensorflow.python.keras._impl.keras import losses from tensorflow.python.keras._impl.keras import metrics as metrics_module from tensorflow.python.keras._impl.keras.engine import training_utils -from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays +from tensorflow.python.keras._impl.keras.utils import generic_utils +from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -173,6 +172,41 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): return outs, total_loss, loss_metrics +def slice_arrays(arrays, indices, contiguous=True): + """Slices batches out of provided arrays (workaround for eager tensors). + + Unfortunately eager tensors don't have the same slicing behavior as + Numpy arrays (they folow the same slicing behavior as symbolic TF tensors), + hence we cannot use `generic_utils.slice_arrays` directly + and we have to implement this workaround based on `concat`. This has a + performance cost. + + Arguments: + arrays: Single array or list of arrays. + indices: List of indices in the array that should be included in the output + batch. + contiguous: Boolean flag indicating whether the indices are contiguous. + + Returns: + Slice of data (either single array or list of arrays). + """ + if any(tensor_util.is_tensor(x) for x in arrays): + converted_to_list = False + if not isinstance(arrays, list): + converted_to_list = True + arrays = [arrays] + if not contiguous: + entries = [[x[i:i + 1] for i in indices] for x in arrays] + slices = [array_ops.concat(x, axis=0) for x in entries] + else: + slices = [x[indices[0]:indices[-1] + 1] for x in arrays] + if converted_to_list: + slices = slices[0] + return slices + else: + return generic_utils.slice_arrays(arrays, indices) + + def _process_single_batch(model, inputs, targets, @@ -270,9 +304,8 @@ def test_on_batch(model, inputs, targets, sample_weights=None): model, inputs, targets, sample_weights=sample_weights, training=False) if not isinstance(outs, list): outs = [outs] - metric_names, metrics_results = _eager_metrics_fn( + _, metrics_results = _eager_metrics_fn( model, outs, targets) - model.metrics_names.append(metric_names) if not isinstance(loss, list): loss = [loss] return loss + loss_metrics + metrics_results @@ -328,6 +361,12 @@ def fit_loop( Raises: ValueError: In case of invalid argument values. """ + if not batch_size: + raise ValueError('With eager execution, `batch_size` should be specified.') + if steps_per_epoch or validation_steps: + raise ValueError('With eager execution, `steps_per_epoch` and ' + '`validation_steps` are not valid arguments ' + '(set `batch_size` instead).') # Required for Eager mode with backend.learning_phase_scope(1): do_validation = False @@ -410,15 +449,18 @@ def fit_loop( elif shuffle: np.random.shuffle(index_array) - batches = make_batches(num_train_samples, batch_size) + batches = generic_utils.make_batches(num_train_samples, batch_size) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] try: - inputs_batch = slice_arrays(inputs, batch_ids) - targets_batch = slice_arrays(targets, batch_ids) + inputs_batch = slice_arrays(inputs, batch_ids, + contiguous=not shuffle) + targets_batch = slice_arrays(targets, batch_ids, + contiguous=not shuffle) if sample_weights: - sample_weights_batch = slice_arrays(sample_weights, batch_ids) + sample_weights_batch = slice_arrays(sample_weights, batch_ids, + contiguous=not shuffle) else: sample_weights_batch = None except TypeError: @@ -539,8 +581,8 @@ def test_loop(model, inputs, targets, feed_data, batch_size=batch_size, steps=steps, steps_name='steps') outs = [] if verbose == 1: - progbar = Progbar(target=num_samples) - batches = make_batches(num_samples, batch_size) + progbar = generic_utils.Progbar(target=num_samples) + batches = generic_utils.make_batches(num_samples, batch_size) index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] @@ -620,12 +662,12 @@ def predict_loop(model, inputs, inputs, batch_size, steps, 'steps') if verbose == 1: if steps is not None: - progbar = Progbar(target=steps) + progbar = generic_utils.Progbar(target=steps) else: - progbar = Progbar(target=num_samples) + progbar = generic_utils.Progbar(target=num_samples) outs = [] - batches = make_batches(num_samples, batch_size) + batches = generic_utils.make_batches(num_samples, batch_size) index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py index 8848b393d5..6cdb6b0753 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import numpy as np from tensorflow.python.framework import ops @@ -308,6 +307,100 @@ class TrainingTest(test.TestCase): model.compile(loss=None, optimizer='rms') + def test_model_methods_with_eager_tensors_multi_io(self): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + loss_weights = [1., 0.5] + metrics = ['mae'] + model.compile( + optimizer, + loss, + metrics=metrics, + loss_weights=loss_weights, + sample_weight_mode=None) + + input_a = keras.backend.zeros(shape=(10, 3)) + input_b = keras.backend.zeros(shape=(10, 3)) + target_d = keras.backend.zeros(shape=(10, 4)) + target_e = keras.backend.zeros(shape=(10, 4)) + + model.fit( + [input_a, input_b], [target_d, target_e], + epochs=1, + batch_size=5, + verbose=0) + # Test: no shuffle. + model.fit( + [input_a, input_b], [target_d, target_e], + epochs=1, + batch_size=5, + verbose=0, + shuffle=False) + # Test: validation data. + model.fit([input_a, input_b], [target_d, target_e], + epochs=1, batch_size=2, verbose=0, + validation_data=([input_a, input_b], [target_d, target_e])) + model.train_on_batch([input_a, input_b], [target_d, target_e]) + model.predict([input_a, input_b], batch_size=5) + model.evaluate([input_a, input_b], [target_d, target_e], + batch_size=2, verbose=0) + model.test_on_batch([input_a, input_b], [target_d, target_e]) + + # Test: mix np and tensors. + input_b = np.zeros(shape=(10, 3)).astype('float32') + target_e = np.zeros(shape=(10, 4)).astype('float32') + model.fit( + [input_a, input_b], [target_d, target_e], + epochs=1, + batch_size=5, + verbose=0) + model.fit([input_a, input_b], [target_d, target_e], + epochs=1, batch_size=2, verbose=0, + validation_data=([input_a, input_b], [target_d, target_e])) + model.fit( + [input_a, input_b], [target_d, target_e], + epochs=1, + batch_size=5, + verbose=0, + shuffle=False) + model.train_on_batch([input_a, input_b], [target_d, target_e]) + model.predict([input_a, input_b], batch_size=5) + model.evaluate([input_a, input_b], [target_d, target_e], + batch_size=2, verbose=0) + model.test_on_batch([input_a, input_b], [target_d, target_e]) + + def test_model_methods_with_eager_tensors_single_io(self): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + inputs = keras.backend.zeros(shape=(10, 3)) + targets = keras.backend.zeros(shape=(10, 4)) + + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0) + model.fit(inputs, targets, epochs=1, batch_size=3, verbose=0, shuffle=False) + model.fit(inputs, targets, epochs=1, batch_size=4, verbose=0, + validation_data=(inputs, targets)) + model.evaluate(inputs, targets, batch_size=2, verbose=0) + model.predict(inputs, batch_size=2) + model.train_on_batch(inputs, targets) + model.test_on_batch(inputs, targets) + class LossWeightingTest(test.TestCase): @@ -533,14 +626,5 @@ class LossWeightingTest(test.TestCase): if __name__ == '__main__': - # Bazel sets these environment variables to very long paths. - # Tempfile uses them to create long paths, and in turn multiprocessing - # library tries to create sockets named after paths. Delete whatever bazel - # writes to these to avoid tests failing due to socket addresses being too - # long. - for var in ('TMPDIR', 'TMP', 'TEMP'): - if var in os.environ: - del os.environ[var] - ops.enable_eager_execution() test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index fd91dbba52..08fd26dd18 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -1117,6 +1117,121 @@ class TestTrainingUtils(test.TestCase): class TestTrainingWithDataTensors(test.TestCase): + def test_training_and_eval_methods_on_symbolic_tensors_single_io(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = 'rmsprop' + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + inputs = keras.backend.zeros(shape=(10, 3)) + targets = keras.backend.zeros(shape=(10, 4)) + + model.fit(inputs, targets, epochs=1, steps_per_epoch=2, verbose=0) + model.evaluate(inputs, targets, steps=2, verbose=0) + model.predict(inputs, steps=2) + model.train_on_batch(inputs, targets) + model.test_on_batch(inputs, targets) + model.fit(inputs, targets, + epochs=1, steps_per_epoch=2, verbose=0, + validation_data=(inputs, targets), validation_steps=2) + + def test_training_and_eval_methods_on_symbolic_tensors_multi_io(self): + with self.test_session(): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = 'rmsprop' + loss = 'mse' + loss_weights = [1., 0.5] + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) + + input_a_tf = keras.backend.zeros(shape=(10, 3)) + input_b_tf = keras.backend.zeros(shape=(10, 3)) + + output_d_tf = keras.backend.zeros(shape=(10, 4)) + output_e_tf = keras.backend.zeros(shape=(10, 4)) + + model.fit( + [input_a_tf, input_b_tf], [output_d_tf, output_e_tf], + epochs=1, + steps_per_epoch=2, + verbose=0) + with self.assertRaisesRegexp(ValueError, + 'should specify the `steps_per_epoch`'): + model.fit( + [input_a_tf, input_b_tf], [output_d_tf, output_e_tf], + epochs=1, + batch_size=5, + verbose=0) + model.train_on_batch([input_a_tf, input_b_tf], [output_d_tf, output_e_tf]) + + # Test with dictionary inputs + model.fit( + {'input_a': input_a_tf, + 'input_b': input_b_tf}, + {'dense': output_d_tf, + 'dropout': output_e_tf}, + epochs=1, + steps_per_epoch=2, + verbose=0) + model.fit( + {'input_a': input_a_tf, + 'input_b': input_b_tf}, + {'dense': output_d_tf, + 'dropout': output_e_tf}, + validation_data=({'input_a': input_a_tf, + 'input_b': input_b_tf}, + {'dense': output_d_tf, + 'dropout': output_e_tf}), + epochs=1, + steps_per_epoch=2, + validation_steps=2, + verbose=0) + model.train_on_batch( + {'input_a': input_a_tf, + 'input_b': input_b_tf}, + {'dense': output_d_tf, + 'dropout': output_e_tf}) + + # Test with validation data + model.fit( + [input_a_tf, input_b_tf], [output_d_tf, output_e_tf], + validation_data=([input_a_tf, input_b_tf], + [output_d_tf, output_e_tf]), + epochs=1, + steps_per_epoch=2, + validation_steps=2, + verbose=0) + # Test with validation split + with self.assertRaisesRegexp(ValueError, + 'you cannot use `validation_split`'): + model.fit( + [input_a_tf, input_b_tf], [output_d_tf, output_e_tf], + epochs=2, + steps_per_epoch=2, + verbose=0, + validation_split=0.2, + validation_steps=2) + + # Test evaluation / prediction methods + model.evaluate([input_a_tf, input_b_tf], [output_d_tf, output_e_tf], + steps=2, verbose=0) + model.predict([input_a_tf, input_b_tf], steps=2) + model.test_on_batch([input_a_tf, input_b_tf], [output_d_tf, output_e_tf]) + def test_model_with_input_feed_tensor(self): """We test building a model with a TF variable as input. diff --git a/tensorflow/python/keras/_impl/keras/engine/training_utils.py b/tensorflow/python/keras/_impl/keras/engine/training_utils.py index 105638ce10..a3fc8ef2a0 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_utils.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_utils.py @@ -22,9 +22,11 @@ import copy import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import tensor_util from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import losses +from tensorflow.python.ops import math_ops def check_num_samples(ins, @@ -64,15 +66,29 @@ def check_num_samples(ins, if batch_size is not None: raise ValueError( 'If ' + steps_name + ' is set, the `batch_size` must be None.') - elif ins and hasattr(ins[0], 'shape'): - num_samples = ins[0].shape[0] - else: + if has_symbolic_tensors(ins) and steps is None: + raise ValueError('If your data is in the form of symbolic tensors, ' + 'you should specify the `' + steps_name + '` argument ' + '(instead of the `batch_size` argument).') + if ins and hasattr(ins[0], 'shape'): + num_samples = int(ins[0].shape[0]) + elif steps is None: raise ValueError( 'Either the input data should have ' 'a defined shape, or ' + steps_name + ' should be specified.') return num_samples +def standardize_single_array(x): + if x is None: + return None + elif tensor_util.is_tensor(x): + return x + elif x.ndim == 1: + x = np.expand_dims(x, 1) + return x + + def standardize_input_data(data, names, shapes=None, @@ -130,9 +146,7 @@ def standardize_input_data(data, else: data = data.values if data.__class__.__name__ == 'DataFrame' else data data = [data] - data = [ - np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data - ] + data = [standardize_single_array(x) for x in data] if len(data) != len(names): if data and hasattr(data[0], 'shape'): @@ -158,7 +172,7 @@ def standardize_input_data(data, # Check shapes compatibility. if shapes: for i in range(len(names)): - if shapes[i] is not None: + if shapes[i] is not None and not tensor_util.is_tensor(data[i]): data_shape = data[i].shape shape = shapes[i] if data[i].ndim != len(shape): @@ -245,12 +259,13 @@ def check_array_lengths(inputs, targets, weights=None): """ def set_of_lengths(x): - # return a set with the variation between + # Returns a set with the variation between # different shapes, with None => 0 if x is None: return {} else: - return set([y.shape[0] for y in x if y is not None]) + return set([y.shape[0] for y in x + if y is not None and not tensor_util.is_tensor(y)]) set_x = set_of_lengths(inputs) set_y = set_of_lengths(targets) @@ -422,7 +437,7 @@ def weighted_masked_objective(fn): score_array = fn(y_true, y_pred) if mask is not None: # Cast the mask to floatX to avoid float64 upcasting in theano - mask = K.cast(mask, K.floatx()) + mask = math_ops.cast(mask, K.floatx()) # mask should have the same shape as score_array score_array *= mask # the loss per batch should be proportional @@ -436,7 +451,8 @@ def weighted_masked_objective(fn): weight_ndim = K.ndim(weights) score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim))) score_array *= weights - score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx())) + score_array /= K.mean( + math_ops.cast(math_ops.not_equal(weights, 0), K.floatx())) return K.mean(score_array) return weighted @@ -532,3 +548,8 @@ def standardize_weights(y, return weights else: return None + + +def has_symbolic_tensors(ls): + return (any(tensor_util.is_tensor(v) for v in ls) + and not context.executing_eagerly()) diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py index 8426d84df9..5d370ebbb5 100644 --- a/tensorflow/python/keras/_impl/keras/estimator.py +++ b/tensorflow/python/keras/_impl/keras/estimator.py @@ -466,8 +466,8 @@ def model_to_estimator(keras_model=None, keras_model_fn, model_dir=model_dir, config=config) # Pass the config into keras backend's default session. - with session.Session(config=estimator._session_config) as sess: - K.set_session(sess) + sess = session.Session(config=estimator._session_config) + K.set_session(sess) keras_weights = keras_model.get_weights() if keras_model._is_graph_network: diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py index c40ee109aa..11ca89d625 100644 --- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py @@ -26,6 +26,7 @@ from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion +from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -146,7 +147,7 @@ class PReLU(Layer): if K.backend() == 'theano': neg = ( K.pattern_broadcast(self.alpha, self.param_broadcast) * - (inputs - K.abs(inputs)) * 0.5) + (inputs - math_ops.abs(inputs)) * 0.5) else: neg = -self.alpha * K.relu(-inputs) return pos + neg @@ -232,7 +233,8 @@ class ThresholdedReLU(Layer): self.theta = K.cast_to_floatx(theta) def call(self, inputs, mask=None): - return inputs * K.cast(K.greater(inputs, self.theta), K.floatx()) + return inputs * math_ops.cast( + math_ops.greater(inputs, self.theta), K.floatx()) def get_config(self): config = {'theta': float(self.theta)} diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py index d95a094245..b78962d66a 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py @@ -29,6 +29,8 @@ from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent from tensorflow.python.keras._impl.keras.utils import conv_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -438,9 +440,9 @@ class ConvLSTM2D(ConvRecurrent2D): def get_initial_state(self, inputs): # (samples, timesteps, rows, cols, filters) - initial_state = K.zeros_like(inputs) + initial_state = array_ops.zeros_like(inputs) # (samples, rows, cols, filters) - initial_state = K.sum(initial_state, axis=1) + initial_state = math_ops.reduce_sum(initial_state, axis=1) shape = list(self.kernel_shape) shape[-1] = self.filters initial_state = self.input_conv( @@ -483,8 +485,8 @@ class ConvLSTM2D(ConvRecurrent2D): def get_constants(self, inputs, training=None): constants = [] if self.implementation == 0 and 0 < self.dropout < 1: - ones = K.zeros_like(inputs) - ones = K.sum(ones, axis=1) + ones = array_ops.zeros_like(inputs) + ones = math_ops.reduce_sum(ones, axis=1) ones += 1 def dropped_inputs(): @@ -501,8 +503,8 @@ class ConvLSTM2D(ConvRecurrent2D): if 0 < self.recurrent_dropout < 1: shape = list(self.kernel_shape) shape[-1] = self.filters - ones = K.zeros_like(inputs) - ones = K.sum(ones, axis=1) + ones = array_ops.zeros_like(inputs) + ones = math_ops.reduce_sum(ones, axis=1) ones = self.input_conv(ones, K.zeros(shape), padding=self.padding) ones += 1. diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py index 73e4f15f7e..c74fc1e4c0 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core.py +++ b/tensorflow/python/keras/_impl/keras/layers/core.py @@ -37,6 +37,8 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import func_dump from tensorflow.python.keras._impl.keras.utils.generic_utils import func_load from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.layers import core as tf_core_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -75,12 +77,12 @@ class Masking(Layer): self.mask_value = mask_value def compute_mask(self, inputs, mask=None): - return K.any(K.not_equal(inputs, self.mask_value), axis=-1) + return K.any(math_ops.not_equal(inputs, self.mask_value), axis=-1) def call(self, inputs): boolean_mask = K.any( - K.not_equal(inputs, self.mask_value), axis=-1, keepdims=True) - return inputs * K.cast(boolean_mask, inputs.dtype) + math_ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True) + return inputs * math_ops.cast(boolean_mask, inputs.dtype) def compute_output_shape(self, input_shape): return input_shape @@ -170,7 +172,7 @@ class SpatialDropout1D(Dropout): self.input_spec = InputSpec(ndim=3) def _get_noise_shape(self, inputs): - input_shape = K.shape(inputs) + input_shape = array_ops.shape(inputs) noise_shape = (input_shape[0], 1, input_shape[2]) return noise_shape @@ -222,7 +224,7 @@ class SpatialDropout2D(Dropout): self.input_spec = InputSpec(ndim=4) def _get_noise_shape(self, inputs): - input_shape = K.shape(inputs) + input_shape = array_ops.shape(inputs) if self.data_format == 'channels_first': return (input_shape[0], input_shape[1], 1, 1) elif self.data_format == 'channels_last': @@ -275,7 +277,7 @@ class SpatialDropout3D(Dropout): self.input_spec = InputSpec(ndim=5) def _get_noise_shape(self, inputs): - input_shape = K.shape(inputs) + input_shape = array_ops.shape(inputs) if self.data_format == 'channels_first': return (input_shape[0], input_shape[1], 1, 1, 1) elif self.data_format == 'channels_last': @@ -414,7 +416,8 @@ class Reshape(Layer): return tensor_shape.TensorShape(output_shape) def call(self, inputs): - return K.reshape(inputs, (K.shape(inputs)[0],) + self.target_shape) + return array_ops.reshape(inputs, + (array_ops.shape(inputs)[0],) + self.target_shape) def get_config(self): config = {'target_shape': self.target_shape} @@ -467,7 +470,7 @@ class Permute(Layer): return tensor_shape.TensorShape(output_shape) def call(self, inputs): - return K.permute_dimensions(inputs, (0,) + self.dims) + return array_ops.transpose(inputs, perm=(0,) + self.dims) def get_config(self): config = {'dims': self.dims} diff --git a/tensorflow/python/keras/_impl/keras/layers/core_test.py b/tensorflow/python/keras/_impl/keras/layers/core_test.py index 2ca816adbd..551d1b1c3a 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/core_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -159,7 +160,7 @@ class CoreLayersTest(test.TestCase): # test with lambda ld = keras.layers.Lambda( - lambda x: keras.backend.concatenate([keras.backend.square(x), x])) + lambda x: keras.backend.concatenate([math_ops.square(x), x])) config = ld.get_config() ld = keras.layers.Lambda.from_config(config) @@ -235,4 +236,3 @@ class CoreLayersTest(test.TestCase): if __name__ == '__main__': test.main() - diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py index 006ecd3135..540e2d945c 100644 --- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py +++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py @@ -24,6 +24,8 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import Layer from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -126,7 +128,7 @@ class Embedding(Layer): if not self.mask_zero: return None else: - return K.not_equal(inputs, 0) + return math_ops.not_equal(inputs, 0) @shape_type_conversion def compute_output_shape(self, input_shape): @@ -152,8 +154,8 @@ class Embedding(Layer): def call(self, inputs): if K.dtype(inputs) != 'int32': - inputs = K.cast(inputs, 'int32') - out = K.gather(self.embeddings, inputs) + inputs = math_ops.cast(inputs, 'int32') + out = array_ops.gather(self.embeddings, inputs) return out def get_config(self): diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py index c660cbd449..7c87e6c067 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge.py @@ -23,6 +23,9 @@ from __future__ import print_function from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine.base_layer import Layer from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export @@ -127,7 +130,7 @@ class _Merge(Layer): for x in inputs: x_ndim = K.ndim(x) for _ in range(max_ndim - x_ndim): - x = K.expand_dims(x, 1) + x = array_ops.expand_dims(x, axis=1) reshaped_inputs.append(x) return self._merge_function(reshaped_inputs) else: @@ -137,19 +140,22 @@ class _Merge(Layer): for x in inputs: x_ndim = K.ndim(x) if x_ndim is None: - x_shape = K.shape(x) + x_shape = array_ops.shape(x) batch_size = x_shape[0] - new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)]) - x_transposed = K.reshape(x, - K.stack([batch_size, - K.prod(x_shape[1:])])) - x_transposed = K.permute_dimensions(x_transposed, (1, 0)) - x_transposed = K.reshape(x_transposed, new_shape) + new_shape = K.concatenate( + [x_shape[1:], + array_ops.expand_dims(batch_size, axis=-1)]) + x_transposed = array_ops.reshape( + x, + array_ops.stack( + [batch_size, math_ops.reduce_prod(x_shape[1:])], axis=0)) + x_transposed = array_ops.transpose(x_transposed, perm=(1, 0)) + x_transposed = array_ops.reshape(x_transposed, new_shape) reshaped_inputs.append(x_transposed) transposed = True elif x_ndim > 1: dims = list(range(1, x_ndim)) + [0] - reshaped_inputs.append(K.permute_dimensions(x, dims)) + reshaped_inputs.append(array_ops.transpose(x, perm=dims)) transposed = True else: # We don't transpose inputs if they are 1D vectors or scalars. @@ -159,17 +165,18 @@ class _Merge(Layer): if transposed: # If inputs have been transposed, we have to transpose the output too. if y_ndim is None: - y_shape = K.shape(y) - y_ndim = K.shape(y_shape)[0] + y_shape = array_ops.shape(y) + y_ndim = array_ops.shape(y_shape)[0] batch_size = y_shape[y_ndim - 1] - new_shape = K.concatenate( - [K.expand_dims(batch_size), y_shape[:y_ndim - 1]]) - y = K.reshape(y, (-1, batch_size)) - y = K.permute_dimensions(y, (1, 0)) - y = K.reshape(y, new_shape) + new_shape = K.concatenate([ + array_ops.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1] + ]) + y = array_ops.reshape(y, (-1, batch_size)) + y = array_ops.transpose(y, perm=(1, 0)) + y = array_ops.reshape(y, new_shape) elif y_ndim > 1: dims = [y_ndim - 1] + list(range(y_ndim - 1)) - y = K.permute_dimensions(y, dims) + y = array_ops.transpose(y, perm=dims) return y else: return self._merge_function(inputs) @@ -207,7 +214,7 @@ class _Merge(Layer): 'should have the same length.') if all([m is None for m in mask]): return None - masks = [K.expand_dims(m, 0) for m in mask if m is not None] + masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None] return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False) @@ -325,7 +332,7 @@ class Maximum(_Merge): def _merge_function(self, inputs): output = inputs[0] for i in range(1, len(inputs)): - output = K.maximum(output, inputs[i]) + output = math_ops.maximum(output, inputs[i]) return output @@ -340,7 +347,7 @@ class Minimum(_Merge): def _merge_function(self, inputs): output = inputs[0] for i in range(1, len(inputs)): - output = K.minimum(output, inputs[i]) + output = math_ops.minimum(output, inputs[i]) return output @@ -418,10 +425,10 @@ class Concatenate(_Merge): for input_i, mask_i in zip(inputs, mask): if mask_i is None: # Input is unmasked. Append all 1s to masks, - masks.append(K.ones_like(input_i, dtype='bool')) + masks.append(array_ops.ones_like(input_i, dtype='bool')) elif K.ndim(mask_i) < K.ndim(input_i): # Mask is smaller than the input, expand it - masks.append(K.expand_dims(mask_i)) + masks.append(array_ops.expand_dims(mask_i, axis=-1)) else: masks.append(mask_i) concatenated = K.concatenate(masks, axis=self.axis) @@ -511,8 +518,8 @@ class Dot(_Merge): else: axes.append(self.axes[i]) if self.normalize: - x1 = K.l2_normalize(x1, axis=axes[0]) - x2 = K.l2_normalize(x2, axis=axes[1]) + x1 = nn.l2_normalize(x1, axis=axes[0]) + x2 = nn.l2_normalize(x2, axis=axes[1]) output = K.batch_dot(x1, x2, axes) return output diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py index e309d160e5..72dc7a1ff8 100644 --- a/tensorflow/python/keras/_impl/keras/layers/noise.py +++ b/tensorflow/python/keras/_impl/keras/layers/noise.py @@ -23,6 +23,8 @@ import numpy as np from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine import Layer from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -58,7 +60,7 @@ class GaussianNoise(Layer): def noised(): return inputs + K.random_normal( - shape=K.shape(inputs), mean=0., stddev=self.stddev) + shape=array_ops.shape(inputs), mean=0., stddev=self.stddev) return K.in_train_phase(noised, inputs, training=training) @@ -104,7 +106,7 @@ class GaussianDropout(Layer): def noised(): stddev = np.sqrt(self.rate / (1.0 - self.rate)) return inputs * K.random_normal( - shape=K.shape(inputs), mean=1.0, stddev=stddev) + shape=array_ops.shape(inputs), mean=1.0, stddev=stddev) return K.in_train_phase(noised, inputs, training=training) return inputs @@ -153,7 +155,7 @@ class AlphaDropout(Layer): self.supports_masking = True def _get_noise_shape(self, inputs): - return self.noise_shape if self.noise_shape else K.shape(inputs) + return self.noise_shape if self.noise_shape else array_ops.shape(inputs) def call(self, inputs, training=None): if 0. < self.rate < 1.: @@ -164,9 +166,9 @@ class AlphaDropout(Layer): scale = 1.0507009873554804934193349852946 alpha_p = -alpha * scale - kept_idx = K.greater_equal( + kept_idx = math_ops.greater_equal( K.random_uniform(noise_shape, seed=seed), rate) - kept_idx = K.cast(kept_idx, K.floatx()) + kept_idx = math_ops.cast(kept_idx, K.floatx()) # Get affine transformation params a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5 diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index 791f9b3113..7f9f77c296 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -33,6 +33,9 @@ from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -503,9 +506,12 @@ class RNN(Layer): def get_initial_state(self, inputs): # build an all-zero tensor of shape (samples, output_dim) - initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim) - initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,) - initial_state = K.expand_dims(initial_state) # (samples, 1) + initial_state = array_ops.zeros_like(inputs) + # shape of initial_state = (samples, timesteps, input_dim) + initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2)) + # shape of initial_state = (samples,) + initial_state = array_ops.expand_dims(initial_state, axis=-1) + # shape of initial_state = (samples, 1) if hasattr(self.cell.state_size, '__len__'): return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size] else: @@ -631,7 +637,7 @@ class RNN(Layer): if self.stateful: updates = [] for i in range(len(states)): - updates.append(K.update(self.states[i], states[i])) + updates.append(state_ops.assign(self.states[i], states[i])) self.add_update(updates, inputs) if self.return_sequences: @@ -907,8 +913,7 @@ class SimpleRNNCell(Layer): prev_output = states[0] if 0 < self.dropout < 1 and self._dropout_mask is None: self._dropout_mask = _generate_dropout_mask( - _generate_dropout_ones(inputs, - K.shape(inputs)[-1]), + _generate_dropout_ones(inputs, array_ops.shape(inputs)[-1]), self.dropout, training=training) if (0 < self.recurrent_dropout < 1 and @@ -1309,8 +1314,7 @@ class GRUCell(Layer): if 0 < self.dropout < 1 and self._dropout_mask is None: self._dropout_mask = _generate_dropout_mask( - _generate_dropout_ones(inputs, - K.shape(inputs)[-1]), + _generate_dropout_ones(inputs, array_ops.shape(inputs)[-1]), self.dropout, training=training, count=3) @@ -1793,8 +1797,7 @@ class LSTMCell(Layer): def call(self, inputs, states, training=None): if 0 < self.dropout < 1 and self._dropout_mask is None: self._dropout_mask = _generate_dropout_mask( - _generate_dropout_ones(inputs, - K.shape(inputs)[-1]), + _generate_dropout_ones(inputs, array_ops.shape(inputs)[-1]), self.dropout, training=training, count=4) @@ -2176,7 +2179,7 @@ class LSTM(RNN): def _generate_dropout_ones(inputs, dims): - return K.ones((K.shape(inputs)[0], dims)) + return K.ones((array_ops.shape(inputs)[0], dims)) def _generate_dropout_mask(ones, rate, training=None, count=1): @@ -2351,9 +2354,12 @@ class Recurrent(Layer): def get_initial_state(self, inputs): # build an all-zero tensor of shape (samples, output_dim) - initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim) - initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,) - initial_state = K.expand_dims(initial_state) # (samples, 1) + initial_state = array_ops.zeros_like(inputs) + # shape of initial_state = (samples, timesteps, input_dim) + initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2)) + # shape of initial_state = (samples,) + initial_state = array_ops.expand_dims(initial_state, axis=-1) + # shape of initial_state = (samples, 1) initial_state = K.tile(initial_state, [1, self.units]) # (samples, output_dim) initial_state = [initial_state for _ in range(len(self.states))] @@ -2456,7 +2462,7 @@ class Recurrent(Layer): if self.stateful: updates = [] for i in range(len(states)): - updates.append(K.update(self.states[i], states[i])) + updates.append(state_ops.assign(self.states[i], states[i])) self.add_update(updates, inputs) # Properly set learning phase diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py index de022153f6..fb743b617f 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py @@ -24,6 +24,9 @@ from __future__ import print_function import numpy as np from tensorflow.python.keras._impl import keras +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import test @@ -395,8 +398,8 @@ class RNNTest(test.TestCase): # Test `get_losses_for` and `losses` x = keras.Input((None, 1)) - loss_1 = keras.backend.sum(x) - loss_2 = keras.backend.sum(cells[0].kernel) + loss_1 = math_ops.reduce_sum(x) + loss_2 = math_ops.reduce_sum(cells[0].kernel) cells[0].add_loss(loss_1, inputs=x) cells[0].add_loss(loss_2) self.assertEqual(len(layer.losses), 2) @@ -410,10 +413,10 @@ class RNNTest(test.TestCase): layer.build((None, None, 1)) x = keras.Input((None, 1)) - update_1 = keras.backend.update_add( - cells[0].kernel, x[0, 0, 0] * cells[0].kernel) - update_2 = keras.backend.update_add( - cells[0].kernel, keras.backend.ones_like(cells[0].kernel)) + update_1 = state_ops.assign_add(cells[0].kernel, + x[0, 0, 0] * cells[0].kernel) + update_2 = state_ops.assign_add(cells[0].kernel, + array_ops.ones_like(cells[0].kernel)) cells[0].add_update(update_1, inputs=x) cells[0].add_update(update_2) self.assertEqual(len(layer.updates), 2) diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py index 76ddd9299d..c510e464ae 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py +++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py @@ -28,6 +28,7 @@ from tensorflow.python.keras._impl.keras.engine import Layer from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.layers import utils as tf_layers_util +from tensorflow.python.ops import array_ops from tensorflow.python.util.tf_export import tf_export @@ -209,11 +210,11 @@ class TimeDistributed(Wrapper): # We can go with reshape-based implementation for performance. input_length = input_shape[1] if not input_length: - input_length = K.shape(inputs)[1] + input_length = array_ops.shape(inputs)[1] # Shape: (num_samples * timesteps, ...). And track the # transformation in self._input_map. input_uid = tf_layers_util.object_list_uid(inputs) - inputs = K.reshape(inputs, (-1,) + input_shape[2:]) + inputs = array_ops.reshape(inputs, (-1,) + input_shape[2:]) self._input_map[input_uid] = inputs # (num_samples * timesteps, ...) y = self.layer.call(inputs, **kwargs) @@ -221,7 +222,7 @@ class TimeDistributed(Wrapper): uses_learning_phase = y._uses_learning_phase # Shape: (num_samples, timesteps, ...) output_shape = self.compute_output_shape(input_shape).as_list() - y = K.reshape(y, (-1, input_length) + tuple(output_shape[2:])) + y = array_ops.reshape(y, (-1, input_length) + tuple(output_shape[2:])) # Apply activity regularizer if any: if (hasattr(self.layer, 'activity_regularizer') and diff --git a/tensorflow/python/keras/_impl/keras/losses.py b/tensorflow/python/keras/_impl/keras/losses.py index 1576ed7b99..1d634d3801 100644 --- a/tensorflow/python/keras/_impl/keras/losses.py +++ b/tensorflow/python/keras/_impl/keras/losses.py @@ -24,51 +24,55 @@ import six from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export @tf_export('keras.metrics.mean_squared_error', 'keras.losses.mean_squared_error') def mean_squared_error(y_true, y_pred): - return K.mean(K.square(y_pred - y_true), axis=-1) + return K.mean(math_ops.square(y_pred - y_true), axis=-1) @tf_export('keras.metrics.mean_absolute_error', 'keras.losses.mean_absolute_error') def mean_absolute_error(y_true, y_pred): - return K.mean(K.abs(y_pred - y_true), axis=-1) + return K.mean(math_ops.abs(y_pred - y_true), axis=-1) @tf_export('keras.metrics.mean_absolute_percentage_error', 'keras.losses.mean_absolute_percentage_error') def mean_absolute_percentage_error(y_true, y_pred): - diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true), K.epsilon(), None)) + diff = math_ops.abs( + (y_true - y_pred) / K.clip(math_ops.abs(y_true), K.epsilon(), None)) return 100. * K.mean(diff, axis=-1) @tf_export('keras.metrics.mean_squared_logarithmic_error', 'keras.losses.mean_squared_logarithmic_error') def mean_squared_logarithmic_error(y_true, y_pred): - first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.) - second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.) - return K.mean(K.square(first_log - second_log), axis=-1) + first_log = math_ops.log(K.clip(y_pred, K.epsilon(), None) + 1.) + second_log = math_ops.log(K.clip(y_true, K.epsilon(), None) + 1.) + return K.mean(math_ops.square(first_log - second_log), axis=-1) @tf_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge') def squared_hinge(y_true, y_pred): - return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1) + return K.mean( + math_ops.square(math_ops.maximum(1. - y_true * y_pred, 0.)), axis=-1) @tf_export('keras.metrics.hinge', 'keras.losses.hinge') def hinge(y_true, y_pred): - return K.mean(K.maximum(1. - y_true * y_pred, 0.), axis=-1) + return K.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1) @tf_export('keras.losses.categorical_hinge') def categorical_hinge(y_true, y_pred): - pos = K.sum(y_true * y_pred, axis=-1) - neg = K.max((1. - y_true) * y_pred, axis=-1) - return K.maximum(0., neg - pos + 1.) + pos = math_ops.reduce_sum(y_true * y_pred, axis=-1) + neg = math_ops.reduce_max((1. - y_true) * y_pred, axis=-1) + return math_ops.maximum(0., neg - pos + 1.) @tf_export('keras.losses.logcosh') @@ -89,7 +93,7 @@ def logcosh(y_true, y_pred): """ def _logcosh(x): - return x + K.softplus(-2. * x) - K.log(2.) + return x + nn.softplus(-2. * x) - math_ops.log(2.) return K.mean(_logcosh(y_pred - y_true), axis=-1) @@ -117,19 +121,19 @@ def binary_crossentropy(y_true, y_pred): def kullback_leibler_divergence(y_true, y_pred): y_true = K.clip(y_true, K.epsilon(), 1) y_pred = K.clip(y_pred, K.epsilon(), 1) - return K.sum(y_true * K.log(y_true / y_pred), axis=-1) + return math_ops.reduce_sum(y_true * math_ops.log(y_true / y_pred), axis=-1) @tf_export('keras.metrics.poisson', 'keras.losses.poisson') def poisson(y_true, y_pred): - return K.mean(y_pred - y_true * K.log(y_pred + K.epsilon()), axis=-1) + return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1) @tf_export('keras.metrics.cosine_proximity', 'keras.losses.cosine_proximity') def cosine_proximity(y_true, y_pred): - y_true = K.l2_normalize(y_true, axis=-1) - y_pred = K.l2_normalize(y_pred, axis=-1) - return -K.sum(y_true * y_pred, axis=-1) + y_true = nn.l2_normalize(y_true, axis=-1) + y_pred = nn.l2_normalize(y_pred, axis=-1) + return -math_ops.reduce_sum(y_true * y_pred, axis=-1) # Aliases. diff --git a/tensorflow/python/keras/_impl/keras/metrics.py b/tensorflow/python/keras/_impl/keras/metrics.py index 82778a3dc4..747c3e6515 100644 --- a/tensorflow/python/keras/_impl/keras/metrics.py +++ b/tensorflow/python/keras/_impl/keras/metrics.py @@ -37,37 +37,45 @@ from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crosse from tensorflow.python.keras._impl.keras.losses import squared_hinge from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export @tf_export('keras.metrics.binary_accuracy') def binary_accuracy(y_true, y_pred): - return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1) + return K.mean(math_ops.equal(y_true, math_ops.round(y_pred)), axis=-1) @tf_export('keras.metrics.categorical_accuracy') def categorical_accuracy(y_true, y_pred): - return K.cast( - K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)), K.floatx()) + return math_ops.cast( + math_ops.equal( + math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)), + K.floatx()) def sparse_categorical_accuracy(y_true, y_pred): - return K.cast( - K.equal( - K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1), - K.floatx())), K.floatx()) + return math_ops.cast( + math_ops.equal( + math_ops.reduce_max(y_true, axis=-1), + math_ops.cast(math_ops.argmax(y_pred, axis=-1), K.floatx())), + K.floatx()) @tf_export('keras.metrics.top_k_categorical_accuracy') def top_k_categorical_accuracy(y_true, y_pred, k=5): - return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1) + return K.mean( + nn.in_top_k(y_pred, math_ops.argmax(y_true, axis=-1), k), axis=-1) @tf_export('keras.metrics.sparse_top_k_categorical_accuracy') def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): return K.mean( - K.in_top_k(y_pred, K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1) - + nn.in_top_k(y_pred, + math_ops.cast(math_ops.reduce_max(y_true, axis=-1), 'int32'), + k), + axis=-1) # Aliases diff --git a/tensorflow/python/keras/_impl/keras/metrics_test.py b/tensorflow/python/keras/_impl/keras/metrics_test.py index 44289ea02a..9deaab0c05 100644 --- a/tensorflow/python/keras/_impl/keras/metrics_test.py +++ b/tensorflow/python/keras/_impl/keras/metrics_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.python.keras._impl import keras +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import test @@ -104,16 +106,15 @@ class KerasMetricsTest(test.TestCase): The total number of true positives seen this epoch at the completion of the batch. """ - y_true = keras.backend.cast(y_true, 'int32') - y_pred = keras.backend.cast(keras.backend.round(y_pred), 'int32') - correct_preds = keras.backend.cast( - keras.backend.equal(y_pred, y_true), 'int32') - true_pos = keras.backend.cast( - keras.backend.sum(correct_preds * y_true), 'int32') + y_true = math_ops.cast(y_true, 'int32') + y_pred = math_ops.cast(math_ops.round(y_pred), 'int32') + correct_preds = math_ops.cast(math_ops.equal(y_pred, y_true), 'int32') + true_pos = math_ops.cast( + math_ops.reduce_sum(correct_preds * y_true), 'int32') current_true_pos = self.true_positives * 1 - self.add_update(keras.backend.update_add(self.true_positives, - true_pos), - inputs=[y_true, y_pred]) + self.add_update( + state_ops.assign_add(self.true_positives, true_pos), + inputs=[y_true, y_pred]) return current_true_pos + true_pos metric_fn = BinaryTruePositives() diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py index acbb9091d3..9f383deb72 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers.py +++ b/tensorflow/python/keras/_impl/keras/optimizers.py @@ -31,6 +31,7 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_ from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import training_util @@ -118,7 +119,8 @@ class Optimizer(object): 'Common ops without gradient: ' 'K.argmax, K.round, K.eval.') if hasattr(self, 'clipnorm') and self.clipnorm > 0: - norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads])) + norm = K.sqrt( + sum([math_ops.reduce_sum(math_ops.square(g)) for g in grads])) grads = [clip_norm(g, self.clipnorm, norm) for g in grads] if hasattr(self, 'clipvalue') and self.clipvalue > 0: grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] @@ -204,20 +206,20 @@ class SGD(Optimizer): def get_updates(self, loss, params): grads = self.get_gradients(loss, params) - self.updates = [K.update_add(self.iterations, 1)] + self.updates = [state_ops.assign_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: - lr = lr * (1. / # pylint: disable=g-no-augmented-assignment - (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * ( # pylint: disable=g-no-augmented-assignment + 1. / (1. + self.decay * math_ops.cast(self.iterations, + K.dtype(self.decay)))) # momentum shapes = [K.int_shape(p) for p in params] moments = [K.zeros(shape) for shape in shapes] self.weights = [self.iterations] + moments for p, g, m in zip(params, grads, moments): v = self.momentum * m - lr * g # velocity - self.updates.append(K.update(m, v)) + self.updates.append(state_ops.assign(m, v)) if self.nesterov: new_p = p + self.momentum * v - lr * g @@ -228,7 +230,7 @@ class SGD(Optimizer): if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) - self.updates.append(K.update(p, new_p)) + self.updates.append(state_ops.assign(p, new_p)) return self.updates def get_config(self): @@ -277,25 +279,25 @@ class RMSprop(Optimizer): grads = self.get_gradients(loss, params) accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] self.weights = accumulators - self.updates = [K.update_add(self.iterations, 1)] + self.updates = [state_ops.assign_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: - lr = lr * (1. / # pylint: disable=g-no-augmented-assignment - (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * ( # pylint: disable=g-no-augmented-assignment + 1. / (1. + self.decay * math_ops.cast(self.iterations, + K.dtype(self.decay)))) for p, g, a in zip(params, grads, accumulators): # update accumulator - new_a = self.rho * a + (1. - self.rho) * K.square(g) - self.updates.append(K.update(a, new_a)) + new_a = self.rho * a + (1. - self.rho) * math_ops.square(g) + self.updates.append(state_ops.assign(a, new_a)) new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon) # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) - self.updates.append(K.update(p, new_p)) + self.updates.append(state_ops.assign(p, new_p)) return self.updates def get_config(self): @@ -339,24 +341,24 @@ class Adagrad(Optimizer): shapes = [K.int_shape(p) for p in params] accumulators = [K.zeros(shape) for shape in shapes] self.weights = accumulators - self.updates = [K.update_add(self.iterations, 1)] + self.updates = [state_ops.assign_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: - lr = lr * (1. / # pylint: disable=g-no-augmented-assignment - (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * ( # pylint: disable=g-no-augmented-assignment + 1. / (1. + self.decay * math_ops.cast(self.iterations, + K.dtype(self.decay)))) for p, g, a in zip(params, grads, accumulators): - new_a = a + K.square(g) # update accumulator - self.updates.append(K.update(a, new_a)) + new_a = a + math_ops.square(g) # update accumulator + self.updates.append(state_ops.assign(a, new_a)) new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon) # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) - self.updates.append(K.update(p, new_p)) + self.updates.append(state_ops.assign(p, new_p)) return self.updates def get_config(self): @@ -403,18 +405,18 @@ class Adadelta(Optimizer): accumulators = [K.zeros(shape) for shape in shapes] delta_accumulators = [K.zeros(shape) for shape in shapes] self.weights = accumulators + delta_accumulators - self.updates = [K.update_add(self.iterations, 1)] + self.updates = [state_ops.assign_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: - lr = lr * (1. / # pylint: disable=g-no-augmented-assignment - (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * ( # pylint: disable=g-no-augmented-assignment + 1. / (1. + self.decay * math_ops.cast(self.iterations, + K.dtype(self.decay)))) for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators): # update accumulator - new_a = self.rho * a + (1. - self.rho) * K.square(g) - self.updates.append(K.update(a, new_a)) + new_a = self.rho * a + (1. - self.rho) * math_ops.square(g) + self.updates.append(state_ops.assign(a, new_a)) # use the new accumulator and the *old* delta_accumulator update = g * K.sqrt(d_a + self.epsilon) / K.sqrt(new_a + self.epsilon) @@ -424,11 +426,11 @@ class Adadelta(Optimizer): if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) - self.updates.append(K.update(p, new_p)) + self.updates.append(state_ops.assign(p, new_p)) # update delta_accumulator - new_d_a = self.rho * d_a + (1 - self.rho) * K.square(update) - self.updates.append(K.update(d_a, new_d_a)) + new_d_a = self.rho * d_a + (1 - self.rho) * math_ops.square(update) + self.updates.append(state_ops.assign(d_a, new_d_a)) return self.updates def get_config(self): @@ -483,17 +485,18 @@ class Adam(Optimizer): def get_updates(self, loss, params): grads = self.get_gradients(loss, params) - self.updates = [K.update_add(self.iterations, 1)] + self.updates = [state_ops.assign_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: - lr = lr * (1. / # pylint: disable=g-no-augmented-assignment - (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * ( # pylint: disable=g-no-augmented-assignment + 1. / (1. + self.decay * math_ops.cast(self.iterations, + K.dtype(self.decay)))) - t = K.cast(self.iterations, K.floatx()) + 1 + t = math_ops.cast(self.iterations, K.floatx()) + 1 lr_t = lr * ( - K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))) + K.sqrt(1. - math_ops.pow(self.beta_2, t)) / + (1. - math_ops.pow(self.beta_1, t))) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] @@ -505,23 +508,23 @@ class Adam(Optimizer): for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g - v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) + v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g) if self.amsgrad: - vhat_t = K.maximum(vhat, v_t) + vhat_t = math_ops.maximum(vhat, v_t) p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon) - self.updates.append(K.update(vhat, vhat_t)) + self.updates.append(state_ops.assign(vhat, vhat_t)) else: p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) - self.updates.append(K.update(m, m_t)) - self.updates.append(K.update(v, v_t)) + self.updates.append(state_ops.assign(m, m_t)) + self.updates.append(state_ops.assign(v, v_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) - self.updates.append(K.update(p, new_p)) + self.updates.append(state_ops.assign(p, new_p)) return self.updates def get_config(self): @@ -573,16 +576,16 @@ class Adamax(Optimizer): def get_updates(self, loss, params): grads = self.get_gradients(loss, params) - self.updates = [K.update_add(self.iterations, 1)] + self.updates = [state_ops.assign_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: - lr = lr * (1. / # pylint: disable=g-no-augmented-assignment - (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * ( # pylint: disable=g-no-augmented-assignment + 1. / (1. + self.decay * math_ops.cast(self.iterations, + K.dtype(self.decay)))) - t = K.cast(self.iterations, K.floatx()) + 1 - lr_t = lr / (1. - K.pow(self.beta_1, t)) + t = math_ops.cast(self.iterations, K.floatx()) + 1 + lr_t = lr / (1. - math_ops.pow(self.beta_1, t)) shapes = [K.int_shape(p) for p in params] # zero init of 1st moment @@ -594,18 +597,18 @@ class Adamax(Optimizer): for p, g, m, u in zip(params, grads, ms, us): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g - u_t = K.maximum(self.beta_2 * u, K.abs(g)) + u_t = math_ops.maximum(self.beta_2 * u, math_ops.abs(g)) p_t = p - lr_t * m_t / (u_t + self.epsilon) - self.updates.append(K.update(m, m_t)) - self.updates.append(K.update(u, u_t)) + self.updates.append(state_ops.assign(m, m_t)) + self.updates.append(state_ops.assign(u, u_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) - self.updates.append(K.update(p, new_p)) + self.updates.append(state_ops.assign(p, new_p)) return self.updates def get_config(self): @@ -659,16 +662,17 @@ class Nadam(Optimizer): def get_updates(self, loss, params): grads = self.get_gradients(loss, params) - self.updates = [K.update_add(self.iterations, 1)] + self.updates = [state_ops.assign_add(self.iterations, 1)] - t = K.cast(self.iterations, K.floatx()) + 1 + t = math_ops.cast(self.iterations, K.floatx()) + 1 # Due to the recommendations in [2], i.e. warming momentum schedule momentum_cache_t = self.beta_1 * ( - 1. - 0.5 * (K.pow(K.cast_to_floatx(0.96), t * self.schedule_decay))) + 1. - 0.5 * + (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay))) momentum_cache_t_1 = self.beta_1 * ( 1. - 0.5 * - (K.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay))) + (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay))) m_schedule_new = self.m_schedule * momentum_cache_t m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1 self.updates.append((self.m_schedule, m_schedule_new)) @@ -684,13 +688,13 @@ class Nadam(Optimizer): g_prime = g / (1. - m_schedule_new) m_t = self.beta_1 * m + (1. - self.beta_1) * g m_t_prime = m_t / (1. - m_schedule_next) - v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g) - v_t_prime = v_t / (1. - K.pow(self.beta_2, t)) + v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g) + v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t)) m_t_bar = ( 1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime - self.updates.append(K.update(m, m_t)) - self.updates.append(K.update(v, v_t)) + self.updates.append(state_ops.assign(m, m_t)) + self.updates.append(state_ops.assign(v, v_t)) p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon) new_p = p_t @@ -699,7 +703,7 @@ class Nadam(Optimizer): if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) - self.updates.append(K.update(p, new_p)) + self.updates.append(state_ops.assign(p, new_p)) return self.updates def get_config(self): @@ -743,7 +747,7 @@ class TFOptimizer(Optimizer): global_step = training_util.get_global_step() opt_update = self.optimizer.apply_gradients(grads, global_step) else: - self.updates = [K.update_add(self.iterations, 1)] + self.updates = [state_ops.assign_add(self.iterations, 1)] if not params: return self.updates diff --git a/tensorflow/python/keras/_impl/keras/regularizers.py b/tensorflow/python/keras/_impl/keras/regularizers.py index 2c30844647..74c37d370e 100644 --- a/tensorflow/python/keras/_impl/keras/regularizers.py +++ b/tensorflow/python/keras/_impl/keras/regularizers.py @@ -23,6 +23,7 @@ import six from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -55,9 +56,9 @@ class L1L2(Regularizer): def __call__(self, x): regularization = 0. if self.l1: - regularization += K.sum(self.l1 * K.abs(x)) + regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x)) if self.l2: - regularization += K.sum(self.l2 * K.square(x)) + regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x)) return regularization def get_config(self): diff --git a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py index 4c8009dfd8..902972ecbb 100644 --- a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py @@ -35,7 +35,7 @@ def count_params(weights): Returns: The total number of scalars composing the weights """ - return int(np.sum([K.count_params(p) for p in set(weights)])) + return int(np.sum([np.prod(p.get_shape().as_list()) for p in set(weights)])) def print_summary(model, line_length=None, positions=None, print_fn=None): @@ -193,8 +193,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): else: trainable_count = count_params(model.trainable_weights) - non_trainable_count = int( - np.sum([K.count_params(p) for p in set(model.non_trainable_weights)])) + non_trainable_count = count_params(model.non_trainable_weights) print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count)) print_fn('Trainable params: {:,}'.format(trainable_count)) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index ea210346c1..6c34ea1816 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -295,7 +295,6 @@ tf_py_test( "//tensorflow/python:nn_grad", ], data = ["//tensorflow/core:image_testdata"], - tags = ["no_windows"], ) tf_py_test( @@ -1142,7 +1141,6 @@ tf_py_test( "//tensorflow/python:variables", ], data = ["//tensorflow/core:lmdb_testdata"], - tags = ["no_windows"], ) cuda_py_test( @@ -2332,7 +2330,6 @@ cuda_py_test( "//tensorflow/python:variables", ], shard_count = 4, - tags = ["no_windows"], ) cuda_py_test( @@ -2463,7 +2460,6 @@ cuda_py_test( "//tensorflow/python/eager:context", ], shard_count = 10, - tags = ["no_windows"], ) cuda_py_test( @@ -2523,7 +2519,10 @@ cuda_py_test( "//tensorflow/python:sparse_ops", ], shard_count = 5, - tags = ["noasan"], + tags = [ + "noasan", + "optonly", # b/77589990 + ], ) cuda_py_test( @@ -2726,6 +2725,7 @@ cuda_py_test( ], data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"], shard_count = 20, + tags = ["no_windows"], ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 64c1760d5e..5a20eebbc5 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -780,6 +780,14 @@ class StridedSliceGradTest(test_util.TensorFlowTestCase): grad = GradSliceChecker(self, sess, var, np.array(8)) _ = grad[tuple()] + def testInt64Indices(self): + with self.test_session(use_gpu=True) as sess: + a = math_ops.range(3) + index = constant_op.constant(1, dtype=dtypes.int64) + b = 2 * a[index] + grad, = gradients_impl.gradients(b, a) + self.assertAllEqual(sess.run(grad), [0, 2, 0]) + class StridedSliceGradTypeTest(test_util.TensorFlowTestCase): """Test varied index types and host located memory.""" @@ -999,30 +1007,38 @@ class SliceAssignTest(test_util.TensorFlowTestCase): class ShapeSizeRankTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes() def testDenseShape(self): - with self.test_session(): - t_value = [[0, 42], [24, 0]] - self.assertAllEqual((2, 2), array_ops.shape(t_value).eval()) - self.assertEqual(4, array_ops.size(t_value).eval()) - self.assertEqual(2, array_ops.rank(t_value).eval()) + t_value = [[0, 42], [24, 0]] + self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t_value))) + self.assertEqual(4, self.evaluate(array_ops.size(t_value))) + self.assertEqual(2, self.evaluate(array_ops.rank(t_value))) - t = constant_op.constant(t_value) - self.assertAllEqual((2, 2), array_ops.shape(t).eval()) - self.assertEqual(4, array_ops.size(t).eval()) - self.assertEqual(2, array_ops.rank(t).eval()) + t = constant_op.constant(t_value) + self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t))) + self.assertEqual(4, self.evaluate(array_ops.size(t))) + self.assertEqual(2, self.evaluate(array_ops.rank(t))) + @test_util.run_in_graph_and_eager_modes() def testSparseShape(self): - with self.test_session(): - sp_value = sparse_tensor.SparseTensorValue( - indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2)) - self.assertAllEqual((2, 2), array_ops.shape(sp_value).eval()) - self.assertEqual(4, array_ops.size(sp_value).eval()) - self.assertEqual(2, array_ops.rank(sp_value).eval()) - - sp = sparse_tensor.SparseTensor.from_value(sp_value) - self.assertAllEqual((2, 2), array_ops.shape(sp).eval()) - self.assertEqual(4, array_ops.size(sp).eval()) - self.assertEqual(2, array_ops.rank(sp).eval()) + sp_value = sparse_tensor.SparseTensorValue( + indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2)) + self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(sp_value))) + self.assertEqual(4, self.evaluate(array_ops.size(sp_value))) + self.assertEqual(2, self.evaluate(array_ops.rank(sp_value))) + + sp = sparse_tensor.SparseTensor.from_value(sp_value) + self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(sp))) + self.assertEqual(4, self.evaluate(array_ops.size(sp))) + self.assertEqual(2, self.evaluate(array_ops.rank(sp))) + + @test_util.run_in_graph_and_eager_modes() + def testSizeDtype(self): + tensor = [1] + self.assertEqual(dtypes.int32, self.evaluate(array_ops.size(tensor)).dtype) + self.assertEqual( + dtypes.int64, + self.evaluate(array_ops.size(tensor, out_type=dtypes.int64)).dtype) @test_util.with_c_api diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py index ec8ac74163..f4616fd661 100644 --- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn_ops @@ -344,6 +345,8 @@ class Conv3DTest(test.TestCase): if data_format == "NCDHW": conv = test_util.NCHWToNHWC(conv) + self.assertEqual(conv.shape, tensor_shape.TensorShape(output_shape)) + if test_input: jacob_t, jacob_n = gradient_checker.compute_gradient( orig_input_tensor, input_shape, conv, output_shape) diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 8db0bb6f0d..34e7751243 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -2165,5 +2165,47 @@ class AccumulateTest(test.TestCase): math_ops.accumulate_n([a], tensor_dtype=np.int32) +class PolyvalTest(test.TestCase): + + def _runtest(self, dtype, degree): + x = np.random.rand(2, 2).astype(dtype) + coeffs = [np.random.rand(2, 2).astype(dtype) for _ in range(degree + 1)] + np_val = np.polyval(coeffs, x) + with self.test_session(): + tf_val = math_ops.polyval(coeffs, x) + self.assertAllClose(np_val, tf_val.eval()) + + def testSimple(self): + for dtype in [ + np.int32, np.float32, np.float64, np.complex64, np.complex128 + ]: + for degree in range(5): + self._runtest(dtype, degree) + + def testBroadcast(self): + dtype = np.float32 + degree = 3 + shapes = [(1,), (2, 1), (1, 2), (2, 2)] + for x_shape in shapes: + for coeff_shape in shapes: + x = np.random.rand(*x_shape).astype(dtype) + coeffs = [ + np.random.rand(*coeff_shape).astype(dtype) + for _ in range(degree + 1) + ] + np_val = np.polyval(coeffs, x) + with self.test_session(): + tf_val = math_ops.polyval(coeffs, x) + self.assertAllClose(np_val, tf_val.eval()) + + def testEmpty(self): + x = np.random.rand(2, 2).astype(np.float32) + coeffs = [] + np_val = np.polyval(coeffs, x) + with self.test_session(): + tf_val = math_ops.polyval(coeffs, x) + self.assertAllClose(np_val, tf_val.eval()) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py index df99a0ed25..a8def95b14 100644 --- a/tensorflow/python/kernel_tests/distributions/uniform_test.py +++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py @@ -281,6 +281,22 @@ class UniformTest(test.TestCase): expected_pdf = [1.0, 0.1] self.assertAllClose(expected_pdf, pdf.eval()) + def testUniformFloat64(self): + uniform = uniform_lib.Uniform( + low=np.float64(0.), high=np.float64(1.)) + + self.assertAllClose( + [1., 1.], + self.evaluate(uniform.prob(np.array([0.5, 0.6], dtype=np.float64)))) + + self.assertAllClose( + [0.5, 0.6], + self.evaluate(uniform.cdf(np.array([0.5, 0.6], dtype=np.float64)))) + + self.assertAllClose(0.5, self.evaluate(uniform.mean())) + self.assertAllClose(1 / 12., self.evaluate(uniform.variance())) + self.assertAllClose(0., self.evaluate(uniform.entropy())) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 1301ef9d19..34fb655035 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -39,6 +40,7 @@ import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import from tensorflow.python.platform import test +# pylint: disable=invalid-name def simple_scoped_fn(a, x): """Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope.""" with variable_scope.variable_scope("body"): @@ -158,6 +160,13 @@ class FunctionalOpsTest(test.TestCase): values=constant_op.constant([0, 1, 2]), dense_shape=[2, 2])) + @test_util.run_in_graph_and_eager_modes() + def testMapOverScalarErrors(self): + with self.assertRaisesRegexp(ValueError, "not scalars"): + functional_ops.map_fn(lambda x: x, [1, 2]) + with self.assertRaisesRegexp(ValueError, "not a scalar"): + functional_ops.map_fn(lambda x: x, 1) + def testMap_Scoped(self): with self.test_session() as sess: @@ -607,6 +616,276 @@ class FunctionalOpsTest(test.TestCase): mul = sess.run(remote_op) self.assertEqual(mul, 9) + def testIf(self): + + @function.Defun(dtypes.float32) + def Twice(x): + return x * 2 + + @function.Defun(dtypes.float32) + def Thrice(x): + return x * 3 + 1 + + with self.test_session(use_gpu=False) as sess: + + def Run(x): + return sess.run( + functional_ops.If(math_ops.greater(x, 0), [x], Twice, Thrice))[0] + + self.assertAllEqual(Run(9.), 18.) + self.assertAllEqual(Run(-8.), -23.) + self.assertAllEqual(Run(0.), 1.) + + def testWhile(self): + + @function.Defun(*[dtypes.float32] * 2) + def Cond(n, unused_x): + return n > 0 + + @function.Defun(*[dtypes.float32] * 2) + def Body(n, x): + return n - 1, x + n + + # TODO(b/65752372): Set `use_gpu=False` because + # `functional_ops.While()` does not reliably work on GPU (apparently + # because the result of evaluating the condition may be in device + # memory, but it is read on the host). + with self.test_session(use_gpu=False) as sess: + + def Run(n): + return sess.run(functional_ops.While([n, 0.], Cond, Body))[1] + + self.assertAllEqual(Run(20.), 210.) + self.assertAllEqual(Run(100.), 5050.) + + def testWhileError(self): + + @function.Defun(*[dtypes.float32] * 2) + def Cond(n, unused_x): + return n > 0 + + @function.Defun(*[dtypes.float32] * 2) + def CondReturnsTooManyArgs(n, x): + return n > 0, x + + @function.Defun(*[dtypes.float32] * 2) + def Body(n, x): + return n - 1, x + n + + @function.Defun(*[dtypes.float32] * 2) + def BodyReturnsTooManyArgs(n, x): + return n - 1, x + n, x + + # TODO(b/65752372): Set `use_gpu=False` because + # `functional_ops.While()` does not reliably work on GPU (apparently + # because the result of evaluating the condition may be in device + # memory, but it is read on the host). + with self.test_session(use_gpu=False): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Expected a single scalar.*got 2 tensors."): + functional_ops.While([5., 0.], CondReturnsTooManyArgs, Body)[0].eval() + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "While loop body returned 3 arguments. Expected: 2"): + functional_ops.While([5., 0.], Cond, BodyReturnsTooManyArgs)[0].eval() + + def testWhileInMultipleSubgraphs(self): + + @function.Defun(* [dtypes.float32] * 2) + def Cond(n, x): # pylint: disable=unused-argument + return n > 0 + + @function.Defun(* [dtypes.float32] * 2) + def Body(n, x): + return n - 1, x + n + + # TODO(b/65752372): Set `use_gpu=False` because + # `functional_ops.While()` does not reliably work on GPU (apparently + # because the result of evaluating the condition may be in device + # memory, but it is read on the host). + with self.test_session(use_gpu=False) as sess: + n = array_ops.placeholder(dtypes.float32) + _, result = functional_ops.While([n, 0.], Cond, Body) + c = constant_op.constant(37.) + + self.assertAllEqual(210., sess.run(result, feed_dict={n: 20.})) + self.assertAllEqual(5050., sess.run(result, feed_dict={n: 100.})) + # Test that the result is the same when we run a different subgraph. + self.assertAllEqual(5050., sess.run([result, c], feed_dict={n: 100.})[0]) + + def _tfSum(self, rewrite_with_while): + # On GPU, don't rewrite using a while loop. + use_gpu = not rewrite_with_while + with self.test_session(use_gpu=use_gpu) as sess: + + @function.Defun(dtypes.int32, dtypes.float32) + def Body(n, x): + return x + math_ops.to_float(n) + + xs = [ + # 1 + 2 + ... + 20 + functional_ops.For( + 1, 21, 1, [0.], Body, rewrite_with_while=rewrite_with_while)[0], + # 100 + 99 + ... + 1 + functional_ops.For( + 100, 0, -1, [0.], Body, rewrite_with_while=rewrite_with_while)[0], + ] + xvals = sess.run(xs) + self.assertAllEqual(210, xvals[0]) + self.assertAllEqual(5050, xvals[1]) + + def testFor(self): + self._tfSum(False) + + def testForWithWhile(self): + self._tfSum(True) + + def testForWithWhileNaming(self): + g = ops.Graph() + with g.as_default(): + + @function.Defun(dtypes.int32, dtypes.float32, func_name="TestBody") + def TestBody(n, x): + return x + math_ops.to_float(n) + + _ = functional_ops.For( + 1, 21, 1, [0.], TestBody, rewrite_with_while=True)[0] + + names = [] + for func in g.as_graph_def().library.function: + names.append(func.signature.name) + self.assertTrue("TestBody" in names) + self.assertTrue("TestBody_Cond" in names) + self.assertTrue("TestBody_Body" in names) + + def testForCapturedInputs(self): + v = variables.Variable(1.0) + + @function.Defun(dtypes.int32) + def TestNullary(n): + v + math_ops.to_float(n) # pylint: disable=expression-not-assigned + + @function.Defun(dtypes.int32, dtypes.float32) + def TestUnary(n, x): + return x + math_ops.to_float(n) + v + + @function.Defun(dtypes.int32, dtypes.float32, dtypes.float32) + def TestBinary(n, x, x2): + return x + math_ops.to_float(n) + v, x2 + v + + for rewrite_with_while in (True, False): + # TODO(b/65752372): Set `use_gpu=False` because + # `functional_ops.While()` does not reliably work on GPU (apparently + # because the result of evaluating the condition may be in device + # memory, but it is read on the host). + use_gpu = not rewrite_with_while + with self.test_session(use_gpu=use_gpu) as sess: + result_nullary = functional_ops.For( + 1, 10, 1, [], TestNullary, + rewrite_with_while=rewrite_with_while) + result_unary = functional_ops.For( + 1, 10, 1, [0.], TestUnary, + rewrite_with_while=rewrite_with_while) + result_binary = functional_ops.For( + 1, 10, 1, [0., 0.], TestBinary, + rewrite_with_while=rewrite_with_while) + sess.run(variables.global_variables_initializer()) + assert not result_nullary + # The nullary variant doesn't return anything so we can't easily run it. + # As a total hack, fetch the operation by name and run it. + sess.run(ops.get_default_graph().get_operation_by_name( + "While" if rewrite_with_while else "For")) + assert len(result_unary) == 1 + self.assertEqual([54.0], sess.run(result_unary)) + assert len(result_binary) == 2 + self.assertEqual([54.0, 9.0], sess.run(result_binary)) + + def _tfMLP(self, xval, wsval, bsval, rewrite_with_while): + # On GPU, don't rewrite using a while loop. + use_gpu = not rewrite_with_while + with self.test_session(use_gpu=use_gpu): + + @function.Defun(dtypes.int32, *[dtypes.float64] * 3) + def MLP(i, a, ws, bs): + a = math_ops.tanh(math_ops.matmul(a, ws[i, :]) + bs[i, :]) + return a, ws, bs + + ret = functional_ops.For( + 0, + wsval.shape[0], + 1, [xval, wsval, bsval], + MLP, + rewrite_with_while=rewrite_with_while)[0] + + return ret.eval() + + def _npMLP(self, xval, wsval, bsval): + for i in range(wsval.shape[0]): + xval = np.tanh(np.dot(xval, wsval[i, :]) + bsval[i, :]) + return xval + + def _testForMLP(self, rewrite_with_while): + # We construct a 5-layer Multi-Layer Perceptron network here. + # Each layer have the same number of hidden unites (3), and the + # activation function is tanh(). We feed the input (xval) with + # batch size 2. + xval = np.random.normal(size=(2, 3)) + wsval = np.random.normal(size=(5, 3, 3)) + bsval = np.random.normal(size=(5, 3)) + np_ans = self._npMLP(xval, wsval, bsval) + tf_for_ans = self._tfMLP(xval, wsval, bsval, rewrite_with_while) + self.assertAllClose(np_ans, tf_for_ans) + + def testForMLP(self): + self._testForMLP(False) + + def testForMLPWhile(self): + self._testForMLP(True) + + def testForError(self): + + @function.Defun(dtypes.int32, dtypes.float32) + def Foo(i, v): + return math_ops.to_float(i) + v + + @function.Defun(dtypes.int32, dtypes.float32) + def ReturnsTooManyArgs(unused_i, v): + return v, v + + with self.test_session(use_gpu=True): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "must be a scalar"): + functional_ops.For([0], 10, 1, [0.0], Foo)[0].eval() + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Invalid start/limit/delta"): + functional_ops.For(0, 10, -1, [0.0], Foo)[0].eval() + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "For loop body returned 2 arguments. Expected: 1"): + functional_ops.For(0, 10, 1, [0.0], ReturnsTooManyArgs)[0].eval() + + def testGradient(self): + + @function.Defun(dtypes.float32) + def Poly(x): + # y = 2x^3+3x^2+4x+8 + return 2 * x * x * x + 3 * x * x + 4 * x + 8 + + @function.Defun(dtypes.float32) + def Grad(x): + # dy/dx = dy/dy * dy/dx = 1.0 * (6x^2+6x+4) + return functional_ops.Gradient([x, 1.0], Poly)[0] + + with self.test_session(use_gpu=False) as sess: + a = constant_op.constant(0.) + avals = [Poly(a), Grad(a)] + b = constant_op.constant(1.) + bvals = [Poly(b), Grad(b)] + self.assertAllEqual(sess.run(avals), [8., 4.]) + self.assertAllEqual(sess.run(bvals), [17., 16.]) + if __name__ == "__main__": test.main() + +# pylint: enable=invalid-name diff --git a/tensorflow/python/kernel_tests/large_concat_op_test.py b/tensorflow/python/kernel_tests/large_concat_op_test.py index 66afb6ec01..184d1dde2a 100644 --- a/tensorflow/python/kernel_tests/large_concat_op_test.py +++ b/tensorflow/python/kernel_tests/large_concat_op_test.py @@ -19,10 +19,12 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test +@test_util.with_c_api class LargeConcatOpTest(test.TestCase): """Tests that belong in concat_op_test.py, but run over large tensors.""" diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py index e1edffc3d9..7b291e29de 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linear_operator_util from tensorflow.python.platform import test @@ -94,8 +95,8 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase): class BroadcastMatrixBatchDimsTest(test.TestCase): def test_zero_batch_matrices_returned_as_empty_list(self): - self.assertAllEqual( - [], linear_operator_util.broadcast_matrix_batch_dims([])) + self.assertAllEqual([], + linear_operator_util.broadcast_matrix_batch_dims([])) def test_one_batch_matrix_returned_after_tensor_conversion(self): arr = rng.rand(2, 3, 4) @@ -194,6 +195,44 @@ class BroadcastMatrixBatchDimsTest(test.TestCase): linear_operator_util.broadcast_matrix_batch_dims([y, x]) +class CholeskySolveWithBroadcastTest(test.TestCase): + + def test_static_dims_broadcast(self): + # batch_shape = [2] + chol = rng.rand(3, 3) + rhs = rng.rand(2, 3, 7) + chol_broadcast = chol + np.zeros((2, 1, 1)) + + with self.test_session(): + result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs) + self.assertAllEqual((2, 3, 7), result.get_shape()) + expected = linalg_ops.cholesky_solve(chol_broadcast, rhs) + self.assertAllEqual(expected.eval(), result.eval()) + + def test_dynamic_dims_broadcast_64bit(self): + # batch_shape = [2, 2] + chol = rng.rand(2, 3, 3) + rhs = rng.rand(2, 1, 3, 7) + chol_broadcast = chol + np.zeros((2, 2, 1, 1)) + rhs_broadcast = rhs + np.zeros((2, 2, 1, 1)) + + chol_ph = array_ops.placeholder(dtypes.float64) + rhs_ph = array_ops.placeholder(dtypes.float64) + + with self.test_session() as sess: + result, expected = sess.run( + [ + linear_operator_util.cholesky_solve_with_broadcast( + chol_ph, rhs_ph), + linalg_ops.cholesky_solve(chol_broadcast, rhs_broadcast) + ], + feed_dict={ + chol_ph: chol, + rhs_ph: rhs, + }) + self.assertAllEqual(expected, result) + + class MatmulWithBroadcastTest(test.TestCase): def test_static_dims_broadcast(self): @@ -209,7 +248,7 @@ class MatmulWithBroadcastTest(test.TestCase): expected = math_ops.matmul(x, y_broadcast) self.assertAllEqual(expected.eval(), result.eval()) - def test_dynamic_dims_broadcast_32bit(self): + def test_dynamic_dims_broadcast_64bit(self): # batch_shape = [2] # for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7 x = rng.rand(2, 1, 3) @@ -221,9 +260,90 @@ class MatmulWithBroadcastTest(test.TestCase): with self.test_session() as sess: result, expected = sess.run( - [linear_operator_util.matmul_with_broadcast(x_ph, y_ph), - math_ops.matmul(x, y_broadcast)], - feed_dict={x_ph: x, y_ph: y}) + [ + linear_operator_util.matmul_with_broadcast(x_ph, y_ph), + math_ops.matmul(x, y_broadcast) + ], + feed_dict={ + x_ph: x, + y_ph: y + }) + self.assertAllEqual(expected, result) + + +class MatrixSolveWithBroadcastTest(test.TestCase): + + def test_static_dims_broadcast(self): + # batch_shape = [2] + matrix = rng.rand(3, 3) + rhs = rng.rand(2, 3, 7) + matrix_broadcast = matrix + np.zeros((2, 1, 1)) + + with self.test_session(): + result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs) + self.assertAllEqual((2, 3, 7), result.get_shape()) + expected = linalg_ops.matrix_solve(matrix_broadcast, rhs) + self.assertAllEqual(expected.eval(), result.eval()) + + def test_dynamic_dims_broadcast_64bit(self): + # batch_shape = [2, 2] + matrix = rng.rand(2, 3, 3) + rhs = rng.rand(2, 1, 3, 7) + matrix_broadcast = matrix + np.zeros((2, 2, 1, 1)) + rhs_broadcast = rhs + np.zeros((2, 2, 1, 1)) + + matrix_ph = array_ops.placeholder(dtypes.float64) + rhs_ph = array_ops.placeholder(dtypes.float64) + + with self.test_session() as sess: + result, expected = sess.run( + [ + linear_operator_util.matrix_solve_with_broadcast( + matrix_ph, rhs_ph), + linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast) + ], + feed_dict={ + matrix_ph: matrix, + rhs_ph: rhs, + }) + self.assertAllEqual(expected, result) + + +class MatrixTriangularSolveWithBroadcastTest(test.TestCase): + + def test_static_dims_broadcast(self): + # batch_shape = [2] + matrix = rng.rand(2, 3, 3) + rhs = rng.rand(3, 7) + rhs_broadcast = rhs + np.zeros((2, 1, 1)) + + with self.test_session(): + result = linear_operator_util.matrix_triangular_solve_with_broadcast( + matrix, rhs) + self.assertAllEqual((2, 3, 7), result.get_shape()) + expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast) + self.assertAllEqual(expected.eval(), result.eval()) + + def test_dynamic_dims_broadcast_64bit(self): + # batch_shape = [2] + matrix = rng.rand(2, 3, 3) + rhs = rng.rand(3, 7) + rhs_broadcast = rhs + np.zeros((2, 1, 1)) + + matrix_ph = array_ops.placeholder(dtypes.float64) + rhs_ph = array_ops.placeholder(dtypes.float64) + + with self.test_session() as sess: + result, expected = sess.run( + [ + linear_operator_util.matrix_triangular_solve_with_broadcast( + matrix_ph, rhs_ph), + linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast) + ], + feed_dict={ + matrix_ph: matrix, + rhs_ph: rhs, + }) self.assertAllEqual(expected, result) @@ -244,7 +364,7 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase): operator = DomainDimensionStubOperator(3) # Should not raise linear_operator_util.assert_compatible_matrix_dimensions( - operator, x).run() + operator, x).run() # pyformat: disable def test_incompatible_dimensions_raise(self): with self.test_session(): @@ -252,7 +372,7 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase): operator = DomainDimensionStubOperator(3) with self.assertRaisesOpError("Incompatible matrix dimensions"): linear_operator_util.assert_compatible_matrix_dimensions( - operator, x).run() + operator, x).run() # pyformat: disable if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index ad802f7e1f..55653489af 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -1124,40 +1124,91 @@ class AUCTest(test.TestCase): self.assertAlmostEqual(0.7, auc.eval(), 5) - def testAUCPRSpecialCase(self): + # Regarding the AUC-PR tests: note that the preferred method when + # calculating AUC-PR is summation_method='careful_interpolation'. + def testCorrectAUCPRSpecialCase(self): with self.test_session() as sess: predictions = constant_op.constant( [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4)) - auc, update_op = metrics.auc(labels, predictions, curve='PR') + auc, update_op = metrics.auc(labels, predictions, curve='PR', + summation_method='careful_interpolation') + + sess.run(variables.local_variables_initializer()) + # expected ~= 0.79726744594 + expected = 1 - math.log(1.5) / 2 + self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(expected, auc.eval(), delta=1e-3) + + def testCorrectAnotherAUCPRSpecialCase(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], + shape=(1, 7), + dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7)) + auc, update_op = metrics.auc(labels, predictions, curve='PR', + summation_method='careful_interpolation') + + sess.run(variables.local_variables_initializer()) + # expected ~= 0.61350593198 + expected = (2.5 - 2 * math.log(4./3) - 0.25 * math.log(7./5)) / 3 + self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(expected, auc.eval(), delta=1e-3) + + def testThirdCorrectAUCPRSpecialCase(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], + shape=(1, 7), + dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7)) + auc, update_op = metrics.auc(labels, predictions, curve='PR', + summation_method='careful_interpolation') + + sess.run(variables.local_variables_initializer()) + # expected ~= 0.90410597584 + expected = 1 - math.log(4./3) / 3 + self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(expected, auc.eval(), delta=1e-3) + + def testIncorrectAUCPRSpecialCase(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4)) + auc, update_op = metrics.auc(labels, predictions, curve='PR', + summation_method='trapezoidal') sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3) self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3) - def testAnotherAUCPRSpecialCase(self): + def testAnotherIncorrectAUCPRSpecialCase(self): with self.test_session() as sess: predictions = constant_op.constant( [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], shape=(1, 7), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7)) - auc, update_op = metrics.auc(labels, predictions, curve='PR') + auc, update_op = metrics.auc(labels, predictions, curve='PR', + summation_method='trapezoidal') sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3) self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3) - def testThirdAUCPRSpecialCase(self): + def testThirdIncorrectAUCPRSpecialCase(self): with self.test_session() as sess: predictions = constant_op.constant( [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], shape=(1, 7), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7)) - auc, update_op = metrics.auc(labels, predictions, curve='PR') + auc, update_op = metrics.auc(labels, predictions, curve='PR', + summation_method='trapezoidal') sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index c31d5a1f91..edc63264a3 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -802,6 +802,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_update(v, [1], [3.0]) self.assertAllEqual([1.0, 3.0], v.numpy()) + def testScatterAddStateOps(self): + with context.eager_mode(): + v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="add") + state_ops.scatter_add(v, [1], [3]) + self.assertAllEqual([1.0, 5.0], v.numpy()) + def testScatterUpdateCast(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update") diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 242cdff6f3..ec741d3265 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -694,7 +694,8 @@ class Layer(checkpointable.CheckpointableBase): self._dtype = input_list[0].dtype.base_dtype.name except AttributeError: pass - input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) + if all(hasattr(x, 'get_shape') for x in input_list): + input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) self.build(input_shapes) try: # Note: not all sub-classes of Layer call Layer.__init__ (especially diff --git a/tensorflow/python/lib/core/py_exception_registry.cc b/tensorflow/python/lib/core/py_exception_registry.cc new file mode 100644 index 0000000000..6637de632b --- /dev/null +++ b/tensorflow/python/lib/core/py_exception_registry.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/python/lib/core/py_exception_registry.h" + +#include <Python.h> + +namespace tensorflow { + +PyExceptionRegistry* PyExceptionRegistry::singleton_ = nullptr; + +void PyExceptionRegistry::Init(PyObject* code_to_exc_type_map) { + DCHECK(singleton_ == nullptr) << "PyExceptionRegistry::Init() already called"; + singleton_ = new PyExceptionRegistry; + + DCHECK(PyDict_Check(code_to_exc_type_map)); + PyObject* key; + PyObject* value; + Py_ssize_t pos = 0; + while (PyDict_Next(code_to_exc_type_map, &pos, &key, &value)) { + TF_Code code = static_cast<TF_Code>(PyLong_AsLong(key)); + singleton_->exc_types_[code] = value; + // The exception classes should also have the lifetime of the process, but + // incref just in case. + Py_INCREF(value); + } +} + +PyObject* PyExceptionRegistry::Lookup(TF_Code code) { + DCHECK(singleton_ != nullptr) << "Must call PyExceptionRegistry::Init() " + "before PyExceptionRegistry::Lookup()"; + DCHECK_NE(code, TF_OK); + DCHECK(singleton_->exc_types_.find(code) != singleton_->exc_types_.end()) + << "Unknown error code passed to PyExceptionRegistry::Lookup: " << code; + return singleton_->exc_types_[code]; +} + +} // namespace tensorflow diff --git a/tensorflow/python/lib/core/py_exception_registry.h b/tensorflow/python/lib/core/py_exception_registry.h new file mode 100644 index 0000000000..2b0f23b548 --- /dev/null +++ b/tensorflow/python/lib/core/py_exception_registry.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_ +#define TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_ + +#include <map> + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/platform/logging.h" + +#ifndef PyObject_HEAD +struct _object; +typedef _object PyObject; +#endif + +namespace tensorflow { + +// Global registry mapping C API error codes to the corresponding custom Python +// exception type. This is used to expose the exception types to C extension +// code (i.e. so we can raise custom exceptions via SWIG). +// +// Init() must be called exactly once at the beginning of the process before +// Lookup() can be used. +// +// Example usage: +// TF_Status* status = TF_NewStatus(); +// TF_Foo(..., status); +// +// if (TF_GetCode(status) != TF_OK) { +// PyObject* exc_type = PyExceptionRegistry::Lookup(TF_GetCode(status)); +// // Arguments to OpError base class. Set `node_def` and `op` to None. +// PyObject* args = +// Py_BuildValue("sss", nullptr, nullptr, TF_Message(status)); +// PyErr_SetObject(exc_type, args); +// Py_DECREF(args); +// TF_DeleteStatus(status); +// return NULL; +// } +class PyExceptionRegistry { + public: + // Initializes the process-wide registry. Should be called exactly once near + // the beginning of the process. The arguments are the various Python + // exception types (e.g. `cancelled_exc` corresponds to + // errors.CancelledError). + static void Init(PyObject* code_to_exc_type_map); + + // Returns the Python exception type corresponding to `code`. Init() must be + // called before using this function. `code` should not be TF_OK. + static PyObject* Lookup(TF_Code code); + + private: + static PyExceptionRegistry* singleton_; + PyExceptionRegistry() = default; + + // Maps error codes to the corresponding Python exception type. + std::map<TF_Code, PyObject*> exc_types_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_ diff --git a/tensorflow/python/lib/core/py_exception_registry.i b/tensorflow/python/lib/core/py_exception_registry.i new file mode 100644 index 0000000000..e872b74985 --- /dev/null +++ b/tensorflow/python/lib/core/py_exception_registry.i @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "tensorflow/python/platform/base.i" + +%{ +#include "tensorflow/python/lib/core/py_exception_registry.h" +%} + +%ignoreall + +%unignore tensorflow::PyExceptionRegistry; +%unignore tensorflow::PyExceptionRegistry::Init; + +%include "tensorflow/python/lib/core/py_exception_registry.h" +%unignoreall diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 8247d354db..32ea737a99 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/python/lib/core/numpy.h" #include "tensorflow/python/lib/core/py_util.h" @@ -77,9 +78,9 @@ string PyRepr(PyObject* obj) { bool IsPyDimension(PyObject* obj) { const char* tp_name = obj->ob_type->tp_name; if (strcmp(tp_name, "Dimension") != 0) return false; - bool ret = - StringPiece(PyRepr(PyType(obj))) - .ends_with("tensorflow.python.framework.tensor_shape.Dimension'>"); + bool ret = str_util::EndsWith( + PyRepr(PyType(obj)), + "tensorflow.python.framework.tensor_shape.Dimension'>"); return ret; } diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index 6fcf9c91d8..bf2d6f68b5 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -78,8 +78,7 @@ def tf_record_iterator(path, options=None): try: while True: try: - with errors.raise_exception_on_not_ok_status() as status: - reader.GetNext(status) + reader.GetNext() except errors.OutOfRangeError: break yield reader.record() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 3c6a5c9e56..57d2657838 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -255,10 +255,15 @@ def _SliceGrad(op, grad): @ops.RegisterGradient("StridedSlice") def _StridedSliceGrad(op, grad): """Gradient for StridedSlice op.""" - x = array_ops.shape(op.inputs[0]) begin = op.inputs[1] end = op.inputs[2] strides = op.inputs[3] + # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the + # same dtype so we build a shape of the same type as other args. + # Note that the choice of `begin` for specifying `out_type` is arbitrary. + # We could choose any of {begin|end|strides}.dtype since they are required to + # be the same. + x = array_ops.shape(op.inputs[0], out_type=begin.dtype) return array_ops.strided_slice_grad( x, diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 207866610b..68d446602e 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -387,7 +387,10 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32): """ if context.executing_eagerly() and not isinstance( input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): - return np.prod(ops.convert_to_tensor(input)._shape_tuple()) # pylint: disable=protected-access + input = ops.convert_to_tensor(input) + np_out_type = out_type.as_numpy_dtype + num_elements = np.prod(input._shape_tuple(), dtype=np_out_type) # pylint: disable=protected-acces: + return ops.convert_to_tensor(num_elements, dtype=out_type) with ops.name_scope(name, "Size", [input]) as name: if isinstance(input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py index ec623b55eb..0891bffdd5 100644 --- a/tensorflow/python/ops/distributions/uniform.py +++ b/tensorflow/python/ops/distributions/uniform.py @@ -166,7 +166,8 @@ class Uniform(distribution.Distribution): return self.low + self.range() * samples def _prob(self, x): - broadcasted_x = x * array_ops.ones(self.batch_shape_tensor()) + broadcasted_x = x * array_ops.ones( + self.batch_shape_tensor(), dtype=x.dtype) return array_ops.where( math_ops.is_nan(broadcasted_x), broadcasted_x, diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index a840b1eddf..161f6f3659 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,22 +27,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_functional_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.python.ops.gen_functional_ops import * -# pylint: enable=wildcard-import # pylint: disable=unused-import -from tensorflow.python.ops.gen_functional_ops import symbolic_gradient +from tensorflow.python.ops.gen_functional_ops import remote_call # pylint: enable=unused-import +from tensorflow.python.ops.gen_functional_ops import symbolic_gradient from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -365,7 +367,15 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, dtype_flat = output_flatten(dtype) # Convert elems to tensor array. n may be known statically. - n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0] + static_shape = elems_flat[0].shape + if static_shape.ndims is not None and static_shape.ndims < 1: + if len(elems_flat) == 1: + raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar") + else: + raise ValueError( + "elements in elems must be 1+ dimensional Tensors, not scalars" + ) + n = static_shape[0].value or array_ops.shape(elems_flat[0])[0] # TensorArrays are always flat elems_ta = [ @@ -634,3 +644,249 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, varscope.set_caching_device(None) return output_pack(results_flat) + + +# pylint: disable=invalid-name +def If(cond, inputs, then_branch, else_branch, name=None): + r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs). + + Args: + cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is + converted to a boolean according to the following rule: if the + scalar is a numerical value, non-zero means True and zero means + False; if the scalar is a string, non-empty means True and empty + means False. + inputs: A list of input tensors. + then_branch: A function takes 'inputs' and returns a list of tensors, + whose types are the same as what else_branch returns. + else_branch: A function takes 'inputs' and returns a list of tensors. + whose types are the same as what then_branch returns. + name: A name for the operation (optional). + + Returns: + A list of tensors returned by either then_branch(inputs) + or else_branch(inputs). + """ + # pylint: disable=protected-access + return gen_functional_ops._if( + cond, + inputs, [_.type for _ in then_branch.definition.signature.output_arg], + then_branch, + else_branch, + name=name) + + +def Gradient(inputs, f, name=None): + r"""Computes the gradient function for function f via backpropagation. + + Args: + inputs: A list of tensors of size N + M. + f: The function we want to compute the gradient for. + + The function 'f' must be a numerical function which takes N inputs and + produces M outputs. Its gradient function 'g', which is a function + taking N + M inputs and produces N outputs. + + I.e. if we have + (y1, y2, ..., yM) = f(x1, x2, ..., xN), + then, g is + (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, + dL/dy1, dL/dy2, ..., dL/dyM), + + where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the + loss function). dL/dxi is the partial derivative of L with respect + to xi. + + name: A name for the operation (optional). + + Returns: + A list of tensors of size N. + """ + # TODO(zhifengc): Pretty-print the above spec in latex. + # TODO(zhfiengc): Needs some math expert to say the comment above better. + tlist = [_.type for _ in f.definition.signature.input_arg] + return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name) + + +# pylint: disable=invalid-name,protected-access +def While(input_, cond, body, name=None, hostmem=None): + r"""output = input; While (Cond(output)) { output = Body(output) }. + + Args: + input_: A list of `Tensor` objects. + A list of input tensors whose types are T. + cond: . A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. + body: . A funcion takes a list of tensors and returns another + list tensors. Both lists have the same types as specified + by T. + name: A name for the operation (optional). + hostmem: A list of integer. If i is in the list, input[i] is a + host memory tensor. + + Returns: + A list of `Tensor` objects. Has the same type as `input`. + A list of output tensors whose types are T. + """ + ret = gen_functional_ops._while(input_, cond, body, name=name) + if hostmem: + input_attr = attr_value_pb2.AttrValue() + input_attr.list.i.extend(hostmem) + ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access + + output_attr = attr_value_pb2.AttrValue() + output_attr.list.i.extend(hostmem) + ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access + return ret + + +# b/36459430 +# +# Ideally, we do not need this rewrite For loop into a While loop. +# However, today, if a While runs on GPU and the condition returns a +# boolean, the While kernel crashes. Even if we fix the crash, the +# bool needs to be copied between GPU and CPU. So, a for loop is much +# preferred when running on GPU. +# +# On the other hand, For op has no directly XLA kernel. So, when we run +# a for loop, we need to rewrite it using a While op. +# +# It should be possible and probably better to write a XLA C++ kernel +# implementing the logic in _ForUsingWhile. +def _ForUsingWhile(start, + limit, + delta, + inputs, + forbody, + name=None, + hostmem=None): + """Helper to implement a For loop using a While.""" + # To support negative delta (e.g., range(100, 0, -3)), we iterate + # over the range(n) and use iter * delta + start as the real + # iteration index. (e.g., for i in range(34): iter = i * (-3) + + # 100). + d = math_ops.abs(delta) + # XLA on TPUs doesn't support integer division + n = math_ops.cast( + math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) / + math_ops.cast(d, dtypes.float32), dtypes.int32) + + # Carried loop variables ("extra_args") are implicitly added to the input list + # of the WhileBody function. WhileCond does not call forbody, and so does not + # depend on any of forbody's extra_args. Since WhileCond and WhileBody + # must have identical inputs, we have to augment the cond signature to take + # the same types as the carried loop variables. + body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:] + cond_sig = body_sig + [t.dtype for t in forbody.captured_inputs] + + cond_name = "%s_Cond" % forbody.name + + @function.Defun(*cond_sig, func_name=cond_name) + def WhileCond(i, n, *args): + del args + return i < n + + body_name = "%s_Body" % forbody.name + + @function.Defun(*body_sig, func_name=body_name) + def WhileBody(i, n, start, delta, *args): + """A While wrapper for forbody that handles loop-carried captured inputs.""" + for_result = forbody(start + i * delta, *args) + # Nullary functions return an Operation. Normal functions can't do this + # because their return values are converted to Tensors. + if isinstance(for_result, ops.Operation): + for_result = () + # Unary functions return a single Tensor value. + elif isinstance(for_result, ops.Tensor): + for_result = (for_result,) + extra_args = tuple(function.get_extra_args()) + return (i + 1, n, start, delta) + tuple(for_result) + extra_args + + if hostmem is not None: + hostmem = [(4 + _) for _ in hostmem] + + results = While( + input_=[0, n, start, delta] + inputs + WhileBody.captured_inputs, + cond=WhileCond, + body=WhileBody, + name=name, + hostmem=hostmem) + # Slice off the loop-carried captured inputs. + return list(results[4:len(results) - len(WhileBody.captured_inputs)]) + + +def For(start, + limit, + delta, + inputs, + body, + name=None, + hostmem=None, + rewrite_with_while=None): + r"""out = input; for i in range(start, limit, delta) out = body(i, out). + + Args: + start: A `Tensor` of type `int32`. + limit: A `Tensor` of type `int32`. + delta: A `Tensor` of type `int32`. + inputs: A list of `Tensor` objects. + A list of input tensors whose types are T. + body: A function takes a list of tensors and returns another + list of tensors. Both lists have the same types as (int32, T...). + name: A name for the operation (optional). + hostmem: A list of integer. If i is in the list, inputs[i] is a + host memory tensor. In other words, (i+1)-th argument of the body + function is expecting a host memory. + rewrite_with_while: If True, using While op to implement the For. + + Returns: + A list of `Tensor` objects. Has the same type as `input`. + A list of output tensors whose types are T. + """ + if rewrite_with_while: + return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem) + if body.captured_inputs: + wrapper_name = "%s_BodyWrapper" % body.name + + @function.Defun(*body.declared_input_types, func_name=wrapper_name) + def BodyWrapper(*args): + """A wrapper for body that handles loop-carried captured inputs.""" + body_result = body(*args) + extra_args = tuple(function.get_extra_args()) + # Nullary functions return an Operation. Normal functions can't do this + # because their return values are converted to Tensors. + if isinstance(body_result, ops.Operation): + return extra_args + # Unary functions return a single Tensor value. + elif not isinstance(body_result, tuple): + return (body_result,) + extra_args + # N-ary functions return a tuple of Tensors. + else: + return body_result + extra_args + + inputs += BodyWrapper.captured_inputs + ret = gen_functional_ops._for( + start, limit, delta, inputs, BodyWrapper, name=name) + # Slice off the loop-carried captured inputs. + ret = ret[:-len(BodyWrapper.captured_inputs)] + else: + ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name) + if hostmem: + num_for_params = 3 # start/limit/delta + + input_attr = attr_value_pb2.AttrValue() + input_attr.list.i.extend([num_for_params + i for i in hostmem]) + ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access + + output_attr = attr_value_pb2.AttrValue() + output_attr.list.i.extend(hostmem) + ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access + return ret + + +# pylint: enable=invalid-name,protected-access diff --git a/tensorflow/python/ops/linalg/linear_operator_util.py b/tensorflow/python/ops/linalg/linear_operator_util.py index 427bd1e890..9dd40765c2 100644 --- a/tensorflow/python/ops/linalg/linear_operator_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_util.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -102,6 +103,22 @@ def assert_is_batch_matrix(tensor): "%s" % tensor) +def shape_tensor(shape, name=None): + """Convert Tensor using default type, unless empty list or tuple.""" + # Works just like random_ops._ShapeTensor. + if isinstance(shape, (tuple, list)) and not shape: + dtype = dtypes.int32 + else: + dtype = None + return ops.convert_to_tensor(shape, dtype=dtype, name=name) + + +################################################################################ +# Broadcasting versions of common linear algebra functions. +# TODO(b/77519145) Do this more efficiently in some special cases. +################################################################################ + + def broadcast_matrix_batch_dims(batch_matrices, name=None): """Broadcast leading dimensions of zero or more [batch] matrices. @@ -170,7 +187,8 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None): bcast_batch_shape = batch_matrices[0].get_shape()[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = array_ops.broadcast_static_shape( - bcast_batch_shape, mat.get_shape()[:-2]) + bcast_batch_shape, + mat.get_shape()[:-2]) if bcast_batch_shape.is_fully_defined(): # The [1, 1] at the end will broadcast with anything. bcast_shape = bcast_batch_shape.concatenate([1, 1]) @@ -183,7 +201,8 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None): bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = array_ops.broadcast_dynamic_shape( - bcast_batch_shape, array_ops.shape(mat)[:-2]) + bcast_batch_shape, + array_ops.shape(mat)[:-2]) bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0) for i, mat in enumerate(batch_matrices): batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape) @@ -195,6 +214,13 @@ def _broadcast_to_shape(x, shape): return x + array_ops.zeros(shape=shape, dtype=x.dtype) +def cholesky_solve_with_broadcast(chol, rhs, name=None): + """Solve systems of linear equations.""" + with ops.name_scope(name, "CholeskySolveWithBroadcast", [chol, rhs]): + chol, rhs = broadcast_matrix_batch_dims([chol, rhs]) + return linalg_ops.cholesky_solve(chol, rhs) + + def matmul_with_broadcast(a, b, transpose_a=False, @@ -206,6 +232,11 @@ def matmul_with_broadcast(a, name=None): """Multiplies matrix `a` by matrix `b`, producing `a @ b`. + Works identically to `tf.matmul`, but broadcasts batch dims + of `a` and `b` (by replicating) if they are determined statically to be + different, or if static shapes are not fully defined. Thus, this may result + in an inefficient replication of data. + The inputs must be matrices (or tensors of rank > 2, representing batches of matrices). @@ -276,7 +307,7 @@ def matmul_with_broadcast(a, ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b are both set to True. """ - with ops.name_scope(name, "MatMulWithBroadcast", [a, b]) as name: + with ops.name_scope(name, "MatMulWithBroadcast", [a, b]): a, b = broadcast_matrix_batch_dims([a, b]) return math_ops.matmul( a, @@ -289,11 +320,43 @@ def matmul_with_broadcast(a, b_is_sparse=b_is_sparse) -def shape_tensor(shape, name=None): - """Convert Tensor using default type, unless empty list or tuple.""" - # Works just like random_ops._ShapeTensor. - if isinstance(shape, (tuple, list)) and not shape: - dtype = dtypes.int32 - else: - dtype = None - return ops.convert_to_tensor(shape, dtype=dtype, name=name) +def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None): + """Solve systems of linear equations.""" + with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]): + matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) + return linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint) + + +def matrix_triangular_solve_with_broadcast(matrix, + rhs, + lower=True, + adjoint=False, + name=None): + """Solves triangular systems of linear equations with by backsubstitution. + + Works identically to `tf.matrix_triangular_solve`, but broadcasts batch dims + of `matrix` and `rhs` (by replicating) if they are determined statically to be + different, or if static shapes are not fully defined. Thus, this may result + in an inefficient replication of data. + + Args: + matrix: A Tensor. Must be one of the following types: + `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`. + rhs: A `Tensor`. Must have the same `dtype` as `matrix`. + Shape is `[..., M, K]`. + lower: An optional `bool`. Defaults to `True`. Indicates whether the + innermost matrices in `matrix` are lower or upper triangular. + adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve + with matrix or its (block-wise) adjoint. + name: A name for the operation (optional). + + Returns: + `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`. + """ + with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]): + matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) + return linalg_ops.matrix_triangular_solve( + matrix, + rhs, + lower=lower, + adjoint=adjoint) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 276897ab99..b460ce5b95 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -174,6 +174,7 @@ from tensorflow.python.ops.gen_math_ops import * # pylint: enable=wildcard-import from tensorflow.python.util import compat from tensorflow.python.util import deprecation +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export # Aliases for some automatically-generated names. @@ -184,7 +185,6 @@ arg_min = deprecation.deprecated(None, "Use `argmin` instead")(arg_min) # pylin tf_export("arg_max")(arg_max) tf_export("arg_min")(arg_min) - # This is set by resource_variable_ops.py. It is included in this way since # there is a circular dependency between math_ops and resource_variable_ops _resource_variable_type = None @@ -1343,8 +1343,7 @@ def _ReductionDims(x, axis, reduction_indices): else: # Fast path: avoid creating Rank and Range ops if ndims is known. if isinstance(x, ops.Tensor) and x._rank() is not None: # pylint: disable=protected-access - return constant_op.constant( - np.arange(x._rank()), dtype=dtypes.int32) # pylint: disable=protected-access + return constant_op.constant(np.arange(x._rank()), dtype=dtypes.int32) # pylint: disable=protected-access if (isinstance(x, sparse_tensor.SparseTensor) and x.dense_shape.get_shape().is_fully_defined()): rank = x.dense_shape.get_shape()[0].value # sparse.dense_shape is 1-D. @@ -1522,7 +1521,7 @@ def reduce_mean(input_tensor, input_tensor: The tensor to reduce. Should have numeric type. axis: The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range - `[-rank(input_tensor), rank(input_tensor)]`. + `[-rank(input_tensor), rank(input_tensor))`. keepdims: If true, retains reduced dimensions with length 1. name: A name for the operation (optional). reduction_indices: The old (deprecated) name for axis. @@ -2273,10 +2272,11 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): ValueError: If `inputs` don't all have same shape and dtype or the shape cannot be inferred. """ + def _input_error(): - return ValueError( - "inputs must be a list of at least one Tensor with the " - "same dtype and shape") + return ValueError("inputs must be a list of at least one Tensor with the " + "same dtype and shape") + if not inputs or not isinstance(inputs, (list, tuple)): raise _input_error() inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) @@ -2294,8 +2294,8 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): # tensor_dtype is for safety only; operator's output type computed in C++ if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: - raise TypeError("tensor_dtype is {}, but input is of type {}" - .format(tensor_dtype, inputs[0].dtype)) + raise TypeError("tensor_dtype is {}, but input is of type {}".format( + tensor_dtype, inputs[0].dtype)) if len(inputs) == 1 and name is None: return inputs[0] @@ -2761,14 +2761,14 @@ def sparse_segment_sum(data, indices, segment_ids, name=None, name=name) else: return gen_math_ops.sparse_segment_sum( - data=data, - indices=indices, - segment_ids=segment_ids, - name=name) + data=data, indices=indices, segment_ids=segment_ids, name=name) @tf_export("sparse_segment_mean") -def sparse_segment_mean(data, indices, segment_ids, name=None, +def sparse_segment_mean(data, + indices, + segment_ids, + name=None, num_segments=None): r"""Computes the mean along sparse segments of a tensor. @@ -2805,14 +2805,14 @@ def sparse_segment_mean(data, indices, segment_ids, name=None, name=name) else: return gen_math_ops.sparse_segment_mean( - data=data, - indices=indices, - segment_ids=segment_ids, - name=name) + data=data, indices=indices, segment_ids=segment_ids, name=name) @tf_export("sparse_segment_sqrt_n") -def sparse_segment_sqrt_n(data, indices, segment_ids, name=None, +def sparse_segment_sqrt_n(data, + indices, + segment_ids, + name=None, num_segments=None): r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N). @@ -2842,10 +2842,7 @@ def sparse_segment_sqrt_n(data, indices, segment_ids, name=None, name=name) else: return gen_math_ops.sparse_segment_sqrt_n( - data=data, - indices=indices, - segment_ids=segment_ids, - name=name) + data=data, indices=indices, segment_ids=segment_ids, name=name) @tf_export("tensordot", "linalg.tensordot") @@ -3016,6 +3013,47 @@ def tensordot(a, b, axes, name=None): return product +@tf_export("math.polyval") +def polyval(coeffs, x, name=None): + r"""Computes the elementwise value of a polynomial. + + If `x` is a tensor and `coeffs` is a list n + 1 tensors, this function returns + the value of the n-th order polynomial + + p(x) = coeffs[n-1] + coeffs[n-2] * x + ... + coeffs[0] * x**(n-1) + + evaluated using Horner's method, i.e. + + p(x) = coeffs[n-1] + x * (coeffs[n-2] + ... + x * (coeffs[1] + + x * coeffs[0])) + + Args: + coeffs: A list of `Tensor` representing the coefficients of the polynomial. + x: A `Tensor` representing the variable of the polynomial. + name: A name for the operation (optional). + + Returns: + A `tensor` of the shape as the expression p(x) with usual broadcasting rules + for element-wise addition and multiplication applied. + + @compatibility(numpy) + Equivalent to numpy.polyval. + @end_compatibility + """ + + with ops.name_scope(name, "polyval", nest.flatten(coeffs) + [x]) as name: + x = ops.convert_to_tensor(x, name="x") + if len(coeffs) < 1: + return array_ops.zeros_like(x, name=name) + coeffs = [ + ops.convert_to_tensor(coeff, name=("coeff_%d" % index)) + for index, coeff in enumerate(coeffs) + ] + p = coeffs[0] + for c in coeffs[1:]: + p = c + p * x + return p + # FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow # 1.0 API so we leave these here for backwards compatibility. fft = gen_spectral_ops.fft diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 9f85188b35..05bcee8801 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -155,9 +155,7 @@ class RoundTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes() def testRounding(self): - x = [0.49, 0.7, -0.3, -0.8] - # TODO(nolivia): Remove this when RoundOp is forwards compatible - # x = np.arange(-5.0, 5.0, .25) + x = np.arange(-5.0, 5.0, .25) for dtype in [np.float32, np.double, np.int32]: x_np = np.array(x, dtype=dtype) with test_util.device(use_gpu=True): diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 9ec4954579..47eea6ef6b 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -626,10 +627,16 @@ def auc(labels, curve: Specifies the name of the curve to be computed, 'ROC' [default] or 'PR' for the Precision-Recall-curve. name: An optional variable_scope name. - summation_method: Specifies the Riemann summation method used, 'trapezoidal' - [default] that applies the trapezoidal rule, 'minoring' that applies - left summation for increasing intervals and right summation for decreasing - intervals or 'majoring' that applies the opposite. + summation_method: Specifies the Riemann summation method used + (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that + applies the trapezoidal rule; 'careful_interpolation', a variant of it + differing only by a more correct interpolation scheme for PR-AUC - + interpolating (true/false) positives but not the ratio that is precision; + 'minoring' that applies left summation for increasing intervals and right + summation for decreasing intervals; 'majoring' that does the opposite. + Note that 'careful_interpolation' is strictly preferred to 'trapezoidal' + (to be deprecated soon) as it applies the same method for ROC, and a + better one (see Davis & Goadrich 2006 for details) for the PR curve. Returns: auc: A scalar `Tensor` representing the current area-under-curve. @@ -664,8 +671,62 @@ def auc(labels, # Add epsilons to avoid dividing by 0. epsilon = 1.0e-6 + def interpolate_pr_auc(tp, fp, fn): + """Interpolation formula inspired by section 4 of Davis & Goadrich 2006. + + Note here we derive & use a closed formula not present in the paper + - as follows: + Modeling all of TP (true positive weight), + FP (false positive weight) and their sum P = TP + FP (positive weight) + as varying linearly within each interval [A, B] between successive + thresholds, we get + Precision = (TP_A + slope * (P - P_A)) / P + with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A). + The area within the interval is thus (slope / total_pos_weight) times + int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} + int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} + where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) + Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight + where dTP == TP_B - TP_A. + Note that when P_A == 0 the above calculation simplifies into + int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) + which is really equivalent to imputing constant precision throughout the + first bucket having >0 true positives. + + Args: + tp: true positive counts + fp: false positive counts + fn: false negative counts + Returns: + pr_auc: an approximation of the area under the P-R curve. + """ + dtp = tp[:num_thresholds - 1] - tp[1:] + p = tp + fp + prec_slope = _safe_div(dtp, p[:num_thresholds - 1] - p[1:], 'prec_slope') + intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:]) + safe_p_ratio = array_ops.where( + math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0), + _safe_div(p[:num_thresholds - 1], p[1:], 'recall_relative_ratio'), + array_ops.ones_like(p[1:])) + return math_ops.reduce_sum( + _safe_div( + prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), + tp[1:] + fn[1:], + name='pr_auc_increment'), + name='interpolate_pr_auc') + def compute_auc(tp, fn, tn, fp, name): """Computes the roc-auc or pr-auc based on confusion counts.""" + if curve == 'PR': + if summation_method == 'trapezoidal': + logging.warning( + 'Trapezoidal rule is known to produce incorrect PR-AUCs; ' + 'please switch to "careful_interpolation" instead.') + elif summation_method == 'careful_interpolation': + # This one is a bit tricky and is handled separately. + return interpolate_pr_auc(tp, fp, fn) rec = math_ops.div(tp + epsilon, tp + fn + epsilon) if curve == 'ROC': fp_rate = math_ops.div(fp, fp + tn + epsilon) @@ -675,7 +736,9 @@ def auc(labels, prec = math_ops.div(tp + epsilon, tp + fp + epsilon) x = rec y = prec - if summation_method == 'trapezoidal': + if summation_method in ('trapezoidal', 'careful_interpolation'): + # Note that the case ('PR', 'careful_interpolation') has been handled + # above. return math_ops.reduce_sum( math_ops.multiply(x[:num_thresholds - 1] - x[1:], (y[:num_thresholds - 1] + y[1:]) / 2.), @@ -923,8 +986,8 @@ def mean_per_class_accuracy(labels, weights = array_ops.reshape(weights, [-1]) weights = math_ops.to_float(weights) - is_correct = is_correct * weights - ones = ones * weights + is_correct *= weights + ones *= weights update_total_op = state_ops.scatter_add(total, labels, ones) update_count_op = state_ops.scatter_add(count, labels, is_correct) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 0c55386241..07ca32953f 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1808,7 +1808,7 @@ def softmax_cross_entropy_with_logits_v2( or `float64`). Backpropagation will happen into both `logits` and `labels`. To disallow - backpropagation into `labels`, pass label tensors through a `stop_gradients` + backpropagation into `labels`, pass label tensors through @{tf.stop_gradient} before feeding it to this function. **Note that to avoid confusion, it is required to pass only named arguments to @@ -1895,7 +1895,7 @@ _XENT_DEPRECATION = """ Future major versions of TensorFlow will allow gradients to flow into the labels input on backprop by default. -See tf.nn.softmax_cross_entropy_with_logits_v2. +See @{tf.nn.softmax_cross_entropy_with_logits_v2}. """ diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index da86d5f6ca..46a5f4fae6 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -1081,6 +1081,42 @@ class DataFormatDimMapTest(test_lib.TestCase): self._test([1, -3, -2], [2, 2, 3]) self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]]) + def testNHWCtoNCHW(self): + x_val = [1, -3, -2] + y_val_expected = [2, 2, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNHWCtoHWNC(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNHWCtoWHCN(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, y_val_expected) + + def testArbitraryASCII(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, y_val_expected) + class DataFormatVectorPermuteTest(test_lib.TestCase): diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 2f39ea2e7d..07e25e540c 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -171,7 +171,9 @@ class ResourceVariable(variables.Variable): to see all modifications to the value of the variable which happen in any operation on which the read_value depends on (either directly, indirectly, or via a control dependency) and guaranteed to not see any modification to the - value of the variable on which the read_value operation does not depend on. + value of the variable from operations that depend on the read_value operation. + Updates from operations that have no dependency relationship to the read_value + operation might or might not be visible to read_value. For example, if there is more than one assignment to a ResourceVariable in a single session.run call there is a well-defined value for each operation diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 01fc3182bc..f6a11ca625 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -423,3 +423,55 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): ref.handle, indices, ops.convert_to_tensor(updates, dtype=ref.dtype), use_locking, name)]): return ref.read_value() + + +@tf_export("scatter_add") +def scatter_add(ref, indices, updates, use_locking=False, name=None): + # pylint: disable=line-too-long + r"""Adds sparse updates to the variable referenced by `resource`. + + This operation computes + + ```python + # Scalar indices + ref[indices, ...] += updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] += updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] + ``` + + This operation outputs `ref` after the update is done. + This makes it easier to chain operations that need to use the updated value. + Duplicate entries are handled correctly: if multiple `indices` reference + the same location, their contributions add. + + Requires `updates.shape = indices.shape + ref.shape[1:]`. + + <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> + <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> + </div> + + Args: + ref: A `Variable`. + indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. + A tensor of indices into the first dimension of `ref`. + updates: A `Tensor`. Must have the same type as `ref`. + A tensor of updated values to store in `ref`. + use_locking: An optional `bool`. Defaults to `True`. + If True, the assignment will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. + name: A name for the operation (optional). + + Returns: + Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. + """ + if ref.dtype._is_ref_dtype: + return gen_state_ops.scatter_add(ref, indices, updates, + use_locking=use_locking, name=name) + return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access + ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), + name=name)) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index c35735ca65..e33085ba62 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1164,7 +1164,7 @@ class _VariableScopeStore(threading.local): self.variable_scopes_count[scope_name] = 1 def close_variable_subscopes(self, scope_name): - for k in self.variable_scopes_count: + for k in list(self.variable_scopes_count.keys()): if not scope_name or k.startswith(scope_name + "/"): self.variable_scopes_count[k] = 0 diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i index dbefca2be9..478dd46f7e 100644 --- a/tensorflow/python/platform/base.i +++ b/tensorflow/python/platform/base.i @@ -229,3 +229,25 @@ _COPY_TYPEMAPS(unsigned int, mode_t); %define final %enddef %define override %enddef #endif + +// Typemaps to automatically raise a Python exception from bad output TF_Status. +// TODO(b/77295559): expand this to all TF_Status* output params and deprecate +// raise_exception_on_not_ok_status (currently it only affects the C API). +%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) { + $1 = TF_NewStatus(); +} + +%typemap(freearg) (TF_Status* status) { + TF_DeleteStatus($1); +} + +%typemap(argout) TF_Status* status { + TF_Code code = TF_GetCode($1); + if (code != TF_OK) { + PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code); + // Arguments to OpError. + PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1)); + SWIG_SetErrorObj(exc, exc_args); + SWIG_fail; + } +} diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 39fabb9c1b..7acb8eeb1a 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +%include "tensorflow/python/platform/base.i" + %ignore ""; %rename("%s") TFE_NewContext; diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 82b908ac0e..26e8acd897 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -25,6 +25,7 @@ limitations under the License. %include "tensorflow/python/util/tfprof.i" %include "tensorflow/python/lib/core/py_func.i" +%include "tensorflow/python/lib/core/py_exception_registry.i" %include "tensorflow/python/lib/io/py_record_reader.i" %include "tensorflow/python/lib/io/py_record_writer.i" @@ -54,4 +55,3 @@ limitations under the License. %include "tensorflow/python/grappler/tf_optimizer.i" %include "tensorflow/python/grappler/cost_analyzer.i" %include "tensorflow/python/grappler/model_analyzer.i" - diff --git a/tensorflow/python/tools/optimize_for_inference.py b/tensorflow/python/tools/optimize_for_inference.py index 902748d55e..dac6a06a89 100644 --- a/tensorflow/python/tools/optimize_for_inference.py +++ b/tensorflow/python/tools/optimize_for_inference.py @@ -87,7 +87,9 @@ def main(unused_args): output_graph_def = optimize_for_inference_lib.optimize_for_inference( input_graph_def, FLAGS.input_names.split(","), - FLAGS.output_names.split(","), FLAGS.placeholder_type_enum) + FLAGS.output_names.split(","), + FLAGS.placeholder_type_enum, + FLAGS.toco_compatible) if FLAGS.frozen_graph: f = gfile.FastGFile(FLAGS.output, "w") @@ -138,6 +140,14 @@ def parse_args(): type=int, default=dtypes.float32.as_datatype_enum, help="The AttrValue enum to use for placeholders.") + parser.add_argument( + "--toco_compatible", + type=bool, + default=False, + help="""\ + If true, only use ops compatible with Tensorflow + Lite Optimizing Converter.\ + """) return parser.parse_known_args() diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py index 9c19271222..bb90d1cd6e 100644 --- a/tensorflow/python/tools/optimize_for_inference_lib.py +++ b/tensorflow/python/tools/optimize_for_inference_lib.py @@ -87,7 +87,7 @@ EPSILON_ATTR = { def optimize_for_inference(input_graph_def, input_node_names, output_node_names, - placeholder_type_enum): + placeholder_type_enum, toco_compatible=False): """Applies a series of inference optimizations on the input graph. Args: @@ -98,6 +98,8 @@ def optimize_for_inference(input_graph_def, input_node_names, output_node_names, results. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. + toco_compatible: Boolean, if True, only runs optimizations that result in + TOCO compatible graph operations (default=False). Returns: An optimized version of the input graph. @@ -110,8 +112,9 @@ def optimize_for_inference(input_graph_def, input_node_names, output_node_names, optimized_graph_def = graph_util.remove_training_nodes( optimized_graph_def, output_node_names) optimized_graph_def = fold_batch_norms(optimized_graph_def) - optimized_graph_def = fuse_resize_and_conv(optimized_graph_def, - output_node_names) + if not toco_compatible: + optimized_graph_def = fuse_resize_and_conv(optimized_graph_def, + output_node_names) ensure_graph_is_valid(optimized_graph_def) return optimized_graph_def diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index c82b898bd0..16e200d64d 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -376,7 +376,9 @@ class DistributionStrategy(object): update. Allreduce is an algorithm for performing a reduction on values from multiple devices and making the result available on all of those devices. - * TODO(josh11b): Future: partitioned variables + * In the future we will have support for TensorFlows' partitioned + variables, where a single variable is split across multiple + devices. We have then a few approaches we want to support: * Code written (as if) with no knowledge of class `DistributionStrategy`. @@ -390,7 +392,6 @@ class DistributionStrategy(object): ``` with my_distribution.scope(): iterator = my_distribution.distribute_dataset(dataset) - # TODO(josh11b): iterator = dataset.make_one_shot_iterator() tower_train_ops = my_distribution.call_for_each_tower( tower_fn, iterator.get_next()) train_op = tf.group(my_distribution.unwrap(tower_train_ops)) @@ -402,6 +403,10 @@ class DistributionStrategy(object): using `my_distribution`'s policy, and library functions called by `tower_fn` can use the `get_tower_context()` API to get enhanced behavior in this case. + + Note that in the future we will add support for initializable + Dataset iterators, at which point this example code will change. + * If you want to write a distributed algorithm, you may use any of the `DistributionStrategy` APIs inside a `with my_distribution.scope():` block of code. @@ -514,7 +519,7 @@ class DistributionStrategy(object): Steps 3 and 4 are done automatically by class `Optimizer` if you call its `apply_gradients` method in a tower context. Otherwise you can - manually call its `distributed_apply` method in a cross-tower context. + manually call its `_distributed_apply` method in a cross-tower context. Another thing you might want to do in the middle of your tower function is an all-reduce of some intermediate value, using `d.reduce()` or diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 44f00a96de..caa26581e8 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -515,8 +515,7 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input, def _sparse_values_to_keep(t, keep_input): """Convert a per-row `keep_input` vector to a per-value one.""" # Get the rows of every value in the sparse Tensor. - row_values = array_ops.reshape( - t.indices, [array_ops.shape(t.indices)[0], -1])[:, 0] + row_values = t.indices[:, 0] # The value should be kept iff the row should be kept. return array_ops.gather(keep_input, row_values) if keep_input.shape.ndims == 1: diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD index 27cdb860fe..1913fc20ee 100644 --- a/tensorflow/stream_executor/BUILD +++ b/tensorflow/stream_executor/BUILD @@ -75,7 +75,6 @@ cc_library( ":stream_executor", "//tensorflow/core:lib", "//tensorflow/core/kernels:ops_util", - "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", ] + if_cuda_is_configured([ "//tensorflow/core:cuda", diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 1aea0485fd..f408c06f46 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -18,7 +18,6 @@ limitations under the License. #include <functional> #include <memory> -#include "absl/strings/str_cat.h" #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/env_var.h" @@ -113,7 +112,7 @@ string ToString(libraryPropertyType type) { case PATCH_LEVEL: return "PATCH_LEVEL"; default: - return absl::StrCat( + return port::StrCat( "<unknown libraryPropertyType: ", static_cast<int>(type), ">"); } } @@ -375,7 +374,7 @@ port::Status GetCudnnProperty(libraryPropertyType type, int* value) { cudnnStatus_t status = cudnnGetProperty(type, value); if (status != CUDNN_STATUS_SUCCESS) { const string error = - absl::StrCat("cudnnGetProperty failed for type: ", ToString(type), + port::StrCat("cudnnGetProperty failed for type: ", ToString(type), " with status: ", ToString(status)); LOG(ERROR) << error; return port::Status{port::error::INTERNAL, error}; @@ -419,7 +418,7 @@ port::Status CudnnSupport::Init() { CudnnVersion loaded_version; TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version)); if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) { - const tensorflow::string error = absl::StrCat( + const tensorflow::string error = port::StrCat( "Loaded runtime CuDNN library: ", loaded_version.ToString(), " but source was compiled with: ", source_version.ToString(), ". CuDNN library major and minor version needs to match or have " diff --git a/tensorflow/stream_executor/cuda/cudnn_version.h b/tensorflow/stream_executor/cuda/cudnn_version.h index 058cc87bfa..2ed02e1700 100644 --- a/tensorflow/stream_executor/cuda/cudnn_version.h +++ b/tensorflow/stream_executor/cuda/cudnn_version.h @@ -18,7 +18,7 @@ limitations under the License. #include <string> -#include "absl/strings/str_join.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace perftools { namespace gputools { @@ -30,8 +30,9 @@ struct CudnnVersion { CudnnVersion(int major, int minor, int patch) : major_version(major), minor_version(minor), patch_level(patch) {} - std::string ToString() const { - return absl::StrJoin({major_version, minor_version, patch_level}, "."); + tensorflow::string ToString() const { + return tensorflow::strings::StrCat(major_version, ".", minor_version, ".", + patch_level); } int major_version; diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index fcc57d506e..fd44b0eb3b 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -304,6 +304,7 @@ def tf_cc_shared_object( clean_dep("//tensorflow:darwin"): [ "-Wl,-install_name,@rpath/" + name.split("/")[-1], ], + clean_dep("//tensorflow:windows"): [], "//conditions:default": [ "-Wl,-soname," + name.split("/")[-1], ], @@ -342,6 +343,22 @@ register_extension_info( label_regex_for_dep = "{extension_name}.*", ) +# A simple wrap around native.cc_binary rule. +# When using this rule, you should realize it doesn't link to any tensorflow +# dependencies by default. +def tf_native_cc_binary(name, + copts=tf_copts(), + **kwargs): + native.cc_binary( + name=name, + copts=copts, + **kwargs) + +register_extension_info( + extension_name = "tf_native_cc_binary", + label_regex_for_dep = "{extension_name}.*", +) + def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="", @@ -622,9 +639,12 @@ def tf_cc_test(name, linkopts=select({ clean_dep("//tensorflow:android"): [ "-pie", - ], + ], clean_dep("//tensorflow:windows"): [], clean_dep("//tensorflow:windows_msvc"): [], + clean_dep("//tensorflow:darwin"): [ + "-lm", + ], "//conditions:default": [ "-lpthread", "-lm" @@ -910,6 +930,7 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs): if 'linkstatic' not in kwargs or kwargs['linkstatic'] != 1: enable_text_relocation_linkopt = select({ clean_dep("//tensorflow:darwin"): [], + clean_dep("//tensorflow:windows"): [], "//conditions:default": ['-Wl,-z,notext'],}) if 'linkopts' in kwargs: kwargs['linkopts'] += enable_text_relocation_linkopt @@ -1178,6 +1199,20 @@ def tf_custom_op_library_additional_deps(): "@protobuf_archive//:protobuf_headers", clean_dep("//third_party/eigen3"), clean_dep("//tensorflow/core:framework_headers_lib"), + ] + if_windows(["//tensorflow/python:pywrap_tensorflow_import_lib"]) + +# A list of targets that contains the implemenation of +# tf_custom_op_library_additional_deps. It's used to generate a DEF file for +# exporting symbols from _pywrap_tensorflow.dll on Windows. +def tf_custom_op_library_additional_deps_impl(): + return [ + "@protobuf_archive//:protobuf", + "@nsync//:nsync_cpp", + # for //third_party/eigen3 + clean_dep("//third_party/eigen3"), + # for //tensorflow/core:framework_headers_lib + clean_dep("//tensorflow/core:framework"), + clean_dep("//tensorflow/core:reader_base"), ] # Traverse the dependency graph along the "deps" attribute of the @@ -1264,6 +1299,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): deps=deps + if_cuda(cuda_deps), data=[name + "_check_deps"], copts=tf_copts(is_external=True), + features = ["windows_export_all_symbols"], linkopts=linkopts + select({ "//conditions:default": [ "-lm", @@ -1410,7 +1446,8 @@ def tf_py_wrap_cc(name, ]) + tf_extension_copts()), linkopts=tf_extension_linkopts() + extra_linkopts, linkstatic=1, - deps=deps + extra_deps) + deps=deps + extra_deps, + **kwargs) native.genrule( name="gen_" + cc_library_pyd_name, srcs=[":" + cc_library_name], diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index 6722536358..9f1bdd8aae 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -93,6 +93,7 @@ genrule( "api/logging/__init__.py", "api/losses/__init__.py", "api/manip/__init__.py", + "api/math/__init__.py", "api/metrics/__init__.py", "api/nn/__init__.py", "api/nn/rnn_cell/__init__.py", diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt index 759ff752b0..05e603efb7 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt @@ -7,10 +7,6 @@ tf_class { mtype: "<type \'property\'>" } member { - name: "distribute" - mtype: "<type \'property\'>" - } - member { name: "evaluation_master" mtype: "<type \'property\'>" } @@ -82,9 +78,13 @@ tf_class { name: "tf_random_seed" mtype: "<type \'property\'>" } + member { + name: "train_distribute" + mtype: "<type \'property\'>" + } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\'], " } member_method { name: "replace" diff --git a/tensorflow/tools/api/golden/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/tensorflow.math.pbtxt new file mode 100644 index 0000000000..897718c05e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.math.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.math" +tf_module { + member_method { + name: "polyval" + argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 937044aece..afa3b78eb7 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -405,6 +405,10 @@ tf_module { mtype: "<type \'module\'>" } member { + name: "math" + mtype: "<type \'module\'>" + } + member { name: "metrics" mtype: "<type \'module\'>" } diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 603b2a4327..7eeae05847 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -145,6 +145,9 @@ class ApiCompatibilityTest(test.TestCase): verbose_diff_message = '' # First check if the key is not found in one or the other. if key in only_in_expected: + # TODO(annarev): remove once we switch to using tf_export decorators. + if key == 'tensorflow.math': + continue diff_message = 'Object %s expected but not found (removed). %s' % ( key, additional_missing_object_message) verbose_diff_message = diff_message @@ -229,6 +232,13 @@ class ApiCompatibilityTest(test.TestCase): for filename in golden_file_list } + # TODO(annarev): remove once we switch to using tf_export decorators. + tf_module = golden_proto_dict['tensorflow'].tf_module + for i in range(len(tf_module.member)): + if tf_module.member[i].name == 'math': + del tf_module.member[i] + break + # Diff them. Do not fail if called with update. # If the test is run to update goldens, only report diffs but do not fail. self._AssertProtoDictEquals( diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh index 8b8ba31a0d..438c5d52f6 100644 --- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh @@ -65,4 +65,6 @@ bazel test -c opt $BUILD_OPTS -k --test_output=errors \ --define=no_tensorflow_py_deps=true --test_lang_filters=py \ --test_tag_filters=-no_pip,-no_windows,-no_oss \ --build_tag_filters=-no_pip,-no_windows,-no_oss --build_tests_only \ - //${PY_TEST_DIR}/tensorflow/python/... + --flaky_test_attempts=3 \ + //${PY_TEST_DIR}/tensorflow/python/... \ + //${PY_TEST_DIR}/tensorflow/contrib/... diff --git a/tensorflow/tools/def_file_filter/BUILD b/tensorflow/tools/def_file_filter/BUILD new file mode 100644 index 0000000000..e390e0fb05 --- /dev/null +++ b/tensorflow/tools/def_file_filter/BUILD @@ -0,0 +1,9 @@ +# Description: +# Tools for filtering DEF file for TensorFlow on Windows +# +# On Windows, we use a DEF file generated by Bazel to export +# symbols from the tensorflow dynamic library(_pywrap_tensorflow.dll). +# The maximum number of symbols that can be exported per DLL is 64K, +# so we have to filter some useless symbols through this python script. + +package(default_visibility = ["//visibility:public"]) diff --git a/tensorflow/tools/def_file_filter/BUILD.tpl b/tensorflow/tools/def_file_filter/BUILD.tpl new file mode 100644 index 0000000000..3cb72f4979 --- /dev/null +++ b/tensorflow/tools/def_file_filter/BUILD.tpl @@ -0,0 +1,15 @@ +# Description: +# Tools for filtering DEF file for TensorFlow on Windows +# +# On Windows, we use a DEF file generated by Bazel to export +# symbols from the tensorflow dynamic library(_pywrap_tensorflow.dll). +# The maximum number of symbols that can be exported per DLL is 64K, +# so we have to filter some useless symbols through this python script. + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "def_file_filter", + srcs = ["def_file_filter.py"], + srcs_version = "PY2AND3", +) diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl new file mode 100644 index 0000000000..8bdc03eb0f --- /dev/null +++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl @@ -0,0 +1,168 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""def_file_filter.py - tool to filter a windows def file. + +The def file can be used to export symbols from the tensorflow dll to enable +tf.load_library(). + +Because the linker allows only 64K symbols to be exported per dll +we filter the symbols down to the essentials. The regular expressions +we use for this are specific to tensorflow. + +TODO: this works fine but there is an issue with exporting +'const char * const' and importing it from a user_ops. The problem is +on the importing end and using __declspec(dllimport) works around it. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import io +import os +import re +import subprocess +import sys +import tempfile + +# External tools we use that come with visual studio sdk +UNDNAME = "%{undname_bin_path}" + +# Exclude if matched +EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::") + +# Include if matched before exclude +INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" + r"google::protobuf::internal::ArenaImpl::AllocateAligned|" # for contrib/data/_prefetching_ops + r"google::protobuf::internal::ArenaImpl::AddCleanup|" # for contrib/data/_prefetching_ops + r"google::protobuf::Arena::OnArenaAllocation|" # for contrib/data/_prefetching_ops + r"tensorflow::internal::LogMessage|" + r"tensorflow::internal::LogString|" + r"tensorflow::internal::CheckOpMessageBuilder|" + r"tensorflow::internal::MakeCheckOpValueString|" + r"tensorflow::internal::PickUnusedPortOrDie|" + r"tensorflow::internal::ValidateDevice|" + r"tensorflow::ops::internal::Enter|" + r"tensorflow::strings::internal::AppendPieces|" + r"tensorflow::strings::internal::CatPieces|" + r"tensorflow::io::internal::JoinPathImpl") + +# Include if matched after exclude +INCLUDE_RE = re.compile(r"^(TF_\w*)$|" + r"^(TFE_\w*)$|" + r"nsync::|" + r"tensorflow::|" + r"functor::|" + r"perftools::gputools") + +# We want to identify data members explicitly in the DEF file, so that no one +# can implicitly link against the DLL if they use one of the variables exported +# from the DLL and the header they use does not decorate the symbol with +# __declspec(dllimport). It is easier to detect what a data symbol does +# NOT look like, so doing it with the below regex. +DATA_EXCLUDE_RE = re.compile(r"[)(]|" + r"vftable|" + r"vbtable|" + r"vcall|" + r"RTTI|" + r"protobuf::internal::ExplicitlyConstructed") + +def get_args(): + """Parse command line.""" + filename_list = lambda x: x.split(";") + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=filename_list, + help="paths to input def file", + required=True) + parser.add_argument("--output", help="output deffile", required=True) + parser.add_argument("--target", help="name of the target", required=True) + args = parser.parse_args() + return args + + +def main(): + """main.""" + args = get_args() + + # Pipe dumpbin to extract all linkable symbols from libs. + # Good symbols are collected in candidates and also written to + # a temp file. + candidates = [] + tmpfile = tempfile.NamedTemporaryFile(mode="w", delete=False) + for def_file_path in args.input: + def_file = open(def_file_path, 'r') + for line in def_file: + cols = line.split() + sym = cols[0] + tmpfile.file.write(sym + "\n") + candidates.append(sym) + tmpfile.file.close() + + # Run the symbols through undname to get their undecorated name + # so we can filter on something readable. + with open(args.output, "w") as def_fp: + # track dupes + taken = set() + + # Header for the def file. + def_fp.write("LIBRARY " + args.target + "\n") + def_fp.write("EXPORTS\n") + def_fp.write("\t ??1OpDef@tensorflow@@UEAA@XZ\n") + + # Each symbols returned by undname matches the same position in candidates. + # We compare on undname but use the decorated name from candidates. + dupes = 0 + proc = subprocess.Popen([UNDNAME, tmpfile.name], stdout=subprocess.PIPE) + for idx, line in enumerate(io.TextIOWrapper(proc.stdout, encoding="utf-8")): + decorated = candidates[idx] + if decorated in taken: + # Symbol is already in output, done. + dupes += 1 + continue + + if not INCLUDEPRE_RE.search(line): + if EXCLUDE_RE.search(line): + continue + if not INCLUDE_RE.search(line): + continue + + if "deleting destructor" in line: + # Some of the symbols convered by INCLUDEPRE_RE export deleting + # destructor symbols, which is a bad idea. + # So we filter out such symbols here. + continue + + if DATA_EXCLUDE_RE.search(line): + def_fp.write("\t" + decorated + "\n") + else: + def_fp.write("\t" + decorated + " DATA\n") + taken.add(decorated) + def_fp.close() + + exit_code = proc.wait() + if exit_code != 0: + print("{} failed, exit={}".format(UNDNAME, exit_code)) + return exit_code + + os.unlink(tmpfile.name) + + print("symbols={}, taken={}, dupes={}" + .format(len(candidates), len(taken), dupes)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl b/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl new file mode 100644 index 0000000000..47539b2423 --- /dev/null +++ b/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl @@ -0,0 +1,56 @@ +"""Repository rule for def file filter autoconfiguration. + +This repository reuses Bazel's VC detect mechanism to find undname.exe, +which is a tool used in def_file_filter.py. + +def_file_filter.py is for filtering the DEF file for TensorFlow on Windows. +On Windows, we use a DEF file generated by Bazel to export symbols from the +tensorflow dynamic library(_pywrap_tensorflow.dll). The maximum number of +symbols that can be exported per DLL is 64K, so we have to filter some useless +symbols through this python script. + +`def_file_filter_config` depends on the following environment variables: + * `BAZEL_VC` + * `BAZEL_VS` + * `VS90COMNTOOLS` + * `VS100COMNTOOLS` + * `VS110COMNTOOLS` + * `VS120COMNTOOLS` + * `VS140COMNTOOLS` +""" + +load("@bazel_tools//tools/cpp:windows_cc_configure.bzl", "find_vc_path") +load("@bazel_tools//tools/cpp:windows_cc_configure.bzl", "find_msvc_tool") +load("@bazel_tools//tools/cpp:lib_cc_configure.bzl", "auto_configure_fail") + +def _def_file_filter_configure_impl(repository_ctx): + if repository_ctx.os.name.lower().find("windows") == -1: + repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD") + repository_ctx.file("def_file_filter.py", "") + return + vc_path = find_vc_path(repository_ctx) + if vc_path == "visual-studio-not-found": + auto_configure_fail("Visual C++ build tools not found on your machine") + undname_bin_path = find_msvc_tool(repository_ctx, vc_path, "undname.exe").replace("\\", "\\\\") + + repository_ctx.template( + "def_file_filter.py", + Label("//tensorflow/tools/def_file_filter:def_file_filter.py.tpl"), + { + "%{undname_bin_path}": undname_bin_path, + }) + repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD") + + +def_file_filter_configure = repository_rule( + implementation = _def_file_filter_configure_impl, + environ = [ + "BAZEL_VC", + "BAZEL_VS", + "VS90COMNTOOLS", + "VS100COMNTOOLS", + "VS110COMNTOOLS", + "VS120COMNTOOLS", + "VS140COMNTOOLS" + ], +) diff --git a/tensorflow/tools/graph_transforms/backports_test.cc b/tensorflow/tools/graph_transforms/backports_test.cc index ab9a61afa7..80a954e062 100644 --- a/tensorflow/tools/graph_transforms/backports_test.cc +++ b/tensorflow/tools/graph_transforms/backports_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/public/session.h" @@ -191,7 +192,7 @@ TEST(BackportTensorArrayV3Test, TestBackportTensorArrayV3Subtypes) { std::map<string, const NodeDef*> node_lookup; MapNamesToNodes(result, &node_lookup); ASSERT_EQ(1, node_lookup.count("v3_node")); - EXPECT_TRUE(StringPiece(node_lookup.at("v3_node")->op()).ends_with("V2")); + EXPECT_TRUE(str_util::EndsWith(node_lookup.at("v3_node")->op(), "V2")); } } diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc index 250f54e20f..85660f94a8 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc @@ -283,6 +283,10 @@ Status FoldConstants(const GraphDef& input_graph_def, }; } + TF_RETURN_IF_ERROR(context.GetOneInt64Parameter( + "max_constant_size_in_bytes", cf_opts.max_constant_size_in_bytes, + &cf_opts.max_constant_size_in_bytes)); + // Constant folding. bool was_mutated; TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr, diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index 41106de008..a082399a87 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/public/session.h" @@ -209,10 +210,10 @@ class ConstantFoldingTest : public ::testing::Test { for (const NodeDef& node : graph_def.node()) { const StringPiece name(node.name()); const int occurrence_count = folded_node_map.count(node.name()); - if (name.ends_with("expect_removed")) { + if (str_util::EndsWith(name, "expect_removed")) { EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name(); } - if (name.ends_with("expect_remains")) { + if (str_util::EndsWith(name, "expect_remains")) { EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name(); } } @@ -370,6 +371,46 @@ class ConstantFoldingTest : public ::testing::Test { EXPECT_EQ(0, node_map.count("b")); EXPECT_EQ(1, node_map.count("c")); } + + void TestMaxConstantSizeInBytes() { + auto root = tensorflow::Scope::NewRootScope(); + + const int width = 100; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota<float>(&a_data, 1.0f); + Output a_const = ::tensorflow::ops::Const( + root.WithOpName("a_expect_remains"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota<float>(&b_data, 1.0f); + Output b_const = ::tensorflow::ops::Const( + root.WithOpName("b_expect_remains"), Input::Initializer(b_data)); + + Output add = ::tensorflow::ops::Add(root.WithOpName("add_expect_remains"), + a_const, b_const); + + Output placeholder = ::tensorflow::ops::Placeholder( + root.WithOpName("placeholder_expect_remains"), DT_FLOAT); + + Output mul = ::tensorflow::ops::Mul( + root.WithOpName("output_expect_remains"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + Tensor placeholder_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota<float>(&placeholder_tensor, 1.0f); + + // Setting the maximum constant size to 10 bytes should stop the constant + // folding at add(a, b) that would have yielded a constant of + // 100*sizeof(float) bytes. + graph_transforms::TransformFuncContext context; + context.params["max_constant_size_in_bytes"] = {"10"}; + TestConstantFolding(graph_def, + {{"placeholder_expect_remains", placeholder_tensor}}, + {}, {"output_expect_remains"}, context); + } }; TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); } @@ -394,5 +435,9 @@ TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) { TestRemoveUnusedNodesMultipleOutputs(); } +TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) { + TestMaxConstantSizeInBytes(); +} + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc index 2436c7e4a2..f401723808 100644 --- a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc +++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc @@ -40,8 +40,8 @@ Status ExtractMinMaxRecords(const string& log_file_name, for (const string& file_line : file_lines) { // We expect to find a line with components separated by semicolons, so to // start make sure that the basic structure is in place/ - StringPiece line(file_line); - if (!line.contains(print_suffix + ";" + requant_prefix)) { + if (!str_util::StrContains(file_line, + print_suffix + ";" + requant_prefix)) { continue; } std::vector<string> line_parts = str_util::Split(file_line, ';'); @@ -53,8 +53,7 @@ Status ExtractMinMaxRecords(const string& log_file_name, bool min_max_found = false; int min_max_index; for (int i = 1; i < line_parts.size(); ++i) { - StringPiece line_part(line_parts[i]); - if (line_part.starts_with(requant_prefix)) { + if (str_util::StartsWith(line_parts[i], requant_prefix)) { min_max_found = true; min_max_index = i; } @@ -90,7 +89,7 @@ Status ExtractMinMaxRecords(const string& log_file_name, continue; } StringPiece name_string = line_parts[min_max_index - 1]; - if (!name_string.ends_with(print_suffix)) { + if (!str_util::EndsWith(name_string, print_suffix)) { continue; } string name = diff --git a/tensorflow/tools/graph_transforms/insert_logging.cc b/tensorflow/tools/graph_transforms/insert_logging.cc index e1ee2b420b..377665448c 100644 --- a/tensorflow/tools/graph_transforms/insert_logging.cc +++ b/tensorflow/tools/graph_transforms/insert_logging.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/util/command_line_flags.h" @@ -101,7 +102,7 @@ Status InsertLogging(const GraphDef& input_graph_def, const bool op_matches = (ops.count(node.op()) > 0); bool prefix_matches = false; for (const string& prefix : prefixes) { - if (StringPiece(node.name()).starts_with(prefix)) { + if (str_util::StartsWith(node.name(), prefix)) { prefix_matches = true; } } diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc index 701e350fc3..cc82100148 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/util/command_line_flags.h" @@ -88,7 +89,7 @@ void CreateConstNode(const Tensor& tensor, const string& name, string GetMonolithicTensorKey(const string& tensor_slice_name) { std::vector<string> names = Split(tensor_slice_name, "/"); - if (StringPiece(names[names.size() - 1]).starts_with("part_")) { + if (str_util::StartsWith(names[names.size() - 1], "part_")) { CHECK_GE(names.size(), 2); names.pop_back(); } @@ -102,8 +103,8 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def, for (const auto& node : input_graph_def.node()) { std::vector<string> node_name_parts = Split(node.name(), "/"); if (node_name_parts.size() == 2 && - StringPiece(node_name_parts[0]).starts_with("save") && - StringPiece(node_name_parts[1]).starts_with("Assign") && + str_util::StartsWith(node_name_parts[0], "save") && + str_util::StartsWith(node_name_parts[1], "Assign") && node.input(0) == target_name) { restore_node_name = node.input(1); break; diff --git a/tensorflow/tools/graph_transforms/transform_graph_test.cc b/tensorflow/tools/graph_transforms/transform_graph_test.cc index bc2412fcbd..b276229aa4 100644 --- a/tensorflow/tools/graph_transforms/transform_graph_test.cc +++ b/tensorflow/tools/graph_transforms/transform_graph_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/public/session.h" @@ -112,12 +113,11 @@ class TransformGraphTest : public ::testing::Test { graph_transforms::MapNamesToNodes(out_graph_def, &out_node_map); for (const NodeDef& node : out_graph_def.node()) { - const StringPiece name(node.name()); const int occurrence_count = out_node_map.count(node.name()); - if (name.ends_with("expect_removed")) { + if (str_util::EndsWith(node.name(), "expect_removed")) { EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name(); } - if (name.ends_with("expect_remains")) { + if (str_util::EndsWith(node.name(), "expect_remains")) { EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name(); } } @@ -139,7 +139,7 @@ class TransformGraphTest : public ::testing::Test { Status no_such_status = TransformGraph({}, {}, {{"test_no_such_transform", {}}}, &graph_def); EXPECT_TRUE( - StringPiece(no_such_status.ToString()).contains("not recognized")); + str_util::StrContains(no_such_status.ToString(), "not recognized")); } void TestParseTransformParameters() { diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index 55f28a9e1d..367048965d 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -88,7 +88,7 @@ void NodeNamePartsFromInput(const string& input_name, string* prefix, *suffix = ":" + input_parts[1]; } StringPiece node_name_piece(input_parts[0]); - if (node_name_piece.Consume("^")) { + if (str_util::ConsumePrefix(&node_name_piece, "^")) { *prefix = "^"; } else { *prefix = ""; @@ -200,8 +200,7 @@ Status SortByExecutionOrder(const GraphDef& input_graph_def, // for merge only wait for one non-control input. int32 num_control_edges = 0; for (int i = 0; i < node_def.input_size(); ++i) { - StringPiece input_name(node_def.input(i)); - if (input_name.starts_with("^")) { + if (str_util::StartsWith(node_def.input(i), "^")) { num_control_edges++; } } @@ -504,7 +503,7 @@ Status RenameNodeInputs(const GraphDef& input_graph_def, const string& dest_name = input_to_rename.second; bool is_match; string match_name; - if (StringPiece(source_name).ends_with(":*")) { + if (str_util::EndsWith(source_name, ":*")) { is_match = true; string prefix; string unused_node_name; diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 62fec2c402..376644718f 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -48,36 +48,66 @@ py_binary( deps = ["//tensorflow:tensorflow_py"], ) +COMMON_PIP_DEPS = [ + ":licenses", + "MANIFEST.in", + "README", + "setup.py", + ":included_headers", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/autograph:autograph", + "//tensorflow/contrib/autograph/converters:converters", + "//tensorflow/contrib/autograph/converters:test_lib", + "//tensorflow/contrib/autograph/impl:impl", + "//tensorflow/contrib/autograph/pyct:pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis", + "//tensorflow/contrib/boosted_trees:boosted_trees_pip", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", + "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:contrib_op_loader", + "//tensorflow/contrib/eager/python/examples:examples_pip", + "//tensorflow/contrib/eager/python:checkpointable_utils", + "//tensorflow/contrib/eager/python:evaluator", + "//tensorflow/contrib/gan:gan", + "//tensorflow/contrib/graph_editor:graph_editor_pip", + "//tensorflow/contrib/keras:keras", + "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip", + "//tensorflow/contrib/nn:nn_py", + "//tensorflow/contrib/predictor:predictor_pip", + "//tensorflow/contrib/receptive_field:receptive_field_pip", + "//tensorflow/contrib/session_bundle:session_bundle_pip", + "//tensorflow/contrib/signal:signal_py", + "//tensorflow/contrib/signal:test_util", + "//tensorflow/contrib/slim:slim", + "//tensorflow/contrib/slim/python/slim/data:data_pip", + "//tensorflow/contrib/slim/python/slim/nets:nets_pip", + "//tensorflow/contrib/specs:specs", + "//tensorflow/contrib/summary:summary_test_util", + "//tensorflow/contrib/tensor_forest:init_py", + "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip", + "//tensorflow/contrib/timeseries:timeseries_pip", + "//tensorflow/contrib/tpu", + "//tensorflow/examples/tutorials/mnist:package", + "//tensorflow/python:distributed_framework_test_lib", + "//tensorflow/python:meta_graph_testdata", + "//tensorflow/python:spectral_ops_test_util", + "//tensorflow/python:util_example_parser_configuration", + "//tensorflow/python/debug:debug_pip", + "//tensorflow/python/eager:eager_pip", + "//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files", + "//tensorflow/python/saved_model:saved_model", + "//tensorflow/python/tools:tools_pip", + "//tensorflow/python:test_ops", + "//tensorflow/tools/dist_test/server:grpc_tensorflow_server", +] + # On Windows, python binary is a zip file of runfiles tree. # Add everything to its data dependency for generating a runfiles tree # for building the pip package on Windows. py_binary( name = "simple_console_for_windows", srcs = ["simple_console_for_windows.py"], - data = [ - "MANIFEST.in", - "README", - "setup.py", - ":included_headers", - "//tensorflow/contrib/nn:nn_py", - "//tensorflow/contrib/session_bundle:session_bundle_pip", - "//tensorflow/contrib/signal:signal_py", - "//tensorflow/contrib/slim/python/slim/data:data_pip", - "//tensorflow/python:util_example_parser_configuration", - "//tensorflow/python/debug:debug_pip", - "//tensorflow/python/saved_model", - "//tensorflow/python:spectral_ops_test_util", - "//tensorflow/python/tools:tools_pip", - "//tensorflow/python/eager:eager_pip", - "//tensorflow/contrib/summary:summary_test_util", - # These targets don't build on Windows yet. Exclude them for now. - # "//tensorflow/contrib/slim", - # "//tensorflow/contrib/slim/python/slim/nets:nets_pip", - # "//tensorflow/contrib/specs", - # "//tensorflow/contrib/tensor_forest:init_py", - # "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip", - # "//tensorflow/examples/tutorials/mnist:package", - ], + data = COMMON_PIP_DEPS, srcs_version = "PY2AND3", deps = ["//tensorflow:tensorflow_py"], ) @@ -111,6 +141,7 @@ filegroup( "@kafka//:LICENSE", "@libxsmm_archive//:LICENSE", "@lmdb//:LICENSE", + "@local_config_nccl//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@grpc//third_party/nanopb:LICENSE.txt", "@grpc//third_party/address_sorting:LICENSE", @@ -127,8 +158,6 @@ filegroup( "@org_python_pypi_backports_weakref//:LICENSE", ] + if_mkl([ "//third_party/mkl:LICENSE", - ]) + if_not_windows([ - "@nccl_archive//:LICENSE.txt", ]) + tf_additional_license_deps(), ) @@ -138,63 +167,13 @@ sh_binary( data = select({ "//tensorflow:windows": [":simple_console_for_windows"], "//tensorflow:windows_msvc": [":simple_console_for_windows"], - "//conditions:default": [ - ":licenses", - "MANIFEST.in", - "README", - "setup.py", - ":included_headers", + "//conditions:default": COMMON_PIP_DEPS + [ ":simple_console", - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/boosted_trees:boosted_trees_pip", - "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", - "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:contrib_op_loader", - "//tensorflow/contrib/eager/python/examples:examples_pip", - "//tensorflow/contrib/eager/python:checkpointable_utils", - "//tensorflow/contrib/eager/python:evaluator", - "//tensorflow/contrib/gan:gan", - "//tensorflow/contrib/graph_editor:graph_editor_pip", - "//tensorflow/contrib/keras:keras", - "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip", "//tensorflow/contrib/lite/python:interpreter_test_data", "//tensorflow/contrib/lite/python:tf_lite_py_pip", "//tensorflow/contrib/lite/toco:toco", "//tensorflow/contrib/lite/toco/python:toco_wrapper", "//tensorflow/contrib/lite/toco/python:toco_from_protos", - "//tensorflow/contrib/nn:nn_py", - "//tensorflow/contrib/predictor:predictor_pip", - "//tensorflow/contrib/autograph:autograph", - "//tensorflow/contrib/autograph/converters:converters", - "//tensorflow/contrib/autograph/converters:test_lib", - "//tensorflow/contrib/autograph/impl:impl", - "//tensorflow/contrib/autograph/pyct:pyct", - "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis", - "//tensorflow/contrib/receptive_field:receptive_field_pip", - "//tensorflow/contrib/session_bundle:session_bundle_pip", - "//tensorflow/contrib/signal:signal_py", - "//tensorflow/contrib/signal:test_util", - "//tensorflow/contrib/slim:slim", - "//tensorflow/contrib/slim/python/slim/data:data_pip", - "//tensorflow/contrib/slim/python/slim/nets:nets_pip", - "//tensorflow/contrib/specs:specs", - "//tensorflow/contrib/summary:summary_test_util", - "//tensorflow/contrib/tensor_forest:init_py", - "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip", - "//tensorflow/contrib/timeseries:timeseries_pip", - "//tensorflow/contrib/tpu", - "//tensorflow/examples/tutorials/mnist:package", - "//tensorflow/python:distributed_framework_test_lib", - "//tensorflow/python:meta_graph_testdata", - "//tensorflow/python:spectral_ops_test_util", - "//tensorflow/python:util_example_parser_configuration", - "//tensorflow/python/debug:debug_pip", - "//tensorflow/python/eager:eager_pip", - "//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files", - "//tensorflow/python/saved_model:saved_model", - "//tensorflow/python/tools:tools_pip", - "//tensorflow/python:test_ops", - "//tensorflow/tools/dist_test/server:grpc_tensorflow_server", ], }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([ "//tensorflow/contrib/tensorrt:init_py", diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index feb3114bde..8f0cf8c3d1 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -162,7 +162,9 @@ function main() { # Before we leave the top-level directory, make sure we know how to # call python. - source tools/python_bin_path.sh + if [[ -e tools/python_bin_path.sh ]]; then + source tools/python_bin_path.sh + fi pushd ${TMPDIR} rm -f MANIFEST diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index fe6b9407d6..ace0d411b9 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -2,6 +2,7 @@ load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure") +load("//third_party:nccl/nccl_configure.bzl", "nccl_configure") load("//third_party/mkl:build_defs.bzl", "mkl_repository") load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/py:python_configure.bzl", "python_configure") @@ -13,6 +14,8 @@ load("//third_party:repo.bzl", "tf_http_archive") load("//third_party/clang_toolchain:cc_configure_clang.bzl", "cc_download_clang_toolchain") load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external") load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") +load("//tensorflow/tools/def_file_filter:def_file_filter_configure.bzl", + "def_file_filter_configure") # Sanitize a dependency so that it works correctly from code that includes @@ -29,10 +32,15 @@ def tf_workspace(path_prefix="", tf_repo_name=""): cc_download_clang_toolchain(name="local_config_download_clang") cuda_configure(name="local_config_cuda") tensorrt_configure(name="local_config_tensorrt") + nccl_configure(name="local_config_nccl") git_configure(name="local_config_git") sycl_configure(name="local_config_sycl") python_configure(name="local_config_python") + # For windows bazel build + # TODO: Remove def file filter when TensorFlow can export symbols properly on Windows. + def_file_filter_configure(name = "local_config_def_file_filter") + # Point //external/local_config_arm_compiler to //external/arm_compiler arm_compiler_configure( name="local_config_arm_compiler", @@ -42,7 +50,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): mkl_repository( name = "mkl_linux", urls = [ - "https://mirror.bazel.build/intel/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz", + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz", "https://github.com/intel/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz", ], sha256 = "feacc3d82565c1231470359b42c696236fae873704e0b013436afba5fd4fd30f", @@ -52,7 +60,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): mkl_repository( name = "mkl_windows", urls = [ - "https://mirror.bazel.build/intel/mkl-dnn/releases/download/v0.12/mklml_win_2018.0.1.20171227.zip", + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.12/mklml_win_2018.0.1.20171227.zip", "https://github.com/intel/mkl-dnn/releases/download/v0.12/mklml_win_2018.0.1.20171227.zip" ], sha256 = "24bae8d7b22b431a654acadea43f2243c46ae6b1e5a73a4a936825f31d284ee4", @@ -62,7 +70,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): mkl_repository( name = "mkl_darwin", urls = [ - "https://mirror.bazel.build/intel/mkl-dnn/releases/download/v0.12/mklml_mac_2018.0.1.20171227.tgz", + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.12/mklml_mac_2018.0.1.20171227.tgz", "https://github.com/intel/mkl-dnn/releases/download/v0.12/mklml_mac_2018.0.1.20171227.tgz" ], sha256 = "0e954ec6fd3dc5e37f64c4043f6b5613dd687558da3df1028b3b7c29ff5cf77f", @@ -454,11 +462,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/1c3cdea2f181d8e14ee184466c5fb237f1b4cda8.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/1c3cdea2f181d8e14ee184466c5fb237f1b4cda8.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7e78daafdd22f3f17720a103d29d89590534004e.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/7e78daafdd22f3f17720a103d29d89590534004e.tar.gz", ], - sha256 = "1efbb9b05af88368be984d2f6526061d4a857181ef10f8841889a3a46869bb01", - strip_prefix = "llvm-1c3cdea2f181d8e14ee184466c5fb237f1b4cda8", + sha256 = "a6d94bd9de23515a1e3792a830421e3885977ea43d03427cdbe68f98cb7e0045", + strip_prefix = "llvm-7e78daafdd22f3f17720a103d29d89590534004e", build_file = clean_dep("//third_party/llvm:llvm.BUILD"), ) @@ -497,11 +505,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "zlib_archive", urls = [ - "https://mirror.bazel.build/zlib.net/zlib-1.2.8.tar.gz", - "http://zlib.net/fossils/zlib-1.2.8.tar.gz", + "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz", + "https://zlib.net/zlib-1.2.11.tar.gz", ], - sha256 = "36658cb768a54c1d4dec43c3116c27ed893e88b02ecfcb44f2166f9c0b7f2a0d", - strip_prefix = "zlib-1.2.8", + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", build_file = clean_dep("//third_party:zlib.BUILD"), ) @@ -518,11 +526,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "snappy", urls = [ - "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz", - "https://github.com/google/snappy/archive/1.1.4.tar.gz", + "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.7.tar.gz", + "https://github.com/google/snappy/archive/1.1.7.tar.gz", ], - sha256 = "2f7504c73d85bac842e893340333be8cb8561710642fc9562fccdd9d2c3fcc94", - strip_prefix = "snappy-1.1.4", + sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4", + strip_prefix = "snappy-1.1.7", build_file = clean_dep("//third_party:snappy.BUILD"), ) @@ -534,7 +542,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176", strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7", - build_file = clean_dep("//third_party:nccl.BUILD"), + build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"), ) tf_http_archive( diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD index 28293a3659..075b46896e 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.BUILD @@ -163,13 +163,6 @@ all_cmake_vars = select({ # Performs CMake variable substitutions on configuration header files. expand_cmake_vars( - name = "datatypes_gen", - src = "include/llvm/Support/DataTypes.h.cmake", - cmake_vars = all_cmake_vars, - dst = "include/llvm/Support/DataTypes.h", -) - -expand_cmake_vars( name = "config_gen", src = "include/llvm/Config/config.h.cmake", cmake_vars = all_cmake_vars, @@ -305,9 +298,7 @@ cc_binary( srcs = glob([ "utils/TableGen/*.cpp", "utils/TableGen/*.h", - ]) + [ - "lib/Target/X86/Disassembler/X86DisassemblerDecoderCommon.h", - ], + ]), linkopts = [ "-lm", "-ldl", @@ -2014,7 +2005,6 @@ cc_library( "include/llvm/Support/WasmRelocs/*.def", ]) + [ "include/llvm/BinaryFormat/MachO.def", - "include/llvm/Support/DataTypes.h", "include/llvm/Support/VCSRevision.h", "include/llvm/ExecutionEngine/ObjectMemoryBuffer.h", ], diff --git a/third_party/nccl/LICENSE b/third_party/nccl/LICENSE new file mode 100644 index 0000000000..146d9b765c --- /dev/null +++ b/third_party/nccl/LICENSE @@ -0,0 +1,203 @@ +Copyright 2018 The TensorFlow Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018, The TensorFlow Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/third_party/nccl.BUILD b/third_party/nccl/nccl_archive.BUILD index b2b8e18824..a05899e38d 100644 --- a/third_party/nccl.BUILD +++ b/third_party/nccl/nccl_archive.BUILD @@ -43,6 +43,7 @@ cc_library( "-Iexternal/nccl_archive/src", "-O3", ] + cuda_default_copts(), + include_prefix = "third_party/nccl", linkopts = select({ "@org_tensorflow//tensorflow:android": [ "-pie", @@ -61,6 +62,7 @@ cc_library( "-lrt", ], }), + strip_include_prefix = "src", visibility = ["//visibility:public"], deps = ["@local_config_cuda//cuda:cuda_headers"], ) diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl new file mode 100644 index 0000000000..9dfcb18369 --- /dev/null +++ b/third_party/nccl/nccl_configure.bzl @@ -0,0 +1,172 @@ +# -*- Python -*- +"""Repository rule for NCCL configuration. + +`nccl_configure` depends on the following environment variables: + + * `TF_NCCL_VERSION`: The NCCL version. + * `NCCL_INSTALL_PATH`: The installation path of the NCCL library. +""" + +load( + "//third_party/gpus:cuda_configure.bzl", + "auto_configure_fail", + "find_cuda_define", + "matches_version", +) + +_NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH" +_TF_NCCL_VERSION = "TF_NCCL_VERSION" + +_DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR" +_DEFINE_NCCL_MINOR = "#define NCCL_MINOR" +_DEFINE_NCCL_PATCH = "#define NCCL_PATCH" + +_NCCL_DUMMY_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl", + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = "@nccl_archive//:nccl", + visibility = ["//visibility:public"], +) +""" + +_NCCL_LOCAL_BUILD_TEMPLATE = """ +filegroup( + name = "LICENSE", + data = ["nccl/NCCL-SLA.txt"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl", + srcs = ["nccl/lib/libnccl.so.%s"], + hdrs = ["nccl/include/nccl.h"], + include_prefix = "third_party/nccl", + strip_include_prefix = "nccl/include", + deps = [ + "@local_config_cuda//cuda:cuda_headers", + ], + visibility = ["//visibility:public"], +) +""" + + +def _find_nccl_header(repository_ctx, nccl_install_path): + """Finds the NCCL header on the system. + + Args: + repository_ctx: The repository context. + nccl_install_path: The NCCL library install directory. + + Returns: + The path to the NCCL header. + """ + header_path = repository_ctx.path("%s/include/nccl.h" % nccl_install_path) + if not header_path.exists: + auto_configure_fail("Cannot find %s" % str(header_path)) + return header_path + + +def _check_nccl_version(repository_ctx, nccl_install_path, nccl_version): + """Checks whether the header file matches the specified version of NCCL. + + Args: + repository_ctx: The repository context. + nccl_install_path: The NCCL library install directory. + nccl_version: The expected NCCL version. + + Returns: + A string containing the library version of NCCL. + """ + header_path = _find_nccl_header(repository_ctx, nccl_install_path) + header_dir = str(header_path.realpath.dirname) + major_version = find_cuda_define(repository_ctx, header_dir, "nccl.h", + _DEFINE_NCCL_MAJOR) + minor_version = find_cuda_define(repository_ctx, header_dir, "nccl.h", + _DEFINE_NCCL_MINOR) + patch_version = find_cuda_define(repository_ctx, header_dir, "nccl.h", + _DEFINE_NCCL_PATCH) + header_version = "%s.%s.%s" % (major_version, minor_version, patch_version) + if not matches_version(nccl_version, header_version): + auto_configure_fail( + ("NCCL library version detected from %s/nccl.h (%s) does not match " + + "TF_NCCL_VERSION (%s). To fix this rerun configure again.") % + (header_dir, header_version, nccl_version)) + + +def _find_nccl_lib(repository_ctx, nccl_install_path, nccl_version): + """Finds the given NCCL library on the system. + + Args: + repository_ctx: The repository context. + nccl_install_path: The NCCL library installation directory. + nccl_version: The version of NCCL library files as returned + by _nccl_version. + + Returns: + The path to the NCCL library. + """ + lib_path = repository_ctx.path("%s/lib/libnccl.so.%s" % (nccl_install_path, + nccl_version)) + if not lib_path.exists: + auto_configure_fail("Cannot find NCCL library %s" % str(lib_path)) + return lib_path + + +def _nccl_configure_impl(repository_ctx): + """Implementation of the nccl_configure repository rule.""" + if _TF_NCCL_VERSION not in repository_ctx.os.environ: + # Add a dummy build file to make bazel query happy. + repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) + return + + nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip() + if matches_version("1", nccl_version): + # Alias to GitHub target from @nccl_archive. + if not matches_version(nccl_version, "1.3"): + auto_configure_fail( + "NCCL from GitHub must use version 1.3 (got %s)" % nccl_version) + repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT) + else: + # Create target for locally installed NCCL. + nccl_install_path = repository_ctx.os.environ[_NCCL_INSTALL_PATH].strip() + _check_nccl_version(repository_ctx, nccl_install_path, nccl_version) + repository_ctx.symlink(nccl_install_path, "nccl") + repository_ctx.file("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE % nccl_version) + + +nccl_configure = repository_rule( + implementation=_nccl_configure_impl, + environ=[ + _NCCL_INSTALL_PATH, + _TF_NCCL_VERSION, + ], +) +"""Detects and configures the NCCL configuration. + +Add the following to your WORKSPACE FILE: + +```python +nccl_configure(name = "local_config_nccl") +``` + +Args: + name: A unique name for this workspace rule. +""" diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD index fd48ed8941..cc11f52d0e 100644 --- a/third_party/snappy.BUILD +++ b/third_party/snappy.BUILD @@ -4,25 +4,12 @@ licenses(["notice"]) # BSD 3-Clause exports_files(["COPYING"]) -config_setting( - name = "windows", - values = {"cpu": "x64_windows"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "windows_msvc", - values = {"cpu": "x64_windows_msvc"}, - visibility = ["//visibility:public"], -) - cc_library( name = "snappy", srcs = [ + "config.h", "snappy.cc", "snappy.h", - "snappy-c.cc", - "snappy-c.h", "snappy-internal.h", "snappy-sinksource.cc", "snappy-sinksource.h", @@ -32,9 +19,18 @@ cc_library( ], hdrs = ["snappy.h"], copts = select({ - ":windows": [], - ":windows_msvc": [], + "@org_tensorflow//tensorflow:windows": [ + "/DHAVE_CONFIG_H", + "/EHsc", + ], + "@org_tensorflow//tensorflow:windows_msvc": [ + "/DHAVE_CONFIG_H", + "/EHsc", + ], "//conditions:default": [ + "-DHAVE_CONFIG_H", + "-fno-exceptions", + "-Wno-sign-compare", "-Wno-shift-negative-value", "-Wno-implicit-function-declaration", ], @@ -42,20 +38,66 @@ cc_library( ) genrule( + name = "config_h", + outs = ["config.h"], + cmd = "\n".join([ + "cat <<'EOF' >$@", + "#define HAVE_STDDEF_H 1", + "#define HAVE_STDINT_H 1", + "", + "#ifdef __has_builtin", + "# if !defined(HAVE_BUILTIN_EXPECT) && __has_builtin(__builtin_expect)", + "# define HAVE_BUILTIN_EXPECT 1", + "# endif", + "# if !defined(HAVE_BUILTIN_CTZ) && __has_builtin(__builtin_ctzll)", + "# define HAVE_BUILTIN_CTZ 1", + "# endif", + "#elif defined(__GNUC__) && (__GNUC__ > 3 || __GNUC__ == 3 && __GNUC_MINOR__ >= 4)", + "# ifndef HAVE_BUILTIN_EXPECT", + "# define HAVE_BUILTIN_EXPECT 1", + "# endif", + "# ifndef HAVE_BUILTIN_CTZ", + "# define HAVE_BUILTIN_CTZ 1", + "# endif", + "#endif", + "", + "#ifdef __has_include", + "# if !defined(HAVE_BYTESWAP_H) && __has_include(<byteswap.h>)", + "# define HAVE_BYTESWAP_H 1", + "# endif", + "# if !defined(HAVE_UNISTD_H) && __has_include(<unistd.h>)", + "# define HAVE_UNISTD_H 1", + "# endif", + "# if !defined(HAVE_SYS_ENDIAN_H) && __has_include(<sys/endian.h>)", + "# define HAVE_SYS_ENDIAN_H 1", + "# endif", + "# if !defined(HAVE_SYS_MMAN_H) && __has_include(<sys/mman.h>)", + "# define HAVE_SYS_MMAN_H 1", + "# endif", + "# if !defined(HAVE_SYS_UIO_H) && __has_include(<sys/uio.h>)", + "# define HAVE_SYS_UIO_H 1", + "# endif", + "#endif", + "", + "#ifndef SNAPPY_IS_BIG_ENDIAN", + "# ifdef __s390x__", + "# define SNAPPY_IS_BIG_ENDIAN 1", + "# elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__", + "# define SNAPPY_IS_BIG_ENDIAN 1", + "# endif", + "#endif", + "EOF", + ]), +) + +genrule( name = "snappy_stubs_public_h", srcs = ["snappy-stubs-public.h.in"], outs = ["snappy-stubs-public.h"], cmd = ("sed " + - "-e 's/@ac_cv_have_stdint_h@/1/g' " + - "-e 's/@ac_cv_have_stddef_h@/1/g' " + - "-e 's/@ac_cv_have_stdint_h@/1/g' " + - select({ - "@org_tensorflow//tensorflow:windows": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ", - "@org_tensorflow//tensorflow:windows_msvc": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ", - "//conditions:default": "-e 's/@ac_cv_have_sys_uio_h@/1/g' ", - }) + - "-e 's/@SNAPPY_MAJOR@/1/g' " + - "-e 's/@SNAPPY_MINOR@/1/g' " + - "-e 's/@SNAPPY_PATCHLEVEL@/4/g' " + + "-e 's/$${\\(.*\\)_01}/\\1/g' " + + "-e 's/$${SNAPPY_MAJOR}/1/g' " + + "-e 's/$${SNAPPY_MINOR}/1/g' " + + "-e 's/$${SNAPPY_PATCHLEVEL}/4/g' " + "$< >$@"), ) diff --git a/third_party/zlib.BUILD b/third_party/zlib.BUILD index d164ee719c..e8048dd98a 100644 --- a/third_party/zlib.BUILD +++ b/third_party/zlib.BUILD @@ -2,18 +2,6 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # BSD/MIT-like license (for zlib) -config_setting( - name = "windows", - values = {"cpu": "x64_windows"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "windows_msvc", - values = {"cpu": "x64_windows_msvc"}, - visibility = ["//visibility:public"], -) - cc_library( name = "zlib", srcs = [ @@ -45,8 +33,8 @@ cc_library( ], hdrs = ["zlib.h"], copts = select({ - ":windows": [], - ":windows_msvc": [], + "@org_tensorflow//tensorflow:windows": [], + "@org_tensorflow//tensorflow:windows_msvc": [], "//conditions:default": [ "-Wno-shift-negative-value", "-DZ_HAVE_UNISTD_H", |