diff options
167 files changed, 5041 insertions, 3620 deletions
diff --git a/configure.py b/configure.py index 8930c3a1f1..d411214817 100644 --- a/configure.py +++ b/configure.py @@ -35,7 +35,7 @@ except ImportError: _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' -_DEFAULT_NCCL_VERSION = '1.3' +_DEFAULT_NCCL_VERSION = '2.2' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' @@ -1097,8 +1097,10 @@ def set_tf_nccl_install_path(environ_cp): raise ValueError('Currently NCCL is only supported on Linux platforms.') ask_nccl_version = ( - 'Please specify the NCCL version you want to use. ' - '[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION + 'Please specify the NCCL version you want to use. If NCCL %s is not ' + 'installed, then you can use version 1.3 that can be fetched ' + 'automatically but it may have worse performance with multiple GPUs. ' + '[Default is %s]: ') % (_DEFAULT_NCCL_VERSION, _DEFAULT_NCCL_VERSION) for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_nccl_version = get_from_env_or_user_or_default( diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 06a3be18e0..730b1b669b 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -34,6 +34,35 @@ cc_library( ) cc_library( + name = "reader", + srcs = ["reader.cc"], + hdrs = ["reader.h"], + deps = [ + ":constants", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "reader_test", + srcs = ["reader_test.cc"], + data = [ + ":saved_model_half_plus_two", + ], + linkstatic = 1, + deps = [ + ":constants", + ":reader", + ":tag_constants", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( name = "loader", hdrs = ["loader.h"], deps = [ @@ -54,6 +83,7 @@ cc_library( hdrs = ["loader.h"], deps = [ ":constants", + ":reader", ] + if_not_mobile([ "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index faa1e378d0..07807ed2f3 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -18,8 +18,10 @@ limitations under the License. #include <unordered_set> #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" @@ -43,56 +45,6 @@ auto* load_latency = monitoring::Counter<1>::New( constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; -Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { - const string saved_model_pb_path = - io::JoinPath(export_dir, kSavedModelFilenamePb); - if (Env::Default()->FileExists(saved_model_pb_path).ok()) { - return ReadBinaryProto(Env::Default(), saved_model_pb_path, - saved_model_proto); - } - const string saved_model_pbtxt_path = - io::JoinPath(export_dir, kSavedModelFilenamePbTxt); - if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { - return ReadTextProto(Env::Default(), saved_model_pbtxt_path, - saved_model_proto); - } - return Status(error::Code::NOT_FOUND, - "Could not find SavedModel .pb or .pbtxt at supplied export " - "directory path: " + - export_dir); -} - -string GetTagsAsString(const std::unordered_set<string>& tags) { - string tags_as_string = "{ "; - for (const string& tag : tags) { - tags_as_string = strings::StrCat(tags_as_string, tag, " "); - } - tags_as_string = strings::StrCat(tags_as_string, "}"); - return tags_as_string; -} - -Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, - const std::unordered_set<string>& tags, - MetaGraphDef* meta_graph_def_to_load) { - for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) { - // Get tags from the meta_graph_def. - std::unordered_set<string> graph_tags; - for (const string& tag : meta_graph_def.meta_info_def().tags()) { - graph_tags.insert(tag); - } - // Match with the set of tags provided. - if (graph_tags == tags) { - *meta_graph_def_to_load = meta_graph_def; - return Status::OK(); - } - } - return Status(error::Code::NOT_FOUND, - "Could not find meta graph def matching supplied tags: " + - GetTagsAsString(tags) + - ". To inspect available tag-sets in the SavedModel, please " - "use the SavedModel CLI: `saved_model_cli`"); -} - Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr<Session>* session) { @@ -235,18 +187,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, const string& export_dir, const std::unordered_set<string>& tags, SavedModelBundle* const bundle) { - if (!MaybeSavedModelDirectory(export_dir)) { - return Status(error::Code::NOT_FOUND, - "SavedModel not found in export directory: " + export_dir); - } - LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags) - << "; from: " << export_dir; - - SavedModel saved_model_proto; - TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); - - TF_RETURN_IF_ERROR( - FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def)); + TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags, + &bundle->meta_graph_def)); TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession( bundle->meta_graph_def, session_options, &bundle->session)); @@ -288,8 +230,8 @@ Status LoadSavedModel(const SessionOptions& session_options, return end_microseconds - start_microseconds; }(); auto log_and_count = [&](const string& status_str) { - LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags) - << "; Status: " << status_str << ". Took " + LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ") + << " }; Status: " << status_str << ". Took " << load_latency_microsecs << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); }; diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc new file mode 100644 index 0000000000..2146c8a197 --- /dev/null +++ b/tensorflow/cc/saved_model/reader.cc @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/reader.h" + +#include <unordered_set> + +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" + +namespace tensorflow { +namespace { + +Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { + LOG(INFO) << "Reading SavedModel from: " << export_dir; + + const string saved_model_pb_path = + io::JoinPath(export_dir, kSavedModelFilenamePb); + if (Env::Default()->FileExists(saved_model_pb_path).ok()) { + return ReadBinaryProto(Env::Default(), saved_model_pb_path, + saved_model_proto); + } + const string saved_model_pbtxt_path = + io::JoinPath(export_dir, kSavedModelFilenamePbTxt); + if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { + return ReadTextProto(Env::Default(), saved_model_pbtxt_path, + saved_model_proto); + } + return Status(error::Code::NOT_FOUND, + "Could not find SavedModel .pb or .pbtxt at supplied export " + "directory path: " + + export_dir); +} + +Status FindMetaGraphDef(const SavedModel& saved_model_proto, + const std::unordered_set<string>& tags, + MetaGraphDef* meta_graph_def) { + LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ") + << " }"; + for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) { + // Get tags from the graph_def. + std::unordered_set<string> graph_tags; + for (const string& tag : graph_def.meta_info_def().tags()) { + graph_tags.insert(tag); + } + // Match with the set of tags provided. + if (graph_tags == tags) { + *meta_graph_def = graph_def; + return Status::OK(); + } + } + return Status( + error::Code::NOT_FOUND, + strings::StrCat( + "Could not find meta graph def matching supplied tags: { ", + str_util::Join(tags, " "), + " }. To inspect available tag-sets in the SavedModel, please " + "use the SavedModel CLI: `saved_model_cli`")); +} + +} // namespace + +Status ReadMetaGraphDefFromSavedModel(const string& export_dir, + const std::unordered_set<string>& tags, + MetaGraphDef* const meta_graph_def) { + SavedModel saved_model_proto; + TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); + TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def)); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/reader.h b/tensorflow/cc/saved_model/reader.h new file mode 100644 index 0000000000..5815108df2 --- /dev/null +++ b/tensorflow/cc/saved_model/reader.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/// Functions to read the SavedModel proto, or parts of it. + +#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_ +#define TENSORFLOW_CC_SAVED_MODEL_READER_H_ + +#include <string> +#include <unordered_set> + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { + +// Reads the SavedModel proto from saved_model.pb(txt) in the given directory, +// finds the MetaGraphDef that matches the given set of tags and writes it to +// the `meta_graph_def` parameter. Returns a failure status when the SavedModel +// file does not exist or no MetaGraphDef matches the tags. +Status ReadMetaGraphDefFromSavedModel(const string& export_dir, + const std::unordered_set<string>& tags, + MetaGraphDef* const meta_graph_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_ diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc new file mode 100644 index 0000000000..620e9c2eec --- /dev/null +++ b/tensorflow/cc/saved_model/reader_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/reader.h" + +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/tag_constants.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +constexpr char kTestDataPbTxt[] = + "cc/saved_model/testdata/half_plus_two_pbtxt/00000123"; +constexpr char kTestDataSharded[] = + "cc/saved_model/testdata/half_plus_two/00000123"; + +class ReaderTest : public ::testing::Test { + protected: + ReaderTest() {} + + void CheckMetaGraphDef(const MetaGraphDef& meta_graph_def) { + const auto& tags = meta_graph_def.meta_info_def().tags(); + EXPECT_TRUE(std::find(tags.begin(), tags.end(), kSavedModelTagServe) != + tags.end()); + EXPECT_NE(meta_graph_def.meta_info_def().tensorflow_version(), ""); + EXPECT_EQ( + meta_graph_def.signature_def().at("serving_default").method_name(), + "tensorflow/serving/predict"); + } +}; + +TEST_F(ReaderTest, TagMatch) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &meta_graph_def)); + CheckMetaGraphDef(meta_graph_def); +} + +TEST_F(ReaderTest, NoTagMatch) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"}, + &meta_graph_def); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: { missing-tag }")) + << st.error_message(); +} + +TEST_F(ReaderTest, NoTagMatchMultiple) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = ReadMetaGraphDefFromSavedModel( + export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: ")) + << st.error_message(); +} + +TEST_F(ReaderTest, PbtxtFormat) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt); + TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &meta_graph_def)); + CheckMetaGraphDef(meta_graph_def); +} + +TEST_F(ReaderTest, InvalidExportPath) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path"); + Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &meta_graph_def); + EXPECT_FALSE(st.ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 5a335aa43c..d88a34dfd9 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -127,6 +127,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index a6f5769e7b..cc4b13d3b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -33,134 +34,6 @@ limitations under the License. namespace tensorflow { namespace { -// Rotates a 32-bit integer 'v' left by 'distance' bits. -xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v, - int distance) { - return xla::Or( - xla::ShiftLeft(v, xla::ConstantR0<int>(builder, distance)), - xla::ShiftRightLogical(v, xla::ConstantR0<int>(builder, 32 - distance))); -} - -using ThreeFry2x32State = std::array<xla::XlaOp, 2>; - -// Implements the ThreeFry counter-based PRNG algorithm. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, - ThreeFry2x32State input, ThreeFry2x32State key) { - // Rotation distances specified by the Threefry2x32 algorithm. - constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24}; - ThreeFry2x32State x; - - std::array<xla::XlaOp, 3> ks; - // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. - ks[2] = xla::ConstantR0<int32>(builder, 0x1BD11BDA); - for (int i = 0; i < 2; ++i) { - ks[i] = key[i]; - x[i] = input[i]; - ks[2] = xla::Xor(ks[2], key[i]); - } - - x[0] = xla::Add(x[0], ks[0]); - x[1] = xla::Add(x[1], ks[1]); - - // Performs a single round of the Threefry2x32 algorithm, with a rotation - // amount 'rotation'. - auto round = [builder](ThreeFry2x32State v, int rotation) { - v[0] = xla::Add(v[0], v[1]); - v[1] = RotateLeftS32(builder, v[1], rotation); - v[1] = xla::Xor(v[0], v[1]); - return v; - }; - - // There are no known statistical flaws with 13 rounds of Threefry2x32. - // We are conservative and use 20 rounds. - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = xla::Add(x[0], ks[1]); - x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0<int32>(builder, 1)); - - x = round(x, rotations[4]); - x = round(x, rotations[5]); - x = round(x, rotations[6]); - x = round(x, rotations[7]); - x[0] = xla::Add(x[0], ks[2]); - x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0<int32>(builder, 2)); - - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = xla::Add(x[0], ks[0]); - x[1] = xla::Add(xla::Add(x[1], ks[1]), xla::ConstantR0<int32>(builder, 3)); - - x = round(x, rotations[4]); - x = round(x, rotations[5]); - x = round(x, rotations[6]); - x = round(x, rotations[7]); - x[0] = xla::Add(x[0], ks[1]); - x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0<int32>(builder, 4)); - - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = xla::Add(x[0], ks[2]); - x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0<int32>(builder, 5)); - - return x; -} - -// Returns a tensor of 'shape' random values uniformly distributed in the range -// [minval, maxval) -xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, - const TensorShape& shape, double minval, - double maxval) { - // Split the seed into two 32-bit scalars to form a key. - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - ThreeFry2x32State key = {seed0, seed1}; - const int64 size = shape.num_elements(); - - const int64 half_size = MathUtil::CeilOfRatio<int64>(size, 2); - const bool size_is_odd = (half_size * 2 != size); - - // Fill the generator inputs with unique counter values. - ThreeFry2x32State inputs; - inputs[0] = xla::Iota(builder, xla::S32, half_size); - inputs[1] = xla::Add(inputs[0], xla::ConstantR0<int32>(builder, half_size)); - ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); - - if (size_is_odd) { - outputs[1] = xla::Slice(outputs[1], {0}, {half_size - 1}, {1}); - } - - auto bits = - xla::Reshape(xla::ConcatInDim(builder, outputs, 0), shape.dim_sizes()); - - // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit - // forces the random bits into the mantissa. - constexpr int kFloatBits = 32; - constexpr int kMantissaBits = 23; - bits = xla::Or( - xla::ShiftRightLogical( - bits, xla::ConstantR0<int32>(builder, kFloatBits - kMantissaBits)), - xla::ConstantR0<int32>(builder, bit_cast<int32>(1.0f))); - auto floats = xla::BitcastConvertType(bits, xla::F32); - - // We have a floating point number in the range [1.0, 2.0). - // Subtract 1.0f to shift to the range [0.0, 1.0) - floats = xla::Sub(floats, xla::ConstantR0<float>(builder, 1.0f)); - // Multiply and add to shift to the range [minval, maxval). - floats = xla::Mul(floats, xla::ConstantR0<float>(builder, maxval - minval)); - floats = xla::Add(floats, xla::ConstantR0<float>(builder, minval)); - return floats; -} - -} // namespace - class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) @@ -177,7 +50,17 @@ class StatelessRandomUniformOp : public XlaOpKernel { errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); - ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, xla::ConstantR0<float>(builder, 0.0), + xla::ConstantR0<float>(builder, 1.0)); + ctx->SetOutput(0, uniform); } private: @@ -206,8 +89,16 @@ class StatelessRandomNormalOp : public XlaOpKernel { seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - auto uniform = - RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, + xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)), + xla::ConstantR0<float>(builder, 1.0)); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) auto normal = @@ -240,10 +131,18 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); - xla::XlaBuilder* b = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, + xla::ConstantR0<float>(builder, std::numeric_limits<float>::min()), + xla::ConstantR0<float>(builder, 1.0)); - auto uniform = - RandomUniform(b, seed, shape, std::numeric_limits<float>::min(), 1.0); ctx->SetOutput(0, TruncatedNormal(uniform)); } @@ -257,4 +156,5 @@ REGISTER_XLA_OP(Name("StatelessTruncatedNormal") .TypeConstraint("Tseed", DT_INT32), StatelessTruncatedNormalOp); +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 2d4593ea49..fc14834ca6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -279,7 +279,7 @@ class XlaOpRegistrar { #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \ static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ - XlaOpRegistrationBuilder::BUILDER.Build( \ + ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build( \ [](::tensorflow::OpKernelConstruction* context) \ -> ::tensorflow::OpKernel* { return new OP(context); })); diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 6933e9a838..ece5a885b5 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -119,6 +119,21 @@ xla_test( ) cc_library( + name = "prng", + srcs = ["prng.cc"], + hdrs = ["prng.h"], + deps = [ + ":constants", + ":math", + ":numeric", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/core:lib", + ], +) + +cc_library( name = "testing", srcs = ["testing.cc"], hdrs = ["testing.h"], diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc new file mode 100644 index 0000000000..299a6ac2b6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cmath> + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" + +namespace xla { +namespace { + +// Rotates a 32-bit integer 'v' left by 'distance' bits. +XlaOp RotateLeftS32(XlaOp v, int distance) { + return (v << ConstantR0<int32>(v.builder(), distance)) | + ShiftRightLogical(v, ConstantR0<int32>(v.builder(), 32 - distance)); +} + +using ThreeFry2x32State = std::array<XlaOp, 2>; + +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { + XlaBuilder* builder = input[0].builder(); + // Rotation distances specified by the Threefry2x32 algorithm. + constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24}; + ThreeFry2x32State x; + + std::array<XlaOp, 3> ks; + // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. + ks[2] = ConstantR0<int32>(builder, 0x1BD11BDA); + for (int i = 0; i < 2; ++i) { + ks[i] = key[i]; + x[i] = input[i]; + ks[2] = ks[2] ^ key[i]; + } + + x[0] = x[0] + ks[0]; + x[1] = x[1] + ks[1]; + + // Performs a single round of the Threefry2x32 algorithm, with a rotation + // amount 'rotation'. + auto round = [builder](ThreeFry2x32State v, int rotation) { + v[0] = v[0] + v[1]; + v[1] = RotateLeftS32(v[1], rotation); + v[1] = v[0] ^ v[1]; + return v; + }; + + // There are no known statistical flaws with 13 rounds of Threefry2x32. + // We are conservative and use 20 rounds. + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[1]; + x[1] = x[1] + ks[2] + ConstantR0<int32>(builder, 1); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = x[0] + ks[2]; + x[1] = x[1] + ks[0] + ConstantR0<int32>(builder, 2); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[0]; + x[1] = x[1] + ks[1] + ConstantR0<int32>(builder, 3); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = x[0] + ks[1]; + x[1] = x[1] + ks[2] + ConstantR0<int32>(builder, 4); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[2]; + x[1] = x[1] + ks[0] + ConstantR0<int32>(builder, 5); + + return x; +} + +} // namespace + +XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape, + XlaOp minval, XlaOp maxval) { + XlaBuilder* builder = seeds[0].builder(); + if (shape.element_type() != F32) { + return builder->ReportError(Unimplemented( + "Types other than F32 are not implemented by StatelessRngUniform.")); + } + ThreeFry2x32State key = seeds; + const int64 size = ShapeUtil::ElementsIn(shape); + + const int64 half_size = CeilOfRatio<int64>(size, 2); + const bool size_is_odd = (half_size * 2 != size); + + // Fill the generator inputs with unique counter values. + ThreeFry2x32State inputs; + inputs[0] = Iota(builder, S32, half_size); + inputs[1] = inputs[0] + ConstantR0<int32>(builder, half_size); + ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); + + if (size_is_odd) { + outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + + auto bits = Reshape(ConcatInDim(builder, outputs, 0), + AsInt64Slice(shape.dimensions())); + + // Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit + // forces the random bits into the mantissa. + constexpr int kFloatBits = 32; + constexpr int kMantissaBits = 23; + bits = ShiftRightLogical( + bits, ConstantR0<int32>(builder, kFloatBits - kMantissaBits)) | + ConstantR0<int32>(builder, tensorflow::bit_cast<int32>(1.0f)); + auto floats = BitcastConvertType(bits, F32); + + // We have a floating point number in the range [1.0, 2.0). + // Subtract 1.0f to shift to the range [0.0, 1.0) + floats = floats - ConstantR0<float>(builder, 1.0f); + // Multiply and add to shift to the range [minval, maxval). + return floats * (maxval - minval) + minval; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h new file mode 100644 index 0000000000..ac86390239 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ + +#include <array> + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a tensor containing 'shape' random values uniformly distributed in +// the range [minval, maxval). Requires 2 32-bit integer seeds. +// Currently only 'shape's of type F32 are implemented. +XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape, + XlaOp minval, XlaOp maxval); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 85c6c632cd..989bb759e3 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -182,6 +182,7 @@ tf_cc_test( name = "shape_inference_test", srcs = ["shape_inference_test.cc"], deps = [ + ":hlo", ":shape_inference", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index e1da8d940c..5e5d893582 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -354,16 +354,30 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { return; } + // We need ptxas >= 9.0 as a hard requirement, because we compile targeting + // PTX 6.0. An older ptxas will just fail to compile any of our code. + // // ptxas 9.0 before 9.0.276 and ptxas 9.1 before 9.1.121 miscompile some // address calculations with large offsets (e.g. "load ptr + large_constant"), // b/70245379. - if ((vmaj == 9 && vmin == 0 && vdot < 276) || - (vmaj == 9 && vmin == 1 && vdot < 121)) { - LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." - << vmin << "." << vdot - << ", which is in range [9.0.0, 9.0.276) + [9.1.0, 9.1.121). " - "These versions are known to miscompile XLA code, leading " - "to incorrect results or invalid-address errors."; + // + // ptxas 9.1.121 miscompiles some large multioutput fusions, again in a way + // that appears related to address calculations. ptxas 9.2.88 appears to + // work, as far as we can tell. + if (vmaj < 9) { + LOG(ERROR) + << "You are using ptxas 8.x, but XLA requires ptxas 9.x (and strongly " + "prefers >= 9.2.88). Compilation of XLA kernels below will likely " + "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " + "binary is sufficient."; + } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) { + LOG(WARNING) + << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." + << vdot + << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to " + "miscompile XLA code, leading to incorrect results or " + "invalid-address errors.\n\nYou do not need to update to CUDA " + "9.2.88; cherry-picking the ptxas binary is sufficient."; } } @@ -391,6 +405,10 @@ void WarnIfBadDriverJITVersion() { // - 384.x before 384.108 // - 387.x before 387.40 // - 390.x before 390.10. + // + // TODO(jlebar): This list does not cover the address-calculation bug we've + // observed in ptxas 9.1.121. Need to get a new safe range from nvidia + // corresponding to ptxas >= 9.2.88. auto vmaj = std::get<0>(version); auto vmin = std::get<1>(version); if ((vmaj == 384 && vmin < 108) || // diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index e291d74dd3..6c23228976 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -178,9 +178,9 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( } // namespace xla static std::unique_ptr<xla::TransferManager> CreateNVPTXTransferManager() { - return xla::MakeUnique<xla::GpuTransferManager>( - /*id=*/ stream_executor::cuda::kCudaPlatformId, - /*pointer_size=*/ llvm::DataLayout(xla::gpu::GpuCompiler::kDataLayout) + return xla::MakeUnique<xla::gpu::GpuTransferManager>( + /*id=*/stream_executor::cuda::kCudaPlatformId, + /*pointer_size=*/llvm::DataLayout(xla::gpu::GpuCompiler::kDataLayout) .getPointerSize(0 /* default address space */)); } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index d420863b85..6f2a7e1850 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -145,7 +145,7 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, llvm::Value* typed_ir_value; if (llvm::isa<llvm::GlobalVariable>(ir_value)) { - typed_ir_value = llvm::ConstantExpr::getBitCast( + typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast( llvm::cast<llvm::GlobalVariable>(ir_value), dest_type); } else { typed_ir_value = diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 388aa35d7d..2799baab41 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -242,15 +242,17 @@ llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, arguments_ptr}); } -llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, - llvm::IRBuilder<>* builder) { +llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* builder) { int bit_width = value->getType()->getPrimitiveSizeInBits(); + llvm::Value* all_warps_mask = builder->getInt32(-1); // Special case for efficiency if (value->getType()->isFloatTy() && bit_width == 32) { return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_shfl_down_f32, - {value, offset, builder->getInt32(kWarpSize - 1)}, {}, builder); + llvm::Intrinsic::nvvm_shfl_sync_down_f32, + {all_warps_mask, value, offset, builder->getInt32(kWarpSize - 1)}, {}, + builder); } // We must split values wider than 32 bits as the "shfl" instruction operates @@ -264,10 +266,11 @@ llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, for (int i = 0; i < num_segments; ++i) { x = builder->CreateInsertElement( x, - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_shfl_down_i32, - {builder->CreateExtractElement(x, i), - offset, builder->getInt32(kWarpSize - 1)}, - {}, builder), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_shfl_sync_down_i32, + {all_warps_mask, builder->CreateExtractElement(x, i), offset, + builder->getInt32(kWarpSize - 1)}, + {}, builder), i); } return builder->CreateBitCast( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 59455f389e..9bb4c42b15 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -125,13 +125,17 @@ llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, llvm::IRBuilder<>* builder); // Emits code to shuffle data between threads of a warp. This has the same -// semantics as the PTX "shfl.down" instruction [0] but works for values of any -// size. The last operand of the emitted "shfl" is `kWarpSize - 1`. +// semantics as the PTX "shfl.sync.down" instruction but works for values that +// aren't 32 bits in size. The last operand of the emitted "shfl" is +// `kWarpSize - 1`. // -// [0] -// http://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl -llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, - llvm::IRBuilder<>* builder); +// This function emits a "full-warp" shuffle, which all threads of a warp +// participate in. *Do not use this function from a divergent context:* You +// can't correctly do so on both Volta and earlier GPUs. +// +// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync +llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* builder); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 673ba530df..75bbbbe8ef 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -918,10 +918,13 @@ Status IrEmitterUnnested::EmitReductionToScalar( ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], shuffle_ir_type->getPointerTo()), "partial_reduction_result"); + CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) + << "Requires block size a multiple of the warp size, otherwise we " + "will read undefined elements."; ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), - &ir_builder_), + EmitFullWarpShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_), ir_builder_.CreateBitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( @@ -1498,10 +1501,13 @@ Status IrEmitterUnnested::EmitRowReduction( ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], shuffle_ir_type->getPointerTo()), "partial_reduction_result"); + CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) + << "Requires block size a multiple of the warp size, otherwise we " + "will read undefined elements."; ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), - &ir_builder_), + EmitFullWarpShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_), ir_builder_.CreateBitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index a4e4e85bf3..2b0d6924a2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -206,7 +206,7 @@ std::unique_ptr<llvm::TargetMachine> GetTargetMachine( codegen_opt_level = CodeGenOpt::None; } return WrapUnique(target->createTargetMachine( - triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx42", target_options, + triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, Optional<Reloc::Model>(RelocModel), Optional<CodeModel::Model>(CMModel), codegen_opt_level)); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 830ebfb125..19bee38790 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -386,6 +386,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( slice_sizes); break; } + case HloOpcode::kGather: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Gather instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_gather_dimension_numbers()) + << "Gather instruction should have GatherDimensionNumbers set."; + std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers = + MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers()); + std::vector<int64> gather_window_bounds; + for (int64 bound : proto.gather_window_bounds()) { + gather_window_bounds.push_back(bound); + } + instruction = + CreateGather(proto.shape(), operands(0), operands(1), + *gather_dimension_numbers, gather_window_bounds); + break; + } default: { instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -427,13 +444,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->set_sharding(sharding); } - if (proto.has_gather_dimension_numbers()) { - instruction->gather_dimension_numbers_ = - MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers()); - } - for (int64 bound : proto.gather_window_bounds()) { - instruction->gather_window_bounds_.push_back(bound); - } return std::move(instruction); } @@ -1036,34 +1046,8 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds) { - std::unique_ptr<HloInstruction> instruction = - WrapUnique(new HloInstruction(HloOpcode::kGather, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(gather_indices); - instruction->gather_dimension_numbers_ = - MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); - c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_)); - return instruction; -} - -/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice<int64> output_window_dims, - tensorflow::gtl::ArraySlice<int64> elided_window_dims, - tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, - int64 index_vector_dim) { - GatherDimensionNumbers gather_dim_numbers; - for (int64 output_window_dim : output_window_dims) { - gather_dim_numbers.add_output_window_dims(output_window_dim); - } - for (int64 elided_window_dim : elided_window_dims) { - gather_dim_numbers.add_elided_window_dims(elided_window_dim); - } - for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { - gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); - } - - gather_dim_numbers.set_index_vector_dim(index_vector_dim); - return gather_dim_numbers; + return MakeUnique<HloGatherInstruction>(shape, operand, gather_indices, + gather_dim_numbers, window_bounds); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain( @@ -1127,6 +1111,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kSort: + case HloOpcode::kGather: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1228,11 +1213,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kGather: - CHECK_EQ(new_operands.size(), 2); - clone = CreateGather(shape, new_operands[0], new_operands[1], - *gather_dimension_numbers_, gather_window_bounds_); - break; case HloOpcode::kDomain: CHECK_EQ(new_operands.size(), 1); clone = @@ -1539,11 +1519,6 @@ bool HloInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(dot_dimension_numbers(), other.dot_dimension_numbers()); - case HloOpcode::kGather: - return protobuf_util::ProtobufEquals(gather_dimension_numbers(), - other.gather_dimension_numbers()) && - gather_window_bounds() == other.gather_window_bounds(); - // Remaining instructions with special values. case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); @@ -1590,6 +1565,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1955,11 +1931,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString( if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } - if (gather_dimension_numbers_ != nullptr) { - extra.push_back(GatherDimensionNumbersToString()); - extra.push_back( - StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); - } if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { @@ -2089,14 +2060,6 @@ HloInstructionProto HloInstruction::ToProto() const { if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } - if (gather_dimension_numbers_ != nullptr) { - *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; - } - if (opcode() == HloOpcode::kGather) { - for (int64 bound : gather_window_bounds()) { - proto.add_gather_window_bounds(bound); - } - } if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); @@ -2857,26 +2820,6 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } -string HloInstruction::GatherDimensionNumbersToString() const { - CHECK_NE(gather_dimension_numbers_.get(), nullptr); - string output_window_dims = - StrCat("output_window_dims={", - Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); - string elided_window_dims = - StrCat("elided_window_dims={", - Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); - string gather_dims_to_operand_dims = StrCat( - "gather_dims_to_operand_dims={", - Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); - string index_vector_dim = StrCat( - "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); - - return Join<std::initializer_list<string>>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, - index_vector_dim}, - ", "); -} - bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: @@ -3190,4 +3133,14 @@ int64 HloInstruction::slice_sizes(int64 dimension) const { const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const { return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes(); } + +const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { + return Cast<HloGatherInstruction>(this)->gather_dimension_numbers(); +} + +tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds() + const { + return Cast<HloGatherInstruction>(this)->gather_window_bounds(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index b392d65636..cbd78fa124 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -700,13 +700,6 @@ class HloInstruction { // when we plumb a primordial token from the entry computation. static std::unique_ptr<HloInstruction> CreateToken(); - // Creates an instance of GatherDimensionNumbers. - static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice<int64> output_window_dims, - tensorflow::gtl::ArraySlice<int64> elided_window_dims, - tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, - int64 index_vector_dim); - // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -1081,19 +1074,6 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; - const GatherDimensionNumbers& gather_dimension_numbers() const { - CHECK(gather_dimension_numbers_ != nullptr); - return *gather_dimension_numbers_; - } - - tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const { - CHECK_EQ(opcode(), HloOpcode::kGather); - return gather_window_bounds_; - } - - // Returns the dump string of the gather dimension numbers. - string GatherDimensionNumbersToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1460,6 +1440,12 @@ class HloInstruction { // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. const std::vector<int64>& dynamic_slice_sizes() const; + + // Delegates to HloGatherInstruction::gather_dimension_numbers. + const GatherDimensionNumbers& gather_dimension_numbers() const; + // Delegates to HloGatherInstruction::gather_window_bounds. + tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1603,9 +1589,6 @@ class HloInstruction { // Describes the dimension numbers used for a dot. std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; - std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; - std::vector<int64> gather_window_bounds_; - // Used to tag kCopy instructions that are eligible for copy elision. bool copy_elision_allowed_ = true; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 87c048930f..b75a2bd34b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -1369,7 +1370,7 @@ TEST_F(HloInstructionTest, StringifyGather_0) { HloInstruction* gather_instruction = builder.AddInstruction(HloInstruction::CreateGather( gather_result_shape, input, gather_indices, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1405,7 +1406,7 @@ TEST_F(HloInstructionTest, StringifyGather_1) { HloInstruction* gather_instruction = builder.AddInstruction(HloInstruction::CreateGather( gather_result_shape, input, gather_indices, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 7ea42caa7b..f333c489ed 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1914,4 +1914,93 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( return MakeUnique<HloDynamicSliceInstruction>( shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); } + +HloGatherInstruction::HloGatherInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds) + : HloInstruction(HloOpcode::kGather, shape) { + AppendOperand(operand); + AppendOperand(gather_indices); + gather_dimension_numbers_ = + MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); + c_copy(window_bounds, std::back_inserter(gather_window_bounds_)); +} + +string HloGatherInstruction::GatherDimensionNumbersToString() const { + CHECK(gather_dimension_numbers_ != nullptr); + string output_window_dims = + StrCat("output_window_dims={", + Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); + string elided_window_dims = + StrCat("elided_window_dims={", + Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); + string gather_dims_to_operand_dims = StrCat( + "gather_dims_to_operand_dims={", + Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); + + return Join<std::initializer_list<string>>( + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice<int64> output_window_dims, + tensorflow::gtl::ArraySlice<int64> elided_window_dims, + tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, + int64 index_vector_dim) { + GatherDimensionNumbers gather_dim_numbers; + for (int64 output_window_dim : output_window_dims) { + gather_dim_numbers.add_output_window_dims(output_window_dim); + } + for (int64 elided_window_dim : elided_window_dims) { + gather_dim_numbers.add_elided_window_dims(elided_window_dim); + } + for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { + gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + } + + gather_dim_numbers.set_index_vector_dim(index_vector_dim); + return gather_dim_numbers; +} + +HloInstructionProto HloGatherInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers(); + for (int64 bound : gather_window_bounds()) { + proto.add_gather_window_bounds(bound); + } + return proto; +} + +std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {GatherDimensionNumbersToString(), + StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")}; +} + +bool HloGatherInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloGatherInstruction&>(other); + return protobuf_util::ProtobufEquals( + gather_dimension_numbers(), + casted_other.gather_dimension_numbers()) && + gather_window_bounds() == casted_other.gather_window_bounds(); +} + +std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique<HloGatherInstruction>( + shape, new_operands[0], new_operands[1], gather_dimension_numbers(), + gather_window_bounds()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e922d94234..65a93cdcf1 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1148,6 +1148,49 @@ class HloDynamicSliceInstruction : public HloInstruction { // ('start' is specified dynamically in the second operand of the operation). std::vector<int64> dynamic_slice_sizes_; }; + +class HloGatherInstruction : public HloInstruction { + public: + explicit HloGatherInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds); + const GatherDimensionNumbers& gather_dimension_numbers() const { + CHECK(gather_dimension_numbers_ != nullptr); + return *gather_dimension_numbers_; + } + tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const { + return gather_window_bounds_; + } + // Returns the dump string of the gather dimension numbers. + string GatherDimensionNumbersToString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Creates an instance of GatherDimensionNumbers. + static GatherDimensionNumbers MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice<int64> output_window_dims, + tensorflow::gtl::ArraySlice<int64> elided_window_dims, + tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, + int64 index_vector_dim); + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; + std::vector<int64> gather_window_bounds_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index f162d52d3c..d387539350 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -1192,11 +1193,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return false; } - GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/*output_window_dims, - /*elided_window_dims=*/*elided_window_dims, - /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, - /*index_vector_dim=*/*index_vector_dim); + GatherDimensionNumbers dim_numbers = + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/*output_window_dims, + /*elided_window_dims=*/*elided_window_dims, + /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); instruction = builder->AddInstruction(HloInstruction::CreateGather( shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 7c63c0acc7..39fe3c7835 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -75,19 +75,6 @@ PlatformUtil::GetSupportedPlatforms() { auto* platform = platform_pair.second; auto compiler_status = Compiler::GetForPlatform(platform); if (compiler_status.ok()) { - if (platform->VisibleDeviceCount() > 0) { - LOG(INFO) << "platform " << platform->Name() << " present with " - << platform->VisibleDeviceCount() << " visible devices"; - } else { - LOG(WARNING) << "platform " << platform->Name() << " present but no " - << "visible devices found"; - } - // Note: currently we call zero device platforms "supported" on the basis - // that, if the platform support was linked in, it was probably intended - // to be used for execution, and this way we can flag an error. - // - // TODO(b/33730287) If we want an alternative version of this behavior we - // could add an --xla_fallback_to_host flag. platforms.push_back(platform); } else { LOG(INFO) << "platform " << platform->Name() << " present but no " diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index bafe14d6f4..9b1ce143c6 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include <string> +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -1543,45 +1544,45 @@ class GatherShapeInferenceTest : public ShapeInferenceTest { }; TEST_F(GatherShapeInferenceTest, TensorFlowGather) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_vector_32_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), + /*window_bounds=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{1}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, - /*index_vector_dim=*/1), - /*window_bounds=*/{1, 48})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_vector_32_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{1}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1), + /*window_bounds=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, - /*index_vector_dim=*/4), - /*window_bounds=*/{1, 48})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4), + /*window_bounds=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1592,7 +1593,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1609,7 +1610,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1627,7 +1628,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1646,7 +1647,7 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0, 1, 2, 3, 4}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1664,7 +1665,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0, 1, 2, 3}, /*elided_window_dims=*/{0}, /*gather_dims_to_operand_dims=*/{0}, @@ -1679,10 +1680,11 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/1), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1693,10 +1695,11 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/0), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1707,10 +1710,11 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/0), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1722,7 +1726,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 8, 7}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1739,7 +1743,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 7}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1756,7 +1760,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 99, 100, 101}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1772,7 +1776,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 9}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1788,7 +1792,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{4}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1806,7 +1810,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 19}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1823,7 +1827,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 3}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1841,7 +1845,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, @@ -1860,7 +1864,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, @@ -1878,7 +1882,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, @@ -1896,7 +1900,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{2, 1}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1911,7 +1915,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{2}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1928,7 +1932,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1946,7 +1950,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{1}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1962,7 +1966,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 4f8ef2d8a1..7821c98f1c 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -30,7 +30,6 @@ from tensorflow.contrib.autograph.impl.api import do_not_convert from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.core.errors import improved_errors -from tensorflow.contrib.autograph.core.errors import rewrite_graph_construction_error from tensorflow.contrib.autograph.core.errors import GraphConstructionError from tensorflow.contrib.autograph.core.errors import TfRuntimeError from tensorflow.contrib.autograph.impl.api import to_graph @@ -46,12 +45,14 @@ _allowed_symbols = [ 'convert', 'converted_call', 'do_not_convert', - 'improved_errors', 'to_code', 'to_graph', # Overloaded operators 'operators', - 'rewrite_graph_construction_error', + # Errors + 'improved_errors', + 'GraphConstructionError', + 'TfRuntimeError', # Python language "extensions" 'set_element_type', 'set_loop_options', diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index b2e2e27673..33d8d517a5 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -24,6 +24,7 @@ py_library( "continue_statements.py", "control_flow.py", "decorators.py", + "error_handlers.py", "ifexp.py", "list_comprehension.py", "lists.py", @@ -216,6 +217,18 @@ py_test( ) py_test( + name = "error_handlers_test", + srcs = ["error_handlers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( name = "slices_test", srcs = ["slices_test.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/autograph/converters/error_handlers.py b/tensorflow/contrib/autograph/converters/error_handlers.py new file mode 100644 index 0000000000..3f23662152 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/error_handlers.py @@ -0,0 +1,52 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wraps function bodies with a try/except to rewrite error tracebacks. + +Only adds try/except wrappers to functions that have the anno.Basic.ORIGIN +annotation because these are the functions originally written by the user. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates + + +class ErrorRewritingTransformer(converter.Base): + """Possibly wraps the body of a function in a try/except. + + Only wraps functions that were originally defined by the user, detected by + checking for the anno.Basic.ORIGIN annotation. + """ + + def visit_FunctionDef(self, node): + node = self.generic_visit(node) + + if anno.hasanno(node, anno.Basic.ORIGIN): + template = """ + try: + body + except: + ag__.rewrite_graph_construction_error(ag_source_map__) + """ + node.body = templates.replace(template, body=node.body) + return node + + +def transform(node, ctx): + return ErrorRewritingTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py new file mode 100644 index 0000000000..408e35b4b6 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for error_handlers module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import error_handlers +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.core import errors +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.python.platform import test + + +class ErrorHandlersTest(converter_testing.TestCase): + + def compiled_fn(self, test_fn, add_origin=False): + node = self.parse_and_analyze(test_fn, {}) + if add_origin: + anno.setanno(node.body[0], anno.Basic.ORIGIN, + origin_info.OriginInfo(__file__, None, None, None, None)) + node = error_handlers.transform(node, self.ctx) + module = self.compiled(node,) + return module + + def test_no_origin_annotation(self): + + def test_fn(): + raise ValueError('Crash!') + + with self.compiled_fn(test_fn) as result: + with self.assertRaises(ValueError): + result.test_fn() + + def test_wraps_body(self): + + def test_fn(): + raise ValueError('Crash!') + + with self.compiled_fn(test_fn, add_origin=True) as result: + result.rewrite_graph_construction_error = None + with self.assertRaises(errors.GraphConstructionError): + result.test_fn() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/single_return.py index a351cd81b8..3b9c9a06d8 100644 --- a/tensorflow/contrib/autograph/converters/single_return.py +++ b/tensorflow/contrib/autograph/converters/single_return.py @@ -224,11 +224,6 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor): self.generic_visit(node) self.cant_return = False - def visit_Try(self, node): - self.cant_return = True - self.generic_visit(node) - self.cant_return = False - def visit_Return(self, node): if self.cant_return: raise ValueError( diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py index 3f5fc57125..de04cc9184 100644 --- a/tensorflow/contrib/autograph/converters/slices.py +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -56,8 +56,7 @@ class SliceTransformer(converter.Base): def visit_Subscript(self, node): node = self.generic_visit(node) if not isinstance(node.slice, gast.Index): - # TODO(mdan): It might make more sense to wave them through. - raise NotImplementedError('non-index slice') + return node if not isinstance(node.ctx, gast.Load): # Index writes are handled at a higher level, one at which the rvalue is diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py index 54e6aa0f3b..a93e4a8064 100644 --- a/tensorflow/contrib/autograph/core/converter.py +++ b/tensorflow/contrib/autograph/core/converter.py @@ -64,15 +64,29 @@ from __future__ import division from __future__ import print_function import collections +from enum import Enum + from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import naming +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import liveness +from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.contrib.autograph.pyct.static_analysis import type_info # TODO(mdan): These contexts can be refactored into first class objects. # For example, we could define Program and Entity abstractions that hold on # to the actual entity and have conversion methods. +# TODO(mdan): Add a test specific to this converter. + class ProgramContext(object): """ProgramContext keeps track of converting function hierarchies. @@ -197,6 +211,46 @@ class Base(transformer.Base): self._used = False self._ast_depth = 0 + def get_definition_directive(self, node, directive, arg, default): + """Returns the unique directive for a symbol, or a default if none exist. + + See lang/directives.py for details on directives. + + Args: + node: ast.AST + directive: Callable[..., Any] + arg: str + default: Any + + Raises: + ValueError: if conflicting annotations have been found + """ + defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) + if not defs: + return default + + # TODO(mdan): Simplify this. + arg_values = [] + for def_ in defs: + if (directive not in def_.directives or + arg not in arg not in def_.directives[directive]): + continue + arg_value = def_.directives[directive][arg] + for prev_value in arg_values: + if not ast_util.matches(arg_value, prev_value): + qn = anno.getanno(node, anno.Basic.QN) + raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % + (qn, directive.__name__, arg, + compiler.ast_to_source(arg_value).strip(), + compiler.ast_to_source(prev_value).strip())) + arg_values.append(arg_value) + + if not arg_values: + return default + + arg_value, = arg_values + return arg_value + def visit(self, node): if not self._ast_depth: if self._used: @@ -208,3 +262,69 @@ class Base(transformer.Base): return super(Base, self).visit(node) finally: self._ast_depth -= 1 + + +class AnnotatedDef(reaching_definitions.Definition): + + def __init__(self): + super(AnnotatedDef, self).__init__() + self.directives = {} + + +class AgAnno(Enum): + """Annotation labels specific to AutoGraph. See anno.py.""" + + DIRECTIVES = 'User directives associated with the annotated statement.' + + def __repr__(self): + return self.name + + +def standard_analysis(node, context, is_initial=False): + """Performs a complete static analysis of the given code. + + Args: + node: ast.AST + context: converter.EntityContext + is_initial: bool, whether this is the initial analysis done on the input + source code + + Returns: + ast.AST, same as node, with the static analysis annotations added + """ + # TODO(mdan): Clear static analysis here. + # TODO(mdan): Consider not running all analyses every time. + # TODO(mdan): Don't return a node because it's modified by reference. + graphs = cfg.build(node) + node = qual_names.resolve(node) + node = activity.resolve(node, context.info, None) + node = reaching_definitions.resolve(node, context.info, graphs, AnnotatedDef) + node = liveness.resolve(node, context.info, graphs) + node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) + node = type_info.resolve(node, context.info) + # This second call allows resolving first-order class attributes. + node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) + if is_initial: + anno.dup( + node, + { + anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, + }, + ) + return node + + +def apply_(node, context, converter_module): + """Applies a converter to an AST. + + Args: + node: ast.AST + context: converter.EntityContext + converter_module: converter.Base + + Returns: + ast.AST, the result of applying converter to node + """ + node = standard_analysis(node, context) + node = converter_module.transform(node, context) + return node diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/contrib/autograph/core/converter_testing.py index 0e46aacc12..c47b70f15c 100644 --- a/tensorflow/contrib/autograph/core/converter_testing.py +++ b/tensorflow/contrib/autograph/core/converter_testing.py @@ -25,6 +25,7 @@ from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.core import errors from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import pretty_printer @@ -89,6 +90,8 @@ class TestCase(test.TestCase): fake_ag = self.make_fake_mod('fake_ag', converted_call) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils + fake_ag.__dict__['rewrite_graph_construction_error'] = ( + errors.rewrite_graph_construction_error) result.__dict__['ag__'] = fake_ag yield result except Exception: # pylint:disable=broad-except diff --git a/tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb b/tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb index d6a29ea1ec..a64e266f6a 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb @@ -378,7 +378,7 @@ } ], "source": [ - "#@test {\"timeout\": 90} \n", + "#@test {\"timeout\": 90}\n", "with tf.Graph().as_default():\n", " hp = tf.contrib.training.HParams(\n", " learning_rate=0.05,\n", @@ -580,7 +580,7 @@ } ], "source": [ - "#@test {\"timeout\": 90} \n", + "#@test {\"timeout\": 90}\n", "with context.eager_mode():\n", " durations = []\n", " for t in range(burn_ins + trials):\n", @@ -628,10 +628,6 @@ "colab": { "collapsed_sections": [], "default_view": {}, - "last_runtime": { - "build_target": "", - "kind": "local" - }, "name": "Autograph vs. Eager MNIST benchmark", "provenance": [ { diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index c7401c7df1..f7fe3de5da 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -99,6 +99,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): Returns: A decorator that wraps the original function. """ + def decorator(f): """Decorator implementation.""" @@ -109,8 +110,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): @wraps(f) def py_func_wrapper(*args, **kwargs): if kwargs: - raise NotImplementedError( - 'RunMode.PY_FUNC does not yet support kwargs') + raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs') # TODO(mdan): Add support for kwargs. return py_func.wrap_py_func( f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes) @@ -231,7 +231,10 @@ def to_graph(e, Returns: A function with a signature identical to `o`, but which when executed it - creates TF a graph that has the same functionality as the original entity. + creates TF a graph that has the same functionality as the original entity. + Raises: + ValueError: If the converted function defines or refers to symbol names that + are reserved for AutoGraph. """ program_ctx = converter.ProgramContext( recursive=recursive, @@ -256,6 +259,19 @@ def to_graph(e, compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) + # Need this so the source_mapping attribute is available for the context + # manager to access for runtime errors. + # + # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' + # symbol to the compiled module. + source_map_attribute_name = 'ag_source_map' + if getattr(compiled_fn, source_map_attribute_name, None) is not None: + raise ValueError('cannot convert %s because is has an attribute ' + '"%s", which is reserved for AutoGraph.' % + (compiled_fn, source_map_attribute_name)) + setattr(compiled_fn, source_map_attribute_name, + compiled_node.__dict__['ag_source_map__']) + if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) @@ -292,7 +308,7 @@ def to_code(e, conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) code = '\n'.join( - compiler.ast_to_source(dep, indentation) + compiler.ast_to_source(dep, indentation)[0] for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) return program_ctx.required_imports + '\n\n' + code diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index 9943093332..4de7df6572 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -206,8 +206,8 @@ class ApiTest(test.TestCase): return x with self.test_session() as sess: - x = api.converted_call( - test_fn, False, False, {}, constant_op.constant(-1)) + x = api.converted_call(test_fn, False, False, {}, + constant_op.constant(-1)) self.assertEqual(1, sess.run(x)) def test_converted_call_method(self): @@ -274,8 +274,8 @@ class ApiTest(test.TestCase): return self.x with self.test_session() as sess: - tc = api.converted_call( - TestClass, False, False, {}, constant_op.constant(-1)) + tc = api.converted_call(TestClass, False, False, {}, + constant_op.constant(-1)) # tc is now a converted object. x = tc.test_method() self.assertEqual(1, sess.run(x)) @@ -305,6 +305,13 @@ class ApiTest(test.TestCase): # Just check that it is parseable Python code. self.assertIsNotNone(parser.parse_str(compiled_code)) + def test_source_map_attribute_present(self): + + def test_fn(y): + return y**2 + + self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 776d19f672..bd14359356 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -31,6 +31,7 @@ from tensorflow.contrib.autograph.converters import call_trees from tensorflow.contrib.autograph.converters import continue_statements from tensorflow.contrib.autograph.converters import control_flow from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.converters import error_handlers from tensorflow.contrib.autograph.converters import ifexp from tensorflow.contrib.autograph.converters import lists from tensorflow.contrib.autograph.converters import logical_expressions @@ -40,8 +41,10 @@ from tensorflow.contrib.autograph.converters import single_return from tensorflow.contrib.autograph.converters import slices from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.core import errors from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer @@ -231,6 +234,8 @@ def _add_self_references(namespace, autograph_module): ag_internal = imp.new_module('autograph') ag_internal.converted_call = autograph_module.converted_call ag_internal.utils = utils + ag_internal.rewrite_graph_construction_error = ( + errors.rewrite_graph_construction_error) # TODO(mdan): Add safeguards against name clashes. # We don't want to create a submodule because we want the operators to be # accessible as ag__.<operator> @@ -241,9 +246,10 @@ def _add_self_references(namespace, autograph_module): def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" + node, source = parser.parse_entity(f) node = node.body[0] - + origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) @@ -319,4 +325,5 @@ def node_to_graph(node, context): node = _apply_transformer(node, context, logical_expressions) node = _apply_transformer(node, context, side_effect_guards) node = _apply_transformer(node, context, name_scopes) + node = _apply_transformer(node, context, error_handlers) return node diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index f5279298af..207225a1ac 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -79,10 +79,12 @@ class ConversionTest(test.TestCase): self.assertTrue(f in program_ctx.dependency_cache) self.assertTrue(g in program_ctx.dependency_cache) self.assertEqual('tf__f', program_ctx.dependency_cache[f].name) - # need the extra .body[0] in order to step past the with tf.name_scope('f') - # that is added automatically + # need one extra .body[0] in order to step past the try/except wrapper that + # is added automatically, the other for the with tf.name_scope('f') that is + # added automatically self.assertEqual( - 'tf__g', program_ctx.dependency_cache[f].body[0].body[0].value.func.id) + 'tf__g', + program_ctx.dependency_cache[f].body[0].body[0].body[0].value.func.id) self.assertEqual('tf__g', program_ctx.dependency_cache[g].name) def test_entity_to_graph_class_hierarchy(self): diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py index 0cf87dd8d3..86e3f56a64 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -20,6 +20,7 @@ from __future__ import print_function import ast +import collections import gast from tensorflow.contrib.autograph.pyct import anno @@ -184,7 +185,6 @@ class PatternMatcher(gast.NodeVisitor): if v != p: return self.no_match() - def matches(node, pattern): """Basic pattern matcher for AST. @@ -251,3 +251,32 @@ def apply_to_single_assignments(targets, values, apply_fn): apply_to_single_assignments(target_el, value_el, apply_fn) else: apply_fn(target, values) + + +def iter_fields(node): + for field in sorted(node._fields): + try: + yield getattr(node, field) + except AttributeError: + pass + + +def iter_child_nodes(node): + for field in iter_fields(node): + if isinstance(field, gast.AST): + yield field + elif isinstance(field, list): + for item in field: + if isinstance(item, gast.AST): + yield item + + +def parallel_walk(node_a, node_b): + todo_a = collections.deque([node_a]) + todo_b = collections.deque([node_b]) + while todo_a and todo_b: + node_a = todo_a.popleft() + node_b = todo_b.popleft() + todo_a.extend(iter_child_nodes(node_a)) + todo_b.extend(iter_child_nodes(node_b)) + yield node_a, node_b diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py index bd546c7f48..981e398b93 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util_test.py +++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py @@ -44,7 +44,8 @@ class AstUtilTest(test.TestCase): node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) self.assertIsInstance(node.body[0].value.left.id, str) - self.assertEqual(compiler.ast_to_source(node).strip(), 'renamed_a + b') + source, _ = compiler.ast_to_source(node) + self.assertEqual(source.strip(), 'renamed_a + b') def test_rename_symbols_attributes(self): node = parser.parse_str('b.c = b.c.d') @@ -53,8 +54,8 @@ class AstUtilTest(test.TestCase): node = ast_util.rename_symbols( node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) - self.assertEqual( - compiler.ast_to_source(node).strip(), 'renamed_b_c = renamed_b_c.d') + source, _ = compiler.ast_to_source(node) + self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d') def test_rename_symbols_annotations(self): node = parser.parse_str('a[i]') @@ -129,9 +130,9 @@ class AstUtilTest(test.TestCase): 'super(Bar, _).__init__(_)') def _mock_apply_fn(self, target, source): - target = compiler.ast_to_source(target).strip() - source = compiler.ast_to_source(source).strip() - self._invocation_counts[(target, source)] += 1 + target, _ = compiler.ast_to_source(target) + source, _ = compiler.ast_to_source(source) + self._invocation_counts[(target.strip(), source.strip())] += 1 def test_apply_to_single_assignments_dynamic_unpack(self): node = parser.parse_str('a, b, c = d') @@ -155,6 +156,25 @@ class AstUtilTest(test.TestCase): ('c', 'f'): 1, }) + def test_parallel_walk(self): + ret = ast.Return( + ast.BinOp( + op=ast.Add(), + left=ast.Name(id='a', ctx=ast.Load()), + right=ast.Num(1))) + node = ast.FunctionDef( + name='f', + args=ast.arguments( + args=[ast.Name(id='a', ctx=ast.Param())], + vararg=None, + kwarg=None, + defaults=[]), + body=[ret], + decorator_list=[], + returns=None) + for child_a, child_b in ast_util.parallel_walk(node, node): + self.assertEqual(child_a, child_b) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py index 8ef234745c..cef6e95206 100644 --- a/tensorflow/contrib/autograph/pyct/cfg.py +++ b/tensorflow/contrib/autograph/pyct/cfg.py @@ -67,8 +67,10 @@ class Node(object): if isinstance(self.ast_node, gast.FunctionDef): return 'def %s' % self.ast_node.name elif isinstance(self.ast_node, gast.withitem): - return compiler.ast_to_source(self.ast_node.context_expr).strip() - return compiler.ast_to_source(self.ast_node).strip() + source, _ = compiler.ast_to_source(self.ast_node.context_expr) + return source.strip() + source, _ = compiler.ast_to_source(self.ast_node) + return source.strip() class Graph( @@ -122,6 +124,8 @@ class _WalkMode(Enum): REVERSE = 2 +# TODO(mdan): Rename to DataFlowAnalyzer. +# TODO(mdan): Consider specializations that use gen/kill/transfer abstractions. class GraphVisitor(object): """Base class for a CFG visitors. @@ -159,6 +163,7 @@ class GraphVisitor(object): """ raise NotImplementedError('Subclasses must implement this.') + # TODO(mdan): Rename to flow? def visit_node(self, node): """Visitor function. diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py index 24c4517afa..c172ab21f6 100644 --- a/tensorflow/contrib/autograph/pyct/compiler.py +++ b/tensorflow/contrib/autograph/pyct/compiler.py @@ -30,9 +30,49 @@ import tempfile import astor import gast +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.contrib.autograph.pyct import parser + + +def _build_source_map(node, code): + """Return the Python objects represented by given AST. + + Compiling the AST code this way ensures that the source code is readable by + e.g. `pdb` or `inspect`. + + Args: + node: An AST node of the original generated code, before the source code is + generated. + code: The string representation of the source code for the newly generated + code. + + Returns: + Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph + generated code. + """ + # After we have the final generated code we reparse it to get the final line + # numbers. Then we walk through the generated and original ASTs in parallel + # to build the mapping between the user and generated code. + new_node = parser.parse_str(code) + origin_info.resolve(new_node, code) + source_mapping = {} + for before, after in ast_util.parallel_walk(node, new_node): + # Need both checks because if origin information is ever copied over to new + # nodes then we need to rely on the fact that only the original user code + # has the origin annotation. + if (anno.hasanno(before, anno.Basic.ORIGIN) and + anno.hasanno(after, anno.Basic.ORIGIN)): + source_info = anno.getanno(before, anno.Basic.ORIGIN) + new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number + source_mapping[new_line_number] = source_info + return source_mapping + def ast_to_source(node, indentation=' '): """Return the source code of given AST.""" + original_node = node if isinstance(node, gast.AST): node = gast.gast_to_ast(node) generator = astor.codegen.SourceGenerator(indentation, False, @@ -42,11 +82,16 @@ def ast_to_source(node, indentation=' '): # In some versions of Python, literals may appear as actual values. This # ensures everything is string. code = map(str, generator.result) - return astor.source_repr.pretty_source(code).lstrip() + code = astor.source_repr.pretty_source(code).lstrip() + source_mapping = _build_source_map(original_node, code) + return code, source_mapping -def ast_to_object( - node, indentation=' ', source_prefix=None, delete_on_exit=True): + +def ast_to_object(node, + indentation=' ', + source_prefix=None, + delete_on_exit=True): """Return the Python objects represented by given AST. Compiling the AST code this way ensures that the source code is readable by @@ -56,15 +101,30 @@ def ast_to_object( node: The code to compile, as an AST object. indentation: The string to use for indentation. source_prefix: Optional string to print as-is into the source file. - delete_on_exit: Whether to delete the temporary file used for compilation - on exit. + delete_on_exit: Whether to delete the temporary file used for compilation on + exit. Returns: A module object containing the compiled source code. + Raises: + ValueError: If ag_source_map__ is already in the namespace of the compiled + node. """ - source = ast_to_source(node, indentation) + # code_source_mapping does not yet include the offsets from import statements. + source, code_source_mapping = ast_to_source(node, indentation=indentation) with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + # TODO(znado): move into an _offset_source_map() helper function. + # Need to offset the generated line numbers by the number of import lines. + if source_prefix: + num_import_lines = source_prefix.count('\n') + 1 + else: + num_import_lines = 0 + source_mapping = {} + for line_number, original_position in code_source_mapping.items(): + source_map_key = origin_info.CodeLocation( + file_path=f.name, line_number=line_number + num_import_lines) + source_mapping[source_map_key] = original_position module_name = os.path.basename(f.name[:-3]) if source_prefix: f.write(source_prefix) @@ -72,4 +132,27 @@ def ast_to_object( f.write(source) if delete_on_exit: atexit.register(lambda: os.remove(f.name)) - return imp.load_source(module_name, f.name), source + compiled_node = imp.load_source(module_name, f.name) + + # TODO(znado): Clean this up so we don't need to attach it to the namespace. + # TODO(znado): This does not work for classes because their methods share a + # namespace. + # This attaches the source map which is needed for error handling. Note that + # api.to_graph copies this source map into an attribute of the function. + # + # We need this so the ag_source_map__ variable is available to the call to + # rewrite_graph_construction_error in the except block inside each function + # that handles graph construction errors. + # + # We cannot get the rewritten function name until it is too late so templating + # is hard, and this cleanly fixes the + # issues encountered with nested functions because this is attached to the + # outermost one. + source_map_name = 'ag_source_map__' + if source_map_name in compiled_node.__dict__: + raise ValueError('cannot convert %s because is has namespace attribute ' + '"%s", which is reserved for AutoGraph.' % + (compiled_node, source_map_name)) + compiled_node.__dict__[source_map_name] = source_mapping + + return compiled_node, source diff --git a/tensorflow/contrib/autograph/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py index 98cdc1506b..e29fa9324c 100644 --- a/tensorflow/contrib/autograph/pyct/compiler_test.py +++ b/tensorflow/contrib/autograph/pyct/compiler_test.py @@ -59,14 +59,14 @@ class CompilerTest(test.TestCase): value=gast.Str('c')) ]) + source, _ = compiler.ast_to_source(node, indentation=' ') self.assertEqual( textwrap.dedent(""" if 1: a = b else: a = 'c' - """).strip(), - compiler.ast_to_source(node, indentation=' ').strip()) + """).strip(), source.strip()) def test_ast_to_object(self): node = gast.FunctionDef( diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py index b3c6a43d37..614e346634 100644 --- a/tensorflow/contrib/autograph/pyct/origin_info.py +++ b/tensorflow/contrib/autograph/pyct/origin_info.py @@ -17,10 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from collections import namedtuple +import collections +import gast -class CodeLocation(namedtuple('CodeLocation', ('file_path', 'line_number'))): +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.python.util import tf_inspect + + +class CodeLocation( + collections.namedtuple('CodeLocation', ('file_path', 'line_number'))): """Location of a line of code. Attributes: @@ -31,8 +37,9 @@ class CodeLocation(namedtuple('CodeLocation', ('file_path', 'line_number'))): class OriginInfo( - namedtuple('OriginInfo', ('file_path', 'function_name', 'line_number', - 'column_offset', 'source_code_line'))): + collections.namedtuple('OriginInfo', + ('file_path', 'function_name', 'line_number', + 'column_offset', 'source_code_line'))): """Container for information about the source code before conversion. Instances of this class contain information about the source code that @@ -50,3 +57,44 @@ class OriginInfo( """ return (self.file_path, self.line_number, self.function_name, self.source_code_line) + + +# TODO(znado): Consider refactoring this into a Visitor. +def resolve(node, source, function=None): + """Adds an origin information to all nodes inside the body of function. + + Args: + node: The AST node for the function whose body nodes will be annotated. + source: Text, the source code string for the function whose body nodes will + be annotated. + function: Callable, the function that will have all nodes inside of it + annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If + it is None then only the line numbers and column offset will be set in the + annotation, with the rest of the information being None. + + Returns: + A tuple of the AST node for function and a String containing its source + code. + """ + if function: + _, function_lineno = tf_inspect.getsourcelines(function) + function_filepath = tf_inspect.getsourcefile(function) + else: + function_lineno = None + function_filepath = None + source_lines = source.split('\n') + for n in gast.walk(node): + if hasattr(n, 'lineno'): + # n.lineno is relative to the start of the enclosing function, so need to + # offset it by the line of the function. + source_code_line = source_lines[n.lineno - 1] + if function: + source_lineno = n.lineno + function_lineno - 1 + function_name = function.__name__ + else: + source_lineno = n.lineno + function_name = None + anno.setanno( + n, anno.Basic.ORIGIN, + OriginInfo(function_filepath, function_name, source_lineno, + n.col_offset, source_code_line)) diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py index 9001e54e46..72d1d3b269 100644 --- a/tensorflow/contrib/autograph/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -45,6 +45,7 @@ class ReplaceTransformer(gast.NodeTransformer): self.replacements = replacements self.in_replacements = False self.preserved_annos = { + anno.Basic.ORIGIN, anno.Basic.SKIP_PROCESSING, anno.Static.ORIG_DEFINITIONS, } diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index d9a157aead..bbdfefc50a 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -396,7 +396,8 @@ class Base(gast.NodeTransformer): def _get_source(self, node): try: - return compiler.ast_to_source(node) + source, _ = compiler.ast_to_source(node) + return source except AssertionError: return '<could not convert AST to source>' diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index eb9482dc25..b2330c4e34 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -193,6 +193,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) # flaky test "${tensorflow_source_dir}/tensorflow/python/profiler/internal/run_metadata_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/model_analyzer_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/map_dataset_op_test.py" # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py" @@ -216,7 +217,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) ${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py ${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py ${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py - + # Tests too large to run. + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py ) if (WIN32) set(tf_test_src_py_exclude diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index cc9cf53410..b33243021b 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -214,7 +214,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer, total_generator_loss = 0.0 total_discriminator_loss = 0.0 - for (batch_index, images) in enumerate(tfe.Iterator(dataset)): + for (batch_index, images) in enumerate(dataset): with tf.device('/cpu:0'): tf.assign_add(step_counter, 1) @@ -227,7 +227,10 @@ def train_one_epoch(generator, discriminator, generator_optimizer, maxval=1., seed=batch_index) - with tf.GradientTape(persistent=True) as g: + # we can use 2 tapes or a single persistent tape. + # Using two tapes is memory efficient since intermediate tensors can be + # released between the two .gradient() calls below + with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise) tf.contrib.summary.image( 'generated_images', @@ -243,9 +246,10 @@ def train_one_epoch(generator, discriminator, generator_optimizer, generator_loss_val = generator_loss(discriminator_gen_outputs) total_generator_loss += generator_loss_val - generator_grad = g.gradient(generator_loss_val, generator.variables) - discriminator_grad = g.gradient(discriminator_loss_val, - discriminator.variables) + generator_grad = gen_tape.gradient(generator_loss_val, + generator.variables) + discriminator_grad = disc_tape.gradient(discriminator_loss_val, + discriminator.variables) generator_optimizer.apply_gradients( zip(generator_grad, generator.variables)) diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 4257e754ad..16a0e71624 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -36,12 +36,13 @@ struct AllocationInfo { ArenaPlanner::ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info, - bool preserve_inputs) + bool preserve_inputs, bool preserve_intermediates) : context_(context), graph_info_(std::move(graph_info)), arena_(kDefaultArenaAlignment), persistent_arena_(kDefaultArenaAlignment), - preserve_inputs_(preserve_inputs) {} + preserve_inputs_(preserve_inputs), + preserve_intermediates_(preserve_intermediates) {} ArenaPlanner::~ArenaPlanner() {} int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) { @@ -164,13 +165,15 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Then update the ref-counts of the node's inputs, and if necessary queue // them for deallocation. - TfLiteIntArray* node_inputs = node.inputs; - for (int j = 0; j < node_inputs->size; ++j) { - int tensor_index = node_inputs->data[j]; - if (tensor_index != kOptionalTensor) { - refcounts[tensor_index]--; - if (refcounts[tensor_index] == 0) { - TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index)); + if (!preserve_intermediates_) { + TfLiteIntArray* node_inputs = node.inputs; + for (int j = 0; j < node_inputs->size; ++j) { + int tensor_index = node_inputs->data[j]; + if (tensor_index != kOptionalTensor) { + refcounts[tensor_index]--; + if (refcounts[tensor_index] == 0) { + TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index)); + } } } } diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index 1d84950e91..82c866734f 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -47,7 +47,7 @@ class ArenaPlanner : public MemoryPlanner { // graph will not share memory with any other tensor, effectively preserving // them until the end of inference. ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info, - bool preserve_inputs); + bool preserve_inputs, bool preserve_intermediates); ~ArenaPlanner() override; ArenaPlanner(const ArenaPlanner&) = delete; ArenaPlanner& operator=(const ArenaPlanner&) = delete; @@ -104,7 +104,14 @@ class ArenaPlanner : public MemoryPlanner { // declared as kTfLiteArenaRwPersistent. SimpleMemoryArena persistent_arena_; + // Ensure that the memory self-allocated for inputs is never reused by the + // allocator. This allows for example, multiple runs without getting + // unpredictable results. bool preserve_inputs_; + + // If true, then no overlapping of memory areas is done, meaning intermediates + // results can be queried after running (modulo running delegates). + bool preserve_intermediates_; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index f5bd1932f9..1adb426d58 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -156,7 +156,7 @@ class ArenaPlannerTest : public ::testing::Test { context_.ReportError = ReportError; planner_.reset(new ArenaPlanner( &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)), - preserve_inputs)); + preserve_inputs, /*preserve intermediates*/ false)); CHECK(planner_->ResetAllocations() == kTfLiteOk); CHECK(planner_->PlanAllocations() == kTfLiteOk); } diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD index 11cc8185f6..066b106215 100644 --- a/tensorflow/contrib/lite/delegates/eager/BUILD +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -30,6 +30,6 @@ cc_test( ":util", "//tensorflow/contrib/lite/testing:util", "//tensorflow/core:lib", - "//testing/base/public:gunit", + "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index c1c8ef049f..4e7d33a1b6 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -39,22 +39,22 @@ single thread large core. Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance ------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: -Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.9% | 65.8% | 3.7 ms -Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.5% | 69.1% | 5.5 ms -Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.8% | 71.9% | 7.9 ms -Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 73.8% | 10.4 ms -Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.9% | 8.8 ms -Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 81.3% | 13.0 ms -Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.4% | 83.2% | 18.3 ms -Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 62.2% | 84.5% | 24.7 ms -Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 59.8% | 82.8% | 16.2 ms -Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 63.9% | 85.5% | 24.3 ms -Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.2% | 87.1% | 33.8 ms -Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 67.9% | 88.1% | 45.4 ms -Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 64.0% | 85.5% | 24.9 ms -Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.3% | 87.7% | 37.4 ms -Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.0% | 88.9% | 51.9 ms -Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 69.7% | 89.5% | 70.2 ms +Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.7% | 65.8% | 3.7 ms +Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 41.9% | 69.1% | 5.5 ms +Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.3% | 71.9% | 7.9 ms +Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 46.4% | 73.8% | 10.4 ms +Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.1% | 78.9% | 8.8 ms +Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.6% | 81.3% | 13.0 ms +Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 59.1% | 83.2% | 18.3 ms +Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 61.0% | 84.5% | 24.7 ms +Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 52.5% | 82.8% | 16.2 ms +Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 63.6% | 85.5% | 24.3 ms +Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 61.1% | 87.1% | 33.8 ms +Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.7% | 88.1% | 45.4 ms +Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 62.7% | 85.5% | 24.9 ms +Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 66.6% | 87.7% | 37.4 ms +Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.2% | 88.9% | 51.9 ms +Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 69.3% | 89.5% | 70.2 ms ## Other models diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 0641a08636..d103786694 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -593,7 +593,7 @@ TfLiteStatus Interpreter::PrepareOpsAndTensors() { if (!memory_planner_) { memory_planner_.reset(new ArenaPlanner( &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)), - /*preserve_inputs=*/true)); + /*preserve_inputs=*/true, /*preserve_intermediates*/ false)); memory_planner_->PlanAllocations(); } diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index 0ba170a4da..f550339d03 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -112,8 +112,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, // TODO(alanchiao): refactor scalar multiply into separate function // for ease of adding a neon equivalent if ever necessary. for (int j = 0; j < col_size; j++) { + const int8_t* value_ptr = reinterpret_cast<int8_t*>(value->data.uint8); output->data.f[j + i * col_size] = - value->data.uint8[j + idx * col_size] * scaling_factor; + value_ptr[j + idx * col_size] * scaling_factor; } } } diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 04657fd863..4a88d168c6 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -107,9 +107,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 8}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -117,9 +117,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } @@ -128,9 +128,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 2, 4}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -138,9 +138,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } @@ -149,9 +149,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -159,9 +159,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index a0e382edb6..200f2f1515 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -255,14 +255,6 @@ void LstmStep( output_state_ptr); } -// TODO(alanchiao): move this to tensor_utils. -void VectorMultiply(const int8_t* vector, const int v_size, const float scale, - float* result) { - for (int i = 0; i < v_size; ++i) { - *result++ = scale * *vector++; - } -} - void LstmStep( const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, float input_to_input_weights_scale, @@ -415,8 +407,9 @@ void LstmStep( // For each batch and cell: update input gate. if (!use_cifg) { if (use_peephole && !is_cell_state_all_zeros) { - VectorMultiply(cell_to_input_weights_ptr, n_cell, - cell_to_input_weights_scale, recovered_cell_weights); + tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, + cell_to_input_weights_scale, + recovered_cell_weights); tensor_utils::VectorBatchVectorCwiseProductAccumulate( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, input_gate_scratch); @@ -427,8 +420,9 @@ void LstmStep( // For each batch and cell: update forget gate. if (use_peephole && !is_cell_state_all_zeros) { - VectorMultiply(cell_to_forget_weights_ptr, n_cell, - cell_to_forget_weights_scale, recovered_cell_weights); + tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, + cell_to_forget_weights_scale, + recovered_cell_weights); tensor_utils::VectorBatchVectorCwiseProductAccumulate( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, forget_gate_scratch); @@ -459,8 +453,9 @@ void LstmStep( tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); // For each batch and cell: update the output gate. if (use_peephole && !is_cell_state_all_zeros) { - VectorMultiply(cell_to_output_weights_ptr, n_cell, - cell_to_output_weights_scale, recovered_cell_weights); + tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, + cell_to_output_weights_scale, + recovered_cell_weights); tensor_utils::VectorBatchVectorCwiseProductAccumulate( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, output_gate_scratch); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index 8c57c987d7..420bc68b43 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -342,6 +342,77 @@ void NeonClipVector(const float* vector, int v_size, float abs_limit, } } +void NeonVectorScalarMultiply(const int8_t* vector, const int v_size, + const float scale, float* result) { + // Here the assumption is that each buffer is 4-byte aligned. + const int kWeightsPerUint32 = 4; + TFLITE_CHECK_EQ((intptr_t)(&vector[0]) & (kWeightsPerUint32 - 1), 0); + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int kWeightsPerNeonLane = 16; + const int postamble_start = v_size - (v_size & (kWeightsPerNeonLane - 1)); + + // Create a vector of 4 floats with the scale value. + const float32x4_t scale_f32x4 = vdupq_n_f32(scale); + int v = 0; + for (; v < postamble_start; v += kWeightsPerNeonLane) { + // Load int8 values, sixteen at a time. + const int8x16_t v_i8x16 = vld1q_s8(vector + v); + // Split it into two components of size eight. + const int8x8_t v0_i8x8 = vget_low_s8(v_i8x16); + const int8x8_t v1_i8x8 = vget_high_s8(v_i8x16); + // Convert both components to int16 first. + const int16x8_t v0_i16x8 = vmovl_s8(v0_i8x8); + const int16x8_t v1_i16x8 = vmovl_s8(v1_i8x8); + // Split each of them into two components each. + const int16x4_t v0_i16x4 = vget_low_s16(v0_i16x8); + const int16x4_t v1_i16x4 = vget_high_s16(v0_i16x8); + const int16x4_t v2_i16x4 = vget_low_s16(v1_i16x8); + const int16x4_t v3_i16x4 = vget_high_s16(v1_i16x8); + // Convert these to int32 and then to float. + float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4)); + float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4)); + float32x4_t v2_f32x4 = vcvtq_f32_s32(vmovl_s16(v2_i16x4)); + float32x4_t v3_f32x4 = vcvtq_f32_s32(vmovl_s16(v3_i16x4)); + // Vector multiply four floats at a time. + v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4); + v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4); + v2_f32x4 = vmulq_f32(v2_f32x4, scale_f32x4); + v3_f32x4 = vmulq_f32(v3_f32x4, scale_f32x4); + // Store the results. + vst1q_f32(result + v, v0_f32x4); + vst1q_f32(result + v + 4, v1_f32x4); + vst1q_f32(result + v + 8, v2_f32x4); + vst1q_f32(result + v + 12, v3_f32x4); + } + + if (v_size - postamble_start >= (kWeightsPerNeonLane >> 1)) { + // Load eight int8 values, if there is at least eight remaining. + const int8x8_t v_i8x8 = vld1_s8(vector + v); + // Convert them to int16 first. + const int16x8_t v_i16x8 = vmovl_s8(v_i8x8); + // Split it into two components. + const int16x4_t v0_i16x4 = vget_low_s16(v_i16x8); + const int16x4_t v1_i16x4 = vget_high_s16(v_i16x8); + // Convert the components two floats. + float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4)); + float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4)); + // Vector multiply four floats at a time. + v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4); + v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4); + // Store the results. + vst1q_f32(result + v, v0_f32x4); + vst1q_f32(result + v + 4, v1_f32x4); + v += (kWeightsPerNeonLane >> 1); + } + + // Postamble loop. + for (; v < v_size; v++) { + result[v] = scale * vector[v]; + } +} + void NeonSymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, float* max, float* scaling_factor) { diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 7a5a8fc541..45c9f65b64 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -105,6 +105,10 @@ bool IsZeroVector(const float* vector, int v_size) { return NEON_OR_PORTABLE(IsZeroVector, vector, v_size); } +void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result) { + NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result); +} void ClipVector(const float* vector, int v_size, float abs_limit, float* result) { NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index f14667090f..db7926df9a 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -124,6 +124,12 @@ void PortableCopyVector(const float* vector, int v_size, float* result); // Fill vector with 0.f. void PortableZeroVector(float* vector, int v_size); +// Multiply all elements of vector with a scalar. +void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); +void NeonVectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); + // Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index ccf112c990..7ead449ca8 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -195,6 +195,13 @@ void PortableZeroVector(float* vector, int v_size) { memset(vector, 0, v_size * sizeof(float)); } +void PortableVectorScalarMultiply(const int8_t* vector, const int v_size, + const float scale, float* result) { + for (int v = 0; v < v_size; ++v) { + *result++ = scale * *vector++; + } +} + void PortableClipVector(const float* vector, int v_size, float abs_limit, float* result) { for (int v = 0; v < v_size; v++) { diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index d2e1fecd25..d3a4fa8507 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -96,6 +96,10 @@ void PortableSub1Vector(const float* vector, int v_size, float* result); // Fill vector with 0.f. void PortableZeroVector(float* vector, int v_size); +// Multiply all elements of vector with a scalar. +void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); + // Clip elements of a vector using a abs_limit value. void PortableClipVector(const float* vector, int v_size, float abs_limit, float* result); @@ -199,6 +203,12 @@ void ZeroVector(float* vector, int v_size) { PortableZeroVector(vector, v_size); } +// Multiply all elements of vector with a scalar. +void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result) { + PortableVectorScalarMultiply(vector, v_size, scale, result); +} + void ClipVector(const float* vector, int v_size, float abs_limit, float* result) { PortableClipVector(vector, v_size, abs_limit, result); diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 5160e22307..82f4503127 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -124,6 +124,10 @@ void Sub1Vector(const float* vector, int v_size, float* result); // Fill vector with 0.f. void ZeroVector(float* vector, int v_size); +// Multiply all elements of vector with a scalar. +void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); + // Clip elements of a vector using a abs_limit value. void ClipVector(const float* vector, int v_size, float abs_limit, float* result); diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc index aa0d49ae4d..372a6efec5 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -32,6 +32,22 @@ TEST(uKernels, ClipTest) { {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0}))); } +TEST(uKernels, VectorScalarMultiply) { + constexpr int kVectorSize = 29; + static int8_t input[kVectorSize]; + for (int i = 0; i < 29; ++i) { + input[i] = static_cast<int8_t>(i - 14); + } + const float scale = 0.1f; + std::vector<float> output(kVectorSize, 0.0f); + VectorScalarMultiply(input, kVectorSize, scale, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {-1.4, -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5, + -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4}))); +} + TEST(uKernels, IsZeroTest) { constexpr int kVectorSize = 21; static float zeros[kVectorSize] = {0.0}; diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index febfd2dc56..556ec7117a 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -15,13 +15,13 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ #define TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ +// Place `<locale>` before <Python.h> to avoid build failures in macOS. +#include <locale> #include <memory> #include <string> #include <vector> -// Place `<locale>` before <Python.h> to avoid build failures in macOS. #include <Python.h> -#include <locale> // We forward declare TFLite classes here to avoid exposing them to SWIG. namespace tflite { diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD index 3f53ef1707..3c6fde23d2 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/BUILD +++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD @@ -10,33 +10,12 @@ load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") tf_py_test( - name = "decode_proto_fail_test", - size = "small", - srcs = ["decode_proto_fail_test.py"], - additional_deps = [ - ":py_test_deps", - "//third_party/py/numpy", - "//tensorflow/contrib/proto:proto", - "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", - ], - data = if_static( - [], - otherwise = [":libtestexample.so"], - ), - tags = [ - "no_pip", # TODO(b/78026780) - "no_windows", # TODO(b/78028010) - ], -) - -tf_py_test( name = "decode_proto_op_test", size = "small", srcs = ["decode_proto_op_test.py"], additional_deps = [ + ":decode_proto_op_test_base", ":py_test_deps", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", "//tensorflow/contrib/proto:proto", "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", ], @@ -55,9 +34,8 @@ tf_py_test( size = "small", srcs = ["encode_proto_op_test.py"], additional_deps = [ + ":encode_proto_op_test_base", ":py_test_deps", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", "//tensorflow/contrib/proto:proto", "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", @@ -73,8 +51,9 @@ tf_py_test( ) py_library( - name = "test_base", - srcs = ["test_base.py"], + name = "proto_op_test_base", + testonly = 1, + srcs = ["proto_op_test_base.py"], deps = [ ":test_example_proto_py", "//tensorflow/python:client_testlib", @@ -82,13 +61,31 @@ py_library( ) py_library( - name = "py_test_deps", + name = "decode_proto_op_test_base", + testonly = 1, + srcs = ["decode_proto_op_test_base.py"], + deps = [ + ":proto_op_test_base", + ":test_example_proto_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "encode_proto_op_test_base", + testonly = 1, + srcs = ["encode_proto_op_test_base.py"], deps = [ - ":test_base", + ":proto_op_test_base", ":test_example_proto_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) +py_library(name = "py_test_deps") + tf_proto_library( name = "test_example_proto", srcs = ["test_example.proto"], diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py deleted file mode 100644 index 3b982864bc..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# ============================================================================= -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -# Python3 preparedness imports. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.proto.python.kernel_tests import test_base -from tensorflow.contrib.proto.python.ops import decode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.platform import test - - -class DecodeProtoFailTest(test_base.ProtoOpTestBase): - """Test failure cases for DecodeToProto.""" - - def _TestCorruptProtobuf(self, sanitize): - """Test failure cases for DecodeToProto.""" - - # The goal here is to check the error reporting. - # Testing against a variety of corrupt protobufs is - # done by fuzzing. - corrupt_proto = 'This is not a binary protobuf' - - # Numpy silently truncates the strings if you don't specify dtype=object. - batch = np.array(corrupt_proto, dtype=object) - msg_type = 'tensorflow.contrib.proto.TestCase' - field_names = ['sizes'] - field_types = [dtypes.int32] - - with self.test_session() as sess: - ctensor, vtensor = decode_proto_op.decode_proto( - batch, - message_type=msg_type, - field_names=field_names, - output_types=field_types, - sanitize=sanitize) - with self.assertRaisesRegexp(errors.DataLossError, - 'Unable to parse binary protobuf' - '|Failed to consume entire buffer'): - _ = sess.run([ctensor] + vtensor) - - def testCorrupt(self): - self._TestCorruptProtobuf(sanitize=False) - - def testSanitizerCorrupt(self): - self._TestCorruptProtobuf(sanitize=True) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py index 2a07794499..934035ec4c 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py @@ -13,273 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Table-driven test for decode_proto op. +"""Tests for decode_proto op.""" -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. -""" # Python3 preparedness imports. from __future__ import absolute_import from __future__ import division from __future__ import print_function -from absl.testing import parameterized -import numpy as np - - -from google.protobuf import text_format - -from tensorflow.contrib.proto.python.kernel_tests import test_base -from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.contrib.proto.python.kernel_tests import decode_proto_op_test_base as test_base from tensorflow.contrib.proto.python.ops import decode_proto_op -from tensorflow.python.framework import dtypes from tensorflow.python.platform import test -class DecodeProtoOpTest(test_base.ProtoOpTestBase, parameterized.TestCase): - - def _compareValues(self, fd, vs, evs): - """Compare lists/arrays of field values.""" - - if len(vs) != len(evs): - self.fail('Field %s decoded %d outputs, expected %d' % - (fd.name, len(vs), len(evs))) - for i, ev in enumerate(evs): - # Special case fuzzy match for float32. TensorFlow seems to mess with - # MAX_FLT slightly and the test doesn't work otherwise. - # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. - if fd.cpp_type == fd.CPPTYPE_FLOAT: - # Numpy isclose() is better than assertIsClose() which uses an absolute - # value comparison. - self.assertTrue( - np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) - elif fd.cpp_type == fd.CPPTYPE_STRING: - # In Python3 string tensor values will be represented as bytes, so we - # reencode the proto values to match that. - self.assertEqual(vs[i], ev.encode('ascii')) - else: - # Doubles and other types pass through unscathed. - self.assertEqual(vs[i], ev) - - def _compareRepeatedPrimitiveValue(self, batch_shape, sizes, fields, - field_dict): - """Compare protos of type RepeatedPrimitiveValue. - - Args: - batch_shape: the shape of the input tensor of serialized messages. - sizes: int matrix of repeat counts returned by decode_proto - fields: list of test_example_pb2.FieldSpec (types and expected values) - field_dict: map from field names to decoded numpy tensors of values - """ - - # Check that expected values match. - for field in fields: - values = field_dict[field.name] - self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) - - fd = field.expected.DESCRIPTOR.fields_by_name[field.name] - - # Values has the same shape as the input plus an extra - # dimension for repeats. - self.assertEqual(list(values.shape)[:-1], batch_shape) - - # Nested messages are represented as TF strings, requiring - # some special handling. - if field.name == 'message_value': - vs = [] - for buf in values.flat: - msg = test_example_pb2.PrimitiveValue() - msg.ParseFromString(buf) - vs.append(msg) - evs = getattr(field.expected, field.name) - if len(vs) != len(evs): - self.fail('Field %s decoded %d outputs, expected %d' % - (fd.name, len(vs), len(evs))) - for v, ev in zip(vs, evs): - self.assertEqual(v, ev) - continue - - # This can be a little confusing. For testing we are using - # RepeatedPrimitiveValue in two ways: it's the proto that we - # decode for testing, and it's used in the expected value as a - # union type. The two cases are slightly different: this is the - # second case. - # We may be fetching the uint64_value from the test proto, but - # in the expected proto we store it in the int64_value field - # because TensorFlow doesn't support unsigned int64. - tf_type_to_primitive_value_field = { - dtypes.float32: - 'float_value', - dtypes.float64: - 'double_value', - dtypes.int32: - 'int32_value', - dtypes.uint8: - 'uint8_value', - dtypes.int8: - 'int8_value', - dtypes.string: - 'string_value', - dtypes.int64: - 'int64_value', - dtypes.bool: - 'bool_value', - # Unhandled TensorFlow types: - # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32 - # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16 - } - tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) - if tf_field_name is None: - self.fail('Unhandled tensorflow type %d' % field.dtype) - - self._compareValues(fd, values.flat, - getattr(field.expected, tf_field_name)) - - def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, - message_type, message_format, sanitize, - force_disordered=False): - """Run decode tests on a batch of messages. - - Args: - fields: list of test_example_pb2.FieldSpec (types and expected values) - case_sizes: expected sizes array - batch_shape: the shape of the input tensor of serialized messages - batch: list of serialized messages - message_type: descriptor name for messages - message_format: format of messages, 'text' or 'binary' - sanitize: whether to sanitize binary protobuf inputs - force_disordered: whether to force fields encoded out of order. - """ - - if force_disordered: - # Exercise code path that handles out-of-order fields by prepending extra - # fields with tag numbers higher than any real field. Note that this won't - # work with sanitization because that forces reserialization using a - # trusted decoder and encoder. - assert not sanitize - extra_fields = test_example_pb2.ExtraFields() - extra_fields.string_value = 'IGNORE ME' - extra_fields.bool_value = False - extra_msg = extra_fields.SerializeToString() - batch = [extra_msg + msg for msg in batch] - - # Numpy silently truncates the strings if you don't specify dtype=object. - batch = np.array(batch, dtype=object) - batch = np.reshape(batch, batch_shape) - - field_names = [f.name for f in fields] - output_types = [f.dtype for f in fields] - - with self.test_session() as sess: - sizes, vtensor = decode_proto_op.decode_proto( - batch, - message_type=message_type, - field_names=field_names, - output_types=output_types, - message_format=message_format, - sanitize=sanitize) - - vlist = sess.run([sizes] + vtensor) - sizes = vlist[0] - # Values is a list of tensors, one for each field. - value_tensors = vlist[1:] - - # Check that the repeat sizes are correct. - self.assertTrue( - np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) - - # Check that the decoded sizes match the expected sizes. - self.assertEqual(len(sizes.flat), len(case_sizes)) - self.assertTrue( - np.all(sizes.flat == np.array( - case_sizes, dtype=np.int32))) - - field_dict = dict(zip(field_names, value_tensors)) - - self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields, - field_dict) - - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) - def testBinary(self, case): - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=False) - - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) - def testBinaryDisordered(self, case): - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=False, - force_disordered=True) - - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) - def testPacked(self, case): - # Now try with the packed serialization. - # We test the packed representations by loading the same test cases - # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. - # To do this we rely on the text format being the same for packed and - # unpacked fields, and reparse the test message using the packed version - # of the proto. - packed_batch = [ - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_format.Parse( - text_format.MessageToString( - primitive, float_format='.17g'), - test_example_pb2.PackedPrimitiveValue()).SerializeToString() - for primitive in case.primitive - ] - - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - packed_batch, - 'tensorflow.contrib.proto.PackedPrimitiveValue', - 'binary', - sanitize=False) - - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) - def testText(self, case): - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_batch = [ - text_format.MessageToString( - primitive, float_format='.17g') for primitive in case.primitive - ] - - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - text_batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'text', - sanitize=False) +class DecodeProtoOpTest(test_base.DecodeProtoOpTestBase): - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) - def testSanitizerGood(self, case): - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=True) + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + super(DecodeProtoOpTest, self).__init__(decode_proto_op, methodName) if __name__ == '__main__': diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py new file mode 100644 index 0000000000..5f7f510352 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py @@ -0,0 +1,310 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for decode_proto op.""" + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + + +from google.protobuf import text_format + +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors + + +class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): + """Base class for testing proto decoding ops.""" + + def __init__(self, decode_module, methodName='runTest'): # pylint: disable=invalid-name + """DecodeProtoOpTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(DecodeProtoOpTestBase, self).__init__(methodName) + self._decode_module = decode_module + + def _compareValues(self, fd, vs, evs): + """Compare lists/arrays of field values.""" + + if len(vs) != len(evs): + self.fail('Field %s decoded %d outputs, expected %d' % + (fd.name, len(vs), len(evs))) + for i, ev in enumerate(evs): + # Special case fuzzy match for float32. TensorFlow seems to mess with + # MAX_FLT slightly and the test doesn't work otherwise. + # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. + if fd.cpp_type == fd.CPPTYPE_FLOAT: + # Numpy isclose() is better than assertIsClose() which uses an absolute + # value comparison. + self.assertTrue( + np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) + elif fd.cpp_type == fd.CPPTYPE_STRING: + # In Python3 string tensor values will be represented as bytes, so we + # reencode the proto values to match that. + self.assertEqual(vs[i], ev.encode('ascii')) + else: + # Doubles and other types pass through unscathed. + self.assertEqual(vs[i], ev) + + def _compareProtos(self, batch_shape, sizes, fields, field_dict): + """Compare protos of type TestValue. + + Args: + batch_shape: the shape of the input tensor of serialized messages. + sizes: int matrix of repeat counts returned by decode_proto + fields: list of test_example_pb2.FieldSpec (types and expected values) + field_dict: map from field names to decoded numpy tensors of values + """ + + # Check that expected values match. + for field in fields: + values = field_dict[field.name] + self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) + + fd = field.value.DESCRIPTOR.fields_by_name[field.name] + + # Values has the same shape as the input plus an extra + # dimension for repeats. + self.assertEqual(list(values.shape)[:-1], batch_shape) + + # Nested messages are represented as TF strings, requiring + # some special handling. + if field.name == 'message_value': + vs = [] + for buf in values.flat: + msg = test_example_pb2.PrimitiveValue() + msg.ParseFromString(buf) + vs.append(msg) + evs = getattr(field.value, field.name) + if len(vs) != len(evs): + self.fail('Field %s decoded %d outputs, expected %d' % + (fd.name, len(vs), len(evs))) + for v, ev in zip(vs, evs): + self.assertEqual(v, ev) + continue + + # This can be a little confusing. For testing we are using TestValue in + # two ways: it's the proto that we decode for testing, and it's used in + # the expected value as a union type. + # + # The two cases are slightly different: this is the second case. We may be + # fetching the uint64_value from the test proto, but in the expected proto + # we store it in the int64_value field because TensorFlow doesn't support + # unsigned int64. + tf_type_to_primitive_value_field = { + dtypes.float32: + 'float_value', + dtypes.float64: + 'double_value', + dtypes.int32: + 'int32_value', + dtypes.uint8: + 'uint8_value', + dtypes.int8: + 'int8_value', + dtypes.string: + 'string_value', + dtypes.int64: + 'int64_value', + dtypes.bool: + 'bool_value', + # Unhandled TensorFlow types: + # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32 + # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16 + } + tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) + if tf_field_name is None: + self.fail('Unhandled tensorflow type %d' % field.dtype) + + self._compareValues(fd, values.flat, + getattr(field.value, tf_field_name)) + + def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, + message_type, message_format, sanitize, + force_disordered=False): + """Run decode tests on a batch of messages. + + Args: + fields: list of test_example_pb2.FieldSpec (types and expected values) + case_sizes: expected sizes array + batch_shape: the shape of the input tensor of serialized messages + batch: list of serialized messages + message_type: descriptor name for messages + message_format: format of messages, 'text' or 'binary' + sanitize: whether to sanitize binary protobuf inputs + force_disordered: whether to force fields encoded out of order. + """ + + if force_disordered: + # Exercise code path that handles out-of-order fields by prepending extra + # fields with tag numbers higher than any real field. Note that this won't + # work with sanitization because that forces reserialization using a + # trusted decoder and encoder. + assert not sanitize + extra_fields = test_example_pb2.ExtraFields() + extra_fields.string_value = 'IGNORE ME' + extra_fields.bool_value = False + extra_msg = extra_fields.SerializeToString() + batch = [extra_msg + msg for msg in batch] + + # Numpy silently truncates the strings if you don't specify dtype=object. + batch = np.array(batch, dtype=object) + batch = np.reshape(batch, batch_shape) + + field_names = [f.name for f in fields] + output_types = [f.dtype for f in fields] + + with self.test_session() as sess: + sizes, vtensor = self._decode_module.decode_proto( + batch, + message_type=message_type, + field_names=field_names, + output_types=output_types, + message_format=message_format, + sanitize=sanitize) + + vlist = sess.run([sizes] + vtensor) + sizes = vlist[0] + # Values is a list of tensors, one for each field. + value_tensors = vlist[1:] + + # Check that the repeat sizes are correct. + self.assertTrue( + np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) + + # Check that the decoded sizes match the expected sizes. + self.assertEqual(len(sizes.flat), len(case_sizes)) + self.assertTrue( + np.all(sizes.flat == np.array( + case_sizes, dtype=np.int32))) + + field_dict = dict(zip(field_names, value_tensors)) + + self._compareProtos(batch_shape, sizes, fields, field_dict) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testBinary(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testBinaryDisordered(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=False, + force_disordered=True) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testPacked(self, case): + # Now try with the packed serialization. + # + # We test the packed representations by loading the same test case using + # PackedTestValue instead of TestValue. To do this we rely on the text + # format being the same for packed and unpacked fields, and reparse the + # test message using the packed version of the proto. + packed_batch = [ + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_format.Parse( + text_format.MessageToString( + value, float_format='.17g'), + test_example_pb2.PackedTestValue()).SerializeToString() + for value in case.values + ] + + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + packed_batch, + 'tensorflow.contrib.proto.PackedTestValue', + 'binary', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testText(self, case): + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_batch = [ + text_format.MessageToString( + value, float_format='.17g') for value in case.values + ] + + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + text_batch, + 'tensorflow.contrib.proto.TestValue', + 'text', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testSanitizerGood(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=True) + + @parameterized.parameters((False), (True)) + def testCorruptProtobuf(self, sanitize): + corrupt_proto = 'This is not a binary protobuf' + + # Numpy silently truncates the strings if you don't specify dtype=object. + batch = np.array(corrupt_proto, dtype=object) + msg_type = 'tensorflow.contrib.proto.TestCase' + field_names = ['sizes'] + field_types = [dtypes.int32] + + with self.test_session() as sess: + ctensor, vtensor = self._decode_module.decode_proto( + batch, + message_type=msg_type, + field_names=field_names, + output_types=field_types, + sanitize=sanitize) + with self.assertRaisesRegexp(errors.DataLossError, + 'Unable to parse binary protobuf' + '|Failed to consume entire buffer'): + _ = sess.run([ctensor] + vtensor) diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py index fb33660554..fc5cd25d43 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py @@ -13,164 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Table-driven test for encode_proto op. +"""Tests for encode_proto op.""" -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. - -It tests that encode_proto is a lossless inverse of decode_proto -(for the specified fields). -""" # Python3 readiness boilerplate from __future__ import absolute_import from __future__ import division from __future__ import print_function -from absl.testing import parameterized -import numpy as np - -from google.protobuf import text_format - -from tensorflow.contrib.proto.python.kernel_tests import test_base -from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.contrib.proto.python.kernel_tests import encode_proto_op_test_base as test_base from tensorflow.contrib.proto.python.ops import decode_proto_op from tensorflow.contrib.proto.python.ops import encode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import flags from tensorflow.python.platform import test -FLAGS = flags.FLAGS - -flags.DEFINE_string('message_text_file', None, - 'A file containing a text serialized TestCase protobuf.') - - -class EncodeProtoOpTest(test_base.ProtoOpTestBase, parameterized.TestCase): - - def testBadInputs(self): - # Invalid field name - with self.test_session(): - with self.assertRaisesOpError('Unknown field: non_existent_field'): - encode_proto_op.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['non_existent_field']).eval() - - # Incorrect types. - with self.test_session(): - with self.assertRaisesOpError( - 'Incompatible type for field double_value.'): - encode_proto_op.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value']).eval() - - # Incorrect shapes of sizes. - with self.test_session(): - with self.assertRaisesOpError( - r'sizes should be batch_size \+ \[len\(field_names\)\]'): - sizes = array_ops.placeholder(dtypes.int32) - values = array_ops.placeholder(dtypes.float64) - encode_proto_op.encode_proto( - sizes=sizes, - values=[values], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value']).eval(feed_dict={ - sizes: [[[0, 0]]], - values: [[0.0]] - }) - - # Inconsistent shapes of values. - with self.test_session(): - with self.assertRaisesOpError( - 'Values must match up to the last dimension'): - sizes = array_ops.placeholder(dtypes.int32) - values1 = array_ops.placeholder(dtypes.float64) - values2 = array_ops.placeholder(dtypes.int32) - (encode_proto_op.encode_proto( - sizes=[[1, 1]], - values=[values1, values2], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value', 'int32_value']).eval(feed_dict={ - values1: [[0.0]], - values2: [[0], [0]] - })) - - def _testRoundtrip(self, in_bufs, message_type, fields): - - field_names = [f.name for f in fields] - out_types = [f.dtype for f in fields] - - with self.test_session() as sess: - sizes, field_tensors = decode_proto_op.decode_proto( - in_bufs, - message_type=message_type, - field_names=field_names, - output_types=out_types) - - out_tensors = encode_proto_op.encode_proto( - sizes, - field_tensors, - message_type=message_type, - field_names=field_names) - - out_bufs, = sess.run([out_tensors]) - - # Check that the re-encoded tensor has the same shape. - self.assertEqual(in_bufs.shape, out_bufs.shape) - - # Compare the input and output. - for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): - in_obj = test_example_pb2.RepeatedPrimitiveValue() - in_obj.ParseFromString(in_buf) - - out_obj = test_example_pb2.RepeatedPrimitiveValue() - out_obj.ParseFromString(out_buf) - - # Check that the deserialized objects are identical. - self.assertEqual(in_obj, out_obj) - - # Check that the input and output serialized messages are identical. - # If we fail here, there is a difference in the serialized - # representation but the new serialization still parses. This could - # be harmless (a change in map ordering?) or it could be bad (e.g. - # loss of packing in the encoding). - self.assertEqual(in_buf, out_buf) - - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) - def testRoundtrip(self, case): - in_bufs = [primitive.SerializeToString() for primitive in case.primitive] - - # np.array silently truncates strings if you don't specify dtype=object. - in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape)) - return self._testRoundtrip( - in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field) - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) - def testRoundtripPacked(self, case): - # Now try with the packed serialization. - # We test the packed representations by loading the same test cases - # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. - # To do this we rely on the text format being the same for packed and - # unpacked fields, and reparse the test message using the packed version - # of the proto. - in_bufs = [ - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_format.Parse( - text_format.MessageToString( - primitive, float_format='.17g'), - test_example_pb2.PackedPrimitiveValue()).SerializeToString() - for primitive in case.primitive - ] +class EncodeProtoOpTest(test_base.EncodeProtoOpTestBase): - # np.array silently truncates strings if you don't specify dtype=object. - in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape)) - return self._testRoundtrip( - in_bufs, 'tensorflow.contrib.proto.PackedPrimitiveValue', case.field) + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + super(EncodeProtoOpTest, self).__init__(decode_proto_op, encode_proto_op, + methodName) if __name__ == '__main__': diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py new file mode 100644 index 0000000000..07dfb924d3 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py @@ -0,0 +1,177 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Table-driven test for encode_proto op. + +This test is run once with each of the *.TestCase.pbtxt files +in the test directory. + +It tests that encode_proto is a lossless inverse of decode_proto +(for the specified fields). +""" +# Python3 readiness boilerplate +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from google.protobuf import text_format + +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops + + +class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): + """Base class for testing proto encoding ops.""" + + def __init__(self, decode_module, encode_module, methodName='runTest'): # pylint: disable=invalid-name + """EncodeProtoOpTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + encode_module: a module containing the `encode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(EncodeProtoOpTestBase, self).__init__(methodName) + self._decode_module = decode_module + self._encode_module = encode_module + + def testBadInputs(self): + # Invalid field name + with self.test_session(): + with self.assertRaisesOpError('Unknown field: non_existent_field'): + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['non_existent_field']).eval() + + # Incorrect types. + with self.test_session(): + with self.assertRaisesOpError( + 'Incompatible type for field double_value.'): + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value']).eval() + + # Incorrect shapes of sizes. + with self.test_session(): + with self.assertRaisesOpError( + r'sizes should be batch_size \+ \[len\(field_names\)\]'): + sizes = array_ops.placeholder(dtypes.int32) + values = array_ops.placeholder(dtypes.float64) + self._encode_module.encode_proto( + sizes=sizes, + values=[values], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value']).eval(feed_dict={ + sizes: [[[0, 0]]], + values: [[0.0]] + }) + + # Inconsistent shapes of values. + with self.test_session(): + with self.assertRaisesOpError( + 'Values must match up to the last dimension'): + sizes = array_ops.placeholder(dtypes.int32) + values1 = array_ops.placeholder(dtypes.float64) + values2 = array_ops.placeholder(dtypes.int32) + (self._encode_module.encode_proto( + sizes=[[1, 1]], + values=[values1, values2], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value', 'int32_value']).eval(feed_dict={ + values1: [[0.0]], + values2: [[0], [0]] + })) + + def _testRoundtrip(self, in_bufs, message_type, fields): + + field_names = [f.name for f in fields] + out_types = [f.dtype for f in fields] + + with self.test_session() as sess: + sizes, field_tensors = self._decode_module.decode_proto( + in_bufs, + message_type=message_type, + field_names=field_names, + output_types=out_types) + + out_tensors = self._encode_module.encode_proto( + sizes, + field_tensors, + message_type=message_type, + field_names=field_names) + + out_bufs, = sess.run([out_tensors]) + + # Check that the re-encoded tensor has the same shape. + self.assertEqual(in_bufs.shape, out_bufs.shape) + + # Compare the input and output. + for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): + in_obj = test_example_pb2.TestValue() + in_obj.ParseFromString(in_buf) + + out_obj = test_example_pb2.TestValue() + out_obj.ParseFromString(out_buf) + + # Check that the deserialized objects are identical. + self.assertEqual(in_obj, out_obj) + + # Check that the input and output serialized messages are identical. + # If we fail here, there is a difference in the serialized + # representation but the new serialization still parses. This could + # be harmless (a change in map ordering?) or it could be bad (e.g. + # loss of packing in the encoding). + self.assertEqual(in_buf, out_buf) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testRoundtrip(self, case): + in_bufs = [value.SerializeToString() for value in case.values] + + # np.array silently truncates strings if you don't specify dtype=object. + in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes)) + return self._testRoundtrip( + in_bufs, 'tensorflow.contrib.proto.TestValue', case.fields) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testRoundtripPacked(self, case): + # Now try with the packed serialization. + # We test the packed representations by loading the same test cases using + # PackedTestValue instead of TestValue. To do this we rely on the text + # format being the same for packed and unpacked fields, and reparse the test + # message using the packed version of the proto. + in_bufs = [ + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_format.Parse( + text_format.MessageToString( + value, float_format='.17g'), + test_example_pb2.PackedTestValue()).SerializeToString() + for value in case.values + ] + + # np.array silently truncates strings if you don't specify dtype=object. + in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes)) + return self._testRoundtrip( + in_bufs, 'tensorflow.contrib.proto.PackedTestValue', case.fields) diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py new file mode 100644 index 0000000000..cbc7b3d3f8 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py @@ -0,0 +1,407 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Test case base for testing proto operations.""" + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ctypes as ct +import os + +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.platform import test + + +class ProtoOpTestBase(test.TestCase): + """Base class for testing proto decoding and encoding ops.""" + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(ProtoOpTestBase, self).__init__(methodName) + lib = os.path.join(os.path.dirname(__file__), "libtestexample.so") + if os.path.isfile(lib): + ct.cdll.LoadLibrary(lib) + + @staticmethod + def named_parameters(): + return ( + ("defaults", ProtoOpTestBase.defaults_test_case()), + ("minmax", ProtoOpTestBase.minmax_test_case()), + ("nested", ProtoOpTestBase.nested_test_case()), + ("optional", ProtoOpTestBase.optional_test_case()), + ("promote_unsigned", ProtoOpTestBase.promote_unsigned_test_case()), + ("ragged", ProtoOpTestBase.ragged_test_case()), + ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), + ("simple", ProtoOpTestBase.simple_test_case()), + ) + + @staticmethod + def defaults_test_case(): + test_case = test_example_pb2.TestCase() + test_case.values.add() # No fields specified, so we get all defaults. + test_case.shapes.append(1) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "double_value_with_default" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(1.0) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "float_value_with_default" + field.dtype = types_pb2.DT_FLOAT + field.value.float_value.append(2.0) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "int64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(3) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sfixed64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(11) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sint64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(13) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "uint64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(4) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "fixed64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(6) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "int32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(5) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sfixed32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(10) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sint32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(12) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "uint32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(9) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "fixed32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(7) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "bool_value_with_default" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "string_value_with_default" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("a") + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "bytes_value_with_default" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("a longer default string") + return test_case + + @staticmethod + def minmax_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(-1.7976931348623158e+308) + value.double_value.append(2.2250738585072014e-308) + value.double_value.append(1.7976931348623158e+308) + value.float_value.append(-3.402823466e+38) + value.float_value.append(1.175494351e-38) + value.float_value.append(3.402823466e+38) + value.int64_value.append(-9223372036854775808) + value.int64_value.append(9223372036854775807) + value.sfixed64_value.append(-9223372036854775808) + value.sfixed64_value.append(9223372036854775807) + value.sint64_value.append(-9223372036854775808) + value.sint64_value.append(9223372036854775807) + value.uint64_value.append(0) + value.uint64_value.append(18446744073709551615) + value.fixed64_value.append(0) + value.fixed64_value.append(18446744073709551615) + value.int32_value.append(-2147483648) + value.int32_value.append(2147483647) + value.sfixed32_value.append(-2147483648) + value.sfixed32_value.append(2147483647) + value.sint32_value.append(-2147483648) + value.sint32_value.append(2147483647) + value.uint32_value.append(0) + value.uint32_value.append(4294967295) + value.fixed32_value.append(0) + value.fixed32_value.append(4294967295) + value.bool_value.append(False) + value.bool_value.append(True) + value.string_value.append("") + value.string_value.append("I refer to the infinite.") + test_case.shapes.append(1) + test_case.sizes.append(3) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(-1.7976931348623158e+308) + field.value.double_value.append(2.2250738585072014e-308) + field.value.double_value.append(1.7976931348623158e+308) + test_case.sizes.append(3) + field = test_case.fields.add() + field.name = "float_value" + field.dtype = types_pb2.DT_FLOAT + field.value.float_value.append(-3.402823466e+38) + field.value.float_value.append(1.175494351e-38) + field.value.float_value.append(3.402823466e+38) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "int64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sfixed64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sint64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "uint64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(0) + field.value.int64_value.append(-1) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "fixed64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(0) + field.value.int64_value.append(-1) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "int32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sfixed32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sint32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(0) + field.value.int32_value.append(-1) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(0) + field.value.int32_value.append(-1) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(False) + field.value.bool_value.append(True) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "string_value" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("") + field.value.string_value.append("I refer to the infinite.") + return test_case + + @staticmethod + def nested_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + message_value = value.message_value.add() + message_value.double_value = 23.5 + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "message_value" + field.dtype = types_pb2.DT_STRING + message_value = field.value.message_value.add() + message_value.double_value = 23.5 + return test_case + + @staticmethod + def optional_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.bool_value.append(True) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(0.0) + return test_case + + @staticmethod + def promote_unsigned_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.fixed32_value.append(4294967295) + value.uint32_value.append(4294967295) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(4294967295) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(4294967295) + # Comes from an explicitly-specified default + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "uint32_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(9) + return test_case + + @staticmethod + def ragged_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.double_value.append(123.0) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(3.1) + value.bool_value.append(False) + test_case.shapes.append(2) + test_case.sizes.append(2) + test_case.sizes.append(1) + test_case.sizes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + field.value.double_value.append(123.0) + field.value.double_value.append(3.1) + field.value.double_value.append(0.0) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + field.value.bool_value.append(False) + return test_case + + @staticmethod + def shaped_batch_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(44.0) + value.bool_value.append(False) + value = test_case.values.add() + value.double_value.append(3.14159) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(1.414) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(-32.2) + value.bool_value.append(False) + value = test_case.values.add() + value.double_value.append(0.0001) + value.bool_value.append(True) + test_case.shapes.append(3) + test_case.shapes.append(2) + for _ in range(12): + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + field.value.double_value.append(44.0) + field.value.double_value.append(3.14159) + field.value.double_value.append(1.414) + field.value.double_value.append(-32.2) + field.value.double_value.append(0.0001) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + field.value.bool_value.append(False) + field.value.bool_value.append(True) + field.value.bool_value.append(True) + field.value.bool_value.append(False) + field.value.bool_value.append(True) + return test_case + + @staticmethod + def simple_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.bool_value.append(True) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + return test_case diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_base.py b/tensorflow/contrib/proto/python/kernel_tests/test_base.py deleted file mode 100644 index 1fc8c16786..0000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/test_base.py +++ /dev/null @@ -1,407 +0,0 @@ -# ============================================================================= -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= -"""Test case base for testing proto operations.""" - -# Python3 preparedness imports. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ctypes as ct -import os - -from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 -from tensorflow.core.framework import types_pb2 -from tensorflow.python.platform import test - - -class ProtoOpTestBase(test.TestCase): - """Base class for testing proto decoding and encoding ops.""" - - def __init__(self, methodName="runTest"): # pylint: disable=invalid-name - super(ProtoOpTestBase, self).__init__(methodName) - lib = os.path.join(os.path.dirname(__file__), "libtestexample.so") - if os.path.isfile(lib): - ct.cdll.LoadLibrary(lib) - - @staticmethod - def named_parameters(): - return ( - ("defaults", ProtoOpTestBase.defaults_test_case()), - ("minmax", ProtoOpTestBase.minmax_test_case()), - ("nested", ProtoOpTestBase.nested_test_case()), - ("optional", ProtoOpTestBase.optional_test_case()), - ("promote_unsigned", ProtoOpTestBase.promote_unsigned_test_case()), - ("ragged", ProtoOpTestBase.ragged_test_case()), - ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), - ("simple", ProtoOpTestBase.simple_test_case()), - ) - - @staticmethod - def defaults_test_case(): - test_case = test_example_pb2.TestCase() - test_case.primitive.add() # No fields specified, so we get all defaults. - test_case.shape.append(1) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "double_default" - field.dtype = types_pb2.DT_DOUBLE - field.expected.double_value.append(1.0) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "float_default" - field.dtype = types_pb2.DT_FLOAT - field.expected.float_value.append(2.0) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "int64_default" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(3) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "sfixed64_default" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(11) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "sint64_default" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(13) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "uint64_default" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(4) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "fixed64_default" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(6) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "int32_default" - field.dtype = types_pb2.DT_INT32 - field.expected.int32_value.append(5) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "sfixed32_default" - field.dtype = types_pb2.DT_INT32 - field.expected.int32_value.append(10) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "sint32_default" - field.dtype = types_pb2.DT_INT32 - field.expected.int32_value.append(12) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "uint32_default" - field.dtype = types_pb2.DT_INT32 - field.expected.int32_value.append(-1) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "fixed32_default" - field.dtype = types_pb2.DT_INT32 - field.expected.int32_value.append(7) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "bool_default" - field.dtype = types_pb2.DT_BOOL - field.expected.bool_value.append(True) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "string_default" - field.dtype = types_pb2.DT_STRING - field.expected.string_value.append("a") - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "bytes_default" - field.dtype = types_pb2.DT_STRING - field.expected.string_value.append("a longer default string") - return test_case - - @staticmethod - def minmax_test_case(): - test_case = test_example_pb2.TestCase() - primitive = test_case.primitive.add() - primitive.double_value.append(-1.7976931348623158e+308) - primitive.double_value.append(2.2250738585072014e-308) - primitive.double_value.append(1.7976931348623158e+308) - primitive.float_value.append(-3.402823466e+38) - primitive.float_value.append(1.175494351e-38) - primitive.float_value.append(3.402823466e+38) - primitive.int64_value.append(-9223372036854775808) - primitive.int64_value.append(9223372036854775807) - primitive.sfixed64_value.append(-9223372036854775808) - primitive.sfixed64_value.append(9223372036854775807) - primitive.sint64_value.append(-9223372036854775808) - primitive.sint64_value.append(9223372036854775807) - primitive.uint64_value.append(0) - primitive.uint64_value.append(18446744073709551615) - primitive.fixed64_value.append(0) - primitive.fixed64_value.append(18446744073709551615) - primitive.int32_value.append(-2147483648) - primitive.int32_value.append(2147483647) - primitive.sfixed32_value.append(-2147483648) - primitive.sfixed32_value.append(2147483647) - primitive.sint32_value.append(-2147483648) - primitive.sint32_value.append(2147483647) - primitive.uint32_value.append(0) - primitive.uint32_value.append(4294967295) - primitive.fixed32_value.append(0) - primitive.fixed32_value.append(4294967295) - primitive.bool_value.append(False) - primitive.bool_value.append(True) - primitive.string_value.append("") - primitive.string_value.append("I refer to the infinite.") - test_case.shape.append(1) - test_case.sizes.append(3) - field = test_case.field.add() - field.name = "double_value" - field.dtype = types_pb2.DT_DOUBLE - field.expected.double_value.append(-1.7976931348623158e+308) - field.expected.double_value.append(2.2250738585072014e-308) - field.expected.double_value.append(1.7976931348623158e+308) - test_case.sizes.append(3) - field = test_case.field.add() - field.name = "float_value" - field.dtype = types_pb2.DT_FLOAT - field.expected.float_value.append(-3.402823466e+38) - field.expected.float_value.append(1.175494351e-38) - field.expected.float_value.append(3.402823466e+38) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "int64_value" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(-9223372036854775808) - field.expected.int64_value.append(9223372036854775807) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "sfixed64_value" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(-9223372036854775808) - field.expected.int64_value.append(9223372036854775807) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "sint64_value" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(-9223372036854775808) - field.expected.int64_value.append(9223372036854775807) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "uint64_value" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(0) - field.expected.int64_value.append(-1) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "fixed64_value" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(0) - field.expected.int64_value.append(-1) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "int32_value" - field.dtype = types_pb2.DT_INT32 - field.expected.int32_value.append(-2147483648) - field.expected.int32_value.append(2147483647) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "sfixed32_value" - field.dtype = types_pb2.DT_INT32 - field.expected.int32_value.append(-2147483648) - field.expected.int32_value.append(2147483647) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "sint32_value" - field.dtype = types_pb2.DT_INT32 - field.expected.int32_value.append(-2147483648) - field.expected.int32_value.append(2147483647) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "uint32_value" - field.dtype = types_pb2.DT_INT32 - field.expected.int32_value.append(0) - field.expected.int32_value.append(-1) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "fixed32_value" - field.dtype = types_pb2.DT_INT32 - field.expected.int32_value.append(0) - field.expected.int32_value.append(-1) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "bool_value" - field.dtype = types_pb2.DT_BOOL - field.expected.bool_value.append(False) - field.expected.bool_value.append(True) - test_case.sizes.append(2) - field = test_case.field.add() - field.name = "string_value" - field.dtype = types_pb2.DT_STRING - field.expected.string_value.append("") - field.expected.string_value.append("I refer to the infinite.") - return test_case - - @staticmethod - def nested_test_case(): - test_case = test_example_pb2.TestCase() - primitive = test_case.primitive.add() - message_value = primitive.message_value.add() - message_value.double_value = 23.5 - test_case.shape.append(1) - test_case.sizes.append(1) - field = test_case.field.add() - field.name = "message_value" - field.dtype = types_pb2.DT_STRING - message_value = field.expected.message_value.add() - message_value.double_value = 23.5 - return test_case - - @staticmethod - def optional_test_case(): - test_case = test_example_pb2.TestCase() - primitive = test_case.primitive.add() - primitive.bool_value.append(True) - test_case.shape.append(1) - test_case.sizes.append(1) - field = test_case.field.add() - field.name = "bool_value" - field.dtype = types_pb2.DT_BOOL - field.expected.bool_value.append(True) - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "double_value" - field.dtype = types_pb2.DT_DOUBLE - field.expected.double_value.append(0.0) - return test_case - - @staticmethod - def promote_unsigned_test_case(): - test_case = test_example_pb2.TestCase() - primitive = test_case.primitive.add() - primitive.fixed32_value.append(4294967295) - primitive.uint32_value.append(4294967295) - test_case.shape.append(1) - test_case.sizes.append(1) - field = test_case.field.add() - field.name = "fixed32_value" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(4294967295) - test_case.sizes.append(1) - field = test_case.field.add() - field.name = "uint32_value" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(4294967295) - # Comes from an explicitly-specified default - test_case.sizes.append(0) - field = test_case.field.add() - field.name = "uint32_default" - field.dtype = types_pb2.DT_INT64 - field.expected.int64_value.append(4294967295) - return test_case - - @staticmethod - def ragged_test_case(): - test_case = test_example_pb2.TestCase() - primitive = test_case.primitive.add() - primitive.double_value.append(23.5) - primitive.double_value.append(123.0) - primitive.bool_value.append(True) - primitive = test_case.primitive.add() - primitive.double_value.append(3.1) - primitive.bool_value.append(False) - test_case.shape.append(2) - test_case.sizes.append(2) - test_case.sizes.append(1) - test_case.sizes.append(1) - test_case.sizes.append(1) - field = test_case.field.add() - field.name = "double_value" - field.dtype = types_pb2.DT_DOUBLE - field.expected.double_value.append(23.5) - field.expected.double_value.append(123.0) - field.expected.double_value.append(3.1) - field.expected.double_value.append(0.0) - field = test_case.field.add() - field.name = "bool_value" - field.dtype = types_pb2.DT_BOOL - field.expected.bool_value.append(True) - field.expected.bool_value.append(False) - return test_case - - @staticmethod - def shaped_batch_test_case(): - test_case = test_example_pb2.TestCase() - primitive = test_case.primitive.add() - primitive.double_value.append(23.5) - primitive.bool_value.append(True) - primitive = test_case.primitive.add() - primitive.double_value.append(44.0) - primitive.bool_value.append(False) - primitive = test_case.primitive.add() - primitive.double_value.append(3.14159) - primitive.bool_value.append(True) - primitive = test_case.primitive.add() - primitive.double_value.append(1.414) - primitive.bool_value.append(True) - primitive = test_case.primitive.add() - primitive.double_value.append(-32.2) - primitive.bool_value.append(False) - primitive = test_case.primitive.add() - primitive.double_value.append(0.0001) - primitive.bool_value.append(True) - test_case.shape.append(3) - test_case.shape.append(2) - for _ in range(12): - test_case.sizes.append(1) - field = test_case.field.add() - field.name = "double_value" - field.dtype = types_pb2.DT_DOUBLE - field.expected.double_value.append(23.5) - field.expected.double_value.append(44.0) - field.expected.double_value.append(3.14159) - field.expected.double_value.append(1.414) - field.expected.double_value.append(-32.2) - field.expected.double_value.append(0.0001) - field = test_case.field.add() - field.name = "bool_value" - field.dtype = types_pb2.DT_BOOL - field.expected.bool_value.append(True) - field.expected.bool_value.append(False) - field.expected.bool_value.append(True) - field.expected.bool_value.append(True) - field.expected.bool_value.append(False) - field.expected.bool_value.append(True) - return test_case - - @staticmethod - def simple_test_case(): - test_case = test_example_pb2.TestCase() - primitive = test_case.primitive.add() - primitive.double_value.append(23.5) - primitive.bool_value.append(True) - test_case.shape.append(1) - test_case.sizes.append(1) - field = test_case.field.add() - field.name = "double_value" - field.dtype = types_pb2.DT_DOUBLE - field.expected.double_value.append(23.5) - test_case.sizes.append(1) - field = test_case.field.add() - field.name = "bool_value" - field.dtype = types_pb2.DT_BOOL - field.expected.bool_value.append(True) - return test_case diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto index a2c88e372b..674d881220 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto +++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto @@ -1,6 +1,4 @@ // Test description and protos to work with it. -// -// Many of the protos in this file are for unit tests that haven't been written yet. syntax = "proto2"; @@ -8,54 +6,27 @@ import "tensorflow/core/framework/types.proto"; package tensorflow.contrib.proto; -// A TestCase holds a proto and a bunch of assertions -// about how it should decode. +// A TestCase holds a proto and assertions about how it should decode. message TestCase { - // A batch of primitives to be serialized and decoded. - repeated RepeatedPrimitiveValue primitive = 1; - // The shape of the batch. - repeated int32 shape = 2; + // Batches of primitive values. + repeated TestValue values = 1; + // The batch shapes. + repeated int32 shapes = 2; // Expected sizes for each field. repeated int32 sizes = 3; // Expected values for each field. - repeated FieldSpec field = 4; + repeated FieldSpec fields = 4; }; // FieldSpec describes the expected output for a single field. message FieldSpec { optional string name = 1; optional tensorflow.DataType dtype = 2; - optional RepeatedPrimitiveValue expected = 3; + optional TestValue value = 3; }; +// NOTE: This definition must be kept in sync with PackedTestValue. message TestValue { - optional PrimitiveValue primitive_value = 1; - optional EnumValue enum_value = 2; - optional MessageValue message_value = 3; - optional RepeatedMessageValue repeated_message_value = 4; - optional RepeatedPrimitiveValue repeated_primitive_value = 6; -} - -message PrimitiveValue { - optional double double_value = 1; - optional float float_value = 2; - optional int64 int64_value = 3; - optional uint64 uint64_value = 4; - optional int32 int32_value = 5; - optional fixed64 fixed64_value = 6; - optional fixed32 fixed32_value = 7; - optional bool bool_value = 8; - optional string string_value = 9; - optional bytes bytes_value = 12; - optional uint32 uint32_value = 13; - optional sfixed32 sfixed32_value = 15; - optional sfixed64 sfixed64_value = 16; - optional sint32 sint32_value = 17; - optional sint64 sint64_value = 18; -} - -// NOTE: This definition must be kept in sync with PackedPrimitiveValue. -message RepeatedPrimitiveValue { repeated double double_value = 1; repeated float float_value = 2; repeated int64 int64_value = 3; @@ -74,30 +45,31 @@ message RepeatedPrimitiveValue { repeated PrimitiveValue message_value = 19; // Optional fields with explicitly-specified defaults. - optional double double_default = 20 [default = 1.0]; - optional float float_default = 21 [default = 2.0]; - optional int64 int64_default = 22 [default = 3]; - optional uint64 uint64_default = 23 [default = 4]; - optional int32 int32_default = 24 [default = 5]; - optional fixed64 fixed64_default = 25 [default = 6]; - optional fixed32 fixed32_default = 26 [default = 7]; - optional bool bool_default = 27 [default = true]; - optional string string_default = 28 [default = "a"]; - optional bytes bytes_default = 29 [default = "a longer default string"]; - optional uint32 uint32_default = 30 [default = 4294967295]; - optional sfixed32 sfixed32_default = 31 [default = 10]; - optional sfixed64 sfixed64_default = 32 [default = 11]; - optional sint32 sint32_default = 33 [default = 12]; - optional sint64 sint64_default = 34 [default = 13]; + optional double double_value_with_default = 20 [default = 1.0]; + optional float float_value_with_default = 21 [default = 2.0]; + optional int64 int64_value_with_default = 22 [default = 3]; + optional uint64 uint64_value_with_default = 23 [default = 4]; + optional int32 int32_value_with_default = 24 [default = 5]; + optional fixed64 fixed64_value_with_default = 25 [default = 6]; + optional fixed32 fixed32_value_with_default = 26 [default = 7]; + optional bool bool_value_with_default = 27 [default = true]; + optional string string_value_with_default = 28 [default = "a"]; + optional bytes bytes_value_with_default = 29 + [default = "a longer default string"]; + optional uint32 uint32_value_with_default = 30 [default = 9]; + optional sfixed32 sfixed32_value_with_default = 31 [default = 10]; + optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; + optional sint32 sint32_value_with_default = 33 [default = 12]; + optional sint64 sint64_value_with_default = 34 [default = 13]; } -// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue -// in the text format, but the binary serializion is different. -// We test the packed representations by loading the same test cases -// using this definition instead of RepeatedPrimitiveValue. -// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue -// in every way except the packed=true declaration. -message PackedPrimitiveValue { +// A PackedTestValue looks exactly the same as a TestValue in the text format, +// but the binary serializion is different. We test the packed representations +// by loading the same test cases using this definition instead of TestValue. +// +// NOTE: This definition must be kept in sync with TestValue in every way except +// the packed=true declaration. +message PackedTestValue { repeated double double_value = 1 [packed = true]; repeated float float_value = 2 [packed = true]; repeated int64 int64_value = 3 [packed = true]; @@ -115,23 +87,53 @@ message PackedPrimitiveValue { repeated sint64 sint64_value = 18 [packed = true]; repeated PrimitiveValue message_value = 19; - optional double double_default = 20 [default = 1.0]; - optional float float_default = 21 [default = 2.0]; - optional int64 int64_default = 22 [default = 3]; - optional uint64 uint64_default = 23 [default = 4]; - optional int32 int32_default = 24 [default = 5]; - optional fixed64 fixed64_default = 25 [default = 6]; - optional fixed32 fixed32_default = 26 [default = 7]; - optional bool bool_default = 27 [default = true]; - optional string string_default = 28 [default = "a"]; - optional bytes bytes_default = 29 [default = "a longer default string"]; - optional uint32 uint32_default = 30 [default = 4294967295]; - optional sfixed32 sfixed32_default = 31 [default = 10]; - optional sfixed64 sfixed64_default = 32 [default = 11]; - optional sint32 sint32_default = 33 [default = 12]; - optional sint64 sint64_default = 34 [default = 13]; + optional double double_value_with_default = 20 [default = 1.0]; + optional float float_value_with_default = 21 [default = 2.0]; + optional int64 int64_value_with_default = 22 [default = 3]; + optional uint64 uint64_value_with_default = 23 [default = 4]; + optional int32 int32_value_with_default = 24 [default = 5]; + optional fixed64 fixed64_value_with_default = 25 [default = 6]; + optional fixed32 fixed32_value_with_default = 26 [default = 7]; + optional bool bool_value_with_default = 27 [default = true]; + optional string string_value_with_default = 28 [default = "a"]; + optional bytes bytes_value_with_default = 29 + [default = "a longer default string"]; + optional uint32 uint32_value_with_default = 30 [default = 9]; + optional sfixed32 sfixed32_value_with_default = 31 [default = 10]; + optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; + optional sint32 sint32_value_with_default = 33 [default = 12]; + optional sint64 sint64_value_with_default = 34 [default = 13]; } +message PrimitiveValue { + optional double double_value = 1; + optional float float_value = 2; + optional int64 int64_value = 3; + optional uint64 uint64_value = 4; + optional int32 int32_value = 5; + optional fixed64 fixed64_value = 6; + optional fixed32 fixed32_value = 7; + optional bool bool_value = 8; + optional string string_value = 9; + optional bytes bytes_value = 12; + optional uint32 uint32_value = 13; + optional sfixed32 sfixed32_value = 15; + optional sfixed64 sfixed64_value = 16; + optional sint32 sint32_value = 17; + optional sint64 sint64_value = 18; +} + +// Message containing fields with field numbers higher than any field above. +// An instance of this message is prepended to each binary message in the test +// to exercise the code path that handles fields encoded out of order of field +// number. +message ExtraFields { + optional string string_value = 1776; + optional bool bool_value = 1777; +} + +// The messages below are for yet-to-be created tests. + message EnumValue { enum Color { RED = 0; @@ -171,12 +173,3 @@ message RepeatedMessageValue { repeated NestedMessageValue message_values = 11; } - -// Message containing fields with field numbers higher than any field above. An -// instance of this message is prepended to each binary message in the test to -// exercise the code path that handles fields encoded out of order of field -// number. -message ExtraFields { - optional string string_value = 1776; - optional bool bool_value = 1777; -} diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 178328619f..4073b390fc 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -132,6 +132,48 @@ class TestGatherTree(test.TestCase): def test_gather_tree_from_array_2d(self): self._test_gather_tree_from_array(depth_ndims=2) + def test_gather_tree_from_array_complex_trajectory(self): + # Max. time = 7, batch = 1, beam = 5. + array = np.expand_dims(np.array( + [[[25, 12, 114, 89, 97]], + [[9, 91, 64, 11, 162]], + [[34, 34, 34, 34, 34]], + [[2, 4, 2, 2, 4]], + [[2, 3, 6, 2, 2]], + [[2, 2, 2, 3, 2]], + [[2, 2, 2, 2, 2]]]), -1) + parent_ids = np.array( + [[[0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0]], + [[0, 1, 2, 3, 4]], + [[0, 0, 1, 2, 1]], + [[0, 1, 1, 2, 3]], + [[0, 1, 3, 1, 2]], + [[0, 1, 2, 3, 4]]]) + expected_array = np.expand_dims(np.array( + [[[25, 25, 25, 25, 25]], + [[9, 9, 91, 9, 9]], + [[34, 34, 34, 34, 34]], + [[2, 4, 2, 4, 4]], + [[2, 3, 6, 3, 6]], + [[2, 2, 2, 3, 2]], + [[2, 2, 2, 2, 2]]]), -1) + sequence_length = [[4, 6, 4, 7, 6]] + + array = ops.convert_to_tensor( + array, dtype=dtypes.float32) + parent_ids = ops.convert_to_tensor( + parent_ids, dtype=dtypes.int32) + expected_array = ops.convert_to_tensor( + expected_array, dtype=dtypes.float32) + + sorted_array = beam_search_decoder.gather_tree_from_array( + array, parent_ids, sequence_length) + + with self.test_session() as sess: + sorted_array, expected_array = sess.run([sorted_array, expected_array]) + self.assertAllEqual(expected_array, sorted_array) + class TestArrayShapeChecks(test.TestCase): diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index c7fbeea310..f17dbb0fe3 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -145,24 +145,20 @@ def gather_tree_from_array(t, parent_ids, sequence_length): array_ops.expand_dims(math_ops.range(beam_width), 0), 0) beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) - mask = array_ops.sequence_mask( - sequence_length, maxlen=max_time, dtype=dtypes.int32) - mask = array_ops.transpose(mask, perm=[2, 0, 1]) - - # Use beam_width + 1 to mark the end of beam. - masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1) - max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(sequence_length, axis=1)) sorted_beam_ids = beam_search_ops.gather_tree( - step_ids=masked_beam_ids, + step_ids=beam_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=beam_width + 1) # For out of range steps, simply copy the same beam. + in_bound_steps = array_ops.transpose( + array_ops.sequence_mask(sequence_length, maxlen=max_time), + perm=[2, 0, 1]) sorted_beam_ids = array_ops.where( - math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids) + in_bound_steps, x=sorted_beam_ids, y=beam_ids) # Generate indices for gather_nd. time_ind = array_ops.tile(array_ops.reshape( diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 5bfc7f9109..6ebc30ca82 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -86,27 +86,48 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(jie): Segmentation shouldn't associated with op name. // Split it into a registration for each kernel. static const std::set<string> candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", - "AvgPool", - "ConcatV2", - "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", - // TODO(ben,jie): ... + "Identity", + "Snapshot", + "Const", + "Conv2D", + "MaxPool", + "BiasAdd", + "Relu", + "Add", + "Mul", + "Sub", + "Rsqrt", + "Pad", + "Mean", + "AvgPool", + "ConcatV2", + "DepthwiseConv2dNative", + "FusedBatchNorm", + "FusedBatchNormV2", + "Div", + "RealDiv", + "Rsqrt", + "Reciprocal", + "Exp", + "Log", + "Sqrt", + "Abs", + "Neg", +#if NV_TENSORRT_MAJOR > 3 + "MatMul", + "BatchMatMul", + "Softmax", + "Minimum", + "Maximum", + "TopKV2", + "Sum", + "Prod", + "Max", + "Min", +#endif + // TODO(ben,jie): ... }; - // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) + // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) return (candidate_ops.count(node->type_string()) || PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 146b9c7344..0ee708bc1c 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -49,9 +49,29 @@ limitations under the License. #if GOOGLE_TENSORRT #include "tensorrt/include/NvInfer.h" -// Check if the types are equal. Cast to int first so that failure log message -// would work! -#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) +// Check if the types are equal. Cast to int first so that failure log message +// would work! +#define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) + +#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ + do { \ + return tensorflow::errors::Internal( \ + "TFTRT::", __FUNCTION__, "failed to add TRT layer, at: ", node); \ + } while (0) + +#define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \ + do { \ + if (status == false) { \ + TFTRT_INTERNAL_ERROR_AT_NODE(node); \ + } \ + } while (0) + +#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \ + do { \ + if (ptr == nullptr) { \ + TFTRT_INTERNAL_ERROR_AT_NODE(node); \ + } \ + } while (0) namespace tensorflow { namespace tensorrt { @@ -75,13 +95,110 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, case tensorflow::DataType::DT_HALF: *trt_dtype = nvinfer1::DataType::kHALF; break; +#if NV_TENSORRT_MAJOR > 3 + case tensorflow::DataType::DT_INT32: + *trt_dtype = nvinfer1::DataType::kINT32; + break; +#endif default: return tensorflow::errors::InvalidArgument( - "Unsupported data type " + tensorflow::DataTypeString(tf_dtype)); + "Unsupported data type ", tensorflow::DataTypeString(tf_dtype)); } return tensorflow::Status::OK(); } +// Return whether or not the broadcast is feasible; +bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l, + const bool operand_l_is_tensor, + const nvinfer1::Dims& operand_r, + const bool operand_r_is_tensor, + nvinfer1::Dims* operand_l_new_shape, + nvinfer1::Dims* operand_r_new_shape) { + // *************************************************************************** + // TensorRT Elementwise op supports broadcast but requires both tensor to be + // of Identical rank + // + // We consider case of: + // 1. operand_l to be a Tensor & operand_r to be a Const; + // 2. operand_l to be a Tensor & operand_r to be a Tensor; + // note: const op const (constant folding) should fallback to TensorFlow + // + // broadcast scheme: + // T: 1 3 5 (tensor would not have batch dimension) + // W: 1 1 3 1 (weight would have all explicit dimensions) + // i. fill in explicit dimensions + // -> T: -1 1 3 5 (we put a -1 for batch dimension) + // -> W: 1 1 3 1 + // ii. compare broadcast feasibility + // + // We cannot support the following since TensorRT does not allow manipulation + // on batch dimension, we cannot generate output with proper shape + // T: 3 5 1 + // W: 1 1 1 1 3 5 1 + // -> T: 1 1 1 -1 3 5 1 + // -> W: 1 1 1 1 3 5 1 + // *************************************************************************** + const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; + const size_t element_size = sizeof(operand_l.d[0]); + + // fill in dimensions + int l_s[max_nb_dims]; + std::fill(l_s, l_s + max_nb_dims, 1); + int l_d = operand_l_is_tensor ? operand_l.nbDims + 1 : operand_l.nbDims; + int r_s[max_nb_dims]; + std::fill(r_s, r_s + max_nb_dims, 1); + int r_d = operand_r_is_tensor ? operand_r.nbDims + 1 : operand_r.nbDims; + + int max_d = std::max(l_d, r_d); + std::memcpy(l_s + max_d - operand_l.nbDims, operand_l.d, + operand_l.nbDims * element_size); + std::memcpy(r_s + max_d - operand_r.nbDims, operand_r.d, + operand_r.nbDims * element_size); + + // set -1 for batch dimension, since batch size is not supposed to be + // broadcasted + if (operand_l_is_tensor) { + if (max_d != l_d) { // if broadcast beyond batch dimension, fail + return false; + } + l_s[0] = -1; + } + if (operand_r_is_tensor) { + if (max_d != r_d) { // if broadcast beyond batch dimension, fail + return false; + } + r_s[0] = -1; + } + + // compare broadcast feasibility + for (int i = max_d - 1; i >= 0; i--) { + if ((l_s[i] != r_s[i]) && (l_s[i] != 1) && (r_s[i] != 1)) { + return false; + } + } + + // output new TensorRT Dimension (stripping the batch dimension) + operand_l_new_shape->nbDims = max_d - 1; + std::memcpy(operand_l_new_shape->d, l_s + 1, (max_d - 1) * element_size); + operand_r_new_shape->nbDims = max_d - 1; + std::memcpy(operand_r_new_shape->d, r_s + 1, (max_d - 1) * element_size); + + return true; +} + +inline bool DimsEqual(const nvinfer1::Dims& dim_l, + const nvinfer1::Dims& dim_r) { + if (dim_l.nbDims != dim_r.nbDims) { + return false; + } + for (int i = 0; i < dim_l.nbDims; i++) { + if (dim_l.d[i] != dim_r.d[i]) { + return false; + } + } + return true; +} + inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) { nvinfer1::Dims dims; dims.nbDims = tensor.dims(); @@ -91,7 +208,7 @@ inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) { return dims; } -inline int64_t GetShapeSize(nvinfer1::Dims shape) { +inline int64_t GetShapeSize(const nvinfer1::Dims& shape) { // Returns total number of elements in shape int64_t count = 1; for (int d = 0; d < shape.nbDims; ++d) { @@ -104,7 +221,7 @@ static std::vector<std::pair<int, int>> CreateSamePadding( const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel, const std::vector<int64_t>& input_dims) { std::vector<std::pair<int, int>> padding(input_dims.size()); - CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+? + CHECK_EQ(stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+? for (size_t i = 0; i < input_dims.size(); ++i) { // Formula to calculate the padding @@ -134,6 +251,7 @@ string GetCommonNameScope(const string& op_name_a, const string& op_name_b) { return op_name_a.substr(0, last_scope_separator); } +// Class to convert TF weight to TRT weight. class TRT_ShapedWeights { public: TRT_ShapedWeights(tensorflow::DataType type, const void* values, @@ -145,12 +263,14 @@ class TRT_ShapedWeights { explicit TRT_ShapedWeights(tensorflow::DataType type) : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {} + // TODO(aaroey): use rvalue reference. TRT_ShapedWeights(const TRT_ShapedWeights& rhs) : shape_(rhs.shape_), type_(rhs.type_), values_(rhs.values_), empty_weight_flag_(rhs.empty_weight_flag_) {} + // TODO(aaroey): use GetShapeSize() instead. int64_t count() const { int64_t c = 1; for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i]; @@ -168,6 +288,7 @@ class TRT_ShapedWeights { const void* GetValues() const { return values_; } + // TODO(aaroey): get rid of this method. void SetValues(const void* values) { values_ = values; } size_t size_bytes() const { @@ -178,10 +299,12 @@ class TRT_ShapedWeights { // Default converter operator nvinfer1::Weights() const { return GetWeightsForTRT(); } + // TODO(aaroey): make these private. nvinfer1::Dims shape_; tensorflow::DataType type_; private: + // TODO(aaroey): this should not be const as it's always from TRTWeightStore. const void* values_; bool empty_weight_flag_; }; @@ -192,6 +315,7 @@ class TRT_TensorOrWeights { : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {} explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights) : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {} + // TODO(aaroey): use rvalue reference. TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {} ~TRT_TensorOrWeights() {} @@ -200,19 +324,19 @@ class TRT_TensorOrWeights { bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; } nvinfer1::ITensor* tensor() { - CHECK_EQ(is_tensor(), true); + CHECK(is_tensor()); return tensor_; } const nvinfer1::ITensor* tensor() const { - CHECK_EQ(is_tensor(), true); + CHECK(is_tensor()); return tensor_; } TRT_ShapedWeights& weights() { - CHECK_EQ(is_weights(), true); + CHECK(is_weights()); return weights_; } const TRT_ShapedWeights& weights() const { - CHECK_EQ(is_weights(), true); + CHECK(is_weights()); return weights_; } nvinfer1::Dims shape() const { @@ -236,21 +360,25 @@ class TFAttrs { attrs_.insert({attr.first, &attr.second}); } } - bool count(string key) const { return attrs_.count(key); } - tensorflow::AttrValue const* at(string key) const { + + bool count(const string& key) const { return attrs_.count(key); } + + tensorflow::AttrValue const* at(const string& key) const { if (!attrs_.count(key)) { LOG(FATAL) << "Attribute not found: " << key; } return attrs_.at(key); } + template <typename T> T get(const string& key) const; + template <typename T> T get(const string& key, const T& default_value) const { return attrs_.count(key) ? this->get<T>(key) : default_value; } - std::vector<string> GetAllAttrKey() { + std::vector<string> GetAllAttrKeys() const { std::vector<string> attr_list; for (const auto& attr_item : attrs_) { attr_list.emplace_back(attr_item.first); @@ -285,15 +413,6 @@ std::vector<string> TFAttrs::get<std::vector<string>>(const string& key) const { auto attr = this->at(key)->list().s(); return std::vector<string>(attr.begin(), attr.end()); } -template <> -nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(const string& key) const { - auto values = this->get<std::vector<int>>(key); - nvinfer1::Dims dims; - dims.nbDims = values.size(); - std::copy(values.begin(), values.end(), dims.d); - // Note: No dimension type information is included - return dims; -} template <> nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const { @@ -319,10 +438,11 @@ bool TFAttrs::get<bool>(const string& key) const { } // TODO(jie): reorder4 & reorder2 should be merged? +// TODO(aaroey): fix the order of parameters. template <typename T> -void Reorder4(nvinfer1::DimsNCHW shape, const T* idata, - nvinfer1::DimsNCHW istrides, T* odata, - nvinfer1::DimsNCHW ostrides) { +void Reorder4(const nvinfer1::DimsNCHW& shape, const T* idata, + const nvinfer1::DimsNCHW& istrides, T* odata, + const nvinfer1::DimsNCHW& ostrides) { for (int n = 0; n < shape.n(); ++n) { for (int c = 0; c < shape.c(); ++c) { for (int h = 0; h < shape.h(); ++h) { @@ -337,12 +457,13 @@ void Reorder4(nvinfer1::DimsNCHW shape, const T* idata, } template <typename T> -void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides, - T* odata, nvinfer1::DimsHW ostrides) { +void Reorder2(const nvinfer1::DimsHW& shape, const T* idata, + const nvinfer1::DimsHW& istrides, T* odata, + const nvinfer1::DimsHW& ostrides) { for (int h = 0; h < shape.h(); ++h) { for (int w = 0; w < shape.w(); ++w) { odata[h * ostrides.h() + w * ostrides.w()] = - idata[h * ostrides.h() + w * ostrides.w()]; + idata[h * istrides.h() + w * istrides.w()]; } } } @@ -350,16 +471,17 @@ void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides, // TODO(jie): fallback to tensorflow!! void ReorderCKtoKC(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights) { - int c = iweights.shape_.d[0]; - int k = iweights.shape_.d[1]; + const int c = iweights.shape_.d[0]; + const int k = iweights.shape_.d[1]; oweights->shape_.d[0] = k; oweights->shape_.d[1] = c; - nvinfer1::DimsHW istrides = {1, k}; - nvinfer1::DimsHW ostrides = {c, 1}; + const nvinfer1::DimsHW istrides = {1, k}; + const nvinfer1::DimsHW ostrides = {c, 1}; switch (iweights.type_) { case tensorflow::DataType::DT_FLOAT: { Reorder2({k, c}, static_cast<float const*>(iweights.GetValues()), istrides, + // TODO(aaroey): get rid of all the const_cast like this. static_cast<float*>(const_cast<void*>(oweights->GetValues())), ostrides); break; @@ -382,21 +504,24 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, int num_groups) { CHECK_EQ(iweights.type_, oweights->type_); CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); - int r = iweights.shape_.d[0]; - int s = iweights.shape_.d[1]; - // TRT requires GKcRS, while TF depthwise has RSCK - // where c=1, C=G + // K indexes over output channels, C over input channels, and R and S over the + // height and width of the convolution + const int r = iweights.shape_.d[0]; + const int s = iweights.shape_.d[1]; + // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G VLOG(2) << "num_groups: " << num_groups; - int c = iweights.shape_.d[2] / num_groups; + const int c = iweights.shape_.d[2] / num_groups; VLOG(2) << "c" << iweights.shape_.d[2] << " then " << c; - int k = iweights.shape_.d[3] * num_groups; + const int k = iweights.shape_.d[3] * num_groups; VLOG(2) << "k" << iweights.shape_.d[3] << " then " << k; + VLOG(2) << "r" << iweights.shape_.d[0] << " then " << r; + VLOG(2) << "s" << iweights.shape_.d[1] << " then " << s; oweights->shape_.d[0] = k / num_groups; oweights->shape_.d[1] = c * num_groups; oweights->shape_.d[2] = r; oweights->shape_.d[3] = s; - nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; - nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; + const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; + const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; switch (iweights.type_) { case tensorflow::DataType::DT_FLOAT: { Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()), @@ -428,11 +553,14 @@ using OpConverter = std::vector<TRT_TensorOrWeights>*)>; class Converter { + // TODO(aaroey): fix the order of members. std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_; std::unordered_map<string, OpConverter> op_registry_; OpConverter plugin_converter_; nvinfer1::INetworkDefinition* trt_network_; std::list<std::vector<uint8_t>> temp_bufs_; + // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to + // operate the stored weights instead of operating it directly. TRTWeightStore* weight_store_; bool fp16_; void register_op_converters(); @@ -440,7 +568,7 @@ class Converter { std::vector<TRT_TensorOrWeights>* inputs) { for (auto const& input_name : node_def.input()) { /************************************************************************* - * TODO(jie) handle case 1) here + * TODO(jie): handle case 1) here. * Normalizes the inputs and extracts associated metadata: * 1) Inputs can contain a colon followed by a suffix of characters. * That suffix may be a single number (e.g. inputName:1) or several @@ -454,6 +582,7 @@ class Converter { if (input_name[0] == '^') continue; string name = input_name; auto first = name.find_first_of(':'); + // TODO(aaroey): why removing the colon but not the zero? A bug? if (first != string::npos && first + 2 == name.size() && name[first + 1] == '0') name.erase(first); @@ -462,12 +591,13 @@ class Converter { if (trt_tensors_.count(name)) { inputs->push_back(trt_tensors_.at(name)); } else { - string str("Node "); - StrAppend(&str, node_def.name(), " should have an input named '", name, + // TODO(aaroey): this should not happen, make it a CHECK. + // TODO(aaroey): use StrCat for pattern like this. + string msg("Node "); + StrAppend(&msg, node_def.name(), " should have an input named '", name, "' but it is not available"); - LOG(WARNING) << "input: " << name << " not available for node at " - << node_def.name(); - return tensorflow::errors::InvalidArgument(str); + LOG(ERROR) << msg; + return tensorflow::errors::InvalidArgument(msg); } } return tensorflow::Status::OK(); @@ -488,6 +618,7 @@ class Converter { weights.SetValues(weight_store_->store_.back().data()); return weights; } + // TODO(aaroey): fix all the namings. bool isFP16() { return fp16_; } TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) { return this->get_temp_weights(weights.type_, weights.shape_); @@ -496,7 +627,7 @@ class Converter { tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) { std::vector<TRT_TensorOrWeights> inputs; TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); - string op = node_def.op(); + const string& op = node_def.op(); std::vector<TRT_TensorOrWeights> outputs; if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) { TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); @@ -509,7 +640,7 @@ class Converter { TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); } for (size_t i = 0; i < outputs.size(); ++i) { - TRT_TensorOrWeights output = outputs.at(i); + TRT_TensorOrWeights& output = outputs[i]; // TODO(jie): tf protobuf seems to be omitting the :0 suffix string output_name = node_def.name(); if (i != 0) output_name = StrCat(output_name, ":", i); @@ -527,26 +658,29 @@ class Converter { nvinfer1::INetworkDefinition* network() { return trt_network_; } - TRT_TensorOrWeights get_tensor(string name) { + TRT_TensorOrWeights get_tensor(const string& name) { if (!trt_tensors_.count(name)) { return TRT_TensorOrWeights(nullptr); } return trt_tensors_.at(name); } - bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) { + bool insert_input_tensor(const string& name, nvinfer1::ITensor* tensor) { return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second; } nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor, - std::vector<int> order) { - auto dims = input_tensor->getDimensions(); + const std::vector<int>& order) { + const auto dims = input_tensor->getDimensions(); // TODO(jie): change the return to status and properly exit if (order.size() - 1 != size_t(dims.nbDims)) LOG(ERROR) << "Dimension does not match, fail gracefully"; nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); + if (layer == nullptr) { + return nullptr; + } nvinfer1::Permutation permutation; for (int32_t i = 0; i < dims.nbDims; ++i) { permutation.order[i] = order[i + 1] - 1; @@ -577,13 +711,14 @@ TRT_ShapedWeights ConvertFP32ToFP16(Converter& ctx, } return weights; } + // **************************************************************************** // Constant folding functions // TODO(jie): once optimizer kicks in, we should have done constant folding // there. -//*****************************************************************************/ +// ***************************************************************************** struct LambdaFactory { - enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB }; + enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB, RECIP }; OP_CATEGORY op; template <typename T> @@ -595,6 +730,8 @@ struct LambdaFactory { } case OP_CATEGORY::NEG: return [](T t) -> T { return -t; }; + case OP_CATEGORY::RECIP: + return [](T t) -> T { return 1.0 / t; }; default: VLOG(2) << "Not supported op for unary: " << static_cast<int>(op); return nullptr; @@ -628,7 +765,6 @@ struct LambdaFactory { VLOG(2) << "LAMBDA VAL : " << val; return l + val; }; - // Return [val](T l)-> T {return l+val;}; case OP_CATEGORY::SUB: return [val](T l) -> T { VLOG(2) << "LAMBDA VAL : " << val; @@ -688,11 +824,13 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() { } case OP_CATEGORY::NEG: return [](Eigen::half t) -> Eigen::half { return -t; }; + // TODO(aaroey): can we support RECIP? default: VLOG(2) << "Not supported op for unary: " << static_cast<int>(op); return nullptr; } } + tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, LambdaFactory unary_op) { @@ -738,6 +876,7 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, if (iweights_l.count() != iweights_r.count()) { // We only supports broadcast of RankZero if (iweights_l.count() == 1) { + // TODO(aaroey): Remove loggings like this. VLOG(2) << "I bet it is not working!" << (*inp_l); std::transform(inp_r, inp_r + iweights_r.count(), oup, binary_op.broadcast_l<float>(*inp_l)); @@ -790,117 +929,21 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, return tensorflow::Status::OK(); } -tensorflow::Status ConstantFoldUnary( - Converter& ctx, const tensorflow::NodeDef& node_def, - const std::vector<TRT_TensorOrWeights>& inputs, - std::vector<TRT_TensorOrWeights>* outputs) { - TRT_ShapedWeights weights_input = inputs.at(0).weights(); - - // Allocate output weights - TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input); - - // FIXME assume type matches input weights - // Get trt type & shape - // Maybe this part has to be moved into the block of rsqrt later - // Check type consistency - CHECK_EQ(weights_input.type_, - TFAttrs(node_def).get<tensorflow::DataType>("T")); - - LambdaFactory unary_op; - if (node_def.op() == "Rsqrt") { - // Compute rsqrt - unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT; - auto ret = UnaryCompute(weights_input, &weights_output, unary_op); - // Pass the output - if (ret == tensorflow::Status::OK()) { - outputs->push_back(TRT_TensorOrWeights(weights_output)); - } - return ret; - } else { - return tensorflow::errors::Unimplemented("Binary op not supported: " + - node_def.op()); - } -} - -// TODO(jie,ben) broadcast is needed yet not implemented -// Let's get the simple stuff working first. Maybe we should fall back to TF -// approach for constant folding -tensorflow::Status ConstantFoldBinary( - Converter& ctx, const tensorflow::NodeDef& node_def, - const std::vector<TRT_TensorOrWeights>& inputs, - std::vector<TRT_TensorOrWeights>* outputs) { - TRT_ShapedWeights weights_input_l = inputs.at(0).weights(); - TRT_ShapedWeights weights_input_r = inputs.at(1).weights(); - - // Check type consistency - CHECK_EQ(weights_input_l.type_, weights_input_r.type_); - - if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims) - return tensorflow::errors::Unimplemented( - "Binary op implicit broadcast not supported: " + node_def.op()); - - // TODO(jie): constant fold should really fall back to TF. - int num_dims = weights_input_l.shape_.nbDims; - nvinfer1::Dims output_shape; - output_shape.nbDims = num_dims; - VLOG(2) << "nb_dims: " << num_dims - << ", the other: " << weights_input_r.shape_.nbDims; - for (int i = 0; i < num_dims; i++) { - if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) { - output_shape.d[i] = weights_input_l.shape_.d[i]; - } else if (weights_input_l.shape_.d[i] == 1 || - weights_input_r.shape_.d[i] == 1) { - output_shape.d[i] = - std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]); - } else { - return tensorflow::errors::Unimplemented( - "Binary op with incompatible shape at, " + node_def.op()); - } - VLOG(2) << "left: " << weights_input_l.shape_.d[i] - << "right: " << weights_input_r.shape_.d[i] - << "output: " << output_shape.d[i]; - } - - // FIXME assume type matches input weights - // Get trt type & shape - TFAttrs attrs(node_def); - // Maybe this part has to be moved into the block of rsqrt later - tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T"); - - // Allocate output weights - TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape); - - LambdaFactory binary_op; - if (node_def.op() == "Sub") { - binary_op.op = LambdaFactory::OP_CATEGORY::SUB; - } else if (node_def.op() == "Mul") { - binary_op.op = LambdaFactory::OP_CATEGORY::MUL; - } else if (node_def.op() == "Add") { - binary_op.op = LambdaFactory::OP_CATEGORY::ADD; - } else { - return tensorflow::errors::Unimplemented("Binary op not supported: " + - node_def.op()); - } - auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output, - binary_op); - - // Pass the output - if (ret == tensorflow::Status::OK()) { - outputs->push_back(TRT_TensorOrWeights(weights_output)); - } - - return ret; -} - // TODO(jie): broadcast is needed yet not implemented. // Only implemented channel wise for the time being tensorflow::Status BinaryTensorOpWeight( Converter& ctx, const tensorflow::NodeDef& node_def, const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights, - std::vector<TRT_TensorOrWeights>* outputs) { - // FIXME assume type matches input weights - // Get trt type & shape - // Maybe this part has to be moved into the block of rsqrt later + bool swapped_inputs, std::vector<TRT_TensorOrWeights>* outputs) { + // tensor is the left operand while weights is the right operand; + // when swapped_inputs set to true, those two are swapped. + // TODO(aaroey): use a set. + if (node_def.op() != "Sub" && node_def.op() != "Add" && + node_def.op() != "Mul" && node_def.op() != "Div" && + node_def.op() != "RealDiv") { + return tensorflow::errors::Unimplemented( + "op not supported: " + node_def.op() + ", at: " + node_def.name()); + } // Check type consistency nvinfer1::DataType ttype; @@ -910,6 +953,12 @@ tensorflow::Status BinaryTensorOpWeight( auto dims_w = weights.shape_; auto dims_t = tensor->getDimensions(); + // TODO(jie): addScale checks for input tensor dimension + if (dims_t.nbDims != 3) { + return tensorflow::errors::InvalidArgument( + "addScale requires tensor with rank 3, " + node_def.name()); + } + // default to element-wise auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; @@ -980,6 +1029,7 @@ tensorflow::Status BinaryTensorOpWeight( permutation[dims_t.nbDims] = 1; tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), permutation); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); } else { return tensorflow::errors::InvalidArgument( "Transpose cannot be applied, " + node_def.name()); @@ -997,11 +1047,35 @@ tensorflow::Status BinaryTensorOpWeight( // Maybe I should do a switch if (node_def.op() == "Sub") { - TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights); - LambdaFactory unary_op; - unary_op.op = LambdaFactory::OP_CATEGORY::NEG; - TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); - shift_weights = neg_weights; + if (swapped_inputs) { + shift_weights = weights; + nvinfer1::IUnaryLayer* layer = + ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor), + nvinfer1::UnaryOperation::kNEG); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + tensor = layer->getOutput(0); + } else { + TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights); + LambdaFactory unary_op; + unary_op.op = LambdaFactory::OP_CATEGORY::NEG; + TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); + shift_weights = neg_weights; + } + } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") { + if (swapped_inputs) { + scale_weights = weights; + nvinfer1::IUnaryLayer* layer = + ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor), + nvinfer1::UnaryOperation::kRECIP); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + tensor = layer->getOutput(0); + } else { + TRT_ShapedWeights recip_weights = ctx.get_temp_weights_like(weights); + LambdaFactory unary_op; + unary_op.op = LambdaFactory::OP_CATEGORY::RECIP; + TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op)); + scale_weights = recip_weights; + } } else if (node_def.op() == "Mul") { scale_weights = weights; } else if (node_def.op() == "Add") { @@ -1014,11 +1088,13 @@ tensorflow::Status BinaryTensorOpWeight( nvinfer1::IScaleLayer* layer = ctx.network()->addScale( *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights, scale_weights, power_weights); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); // transpose back dimension if (permutation_flag) { output_tensor = ctx.TransposeTensor(output_tensor, permutation); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } // Pass the output @@ -1042,20 +1118,31 @@ tensorflow::Status ConvertConv2DHelper( if (data_format == "NHWC") { tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 1, 2}); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); h_index = 1; w_index = 2; // TODO(jie): transpose it } // tensor after transpose (NCHW) - auto tensor_dim = tensor->getDimensions(); + const auto tensor_dim = tensor->getDimensions(); int num_groups = group; - if (num_groups == 0) // depthwise convolution - num_groups = tensor_dim.d[0]; + if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution VLOG(2) << "groups count: " << num_groups; TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + + VLOG(2) << "weight shape: " << weights_rsck.shape_.nbDims; + for (int i = 0; i < weights_rsck.shape_.nbDims; i++) { + VLOG(2) << weights_rsck.shape_.d[i]; + } + + if (weights_rsck.shape_.nbDims != 4) { + return tensorflow::errors::Internal( + "Conv2D expects kernel of dimension 4, at: " + node_def.name()); + } + if (ctx.isFP16()) { weights_rsck = ConvertFP32ToFP16(ctx, inputs.at(1).weights()); } @@ -1063,18 +1150,22 @@ tensorflow::Status ConvertConv2DHelper( TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); TRT_ShapedWeights biases(weights.type_); - int noutput = weights.shape_.d[0] * num_groups; + const int noutput = weights.shape_.d[0] * num_groups; nvinfer1::DimsHW kernel_size; kernel_size.h() = weights.shape_.d[2]; kernel_size.w() = weights.shape_.d[3]; + VLOG(2) << "RSCK: "; + for (int i = 0; i < 4; i++) { + VLOG(2) << " " << weights.shape_.d[i]; + } VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w(); // TODO(jie): stride. (NHWC/NCHW) - auto tf_stride = attrs.get<std::vector<int>>("strides"); + const auto tf_stride = attrs.get<std::vector<int>>("strides"); VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index; VLOG(2) << "stride!!!: " << tf_stride[0] << tf_stride[1] << tf_stride[2] << tf_stride[3]; - nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); std::vector<std::pair<int, int>> padding; // TODO(jie): padding. @@ -1102,6 +1193,7 @@ tensorflow::Status ConvertConv2DHelper( *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); + TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); auto dim_after = tensor->getDimensions(); @@ -1112,6 +1204,7 @@ tensorflow::Status ConvertConv2DHelper( nvinfer1::IConvolutionLayer* layer = ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor), noutput, kernel_size, weights, biases); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); @@ -1126,6 +1219,7 @@ tensorflow::Status ConvertConv2DHelper( if (data_format == "NHWC") { // TODO(jie): transpose it back! output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } else { VLOG(2) << "NCHW !!!!"; } @@ -1147,35 +1241,91 @@ tensorflow::Status ConvertConv2DHelper( node_def.name()); } +// Helper function converts input into tensor with shape specified by dims. +bool PrepareTensorForShape(Converter& ctx, const TRT_TensorOrWeights& input, + const nvinfer1::Dims& dims, + const nvinfer1::ITensor** tensor) { + if (input.is_tensor()) { + if (DimsEqual(input.shape(), dims)) { + *tensor = input.tensor(); + } else { + nvinfer1::IShuffleLayer* layer = ctx.network()->addShuffle( + *const_cast<nvinfer1::ITensor*>(input.tensor())); + if (layer != nullptr) { + layer->setReshapeDimensions(dims); + *tensor = layer->getOutput(0); + } else { + return false; + } + } + } else { +#if NV_TENSORRT_MAJOR > 3 + nvinfer1::IConstantLayer* layer = + ctx.network()->addConstant(dims, input.weights()); + if (layer != nullptr) { + *tensor = layer->getOutput(0); + } else { + return false; + } +#else + return false; +#endif + } + return true; +} + tensorflow::Status BinaryTensorOpTensor( Converter& ctx, const tensorflow::NodeDef& node_def, - const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r, + const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, std::vector<TRT_TensorOrWeights>* outputs) { static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{ {"Add", nvinfer1::ElementWiseOperation::kSUM}, {"Mul", nvinfer1::ElementWiseOperation::kPROD}, {"Sub", nvinfer1::ElementWiseOperation::kSUB}, {"Div", nvinfer1::ElementWiseOperation::kDIV}, + {"RealDiv", nvinfer1::ElementWiseOperation::kDIV}, + {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, + {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, }; - // FIXME assume type matches input weights + const nvinfer1::ITensor* tensor_l; + const nvinfer1::ITensor* tensor_r; + + nvinfer1::Dims dim_l; + nvinfer1::Dims dim_r; + + if (!TensorRTGetBroadcastShape(operand_l.shape(), operand_l.is_tensor(), + operand_r.shape(), operand_r.is_tensor(), + &dim_l, &dim_r)) { + return tensorflow::errors::InvalidArgument( + "Binary op broadcast scheme not supported by TensorRT op: " + + node_def.op() + ", at: " + node_def.name()); + } + + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, operand_l, dim_l, &tensor_l), node_def.name()); + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, operand_r, dim_r, &tensor_r), node_def.name()); + // get trt type & shape TFAttrs attrs(node_def); // maybe this part has to be moved into the block of rsqrt later nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T"); // check type consistency - CHECK_EQ_TYPE(tensor_l->getType(), dtype); - CHECK_EQ_TYPE(tensor_r->getType(), dtype); + TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype); + TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); - if (op_pair == ops.end()) + if (op_pair == ops.end()) { return tensorflow::errors::Unimplemented( - "binary op: " + node_def.op() + - " not supported at: " + node_def.name()); + "binary op: ", node_def.op(), " not supported at: ", node_def.name()); + } nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( + // TODO(aaroey): will tensor_l/tensor_r get modified? *const_cast<nvinfer1::ITensor*>(tensor_l), *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -1202,7 +1352,7 @@ tensorflow::Status ConvertPlugin(Converter& ctx, // passing attributes // TODO(jie): support more general attribute TFAttrs attrs(node_def); - auto attr_key_vector = attrs.GetAllAttrKey(); + auto attr_key_vector = attrs.GetAllAttrKeys(); for (auto attr_key : attr_key_vector) { // TODO(jie): support only list of float for toy example here. auto data = attrs.get<std::vector<float>>(attr_key); @@ -1223,29 +1373,6 @@ tensorflow::Status ConvertPlugin(Converter& ctx, return tensorflow::Status::OK(); } -tensorflow::Status ConvertPlaceholder( - Converter& ctx, const tensorflow::NodeDef& node_def, - const std::vector<TRT_TensorOrWeights>& inputs, - std::vector<TRT_TensorOrWeights>* outputs) { - VLOG(2) << "Placeholder should have been replace already"; - return tensorflow::errors::Unimplemented("cannot convert Placeholder op"); - // OK this make sense since we are supposed to replace it with input - TFAttrs attrs(node_def); - nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype"); - nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape"); - - dims.nbDims--; - for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1]; - - nvinfer1::ITensor* output = - ctx.network()->addInput(node_def.name().c_str(), dtype, dims); - if (!output) { - return tensorflow::errors::InvalidArgument("Failed to create Input layer"); - } - outputs->push_back(TRT_TensorOrWeights(output)); - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertConv2D(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs, @@ -1271,65 +1398,64 @@ tensorflow::Status ConvertPool(Converter& ctx, int h_index = 2; int w_index = 3; - auto data_format = attrs.get<string>("data_format"); + const auto data_format = attrs.get<string>("data_format"); if (data_format == "NHWC") { h_index = 1; w_index = 2; tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 1, 2}); - } else { - VLOG(2) << "NCHW !!!!"; + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); } + nvinfer1::PoolingType type; - // TODO(jie): support other pooling type - if (node_def.op() == "MaxPool") + if (node_def.op() == "MaxPool") { type = nvinfer1::PoolingType::kMAX; - else if (node_def.op() == "AvgPool") + } else if (node_def.op() == "AvgPool") { type = nvinfer1::PoolingType::kAVERAGE; - else - return tensorflow::errors::Unimplemented("Only supports Max pool"); + } else { + return tensorflow::errors::Unimplemented("Unsupported pool type: ", + node_def.op()); + } - // TODO(jie): NCHW - auto tf_stride = attrs.get<std::vector<int>>("strides"); - nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + const auto tf_stride = attrs.get<std::vector<int>>("strides"); + const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); - auto tf_kernel = attrs.get<std::vector<int>>("ksize"); - nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); + const auto tf_kernel = attrs.get<std::vector<int>>("ksize"); + const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); auto tensor_dim = tensor->getDimensions(); std::vector<std::pair<int, int>> padding; - // TODO(jie): padding. - if (attrs.get<string>("padding") == "SAME") { + const string padding_type = attrs.get<string>("padding"); + if (padding_type == "SAME") { // This is NCHW tensor with no batch dimension. // 1 -> h // 2 -> w padding = CreateSamePadding( stride, ksize, {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])}); - } else if (attrs.get<string>("padding") == "VALID") { - // No padding for valid padding here - VLOG(2) << "No padding added for VALID padding in pool" << node_def.name(); + } else if (padding_type == "VALID") { padding = {{0, 0}, {0, 0}}; } else { - return tensorflow::errors::Unimplemented( - "Current MaxPool cannot support padding other than SAME"); + return tensorflow::errors::Unimplemented("Unsupported padding type: ", + padding_type); } if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { - // TODO(jie): handle asymmetric padding VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second << padding[1].first << padding[1].second; auto pad_layer = ctx.network()->addPadding( *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); + TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); } nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling( *const_cast<nvinfer1::ITensor*>(tensor), type, ksize); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); @@ -1337,10 +1463,8 @@ tensorflow::Status ConvertPool(Converter& ctx, nvinfer1::ITensor* output_tensor = layer->getOutput(0); if (data_format == "NHWC") { - // TODO(jie): transpose it back! output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); - } else { - VLOG(2) << "NCHW !!!!"; + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); @@ -1353,6 +1477,7 @@ tensorflow::Status ConvertActivation( const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); nvinfer1::IActivationLayer* layer = ctx.network()->addActivation( *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); @@ -1363,40 +1488,61 @@ tensorflow::Status ConvertScale(Converter& ctx, const std::vector<TRT_TensorOrWeights>& inputs, std::vector<TRT_TensorOrWeights>* outputs) { if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) + !inputs.at(1).is_weights()) { return tensorflow::errors::Unimplemented( - "Only supports tensor op weight for now, at " + node_def.name()); - // Implement tensor binaryOp weight [channel wise] for now; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + "ConvertScale only supports tensor<op>weight: ", node_def.name()); + } + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); TRT_ShapedWeights weights = inputs.at(1).weights(); if (ctx.isFP16()) { weights = ConvertFP32ToFP16(ctx, inputs.at(1).weights()); } TRT_ShapedWeights empty_weights(weights.type_); - TFAttrs attrs(node_def); - // Transpose NHWC - auto data_format = attrs.get<string>("data_format"); + const auto data_format = attrs.get<string>("data_format"); + int channel_index; + const auto dims = tensor->getDimensions(); if (data_format == "NHWC") { - tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), - {0, 3, 1, 2}); - // TODO(jie): transpose it + // 1). NHWC is really N+C + channel_index = dims.nbDims - 1; // batch dimension is implicit here! } else { - VLOG(2) << "NCHW !!!!"; + // 2). NCHW is really N+CHW + channel_index = dims.nbDims - 3; // batch dimension is implicit here! } - auto dims = tensor->getDimensions(); - VLOG(2) << "tensor dimensions: " << dims.nbDims; - for (int i = 0; i < dims.nbDims; i++) { - VLOG(2) << "i: " << dims.d[i]; + nvinfer1::Permutation permutation; + for (int32_t i = 0; i < dims.nbDims; ++i) { + permutation.order[i] = i; } - dims = weights.shape_; - VLOG(2) << "tensor dimensions: " << dims.nbDims; - for (int i = 0; i < dims.nbDims; i++) { - VLOG(2) << "i: " << dims.d[i]; + + if (channel_index >= 0) { + permutation.order[0] = channel_index; + permutation.order[channel_index] = 0; + } else { + return tensorflow::errors::Unimplemented( + "TFTRT::BiasAdd cannot apply on batch dimension, at ", node_def.name()); + } + + // TensorRT addScale requires input to be of rank 3, we need to apply + // transpose as well as reshape + if (channel_index != 0 || dims.nbDims != 3) { + nvinfer1::IShuffleLayer* shuffle_layer = + ctx.network()->addShuffle(*const_cast<nvinfer1::ITensor*>(tensor)); + TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + nvinfer1::Dims reshape_dims; + reshape_dims.nbDims = 3; + reshape_dims.d[0] = 0; // 0 copy from the input + reshape_dims.d[1] = dims.nbDims >= 2 ? 0 : 1; // 0 copy from the input + reshape_dims.d[2] = dims.nbDims >= 3 ? -1 : 1; // -1 infer from the rest + if (channel_index != 0) { + // maybe we do not need this check. concerned about TRT optimization + shuffle_layer->setFirstTranspose(permutation); + } + shuffle_layer->setReshapeDimensions(reshape_dims); + tensor = shuffle_layer->getOutput(0); } nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; @@ -1407,14 +1553,26 @@ tensorflow::Status ConvertScale(Converter& ctx, nvinfer1::IScaleLayer* layer = ctx.network()->addScale(*const_cast<nvinfer1::ITensor*>(tensor), mode, weights, empty_weights, empty_weights); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); - if (data_format == "NHWC") { - // TODO(jie): transpose it back! - output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); - } else { - VLOG(2) << "NCHW !!!!"; + + // restore transpose & reshape + if (channel_index != 0 || dims.nbDims != 3) { + nvinfer1::IShuffleLayer* shuffle_layer = ctx.network()->addShuffle( + *const_cast<nvinfer1::ITensor*>(output_tensor)); + TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + nvinfer1::Dims reshape_dims = dims; + int tmp = reshape_dims.d[channel_index]; + reshape_dims.d[channel_index] = reshape_dims.d[0]; + reshape_dims.d[0] = tmp; + shuffle_layer->setReshapeDimensions(reshape_dims); + if (channel_index != 0) { + shuffle_layer->setSecondTranspose(permutation); + } + output_tensor = shuffle_layer->getOutput(0); } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -1431,11 +1589,13 @@ tensorflow::Status ConvertConst(Converter& ctx, // Create shaped weights as output tensorflow::Tensor tensor; - if (!tensor.FromProto(weights_tensor)) - return tensorflow::errors::Internal("Cannot parse weight tensor proto: " + + if (!tensor.FromProto(weights_tensor)) { + return tensorflow::errors::Internal("Cannot parse weight tensor proto: ", node_def.name()); + } TRT_ShapedWeights weights(dtype); + // TODO(aaroey): we should choose the array using dtype and shape. if (!weights_tensor.float_val().empty()) { VLOG(2) << "SCALAR!!!" << node_def.name(); nvinfer1::Dims scalar_shape; @@ -1443,22 +1603,16 @@ tensorflow::Status ConvertConst(Converter& ctx, VLOG(2) << "dimensions: " << tensor.dims(); VLOG(2) << "size: " << weights_tensor.float_val_size(); scalar_shape = GetTensorShape(tensor); + VLOG(2) << "details: "; for (int i = 0; i < scalar_shape.nbDims; i++) VLOG(2) << scalar_shape.d[i]; - if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size()) { - if (weights_tensor.float_val_size() == 1 || - scalar_shape.d[0] == weights_tensor.float_val_size()) { - scalar_shape.nbDims = 1; - // no dimension provided. flatten it - scalar_shape.d[0] = weights_tensor.float_val_size(); - scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; - } else { - LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" - << " kUNIFORM, at: " << node_def.name(); - string err_str("Broadcast method is not supported for '"); - StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); - return tensorflow::errors::InvalidArgument(err_str); - } + if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size() && + weights_tensor.float_val_size() != 1) { + LOG(ERROR) << "Broadcast on weights only supports kCHANNEL and" + << " kUNIFORM, at: " << node_def.name(); + string err_str("Broadcast method is not supported for '"); + StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); + return tensorflow::errors::InvalidArgument(err_str); } } else { VLOG(2) << "Dimensions: " << tensor.dims(); @@ -1468,39 +1622,42 @@ tensorflow::Status ConvertConst(Converter& ctx, scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) { scalar_shape.d[i] = 0; - scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; } } + // TODO(aaroey): use GetShapeSize(). size_t len_data = tensorflow::DataTypeSize(dtype); for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i]; ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data)); void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0])); - std::vector<float> tensor_data( - weights_tensor.float_val().begin(), - weights_tensor.float_val() - .end()); // make a local copy first to flatten - memcpy(dst, tensor_data.data(), len_data); // store into weight store + if (weights_tensor.float_val_size() == 1) { + std::fill_n((float*)dst, GetShapeSize(scalar_shape), + *weights_tensor.float_val().begin()); + } else { + // TODO(aaroey): get rid of this copy as RepeatedField is always + // contiguous make a local copy first to flatten doesn't have to be + // contiguous + std::vector<float> tensor_data(weights_tensor.float_val().begin(), + weights_tensor.float_val().end()); + memcpy(dst, tensor_data.data(), len_data); // store into weight store + } + VLOG(2) << "create shape details: "; + for (int i = 0; i < scalar_shape.nbDims; i++) VLOG(2) << scalar_shape.d[i]; weights = TRT_ShapedWeights(dtype, dst, scalar_shape); } else if (!weights_tensor.int_val().empty()) { + // TODO(aaroey): this is very similar to the above code for float, merge + // them. VLOG(2) << "int!!!" << node_def.name(); nvinfer1::Dims scalar_shape; if (tensor.dims() > 0) { VLOG(2) << "dimensions: " << tensor.dims(); scalar_shape = GetTensorShape(tensor); - if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size()) { - if (weights_tensor.int_val_size() == 1 || - scalar_shape.d[0] == weights_tensor.int_val_size()) { - scalar_shape.nbDims = 1; - // no dimension provided. flatten it - scalar_shape.d[0] = weights_tensor.int_val_size(); - scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; - } else { - LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" - << " kUNIFORM, at: " << node_def.name(); - string err_str("Broadcast method is not supported for '"); - StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); - return tensorflow::errors::InvalidArgument(err_str); - } + if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size() && + weights_tensor.int_val_size() != 1) { + LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" + << " kUNIFORM, at: " << node_def.name(); + string err_str("Broadcast method is not supported for '"); + StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); + return tensorflow::errors::InvalidArgument(err_str); } } else { VLOG(2) << "dimensions: " << tensor.dims(); @@ -1513,23 +1670,30 @@ tensorflow::Status ConvertConst(Converter& ctx, scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; } } - // we should not have converted //if (ctx.isFP16()) { + // we should not have converted size_t len_data = tensorflow::DataTypeSize(dtype); for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i]; size_t len_tensor = weights_tensor.int_val_size() * sizeof(int32); len_data = std::max(len_data, len_tensor); ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data)); void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0])); - std::vector<int32> tensor_data( - weights_tensor.int_val().begin(), - weights_tensor.int_val().end()); // make a local copy first to flatten - // doesn't have to be contigous - memcpy(dst, tensor_data.data(), len_tensor); // store into weight store + if (weights_tensor.int_val_size() == 1) { + std::fill_n((int*)dst, GetShapeSize(scalar_shape), + *weights_tensor.int_val().begin()); + } else { + // TODO(aaroey): get rid of this copy as RepeatedField is always + // contiguous make a local copy first to flatten doesn't have to be + // contiguous + std::vector<int32> tensor_data(weights_tensor.int_val().begin(), + weights_tensor.int_val().end()); + memcpy(dst, tensor_data.data(), len_tensor); // store into weight store + } weights = TRT_ShapedWeights(dtype, dst, scalar_shape); } else if (!weights_tensor.tensor_content().empty()) { - // obsolete method. - // After optimization path, we do not see weights in this format. - // fp16 conversion technically should be needed here. + // obsolete method. + // After optimization path, we do not see weights in this format. + // TODO(aaroey): why? + // fp16 conversion technically should be needed here. VLOG(2) << "TENSOR!!!" << node_def.name(); const auto& content = weights_tensor.tensor_content(); @@ -1543,8 +1707,8 @@ tensorflow::Status ConvertConst(Converter& ctx, content, static_cast<char*>(const_cast<void*>(weights.GetValues()))); } } else { - return tensorflow::errors::Unimplemented( - "Not supported constant type, at " + node_def.name()); + return tensorflow::errors::Unimplemented("Not supported constant type, at ", + node_def.name()); } // Pass the output outputs->push_back(TRT_TensorOrWeights(weights)); @@ -1563,96 +1727,144 @@ tensorflow::Status ConvertBinary(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs, std::vector<TRT_TensorOrWeights>* outputs) { - if (inputs.size() != 2) + if (inputs.size() != 2) { return tensorflow::errors::FailedPrecondition( - "Binary ops require two tensor input, at " + node_def.name()); - - if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) - return ConstantFoldBinary(ctx, node_def, inputs, outputs); - - if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) - return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(), - inputs.at(1).weights(), outputs); + "Binary ops require two tensor input, at ", node_def.name()); + } - if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) - return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(), - inputs.at(0).weights(), outputs); + // Constant folding should have been done by TensorFlow - if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) - return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(), - inputs.at(1).tensor(), outputs); + if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) { + return tensorflow::errors::Unimplemented( + "Constant folding is falled back to TensorFlow, binary op received " + "both input as constant at: ", + node_def.name()); + } - return tensorflow::errors::Unknown("Binary op input error, at " + - node_def.name()); + // Try to convert into Scale layer first (for better performance) + // Since scale layer supports restricted broadcast policy and op types, we + // allow failure and try to handle it through Elementwise op + // (BinaryTensorOpTensor) + Status status = tensorflow::Status::OK(); + if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) { + status = BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(), + inputs.at(1).weights(), false, outputs); + } else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) { + status = BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(), + inputs.at(0).weights(), true, outputs); +#if NV_TENSORRT_MAJOR == 3 + } else { +#else + } + if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor() || !status.ok()) { +#endif + status = BinaryTensorOpTensor(ctx, node_def, inputs.at(0), inputs.at(1), + outputs); + } + return status; } tensorflow::Status ConvertUnary(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs, std::vector<TRT_TensorOrWeights>* outputs) { - if (inputs.size() != 1) + static const std::unordered_map<string, nvinfer1::UnaryOperation> ops{ + {"Neg", nvinfer1::UnaryOperation::kNEG}, + {"Exp", nvinfer1::UnaryOperation::kEXP}, + {"Log", nvinfer1::UnaryOperation::kLOG}, + {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, + {"Abs", nvinfer1::UnaryOperation::kABS}, + {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, + }; + + if (inputs.size() != 1) { return tensorflow::errors::FailedPrecondition( - "Unary ops require single tensor input, at " + node_def.name()); + "Unary ops require single tensor input, at ", node_def.name()); + } - if (inputs.at(0).is_weights()) - return ConstantFoldUnary(ctx, node_def, inputs, outputs); - else if (inputs.at(0).is_tensor()) +#if NV_TENSORRT_MAJOR == 3 + if (inputs.at(0).is_weights()) { return tensorflow::errors::Unimplemented( - "Unary op for tensor not supported, at " + node_def.name()); + "Constant folding for unary op is not supported", node_def.name()); + } +#endif + + // TODO(jie): check type + const nvinfer1::ITensor* tensor; + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, inputs.at(0), inputs.at(0).shape(), &tensor), + node_def.name()); + + nvinfer1::IUnaryLayer* layer; + if (node_def.op() == "Rsqrt") { + layer = ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor), + nvinfer1::UnaryOperation::kSQRT); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + tensor = layer->getOutput(0); + layer = ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor), + nvinfer1::UnaryOperation::kRECIP); + } else if (ops.count(node_def.op()) != 0) { + layer = ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor), + ops.at(node_def.op())); + } else { + return tensorflow::errors::InvalidArgument( + "Binary op: ", node_def.op(), " not supported, at ", node_def.name()); + } - return tensorflow::errors::Unknown("Binary op input error, at " + - node_def.name()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); } -tensorflow::Status ConvertReduce(Converter& ctx, - const tensorflow::NodeDef& node_def, - const std::vector<TRT_TensorOrWeights>& inputs, - std::vector<TRT_TensorOrWeights>* outputs) { +#if NV_TENSORRT_MAJOR == 3 +tensorflow::Status ConvertReducePool( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector<TRT_TensorOrWeights>& inputs, + std::vector<TRT_TensorOrWeights>* outputs) { if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) + !inputs.at(1).is_weights()) { return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at" + node_def.name()); + "Input expects tensor and weights, at", node_def.name()); + } // Implement tensor binaryOp weight [channel wise] for now; const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - auto dims = tensor->getDimensions(); + const auto dims = tensor->getDimensions(); // Restore implicit batch dimension - int nb_dims = dims.nbDims + 1; + const int nb_dims = dims.nbDims + 1; TRT_ShapedWeights index_list = inputs.at(1).weights(); - TFAttrs attrs(node_def); - // TODO(jie): handle data type. - // Index type here is done through TF type, so I can leverage their - // EnumToDataType for my cast auto index_type = attrs.get<tensorflow::DataType>("Tidx"); // Only expect to handle INT32 as attributes for now - if (index_type != tensorflow::DataType::DT_INT32) + if (index_type != tensorflow::DataType::DT_INT32) { return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); - auto index_list_data = + } + const auto index_list_data = static_cast<int*>(const_cast<void*>(index_list.GetValues())); - // Hack warning: have to fall back to pool layer since reduce is not in public - // TRT yet. - if (nb_dims != 4) + if (nb_dims != 4) { return tensorflow::errors::InvalidArgument( - "TRT only support reduce on 4 dimensional tensors, at" + + "TRT only support reduce on 4 dimensional tensors, at", node_def.name()); - if (index_list.count() > 2) + } + if (index_list.count() > 2) { return tensorflow::errors::InvalidArgument( - "TRT cannot support reduce on more than 2 dimensions, at" + + "TRT cannot support reduce on more than 2 dimensions, at", node_def.name()); + } std::set<int> idx_set; // We cannot operate on Channel. permutation flag used to transpose tensor int permuted_index = -1; for (int i = 0; i < index_list.count(); i++) { - if (index_list_data[i] == 0) - return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" + + if (index_list_data[i] == 0) { + return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at", node_def.name()); + } if (index_list_data[i] == 1) permuted_index = 1; - idx_set.emplace(index_list_data[i]); } @@ -1673,6 +1885,7 @@ tensorflow::Status ConvertReduce(Converter& ctx, // Apply permutation before extracting dimension for pool_kernel tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), permutation_order); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); } // Apply permutation before extracting dimension for pool_kernel @@ -1685,34 +1898,104 @@ tensorflow::Status ConvertReduce(Converter& ctx, nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::PoolingType::kAVERAGE, pool_kernel); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); output_tensor = layer->getOutput(0); } else { - return tensorflow::errors::Unimplemented( - "Op not supported " + node_def.op() + " , at " + node_def.name()); + return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(), + " , at ", node_def.name()); } if (permuted_index != -1) { // Apply permutation before extracting dimension for pool_kernel output_tensor = ctx.TransposeTensor( const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } +#elif NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertReduce(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector<TRT_TensorOrWeights>& inputs, + std::vector<TRT_TensorOrWeights>* outputs) { + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) { + return tensorflow::errors::InvalidArgument( + "Input expects tensor and weights, at", node_def.name()); + } + + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + TRT_ShapedWeights index_list = inputs.at(1).weights(); + + TFAttrs attrs(node_def); + auto index_type = attrs.get<tensorflow::DataType>("Tidx"); + + // Only expect to handle INT32 as attributes for now + if (index_type != tensorflow::DataType::DT_INT32) { + return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); + } + + const auto keep_dims = attrs.get<bool>("keep_dims"); + auto index_list_data = + static_cast<int*>(const_cast<void*>(index_list.GetValues())); + + int axes = 0; + if (index_list.count() == 0) { + return tensorflow::errors::InvalidArgument( + "TRT cannot support reduce on all (batch) dimensions, at", + node_def.name()); + } else { + for (int i = 0; i < index_list.count(); i++) { + if (index_list_data[i] == 0) { + return tensorflow::errors::InvalidArgument( + "TRT cannot reduce at batch dimension, at", node_def.name()); + } + axes |= (1 << (index_list_data[i] - 1)); + } + } + + nvinfer1::ReduceOperation reduce_operation; + if (node_def.op() == "Sum") { + reduce_operation = nvinfer1::ReduceOperation::kSUM; + } else if (node_def.op() == "Prod") { + reduce_operation = nvinfer1::ReduceOperation::kPROD; + } else if (node_def.op() == "Max") { + reduce_operation = nvinfer1::ReduceOperation::kMAX; + } else if (node_def.op() == "Min") { + reduce_operation = nvinfer1::ReduceOperation::kMIN; + } else if (node_def.op() == "Mean") { + reduce_operation = nvinfer1::ReduceOperation::kAVG; + } else { + return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(), + " , at ", node_def.name()); + } + + nvinfer1::ILayer* layer = + ctx.network()->addReduce(*const_cast<nvinfer1::ITensor*>(tensor), + reduce_operation, axes, keep_dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + + outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return tensorflow::Status::OK(); +} +#endif tensorflow::Status ConvertPad(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs, std::vector<TRT_TensorOrWeights>* outputs) { + // TODO(aaroey): make a routine for this check and reuse it. if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) + !inputs.at(1).is_weights()) { return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at" + node_def.name()); + "Input expects tensor and weights, at", node_def.name()); + } // Implement tensor binaryOp weight [channel wise] for now; const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - auto dims = tensor->getDimensions(); + const auto dims = tensor->getDimensions(); // Restore implicit batch dimension - int nb_dims = dims.nbDims + 1; + const int nb_dims = dims.nbDims + 1; TRT_ShapedWeights pads = inputs.at(1).weights(); @@ -1722,21 +2005,24 @@ tensorflow::Status ConvertPad(Converter& ctx, auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings"); // TODO(jie): handle data type conversion for TRT? - if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) + if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) { return tensorflow::errors::InvalidArgument( - "Pad only supports explicit padding on 4 dimensional tensor, at " + + "Pad only supports explicit padding on 4 dimensional tensor, at ", node_def.name()); + } // Only expect to handle INT32 as attributes for now - if (padding_type != tensorflow::DataType::DT_INT32) + if (padding_type != tensorflow::DataType::DT_INT32) { return tensorflow::errors::Unimplemented( "Tpaddings supports only DT_INT32"); + } auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues())); std::vector<int32_t> pad_index; for (int i = 0; i < nb_dims; i++) { - if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) + if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) { pad_index.push_back(i); + } } // No padding at all, we should exit @@ -1746,20 +2032,23 @@ tensorflow::Status ConvertPad(Converter& ctx, } // Only supports padding on less than 2 axis GIE-2579 - if (pad_index.size() > 2) + if (pad_index.size() > 2) { return tensorflow::errors::InvalidArgument( "Padding layer does not support padding on > 2"); + } // Padding on batch dimension is not supported - if (pad_index[0] == 0) + if (pad_index[0] == 0) { return tensorflow::errors::InvalidArgument( "Padding layer does not support padding on batch dimension"); + } // Not doing the legit thing here. ignoring padding on dim 1 and 3; // TODO(jie): implement pad as uff parser - if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) + if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) { return tensorflow::errors::Unimplemented( "Padding layer does not support padding on dimension 1 and 3 yet"); + } bool legit_pad = true; nvinfer1::DimsHW pre_padding(0, 0); @@ -1770,6 +2059,7 @@ tensorflow::Status ConvertPad(Converter& ctx, legit_pad = false; tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 2, 1}); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); permuted_pad_index[0] = 3; } @@ -1786,11 +2076,14 @@ tensorflow::Status ConvertPad(Converter& ctx, nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding( *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); - if (!legit_pad) + if (!legit_pad) { output_tensor = ctx.TransposeTensor( const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1}); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); + } outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); @@ -1803,9 +2096,10 @@ tensorflow::Status ConvertConcat(Converter& ctx, // not including the last input (axis) here int input_size = static_cast<int>(inputs.size()) - 1; - if (!inputs.at(0).is_tensor()) + if (!inputs.at(0).is_tensor()) { return tensorflow::errors::InvalidArgument( - "Concat in TRT support only Tensor input, at " + node_def.name()); + "Concat in TRT support only Tensor input, at ", node_def.name()); + } // We are retrieving the axis TRT_ShapedWeights axis = inputs.at(input_size).weights(); @@ -1816,8 +2110,8 @@ tensorflow::Status ConvertConcat(Converter& ctx, // TODO(jie): handle data type // Only expect to handle INT32 as index attributes for now if (index_type != tensorflow::DataType::DT_INT32) - return tensorflow::errors::Unimplemented( - "Tidx supports only DT_INT32, at " + node_def.name()); + return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32, at ", + node_def.name()); int index = *(static_cast<int*>(const_cast<void*>(axis.GetValues()))); @@ -1825,23 +2119,29 @@ tensorflow::Status ConvertConcat(Converter& ctx, auto dim = inputs.at(0).tensor()->getDimensions(); // dimension check - if (index > dim.nbDims + 1) + if (index > dim.nbDims + 1) { return tensorflow::errors::InvalidArgument( - "Concatenate on axis out of dimension range, at " + node_def.name()); - - if (index == 0) + "Concatenate on axis out of dimension range, at ", node_def.name()); + } + if (index == 0) { return tensorflow::errors::InvalidArgument( - "Concatenate on batch dimension not supported, at " + node_def.name()); + "Concatenate on batch dimension not supported, at ", node_def.name()); + } + if (index < 0) { + index = dim.nbDims + index + 1; + } +#if NV_TENSORRT_MAJOR == 3 // incase we need permutation; std::vector<int> permutation_order(dim.nbDims + 1); for (int i = 0; i < dim.nbDims + 1; i++) permutation_order[i] = i; if (index != 1) { - permutation_order[1] = index - 1; - permutation_order[index - 1] = 1; + permutation_order[1] = index; + permutation_order[index] = 1; } +#endif std::vector<nvinfer1::ITensor const*> inputs_vec; // Shap chack (all input tensor should have same shape) @@ -1849,24 +2149,28 @@ tensorflow::Status ConvertConcat(Converter& ctx, for (int i = 0; i < input_size; i++) { auto tensor_i = inputs.at(i).tensor(); auto dim_i = tensor_i->getDimensions(); - if (dim_i.nbDims != dim.nbDims) + if (dim_i.nbDims != dim.nbDims) { return tensorflow::errors::InvalidArgument( - "Concatenate receives inputs with inconsistent dimensions, at " + + "Concatenate receives inputs with inconsistent dimensions, at ", node_def.name()); - + } for (int j = 0; j < dim.nbDims; j++) { // check dimension consistency on non-concatenate axis - if (j != index - 1 && dim_i.d[j] != dim.d[j]) + if (j != index - 1 && dim_i.d[j] != dim.d[j]) { return tensorflow::errors::InvalidArgument( - "Concatenate receives inputs with inconsistent shape, at" + + "Concatenate receives inputs with inconsistent shape, at", node_def.name()); + } } - // TRT does concatenation only on channel! - if (index != 1) +#if NV_TENSORRT_MAJOR == 3 + // TRT3 does concatenation only on channel! + if (index != 1) { tensor_i = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor_i), permutation_order); - + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor_i, node_def.name()); + } +#endif inputs_vec.push_back(tensor_i); } @@ -1874,11 +2178,18 @@ tensorflow::Status ConvertConcat(Converter& ctx, nvinfer1::IConcatenationLayer* layer = ctx.network()->addConcatenation( const_cast<nvinfer1::ITensor* const*>(inputs_vec.data()), inputs_vec.size()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); +#if NV_TENSORRT_MAJOR > 3 + layer->setAxis(index - 1); +#endif nvinfer1::ITensor* output_tensor = layer->getOutput(0); +#if NV_TENSORRT_MAJOR == 3 if (index != 1) { output_tensor = ctx.TransposeTensor(output_tensor, permutation_order); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } +#endif outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -1997,112 +2308,249 @@ tensorflow::Status ConvertFusedBatchNorm( combined_offset_weights.GetWeightsForTRT(), combined_scale_weights.GetWeightsForTRT(), dummy_power_weights.GetWeightsForTRT()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } +#if NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertMatMulHelper( + Converter& ctx, TRT_TensorOrWeights tensor_input, + TRT_ShapedWeights weights_raw, bool transpose_weight, string node_name, + std::vector<TRT_TensorOrWeights>* outputs) { + nvinfer1::ITensor* output_tensor; + if (!tensor_input.is_tensor()) { + return tensorflow::errors::InvalidArgument("Input 0 expects tensor"); + } + const nvinfer1::ITensor* tensor = tensor_input.tensor(); + + TRT_ShapedWeights weights(weights_raw.type_); + if (transpose_weight) { + weights = weights_raw; + } else { + TRT_ShapedWeights weights_ck = weights_raw; + weights = ctx.get_temp_weights_like(weights_ck); + ReorderCKtoKC(weights_raw, &weights); + } + TRT_ShapedWeights biases(weights.type_); + + int noutput = weights.shape_.d[0]; + + auto input_dim = tensor->getDimensions(); + while (input_dim.nbDims != 3) { + input_dim.d[input_dim.nbDims++] = 1; + } + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, tensor_input, input_dim, &tensor), node_name); + + nvinfer1::IFullyConnectedLayer* layer = ctx.network()->addFullyConnected( + *const_cast<nvinfer1::ITensor*>(tensor), noutput, weights, biases); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name); + output_tensor = layer->getOutput(0); + + const nvinfer1::ITensor* temp_tensor; + auto output_dim = output_tensor->getDimensions(); + output_dim.nbDims = 1; + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, TRT_TensorOrWeights(output_tensor), output_dim, + &temp_tensor), + node_name); + output_tensor = const_cast<nvinfer1::ITensor*>(temp_tensor); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +// inputs are both two dimensional (tensorflow::ops::MatMul) tensorflow::Status ConvertMatMul(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs, std::vector<TRT_TensorOrWeights>* outputs) { + if (!inputs.at(0).is_tensor()) { + return tensorflow::errors::InvalidArgument("Input 0 expects tensor, at" + + node_def.name()); + } + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - // TODO(jie): transpose! TFAttrs attrs(node_def); - TRT_ShapedWeights weights_ck = inputs.at(1).weights(); - TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_ck); - ReorderCKtoKC(weights_ck, &weights); - TRT_ShapedWeights biases(weights.type_); + // TODO(jie): INT32 should be converted? + tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T"); + if (tf_dtype != tensorflow::DataType::DT_FLOAT && + tf_dtype != tensorflow::DataType::DT_HALF) { + return tensorflow::errors::Unimplemented( + "data type is not supported, for node " + node_def.name() + " got " + + tensorflow::DataTypeString(tf_dtype)); + } - int noutput = weights.shape_.d[0]; + bool transpose_a = attrs.get<bool>("transpose_a"); + bool transpose_b = attrs.get<bool>("transpose_b"); - nvinfer1::IFullyConnectedLayer* layer = ctx.network()->addFullyConnected( - *const_cast<nvinfer1::ITensor*>(tensor), noutput, weights, biases); + nvinfer1::ITensor* output_tensor; - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + // FullyConnected: + if (transpose_a) { + return tensorflow::errors::Internal( + "Transpose_a is not supported for TensorRT FullyConnected (op: " + + node_def.op() + "), at: " + node_def.name()); + } + if (inputs.at(1).is_tensor()) { + return tensorflow::errors::Internal( + "Operand 1 must be constant for TensorRT FullyConnected (op: " + + node_def.op() + "), at: " + node_def.name()); + } + return ConvertMatMulHelper(ctx, inputs.at(0), inputs.at(1).weights(), + transpose_b, node_def.name(), outputs); } -tensorflow::Status ConvertReshape( +tensorflow::Status ConvertBatchMatMul( Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs, std::vector<TRT_TensorOrWeights>* outputs) { - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at" + node_def.name()); + TFAttrs attrs(node_def); - // implement tensor binaryOp weight [channel wise] for now; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - auto dims = tensor->getDimensions(); - // restore implicit batch dimension + // TODO(jie): INT32 should be converted? + tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T"); + if (tf_dtype != tensorflow::DataType::DT_FLOAT && + tf_dtype != tensorflow::DataType::DT_HALF) { + return tensorflow::errors::Unimplemented( + "data type is not supported, for node " + node_def.name() + " got " + + tensorflow::DataTypeString(tf_dtype)); + } - TRT_ShapedWeights shape = inputs.at(1).weights(); + bool transpose_a = attrs.get<bool>("adj_x"); + bool transpose_b = attrs.get<bool>("adj_y"); - TFAttrs attrs(node_def); + auto dims = inputs.at(0).shape(); + if (dims.nbDims == 1) { // NC * CK is only supported through fully connected + if (transpose_a == false && inputs.at(0).is_tensor() && + inputs.at(1).is_weights()) { + return ConvertMatMulHelper(ctx, inputs.at(0), inputs.at(1).weights(), + transpose_b, node_def.name(), outputs); + } else { + return tensorflow::errors::InvalidArgument( + "Invalid configuration for MatMul, at: " + node_def.name()); + } + } - auto padding_type = attrs.get<tensorflow::DataType>("Tshape"); + const nvinfer1::ITensor* tensor_l; + const nvinfer1::ITensor* tensor_r; + auto dims_l = inputs.at(0).shape(); + auto dims_r = inputs.at(1).shape(); + if (inputs.at(0).is_weights()) { + if (inputs.at(0).shape().d[0] != 1) { + return tensorflow::errors::InvalidArgument( + "Input 0 as weight assumes broadcast across batch for MatMul, at: " + + node_def.name()); + } else { + for (int i = 0; i < dims_l.nbDims - 1; i++) { + dims_l.d[i] = dims_l.d[i + 1]; + } + dims_l.nbDims--; + } + } + if (inputs.at(1).is_weights()) { + if (inputs.at(1).shape().d[0] != 1) { + return tensorflow::errors::InvalidArgument( + "Input 1 as weight assumes broadcast across batch for MatMul, at: " + + node_def.name()); + } else { + for (int i = 0; i < dims_r.nbDims - 1; i++) { + dims_r.d[i] = dims_r.d[i + 1]; + } + dims_r.nbDims--; + } + } - if (shape.shape_.nbDims != 1) - return tensorflow::errors::InvalidArgument( - "reshape new shape is not 1 dimensional, at " + node_def.name()); + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, inputs.at(0), dims_l, &tensor_l), + node_def.name()); + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, inputs.at(1), dims_r, &tensor_r), + node_def.name()); - // Only expect to handle INT32 as attributes for now - if (padding_type != tensorflow::DataType::DT_INT32) - return tensorflow::errors::Unimplemented( - "reshape new shape supports only DT_INT32, at " + node_def.name()); + nvinfer1::IMatrixMultiplyLayer* layer = ctx.network()->addMatrixMultiply( + *const_cast<nvinfer1::ITensor*>(tensor_l), transpose_a, + *const_cast<nvinfer1::ITensor*>(tensor_r), transpose_b); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} +#endif - auto shape_data = static_cast<int*>(const_cast<void*>(shape.GetValues())); +#if NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertSoftmax( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector<TRT_TensorOrWeights>& inputs, + std::vector<TRT_TensorOrWeights>* outputs) { + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - if (shape_data[0] != -1) + int nbDims = tensor->getDimensions().nbDims; + if (nbDims == 0) { return tensorflow::errors::InvalidArgument( - "reshape new shape first dimension is not -1, at " + node_def.name()); + "TensorRT Softmax cannot apply on batch dimension, at" + + node_def.name()); + } + nvinfer1::ISoftMaxLayer* layer = + ctx.network()->addSoftMax(*const_cast<nvinfer1::ITensor*>(tensor)); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // Tensorflow SoftMax assumes applying softmax on the last dimension. + layer->setAxes(1 << (nbDims - 1)); - auto shape_num_dims = shape.shape_.d[0]; - VLOG(2) << "shape dimensions: " << shape_num_dims; - int volume_w = 1; - for (int i = 1; i < shape.shape_.d[0]; i++) volume_w *= shape_data[i]; + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} +#endif - int volume_t = 1; - for (int i = 0; i < dims.nbDims; i++) volume_t *= dims.d[i]; +#if NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertTopK(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector<TRT_TensorOrWeights>& inputs, + std::vector<TRT_TensorOrWeights>* outputs) { + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - VLOG(2) << "volume: " << volume_t << " volume weights: " << volume_w; - if (volume_w != volume_t) + int nbDims = tensor->getDimensions().nbDims; + if (nbDims == 0) { return tensorflow::errors::InvalidArgument( - "volume does not agree between tensor and new shape, at " + - node_def.name()); + "TensorRT TopK cannot apply on batch dimension, at" + node_def.name()); + } - nvinfer1::IShuffleLayer* layer = - ctx.network()->addShuffle(*const_cast<nvinfer1::ITensor*>(tensor)); + TRT_ShapedWeights k_w = inputs.at(1).weights(); + int k = *(static_cast<int*>(const_cast<void*>(k_w.GetValues()))); - nvinfer1::Dims reshape_dims; - VLOG(2) << "new dimension: " << shape_num_dims - 1; - reshape_dims.nbDims = shape_num_dims - 1; - for (int32_t i = 0; i < reshape_dims.nbDims; ++i) { - reshape_dims.d[i] = shape_data[i + 1]; + nvinfer1::TopKOperation op; + uint32_t reducedAxes = 0; + if (node_def.op() == "TopKV2") { + op = nvinfer1::TopKOperation::kMAX; + reducedAxes |= 1 << (nbDims - 1); + } else { + return tensorflow::errors::Unimplemented( + "Operation: " + node_def.op() + + " not implemented, at: " + node_def.name()); } - layer->setReshapeDimensions(reshape_dims); - VLOG(2) << "new dimension: " << shape_num_dims - 1; - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - auto dims_output = output_tensor->getDimensions(); - VLOG(2) << "output tensor dimension:" << dims_output.nbDims; - outputs->push_back(TRT_TensorOrWeights(output_tensor)); + nvinfer1::ITopKLayer* layer = ctx.network()->addTopK( + *const_cast<nvinfer1::ITensor*>(tensor), op, k, reducedAxes); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + + nvinfer1::ITensor* output_value_tensor = layer->getOutput(0); + nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1); + outputs->push_back(TRT_TensorOrWeights(output_value_tensor)); + outputs->push_back(TRT_TensorOrWeights(output_indices_tensor)); return tensorflow::Status::OK(); } +#endif void Converter::register_op_converters() { // vgg_16 slim implementation - op_registry_["Placeholder"] = ConvertPlaceholder; op_registry_["Conv2D"] = ConvertConv2D; op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; op_registry_["Relu"] = ConvertActivation; op_registry_["MaxPool"] = ConvertPool; op_registry_["AvgPool"] = ConvertPool; - // This could be really handled as ConvertBinary op_registry_["BiasAdd"] = ConvertScale; op_registry_["Const"] = ConvertConst; // TODO(ben,jie): this is a temp hack. @@ -2113,18 +2561,38 @@ void Converter::register_op_converters() { op_registry_["Add"] = ConvertBinary; op_registry_["Mul"] = ConvertBinary; op_registry_["Sub"] = ConvertBinary; - op_registry_["Rsqrt"] = ConvertUnary; - op_registry_["Mean"] = ConvertReduce; op_registry_["Pad"] = ConvertPad; - // TODO(ben,jie): Add more ops op_registry_["ConcatV2"] = ConvertConcat; - op_registry_["MatMul"] = ConvertMatMul; - op_registry_["Reshape"] = ConvertReshape; op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; - plugin_converter_ = ConvertPlugin; + op_registry_["Div"] = ConvertBinary; + op_registry_["RealDiv"] = ConvertBinary; + + op_registry_["Rsqrt"] = ConvertUnary; + op_registry_["Reciprocal"] = ConvertUnary; + op_registry_["Exp"] = ConvertUnary; + op_registry_["Log"] = ConvertUnary; + op_registry_["Sqrt"] = ConvertUnary; + op_registry_["Abs"] = ConvertUnary; + op_registry_["Neg"] = ConvertUnary; +#if NV_TENSORRT_MAJOR == 3 + op_registry_["Mean"] = ConvertReducePool; +#endif +#if NV_TENSORRT_MAJOR > 3 + op_registry_["Sum"] = ConvertReduce; + op_registry_["Prod"] = ConvertReduce; + op_registry_["Max"] = ConvertReduce; + op_registry_["Min"] = ConvertReduce; + op_registry_["Mean"] = ConvertReduce; + op_registry_["Maximum"] = ConvertBinary; + op_registry_["Minimum"] = ConvertBinary; + op_registry_["Softmax"] = ConvertSoftmax; + op_registry_["MatMul"] = ConvertMatMul; + op_registry_["BatchMatMul"] = ConvertBatchMatMul; + op_registry_["TopKV2"] = ConvertTopK; +#endif } } // namespace diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 8a17eb02f1..04d072f5d9 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -316,6 +316,11 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, ctx->SetStatus(tensorflow::errors::InvalidArgument( "INT8 inputs are not supported!")); return; +#if NV_TENSORRT_MAJOR > 3 + case nvinfer1::DataType::kINT32: + buffers[binding_index] = (void*)(input_tensor.flat<int32>().data()); + break; +#endif default: LOG(ERROR) << "Unknown TRT data type: " << int(dtype); ctx->SetStatus(tensorflow::errors::InvalidArgument( @@ -368,6 +373,12 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, ctx->SetStatus(tensorflow::errors::InvalidArgument( "INT8 outputs are not supported!")); return; +#if NV_TENSORRT_MAJOR > 3 + case nvinfer1::DataType::kINT32: + buffers[binding_index] = + reinterpret_cast<void*>(output_tensor->flat<int32>().data()); + break; +#endif default: LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype); ctx->SetStatus(tensorflow::errors::InvalidArgument( diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc index 383635f428..e0c7b62723 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -42,8 +42,14 @@ REGISTER_OP("TRTEngineOp") .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}") .Attr("calibration_data: string = ''") .Input("in_tensor: InT") - .Output("out_tensor: OutT") - .SetShapeFn(shape_inference::TRTEngineOpShapeInference); + .Output("out_tensor: OutT"); +// TODO(jie): TF requires concrete output shape for concrete input shapes. +// This is tricky for batch dimension, since we cannot ensure which input +// would carry the correct batch dimension (for the current stage of the +// implementation, we do require all input tensor to carry the same batch +// size, but this could change in the future). Hence we disable shape +// inference function as a workaround. +// .SetShapeFn(shape_inference::TRTEngineOpShapeInference); } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 227ac120dd..f30dba59ad 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -28,36 +28,50 @@ limitations under the License. namespace tensorflow { namespace shape_inference { -tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { - std::vector<tensorflow::TensorShape> shapes; - for (int i = 0; i < context->num_outputs(); ++i) { - context->set_output(i, context->UnknownShape()); +tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->UnknownShape()); } - auto status = context->GetAttr("input_shapes", &shapes); - // it is ok to not to have shapes - if (!status.ok()) return Status::OK(); - if ((int)shapes.size() != context->num_inputs()) return Status::OK(); - bool different_input = false; - for (int i = 0; i < context->num_inputs(); ++i) { - if (shapes.at(i) != context->input_tensor(i)->shape()) - different_input = true; + + // Check the sanity of the input shapes. + std::vector<tensorflow::TensorShape> input_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("input_shapes", &input_shapes)); + if (input_shapes.size() != c->num_inputs()) { + return tensorflow::errors::InvalidArgument( + "The actual number of inputs doesn't match the number of input " + "shapes set in the attr: ", + c->num_inputs(), " vs ", input_shapes.size()); + } + bool input_match = true; + for (int i = 0; i < c->num_inputs(); ++i) { + ShapeHandle handle; + TF_RETURN_IF_ERROR( + c->MakeShapeFromTensorShape(input_shapes.at(i), &handle)); + ShapeHandle merged; + if (!c->Merge(c->input(i), handle, &merged).ok()) { + // Input shape doesn't match what was set in attr, fine. + input_match = false; + } } - if (different_input) return Status::OK(); - shapes.resize(0); - status = context->GetAttr("output_shapes", &shapes); - if (!status.ok()) return Status::OK(); - if ((int)shapes.size() != context->num_outputs()) return Status::OK(); - std::vector<ShapeHandle> shape_handles(shapes.size()); - for (size_t i = 0; i < shapes.size(); ++i) { - status = - context->MakeShapeFromTensorShape(shapes.at(i), &shape_handles.at(i)); - if (!status.ok()) return Status::OK(); + + // Check the sanity of the output shapes. + std::vector<tensorflow::TensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return tensorflow::errors::InvalidArgument( + "The actual number of outputs doesn't match the number of output " + "shapes set in the attr: ", + c->num_outputs(), " vs ", output_shapes.size()); } - for (int i = 0; i < context->num_outputs(); ++i) { - context->set_output(i, shape_handles.at(i)); + for (size_t i = 0; i < output_shapes.size(); ++i) { + ShapeHandle handle; + TF_RETURN_IF_ERROR( + c->MakeShapeFromTensorShape(output_shapes.at(i), &handle)); + if (input_match) c->set_output(i, handle); } return Status::OK(); } + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 0044fde9d0..ef6c752851 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -16,7 +16,6 @@ package( "//cloud/vmm/testing/tests/tpu:__subpackages__", "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", - "//third_party/cloud_tpu:__subpackages__", ], ) @@ -184,6 +183,7 @@ py_library( "//tensorflow/python:session", "//tensorflow/python:tensor_spec", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/keras:backend", "//tensorflow/python/keras:engine", diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 722e31abb2..8292c920fc 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -45,6 +45,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import contextlib import re @@ -63,9 +64,11 @@ from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import tpu_optimizer from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as tf_session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K from tensorflow.python.keras import models @@ -202,7 +205,6 @@ class TPURewriteContext(object): caller_obj = caller_frame.f_locals.get('self') if (caller_obj is not None and isinstance(caller_obj, base_layer.Layer) and name is not None): - logging.info('Intercepted name_scope: %s', caller_obj) return variable_scope.variable_scope( name, default_name, values, reuse=variable_scope.AUTO_REUSE) @@ -269,6 +271,329 @@ class TPURewriteContext(object): gen_linalg_ops.qr = self._default_qr +class SizedInfeed(collections.namedtuple('SizedInfeed', + ['sharded_infeed_tensors', + 'infeed_ops'])): + """Represents an instantiation of the infeed ops for a concrete input shape. + + sharded_infeed_tensors: A data structure of Tensors used to represent the + placeholder tensors that must be fed when using feed_dicts. + + infeed_ops: the set of ops that will be run to drive infeed for a single step. + """ + pass + + +class TPUInfeedInstance(object): + """TPUInfeedInstance represents the logic to manage feeding in a single step. + + See the comments on the `TPUInfeedManager` for a description for how infeed + is managed. + """ + + @abc.abstractmethod + def make_input_specs(self, input_tensors): + """Constructs the infeed_specs for the given Infeed instance. + + Args: + input_tensors: The inputs to the model. + + Returns: + A list of + """ + pass + + def make_feed_dict(self, tpu_model_op): + """Constructs a feed_dict for this instance, given the tpu_model_op. + + Args: + tpu_model_op: A `TPUModelOp` representing the TPU Model for this + instance's input spec. + + Returns: + A dictionary to use as the feed_dict of a `session.run` call. + """ + pass + + +class TPUInfeedManager(object): + """TPUInfeedManager manages the data infeeding of data to a TPU computation. + + Because there are multiple data sources (e.g. in-memory NumPy arrays, + `tf.data.Dataset`s), we abstract the different logic behind a single + interface: the `TPUInfeedManager`. + + (1) A `TPUFunction` is called with a set of inputs. Based on the inputs, + `TPUFunction` retrieves the corresponding `TPUInfeedManager` (or constructs a + new one if required). + + (2) The `TPUFunction` calls `make_infeed_instance` on the `TPUInfeedManager` + which returns a `TPUInfeedInstance`. + + (3) The `TPUFunction` checks in the shape cache for a pre-compiled instance of + the model based on the returned `input_specs` from `TPUInfeedInstance`. + + (4) [Optional.] If the model has not already been instantiated for the given + input spec, the `TPUFunction` compiles the model for the input spec (using the + `TPUInfeedManager`). + + (5) The `TPUInfeedInstance` constructs the session.run's feed_dict given the + compiled model instance corresponding to its shape. + """ + + @abc.abstractmethod + def make_infeed_instance(self, inputs): + """Given a single step's input, construct a `TPUInfeedInstance`. + + Args: + inputs: The inputs to a given step. + + Returns: + A subclass of `TPUInfeedInstance`. + """ + pass + + @abc.abstractmethod + def build_infeed_from_input_specs(self, input_specs, execution_mode): + """For a given input specification (size, type), construct the infeed ops. + + This is called only once for a given input specification and builds the + graph ops. It does not have a pointer to the actual infeed data. + + Args: + input_specs: TODO(saeta): Document me! + execution_mode: TODO(saeta): Document me! + + Returns: + A `SizedInfeed` instance. + """ + pass + + +class TPUNumpyInfeedManager(TPUInfeedManager): + """TPU Infeed manager for Numpy inputs.""" + + class NumpyInfeedInstance(TPUInfeedInstance): + """Infeed instance for Numpy inputs.""" + + def __init__(self, sharded_inputs): + self._sharded_inputs = sharded_inputs + + def make_input_specs(self, input_tensors): + # Compute an input specification (used to generate infeed enqueue and + # dequeue operations). We use the shape from our input array and the + # dtype from our model. A user may pass in a float64 for a float32 + # input: for model compatibility we still must generate a float32 infeed. + input_specs = [] + # We use the shape and dtype from the first shard to compute the input + # metadata (`input_specs`); all replicas have the same type and shape. + for tensor, ary in zip(input_tensors, self._sharded_inputs[0]): + input_specs.append( + tensor_spec.TensorSpec(ary.shape, tensor.dtype, + _valid_name(tensor.name))) + + return input_specs + + def make_feed_dict(self, tpu_model_op): + infeed_dict = {} + for infeed_tensors, inputs in zip(tpu_model_op.infeed_tensors, + self._sharded_inputs): + for tensor, value in zip(infeed_tensors, inputs): + infeed_dict[tensor] = value + return infeed_dict + + def __init__(self, distribution_strategy): + self._strategy = distribution_strategy + + def _split_tensors(self, inputs): + """Split input data across shards. + + Each input is sliced along the batch axis. + + Args: + inputs: List of Numpy arrays to run on the TPU. + + Returns: + List of lists containing the input to feed to each TPU shard. + """ + if self._strategy.num_towers == 1: + return [inputs] + + batch_size = inputs[0].shape[0] + assert batch_size % self._strategy.num_towers == 0, ( + 'batch_size must be divisible by strategy.num_towers (%s vs %s)' % + (batch_size, self._strategy.num_towers)) + shard_size = batch_size // self._strategy.num_towers + input_list = [] + for index in range(self._strategy.num_towers): + shard_inputs = [ + x[index * shard_size:(index + 1) * shard_size] for x in inputs + ] + input_list.append(shard_inputs) + return input_list + + def make_infeed_instance(self, inputs): + sharded_inputs = self._split_tensors(inputs) + return self.NumpyInfeedInstance(sharded_inputs) + + def build_infeed_from_input_specs(self, input_specs, execution_mode): + infeed_op = [] + shard_infeed_tensors = [] + + for shard_id in range(self._strategy.num_towers): + with ops.device('/device:TPU:%d' % shard_id): + infeed_tensors = [] + for spec in input_specs: + # Construct placeholders for each of the inputs. + infeed_tensors.append( + array_ops.placeholder( + dtype=spec.dtype, + shape=spec.shape, + name='infeed-enqueue-%s-%d' % (spec.name, shard_id))) + shard_infeed_tensors.append(infeed_tensors) + + infeed_op.append( + tpu_ops.infeed_enqueue_tuple( + infeed_tensors, [spec.shape for spec in input_specs], + name='infeed-enqueue-%s-%d' % (execution_mode, shard_id))) + return SizedInfeed(infeed_ops=infeed_op, + sharded_infeed_tensors=shard_infeed_tensors) + + +class TPUDatasetInfeedManager(TPUInfeedManager): + """Manages infeed for a `tf.data.Dataset` into a TPU computation. + """ + + class DatasetInfeedInstance(TPUInfeedInstance): + """An instance of the TPU infeed.""" + + def __init__(self, input_specs): + self._input_specs = input_specs + + def make_input_specs(self, input_tensors): + # TODO(saeta): Do error checking here! + return self._input_specs + + def make_feed_dict(self, tpu_model_op): + # TODO(saeta): Verify tpu_model_op is as expected! + return {} + + def __init__(self, dataset, distribution_strategy, tpu_session): + """Constructs a TPUDatasetInfeedManager. + + Must be called within a `KerasTPUModel.tpu_session` context! + + Args: + dataset: A `tf.data.Dataset` to infeed. + distribution_strategy: The `TPUDistributionStrategy` used to configure the + Keras TPU model. + tpu_session: The `tf.Session` object used for running the TPU model. + """ + self._verify_dataset_shape(dataset) + self._dataset = dataset + self._strategy = distribution_strategy + dummy_x_shape = dataset.output_shapes[0].as_list() + dummy_x_shape[0] *= distribution_strategy.num_towers + dummy_y_shape = dataset.output_shapes[1].as_list() + dummy_y_shape[0] *= distribution_strategy.num_towers + self._iterator = dataset.make_initializable_iterator() + tpu_session.run(self._iterator.initializer) + + self._get_next_ops = [] + ctrl_deps = [] + for i in range(distribution_strategy.num_towers): + with ops.control_dependencies(ctrl_deps): # Ensure deterministic + # TODO(saeta): Ensure correct placement! + get_next_op = self._iterator.get_next() + self._get_next_ops.append(get_next_op) + ctrl_deps.extend(get_next_op) + + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + self._dummy_x = np.zeros(dummy_x_shape, + dtype=dataset.output_types[0].as_numpy_dtype) + self._dummy_y = np.zeros(dummy_y_shape, + dtype=dataset.output_types[1].as_numpy_dtype) + + input_specs = [] + if isinstance(self._iterator.output_shapes, tuple): + assert isinstance(self._iterator.output_types, tuple) + assert len(self._iterator.output_shapes) == len( + self._iterator.output_types) + for i in range(len(self._iterator.output_shapes)): + spec = tensor_spec.TensorSpec(self._iterator.output_shapes[i], + self._iterator.output_types[i]) + input_specs.append(spec) + elif isinstance(self._iterator.output_shapes, tensor_shape.TensorShape): + spec = tensor_spec.TensorSpec(self._iterator.output_shapes, + self._iterator.output_types) + input_specs.append(spec) + + self._infeed_instance = self.DatasetInfeedInstance(input_specs) + + def _verify_dataset_shape(self, dataset): + """Verifies a dataset is of an appropriate shape for TPUs.""" + if not isinstance(dataset, dataset_ops.Dataset): + raise ValueError('The function passed as the `x` parameter did not ' + 'return a `tf.data.Dataset`.') + if not isinstance(dataset.output_classes, tuple): + raise ValueError('The dataset must return a tuple of tf.Tensors, ' + 'instead it returns: %s' % dataset.output_classes) + if len(dataset.output_classes) != 2: + raise ValueError( + 'The dataset must return a 2-element tuple, got ' + '%s output classes instead.' % (dataset.output_classes,)) + for i, cls in enumerate(dataset.output_classes): + if cls != ops.Tensor: + raise ValueError('The dataset returned a non-Tensor type (%s) at ' + 'index %d.' % (cls, i)) + for i, shape in enumerate(dataset.output_shapes): + if not shape: + raise ValueError('The dataset returns a scalar tensor in ' + 'tuple index %d. Did you forget to batch? ' + '(Output shapes: %s).' % (i, + dataset.output_shapes)) + for j, dim in enumerate(shape): + if dim.value is None: + if j == 0: + hint = (' Hint: did you use `ds.batch(BATCH_SIZE, ' + 'drop_remainder=True)`?') + else: + hint = '' + raise ValueError( + 'The Keras-TPU integration for `tf.data` ' + 'currently requires static shapes. The provided ' + 'dataset only has a partially defined shape. ' + '(Dimension %d of output tensor %d is not statically known ' + 'for output shapes: %s.%s)' % (i, j, dataset.output_shapes, hint)) + + @property + def dummy_x(self): + return self._dummy_x + + @property + def dummy_y(self): + return self._dummy_y + + def make_infeed_instance(self, inputs): + # TODO(saeta): Verify inputs is as expected. + return self._infeed_instance + + def build_infeed_from_input_specs(self, input_specs, execution_mode): + shard_infeed_tensors = self._get_next_ops + assert len(shard_infeed_tensors) == self._strategy.num_towers + infeed_ops = [] + for shard_id in range(self._strategy.num_towers): + with ops.device('/device:TPU:%d' % shard_id): + infeed_ops.append( + tpu_ops.infeed_enqueue_tuple( + shard_infeed_tensors[shard_id], + [spec.shape for spec in input_specs], + name='infeed-enqueue-%s-%d' % (execution_mode, shard_id))) + return SizedInfeed(infeed_ops=infeed_ops, + sharded_infeed_tensors=shard_infeed_tensors) + + class TPUFunction(object): """K.function compatible interface for invoking a TPU compiled function. @@ -294,7 +619,7 @@ class TPUFunction(object): if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer): self._optimizer_config = self.model.optimizer.get_config() - def _specialize_model(self, input_specs): + def _specialize_model(self, input_specs, infeed_manager): """Specialize `self.model` (a Keras model) for the given input shapes.""" # Re-create our input and output layers inside our subgraph. They will be # attached to the true computation when we clone our model in `tpu_fn`. @@ -320,8 +645,8 @@ class TPUFunction(object): name='infeed-%s' % self.execution_mode) assert len(infeed_tensors) == len(infeed_layers), ( - 'Infeed inputs did not match model: %s vs %s', (infeed_layers, - infeed_tensors)) + 'Infeed inputs did not match model: %s vs %s' % (infeed_layers, + infeed_tensors)) tpu_targets = [] tpu_input_map = {} @@ -410,26 +735,12 @@ class TPUFunction(object): # Generate CPU side operations to enqueue features/labels and dequeue # outputs from the model call. - infeed_op = [] + sized_infeed = infeed_manager.build_infeed_from_input_specs( + input_specs, self.execution_mode) + # Build output ops. outfeed_op = [] - shard_infeed_tensors = [] - for shard_id in range(self._strategy.num_towers): with ops.device('/device:TPU:%d' % shard_id): - infeed_tensors = [] - for spec in input_specs: - infeed_tensors.append( - array_ops.placeholder( - dtype=spec.dtype, - shape=spec.shape, - name='infeed-enqueue-%s-%d' % (spec.name, shard_id))) - shard_infeed_tensors.append(infeed_tensors) - - infeed_op.append( - tpu_ops.infeed_enqueue_tuple( - infeed_tensors, [spec.shape for spec in input_specs], - name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id))) - outfeed_op.extend( tpu_ops.outfeed_dequeue_tuple( dtypes=[spec.dtype for spec in self._outfeed_spec], @@ -439,8 +750,8 @@ class TPUFunction(object): return TPUModelOp( compile_op, execute_op, - infeed_tensors=shard_infeed_tensors, - infeed_op=infeed_op, + infeed_tensors=sized_infeed.sharded_infeed_tensors, + infeed_op=sized_infeed.infeed_ops, outfeed_op=outfeed_op) def _test_model_compiles(self, tpu_model_ops): @@ -459,36 +770,17 @@ class TPUFunction(object): logging.info('Finished compiling. Time elapsed: %s secs', end_time - start_time) - def _split_tensors(self, inputs): - """Split input data across shards. - - Each input is sliced along the batch axis. - - Args: - inputs: List of Numpy arrays to run on the TPU. - - Returns: - List of lists containing the input to feed to each TPU shard. - """ - if self._strategy.num_towers == 1: - return [inputs] - - batch_size = inputs[0].shape[0] - assert batch_size % self._strategy.num_towers == 0, ( - 'batch_size must be divisible by strategy.num_towers (%s vs %s)' % - (batch_size, self._strategy.num_towers)) - shard_size = batch_size // self._strategy.num_towers - input_list = [] - for index in range(self._strategy.num_towers): - shard_inputs = [ - x[index * shard_size:(index + 1) * shard_size] for x in inputs - ] - input_list.append(shard_inputs) - return input_list - def __call__(self, inputs): assert isinstance(inputs, list) + infeed_manager = None + for x, mgr in self.model._numpy_to_infeed_manager_list: + if inputs[0] is x: + infeed_manager = mgr + break + if infeed_manager is None: + infeed_manager = TPUNumpyInfeedManager(self.model._strategy) + # Strip sample weight from inputs if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or self.execution_mode == model_fn_lib.ModeKeys.EVAL): @@ -497,21 +789,9 @@ class TPUFunction(object): else: input_tensors = self.model._feed_inputs - shard_inputs = self._split_tensors(inputs) + infeed_instance = infeed_manager.make_infeed_instance(inputs) del inputs # To avoid accident usage. - - # Compute an input specification (used to generate infeed enqueue and - # dequeue operations). We use the shape from our input array and the - # dtype from our model. A user may pass in a float64 for a float32 - # input: for model compatibility we still must generate a float32 infeed. - input_specs = [] - - # We use the shape and dtype from the first shard to compute the input - # metadata (`input_specs`); all replicas have the same type and shape. - for tensor, ary in zip(input_tensors, shard_inputs[0]): - input_specs.append( - tensor_spec.TensorSpec(ary.shape, tensor.dtype, - _valid_name(tensor.name))) + input_specs = infeed_instance.make_input_specs(input_tensors) # XLA requires every operation in the graph has a fixed shape. To # handle varying batch sizes we recompile a new sub-graph for each @@ -522,7 +802,8 @@ class TPUFunction(object): with self.model.tpu_session(): logging.info('New input shapes; (re-)compiling: mode=%s, %s', self.execution_mode, input_specs) - new_tpu_model_ops = self._specialize_model(input_specs) + new_tpu_model_ops = self._specialize_model(input_specs, + infeed_manager) self._compilation_cache[shape_key] = new_tpu_model_ops self._test_model_compiles(new_tpu_model_ops) @@ -530,11 +811,7 @@ class TPUFunction(object): self.model._initialize_weights(self._cloned_model) tpu_model_ops = self._compilation_cache[shape_key] - infeed_dict = {} - for infeed_tensors, inputs in zip(tpu_model_ops.infeed_tensors, - shard_inputs): - for tensor, value in zip(infeed_tensors, inputs): - infeed_dict[tensor] = value + infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops) with self.model.tpu_session() as session: _, _, outfeed_outputs = session.run([ @@ -568,6 +845,11 @@ class KerasTPUModel(models.Model): name=cpu_model.name, ) + # Create a mapping from numpy arrays to infeed managers. + # Note: uses a list of tuples instead of a map because numpy arrays are + # not hashable. + self._numpy_to_infeed_manager_list = [] + self.predict_function = None self.test_function = None self.train_function = None @@ -640,6 +922,92 @@ class KerasTPUModel(models.Model): sample_weight_mode, weighted_metrics, target_tensors, **kwargs) + def fit(self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose=1, + callbacks=None, + validation_split=0., + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + **kwargs): + assert not self._numpy_to_infeed_manager_list # Ensure empty. + + infeed_managers = [] # Managers to clean up at the end of the fit call. + if isinstance(x, dataset_ops.Dataset): + # TODO(b/111413240): Support taking a tf.data.Dataset directly. + raise ValueError( + 'Taking a Dataset directly is not yet supported. Please ' + 'wrap your dataset construction code in a function and ' + 'pass that to fit instead. For examples, see: ' + 'https://github.com/tensorflow/tpu/tree/master/models/experimental' + '/keras') + if callable(x): + with self.tpu_session() as sess: + dataset = x() + if steps_per_epoch is None: + raise ValueError('When using tf.data as input to a model, you ' + 'should specify the steps_per_epoch argument.') + if y is not None: + raise ValueError('When using tf.data as input to a model, y must be ' + 'None') + infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + x = infeed_manager.dummy_x + y = infeed_manager.dummy_y + infeed_managers.append((x, infeed_manager)) + + if isinstance(validation_data, dataset_ops.Dataset): + # TODO(b/111413240): Support taking a tf.data.Dataset directly. + raise ValueError( + 'Taking a Dataset directly is not yet supported. Please ' + 'wrap your dataset construction code in a function and ' + 'pass that to fit instead. For examples, see: ' + 'https://github.com/tensorflow/tpu/tree/master/models/experimental' + '/keras') + if callable(validation_data): + with self.tpu_session() as sess: + dataset = validation_data() + if validation_steps is None: + raise ValueError('When using tf.data as validation for a model, you ' + 'should specify the validation_steps argument.') + infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + val_x = infeed_manager.dummy_x + val_y = infeed_manager.dummy_y + infeed_managers.append((val_x, infeed_manager)) + validation_data = (val_x, val_y) + + self._numpy_to_infeed_manager_list = infeed_managers + try: + return super(KerasTPUModel, self).fit( + x, + y, + batch_size, + epochs, + verbose, + callbacks, + validation_split, + validation_data, + shuffle, + class_weight, + sample_weight, + initial_epoch, + steps_per_epoch, + validation_steps, + **kwargs) + finally: + self._numpy_to_infeed_manager_list = [] + def _make_train_function(self): if not self.train_function: self.train_function = TPUFunction( diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 211c59cb90..e54395f05d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -234,7 +234,7 @@ class _InternalTPUContext(object): def mode(self): return self._assert_mode() - def _get_master_address(self): + def master_address(self): mode = self._assert_mode() config = self._config master = ( @@ -244,7 +244,7 @@ class _InternalTPUContext(object): def _get_tpu_system_metadata(self): """Gets the (maybe cached) TPU system metadata.""" - master = self._get_master_address() + master = self.master_address() tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) if tpu_system_metadata is not None: return tpu_system_metadata @@ -261,7 +261,7 @@ class _InternalTPUContext(object): def _get_device_assignment(self): """Gets the (maybe cached) TPU device assignment.""" - master = self._get_master_address() + master = self.master_address() device_assignment = self._lazy_device_assignment_dict.get(master) if device_assignment is not None: return device_assignment @@ -589,7 +589,7 @@ class _InternalTPUContext(object): 'model-parallelism, the total number of TPU cores should be ' 'num_cores_per_replica * num_replicas. Please set it ' 'accordingly or leave it as `None`'.format( - self._get_master_address(), num_replicas, + self.master_address(), num_replicas, user_provided_num_replicas)) raise ValueError(message) @@ -644,7 +644,7 @@ class _OneCoreTPUContext(_InternalTPUContext): def _get_tpu_system_metadata(self): """Gets the (maybe cached) TPU system metadata.""" - master = self._get_master_address() + master = self.master_address() tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) if tpu_system_metadata is not None: return tpu_system_metadata diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 74157a6193..aa407cf4d8 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -43,6 +43,7 @@ from tensorflow.contrib.training.python.training import hparam from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as session_lib from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib @@ -67,6 +68,7 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import evaluation +from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.training import training_util @@ -382,7 +384,14 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_ops = [tpu.initialize_system(job=self._master_job)] + self._init_ops = [] + # For distributed sessions, we can't run initialize_system in a separate + # graph here because 'begin' is only invoked when the MonitoredSession is + # created. We need to reinitialize the system every time MonitoredSession + # creates an underlying tf.Session, so we initialize from Scaffold.finalize. + # See _get_and_wrap_scaffold for more details. + if self._master_job is None: + self._init_ops.append(tpu.initialize_system(job=self._master_job)) self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() @@ -484,7 +493,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): return _OpQueueContext(name=name, target=target, args=args) def after_create_session(self, session, coord): - logging.info('Init TPU system') + logging.info('Running init_ops') session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) @@ -2700,7 +2709,7 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): outputs_from_all_shards=False, device_assignment=ctx.device_assignment) - scaffold = _get_scaffold(captured_scaffold_fn) + scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx) return loss, host_calls, scaffold @@ -2723,7 +2732,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): outputs_from_all_shards=False, device_assignment=ctx.device_assignment) - scaffold = _get_scaffold(captured_scaffold_fn) + scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx) return loss, host_call, scaffold @@ -2751,7 +2760,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): num_shards=num_cores, outputs_from_all_shards=False) - scaffold = _get_scaffold(captured_scaffold_fn) + scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx) return dummy_predict_op, host_calls, scaffold @@ -2841,8 +2850,20 @@ class _CapturedObject(object): return self._object -def _get_scaffold(captured_scaffold_fn): - """Retrieves the Scaffold from `captured_scaffold_fn`.""" +def _get_and_wrap_scaffold(captured_scaffold_fn, ctx): + """Retrieves the Scaffold from `captured_scaffold_fn`. + + Also wraps the scaffold's finalize method to initialize the TPU after the + graph is finalized. + + Args: + captured_scaffold_fn: a `_CapturedObject` containing a scaffold_fn. + ctx: A `_InternalTPUContext` instance used to initialize the TPU. + + Returns: + The Scaffold produced by captured_scaffold_fn, wrapped to initialize the TPU + after the graph is finalized. + """ with _CapturingContext(message='Inside scaffold_fn'): scaffold_fn = captured_scaffold_fn.get() if scaffold_fn: @@ -2853,14 +2874,64 @@ def _get_scaffold(captured_scaffold_fn): else: scaffold = None - if scaffold: - wrapped_finalize = scaffold.finalize - - def _finalize(): - with _CapturingContext('Inside Scaffold.finalize'): - wrapped_finalize() - - scaffold.finalize = _finalize + if scaffold is None: + # When master_address is None, we are using DirectSession, so we can't + # invoke initialize_system from finalize. See comments below. + if ctx.master_address() is None: + return scaffold + scaffold = monitored_session.Scaffold() + + wrapped_finalize = scaffold.finalize + + def _finalize(): + """Invoke wrapped_finalize and initialize the TPU.""" + with _CapturingContext('Inside Scaffold.finalize'): + wrapped_finalize() + # Run tpu.initialize_system in its own graph after finalizing the main graph + # for distributed sessions. This is necessary because the TPU must be + # initialized before the TPU graph rewrite pass runs. We can't put the + # initialization op in the main graph because the main graph also contains + # replicate ops created by tpu.shard. If we tried to run initialization from + # the main graph, the TPU graph rewrite pass would rewrite the replicate ops + # before actually evaluating the initialization ops. + # + # For distributed sessions, the master may independently restart. After a + # master restarts, the rewrite pass runs again when any op in the main graph + # runs, so we must reinitialize the system every time the main graph is + # finalized. + # + # Special case: When master_address is unset, we're using DirectSession. + # DirectSession resets device state between sessions, and uses + # place_pruned_graph. Initialization currently passes state to replication + # through the TPU_SYSTEM resource manager. Under DirectSession, this + # resource manager gets reset when init_session is closed, so DirectSession + # can't initialize here, and must instead initialize from the main graph's + # init_ops. This is possible with DirectSession because it uses + # place_pruned_graph, which removes unreferenced ops before invoking the + # rewrite pass. This makes it possible to run init_ops from the main graph, + # which contains both tpu.initialize_system and tpu.shard ops, without first + # triggering the TPU graph rewrite. We can't do this for distributed + # sessions because they don't support place_pruned_graph. + # + # TODO(b/110943344) Clean this up as part of the initialize_system dataflow + # cleanup. It should be possible to remove the special case for + # DirectSession and the other call to initialize_system from + # _obtain_topology, when topology info is always explicitly passed from + # tpu.initialize_system to tpu.shard, though this requires editing or + # rebuilding the main graph each time the master restarts. + if ctx.master_address() is None: + return + with ops.Graph().as_default(): + logging.info('Init TPU system master_address %s', ctx.master_address()) + with session_lib.Session( + ctx.master_address(), + config=ctx.config.session_config) as init_session: + run_options = config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000) + init_session.run( + tpu.initialize_system(job=ctx.master_job), options=run_options) + logging.info('TPU system initialized') + + scaffold.finalize = _finalize return scaffold diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index dbe87a6dbb..8a43220ec5 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2464,6 +2464,7 @@ tf_cuda_library( "framework/resource_handle.cc", "util/memmapped_file_system.*", "util/memmapped_file_system_writer.*", + "util/stats_calculator.*", "util/version_info.cc", ], ) + select({ @@ -2490,6 +2491,7 @@ tf_cuda_library( ":protos_all_proto_text", ":error_codes_proto_text", ":protos_all_cc", + ":stats_calculator_portable", ":version_lib", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/kernels:bounds_check", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index f903faf1bd..1732553abd 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1188,12 +1188,11 @@ Status DirectSession::CreateExecutors( delete kernel; } }; - params.node_outputs_cb = node_outputs_callback_; optimizer.Optimize(lib, options_.env, device, &iter->second, /*shape_map=*/nullptr); - // EXPERIMENTAL: tfdbg inserts debug nodes in the graph. + // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. const DebugOptions& debug_options = options.callable_options.run_options().debug_options(); if (!debug_options.debug_tensor_watch_opts().empty()) { diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 5f3809ddd6..8096139d90 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1966,14 +1966,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, device_context = device_context_map_[node->id()]; } - // Experimental: debugger (tfdb) access to intermediate node completion. - if (item.num_outputs == 0 && impl_->params_.node_outputs_cb != nullptr) { - // If the node has no output, invoke the callback with output slot set to - // -1, signifying that this is a no-output node. - s.Update(impl_->params_.node_outputs_cb(item.node->name(), -1, nullptr, - false, ctx)); - } - for (int i = 0; i < item.num_outputs; ++i) { const TensorValue val = ctx->release_output(i); if (val.tensor == nullptr) { @@ -2018,13 +2010,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, LogMemory::RecordTensorOutput(ctx->op_kernel().name(), ctx->step_id(), i, to_log); } - - // Experimental: debugger (tfdb) access to intermediate node - // outputs. - if (impl_->params_.node_outputs_cb != nullptr) { - s.Update(impl_->params_.node_outputs_cb(item.node->name(), i, - out->ref, true, ctx)); - } } else { // NOTE that std::move is used here, so val.tensor goes to // uninitialized state (val.tensor->IsInitialized return false). @@ -2036,12 +2021,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, LogMemory::RecordTensorOutput(ctx->op_kernel().name(), ctx->step_id(), i, *out->val); } - - // Experimental: debugger access to intermediate node outputs. - if (impl_->params_.node_outputs_cb != nullptr) { - s.Update(impl_->params_.node_outputs_cb( - item.node->name(), i, out->val.get(), false, ctx)); - } } } else { s.Update(errors::Internal("Output ", i, " of type ", diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index e5d7b7c53c..cd01b43aea 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -103,7 +103,6 @@ class Executor { const Tensor* tensor, const bool is_ref, OpKernelContext* ctx)> NodeOutputsCallback; - NodeOutputsCallback node_outputs_cb = nullptr; }; typedef std::function<void(const Status&)> DoneCallback; virtual void RunAsync(const Args& args, DoneCallback done) = 0; @@ -139,8 +138,6 @@ struct LocalExecutorParams { // when the executor is deleted. std::function<Status(const NodeDef&, OpKernel**)> create_kernel; std::function<void(OpKernel*)> delete_kernel; - - Executor::Args::NodeOutputsCallback node_outputs_cb; }; ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, std::unique_ptr<const Graph> graph, diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index 36e9b3455a..591c22b8f6 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -82,25 +82,6 @@ cc_library( ) tf_cuda_library( - name = "debug_gateway_internal", - srcs = ["debug_gateway.cc"], - hdrs = ["debug_gateway.h"], - copts = tf_copts(), - linkstatic = 1, - deps = [ - ":debug", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:direct_session_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:proto_text", - "//tensorflow/core:protos_all_cc", - ], - alwayslink = 1, -) - -tf_cuda_library( name = "debugger_state_impl", srcs = ["debugger_state_impl.cc"], hdrs = ["debugger_state_impl.h"], @@ -187,42 +168,6 @@ tf_cuda_library( ], ) -# TODO(cais): Fix flakiness on GPU and change this back to a tf_cc_test_gpu. -# See b/34081273. -tf_cc_test( - name = "debug_gateway_test", - size = "small", - srcs = ["debug_gateway_test.cc"], - args = ["--heap_check=local"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = [ - "no_cuda_on_cpu_tap", - "no_gpu", - ], - deps = [ - ":debug", - ":debug_gateway_internal", - ":debug_graph_utils", - "//tensorflow/cc:cc_ops", - "//tensorflow/core:all_kernels", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:direct_session", - "//tensorflow/core:direct_session_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:gpu_runtime", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/kernels:debug_ops", - "//tensorflow/core/kernels:ops_util", - ], -) - tf_cc_test( name = "debug_io_utils_test", size = "small", diff --git a/tensorflow/core/debug/debug_gateway.cc b/tensorflow/core/debug/debug_gateway.cc deleted file mode 100644 index 2e1aabd1cc..0000000000 --- a/tensorflow/core/debug/debug_gateway.cc +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/debug/debug_gateway.h" - -#include <utility> - -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/session_factory.h" -#include "tensorflow/core/framework/tensor.h" - -namespace tensorflow { - -DebugGateway::DebugGateway(DirectSession* session) : session_(session) { - session_->node_outputs_callback_ = - [this](const string& node_name, const int output_slot, - const Tensor* tensor, const bool is_ref, OpKernelContext* ctx) { - if (comp_cb_ != nullptr && output_slot <= 0) { - // The node completion callback is invoked once for a node regardless - // of whether the node has zero, one or more outputs. - // The output_slot can be negative (-1, or kControlSlot) if - // node_outputs_callback_ is invoked for a node with no output. If - // that is the case, notify the callback that the node in question has - // no output. - comp_cb_(node_name, output_slot == 0); - } - - // Copy tensor values (e.g., from GPU to host) only if the - // value callback is not nullptr. - if (val_cb_ != nullptr && output_slot >= 0) { - CopyTensor(node_name, output_slot, tensor, ctx, - [this, node_name, output_slot, - is_ref](const Tensor* copied_tensor) { - val_cb_(node_name, output_slot, *copied_tensor, is_ref); - }); - } - - return Status::OK(); - }; -} - -DebugGateway::~DebugGateway() { - if (session_ != nullptr) { - session_->node_outputs_callback_ = nullptr; - } -} - -void DebugGateway::SetNodeCompletionCallback(NodeCompletionCallback callback) { - comp_cb_ = std::move(callback); -} - -void DebugGateway::SetNodeValueCallback(NodeValueCallback callback) { - val_cb_ = std::move(callback); -} - -void DebugGateway::CopyTensor(const string& node_name, const int output_slot, - const Tensor* src_tensor, OpKernelContext* ctx, - CopyDoneCallback copy_done_cb) { - Device* device = static_cast<Device*>(ctx->device()); - - // Determine if the tensor is initialized properly. - // The second part of the check is necessary because in some cases, a - // tensor can pass the IsInitialized() check, but the dtype is not set, - // e.g., tf.FIFOQueue. - if (src_tensor->IsInitialized() && DataTypeSize(src_tensor->dtype()) > 0) { - // Tensor is initialized. - - string tensor_tag = strings::StrCat(node_name, ":", output_slot); - - // Create copied tensor on host - Allocator* cpu_allocator = tensorflow::cpu_allocator(); - Tensor cpu_tensor(cpu_allocator, src_tensor->dtype(), src_tensor->shape()); - - // Determine if the tensor is on device (GPU) or host (CPU). - // The second part of the check is necessary because even an OpKernel on - // may have output tensors allocated on CPU. - if ((device->name().find("GPU:") != string::npos || - device->name().find("SYCL:") != string::npos) && - !ctx->output_alloc_attr(output_slot).on_host()) { - // GPU tensors: Copy it to host (CPU). - DeviceContext* device_ctxt = ctx->op_device_context(); - - // Copy device (e.g., GPU) tensor to host and when done, invoke the - // callback. - device_ctxt->CopyDeviceTensorToCPU( - src_tensor, "TensorCopy", device, &cpu_tensor, - [node_name, cpu_tensor, copy_done_cb](const Status& s) { - if (s.ok()) { - copy_done_cb(&cpu_tensor); - } else { - LOG(ERROR) << "Copying of device Tensor " << node_name - << " to CPU for debugging failed."; - } - }); - } else { - // For CPU tensors, copy the source tensor and own the copy, because the - // value callback may outlive the life time of the tensor and the tensor - // may shared the underlying buffer with other tensors. - cpu_tensor.UnsafeCopyFromInternal(*src_tensor, src_tensor->dtype(), - src_tensor->shape()); - - copy_done_cb(&cpu_tensor); - } - } else { - // Tensor is not initialized: No need to copy. - copy_done_cb(src_tensor); - } -} - -} // namespace tensorflow diff --git a/tensorflow/core/debug/debug_gateway.h b/tensorflow/core/debug/debug_gateway.h deleted file mode 100644 index bf5b6e08db..0000000000 --- a/tensorflow/core/debug/debug_gateway.h +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_DEBUG_DEBUG_SESSION_H_ -#define TENSORFLOW_DEBUG_DEBUG_SESSION_H_ - -#include <unordered_map> - -#include "tensorflow/core/common_runtime/direct_session.h" -#include "tensorflow/core/common_runtime/executor.h" - -namespace tensorflow { - -// Experimental. tfdb (TensorFlow Debugger): Gateway to intermediate node -// outputs during Session Run calls. Currently limited to DirectSession. -class DebugGateway { - public: - DebugGateway(DirectSession* session); - virtual ~DebugGateway(); - - // Callback for node completion. This callback is invoked only once for - // a node regardless of whether it has one or more outputs. The value(s) of - // the output tensor(s) are not necessarily available when this callback is - // invoked. They may need to be asynchronously copied from device (e.g., - // GPU) to host, hence the need for the NodeValueCallback below. - // - // Args: - // node_name: Name of the node that has just completed execution - // any_output: Whether the node has any output(s) - typedef std::function<void(const string& node_name, const bool any_output)> - NodeCompletionCallback; - void SetNodeCompletionCallback(NodeCompletionCallback callback); - - // Callback for node value. This is invoked when the value of a node's - // output tensor is available on the host, possibly after copying from - // a device (e.g., GPU). - // - // Args: - // node_name: Name of the node of which the output has become available - // output_slot: Output slot number of the output Tensor - // tensor_value: Reference to the tensor value - // is_ref: Whether the output of the reference type - typedef std::function<void(const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref)> - NodeValueCallback; - void SetNodeValueCallback(NodeValueCallback callback); - - // TODO(cais): Add whitelists for ops/tensors (e.g., {"A:0", "B:0"}) - // for node completion callback (whitelist_comp_) and node value callback - // (whitelist_val_). If whitelist_comp_ is non-empty, the gateway will - // invoke the NodeCompletionCallback only for the nodes specified in the - // whitelist. And so forth for whitelist_val_. - - private: - DirectSession* session_; - // TODO(cais): DebugGateway currently supports only DirectSession. Add - // support for GrpcSession. - - NodeCompletionCallback comp_cb_ = nullptr; - NodeValueCallback val_cb_ = nullptr; - - typedef std::function<void(const Tensor* dst_tensor)> CopyDoneCallback; - - void CopyTensor(const string& node_name, const int output_slot, - const Tensor* src_tensor, OpKernelContext* ctx, - CopyDoneCallback copy_done_cb); -}; - -} // end namespace tensorflow - -#endif // TENSORFLOW_DEBUG_DEBUG_SESSION_H_ diff --git a/tensorflow/core/debug/debug_gateway_test.cc b/tensorflow/core/debug/debug_gateway_test.cc deleted file mode 100644 index b1bbd3f698..0000000000 --- a/tensorflow/core/debug/debug_gateway_test.cc +++ /dev/null @@ -1,1011 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/debug/debug_gateway.h" - -#include <algorithm> -#include <cstdlib> -#include <memory> -#include <unordered_map> - -#include "tensorflow/core/debug/debug_graph_utils.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/graph/testlib.h" -#include "tensorflow/core/lib/core/notification.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/protobuf/rewriter_config.pb.h" - -namespace tensorflow { -namespace { - -std::unique_ptr<DirectSession> CreateSession() { - SessionOptions options; - // Turn off graph optimizer so we can observe intermediate node states. - options.config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_opt_level(OptimizerOptions_Level_L0); - options.config.mutable_graph_options() - ->mutable_rewrite_options() - ->set_constant_folding(RewriterConfig::OFF); - options.config.mutable_graph_options() - ->mutable_rewrite_options() - ->set_dependency_optimization(RewriterConfig::OFF); - - return std::unique_ptr<DirectSession>( - dynamic_cast<DirectSession*>(NewSession(options))); -} - -class SessionDebugMinusAXTest : public ::testing::Test { - public: - void Initialize(std::initializer_list<float> a_values) { - Graph graph(OpRegistry::Global()); - -#if GOOGLE_CUDA - const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0"; -#elif defined(TENSORFLOW_USE_SYCL) - const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; -#else - const string kDeviceName = "/job:localhost/replica:0/task:0/device:CPU:0"; -#endif - - Tensor a_tensor(DT_FLOAT, TensorShape({2, 2})); - test::FillValues<float>(&a_tensor, a_values); - Node* a = test::graph::Constant(&graph, a_tensor); - a->set_assigned_device_name(kDeviceName); - a_ = a->name(); - - Tensor x_tensor(DT_FLOAT, TensorShape({2, 1})); - test::FillValues<float>(&x_tensor, {1, 1}); - Node* x = test::graph::Constant(&graph, x_tensor); - x->set_assigned_device_name(kDeviceName); - x_ = x->name(); - - // y = A * x - Node* y = test::graph::Matmul(&graph, a, x, false, false); - y->set_assigned_device_name(kDeviceName); - y_ = y->name(); - - Node* y_neg = test::graph::Unary(&graph, "Neg", y); - y_neg_ = y_neg->name(); - y_neg->set_assigned_device_name(kDeviceName); - - test::graph::ToGraphDef(&graph, &def_); - } - - string a_; - string x_; - string y_; - string y_neg_; - GraphDef def_; -}; - -TEST_F(SessionDebugMinusAXTest, RunSimpleNetwork) { - Initialize({3, 2, -1, 0}); - auto session = CreateSession(); - ASSERT_TRUE(session != nullptr); - - DebugGateway debug_gateway(session.get()); - - // Supply completion and value callbacks - mutex mu; - // Completed nodes with and without outputs - std::vector<string> completed_nodes_w_outputs; - std::vector<string> completed_nodes_wo_outputs; - - Notification callbacks_done; - debug_gateway.SetNodeCompletionCallback( - [&mu, &completed_nodes_w_outputs, &completed_nodes_wo_outputs]( - const string& node_name, const bool any_output) { - mutex_lock l(mu); - if (any_output) { - completed_nodes_w_outputs.push_back(node_name); - } else { - completed_nodes_wo_outputs.push_back(node_name); - } - }); - - std::vector<bool> tensors_initialized; - std::unordered_map<string, Tensor> tensor_vals; - // output_slot values recorded in value callbacks - std::vector<int> output_slots_val; - // is_ref values recorded in value callbacks - std::vector<bool> is_refs_val; - - debug_gateway.SetNodeValueCallback( - [this, &mu, &tensors_initialized, &tensor_vals, &output_slots_val, - &is_refs_val, - &callbacks_done](const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { - mutex_lock l(mu); - tensors_initialized.push_back(tensor_value.IsInitialized()); - tensor_vals.insert(std::make_pair(node_name, tensor_value)); - output_slots_val.push_back(output_slot); - is_refs_val.push_back(is_ref); - - // Set the notification once we have the value from the target node. - if (node_name == y_neg_ && !callbacks_done.HasBeenNotified()) { - callbacks_done.Notify(); - } - }); - - TF_ASSERT_OK(session->Create(def_)); - - std::vector<std::pair<string, Tensor>> inputs; - - // Request two targets: one fetch output and one non-fetched output. - std::vector<string> output_names = {y_ + ":0"}; - std::vector<string> target_nodes = {y_neg_}; - std::vector<Tensor> outputs; - Status s = session->Run(inputs, output_names, target_nodes, &outputs); - TF_ASSERT_OK(s); - - // Wait for callbacks to complete. - callbacks_done.WaitForNotification(); - - ASSERT_EQ(1, outputs.size()); - // The first output should be initialized and have the correct - // output. - auto mat = outputs[0].matrix<float>(); - ASSERT_TRUE(outputs[0].IsInitialized()); - EXPECT_FLOAT_EQ(5.0, mat(0, 0)); - - // Verify the calling history of the completion callback - // The following verifies each node with output(s) invoked the callback - // exactly once. - ASSERT_GE(completed_nodes_w_outputs.size(), 4); // There may be added nodes. - - ASSERT_EQ(1, std::count(completed_nodes_w_outputs.begin(), - completed_nodes_w_outputs.end(), a_)); - ASSERT_EQ(1, std::count(completed_nodes_w_outputs.begin(), - completed_nodes_w_outputs.end(), x_)); - ASSERT_EQ(1, std::count(completed_nodes_w_outputs.begin(), - completed_nodes_w_outputs.end(), y_)); - ASSERT_EQ(1, std::count(completed_nodes_w_outputs.begin(), - completed_nodes_w_outputs.end(), y_neg_)); - - // Apart from nodes with outputs, there are also no-output (control) nodes. - // They ought to be captured by the DebugGateway through - // NodeOutputCallback as well. - ASSERT_GT(completed_nodes_wo_outputs.size(), 0); - - // The DebugGateway should have captured the _SOURCE node. - ASSERT_LE(1, std::count(completed_nodes_wo_outputs.begin(), - completed_nodes_wo_outputs.end(), "_SOURCE")); - - // Verify the calling history of the value callabck - ASSERT_EQ(completed_nodes_w_outputs.size(), tensors_initialized.size()); - - // In this graph, there is no uninitialized node value. - ASSERT_EQ( - tensors_initialized.end(), - std::find(tensors_initialized.begin(), tensors_initialized.end(), false)); - - ASSERT_EQ(completed_nodes_w_outputs.size(), tensor_vals.size()); - ASSERT_EQ(completed_nodes_w_outputs.size(), output_slots_val.size()); - ASSERT_EQ(completed_nodes_w_outputs.size(), is_refs_val.size()); - - // Verify the intermediate tensor values captured through the value callback - auto mat_a = tensor_vals[a_].matrix<float>(); - ASSERT_EQ(3.0, mat_a(0, 0)); - ASSERT_EQ(2.0, mat_a(0, 1)); - ASSERT_EQ(-1.0, mat_a(1, 0)); - ASSERT_EQ(0.0, mat_a(1, 1)); - - auto mat_x = tensor_vals[x_].matrix<float>(); - ASSERT_EQ(1.0, mat_x(0, 0)); - ASSERT_EQ(1.0, mat_x(1, 0)); - - auto mat_y = tensor_vals[y_].matrix<float>(); - ASSERT_EQ(5.0, mat_y(0, 0)); - ASSERT_EQ(-1.0, mat_y(1, 0)); - - auto mat_y_neg = tensor_vals[y_neg_].matrix<float>(); - ASSERT_EQ(-5.0, mat_y_neg(0, 0)); - ASSERT_EQ(1.0, mat_y_neg(1, 0)); - - // In this graph, all outputs are on the first slot - ASSERT_EQ(output_slots_val.size(), - std::count_if(output_slots_val.begin(), output_slots_val.end(), - [](int slot) { return slot == 0; })); - - // In this graph, there is no ref-type tensor. - ASSERT_EQ(is_refs_val.end(), - std::find(is_refs_val.begin(), is_refs_val.end(), true)); -} - -TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) { - // Tensor contains one count of NaN - Initialize({3, std::numeric_limits<float>::quiet_NaN(), -1, 0}); - auto session = CreateSession(); - ASSERT_TRUE(session != nullptr); - - DebugGateway debug_gateway(session.get()); - - // Create debug tensor watch options with two debug ops: - // DebugIdentity and DebugNanCount - RunOptions run_opts; - run_opts.set_output_partition_graphs(true); - - const string debug_identity = "DebugIdentity"; - const string debug_nan_count = "DebugNanCount"; - DebugTensorWatch* tensor_watch_opts = - run_opts.mutable_debug_options()->add_debug_tensor_watch_opts(); - tensor_watch_opts->set_node_name(y_); - tensor_watch_opts->set_output_slot(0); - tensor_watch_opts->add_debug_ops(debug_identity); - tensor_watch_opts->add_debug_ops(debug_nan_count); - - // Expected name of the inserted debug node - string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName( - strings::StrCat(y_, ":", 0), 0, debug_identity); - string debug_nan_count_node_name = DebugNodeInserter::GetDebugNodeName( - strings::StrCat(y_, ":", 0), 1, debug_nan_count); - - // Supply completion and value callbacks - mutex mu; - // Completed nodes with and without outputs - std::vector<string> completed_debug_nodes; - - Notification callbacks_done; - debug_gateway.SetNodeCompletionCallback( - [&mu, &debug_identity_node_name, &debug_nan_count_node_name, - &completed_debug_nodes](const string& node_name, const bool any_output) { - mutex_lock l(mu); - if (any_output && (node_name == debug_identity_node_name || - node_name == debug_nan_count_node_name)) { - completed_debug_nodes.push_back(node_name); - } - }); - - std::vector<Tensor> watched_tensor_vals; - std::vector<Tensor> debug_identity_tensor_vals; - std::vector<Tensor> debug_nan_count_tensor_vals; - - debug_gateway.SetNodeValueCallback( - [this, &mu, &debug_identity_node_name, &debug_nan_count_node_name, - &watched_tensor_vals, &debug_identity_tensor_vals, - &debug_nan_count_tensor_vals, - &callbacks_done](const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { - mutex_lock l(mu); - if (node_name == y_) { - watched_tensor_vals.push_back(tensor_value); - } else if (node_name == debug_identity_node_name && output_slot == 0) { - // output_slot == 0 carries the debug signal. Same below. - debug_identity_tensor_vals.push_back(tensor_value); - } else if (node_name == debug_nan_count_node_name && output_slot == 0) { - debug_nan_count_tensor_vals.push_back(tensor_value); - } - - // Set the notification once we have the value from the target node. - if (node_name == y_neg_ && !callbacks_done.HasBeenNotified()) { - callbacks_done.Notify(); - } - }); - - TF_ASSERT_OK(session->Create(def_)); - - std::vector<std::pair<string, Tensor>> inputs; - - // Request two targets: one fetch output and one non-fetched output. - std::vector<string> output_names = {y_ + ":0"}; - std::vector<string> target_nodes = {y_neg_}; - std::vector<Tensor> outputs; - - RunMetadata run_metadata; - Status s = session->Run(run_opts, inputs, output_names, target_nodes, - &outputs, &run_metadata); - TF_ASSERT_OK(s); - -// Verify the correct number of partition graphs (GraphDefs) outputted -// through RunMetadata, given whether GPU is involved. -#if GOOGLE_CUDA - ASSERT_EQ(2, run_metadata.partition_graphs().size()); -#elif defined(TENSORFLOW_USE_SYCL) - ASSERT_EQ(2, run_metadata.partition_graphs().size()); -#else - ASSERT_EQ(1, run_metadata.partition_graphs().size()); -#endif - - // Wait for callbacks to complete. - callbacks_done.WaitForNotification(); - - // Verify that each of the two debug nodes has completed exactly once. - ASSERT_EQ(2, completed_debug_nodes.size()); - ASSERT_EQ( - 1, std::count(completed_debug_nodes.begin(), completed_debug_nodes.end(), - debug_identity_node_name)); - ASSERT_EQ( - 1, std::count(completed_debug_nodes.begin(), completed_debug_nodes.end(), - debug_nan_count_node_name)); - - // Verify that the tensor values from the watched node and the identity - // debug node are received and they are equal (owing to the debug op being - // "DebugIdentity") - ASSERT_EQ(1, watched_tensor_vals.size()); - ASSERT_EQ(1, debug_identity_tensor_vals.size()); - auto mat_y = watched_tensor_vals[0].matrix<float>(); - auto mat_identity = debug_identity_tensor_vals[0].matrix<float>(); - // ASSERT_EQ doesn't work for nan == nan - ASSERT_TRUE(std::isnan(mat_y(0, 0))); - ASSERT_TRUE(std::isnan(mat_identity(0, 0))); - ASSERT_EQ(-1, mat_identity(1, 0)); - - // Verify that the output from the NaN-count debug node indicates exactly - // one NaN. - ASSERT_EQ(1, debug_nan_count_tensor_vals.size()); - ASSERT_EQ(1, debug_nan_count_tensor_vals[0].scalar<int64>()()); -} - -#if !defined(GOOGLE_CUDA) && !defined(TENSORFLOW_USE_SYCL) -// TODO(cais): Reinstate the following test for concurrent debugged runs on -// a GPU once the root cause of the ~0.5% flakiness has been addressed. -// (b/34081273) -TEST_F(SessionDebugMinusAXTest, - RunSimpleNetworkConcurrentlyWithDifferentDebugTensorWatches) { - // Test concurrent Run() calls on a graph with different debug watches. - - Initialize({3, 2, -1, 0}); - auto session = CreateSession(); - ASSERT_TRUE(session != nullptr); - TF_ASSERT_OK(session->Create(def_)); - - // Number of concurrent Run() calls to launch. - const int kConcurrentRuns = 3; - thread::ThreadPool* tp = - new thread::ThreadPool(Env::Default(), "test", kConcurrentRuns); - - std::vector<string> output_names = {y_ + ":0"}; - std::vector<string> target_nodes = {y_neg_}; - - mutex mu; - DebugGateway debug_gateway(session.get()); - std::unordered_map<string, Tensor> debug_identity_tensor_vals; - - const string debug_identity = "DebugIdentity"; - - const string a_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName( - strings::StrCat(a_, ":", 0), 0, debug_identity); - const string x_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName( - strings::StrCat(x_, ":", 0), 0, debug_identity); - const string y_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName( - strings::StrCat(y_, ":", 0), 0, debug_identity); - - Notification callbacks_done; - volatile int val_callback_count = 0; - - debug_gateway.SetNodeValueCallback( - [this, &mu, &val_callback_count, &a_debug_identity_node_name, - &x_debug_identity_node_name, &y_debug_identity_node_name, - &debug_identity_tensor_vals, &callbacks_done, - &kConcurrentRuns](const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { - mutex_lock l(mu); - - if (node_name == a_debug_identity_node_name && output_slot == 0) { - debug_identity_tensor_vals["a"] = tensor_value; - val_callback_count++; - } else if (node_name == x_debug_identity_node_name && - output_slot == 0) { - // output_slot == 0 carries the debug signal. - debug_identity_tensor_vals["x"] = tensor_value; - val_callback_count++; - } else if (node_name == y_debug_identity_node_name && - output_slot == 0) { - debug_identity_tensor_vals["y"] = tensor_value; - val_callback_count++; - } - - // Set the notification once we have the value from the callbacks from - // all the concurrent Run() calls. - if (val_callback_count == kConcurrentRuns && - !callbacks_done.HasBeenNotified()) { - callbacks_done.Notify(); - } - }); - - int run_counter = 0; - mutex run_lock; - - // Function to be executed concurrently. - auto fn = [this, &run_lock, &run_counter, &session, output_names, - target_nodes, &debug_identity]() { - // Create unique debug tensor watch options for each of the concurrent - // run calls. - RunOptions run_opts; - run_opts.set_output_partition_graphs(true); - - DebugTensorWatch* tensor_watch_opts = - run_opts.mutable_debug_options()->add_debug_tensor_watch_opts(); - tensor_watch_opts->set_output_slot(0); - tensor_watch_opts->add_debug_ops(debug_identity); - - { - // Let the concurrent runs watch different tensors. - - mutex_lock l(run_lock); - - if (run_counter == 0) { - // Let the 1st concurrent run watch a. - tensor_watch_opts->set_node_name(a_); - } else if (run_counter == 1) { - // Let the 2nd concurrent watch x. - tensor_watch_opts->set_node_name(x_); - } else if (run_counter == 2) { - // Let the 3rd concurrent watch y. - tensor_watch_opts->set_node_name(y_); - } - - run_counter++; - } - - // Run the graph. - RunMetadata run_metadata; - std::vector<std::pair<string, Tensor>> inputs; - std::vector<Tensor> outputs; - Status s = session->Run(run_opts, inputs, output_names, target_nodes, - &outputs, &run_metadata); - TF_ASSERT_OK(s); - - ASSERT_EQ(1, run_metadata.partition_graphs().size()); - - ASSERT_EQ(1, outputs.size()); - ASSERT_TRUE(outputs[0].IsInitialized()); - ASSERT_EQ(TensorShape({2, 1}), outputs[0].shape()); - auto mat = outputs[0].matrix<float>(); - EXPECT_FLOAT_EQ(5.0, mat(0, 0)); - EXPECT_FLOAT_EQ(-1.0, mat(1, 0)); - }; - - for (int i = 0; i < kConcurrentRuns; ++i) { - tp->Schedule(fn); - } - - // Wait for the debug callbacks to finish. - callbacks_done.WaitForNotification(); - - // Wait for the concurrent functions with Run() calls to finish. - delete tp; - - { - mutex_lock l(mu); - - ASSERT_EQ(kConcurrentRuns, val_callback_count); - ASSERT_EQ(kConcurrentRuns, debug_identity_tensor_vals.size()); - - ASSERT_EQ(TensorShape({2, 2}), debug_identity_tensor_vals["a"].shape()); - auto a_mat_identity = debug_identity_tensor_vals["a"].matrix<float>(); - ASSERT_EQ(3.0, a_mat_identity(0, 0)); - ASSERT_EQ(2.0, a_mat_identity(0, 1)); - ASSERT_EQ(-1.0, a_mat_identity(1, 0)); - ASSERT_EQ(0.0, a_mat_identity(1, 1)); - - ASSERT_EQ(TensorShape({2, 1}), debug_identity_tensor_vals["x"].shape()); - auto x_mat_identity = debug_identity_tensor_vals["x"].matrix<float>(); - ASSERT_EQ(1.0, x_mat_identity(0, 0)); - ASSERT_EQ(1.0, x_mat_identity(1, 0)); - - ASSERT_EQ(TensorShape({2, 1}), debug_identity_tensor_vals["y"].shape()); - auto y_mat_identity = debug_identity_tensor_vals["y"].matrix<float>(); - ASSERT_EQ(5.0, y_mat_identity(0, 0)); - ASSERT_EQ(-1.0, y_mat_identity(1, 0)); - } -} -#endif - -class SessionDebugOutputSlotWithoutOutgoingEdgeTest : public ::testing::Test { - public: - void Initialize() { - Graph graph(OpRegistry::Global()); - -#if GOOGLE_CUDA - const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0"; -#elif defined(TENSORFLOW_USE_SYCL) - const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; -#else - const string kDeviceName = "/job:localhost/replica:0/task:0/device:CPU:0"; -#endif - - Tensor a_tensor(DT_FLOAT, TensorShape({1, 1})); - test::FillValues<float>(&a_tensor, {42.0}); - Node* a = test::graph::Constant(&graph, a_tensor); - a->set_assigned_device_name(kDeviceName); - - Node* c = test::graph::Constant(&graph, a_tensor); - c->set_assigned_device_name(kDeviceName); - c_ = c->name(); - - // Node c will be executed only because of the control edge from c to y. - // Its output slot (slot 0) does not have an outgoing edge. This test - // is for testing that the debugger can watch that slot properly. - Node* y = test::graph::NoOp(&graph, {c}); - y->set_assigned_device_name(kDeviceName); - y_ = y->name(); - - test::graph::ToGraphDef(&graph, &def_); - } - - string c_; - string y_; - GraphDef def_; -}; - -TEST_F(SessionDebugOutputSlotWithoutOutgoingEdgeTest, - WatchSlotWithoutOutgoingEdge) { - Initialize(); - auto session = CreateSession(); - ASSERT_TRUE(session != nullptr); - - DebugGateway debug_gateway(session.get()); - - // Supply completion and value callbacks - mutex mu; - - string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName( - strings::StrCat(c_, ":", 0), 0, "DebugIdentity"); - - Notification callbacks_done; - - std::vector<Tensor> debug_identity_tensor_vals; - debug_gateway.SetNodeValueCallback( - [this, &mu, &callbacks_done, &debug_identity_node_name, - &debug_identity_tensor_vals]( - const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { - mutex_lock l(mu); - - if (node_name == debug_identity_node_name && output_slot == 0) { - debug_identity_tensor_vals.push_back(tensor_value); - - if (!callbacks_done.HasBeenNotified()) { - callbacks_done.Notify(); - } - } - }); - - // Add DebugIdentity watch on c:0, which does not have an outgoing edge. - RunOptions run_opts; - run_opts.set_output_partition_graphs(true); - - DebugTensorWatch* tensor_watch_opts = - run_opts.mutable_debug_options()->add_debug_tensor_watch_opts(); - tensor_watch_opts->set_node_name(c_); - tensor_watch_opts->set_output_slot(0); - tensor_watch_opts->add_debug_ops("DebugIdentity"); - - TF_ASSERT_OK(session->Create(def_)); - - // Invoke Session::Run() on y. - std::vector<std::pair<string, Tensor>> inputs; - std::vector<string> output_names; - std::vector<string> target_nodes = {y_}; - std::vector<Tensor> outputs; - - RunMetadata run_metadata; - Status s = session->Run(run_opts, inputs, output_names, target_nodes, - &outputs, &run_metadata); - TF_ASSERT_OK(s); - - // Wait for callbacks to complete. - callbacks_done.WaitForNotification(); - - // Assert that DebugIdentity node watching the control edge has been run. - ASSERT_EQ(1, debug_identity_tensor_vals.size()); - auto mat_identity = debug_identity_tensor_vals[0].matrix<float>(); - ASSERT_EQ(42.0, mat_identity(0, 0)); -} - -class SessionDebugVariableTest : public ::testing::Test { - public: - void Initialize() { - Graph graph(OpRegistry::Global()); - -#if GOOGLE_CUDA - const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0"; -#elif defined(TENSORFLOW_USE_SYCL) - const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; -#else - const string kDeviceName = "/job:localhost/replica:0/task:0/device:CPU:0"; -#endif - - // Define variable node. - var_node_name_ = "var"; - Node* var = - test::graph::Var(&graph, DT_FLOAT, TensorShape({3}), var_node_name_); - var->set_assigned_device_name(kDeviceName); - - // Define the initial value and the initial-value node. - Tensor nan_nan_seven(DT_FLOAT, TensorShape({3})); - nan_nan_seven.flat<float>()(0) = std::numeric_limits<float>::quiet_NaN(); - nan_nan_seven.flat<float>()(1) = std::numeric_limits<float>::quiet_NaN(); - nan_nan_seven.flat<float>()(2) = 7.0; - - init_val_node_name_ = "init_val"; - Node* init_val = - test::graph::Constant(&graph, nan_nan_seven, init_val_node_name_); - init_val->set_assigned_device_name(kDeviceName); - - // Define node for variable value initialization - Node* init = test::graph::Assign(&graph, var, init_val); - init->set_assigned_device_name(kDeviceName); - init_node_name_ = init->name(); - - // Define new value node - Tensor nan_eight_eight(DT_FLOAT, TensorShape({3})); - nan_eight_eight.flat<float>()(0) = std::numeric_limits<float>::quiet_NaN(); - nan_eight_eight.flat<float>()(1) = 8.0; - nan_eight_eight.flat<float>()(2) = 8.0; - - Node* new_val = test::graph::Constant(&graph, nan_eight_eight); - new_val->set_assigned_device_name(kDeviceName); - new_val_node_name_ = new_val->name(); - - // Define node for assigning new value - Node* assign = test::graph::Assign(&graph, var, new_val); - assign->set_assigned_device_name(kDeviceName); - assign_node_name_ = assign->name(); - - test::graph::ToGraphDef(&graph, &def_); - } - - string var_node_name_; - string init_val_node_name_; - string init_node_name_; - string new_val_node_name_; - string assign_node_name_; - GraphDef def_; -}; - -TEST_F(SessionDebugVariableTest, WatchUninitializedVariableWithDebugOps) { - Initialize(); - auto session = CreateSession(); - ASSERT_TRUE(session != nullptr); - - DebugGateway debug_gateway(session.get()); - - TF_ASSERT_OK(session->Create(def_)); - - // Set up DebugTensorWatch for an uninitialized tensor (in node var). - RunOptions run_opts; - const string debug_identity = "DebugIdentity"; - DebugTensorWatch* tensor_watch_opts = - run_opts.mutable_debug_options()->add_debug_tensor_watch_opts(); - tensor_watch_opts->set_node_name(var_node_name_); - tensor_watch_opts->set_output_slot(0); - tensor_watch_opts->add_debug_ops(debug_identity); - - // Expected name of the inserted debug node - string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName( - strings::StrCat(var_node_name_, ":", 0), 0, debug_identity); - - // Supply completion and value callbacks - mutex mu; - // Completed nodes with and without outputs - std::vector<string> completed_debug_nodes; - - Notification callbacks_done; - debug_gateway.SetNodeCompletionCallback( - [this, &mu, &debug_identity_node_name, &completed_debug_nodes, - &callbacks_done](const string& node_name, const bool any_output) { - mutex_lock l(mu); - if (any_output && (node_name == debug_identity_node_name)) { - completed_debug_nodes.push_back(node_name); - } - }); - - std::vector<Tensor> debug_identity_tensor_vals; - - debug_gateway.SetNodeValueCallback( - [this, &mu, &debug_identity_node_name, &debug_identity_tensor_vals, - &callbacks_done](const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { - mutex_lock l(mu); - if (node_name == debug_identity_node_name && output_slot == 0) { - // output_slot == 0 carries the debug signal. Same below. - debug_identity_tensor_vals.push_back(tensor_value); - } - - // Set the notification once we have the value from the target node. - if (node_name == init_node_name_ && !callbacks_done.HasBeenNotified()) { - callbacks_done.Notify(); - } - }); - - // First run the initialization op - std::vector<std::pair<string, Tensor>> inputs_init; - std::vector<Tensor> outputs_init; - - RunMetadata run_metadata; - Status s = session->Run(run_opts, inputs_init, {init_node_name_}, {}, - &outputs_init, &run_metadata); - TF_ASSERT_OK(s); - - callbacks_done.WaitForNotification(); - - ASSERT_EQ(1, completed_debug_nodes.size()); - ASSERT_EQ( - 1, std::count(completed_debug_nodes.begin(), completed_debug_nodes.end(), - debug_identity_node_name)); - - // Assert the output reflects the uninitialized nature of var's tensor. - ASSERT_EQ(1, debug_identity_tensor_vals.size()); - ASSERT_FALSE(debug_identity_tensor_vals[0].IsInitialized()); - ASSERT_EQ(DT_FLOAT, debug_identity_tensor_vals[0].dtype()); - ASSERT_EQ(TensorShape({3}), debug_identity_tensor_vals[0].shape()); -} - -TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) { - // Tensor contains one count of NaN - Initialize(); - auto session = CreateSession(); - ASSERT_TRUE(session != nullptr); - - DebugGateway debug_gateway(session.get()); - - TF_ASSERT_OK(session->Create(def_)); - - // First run the initialization op - std::vector<std::pair<string, Tensor>> inputs_init; - std::vector<Tensor> outputs_init; - Status s = session->Run(inputs_init, {init_node_name_}, {}, &outputs_init); - TF_ASSERT_OK(s); - - // Create debug tensor watch options with two ref-type debug ops: - // DebugIdentity and DebugNanCount - RunOptions run_opts; - run_opts.set_output_partition_graphs(true); - const string debug_identity = "DebugIdentity"; - const string debug_nan_count = "DebugNanCount"; - DebugTensorWatch* tensor_watch_opts = - run_opts.mutable_debug_options()->add_debug_tensor_watch_opts(); - tensor_watch_opts->set_node_name(var_node_name_); - tensor_watch_opts->set_output_slot(0); - tensor_watch_opts->add_debug_ops(debug_identity); - tensor_watch_opts->add_debug_ops(debug_nan_count); - - char tempdir_template[] = "/tmp/tfdbg_XXXXXX"; - string temp_dir(mkdtemp(tempdir_template)); - tensor_watch_opts->add_debug_urls(strings::StrCat("file://", temp_dir)); - - // Expected name of the inserted debug node - string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName( - strings::StrCat(var_node_name_, ":", 0), 0, debug_identity); - string debug_nan_count_node_name = DebugNodeInserter::GetDebugNodeName( - strings::StrCat(var_node_name_, ":", 0), 1, debug_nan_count); - - // Supply completion and value callbacks - mutex mu; - // Completed nodes with and without outputs - std::vector<string> completed_debug_nodes; - - Notification callbacks_done; - debug_gateway.SetNodeCompletionCallback( - [this, &mu, &debug_identity_node_name, &debug_nan_count_node_name, - &completed_debug_nodes, - &callbacks_done](const string& node_name, const bool any_output) { - mutex_lock l(mu); - if (any_output && (node_name == debug_identity_node_name || - node_name == debug_nan_count_node_name)) { - completed_debug_nodes.push_back(node_name); - } - }); - - std::vector<Tensor> debug_identity_tensor_vals; - std::vector<Tensor> debug_nan_count_tensor_vals; - - debug_gateway.SetNodeValueCallback( - [this, &mu, &debug_identity_node_name, &debug_nan_count_node_name, - &debug_identity_tensor_vals, &debug_nan_count_tensor_vals, - &callbacks_done](const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { - mutex_lock l(mu); - if (node_name == debug_identity_node_name && output_slot == 0) { - // output_slot == 0 carries the debug signal. Same below. - debug_identity_tensor_vals.push_back(tensor_value); - } else if (node_name == debug_nan_count_node_name && output_slot == 0) { - debug_nan_count_tensor_vals.push_back(tensor_value); - } - - // Set the notification once we have the value from the target node. - if (node_name == assign_node_name_ && - !callbacks_done.HasBeenNotified()) { - callbacks_done.Notify(); - } - }); - - // // Request two targets: one fetch output and one non-fetched output. - std::vector<std::pair<string, Tensor>> inputs; - std::vector<string> output_names = {assign_node_name_ + ":0"}; - std::vector<string> target_nodes = {assign_node_name_}; - std::vector<Tensor> outputs; - - // Run with RunOptions that has tensor watches - RunMetadata run_metadata; - s = session->Run(run_opts, inputs, output_names, target_nodes, &outputs, - &run_metadata); - TF_ASSERT_OK(s); - -#if GOOGLE_CUDA - ASSERT_EQ(2, run_metadata.partition_graphs().size()); -#elif defined(TENSORFLOW_USE_SYCL) - ASSERT_EQ(2, run_metadata.partition_graphs().size()); -#else - ASSERT_EQ(1, run_metadata.partition_graphs().size()); -#endif - - // Wait for callbacks to complete. - callbacks_done.WaitForNotification(); - - // Verify that the update has happened properly. - ASSERT_EQ(1, outputs.size()); - ASSERT_TRUE(std::isnan(outputs[0].vec<float>()(0))); - ASSERT_EQ(8.0, outputs[0].vec<float>()(1)); // Expect new value - ASSERT_EQ(8.0, outputs[0].vec<float>()(2)); // Expect new value - - // Verify that each of the two debug nodes has completed exactly once. - ASSERT_EQ(2, completed_debug_nodes.size()); - ASSERT_EQ( - 1, std::count(completed_debug_nodes.begin(), completed_debug_nodes.end(), - debug_identity_node_name)); - ASSERT_EQ( - 1, std::count(completed_debug_nodes.begin(), completed_debug_nodes.end(), - debug_nan_count_node_name)); - - // Verify that the values from the ref identity node reflects the value - // before the new assign. - ASSERT_EQ(1, debug_identity_tensor_vals.size()); - - auto vec_identity = debug_identity_tensor_vals[0].vec<float>(); - ASSERT_TRUE(std::isnan(vec_identity(0))); - ASSERT_TRUE(std::isnan(vec_identity(1))); - ASSERT_EQ(7.0, vec_identity(2)); - - // Verify that the output from the NaN-count debug node indicates exactly - // two NaNs, i.e., reflecting the value before the new assign. - ASSERT_EQ(1, debug_nan_count_tensor_vals.size()); - ASSERT_EQ(2, debug_nan_count_tensor_vals[0].scalar<int64>()()); -} - -#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_SYCL) -class SessionDebugGPUSwitchTest : public ::testing::Test { - public: - void Initialize() { - Graph graph(OpRegistry::Global()); - -#ifdef GOOGLE_CUDA - const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0"; -#elif TENSORFLOW_USE_SYCL - const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; -#endif - - Tensor vb(DT_BOOL, TensorShape({})); - vb.scalar<bool>()() = true; - Tensor vi(DT_INT64, TensorShape({})); - vi.scalar<int>()() = 42; - // So vi is expected to be forwarded to the second output port of sw. - - Node* pred = test::graph::Constant(&graph, vb); - pred->set_assigned_device_name(kDeviceName); - pred_node_name_ = pred->name(); - - Node* value = test::graph::Constant(&graph, vi); - pred->set_assigned_device_name(kDeviceName); - value_node_name_ = value->name(); - - Node* sw = test::graph::Switch(&graph, value, pred); - sw->set_assigned_device_name(kDeviceName); - sw_node_name_ = sw->name(); - - Node* z = test::graph::Identity(&graph, sw, 1); - sw->set_assigned_device_name(kDeviceName); - z_node_name_ = z->name(); - - test::graph::ToGraphDef(&graph, &def_); - } - - string pred_node_name_; - string value_node_name_; - string sw_node_name_; - string z_node_name_; - GraphDef def_; -}; - -// Test for debug-watching tensors marked as HOST_MEMORY on GPU. -TEST_F(SessionDebugGPUSwitchTest, RunSwitchWithHostMemoryDebugOp) { - Initialize(); - auto session = CreateSession(); - ASSERT_TRUE(session != nullptr); - - DebugGateway debug_gateway(session.get()); - - RunOptions run_opts; - run_opts.set_output_partition_graphs(true); - // This is the name of the boolean tensor fed as pred to the Switch node. - // On GPU, this edge is HOST_MEMORY. - const string watched_tensor = strings::StrCat(pred_node_name_, "/_1"); - - const string debug_identity = "DebugIdentity"; - DebugTensorWatch* tensor_watch_opts = - run_opts.mutable_debug_options()->add_debug_tensor_watch_opts(); - tensor_watch_opts->set_node_name(watched_tensor); - tensor_watch_opts->set_output_slot(0); - tensor_watch_opts->add_debug_ops(debug_identity); - - // Expected name of the inserted debug node - string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName( - strings::StrCat(watched_tensor, ":", 0), 0, debug_identity); - - // Supply completion and value callbacks - mutex mu; - // Completed nodes with and without outputs - std::vector<string> completed_nodes_w_outputs; - std::vector<string> completed_nodes_wo_outputs; - - Notification callbacks_done; - debug_gateway.SetNodeCompletionCallback( - [&mu, &completed_nodes_w_outputs, &completed_nodes_wo_outputs]( - const string& node_name, const bool any_output) { - mutex_lock l(mu); - if (any_output) { - completed_nodes_w_outputs.push_back(node_name); - } else { - completed_nodes_wo_outputs.push_back(node_name); - } - }); - - std::vector<Tensor> debug_identity_tensor_vals; - - debug_gateway.SetNodeValueCallback( - [this, &mu, &debug_identity_node_name, &debug_identity_tensor_vals, - &callbacks_done](const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { - mutex_lock l(mu); - if (node_name == debug_identity_node_name && output_slot == 0) { - debug_identity_tensor_vals.push_back(tensor_value); - } - - // Set the notification once we have the value from the target node. - if (node_name == z_node_name_ && !callbacks_done.HasBeenNotified()) { - callbacks_done.Notify(); - } - }); - - TF_ASSERT_OK(session->Create(def_)); - - std::vector<std::pair<string, Tensor>> inputs; - - // Request two targets: one fetch output and one non-fetched output. - std::vector<string> output_names = {z_node_name_ + ":0"}; - std::vector<string> target_nodes = {z_node_name_}; - std::vector<Tensor> outputs; - - RunMetadata run_metadata; - Status s = session->Run(run_opts, inputs, output_names, target_nodes, - &outputs, &run_metadata); - TF_ASSERT_OK(s); - - ASSERT_EQ(2, run_metadata.partition_graphs().size()); - - // Wait for callbacks to complete. - callbacks_done.WaitForNotification(); - - ASSERT_EQ(1, debug_identity_tensor_vals.size()); - ASSERT_TRUE(debug_identity_tensor_vals[0].scalar<bool>()()); -} -#endif // GOOGLE_CUDA - -} // end namespace -} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index e2f13df19f..6c146036ae 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -261,7 +261,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph, /*shape_map=*/nullptr); - // EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph. + // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. if (!debug_options.debug_tensor_watch_opts().empty()) { TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug( debug_options, subgraph.get(), params.device)); diff --git a/tensorflow/core/ops/debug_ops.cc b/tensorflow/core/ops/debug_ops.cc index 5aebdca1ea..2d9b4360de 100644 --- a/tensorflow/core/ops/debug_ops.cc +++ b/tensorflow/core/ops/debug_ops.cc @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { -// EXPERIMENTAL: tfdbg debugger-inserted ops. +// TensorFlow Debugger-inserted ops. // These ops are used only internally by tfdbg. There is no API for users to // direct create them. Users can create them indirectly by using // RunOptions.debug_options during Session::Run() call. See tfdbg documentation diff --git a/tensorflow/core/protobuf/debug.proto b/tensorflow/core/protobuf/debug.proto index 499900f965..811cf406b9 100644 --- a/tensorflow/core/protobuf/debug.proto +++ b/tensorflow/core/protobuf/debug.proto @@ -7,7 +7,7 @@ option java_multiple_files = true; option java_package = "org.tensorflow.framework"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf"; -// EXPERIMENTAL. Option for watching a node. +// Option for watching a node in TensorFlow Debugger (tfdbg). message DebugTensorWatch { // Name of the node to watch. string node_name = 1; @@ -51,7 +51,7 @@ message DebugTensorWatch { bool tolerate_debug_op_creation_failures = 5; } -// EXPERIMENTAL. Options for initializing DebuggerState. +// Options for initializing DebuggerState in TensorFlow Debugger (tfdbg). message DebugOptions { // Debugging options repeated DebugTensorWatch debug_tensor_watch_opts = 4; diff --git a/tensorflow/docs_src/extend/new_data_formats.md b/tensorflow/docs_src/extend/new_data_formats.md index d1d1f69766..abbf47910e 100644 --- a/tensorflow/docs_src/extend/new_data_formats.md +++ b/tensorflow/docs_src/extend/new_data_formats.md @@ -77,18 +77,24 @@ can be used as a starting point for your implementation: #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" -namespace tensorflow { +namespace myproject { namespace { -class MyReaderDatasetOp : public DatasetOpKernel { +using ::tensorflow::DT_STRING; +using ::tensorflow::PartialTensorShape; +using ::tensorflow::Status; + +class MyReaderDatasetOp : public tensorflow::DatasetOpKernel { public: - MyReaderDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + MyReaderDatasetOp(tensorflow::OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) { // Parse and validate any attrs that define the dataset using // `ctx->GetAttr()`, and store them in member variables. } - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + void MakeDataset(tensorflow::OpKernelContext* ctx, + tensorflow::DatasetBase** output) override { // Parse and validate any input tensors 0that define the dataset using // `ctx->input()` or the utility function // `ParseScalarArgument<T>(ctx, &arg)`. @@ -99,14 +105,14 @@ class MyReaderDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public tensorflow::GraphDatasetBase { public: - Dataset(OpKernelContext* ctx) : GraphDatasetBase(ctx) {} + Dataset(tensorflow::OpKernelContext* ctx) : GraphDatasetBase(ctx) {} - std::unique_ptr<IteratorBase> MakeIteratorInternal( + std::unique_ptr<tensorflow::IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::MyReader")})); + return std::unique_ptr<tensorflow::IteratorBase>(new Iterator( + {this, tensorflow::strings::StrCat(prefix, "::MyReader")})); } // Record structure: Each record is represented by a scalar string tensor. @@ -114,8 +120,8 @@ class MyReaderDatasetOp : public DatasetOpKernel { // Dataset elements can have a fixed number of components of different // types and shapes; replace the following two methods to customize this // aspect of the dataset. - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + const tensorflow::DataTypeVector& output_dtypes() const override { + static auto* const dtypes = new tensorflow::DataTypeVector({DT_STRING}); return *dtypes; } const std::vector<PartialTensorShape>& output_shapes() const override { @@ -132,16 +138,16 @@ class MyReaderDatasetOp : public DatasetOpKernel { // Implement this method if you want to be able to save and restore // instances of this dataset (and any iterators over it). Status AsGraphDefInternal(DatasetGraphDefBuilder* b, - Node** output) const override { + tensorflow::Node** output) const override { // Construct nodes to represent any of the input tensors from this // object's member variables using `b->AddScalar()` and `b->AddVector()`. - std::vector<Node*> input_tensors; + std::vector<tensorflow::Node*> input_tensors; TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); return Status::OK(); } private: - class Iterator : public DatasetIterator<Dataset> { + class Iterator : public tensorflow::DatasetIterator<Dataset> { public: explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params), i_(0) {} @@ -158,15 +164,15 @@ class MyReaderDatasetOp : public DatasetOpKernel { // return `Status::OK()`. // 3. If an error occurs, return an error status using one of the helper // functions from "tensorflow/core/lib/core/errors.h". - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, + Status GetNextInternal(tensorflow::IteratorContext* ctx, + std::vector<tensorflow::Tensor>* out_tensors, bool* end_of_sequence) override { // NOTE: `GetNextInternal()` may be called concurrently, so it is // recommended that you protect the iterator state with a mutex. - mutex_lock l(mu_); + tensorflow::mutex_lock l(mu_); if (i_ < 10) { // Create a scalar string tensor and add it to the output. - Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); + tensorflow::Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); record_tensor.scalar<string>()() = "MyReader!"; out_tensors->emplace_back(std::move(record_tensor)); ++i_; @@ -183,20 +189,20 @@ class MyReaderDatasetOp : public DatasetOpKernel { // // Implement these two methods if you want to be able to save and restore // instances of this iterator. - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); + Status SaveInternal(tensorflow::IteratorStateWriter* writer) override { + tensorflow::mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); return Status::OK(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); + Status RestoreInternal(tensorflow::IteratorContext* ctx, + tensorflow::IteratorStateReader* reader) override { + tensorflow::mutex_lock l(mu_); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); return Status::OK(); } private: - mutex mu_; + tensorflow::mutex mu_; int64 i_ GUARDED_BY(mu_); }; }; @@ -211,14 +217,14 @@ class MyReaderDatasetOp : public DatasetOpKernel { REGISTER_OP("MyReaderDataset") .Output("handle: variant") .SetIsStateful() - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn(tensorflow::shape_inference::ScalarShape); // Register the kernel implementation for MyReaderDataset. -REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(DEVICE_CPU), +REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(tensorflow::DEVICE_CPU), MyReaderDatasetOp); } // namespace -} // namespace tensorflow +} // namespace myproject ``` The last step is to build the C++ code and add a Python wrapper. The easiest way diff --git a/tensorflow/docs_src/guide/index.md b/tensorflow/docs_src/guide/index.md index eefdb9ceae..f78dfc9a89 100644 --- a/tensorflow/docs_src/guide/index.md +++ b/tensorflow/docs_src/guide/index.md @@ -16,15 +16,12 @@ works. The units are as follows: ## Estimators -* @{$estimators} provides an introduction. -* @{$premade_estimators}, introduces Estimators for machine learning. -* @{$custom_estimators}, which demonstrates how to build and train models you - design yourself. -* @{$feature_columns}, which shows how an Estimator can handle a variety of input - data types without changes to the model. -* @{$datasets_for_estimators} describes using tf.data with estimators. -* @{$checkpoints}, which explains how to save training progress and resume where - you left off. +* @{$estimators}, learn how to use Estimators for machine learning. +* @{$premade_estimators}, the basics of premade Estimators. +* @{$checkpoints}, save training progress and resume where you left off. +* @{$feature_columns}, handle a variety of input data types without changes to the model. +* @{$datasets_for_estimators}, use `tf.data` to input data. +* @{$custom_estimators}, write your own Estimator. ## Accelerators diff --git a/tensorflow/docs_src/guide/leftnav_files b/tensorflow/docs_src/guide/leftnav_files index 357a2a1cb9..b3324278c1 100644 --- a/tensorflow/docs_src/guide/leftnav_files +++ b/tensorflow/docs_src/guide/leftnav_files @@ -8,10 +8,10 @@ datasets.md ### Estimators estimators.md: Introduction to Estimators premade_estimators.md -custom_estimators.md +checkpoints.md feature_columns.md datasets_for_estimators.md -checkpoints.md +custom_estimators.md ### Accelerators using_gpu.md diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index f21c073a1b..541a55e184 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -511,6 +511,8 @@ on your system: list of supported GPU cards. * [GPU drivers](http://nvidia.com/drivers) that support your version of the CUDA Toolkit. +* NCCL 2.2 to use TensorFlow with multiple GPUs. For details, see [NVIDIA's + documentation](https://developer.nvidia.com/nccl). * The `libcupti-dev` library is the NVIDIA CUDA Profile Tools Interface. This library provides advanced profiling support. To install this library, use the following command for CUDA Toolkit >= 8.0: diff --git a/tensorflow/docs_src/tutorials/_index.yaml b/tensorflow/docs_src/tutorials/_index.yaml index 07d561b8a2..c74fe58089 100644 --- a/tensorflow/docs_src/tutorials/_index.yaml +++ b/tensorflow/docs_src/tutorials/_index.yaml @@ -175,7 +175,7 @@ landing_page: <a href="/guide/estimators">Estimators guide</a>. </p> <ol style="padding-left: 20px;"> - <li><a href="/guide/premade_estimators">Premade Estimators guide</a></li> + <li><a href="/tutorials/estimators/linear">Build a linear model with Estimators</a></li> <li><a href="https://github.com/tensorflow/models/tree/master/official/wide_deep" class="external">Wide and deep learning with Estimators</a></li> <li><a href="https://github.com/tensorflow/models/tree/master/official/boosted_trees" class="external">Boosted trees</a></li> <li><a href="/hub/tutorials/text_classification_with_tf_hub">How to build a simple text classifier with TF-Hub</a></li> diff --git a/tensorflow/docs_src/tutorials/_toc.yaml b/tensorflow/docs_src/tutorials/_toc.yaml index 4db97e35fc..d33869af6e 100644 --- a/tensorflow/docs_src/tutorials/_toc.yaml +++ b/tensorflow/docs_src/tutorials/_toc.yaml @@ -44,6 +44,8 @@ toc: - title: ML at production scale style: accordion section: + - title: Linear model with Estimators + path: /tutorials/estimators/linear - title: Wide and deep learning path: https://github.com/tensorflow/models/tree/master/official/wide_deep status: external diff --git a/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md b/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md index b45fbefac0..b564a27ecf 100644 --- a/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md +++ b/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md @@ -1,3 +1,3 @@ # Custom training: walkthrough -[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/eager.ipynb) +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/eager/custom_training_walkthrough.ipynb) diff --git a/tensorflow/docs_src/tutorials/estimators/linear.md b/tensorflow/docs_src/tutorials/estimators/linear.md new file mode 100644 index 0000000000..067a33ac03 --- /dev/null +++ b/tensorflow/docs_src/tutorials/estimators/linear.md @@ -0,0 +1,3 @@ +# Build a linear model with Estimators + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/estimators/linear.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/basic_classification.md b/tensorflow/docs_src/tutorials/keras/basic_classification.md index 91bbd85b24..e028af99b9 100644 --- a/tensorflow/docs_src/tutorials/keras/basic_classification.md +++ b/tensorflow/docs_src/tutorials/keras/basic_classification.md @@ -1,3 +1,3 @@ # Basic Classification -[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/basic_classification.ipynb) +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/basic_classification.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/basic_regression.md b/tensorflow/docs_src/tutorials/keras/basic_regression.md index a535f22f5a..8721b7aca1 100644 --- a/tensorflow/docs_src/tutorials/keras/basic_regression.md +++ b/tensorflow/docs_src/tutorials/keras/basic_regression.md @@ -1,3 +1,3 @@ # Basic Regression -[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/basic_regression.ipynb) +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/basic_regression.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/basic_text_classification.md b/tensorflow/docs_src/tutorials/keras/basic_text_classification.md index 7c5d4f7896..c2a16bdd20 100644 --- a/tensorflow/docs_src/tutorials/keras/basic_text_classification.md +++ b/tensorflow/docs_src/tutorials/keras/basic_text_classification.md @@ -1,3 +1,3 @@ # Basic Text Classification -[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/basic_text_classification.ipynb) +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/basic_text_classification.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md b/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md index e5b5ae7b5a..f07f3addd8 100644 --- a/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md +++ b/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md @@ -1,3 +1,3 @@ # Overfitting and Underfitting -[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/overfit_and_underfit.ipynb) +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/overfit_and_underfit.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md b/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md index 44b3772945..a799b379a0 100644 --- a/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md +++ b/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md @@ -1,3 +1,3 @@ # Save and restore Models -[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/save_and_restore_models.ipynb) +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/save_and_restore_models.ipynb) diff --git a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java index 7922f3329c..b063b6f1cd 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -47,7 +47,7 @@ public class SavedModelBundleTest { fail("not expected"); } catch (org.tensorflow.TensorFlowException e) { // expected exception - assertTrue(e.getMessage().contains("SavedModel not found")); + assertTrue(e.getMessage().contains("Could not find SavedModel")); } } } diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index a34a6fc053..a6906f9efd 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -669,6 +669,10 @@ def _trace_and_define_function(name, func, compiled, args, kwds): for collection in curr_graph.collections: tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( collection) + if context.executing_eagerly(): + tmp_graph.seed = context.global_seed() + else: + tmp_graph.seed = curr_graph.seed with tmp_graph.as_default(), AutomaticControlDependencies() as a: func_args = _get_defun_inputs(args) func_kwds = _get_defun_inputs(kwds) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index a3e63c3153..cdd9fe1760 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function as tf_function from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.layers import convolutional @@ -39,6 +40,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -136,6 +138,18 @@ class FunctionTest(test.TestCase): out = sq_op(t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) + def disabled_testRandomSeed(self): + + @function.defun + def f(): + return random_ops.random_normal(()) + + random_seed.set_random_seed(1) + x = f() + self.assertNotEqual(x, f()) + random_seed.set_random_seed(1) + self.assertAllEqual(f(), x) + def testNestedInputsDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 8ee38d35cc..6c415b1bf2 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -707,6 +707,14 @@ py_library( ) py_library( + name = "expect_h5py_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect h5py to already be installed on the system, e.g. via + # `pip install h5py' + visibility = ["//visibility:public"], +) + +py_library( name = "expect_six_installed", # This is a dummy rule used as a numpy dependency in open-source. # We expect six to already be installed on the system, e.g. via diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 2a0e4e7617..495d019f26 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -2304,6 +2304,43 @@ class EstimatorExportTest(test.TestCase): with self.assertRaisesRegexp(ValueError, err_regex): est._export_all_saved_models(export_dir_base, input_receiver_fn_map) + def test_export_all_saved_models_metric_operation(self): + """Ensures metrics ops.Operations can be expoerted (b/109740581).""" + + def _model_fn(features, labels, mode): + del features, labels # Unused + metrics = {'metrics': (constant_op.constant([0]), + control_flow_ops.no_op())} + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn) + est.train(input_fn=dummy_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('metric_operation_export')) + + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()} + + export_dir = est._export_all_saved_models( + export_dir_base, input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + meta_graph = loader.load(sess, [tag_constants.EVAL], export_dir) + sig_outputs = meta_graph.signature_def[ + model_fn_lib.ModeKeys.EVAL].outputs + self.assertEqual( + sig_outputs['metrics/update_op'].name, 'metric_op_wrapper:0') + def test_export_savedmodel_with_saveables_proto_roundtrip(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator( diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 6c26d29985..20382a58d8 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -23,6 +23,7 @@ import abc import six +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.saved_model import signature_def_utils @@ -338,8 +339,16 @@ class _SupervisedOutput(ExportOutput): raise ValueError( '{} update_op must be a Tensor or Operation; got {}.'.format( key, metric_op)) + + # We must wrap any ops in a Tensor before export, as the SignatureDef + # proto expects tensors only. See b/109740581 + metric_op_tensor = metric_op + if isinstance(metric_op, ops.Operation): + with ops.control_dependencies([metric_op]): + metric_op_tensor = constant_op.constant([], name='metric_op_wrapper') + outputs[val_name] = metric_val - outputs[op_name] = metric_op + outputs[op_name] = metric_op_tensor return outputs diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index b21ba91b0f..d94c764fd7 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -24,8 +24,10 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -335,5 +337,18 @@ class SupervisedOutputTest(test.TestCase): self.assertTrue("predictions/output1" in sig_def.outputs) self.assertTrue("features" in sig_def.inputs) + def test_metric_op_is_operation(self): + """Tests that ops.Operation is wrapped by a tensor for metric_ops.""" + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), control_flow_ops.no_op())} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + self.assertEqual( + outputter.metrics["metrics/update_op"].name, "metric_op_wrapper:0") + self.assertTrue( + isinstance(outputter.metrics["metrics/update_op"], ops.Tensor)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py index 7a4457f5a4..7a3c5a9bf1 100644 --- a/tensorflow/python/estimator/keras_test.py +++ b/tensorflow/python/estimator/keras_test.py @@ -32,7 +32,6 @@ from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils -from tensorflow.python.keras.applications import mobilenet from tensorflow.python.keras.optimizers import SGD from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile @@ -60,9 +59,9 @@ def simple_sequential_model(): return model -def simple_functional_model(): +def simple_functional_model(activation='relu'): a = keras.layers.Input(shape=_INPUT_SIZE) - b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dense(16, activation=activation)(a) b = keras.layers.Dropout(0.1)(b) b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) model = keras.models.Model(inputs=[a], outputs=[b]) @@ -474,21 +473,25 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): est_keras.train(input_fn=invald_output_name_input_fn, steps=100) def test_custom_objects(self): - keras_mobile = mobilenet.MobileNet(weights=None) - keras_mobile.compile(loss='categorical_crossentropy', optimizer='adam') + + def relu6(x): + return keras.backend.relu(x, max_value=6) + + keras_model = simple_functional_model(activation=relu6) + keras_model.compile(loss='categorical_crossentropy', optimizer='adam') custom_objects = { - 'relu6': mobilenet.relu6, - 'DepthwiseConv2D': mobilenet.DepthwiseConv2D + 'relu6': relu6 } + with self.assertRaisesRegexp(ValueError, 'relu6'): with self.test_session(): keras_lib.model_to_estimator( - keras_model=keras_mobile, + keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir)) with self.test_session(): keras_lib.model_to_estimator( - keras_model=keras_mobile, + keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir), custom_objects=custom_objects) diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py index e56c695a28..7285e03963 100644 --- a/tensorflow/python/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/applications/mobilenet.py @@ -72,13 +72,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend as K -from tensorflow.python.keras import constraints -from tensorflow.python.keras import initializers -from tensorflow.python.keras import regularizers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine.base_layer import InputSpec from tensorflow.python.keras.layers import Activation from tensorflow.python.keras.layers import BatchNormalization from tensorflow.python.keras.layers import Conv2D @@ -87,10 +83,10 @@ from tensorflow.python.keras.layers import Dropout from tensorflow.python.keras.layers import GlobalAveragePooling2D from tensorflow.python.keras.layers import GlobalMaxPooling2D from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import ReLU from tensorflow.python.keras.layers import Reshape from tensorflow.python.keras.layers import ZeroPadding2D from tensorflow.python.keras.models import Model -from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging @@ -100,10 +96,6 @@ from tensorflow.python.util.tf_export import tf_export BASE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/' -def relu6(x): - return K.relu(x, max_value=6) - - @tf_export('keras.applications.mobilenet.preprocess_input') def preprocess_input(x): """Preprocesses a numpy array encoding a batch of images. @@ -130,12 +122,6 @@ def MobileNet(input_shape=None, classes=1000): """Instantiates the MobileNet architecture. - To load a MobileNet model via `load_model`, import the custom - objects `relu6` and pass them to the `custom_objects` parameter. - E.g. - model = load_model('mobilenet.h5', custom_objects={ - 'relu6': mobilenet.relu6}) - Arguments: input_shape: optional shape tuple, only to be specified if `include_top` is False (otherwise the input shape @@ -412,7 +398,7 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): strides=strides, name='conv1')(x) x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x) - return Activation(relu6, name='conv1_relu')(x) + return ReLU(6, name='conv1_relu')(x) def _depthwise_conv_block(inputs, @@ -479,7 +465,7 @@ def _depthwise_conv_block(inputs, use_bias=False, name='conv_dw_%d' % block_id)(x) x = BatchNormalization(axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x) - x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x) + x = ReLU(6, name='conv_dw_%d_relu' % block_id)(x) x = Conv2D( pointwise_conv_filters, (1, 1), @@ -489,4 +475,4 @@ def _depthwise_conv_block(inputs, name='conv_pw_%d' % block_id)( x) x = BatchNormalization(axis=channel_axis, name='conv_pw_%d_bn' % block_id)(x) - return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x) + return ReLU(6, name='conv_pw_%d_relu' % block_id)(x) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 5d66db232a..53d907a2cc 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -32,10 +32,8 @@ import numpy as np import six from tensorflow.python.keras import backend as K -from tensorflow.python.keras import optimizers from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops -from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as tf_summary from tensorflow.python.util.tf_export import tf_export @@ -644,35 +642,17 @@ class LearningRateScheduler(Callback): self.verbose = verbose def on_epoch_begin(self, epoch, logs=None): - # TODO(yashkatariya): Change the property checking when the learning - # rate attribute is unified across all TF Optimizers. - if isinstance(self.model.optimizer, optimizers.TFOptimizer): - if not hasattr(self.model.optimizer.optimizer, '_lr') and not hasattr( - self.model.optimizer.optimizer, '_learning_rate'): - raise ValueError( - 'TF Optimizer must have a "_lr" or "_learning_rate" attribute.') - else: - opt = self.model.optimizer.optimizer - if hasattr(opt, '_lr'): - opt_lr = Variable(opt._lr) # pylint: disable=protected-access - elif hasattr(opt, '_learning_rate'): - opt_lr = Variable(opt._learning_rate) # pylint: disable=protected-access - else: - if not hasattr(self.model.optimizer, 'lr'): - raise ValueError('Optimizer must have a "lr" attribute.') - else: - opt = self.model.optimizer - opt_lr = opt.lr - + if not hasattr(self.model.optimizer, 'lr'): + raise ValueError('Optimizer must have a "lr" attribute.') try: # new API - lr = float(K.get_value(opt_lr)) + lr = float(K.get_value(self.model.optimizer.lr)) lr = self.schedule(epoch, lr) except TypeError: # Support for old API for backward compatibility lr = self.schedule(epoch) if not isinstance(lr, (float, np.float32, np.float64)): raise ValueError('The output of the "schedule" function ' 'should be float.') - K.set_value(opt_lr, lr) + K.set_value(self.model.optimizer.lr, lr) if self.verbose > 0: print('\nEpoch %05d: LearningRateScheduler reducing learning ' 'rate to %s.' % (epoch + 1, lr)) diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 244d48591c..45598cafd3 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -29,16 +29,10 @@ import numpy as np from tensorflow.core.framework import summary_pb2 from tensorflow.python import keras -from tensorflow.python.eager import context -from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils -from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary.writer import writer_cache -from tensorflow.python.training.adam import AdamOptimizer -from tensorflow.python.training.gradient_descent import GradientDescentOptimizer - try: import h5py # pylint:disable=g-import-not-at-top @@ -376,76 +370,6 @@ class KerasCallbacksTest(test.TestCase): float(keras.backend.get_value( model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon() - @test_util.run_in_graph_and_eager_modes - def test_TF_LearningRateScheduler_Adam(self): - with self.test_session(): - with context.eager_mode(): - np.random.seed(1337) - (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( - train_samples=TRAIN_SAMPLES, - test_samples=TEST_SAMPLES, - input_shape=(INPUT_DIM,), - num_classes=NUM_CLASSES) - y_test = keras.utils.to_categorical(y_test) - y_train = keras.utils.to_categorical(y_train) - model = keras.models.Sequential() - model.add( - keras.layers.Dense( - NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu')) - model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax')) - model.compile( - loss='categorical_crossentropy', - optimizer=AdamOptimizer(), - metrics=['accuracy']) - cbks = [keras.callbacks.LearningRateScheduler(lambda x: 1. / (1. + x))] - model.fit( - x_train, - y_train, - batch_size=BATCH_SIZE, - validation_data=(x_test, y_test), - callbacks=cbks, - epochs=5, - verbose=0) - opt_lr = model.optimizer.optimizer._lr - self.assertLess( - float(keras.backend.get_value( - Variable(opt_lr))) - 0.2, keras.backend.epsilon()) - - @test_util.run_in_graph_and_eager_modes - def test_TF_LearningRateScheduler_GradientDescent(self): - with self.test_session(): - with context.eager_mode(): - np.random.seed(1337) - (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( - train_samples=TRAIN_SAMPLES, - test_samples=TEST_SAMPLES, - input_shape=(INPUT_DIM,), - num_classes=NUM_CLASSES) - y_test = keras.utils.to_categorical(y_test) - y_train = keras.utils.to_categorical(y_train) - model = keras.models.Sequential() - model.add( - keras.layers.Dense( - NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu')) - model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax')) - model.compile( - loss='categorical_crossentropy', - optimizer=GradientDescentOptimizer(1e-3), - metrics=['accuracy']) - cbks = [keras.callbacks.LearningRateScheduler(lambda x: 1. / (1. + x))] - model.fit( - x_train, - y_train, - batch_size=BATCH_SIZE, - validation_data=(x_test, y_test), - callbacks=cbks, - epochs=5, - verbose=0) - opt_lr = model.optimizer.optimizer._learning_rate - self.assertLess( - float(keras.backend.get_value( - Variable(opt_lr))) - 0.2, keras.backend.epsilon()) - def test_ReduceLROnPlateau(self): with self.test_session(): np.random.seed(1337) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 69d3aa4017..487418e694 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -197,7 +197,7 @@ cuda_py_test( cuda_py_test( name = "linear_operator_low_rank_update_test", - size = "medium", + size = "large", srcs = ["linear_operator_low_rank_update_test.py"], additional_deps = [ "//tensorflow/python/ops/linalg", diff --git a/tensorflow/python/lib/core/numpy.h b/tensorflow/python/lib/core/numpy.h index d4621d61ee..0098d938a0 100644 --- a/tensorflow/python/lib/core/numpy.h +++ b/tensorflow/python/lib/core/numpy.h @@ -30,9 +30,10 @@ limitations under the License. #endif // Place `<locale>` before <Python.h> to avoid build failure in macOS. -#include <Python.h> #include <locale> +#include <Python.h> + #include "numpy/arrayobject.h" #include "numpy/ufuncobject.h" diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc index 6b6c82015f..2ee898ea1d 100644 --- a/tensorflow/python/lib/core/py_util.cc +++ b/tensorflow/python/lib/core/py_util.cc @@ -16,9 +16,10 @@ limitations under the License. #include "tensorflow/python/lib/core/py_util.h" // Place `<locale>` before <Python.h> to avoid build failure in macOS. -#include <Python.h> #include <locale> +#include <Python.h> + #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 5b384fd596..9440bab9ee 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -1753,6 +1753,22 @@ def is_jpeg(contents, name=None): return math_ops.equal(substr, b'\xff\xd8\xff', name=name) +def _is_png(contents, name=None): + r"""Convenience function to check if the 'contents' encodes a PNG image. + + Args: + contents: 0-D `string`. The encoded image bytes. + name: A name for the operation (optional) + + Returns: + A scalar boolean tensor indicating if 'contents' may be a PNG image. + is_png is susceptible to false positives. + """ + with ops.name_scope(name, 'is_png'): + substr = string_ops.substr(contents, 0, 3) + return math_ops.equal(substr, b'\211PN', name=name) + + @tf_export('image.decode_image') def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None): """Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`, @@ -1830,8 +1846,8 @@ def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None): def check_png(): """Checks if an image is PNG.""" - is_png = math_ops.equal(substr, b'\211PN', name='is_png') - return control_flow_ops.cond(is_png, _png, check_gif, name='cond_png') + return control_flow_ops.cond( + _is_png(contents), _png, check_gif, name='cond_png') def _jpeg(): """Decodes a jpeg image.""" diff --git a/tensorflow/security/advisory/tfsa-2018-001.md b/tensorflow/security/advisory/tfsa-2018-001.md index bb97543a21..1966789c84 100644 --- a/tensorflow/security/advisory/tfsa-2018-001.md +++ b/tensorflow/security/advisory/tfsa-2018-001.md @@ -22,7 +22,7 @@ TensorFlow 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0 ### Mitigation We have patched the vulnerability in GitHub commit -[49f73c55](https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae4333c55). +[49f73c55](https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae433). If users are running TensorFlow in production or on untrusted data, they are encouraged to apply this patch. diff --git a/tensorflow/security/index.md b/tensorflow/security/index.md index ea39e17ab2..0f176151c2 100644 --- a/tensorflow/security/index.md +++ b/tensorflow/security/index.md @@ -4,7 +4,7 @@ We regularly publish security advisories about using TensorFlow. *Note*: In conjunction with these security advisories, we strongly encourage TensorFlow users to read and understand TensorFlow's security model as outlined -in (https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)[SECURITY.md]. +in [SECURITY.md](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md). | Advisory Number | Type | Versions affected | Reported by | Additional Information | |-----------------|--------------------|:-----------------:|-----------------------|-----------------------------| diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index e4241667ad..9259ebe869 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -24,7 +24,10 @@ load( "if_mkl", "if_mkl_lnx_x64" ) - +load( + "//third_party/mkl_dnn:build_defs.bzl", + "if_mkl_open_source_only", +) def register_extension_info(**kwargs): pass @@ -214,6 +217,7 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False): + if_cuda(["-DGOOGLE_CUDA=1"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"]) + if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) + + if_mkl_open_source_only(["-DDO_NOT_USE_ML"]) + if_mkl_lnx_x64(["-fopenmp"]) + if_android_arm(["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh index d0816c92b7..75da9bb835 100755 --- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh +++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh @@ -35,6 +35,30 @@ elif [[ ${BASH_VER_MAJOR} -eq 4 ]] && [[ ${BASH_VER_MINOR} -lt 2 ]]; then exit 1 fi +function is_absolute { + [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]] +} + +RUNFILES_MANIFEST_FILE="${TEST_SRCDIR}/MANIFEST" +function rlocation() { + if is_absolute "$1" ; then + # If the file path is already fully specified, simply return it. + echo "$1" + elif [[ -e "$TEST_SRCDIR/$1" ]]; then + # If the file exists in the $TEST_SRCDIR then just use it. + echo "$TEST_SRCDIR/$1" + elif [[ -e "$RUNFILES_MANIFEST_FILE" ]]; then + # If a runfiles manifest file exists then use it. + echo "$(grep "^$1 " "$RUNFILES_MANIFEST_FILE" | sed 's/[^ ]* //')" + fi +} + +TEST_BINARY="$(rlocation $TEST_WORKSPACE/${1#./})" +shift + +# Make sure /var/lock exists, this may not be true under MSYS +mkdir -p /var/lock + TF_GPU_COUNT=${TF_GPU_COUNT:-8} for i in `seq 0 $((TF_GPU_COUNT-1))`; do @@ -45,8 +69,8 @@ for i in `seq 0 $((TF_GPU_COUNT-1))`; do # This export only works within the brackets, so it is isolated to one # single command. export CUDA_VISIBLE_DEVICES=$i - echo "Running test $* on GPU $CUDA_VISIBLE_DEVICES" - $@ + echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES" + "$TEST_BINARY" $@ ) return_code=$? flock -u "$lock_fd" diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh index fe3bce428f..36b2142d95 100644 --- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh @@ -105,14 +105,18 @@ create_python_test_dir "${PY_TEST_DIR}" PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow-*.whl) reinstall_tensorflow_pip ${PIP_NAME} +TF_GPU_COUNT=${TF_GPU_COUNT:-8} + # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore, # which will result testing system installed tensorflow # GPU tests are very flaky when running concurrently, so set local_test_jobs=1 bazel test --announce_rc --config=opt -k --test_output=errors \ + --test_env=TF_GPU_COUNT \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ --define=no_tensorflow_py_deps=true --test_lang_filters=py \ --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss \ --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss --build_tests_only \ - --local_test_jobs=1 --test_timeout="300,450,1200,3600" \ + --local_test_jobs=$TF_GPU_COUNT --test_timeout="300,450,1200,3600" \ --flaky_test_attempts=3 \ //${PY_TEST_DIR}/tensorflow/python/... \ //${PY_TEST_DIR}/tensorflow/contrib/... diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index 5ec43b8cb8..2818b822b8 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -15,6 +15,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ git \ libcudnn7=7.1.4.18-1+cuda9.0 \ libcudnn7-dev=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 \ + libnccl-dev=2.2.13-1+cuda9.0 \ libcurl3-dev \ libfreetype6-dev \ libhdf5-serial-dev \ diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu index 9197651ff4..28d4371da3 100644 --- a/tensorflow/tools/docker/Dockerfile.gpu +++ b/tensorflow/tools/docker/Dockerfile.gpu @@ -13,6 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ cuda-cusparse-9-0 \ curl \ libcudnn7=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 \ libfreetype6-dev \ libhdf5-serial-dev \ libpng12-dev \ diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 173f418dc8..44d8a37a8f 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -143,6 +143,7 @@ genrule( "@zlib_archive//:zlib.h", ] + if_mkl([ "//third_party/mkl:LICENSE", + "//third_party/mkl_dnn:LICENSE", ]), outs = ["include/tensorflow/c/LICENSE"], cmd = "$(location :concat_licenses.sh) $(SRCS) >$@", @@ -182,6 +183,7 @@ genrule( "@zlib_archive//:zlib.h", ] + if_mkl([ "//third_party/mkl:LICENSE", + "//third_party/mkl_dnn:LICENSE", ]), outs = ["include/tensorflow/jni/LICENSE"], cmd = "$(location :concat_licenses.sh) $(SRCS) >$@", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index ac252143d7..6d876b786a 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -12,6 +12,7 @@ load( load("//third_party/mkl:build_defs.bzl", "if_mkl") load("//tensorflow:tensorflow.bzl", "if_cuda") load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib") load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps") # This returns a list of headers of all public header libraries (e.g., @@ -145,7 +146,6 @@ filegroup( "@gast_archive//:PKG-INFO", "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", - "@grpc//:LICENSE", "@highwayhash//:LICENSE", "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", @@ -154,8 +154,6 @@ filegroup( "@lmdb//:LICENSE", "@local_config_nccl//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", - "@grpc//third_party/nanopb:LICENSE.txt", - "@grpc//third_party/address_sorting:LICENSE", "@nasm//:LICENSE", "@nsync//:LICENSE", "@pcre//:LICENCE", @@ -169,7 +167,15 @@ filegroup( "@org_python_pypi_backports_weakref//:LICENSE", ] + if_mkl([ "//third_party/mkl:LICENSE", - ]) + tf_additional_license_deps(), + "//third_party/mkl_dnn:LICENSE", + ]) + if_not_system_lib( + "grpc", + [ + "@grpc//:LICENSE", + "@grpc//third_party/nanopb:LICENSE.txt", + "@grpc//third_party/address_sorting:LICENSE", + ], + ) + tf_additional_license_deps(), ) sh_binary( diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index b0089d3360..4101b34a11 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -27,7 +27,7 @@ function cp_external() { pushd . cd "$src_dir" - for f in `find . ! -type d ! -name '*.py' ! -path '*local_config_cuda*' ! -path '*local_config_tensorrt*' ! -path '*org_tensorflow*'`; do + for f in `find . ! -type d ! -name '*.py' ! -path '*local_config_cuda*' ! -path '*local_config_tensorrt*' ! -path '*local_config_syslibs*' ! -path '*org_tensorflow*'`; do mkdir -p "${dest_dir}/$(dirname ${f})" cp "${f}" "${dest_dir}/$(dirname ${f})/" done diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index cd4f17a5ff..378de4261c 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -8,6 +8,7 @@ load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") +load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure") load("//third_party/toolchains/clang6:repo.bzl", "clang6_configure") load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", "arm_compiler_configure") load("//third_party:repo.bzl", "tf_http_archive") @@ -35,6 +36,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): nccl_configure(name="local_config_nccl") git_configure(name="local_config_git") sycl_configure(name="local_config_sycl") + syslibs_configure(name="local_config_syslibs") python_configure(name="local_config_python") # For windows bazel build @@ -161,6 +163,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "2f945446b71336e7f5a2bcace1abcf0b23fbba368266c6a1be33de3de3b3c912", strip_prefix = "re2-2018-04-01", + system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"), ) tf_http_archive( @@ -226,6 +229,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011", strip_prefix = "nasm-2.13.03", build_file = clean_dep("//third_party:nasm.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"), ) tf_http_archive( @@ -237,6 +241,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "1a17020f859cb12711175a67eab5c71fc1904e04b587046218e36106e07eabde", strip_prefix = "libjpeg-turbo-1.5.3", build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"), ) tf_http_archive( @@ -249,6 +254,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "libpng-1.6.34", build_file = clean_dep("//third_party:png.BUILD"), patch_file = clean_dep("//third_party:png_fix_rpi.patch"), + system_build_file = clean_dep("//third_party/systemlibs:png.BUILD"), ) tf_http_archive( @@ -260,6 +266,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6", strip_prefix = "sqlite-amalgamation-3240000", build_file = clean_dep("//third_party:sqlite.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"), ) tf_http_archive( @@ -271,6 +278,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1", strip_prefix = "giflib-5.1.4", build_file = clean_dep("//third_party:gif.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:gif.BUILD"), ) tf_http_archive( @@ -282,6 +290,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", strip_prefix = "six-1.10.0", build_file = clean_dep("//third_party:six.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"), ) tf_http_archive( @@ -293,6 +302,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d", strip_prefix = "astor-0.6.2", build_file = clean_dep("//third_party:astor.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:astor.BUILD"), ) tf_http_archive( @@ -315,6 +325,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b", strip_prefix = "termcolor-1.1.0", build_file = clean_dep("//third_party:termcolor.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:termcolor.BUILD"), ) tf_http_archive( @@ -421,6 +432,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], strip_prefix = "pcre-8.42", build_file = clean_dep("//third_party:pcre.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:pcre.BUILD"), ) tf_http_archive( @@ -433,6 +445,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], strip_prefix = "swig-3.0.8", build_file = clean_dep("//third_party:swig.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:swig.BUILD"), ) tf_http_archive( @@ -444,6 +457,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], strip_prefix = "curl-7.60.0", build_file = clean_dep("//third_party:curl.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"), ) tf_http_archive( @@ -454,6 +468,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44", strip_prefix = "grpc-1.13.0", + system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"), ) tf_http_archive( @@ -472,11 +487,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/ae80745b73e435d07e7fb9c12589304ee29e7f59.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/ae80745b73e435d07e7fb9c12589304ee29e7f59.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bd8c8d759852871609ba2e4e79868420f751949d.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/bd8c8d759852871609ba2e4e79868420f751949d.tar.gz", ], - sha256 = "de69b6f92a634b4d12b9e03ebd8eb34c28f997d9480c28358d6efd4c433fe853", - strip_prefix = "llvm-ae80745b73e435d07e7fb9c12589304ee29e7f59", + sha256 = "0c63e8583b213543309e8577ffe87a0cf34cc22269630d2c5c2f0a2345fda4a8", + strip_prefix = "llvm-bd8c8d759852871609ba2e4e79868420f751949d", build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), ) @@ -489,6 +504,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28", strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb", build_file = clean_dep("//third_party:lmdb.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:lmdb.BUILD"), ) tf_http_archive( @@ -500,6 +516,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6", strip_prefix = "jsoncpp-1.8.4", build_file = clean_dep("//third_party:jsoncpp.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"), ) tf_http_archive( @@ -521,6 +538,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", strip_prefix = "zlib-1.2.11", build_file = clean_dep("//third_party:zlib.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"), ) tf_http_archive( @@ -542,6 +560,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4", strip_prefix = "snappy-1.1.7", build_file = clean_dep("//third_party:snappy.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"), ) tf_http_archive( @@ -612,6 +631,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8", strip_prefix = "jemalloc-4.4.0", build_file = clean_dep("//third_party:jemalloc.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"), ) java_import_external( @@ -690,6 +710,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "cython-0.28.4", build_file = clean_dep("//third_party:cython.BUILD"), delete = ["BUILD.bazel"], + system_build_file = clean_dep("//third_party/systemlibs:cython.BUILD"), ) tf_http_archive( @@ -722,6 +743,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://github.com/google/flatbuffers/archive/v1.9.0.tar.gz", ], build_file = clean_dep("//third_party/flatbuffers:flatbuffers.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:flatbuffers.BUILD"), ) native.new_http_archive( diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD index 5b01f6e3e4..d075809ee9 100644 --- a/third_party/mkl_dnn/BUILD +++ b/third_party/mkl_dnn/BUILD @@ -1 +1,11 @@ licenses(["notice"]) + +exports_files(["LICENSE"]) + +config_setting( + name = "using_mkl_dnn_only", + values = { + "define": "using_mkl_dnn_only=true", + }, + visibility = ["//visibility:public"], +) diff --git a/third_party/mkl_dnn/build_defs.bzl b/third_party/mkl_dnn/build_defs.bzl new file mode 100644 index 0000000000..7ce2a7d9b0 --- /dev/null +++ b/third_party/mkl_dnn/build_defs.bzl @@ -0,0 +1,13 @@ +def if_mkl_open_source_only(if_true, if_false = []): + """Shorthand for select()'ing on whether we're building with + MKL-DNN open source lib only, without depending on MKL binary form. + + Returns a select statement which evaluates to if_true if we're building + with MKL-DNN open source lib only. Otherwise, + the select statement evaluates to if_false. + + """ + return select({ + str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): if_true, + "//conditions:default": if_false, + }) diff --git a/third_party/mkl_dnn/mkldnn.BUILD b/third_party/mkl_dnn/mkldnn.BUILD index 68f24aabae..57d2e1292b 100644 --- a/third_party/mkl_dnn/mkldnn.BUILD +++ b/third_party/mkl_dnn/mkldnn.BUILD @@ -1,5 +1,10 @@ exports_files(["LICENSE"]) +load( + "@org_tensorflow//third_party/mkl_dnn:build_defs.bzl", + "if_mkl_open_source_only", +) + config_setting( name = "clang_linux_x86_64", values = { @@ -15,7 +20,14 @@ cc_library( "src/cpu/*.cpp", ]), hdrs = glob(["include/*"]), - copts = ["-fexceptions"] + select({ + copts = [ + "-fexceptions", + "-DUSE_MKL", + "-DUSE_CBLAS", + ] + if_mkl_open_source_only([ + "-UUSE_MKL", + "-UUSE_CBLAS", + ]) + select({ "@org_tensorflow//tensorflow:linux_x86_64": [ "-fopenmp", # only works with gcc ], @@ -33,4 +45,19 @@ cc_library( ], nocopts = "-fno-exceptions", visibility = ["//visibility:public"], + deps = select({ + "@org_tensorflow//tensorflow:linux_x86_64": [ + "@mkl_linux//:mkl_headers", + "@mkl_linux//:mkl_libs_linux", + ], + "@org_tensorflow//tensorflow:darwin": [ + "@mkl_darwin//:mkl_headers", + "@mkl_darwin//:mkl_libs_darwin", + ], + "@org_tensorflow//tensorflow:windows": [ + "@mkl_windows//:mkl_headers", + "@mkl_windows//:mkl_libs_windows", + ], + "//conditions:default": [], + }), ) diff --git a/third_party/repo.bzl b/third_party/repo.bzl index 9cee1fcc4b..5cb42691c5 100644 --- a/third_party/repo.bzl +++ b/third_party/repo.bzl @@ -35,6 +35,15 @@ def _get_env_var(ctx, name): else: return None +# Checks if we should use the system lib instead of the bundled one +def _use_system_lib(ctx, name): + syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS") + if syslibenv: + for n in syslibenv.strip().split(","): + if n.strip() == name: + return True + return False + # Executes specified command with arguments and calls 'fail' if it exited with # non-zero code def _execute_and_check_ret_code(repo_ctx, cmd_and_args): @@ -75,17 +84,28 @@ def _tf_http_archive(ctx): "Even if you don't have permission to mirror the file, please " + "put the correctly formatted mirror URL there anyway, because " + "someone will come along shortly thereafter and mirror the file.") - ctx.download_and_extract( - ctx.attr.urls, - "", - ctx.attr.sha256, - ctx.attr.type, - ctx.attr.strip_prefix) - if ctx.attr.delete: - _apply_delete(ctx, ctx.attr.delete) - if ctx.attr.patch_file != None: - _apply_patch(ctx, ctx.attr.patch_file) - if ctx.attr.build_file != None: + + use_syslib = _use_system_lib(ctx, ctx.attr.name) + if not use_syslib: + ctx.download_and_extract( + ctx.attr.urls, + "", + ctx.attr.sha256, + ctx.attr.type, + ctx.attr.strip_prefix) + if ctx.attr.delete: + _apply_delete(ctx, ctx.attr.delete) + if ctx.attr.patch_file != None: + _apply_patch(ctx, ctx.attr.patch_file) + + if use_syslib and ctx.attr.system_build_file != None: + # Use BUILD.bazel to avoid conflict with third party projects with + # BUILD or build (directory) underneath. + ctx.template("BUILD.bazel", ctx.attr.system_build_file, { + "%prefix%": ".." if _repos_are_siblings() else "external", + }, False) + + elif ctx.attr.build_file != None: # Use BUILD.bazel to avoid conflict with third party projects with # BUILD or build (directory) underneath. ctx.template("BUILD.bazel", ctx.attr.build_file, { @@ -102,7 +122,11 @@ tf_http_archive = repository_rule( "delete": attr.string_list(), "patch_file": attr.label(), "build_file": attr.label(), - }) + "system_build_file": attr.label(), + }, + environ=[ + "TF_SYSTEM_LIBS", + ]) """Downloads and creates Bazel repos for dependencies. This is a swappable replacement for both http_archive() and diff --git a/third_party/systemlibs/BUILD b/third_party/systemlibs/BUILD new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/third_party/systemlibs/BUILD diff --git a/third_party/systemlibs/BUILD.tpl b/third_party/systemlibs/BUILD.tpl new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/third_party/systemlibs/BUILD.tpl diff --git a/third_party/systemlibs/astor.BUILD b/third_party/systemlibs/astor.BUILD new file mode 100644 index 0000000000..497ec4bcea --- /dev/null +++ b/third_party/systemlibs/astor.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # New BSD + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +py_library( + name = "astor", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/build_defs.bzl.tpl b/third_party/systemlibs/build_defs.bzl.tpl new file mode 100644 index 0000000000..3faa46c581 --- /dev/null +++ b/third_party/systemlibs/build_defs.bzl.tpl @@ -0,0 +1,32 @@ +# -*- Python -*- +"""Skylark macros for system libraries. +""" + +SYSTEM_LIBS_ENABLED = %{syslibs_enabled} + +SYSTEM_LIBS_LIST = [ +%{syslibs_list} +] + + +def if_any_system_libs(a, b=[]): + """Conditional which evaluates to 'a' if any system libraries are configured.""" + if SYSTEM_LIBS_ENABLED: + return a + else: + return b + + +def if_system_lib(lib, a, b=[]): + """Conditional which evaluates to 'a' if we're using the system version of lib""" + + if SYSTEM_LIBS_ENABLED and lib in SYSTEM_LIBS_LIST: + return a + else: + return b + + +def if_not_system_lib(lib, a, b=[]): + """Conditional which evaluates to 'a' if we're using the system version of lib""" + + return if_system_lib(lib, b, a) diff --git a/third_party/systemlibs/curl.BUILD b/third_party/systemlibs/curl.BUILD new file mode 100644 index 0000000000..c5f125caa9 --- /dev/null +++ b/third_party/systemlibs/curl.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # MIT/X derivative license + +filegroup( + name = "COPYING", + visibility = ["//visibility:public"], +) + +cc_library( + name = "curl", + linkopts = ["-lcurl"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/cython.BUILD b/third_party/systemlibs/cython.BUILD new file mode 100644 index 0000000000..1d52587676 --- /dev/null +++ b/third_party/systemlibs/cython.BUILD @@ -0,0 +1,13 @@ +licenses(["notice"]) # Apache-2.0 + +genrule( + name = "lncython", + outs = ["cython"], + cmd = "ln -s $$(which cython) $@", +) + +sh_binary( + name = "cython_binary", + srcs = ["cython"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/flatbuffers.BUILD b/third_party/systemlibs/flatbuffers.BUILD new file mode 100644 index 0000000000..14fceada82 --- /dev/null +++ b/third_party/systemlibs/flatbuffers.BUILD @@ -0,0 +1,38 @@ +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "LICENSE.txt", + visibility = ["//visibility:public"], +) + +# Public flatc library to compile flatbuffer files at runtime. +cc_library( + name = "flatbuffers", + linkopts = ["-lflatbuffers"], + visibility = ["//visibility:public"], +) + +# Public flatc compiler library. +cc_library( + name = "flatc_library", + linkopts = ["-lflatbuffers"], + visibility = ["//visibility:public"], +) + +genrule( + name = "lnflatc", + outs = ["flatc.bin"], + cmd = "ln -s $$(which flatc) $@", +) + +# Public flatc compiler. +sh_binary( + name = "flatc", + srcs = ["flatc.bin"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "runtime_cc", + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/gif.BUILD b/third_party/systemlibs/gif.BUILD new file mode 100644 index 0000000000..5eb2c918ba --- /dev/null +++ b/third_party/systemlibs/gif.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # MIT + +filegroup( + name = "COPYING", + visibility = ["//visibility:public"], +) + +cc_library( + name = "gif", + linkopts = ["-lgif"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/grpc.BUILD b/third_party/systemlibs/grpc.BUILD new file mode 100644 index 0000000000..fd90eb0dd3 --- /dev/null +++ b/third_party/systemlibs/grpc.BUILD @@ -0,0 +1,54 @@ +licenses(["notice"]) # Apache v2 + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "grpc", + linkopts = ["-lgrpc"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "grpc++", + linkopts = ["-lgrpc++"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "grpc_unsecure", + linkopts = ["-lgrpc_unsecure"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "grpc++_unsecure", + linkopts = ["-lgrpc++_unsecure"], + visibility = ["//visibility:public"], +) + +genrule( + name = "ln_grpc_cpp_plugin", + outs = ["grpc_cpp_plugin.bin"], + cmd = "ln -s $$(which grpc_cpp_plugin) $@", +) + +sh_binary( + name = "grpc_cpp_plugin", + srcs = ["grpc_cpp_plugin.bin"], + visibility = ["//visibility:public"], +) + +genrule( + name = "ln_grpc_python_plugin", + outs = ["grpc_python_plugin.bin"], + cmd = "ln -s $$(which grpc_python_plugin) $@", +) + +sh_binary( + name = "grpc_python_plugin", + srcs = ["grpc_python_plugin.bin"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/jemalloc.BUILD b/third_party/systemlibs/jemalloc.BUILD new file mode 100644 index 0000000000..6a48d582ba --- /dev/null +++ b/third_party/systemlibs/jemalloc.BUILD @@ -0,0 +1,30 @@ +licenses(["notice"]) # BSD + +filegroup( + name = "COPYING", + visibility = ["//visibility:public"], +) + +cc_library( + name = "jemalloc_headers", + defines = [ + "jemalloc_posix_memalign=posix_memalign", + "jemalloc_malloc=malloc", + "jemalloc_realloc=realloc", + "jemalloc_free=free", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "jemalloc_impl", + linkopts = ["-ljemalloc"], + defines = [ + "jemalloc_posix_memalign=posix_memalign", + "jemalloc_malloc=malloc", + "jemalloc_realloc=realloc", + "jemalloc_free=free", + ], + visibility = ["//visibility:public"], + deps = [":jemalloc_headers"], +) diff --git a/third_party/systemlibs/jpeg.BUILD b/third_party/systemlibs/jpeg.BUILD new file mode 100644 index 0000000000..f4f52da9bd --- /dev/null +++ b/third_party/systemlibs/jpeg.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # custom notice-style license, see LICENSE.md + +filegroup( + name = "LICENSE.md", + visibility = ["//visibility:public"], +) + +cc_library( + name = "jpeg", + linkopts = ["-ljpeg"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/jsoncpp.BUILD b/third_party/systemlibs/jsoncpp.BUILD new file mode 100644 index 0000000000..cf91917cfb --- /dev/null +++ b/third_party/systemlibs/jsoncpp.BUILD @@ -0,0 +1,37 @@ +licenses(["unencumbered"]) # Public Domain or MIT + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +HEADERS = [ + "include/json/autolink.h", + "include/json/config.h", + "include/json/features.h", + "include/json/forwards.h", + "include/json/json.h", + "include/json/reader.h", + "include/json/value.h", + "include/json/version.h", + "include/json/writer.h", +] + +genrule( + name = "link_headers", + outs = HEADERS, + cmd = """ + for i in $(OUTS); do + i=$${i##*/} + ln -vsf /usr/include/jsoncpp/json/$$i $(@D)/include/json/$$i + done + """, +) + +cc_library( + name = "jsoncpp", + hdrs = HEADERS, + includes = ["."], + linkopts = ["-ljsoncpp"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/lmdb.BUILD b/third_party/systemlibs/lmdb.BUILD new file mode 100644 index 0000000000..6177b095ec --- /dev/null +++ b/third_party/systemlibs/lmdb.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # OpenLDAP Public License + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "lmdb", + linkopts = ["-llmdb"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/nasm.BUILD b/third_party/systemlibs/nasm.BUILD new file mode 100644 index 0000000000..10ef8d8832 --- /dev/null +++ b/third_party/systemlibs/nasm.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # BSD 2-clause + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +sh_binary( + name = "nasm", + srcs = ["nasm"], + visibility = ["@jpeg//:__pkg__"], +) diff --git a/third_party/systemlibs/pcre.BUILD b/third_party/systemlibs/pcre.BUILD new file mode 100644 index 0000000000..df74238847 --- /dev/null +++ b/third_party/systemlibs/pcre.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # BSD + +filegroup( + name = "LICENCE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "pcre", + linkopts = ["-lpcre"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/png.BUILD b/third_party/systemlibs/png.BUILD new file mode 100644 index 0000000000..fc6b6f2d8b --- /dev/null +++ b/third_party/systemlibs/png.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # BSD/MIT-like license + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "png", + linkopts = ["-lpng"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/re2.BUILD b/third_party/systemlibs/re2.BUILD new file mode 100644 index 0000000000..c18e252dbc --- /dev/null +++ b/third_party/systemlibs/re2.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # BSD/MIT-like license + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "re2", + linkopts = ["-lre2"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/six.BUILD b/third_party/systemlibs/six.BUILD new file mode 100644 index 0000000000..ff9b1a540b --- /dev/null +++ b/third_party/systemlibs/six.BUILD @@ -0,0 +1,11 @@ +licenses(["notice"]) # MIT + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +py_library( + name = "six", + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/snappy.BUILD b/third_party/systemlibs/snappy.BUILD new file mode 100644 index 0000000000..fd2db9e2df --- /dev/null +++ b/third_party/systemlibs/snappy.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # BSD 3-Clause + +filegroup( + name = "COPYING", + visibility = ["//visibility:public"], +) + +cc_library( + name = "snappy", + linkopts = ["-lsnappy"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/sqlite.BUILD b/third_party/systemlibs/sqlite.BUILD new file mode 100644 index 0000000000..20ee1ebbef --- /dev/null +++ b/third_party/systemlibs/sqlite.BUILD @@ -0,0 +1,15 @@ +licenses(["unencumbered"]) # Public Domain + +# Production build of SQLite library that's baked into TensorFlow. +cc_library( + name = "org_sqlite", + linkopts = ["-lsqlite3"], + visibility = ["//visibility:public"], +) + +# This is a Copybara sync helper for Google. +py_library( + name = "python", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/swig.BUILD b/third_party/systemlibs/swig.BUILD new file mode 100644 index 0000000000..4c9b74dadb --- /dev/null +++ b/third_party/systemlibs/swig.BUILD @@ -0,0 +1,23 @@ +licenses(["restricted"]) # GPLv3 + +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +filegroup( + name = "templates", + visibility = ["//visibility:public"], +) + +genrule( + name = "lnswiglink", + outs = ["swiglink"], + cmd = "ln -s $$(which swig) $@", +) + +sh_binary( + name = "swig", + srcs = ["swiglink"], + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/syslibs_configure.bzl b/third_party/systemlibs/syslibs_configure.bzl new file mode 100644 index 0000000000..07a44c317e --- /dev/null +++ b/third_party/systemlibs/syslibs_configure.bzl @@ -0,0 +1,160 @@ +# -*- Python -*- +"""Repository rule for system library autoconfiguration. + +`syslibs_configure` depends on the following environment variables: + + * `TF_SYSTEM_LIBS`: list of third party dependencies that should use + the system version instead +""" + +_TF_SYSTEM_LIBS="TF_SYSTEM_LIBS" + +VALID_LIBS=[ + "astor_archive", + "com_googlesource_code_re2", + "curl", + "cython", + "flatbuffers", + "gif_archive", + "grpc", + "jemalloc", + "jpeg", + "jsoncpp_git", + "lmdb", + "nasm", + "org_sqlite", + "pcre", + "png_archive", + "six_archive", + "snappy", + "swig", + "termcolor_archive", + "zlib_archive", +] + + +def auto_configure_fail(msg): + """Output failure message when syslibs configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("\n%sSystem Library Configuration Error:%s %s\n" % (red, no_color, msg)) + + +def _is_windows(repository_ctx): + """Returns true if the host operating system is windows.""" + os_name = repository_ctx.os.name.lower() + if os_name.find("windows") != -1: + return True + return False + + +def _enable_syslibs(repository_ctx): + s = repository_ctx.os.environ.get(_TF_SYSTEM_LIBS, '').strip() + if not _is_windows(repository_ctx) and s != None and s != '': + return True + return False + + +def _get_system_lib_list(repository_ctx): + """Gets the list of deps that should use the system lib. + + Args: + repository_ctx: The repository context. + + Returns: + A string version of a python list + """ + if _TF_SYSTEM_LIBS not in repository_ctx.os.environ: + return [] + + libenv = repository_ctx.os.environ[_TF_SYSTEM_LIBS].strip() + libs = [] + + for lib in list(libenv.split(',')): + lib = lib.strip() + if lib == "": + continue + if lib not in VALID_LIBS: + auto_configure_fail("Invalid system lib set: %s" % lib) + return [] + libs.append(lib) + + return libs + + +def _format_system_lib_list(repository_ctx): + """Formats the list of deps that should use the system lib. + + Args: + repository_ctx: The repository context. + + Returns: + A list of the names of deps that should use the system lib. + """ + libs = _get_system_lib_list(repository_ctx) + ret = '' + for lib in libs: + ret += "'%s',\n" % lib + + return ret + + +def _tpl(repository_ctx, tpl, substitutions={}, out=None): + if not out: + out = tpl.replace(":", "") + repository_ctx.template( + out, + Label("//third_party/systemlibs%s.tpl" % tpl), + substitutions, + False) + + +def _create_dummy_repository(repository_ctx): + """Creates the dummy repository to build with all bundled libraries.""" + + _tpl(repository_ctx, ":BUILD") + _tpl(repository_ctx, ":build_defs.bzl", + { + "%{syslibs_enabled}": 'False', + "%{syslibs_list}": '', + }) + + +def _create_local_repository(repository_ctx): + """Creates the repository to build with system libraries.""" + + _tpl(repository_ctx, ":BUILD") + _tpl(repository_ctx, ":build_defs.bzl", + { + "%{syslibs_enabled}": 'True', + "%{syslibs_list}": _format_system_lib_list(repository_ctx), + }) + + +def _syslibs_autoconf_impl(repository_ctx): + """Implementation of the syslibs_configure repository rule.""" + if not _enable_syslibs(repository_ctx): + _create_dummy_repository(repository_ctx) + else: + _create_local_repository(repository_ctx) + + +syslibs_configure = repository_rule( + implementation = _syslibs_autoconf_impl, + environ = [ + _TF_SYSTEM_LIBS, + ], +) + +"""Configures the build to link to system libraries +instead of using bundled versions. + +Add the following to your WORKSPACE FILE: + +```python +syslibs_configure(name = "local_config_syslibs") +``` + +Args: + name: A unique name for this workspace rule. +""" diff --git a/third_party/systemlibs/termcolor.BUILD b/third_party/systemlibs/termcolor.BUILD new file mode 100644 index 0000000000..915eb621d5 --- /dev/null +++ b/third_party/systemlibs/termcolor.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # MIT + +filegroup( + name = "COPYING.txt", + visibility = ["//visibility:public"], +) + +py_library( + name = "termcolor", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) diff --git a/third_party/systemlibs/zlib.BUILD b/third_party/systemlibs/zlib.BUILD new file mode 100644 index 0000000000..69462ae6cb --- /dev/null +++ b/third_party/systemlibs/zlib.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) # BSD/MIT-like license (for zlib) + +filegroup( + name = "zlib.h", + visibility = ["//visibility:public"], +) + +cc_library( + name = "zlib", + linkopts = ["-lz"], + visibility = ["//visibility:public"], +) diff --git a/tools/bazel.rc b/tools/bazel.rc index 3559375d5c..913c4bc333 100644 --- a/tools/bazel.rc +++ b/tools/bazel.rc @@ -27,6 +27,10 @@ build --define framework_shared_object=true build:mkl --define=using_mkl=true build:mkl -c opt +# This config option is used to enable MKL-DNN open source library only, +# without depending on MKL binary version. +build:mkl_open_source_only --define=using_mkl_dnn_only=true + build:download_clang --crosstool_top=@local_config_download_clang//:toolchain build:download_clang --define=using_clang=true |