diff options
46 files changed, 1307 insertions, 289 deletions
diff --git a/eigen.BUILD b/eigen.BUILD index 806b6d36b9..3e92d5887a 100644 --- a/eigen.BUILD +++ b/eigen.BUILD @@ -1,6 +1,6 @@ package(default_visibility = ["//visibility:public"]) -archive_dir = "eigen-eigen-88444e025a5c" +archive_dir = "eigen-eigen-2f482bcc8b95" cc_library( name = "eigen", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5ee2337647..901071e0f9 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -738,6 +738,7 @@ cc_library( "lib/random/weighted_picker.h", "lib/strings/ordered_code.h", "lib/strings/regexp.h", + "lib/strings/scanner.h", "platform/denormal.h", "platform/platform.h", "platform/tensor_coding.h", diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index c43f1d0973..6ee5f7d446 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1635,14 +1635,14 @@ void ExecutorState::DumpActiveNodeState(const int node_id, void ExecutorState::DumpIterationState(IterationState* iteration) { // Dump any waiting nodes that are holding on to tensors. - for (size_t i = 0; i < impl_->graph_->num_node_ids(); ++i) { + for (int i = 0; i < impl_->graph_->num_node_ids(); ++i) { if (iteration->node_state(i) == PendingCounts::PENDING_NOTREADY || iteration->node_state(i) == PendingCounts::PENDING_READY) { DumpPendingNodeState(i, iteration->input_tensors, false); } } // Then the active nodes. - for (size_t i = 0; i < impl_->graph_->num_node_ids(); ++i) { + for (int i = 0; i < impl_->graph_->num_node_ids(); ++i) { if (iteration->node_state(i) == PendingCounts::STARTED) { DumpActiveNodeState(i, iteration->input_tensors); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index a25d072eea..cd064dd4c4 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -388,7 +388,7 @@ string GPUUtil::MemoryDebugString(const Device* device, Tensor* tensor) { string buf; buf.resize(num_bytes); DeviceMemoryBase gpu_ptr(ptr, num_bytes); - Status s = dev_info->stream->parent()->SynchronousMemcpyD2H( + auto s = dev_info->stream->parent()->SynchronousMemcpyD2H( gpu_ptr, num_bytes, gtl::string_as_array(&buf)); strings::StrAppend(&ret, PrintMemory(gtl::string_as_array(&buf), num_bytes)); diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index df86046c45..be17ad4c4d 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -329,7 +329,6 @@ tf_cc_tests( tags = tf_cuda_tests_tags() + ["exclusive"], tests = [ "grpc_channel_test.cc", - "grpc_server_lib_test.cc", "grpc_session_test.cc", "rpc_rendezvous_mgr_test.cc", ], diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc deleted file mode 100644 index a56afb05a6..0000000000 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2016 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/distributed_runtime/server_lib.h" - -#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { - -// Tests that a server can be cleanly started, stopped, and joined -// when no calls are made against the server. -TEST(Server, StopAfterNoop) { - ServerDef def; - def.set_protocol("grpc"); - def.set_job_name("localhost"); - def.set_task_index(0); - JobDef* job_def = def.mutable_cluster()->add_job(); - job_def->set_name("localhost"); - (*job_def->mutable_tasks())[0] = - strings::StrCat("localhost:", testing::PickUnusedPortOrDie()); - std::unique_ptr<ServerInterface> svr; - TF_EXPECT_OK(NewServer(def, &svr)); - TF_EXPECT_OK(svr->Start()); - TF_EXPECT_OK(svr->Stop()); - TF_EXPECT_OK(svr->Join()); -} - -// Tests that a server can be cleanly started, stopped, and joined -// when a simple call is made against the server. -TEST(Server, StopAfterCall) { - ServerDef def; - def.set_protocol("grpc"); - def.set_job_name("localhost"); - def.set_task_index(0); - JobDef* job_def = def.mutable_cluster()->add_job(); - job_def->set_name("localhost"); - int port = testing::PickUnusedPortOrDie(); - (*job_def->mutable_tasks())[0] = strings::StrCat("localhost:", port); - std::unique_ptr<ServerInterface> svr; - TF_EXPECT_OK(NewServer(def, &svr)); - TF_EXPECT_OK(svr->Start()); - { - SessionOptions options; - options.target = strings::StrCat("grpc://localhost:", port); - std::unique_ptr<GrpcSession> sess(new GrpcSession(options)); - const std::vector<DeviceAttributes> devices = sess->ListDevices(); - EXPECT_GT(devices.size(), 0); - } - TF_EXPECT_OK(svr->Stop()); - TF_EXPECT_OK(svr->Join()); -} - -} // namespace tensorflow diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index 93823b154e..05e1f01a0c 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/regexp.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" @@ -250,10 +249,16 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { if (is_list) { // TextFormat parser considers "i: 7" to be the same as "i: [7]", // but we only want to allow list values with []. - if (!RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[.*\\]\\s*")) { + StringPiece cleaned = text; + str_util::RemoveLeadingWhitespace(&cleaned); + str_util::RemoveTrailingWhitespace(&cleaned); + if (cleaned.size() < 2 || cleaned[0] != '[' || + cleaned[cleaned.size() - 1] != ']') { return false; } - if (RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[\\s*\\]\\s*")) { + cleaned.remove_prefix(1); + str_util::RemoveLeadingWhitespace(&cleaned); + if (cleaned.size() == 1) { // User wrote "[]", so return empty list without invoking the TextFormat // parse which returns an error for "i: []". out->Clear(); diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index 71281cf7b2..ccfbec662f 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -84,8 +84,9 @@ Status RemoveNewDefaultAttrsFromGraphDef( if (!s.ok()) return s; for (const auto& attr : node_def->attr()) { - // If the attr is not in consumer_op_def... - if (FindAttr(attr.first, *consumer_op_def) == nullptr) { + // If the attr is not in consumer_op_def and doesn't start with '_'... + if (!StringPiece(attr.first).starts_with("_") && + FindAttr(attr.first, *consumer_op_def) == nullptr) { const OpDef::AttrDef* producer_attr_def = FindAttr(attr.first, *producer_op_def); if (producer_attr_def == nullptr) { diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc index 96001c8f6d..2b51a0b1de 100644 --- a/tensorflow/core/framework/graph_def_util_test.cc +++ b/tensorflow/core/framework/graph_def_util_test.cc @@ -133,6 +133,36 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) { EXPECT_TRUE(op_attr_removed.empty()); } +// Attrs starting with underscores should not be removed. +TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) { + OpList consumer_op_list; + TF_ASSERT_OK(OpDefBuilder("Underscore").Finalize(consumer_op_list.add_op())); + OpListOpRegistry consumer_registry(&consumer_op_list); + + OpList producer_op_list; + TF_ASSERT_OK(OpDefBuilder("Underscore").Finalize(producer_op_list.add_op())); + // Add the _underscore attr manually since OpDefBuilder would complain + OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr(); + attr->set_name("_underscore"); + attr->set_type("int"); + attr->mutable_default_value()->set_i(17); + OpListOpRegistry producer_registry(&producer_op_list); + + GraphDef produced_graph_def; + TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry) + .Attr("_underscore", 17) + .Finalize(produced_graph_def.add_node())); + GraphDef expected_graph_def = produced_graph_def; + + std::set<std::pair<string, string>> op_attr_removed; + TF_ASSERT_OK( + RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, + producer_registry, &op_attr_removed)); + + TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); + EXPECT_EQ(op_attr_removed.size(), 0); +} + TEST(StrippedOpListForGraphTest, FlatTest) { // Make four ops OpList op_list; diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 0f08d391ac..641411892d 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/regexp.h" namespace tensorflow { @@ -381,19 +381,50 @@ void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { namespace { -static RE2* valid_op_name_pattern = new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); -static RE2* valid_data_input_pattern = - new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*(\\:(0|([1-9][0-9]*)))?"); -static RE2* valid_control_input_pattern = - new RE2("\\^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); +using ::tensorflow::strings::Scanner; + +bool IsValidOpName(StringPiece sp) { + return Scanner(sp) + .One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .Eos() + .GetResult(); +} + +bool IsValidDataInputName(StringPiece sp) { + // Data inputs are op_name, op_name:0, or op_name:12345. + Scanner scan(sp); + scan.One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); + if (scan.Peek() == ':') { + scan.OneLiteral(":"); + if (scan.Peek() == '0') { + scan.OneLiteral("0"); // :0 + } else { + scan.Many(Scanner::DIGIT); // :[1-9][0-9]* + } + } + scan.Eos(); + + return scan.GetResult(); +} + +bool IsValidControlInputName(StringPiece sp) { + return Scanner(sp) + .OneLiteral("^") + .One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .Eos() + .GetResult(); +} } // namespace Status ValidateOpInput(const string& input_name, bool* is_control_input) { *is_control_input = false; - if (RE2::FullMatch(input_name, *valid_data_input_pattern)) { + if (IsValidDataInputName(input_name)) { return Status::OK(); - } else if (RE2::FullMatch(input_name, *valid_control_input_pattern)) { + } else if (IsValidControlInputName(input_name)) { *is_control_input = true; return Status::OK(); } else { @@ -402,7 +433,7 @@ Status ValidateOpInput(const string& input_name, bool* is_control_input) { } Status ValidateOpName(const string& op_name) { - if (RE2::FullMatch(op_name, *valid_op_name_pattern)) { + if (IsValidOpName(op_name)) { return Status::OK(); } else { return errors::InvalidArgument("Illegal op name '", op_name, "'"); diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index 3a405fc275..07bd60f3b7 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -308,6 +308,12 @@ TEST(NodeDefUtilTest, ValidSyntax) { )proto"); ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'"); + const NodeDef node_def_slash_in_name = ToNodeDef(R"proto( + name:'n\\' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'"); + const NodeDef node_def_internal_input_name = ToNodeDef(R"proto( name:'n' op:'AnyIn' input:'_a' input:'b' attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } @@ -315,6 +321,12 @@ TEST(NodeDefUtilTest, ValidSyntax) { ExpectInvalidSyntax(node_def_internal_input_name, "Illegal op input name '_a'"); + const NodeDef node_def_input_name_slash = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a\\' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'"); + const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto( name:'n' op:'AnyIn' input:'a' input:'^b:0' attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } @@ -322,12 +334,33 @@ TEST(NodeDefUtilTest, ValidSyntax) { ExpectInvalidSyntax(node_def_invalid_control_input_name, "Illegal op input name '^b:0'"); + const NodeDef node_def_control_input_name_slash = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'^b\\' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_control_input_name_slash, + "Illegal op input name '^b\\'"); + const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto( name:'n' op:'AnyIn' input:'^a' input:'b' attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } )proto"); ExpectInvalidSyntax(node_def_data_input_after_control, "All control inputs must follow all data inputs"); + + const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a:b' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_data_input_invalid_port, + "Illegal op input name 'a:b"); + + const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a:00' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_data_input_invalid_port2, + "Illegal op input name 'a:00"); } TEST(NameRangesForNodeTest, Simple) { diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 8983371503..2cd6770f6c 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -15,80 +15,99 @@ limitations under the License. #include "tensorflow/core/framework/op_def_builder.h" +#include <limits> #include <vector> #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/regexp.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +using ::tensorflow::strings::Scanner; + namespace tensorflow { namespace { -bool RE2Consume(StringPiece* sp, const RE2& pattern) { - RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); - bool r = RE2::Consume(&base_sp, pattern); - *sp = FromRegexpStringPiece(base_sp); - return r; -} - -bool RE2Consume(StringPiece* sp, const RE2& pattern, StringPiece* out) { - RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); - RegexpStringPiece base_out; - bool r = RE2::Consume(&base_sp, pattern, &base_out); - *sp = FromRegexpStringPiece(base_sp); - *out = FromRegexpStringPiece(base_out); - return r; -} - -bool RE2Consume(StringPiece* sp, const RE2& pattern, int64* out) { - RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); - bool r = RE2::Consume(&base_sp, pattern, out); - *sp = FromRegexpStringPiece(base_sp); - return r; -} - string AttrError(StringPiece orig, const string& op_name) { return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name); } -const RE2& AttrNameRE() { - static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*"); - return pattern; -} - -const RE2& AttrListPrefixRE() { - static RE2 pattern("list\\s*\\(\\s*"); - return pattern; -} - -const RE2& SpacesRE() { - static RE2 pattern("\\s*"); - return pattern; -} - -const RE2& AttrDoubleQuotedRE() { - static RE2 pattern(R"xx("((?:[^"\\]|\\.)*)"\s*)xx"); - return pattern; -} - -const RE2& AttrSingleQuotedRE() { - static RE2 pattern(R"xx('((?:[^'\\]|\\.)*)'\s*)xx"); - return pattern; -} - -const RE2& AttrTypeRE() { - static RE2 pattern("([a-z0-9]+)\\s*"); - return pattern; -} - -const RE2& AttrNumberRE() { - static RE2 pattern("\\s*(-?\\d+)\\s*"); - return pattern; +bool ConsumeAttrName(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .OneLiteral(":") + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeListPrefix(StringPiece* sp) { + return Scanner(*sp) + .OneLiteral("list") + .AnySpace() + .OneLiteral("(") + .AnySpace() + .GetResult(sp); +} + +bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) { + const string quote_str(1, quote_ch); + return Scanner(*sp) + .OneLiteral(quote_str.c_str()) + .RestartCapture() + .ScanEscapedUntil(quote_ch) + .StopCapture() + .OneLiteral(quote_str.c_str()) + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeAttrType(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .Many(Scanner::LOWERLETTER_DIGIT) + .StopCapture() + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeAttrNumber(StringPiece* sp, int64* out) { + Scanner scan(*sp); + StringPiece match; + StringPiece remaining; + + scan.AnySpace(); + bool is_negative = false; + if (scan.Peek() == '-') { + is_negative = true; + scan.OneLiteral("-"); + } + if (!scan.RestartCapture() + .Many(Scanner::DIGIT) + .StopCapture() + .AnySpace() + .GetResult(&remaining, &match)) { + return false; + } + uint64 val = 0; + if (!str_util::ConsumeLeadingDigits(&match, &val)) return false; + if (is_negative) { + const int64 final_val = static_cast<int64>(val) * -1; + if (final_val > 0) return false; + *out = final_val; + } else { + if (val > static_cast<uint64>(std::numeric_limits<int64>::max())) { + return false; + } + *out = val; + } + *sp = remaining; + return true; } #define VERIFY(expr, ...) \ @@ -107,12 +126,11 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, // Parse "<name>:" at the beginning. StringPiece tmp_name; - VERIFY(RE2Consume(&spec, AttrNameRE(), &tmp_name), - "Trouble parsing '<name>:'"); + VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'"); attr->set_name(tmp_name.data(), tmp_name.size()); // Read "<type>" or "list(<type>)". - bool is_list = RE2Consume(&spec, AttrListPrefixRE()); + bool is_list = ConsumeListPrefix(&spec); string type; if (spec.Consume("string")) { type = "string"; @@ -151,14 +169,14 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, } } else if (spec.Consume("{")) { // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }" - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); AttrValue* allowed = attr->mutable_allowed_values(); if (spec.starts_with("\"") || spec.starts_with("'")) { type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }" while (true) { StringPiece escaped_string; - VERIFY((RE2Consume(&spec, AttrDoubleQuotedRE(), &escaped_string) || - RE2Consume(&spec, AttrSingleQuotedRE(), &escaped_string)), + VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) || + ConsumeQuotedString('\'', &spec, &escaped_string), "Trouble parsing allowed string at '", spec, "'"); string unescaped; string error; @@ -167,7 +185,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, error); allowed->mutable_list()->add_s(unescaped); if (spec.Consume(",")) { - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); if (spec.Consume("}")) break; // Allow ending with ", }". } else { VERIFY(spec.Consume("}"), @@ -179,14 +197,14 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, type = "type"; while (true) { StringPiece type_string; - VERIFY(RE2Consume(&spec, AttrTypeRE(), &type_string), + VERIFY(ConsumeAttrType(&spec, &type_string), "Trouble parsing type string at '", spec, "'"); DataType dt; VERIFY(DataTypeFromString(type_string, &dt), "Unrecognized type string '", type_string, "'"); allowed->mutable_list()->add_type(dt); if (spec.Consume(",")) { - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); if (spec.Consume("}")) break; // Allow ending with ", }". } else { VERIFY(spec.Consume("}"), @@ -198,12 +216,12 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, } else { VERIFY(false, "Trouble parsing type string at '", spec, "'"); } - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); // Write the type into *attr. if (is_list) { VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'"); - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); attr->set_type(strings::StrCat("list(", type, ")")); } else { attr->set_type(type); @@ -212,7 +230,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, // Read optional minimum constraint at the end. if ((is_list || type == "int") && spec.Consume(">=")) { int64 min_limit = -999; - VERIFY(RE2Consume(&spec, AttrNumberRE(), &min_limit), + VERIFY(ConsumeAttrNumber(&spec, &min_limit), "Could not parse integer lower limit after '>=', found '", spec, "' instead"); attr->set_has_minimum(true); @@ -221,7 +239,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, // Parse default value, if present. if (spec.Consume("=")) { - RE2Consume(&spec, SpacesRE()); + str_util::RemoveLeadingWhitespace(&spec); VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()), "Could not parse default value '", spec, "'"); } else { @@ -236,29 +254,49 @@ string InOutError(bool is_output, StringPiece orig, const string& op_name) { "\") for Op ", op_name); } -const RE2& InOutNameRE() { - static RE2 pattern("([a-z][a-z0-9_]*)\\s*:\\s*"); - return pattern; +bool ConsumeInOutName(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LOWERLETTER) + .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .OneLiteral(":") + .AnySpace() + .GetResult(sp, out); } -const RE2& InOutRefOpenRE() { - static RE2 pattern("Ref\\s*\\(\\s*"); - return pattern; +bool ConsumeInOutRefOpen(StringPiece* sp) { + return Scanner(*sp) + .OneLiteral("Ref") + .AnySpace() + .OneLiteral("(") + .AnySpace() + .GetResult(sp); } -const RE2& InOutRefCloseRE() { - static RE2 pattern("\\)\\s*"); - return pattern; +bool ConsumeInOutRefClose(StringPiece* sp) { + return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp); } -const RE2& InOutNameOrTypeRE() { - static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*"); - return pattern; +bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .GetResult(sp, out); } -const RE2& InOutTimesTypeRE() { - static RE2 pattern("[*]\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*"); - return pattern; +bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .OneLiteral("*") + .AnySpace() + .RestartCapture() + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .GetResult(sp, out); } #define VERIFY(expr, ...) \ @@ -279,20 +317,19 @@ void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def, // Parse "<name>:" at the beginning. StringPiece tmp_name; - VERIFY(RE2Consume(&spec, InOutNameRE(), &tmp_name), - "Trouble parsing 'name:'"); + VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'"); arg->set_name(tmp_name.data(), tmp_name.size()); // Detect "Ref(...)". - if (RE2Consume(&spec, InOutRefOpenRE())) { + if (ConsumeInOutRefOpen(&spec)) { arg->set_is_ref(true); } { // Parse "<name|type>" or "<name>*<name|type>". StringPiece first, second, type_or_attr; - VERIFY(RE2Consume(&spec, InOutNameOrTypeRE(), &first), + VERIFY(ConsumeInOutNameOrType(&spec, &first), "Trouble parsing either a type or an attr name at '", spec, "'"); - if (RE2Consume(&spec, InOutTimesTypeRE(), &second)) { + if (ConsumeInOutTimesType(&spec, &second)) { arg->set_number_attr(first.data(), first.size()); type_or_attr = second; } else { @@ -317,7 +354,7 @@ void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def, // Closing ) for Ref(. if (arg->is_ref()) { - VERIFY(RE2Consume(&spec, InOutRefCloseRE()), + VERIFY(ConsumeInOutRefClose(&spec), "Did not find closing ')' for 'Ref(', instead found: '", spec, "'"); } @@ -354,14 +391,19 @@ int num_leading_spaces(StringPiece s) { return i; } -const RE2& DocNameColonRE() { - static RE2 pattern("^[a-zA-Z][a-zA-Z0-9_]*\\s*:"); - return pattern; +bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .OneLiteral(":") + .AnySpace() + .GetResult(sp, out); } -const RE2& DocNameColonSpacesRE() { - static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*"); - return pattern; +bool IsDocNameColon(StringPiece s) { + return ConsumeDocNameColon(&s, nullptr /* out */); } void FinalizeDoc(const string& text, OpDef* op_def, @@ -384,8 +426,7 @@ void FinalizeDoc(const string& text, OpDef* op_def, // Lines until we see name: -> description. int start_l = l; - while (static_cast<size_t>(l) < lines.size() && - !RE2::PartialMatch(lines[l], DocNameColonRE())) { + while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) { ++l; } int end_l = l; @@ -403,10 +444,9 @@ void FinalizeDoc(const string& text, OpDef* op_def, while (static_cast<size_t>(l) < lines.size()) { description.clear(); description.push_back(lines[l]); - RE2Consume(&description.back(), DocNameColonSpacesRE(), &name); + ConsumeDocNameColon(&description.back(), &name); ++l; - while (static_cast<size_t>(l) < lines.size() && - !RE2::PartialMatch(lines[l], DocNameColonRE())) { + while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) { description.push_back(lines[l]); ++l; } diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc index bca67120bb..2d6a7f01ae 100644 --- a/tensorflow/core/framework/op_def_builder_test.cc +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -140,6 +140,12 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) { ExpectSuccess( b().Attr("i: int >= -5"), "attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }"); + ExpectSuccess(b().Attr("i: int >= 9223372036854775807"), + ("attr: { name: 'i' type: 'int' has_minimum: true " + "minimum: 9223372036854775807 }")); + ExpectSuccess(b().Attr("i: int >= -9223372036854775808"), + ("attr: { name: 'i' type: 'int' has_minimum: true " + "minimum: -9223372036854775808 }")); } TEST_F(OpDefBuilderTest, AttrRestrictionFailure) { @@ -164,6 +170,20 @@ TEST_F(OpDefBuilderTest, AttrRestrictionFailure) { ExpectFailure(b().Attr("a:{float,,}"), "Trouble parsing type string at ',}' from " "Attr(\"a:{float,,}\") for Op Test"); + ExpectFailure(b().Attr("i: int >= a"), + "Could not parse integer lower limit after '>=', " + "found ' a' instead from Attr(\"i: int >= a\") for Op Test"); + ExpectFailure(b().Attr("i: int >= -a"), + "Could not parse integer lower limit after '>=', found ' -a' " + "instead from Attr(\"i: int >= -a\") for Op Test"); + ExpectFailure(b().Attr("i: int >= 9223372036854775808"), + "Could not parse integer lower limit after '>=', found " + "' 9223372036854775808' instead from " + "Attr(\"i: int >= 9223372036854775808\") for Op Test"); + ExpectFailure(b().Attr("i: int >= -9223372036854775809"), + "Could not parse integer lower limit after '>=', found " + "' -9223372036854775809' instead from " + "Attr(\"i: int >= -9223372036854775809\") for Op Test"); } TEST_F(OpDefBuilderTest, AttrListOfRestricted) { @@ -241,6 +261,9 @@ TEST_F(OpDefBuilderTest, AttrListWithDefaults) { ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"), "attr: { name: 'a' type: 'list(int)' " "default_value { list { i: [0, -1, 2, -4, 8] } } }"); + ExpectSuccess(b().Attr(R"(a:list(int)=[ ])"), + "attr: { name: 'a' type: 'list(int)' " + "default_value { list { i: [] } } }"); } TEST_F(OpDefBuilderTest, AttrFailedListDefaults) { @@ -259,6 +282,12 @@ TEST_F(OpDefBuilderTest, AttrFailedListDefaults) { ExpectFailure(b().Attr(R"(a:list(string)='foo')"), "Could not parse default value ''foo'' from " "Attr(\"a:list(string)='foo'\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = ["), + "Could not parse default value '[' from " + "Attr(\"a:list(float) = [\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = "), + "Could not parse default value '' from " + "Attr(\"a:list(float) = \") for Op Test"); } TEST_F(OpDefBuilderTest, InputOutput) { @@ -268,7 +297,7 @@ TEST_F(OpDefBuilderTest, InputOutput) { "output_arg: { name: 'b' type: DT_STRING }"); ExpectSuccess(b().Input("c: float "), "input_arg: { name: 'c' type: DT_FLOAT }"); - ExpectSuccess(b().Output("d: Ref(bool)"), + ExpectSuccess(b().Output("d: Ref ( bool ) "), "output_arg: { name: 'd' type: DT_BOOL is_ref: true }"); ExpectOrdered(b().Input("a: bool") .Output("c: complex64") @@ -326,6 +355,12 @@ TEST_F(OpDefBuilderTest, InputOutputFailure) { ExpectFailure( b().Input("CAPS: int32"), "Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test"); + ExpectFailure( + b().Input("_underscore: int32"), + "Trouble parsing 'name:' from Input(\"_underscore: int32\") for Op Test"); + ExpectFailure( + b().Input("0digit: int32"), + "Trouble parsing 'name:' from Input(\"0digit: int32\") for Op Test"); ExpectFailure(b().Input("a: _"), "Trouble parsing either a type or an attr name at '_' from " "Input(\"a: _\") for Op Test"); @@ -344,6 +379,9 @@ TEST_F(OpDefBuilderTest, InputOutputFailure) { ExpectFailure(b().Input("a: Ref(int32"), "Did not find closing ')' for 'Ref(', instead found: '' from " "Input(\"a: Ref(int32\") for Op Test"); + ExpectFailure( + b().Input("a: Ref"), + "Reference to unknown attr 'Ref' from Input(\"a: Ref\") for Op Test"); ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"), "Did not find closing ')' for 'Ref(', instead found: 'y' from " "Input(\"a: Ref(x y\") for Op Test"); diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index b94207e2e8..f7e4f1f05a 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -221,8 +221,16 @@ static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, } Status ValidateOpDef(const OpDef& op_def) { - VALIDATE(RE2::FullMatch(op_def.name(), "(?:_.*|[A-Z][a-zA-Z0-9]*)"), - "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); + using ::tensorflow::strings::Scanner; + + if (!StringPiece(op_def.name()).starts_with("_")) { + VALIDATE(Scanner(op_def.name()) + .One(Scanner::UPPERLETTER) + .Any(Scanner::LETTER_DIGIT) + .Eos() + .GetResult(), + "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); + } std::set<string> names; // for detecting duplicate names for (const auto& attr : op_def.attr()) { diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index f0c9085d48..60425eedb0 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { @@ -117,15 +117,22 @@ Status ResourceMgr::Cleanup(const string& container) { return Status::OK(); } +static bool IsValidContainerName(StringPiece s) { + using ::tensorflow::strings::Scanner; + return Scanner(s) + .One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH) + .Eos() + .GetResult(); +} + Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef, bool use_node_name_as_default) { CHECK(rmgr); rmgr_ = rmgr; string attr_container; TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container)); - static RE2 container_re("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); - if (!attr_container.empty() && - !RE2::FullMatch(attr_container, container_re)) { + if (!attr_container.empty() && !IsValidContainerName(attr_container)) { return errors::InvalidArgument("container contains invalid characters: ", attr_container); } diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index f776d1ebcc..56bc76d384 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -161,6 +161,8 @@ TEST(ContainerInfo, Basic) { EXPECT_EQ(Policy("cat", "", true), "[cat,foo,public]"); EXPECT_EQ(Policy("cat", "bar", false), "[cat,bar,public]"); EXPECT_EQ(Policy("cat", "bar", true), "[cat,bar,public]"); + EXPECT_EQ(Policy("cat.0-dog", "bar", true), "[cat.0-dog,bar,public]"); + EXPECT_EQ(Policy(".cat", "bar", true), "[.cat,bar,public]"); } Status WrongPolicy(const string& attr_container, const string& attr_shared_name, @@ -180,6 +182,7 @@ TEST(ContainerInfo, Error) { // Invalid container. HasError(WrongPolicy("12$%", "", false), "container contains invalid char"); + HasError(WrongPolicy("-cat", "", false), "container contains invalid char"); // Invalid shared name. HasError(WrongPolicy("", "_foo", false), "shared_name cannot start with '_'"); diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 6c5873d0c1..db3b46863f 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -28,8 +28,8 @@ limitations under the License. #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/public/version.h" namespace tensorflow { @@ -126,20 +126,21 @@ void GraphConstructor::SetError(const string& error) { status_->Update(errors::InvalidArgument(error)); } -void GraphConstructor::BuildNodeIndex() { - // Initialized outside the loop for efficiency - const char* pattern; - if (opts_.allow_internal_ops) { - pattern = "[A-Za-z0-9._][A-Za-z0-9_.\\-/]*"; - } else { - pattern = "[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"; - } - RE2 node_name_re(pattern); +bool IsValidNodeName(StringPiece s, bool allow_internal_ops) { + using ::tensorflow::strings::Scanner; + return Scanner(s) + .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE + : Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .Eos() + .GetResult(); +} +void GraphConstructor::BuildNodeIndex() { // Validate the node names and add them to name_index_. for (int n = 0; n < gdef_->node_size(); ++n) { const NodeDef& node_def(gdef_->node(n)); - if (!RE2::FullMatch(node_def.name(), node_name_re)) { + if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) { SetNodeError(node_def, "Node name contains invalid characters"); return; } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 8e391d6510..ea8ae5dc06 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" @@ -127,12 +128,24 @@ REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); REGISTER_OP("TestInt").Input("a: int32"); TEST_F(GraphConstructorTest, InvalidNodeName) { - ExpectError("node { name: 'a:b' op: 'ABC' }", - "Node 'a:b': Node name contains invalid characters"); - ExpectError("node { name: '_abc' op: 'ABC' }", - // Can't start with '_' - "Node '_abc': Node name contains invalid characters"); + auto expect_invalid_name = [this](const char* name) { + ExpectError(strings::StrCat("node { name: '", name, "' op: 'ABC' }"), + strings::StrCat("Node '", name, + "': Node name contains invalid characters")); + }; + + expect_invalid_name("a:b"); + expect_invalid_name("_abc"); // Can't start with '_' + // Name is a\b, but proto text format escapes slashes so we use a\\b here. + // This works for ExpectError too, since re2 also treats \\ as one slash. + expect_invalid_name(R"(a\\b)"); + expect_invalid_name("/a"); + expect_invalid_name("-a"); + ExpectOK("node { name: 'a-bc_' op: 'ABC' }"); + ExpectOK("node { name: 'a-B.0/.c_' op: 'ABC' }"); + ExpectOK("node { name: '0123' op: 'ABC' }"); + ExpectOK("node { name: '.0123' op: 'ABC' }"); } TEST_F(GraphConstructorTest, InvalidSourceNodeName) { diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 7ca186eaec..9dc64a6cef 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -540,6 +540,7 @@ tf_cc_tests( "adjust_contrast_op_benchmark_test", "adjust_contrast_op_test", "colorspace_op_test", + "resize_bicubic_op_test", "resize_bilinear_op_test", "resize_nearest_neighbor_op_test", ], @@ -1104,6 +1105,7 @@ filegroup( "cwise_op_minimum.cc", "cwise_op_mul.cc", "cwise_op_neg.cc", + "cwise_op_rsqrt.cc", "cwise_op_select.cc", "cwise_op_sigmoid.cc", "cwise_op_sqrt.cc", diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc index 42674be967..172bb95661 100644 --- a/tensorflow/core/kernels/decode_raw_op.cc +++ b/tensorflow/core/kernels/decode_raw_op.cc @@ -35,9 +35,9 @@ class DecodeRawOp : public OpKernel { void Compute(OpKernelContext* context) override { const auto& input = context->input(0); - int str_size = -1; + int64 str_size = -1; auto flat_in = input.flat<string>(); - for (int i = 0; i < flat_in.size(); ++i) { + for (int64 i = 0; i < flat_in.size(); ++i) { const string& in_str = flat_in(i); if (str_size == -1) { str_size = in_str.size(); @@ -62,7 +62,7 @@ class DecodeRawOp : public OpKernel { errors::InvalidArgument("Input to DecodeRaw has length ", str_size, " that is not a multiple of ", sizeof(T), ", the size of ", DataTypeString(out_type_))); - const int added_dim = str_size / sizeof(T); + const int64 added_dim = str_size / sizeof(T); out_shape.AddDim(added_dim); Tensor* output_tensor = nullptr; OP_REQUIRES_OK( @@ -76,7 +76,7 @@ class DecodeRawOp : public OpKernel { little_endian_ ? "true" : "false")); // Endianness matches, so just copy each string byte-for-byte. T* out_data = out.data(); - for (int i = 0; i < flat_in.size(); ++i) { + for (int64 i = 0; i < flat_in.size(); ++i) { const T* in_data = reinterpret_cast<const T*>(flat_in(i).data()); memcpy(out_data, in_data, str_size); out_data += added_dim; diff --git a/tensorflow/core/kernels/ops_util.cc b/tensorflow/core/kernels/ops_util.cc index 64955ab0b7..8b03c570de 100644 --- a/tensorflow/core/kernels/ops_util.cc +++ b/tensorflow/core/kernels/ops_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/util/padding.h" namespace tensorflow { @@ -119,9 +119,17 @@ Status GetBroadcastSize(const int index, const int in_size, const int ksize, } string SanitizeThreadSuffix(string suffix) { - static RE2 re("[^A-Za-z0-9_-]"); - re.GlobalReplace(&suffix, re, "_"); - return suffix; + string clean; + for (int i = 0; i < suffix.size(); ++i) { + const char ch = suffix[i]; + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { + clean += ch; + } else { + clean += '_'; + } + } + return clean; } } // namespace tensorflow diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc index b92ec6e004..a34703402c 100644 --- a/tensorflow/core/kernels/resize_bicubic_op.cc +++ b/tensorflow/core/kernels/resize_bicubic_op.cc @@ -28,17 +28,18 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace tensorflow { +namespace { -static const int64 tab_size = (1 << 10); +static const int64 kTableSize = (1 << 10); const float* InitCoeffsTable() { // Allocate and initialize coefficients table using Bicubic // convolution algorithm. // https://en.wikipedia.org/wiki/Bicubic_interpolation - float* coeffs_tab = new float[(tab_size + 1) * 2]; + float* coeffs_tab = new float[(kTableSize + 1) * 2]; static const double A = -0.75; - for (int i = 0; i <= tab_size; ++i) { - float x = i * 1.0 / tab_size; + for (int i = 0; i <= kTableSize; ++i) { + float x = i * 1.0 / kTableSize; coeffs_tab[i * 2] = ((A + 2) * x - (A + 3)) * x * x + 1; x += 1.0; coeffs_tab[i * 2 + 1] = ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; @@ -52,6 +53,32 @@ const float* GetCoeffsTable() { return coeffs_tab; } +inline int64 Bound(int64 val, int64 limit) { + return std::min(limit - 1ll, std::max(0ll, val)); +} + +inline void GetWeightsAndIndices(float scale, int64 out_loc, int64 limit, + std::array<float, 4>* weights, + std::array<int64, 4>* indices) { + const int64 in_loc = floor(scale * out_loc); + const float delta = scale * out_loc - in_loc; + const int64 offset = round(delta * kTableSize); + const float* coeffs_tab = GetCoeffsTable(); + *weights = {{coeffs_tab[offset * 2 + 1], coeffs_tab[offset * 2], + coeffs_tab[(kTableSize - offset) * 2], + coeffs_tab[(kTableSize - offset) * 2 + 1]}}; + *indices = {{Bound(in_loc - 1, limit), Bound(in_loc, limit), + Bound(in_loc + 1, limit), Bound(in_loc + 2, limit)}}; +} + +inline float Interpolate1D(const std::array<float, 4>& weights, + const std::array<float, 4>& values) { + return values[0] * weights[0] + values[1] * weights[1] + + values[2] * weights[2] + values[3] * weights[3]; +} + +} // namespace + typedef Eigen::ThreadPoolDevice CPUDevice; template <typename Device, typename T> @@ -106,40 +133,34 @@ class ResizeBicubicOp : public OpKernel { ? (in_width - 1) / static_cast<float>(out_width - 1) : in_width / static_cast<float>(out_width); - const float* coeffs_tab = GetCoeffsTable(); - - auto cal = [](const float* coeffs_tab, float v0, float v1, float v2, - float v3, float dx) { - const int64 offset = round(dx * tab_size); - const float a0 = coeffs_tab[offset * 2 + 1]; - const float a1 = coeffs_tab[offset * 2]; - const float a2 = coeffs_tab[(tab_size - offset) * 2]; - const float a3 = coeffs_tab[(tab_size - offset) * 2 + 1]; - return a0 * v0 + a1 * v1 + a2 * v2 + a3 * v3; - }; - - float coeff[4] = {0.0}; + std::array<float, 4> coeff = {{0.0, 0.0, 0.0, 0.0}}; for (int64 b = 0; b < batch_size; ++b) { for (int64 y = 0; y < out_height; ++y) { - const int64 in_y = floor(height_scale * y); - const float dy = height_scale * y - in_y; + std::array<float, 4> y_weights; + std::array<int64, 4> y_indices; + GetWeightsAndIndices(height_scale, y, in_height, &y_weights, + &y_indices); for (int64 x = 0; x < out_width; ++x) { - const int64 in_x = floor(width_scale * x); - const float dx = width_scale * x - in_x; + std::array<float, 4> x_weights; + std::array<int64, 4> x_indices; + GetWeightsAndIndices(width_scale, x, in_width, &x_weights, + &x_indices); for (int64 c = 0; c < channels; ++c) { + // Use a 4x4 patch to compute the interpolated output value at + // (b, y, x, c). for (int64 i = 0; i < 4; ++i) { -#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val)))) - int64 bound_y = BOUND(in_y - 1 + i, in_height); - coeff[i] = - cal(coeffs_tab, - input_data(b, bound_y, BOUND(in_x - 1, in_width), c), - input_data(b, bound_y, BOUND(in_x, in_width), c), - input_data(b, bound_y, BOUND(in_x + 1, in_width), c), - input_data(b, bound_y, BOUND(in_x + 2, in_width), c), dx); -#undef BOUND + const std::array<float, 4> values = { + {static_cast<float>( + input_data(b, y_indices[i], x_indices[0], c)), + static_cast<float>( + input_data(b, y_indices[i], x_indices[1], c)), + static_cast<float>( + input_data(b, y_indices[i], x_indices[2], c)), + static_cast<float>( + input_data(b, y_indices[i], x_indices[3], c))}}; + coeff[i] = Interpolate1D(x_weights, values); } - output_data(b, y, x, c) = - cal(coeffs_tab, coeff[0], coeff[1], coeff[2], coeff[3], dy); + output_data(b, y, x, c) = Interpolate1D(y_weights, coeff); } } } diff --git a/tensorflow/core/kernels/resize_bicubic_op_test.cc b/tensorflow/core/kernels/resize_bicubic_op_test.cc new file mode 100644 index 0000000000..d9d68844eb --- /dev/null +++ b/tensorflow/core/kernels/resize_bicubic_op_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +static Graph* ResizeBicubic(int batch_size, int size, int channels) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor input(DT_FLOAT, TensorShape({batch_size, size, size, channels})); + input.flat<float>().setRandom(); + Tensor shape(DT_INT32, TensorShape({2})); + auto shape_t = shape.flat<int32>(); + shape_t(0) = 0.3 * size; + shape_t(1) = 0.7 * size; + test::graph::Binary(g, "ResizeBicubic", test::graph::Constant(g, input), + test::graph::Constant(g, shape)); + return g; +} + +#define BM_ResizeBicubicDev(BATCH, SIZE, CHANNELS) \ + static void BM_ResizeBicubic##_##BATCH##_##SIZE##_##CHANNELS(int iters) { \ + testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * SIZE * SIZE * \ + CHANNELS); \ + test::Benchmark("cpu", ResizeBicubic(BATCH, SIZE, CHANNELS)).Run(iters); \ + } \ + BENCHMARK(BM_ResizeBicubic##_##BATCH##_##SIZE##_##CHANNELS); + +BM_ResizeBicubicDev(8, 32, 3); +BM_ResizeBicubicDev(8, 128, 3); +BM_ResizeBicubicDev(8, 512, 3); +BM_ResizeBicubicDev(8, 1024, 3); +BM_ResizeBicubicDev(16, 32, 3); +BM_ResizeBicubicDev(16, 128, 3); +BM_ResizeBicubicDev(16, 512, 3); +BM_ResizeBicubicDev(16, 1024, 3); +BM_ResizeBicubicDev(32, 32, 3); +BM_ResizeBicubicDev(32, 128, 3); +BM_ResizeBicubicDev(32, 512, 3); +BM_ResizeBicubicDev(32, 1024, 3); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_ops.h b/tensorflow/core/kernels/tile_ops.h index 379f877fcb..3dc5deb286 100644 --- a/tensorflow/core/kernels/tile_ops.h +++ b/tensorflow/core/kernels/tile_ops.h @@ -28,7 +28,12 @@ struct Tile { void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out, typename TTypes<T, NDIM>::ConstTensor in, const Eigen::array<int32, NDIM>& broadcast_array) const { - out.device(d) = in.broadcast(broadcast_array); + if (Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) { + // Use 32bit indexing to speed up the computations + To32Bit(out).device(d) = To32Bit(in).broadcast(broadcast_array); + } else { + out.device(d) = in.broadcast(broadcast_array); + } } }; diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 7a99cbd5d6..a24b7c704a 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -576,7 +576,6 @@ class SparseApplyFtrlOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - const Device& device = ctx->template eigen_device<Device>(); mutex* mu_var = ctx->input_ref_mutex(0); // mu_accum is actually the same mutex as mu_var since currently we use a // global mutex. @@ -666,21 +665,42 @@ class SparseApplyFtrlOp : public OpKernel { auto accum_flat = accum.flat_outer_dims<T>(); auto linear_flat = linear.flat_outer_dims<T>(); auto grad_flat = grad.flat_outer_dims<T>(); + T lr_scalar = lr.scalar<T>()(); + T l1_scalar = l1.scalar<T>()(); + T l2_scalar = l2.scalar<T>()(); + T lr_power_scalar = lr_power.scalar<T>()(); for (Tindex i = 0; i < N; i++) { const Tindex index = indices_vec(i); - typename TTypes<T>::Flat accum(&accum_flat(index, 0), - accum_flat.dimension(1)); - typename TTypes<T>::Flat linear(&linear_flat(index, 0), - linear_flat.dimension(1)); - typename TTypes<T>::Flat var(&var_flat(index, 0), - var_flat.dimension(1)); - typename TTypes<T>::ConstFlat grad(&grad_flat(i, 0), - grad_flat.dimension(1)); - - functor::ApplyFtrl<Device, T>()(device, var, accum, linear, grad, - lr.scalar<T>(), l1.scalar<T>(), - l2.scalar<T>(), lr_power.scalar<T>()); + auto accum = accum_flat.template chip<0>(index); + auto linear = linear_flat.template chip<0>(index); + auto grad = grad_flat.template chip<0>(i); + auto var = var_flat.template chip<0>(index); + + auto new_accum = accum + grad.square(); + if (lr_power_scalar == -0.5) { + linear += + grad - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; + } else { + linear += grad - + (new_accum.pow(-lr_power_scalar) - + accum.pow(-lr_power_scalar)) / + lr_scalar * var; + } + auto x = (linear.constant(l1_scalar) * linear.sign() - linear); + if (lr_power_scalar == -0.5) { + auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) + + linear.constant(2 * l2_scalar); + var = x / y; + } else { + auto y = new_accum.pow(-lr_power_scalar) / + new_accum.constant(lr_scalar) + + linear.constant(2 * l2_scalar); + var = x / y; + } + var = (linear.abs() > linear.constant(l1_scalar)) + .select(var, var.constant(0)); + accum += grad.square(); } } else { CHECK_EQ(1, inner_dim); diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc index 8727f1367b..d681958162 100644 --- a/tensorflow/core/lib/jpeg/jpeg_mem.cc +++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc @@ -141,6 +141,21 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { jpeg_start_decompress(&cinfo); + int64 total_size = static_cast<int64>(cinfo.output_height) * + static_cast<int64>(cinfo.output_width); + // Some of the internal routines do not gracefully handle ridiculously + // large images, so fail fast. + if (cinfo.output_width <= 0 || cinfo.output_height <= 0) { + LOG(ERROR) << "Invalid image size: " << cinfo.output_width << " x " + << cinfo.output_height; + return nullptr; + } + if (total_size >= (1LL << 29)) { + LOG(ERROR) << "Image too large: " << total_size; + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + // check for compatible stride const int min_stride = cinfo.output_width * components * sizeof(JSAMPLE); if (stride == 0) { @@ -405,6 +420,19 @@ bool CompressInternal(const uint8* srcdata, int width, int height, const CompressFlags& flags, string* output) { output->clear(); const int components = (static_cast<int>(flags.format) & 0xff); + + int64 total_size = static_cast<int64>(width) * static_cast<int64>(height); + // Some of the internal routines do not gracefully handle ridiculously + // large images, so fail fast. + if (width <= 0 || height <= 0) { + LOG(ERROR) << "Invalid image size: " << width << " x " << height; + return false; + } + if (total_size >= (1LL << 29)) { + LOG(ERROR) << "Image too large: " << total_size; + return false; + } + int in_stride = flags.stride; if (in_stride == 0) { in_stride = width * (static_cast<int>(flags.format) & 0xff); diff --git a/tensorflow/core/lib/strings/scanner.cc b/tensorflow/core/lib/strings/scanner.cc new file mode 100644 index 0000000000..b05400c97d --- /dev/null +++ b/tensorflow/core/lib/strings/scanner.cc @@ -0,0 +1,59 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/strings/scanner.h" + +namespace tensorflow { +namespace strings { + +void Scanner::ScanEscapedUntilImpl(char end_ch) { + for (;;) { + if (cur_.empty()) { + Error(); + return; + } + const char ch = cur_[0]; + if (ch == end_ch) { + return; + } + + cur_.remove_prefix(1); + if (ch == '\\') { + // Escape character, skip next character. + if (cur_.empty()) { + Error(); + return; + } + cur_.remove_prefix(1); + } + } +} + +bool Scanner::GetResult(StringPiece* remaining, StringPiece* capture) { + if (error_) { + return false; + } + if (remaining != nullptr) { + *remaining = cur_; + } + if (capture != nullptr) { + const char* end = capture_end_ == nullptr ? cur_.data() : capture_end_; + *capture = StringPiece(capture_start_, end - capture_start_); + } + return true; +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/scanner.h b/tensorflow/core/lib/strings/scanner.h new file mode 100644 index 0000000000..ecbb139d60 --- /dev/null +++ b/tensorflow/core/lib/strings/scanner.h @@ -0,0 +1,218 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LIB_STRINGS_SCANNER_H_ +#define TENSORFLOW_LIB_STRINGS_SCANNER_H_ + +#include <string> +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace strings { + +// Scanner provides simplified string parsing, in which a string is parsed as a +// series of scanning calls (e.g. One, Any, Many, OneLiteral, Eos), and then +// finally GetResult is called. If GetResult returns true, then it also returns +// the remaining characters and any captured substring. +// +// The range to capture can be controlled with RestartCapture and StopCapture; +// by default, all processed characters are captured. +class Scanner { + public: + // Classes of characters. Each enum name is to be read as the union of the + // parts - e.g., class LETTER_DIGIT means the class includes all letters and + // all digits. + // + // LETTER means ascii letter a-zA-Z. + // DIGIT means ascii digit: 0-9. + enum CharClass { + // NOTE: When adding a new CharClass, update the AllCharClasses ScannerTest + // in scanner_test.cc + DIGIT, + LETTER, + LETTER_DIGIT, + LETTER_DIGIT_DASH_DOT_SLASH, // SLASH is / only, not backslash + LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE, // SLASH is / only, not backslash + LETTER_DIGIT_DOT, + LETTER_DIGIT_DOT_UNDERSCORE, + LETTER_DIGIT_UNDERSCORE, + LOWERLETTER, + LOWERLETTER_DIGIT, + LOWERLETTER_DIGIT_UNDERSCORE, + NON_ZERO_DIGIT, + SPACE, + UPPERLETTER, + }; + + explicit Scanner(StringPiece source) : cur_(source) { RestartCapture(); } + + // Consume the next character of the given class from input. If the next + // character is not in the class, then GetResult will ultimately return false. + Scanner& One(CharClass clz) { + if (cur_.empty() || !Matches(clz, cur_[0])) { + return Error(); + } + cur_.remove_prefix(1); + return *this; + } + + // Consume the next s.size() characters of the input, if they match <s>. If + // they don't match <s>, this is a no-op. + Scanner& ZeroOrOneLiteral(StringPiece s) { + cur_.Consume(s); + return *this; + } + + // Consume the next s.size() characters of the input, if they match <s>. If + // they don't match <s>, then GetResult will ultimately return false. + Scanner& OneLiteral(StringPiece s) { + if (!cur_.Consume(s)) { + error_ = true; + } + return *this; + } + + // Consume characters from the input as long as they match <clz>. + Scanner& Any(CharClass clz) { + while (!cur_.empty() && Matches(clz, cur_[0])) { + cur_.remove_prefix(1); + } + return *this; + } + + // Shorthand for One(clz).Any(clz). + Scanner& Many(CharClass clz) { return One(clz).Any(clz); } + + // Reset the capture start point. + // + // Later, when GetResult is called and if it returns true, the capture + // returned will start at the position at the time this was called. + Scanner& RestartCapture() { + capture_start_ = cur_.data(); + return *this; + } + + // Stop capturing input. + // + // Later, when GetResult is called and if it returns true, the capture + // returned will end at the position at the time this was called. + Scanner& StopCapture() { + capture_end_ = cur_.data(); + return *this; + } + + // If not at the input of input, then GetResult will ultimately return false. + Scanner& Eos() { + if (!cur_.empty()) error_ = true; + return *this; + } + + // Shorthand for Any(SPACE). + Scanner& AnySpace() { return Any(SPACE); } + + // This scans input until <end_ch> is reached. <end_ch> is NOT consumed. + // Backslash escape sequences are skipped. + // Used for implementing quoted string scanning. + Scanner& ScanEscapedUntil(char end_ch) { + ScanEscapedUntilImpl(end_ch); + return *this; + } + + // Return the next character that will be scanned, or <default_value> if there + // are no more characters to scan. + // Note that if a scan operation has failed (so GetResult() returns false), + // then the value of Peek may or may not have advanced since the scan + // operation that failed. + char Peek(char default_value = '\0') const { + return cur_.empty() ? default_value : cur_[0]; + } + + // Returns true if the input string successfully matched. When true is + // returned, the remaining string is returned in <remaining> and the captured + // string returned in <capture>, if non-NULL. + bool GetResult(StringPiece* remaining = nullptr, + StringPiece* capture = nullptr); + + private: + void ScanEscapedUntilImpl(char end_ch); + + Scanner& Error() { + error_ = true; + return *this; + } + + static bool IsLetter(char ch) { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z'); + } + + static bool IsLowerLetter(char ch) { return ch >= 'a' && ch <= 'z'; } + + static bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; } + + static bool IsSpace(char ch) { + return (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\v' || ch == '\f' || + ch == '\r'); + } + + static bool Matches(CharClass clz, char ch) { + switch (clz) { + case DIGIT: + return IsDigit(ch); + case LETTER: + return IsLetter(ch); + case LETTER_DIGIT: + return IsLetter(ch) || IsDigit(ch); + case LETTER_DIGIT_DASH_DOT_SLASH: + return IsLetter(ch) || IsDigit(ch) || ch == '-' || ch == '.' || + ch == '/'; + case LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE: + return (IsLetter(ch) || IsDigit(ch) || ch == '-' || ch == '.' || + ch == '/' || ch == '_'); + case LETTER_DIGIT_DOT: + return IsLetter(ch) || IsDigit(ch) || ch == '.'; + case LETTER_DIGIT_DOT_UNDERSCORE: + return IsLetter(ch) || IsDigit(ch) || ch == '.' || ch == '_'; + case LETTER_DIGIT_UNDERSCORE: + return IsLetter(ch) || IsDigit(ch) || ch == '_'; + case LOWERLETTER: + return ch >= 'a' && ch <= 'z'; + case LOWERLETTER_DIGIT: + return IsLowerLetter(ch) || IsDigit(ch); + case LOWERLETTER_DIGIT_UNDERSCORE: + return IsLowerLetter(ch) || IsDigit(ch) || ch == '_'; + case NON_ZERO_DIGIT: + return IsDigit(ch) && ch != '0'; + case SPACE: + return IsSpace(ch); + case UPPERLETTER: + return ch >= 'A' && ch <= 'Z'; + } + return false; + } + + StringPiece cur_; + const char* capture_start_ = nullptr; + const char* capture_end_ = nullptr; + bool error_ = false; + + friend class ScannerTest; + TF_DISALLOW_COPY_AND_ASSIGN(Scanner); +}; + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_STRINGS_SCANNER_H_ diff --git a/tensorflow/core/lib/strings/scanner_test.cc b/tensorflow/core/lib/strings/scanner_test.cc new file mode 100644 index 0000000000..98028ae516 --- /dev/null +++ b/tensorflow/core/lib/strings/scanner_test.cc @@ -0,0 +1,266 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/strings/scanner.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace strings { + +class ScannerTest : public ::testing::Test { + protected: + // Returns a string with all chars that are in <clz>, in byte value order. + string ClassStr(Scanner::CharClass clz) { + string s; + for (int i = 0; i < 256; ++i) { + char ch = i; + if (Scanner::Matches(clz, ch)) { + s += ch; + } + } + return s; + } +}; + +TEST_F(ScannerTest, Any) { + StringPiece remaining, match; + EXPECT_TRUE(Scanner(" horse0123") + .Any(Scanner::SPACE) + .Any(Scanner::DIGIT) + .Any(Scanner::LETTER) + .GetResult(&remaining, &match)); + EXPECT_EQ(" horse", match.ToString()); + EXPECT_EQ("0123", remaining.ToString()); + + EXPECT_TRUE(Scanner("") + .Any(Scanner::SPACE) + .Any(Scanner::DIGIT) + .Any(Scanner::LETTER) + .GetResult(&remaining, &match)); + EXPECT_EQ("", remaining.ToString()); + EXPECT_EQ("", match.ToString()); + + EXPECT_TRUE(Scanner("----") + .Any(Scanner::SPACE) + .Any(Scanner::DIGIT) + .Any(Scanner::LETTER) + .GetResult(&remaining, &match)); + EXPECT_EQ("----", remaining.ToString()); + EXPECT_EQ("", match.ToString()); +} + +TEST_F(ScannerTest, AnySpace) { + StringPiece remaining, match; + EXPECT_TRUE(Scanner(" a b ") + .AnySpace() + .One(Scanner::LETTER) + .AnySpace() + .GetResult(&remaining, &match)); + EXPECT_EQ(" a ", match.ToString()); + EXPECT_EQ("b ", remaining.ToString()); +} + +TEST_F(ScannerTest, Eos) { + EXPECT_FALSE(Scanner("a").Eos().GetResult()); + EXPECT_TRUE(Scanner("").Eos().GetResult()); + EXPECT_FALSE(Scanner("abc").OneLiteral("ab").Eos().GetResult()); + EXPECT_TRUE(Scanner("abc").OneLiteral("abc").Eos().GetResult()); +} + +TEST_F(ScannerTest, Many) { + StringPiece remaining, match; + EXPECT_TRUE(Scanner("abc").Many(Scanner::LETTER).GetResult()); + EXPECT_FALSE(Scanner("0").Many(Scanner::LETTER).GetResult()); + EXPECT_FALSE(Scanner("").Many(Scanner::LETTER).GetResult()); + + EXPECT_TRUE( + Scanner("abc ").Many(Scanner::LETTER).GetResult(&remaining, &match)); + EXPECT_EQ(" ", remaining); + EXPECT_EQ("abc", match); + EXPECT_TRUE( + Scanner("abc").Many(Scanner::LETTER).GetResult(&remaining, &match)); + EXPECT_EQ("", remaining); + EXPECT_EQ("abc", match); +} + +TEST_F(ScannerTest, One) { + StringPiece remaining, match; + EXPECT_TRUE(Scanner("abc").One(Scanner::LETTER).GetResult()); + EXPECT_FALSE(Scanner("0").One(Scanner::LETTER).GetResult()); + EXPECT_FALSE(Scanner("").One(Scanner::LETTER).GetResult()); + + EXPECT_TRUE(Scanner("abc") + .One(Scanner::LETTER) + .One(Scanner::LETTER) + .GetResult(&remaining, &match)); + EXPECT_EQ("c", remaining); + EXPECT_EQ("ab", match); + EXPECT_TRUE(Scanner("a").One(Scanner::LETTER).GetResult(&remaining, &match)); + EXPECT_EQ("", remaining); + EXPECT_EQ("a", match); +} + +TEST_F(ScannerTest, OneLiteral) { + EXPECT_FALSE(Scanner("abc").OneLiteral("abC").GetResult()); + EXPECT_TRUE(Scanner("abc").OneLiteral("ab").OneLiteral("c").GetResult()); +} + +TEST_F(ScannerTest, ScanEscapedUntil) { + StringPiece remaining, match; + EXPECT_TRUE(Scanner(R"(' \1 \2 \3 \' \\'rest)") + .OneLiteral("'") + .ScanEscapedUntil('\'') + .OneLiteral("'") + .GetResult(&remaining, &match)); + EXPECT_EQ("rest", remaining.ToString()); + EXPECT_EQ(R"(' \1 \2 \3 \' \\')", match.ToString()); + + // The "scan until" character is not present. + remaining = match = "unset"; + EXPECT_FALSE(Scanner(R"(' \1 \2 \3 \' \\rest)") + .OneLiteral("'") + .ScanEscapedUntil('\'') + .GetResult(&remaining, &match)); + EXPECT_EQ("unset", remaining.ToString()); + EXPECT_EQ("unset", match.ToString()); +} + +TEST_F(ScannerTest, ZeroOrOneLiteral) { + StringPiece remaining, match; + EXPECT_TRUE( + Scanner("abc").ZeroOrOneLiteral("abC").GetResult(&remaining, &match)); + EXPECT_EQ("abc", remaining.ToString()); + EXPECT_EQ("", match.ToString()); + + EXPECT_TRUE( + Scanner("abcd").ZeroOrOneLiteral("ab").ZeroOrOneLiteral("c").GetResult( + &remaining, &match)); + EXPECT_EQ("d", remaining.ToString()); + EXPECT_EQ("abc", match.ToString()); + + EXPECT_TRUE( + Scanner("").ZeroOrOneLiteral("abc").GetResult(&remaining, &match)); + EXPECT_EQ("", remaining.ToString()); + EXPECT_EQ("", match.ToString()); +} + +// Test output of GetResult (including the forms with optional params), +// and that it can be called multiple times. +TEST_F(ScannerTest, CaptureAndGetResult) { + StringPiece remaining, match; + + Scanner scan(" first second"); + EXPECT_TRUE(scan.Any(Scanner::SPACE) + .RestartCapture() + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT) + .StopCapture() + .Any(Scanner::SPACE) + .GetResult(&remaining, &match)); + EXPECT_EQ("second", remaining.ToString()); + EXPECT_EQ("first", match.ToString()); + EXPECT_TRUE(scan.GetResult()); + remaining = ""; + EXPECT_TRUE(scan.GetResult(&remaining)); + EXPECT_EQ("second", remaining.ToString()); + remaining = ""; + match = ""; + EXPECT_TRUE(scan.GetResult(&remaining, &match)); + EXPECT_EQ("second", remaining.ToString()); + EXPECT_EQ("first", match.ToString()); +} + +// Tests that if StopCapture is not called, then calling GetResult, then +// scanning more, then GetResult again will update the capture. +TEST_F(ScannerTest, MultipleGetResultExtendsCapture) { + StringPiece remaining, match; + + Scanner scan("one2three"); + EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match)); + EXPECT_EQ("2three", remaining.ToString()); + EXPECT_EQ("one", match.ToString()); + EXPECT_TRUE(scan.Many(Scanner::DIGIT).GetResult(&remaining, &match)); + EXPECT_EQ("three", remaining.ToString()); + EXPECT_EQ("one2", match.ToString()); + EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match)); + EXPECT_EQ("", remaining.ToString()); + EXPECT_EQ("one2three", match.ToString()); +} + +TEST_F(ScannerTest, FailedMatchDoesntChangeResult) { + // A failed match doesn't change pointers passed to GetResult. + Scanner scan("name"); + StringPiece remaining = "rem"; + StringPiece match = "match"; + EXPECT_FALSE(scan.One(Scanner::SPACE).GetResult(&remaining, &match)); + EXPECT_EQ("rem", remaining.ToString()); + EXPECT_EQ("match", match.ToString()); +} + +TEST_F(ScannerTest, DefaultCapturesAll) { + // If RestartCapture() is not called, the whole string is used. + Scanner scan("a b"); + StringPiece remaining = "rem"; + StringPiece match = "match"; + EXPECT_TRUE(scan.Any(Scanner::LETTER) + .AnySpace() + .Any(Scanner::LETTER) + .GetResult(&remaining, &match)); + EXPECT_EQ("", remaining.ToString()); + EXPECT_EQ("a b", match.ToString()); +} + +TEST_F(ScannerTest, AllCharClasses) { + EXPECT_EQ("0123456789", ClassStr(Scanner::DIGIT)); + EXPECT_EQ("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER)); + EXPECT_EQ("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT)); + EXPECT_EQ( + "-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT_DASH_DOT_SLASH)); + EXPECT_EQ( + "-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_" + "abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)); + EXPECT_EQ(".0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT_DOT)); + EXPECT_EQ(".0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT_DOT_UNDERSCORE)); + EXPECT_EQ("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LETTER_DIGIT_UNDERSCORE)); + EXPECT_EQ("abcdefghijklmnopqrstuvwxyz", ClassStr(Scanner::LOWERLETTER)); + EXPECT_EQ("0123456789abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LOWERLETTER_DIGIT)); + EXPECT_EQ("0123456789_abcdefghijklmnopqrstuvwxyz", + ClassStr(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)); + EXPECT_EQ("123456789", ClassStr(Scanner::NON_ZERO_DIGIT)); + EXPECT_EQ("\t\n\v\f\r ", ClassStr(Scanner::SPACE)); + EXPECT_EQ("ABCDEFGHIJKLMNOPQRSTUVWXYZ", ClassStr(Scanner::UPPERLETTER)); +} + +TEST_F(ScannerTest, Peek) { + EXPECT_EQ('a', Scanner("abc").Peek()); + EXPECT_EQ('a', Scanner("abc").Peek('b')); + EXPECT_EQ('\0', Scanner("").Peek()); + EXPECT_EQ('z', Scanner("").Peek('z')); + EXPECT_EQ('A', Scanner("0123A").Any(Scanner::DIGIT).Peek()); + EXPECT_EQ('\0', Scanner("0123A").Any(Scanner::LETTER_DIGIT).Peek()); +} + +} // namespace strings +} // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 6fcb33d39b..b7d1767e00 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -876,6 +876,25 @@ py_test( ], ) +py_library( + name = "device_lib", + srcs = ["client/device_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":pywrap_tensorflow", + ], +) + +cuda_py_tests( + name = "device_lib_test", + srcs = [ + "client/device_lib_test.py", + ], + additional_deps = [ + ":device_lib", + ], +) + tf_cuda_library( name = "tf_session_helper", srcs = ["client/tf_session_helper.cc"], @@ -897,6 +916,7 @@ tf_py_wrap_cc( name = "pywrap_tensorflow", srcs = ["tensorflow.i"], swig_includes = [ + "client/device_lib.i", "client/events_writer.i", "client/server_lib.i", "client/tf_session.i", diff --git a/tensorflow/python/client/device_lib.i b/tensorflow/python/client/device_lib.i new file mode 100644 index 0000000000..a651605fc0 --- /dev/null +++ b/tensorflow/python/client/device_lib.i @@ -0,0 +1,81 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "tensorflow/python/platform/base.i" + +%{ +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/public/session_options.h" +%} + +%typemap(in, numinputs=0) const tensorflow::SessionOptions& options ( + tensorflow::SessionOptions temp) { + $1 = &temp; +} + +%typemap(in, numinputs=0) std::vector<tensorflow::Device*>* devices ( + std::vector<tensorflow::Device*> temp) { + $1 = &temp; +} + +%typemap(argout) std::vector<tensorflow::Device*>* devices { + std::vector< std::unique_ptr<tensorflow::Device> > safe_devices; + for (auto* device : *$1) safe_devices.emplace_back(device); + + auto temp_string_list = tensorflow::make_safe(PyList_New(0)); + if (!temp_string_list) { + SWIG_fail; + } + + for (const auto& device : safe_devices) { + const tensorflow::DeviceAttributes& attr = device->attributes(); + string attr_serialized; + if (!attr.SerializeToString(&attr_serialized)) { + PyErr_SetString(PyExc_RuntimeError, + "Unable to serialize DeviceAttributes"); + SWIG_fail; + } + + tensorflow::Safe_PyObjectPtr safe_attr_string = tensorflow::make_safe( + %#if PY_MAJOR_VERSION < 3 + PyString_FromStringAndSize( + %#else + PyUnicode_FromStringAndSize( + %#endif + reinterpret_cast<const char*>( + attr_serialized.data()), attr_serialized.size())); + + if (PyList_Append(temp_string_list.get(), safe_attr_string.get()) == -1) { + SWIG_fail; + } + } + + $result = temp_string_list.release(); +} + + +%ignoreall + +%unignore tensorflow; +%unignore tensorflow::DeviceFactory; +%unignore tensorflow::DeviceFactory::AddDevices; + +%include "tensorflow/core/common_runtime/device_factory.h" + +%unignoreall + +%newobject tensorflow::SessionOptions; diff --git a/tensorflow/python/client/device_lib.py b/tensorflow/python/client/device_lib.py new file mode 100644 index 0000000000..75872c463a --- /dev/null +++ b/tensorflow/python/client/device_lib.py @@ -0,0 +1,37 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A Python interface for creating TensorFlow servers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six # pylint: disable=unused-import + +from tensorflow.core.framework import device_attributes_pb2 +from tensorflow.python import pywrap_tensorflow + + +def list_local_devices(): + """List the available devices available in the local process. + + Returns: + A list of `DeviceAttribute` protocol buffers. + """ + def _convert(pb_str): + m = device_attributes_pb2.DeviceAttributes() + m.ParseFromString(pb_str) + return m + return [_convert(s) for s in pywrap_tensorflow.DeviceFactory_AddDevices("")] diff --git a/tensorflow/python/client/device_lib_test.py b/tensorflow/python/client/device_lib_test.py new file mode 100644 index 0000000000..ee028573aa --- /dev/null +++ b/tensorflow/python/client/device_lib_test.py @@ -0,0 +1,42 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for the SWIG-wrapped device lib.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.python.client import device_lib +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class DeviceLibTest(test_util.TensorFlowTestCase): + + def testListLocalDevices(self): + devices = device_lib.list_local_devices() + self.assertGreater(len(devices), 0) + self.assertEqual(devices[0].device_type, "CPU") + + # GPU test + if tf.test.is_built_with_cuda(): + self.assertGreater(len(devices), 1) + self.assertTrue("GPU" in [d.device_type for d in devices]) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 5fbdf19229..217dc1f14c 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -20,11 +20,9 @@ from __future__ import division from __future__ import print_function import re -import sys import threading import numpy as np -import six from tensorflow.python import pywrap_tensorflow as tf_session from tensorflow.python.framework import errors @@ -570,22 +568,21 @@ class BaseSession(SessionInterface): try: return fn(*args) except tf_session.StatusNotOK as e: - e_type, e_value, e_traceback = sys.exc_info() error_message = compat.as_text(e.error_message) m = BaseSession._NODEDEF_NAME_RE.search(error_message) + node_def = None + op = None if m is not None: node_name = m.group(1) - node_def = None try: op = self._graph.get_operation_by_name(node_name) node_def = op.node_def except KeyError: - op = None - # pylint: disable=protected-access - raise errors._make_specific_exception(node_def, op, error_message, - e.code) - # pylint: enable=protected-access - six.reraise(e_type, e_value, e_traceback) + pass + # pylint: disable=protected-access + raise errors._make_specific_exception(node_def, op, error_message, + e.code) + # pylint: enable=protected-access def _extend_graph(self): # Ensure any changes to the graph are reflected in the runtime. diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 9a59562842..09a9935c5e 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -108,6 +108,19 @@ class SessionTest(test_util.TensorFlowTestCase): with self.assertRaisesOpError(lambda e: e.op == a.op): a.eval() + def testErrorCodeWithNoNodeDef(self): + with session.Session() as s: + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + + def exc_predicate(e): + return (e.op is None and e.node_def is None and + e.error_code == error_codes_pb2.INVALID_ARGUMENT) + with self.assertRaisesOpError(exc_predicate): + # Run with a bogus handle. + s.partial_run('foo', r1, feed_dict={a: 1, b: 2}) + def testOpConstructionErrorPayload(self): with session.Session(): failing_op = ops.get_default_graph().create_op( diff --git a/tensorflow/python/framework/errors.py b/tensorflow/python/framework/errors.py index f2ffbb9876..f7aaa63792 100644 --- a/tensorflow/python/framework/errors.py +++ b/tensorflow/python/framework/errors.py @@ -38,7 +38,8 @@ class OpError(Exception): """Creates a new `OpError` indicating that a particular op failed. Args: - node_def: The `graph_pb2.NodeDef` proto representing the op that failed. + node_def: The `graph_pb2.NodeDef` proto representing the op that failed, + if known; otherwise None. op: The `ops.Operation` that failed, if known; otherwise None. message: The message string describing the failure. error_code: The `error_codes_pb2.Code` describing the error. diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 77da519fcc..9c25ea555c 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -31,7 +31,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import logging_ops -from tensorflow.python.pywrap_tensorflow import StatusNotOK def check_op_order(graph): """Sanity check on the ordering of op id.""" @@ -137,7 +136,8 @@ class ControlFlowTest(tf.test.TestCase): dead_branch = tf.identity(switch_op[0]) with self.assertRaisesWithPredicateMatch( - StatusNotOK, lambda e: "The tensor returned for" in str(e)): + tf.errors.InvalidArgumentError, + lambda e: "The tensor returned for" in str(e)): dead_branch.eval() def testSwitchMergeLess(self): diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index af232f65cc..2504dfaa31 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -25,7 +25,6 @@ import time import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf -from tensorflow.python.pywrap_tensorflow import StatusNotOK class FIFOQueueTest(tf.test.TestCase): @@ -1161,7 +1160,7 @@ class FIFOQueueWithTimeoutTest(tf.test.TestCase): # Intentionally do not run any enqueue_ops so that dequeue will block # until operation_timeout_in_ms. - with self.assertRaisesRegexp(StatusNotOK, + with self.assertRaisesRegexp(tf.errors.DeadlineExceededError, "Timed out waiting for notification"): sess.run(dequeued_t) diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 766bdf7dd3..fb2dd00b69 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -29,5 +29,6 @@ limitations under the License. %include "tensorflow/python/client/tf_session.i" %include "tensorflow/python/client/server_lib.i" +%include "tensorflow/python/client/device_lib.i" %include "tensorflow/python/framework/python_op_gen.i" diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 6debeabd97..24d7e16831 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -10,8 +10,8 @@ def tf_workspace(path_prefix = ""): native.new_http_archive( name = "eigen_archive", - url = "https://bitbucket.org/eigen/eigen/get/88444e025a5c.tar.gz", - sha256 = "42e6f6de56b3ff010531a2bbf3e2db1db46be30d3965efb1eaa5634c5db013dd", + url = "https://bitbucket.org/eigen/eigen/get/2f482bcc8b95.tar.gz", + sha256 = "1a80b0e348c3f3608fc5ce1c222902da37a229a01d3965c4622ec3d287617456", build_file = path_prefix + "eigen.BUILD", ) diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky index 95a503d611..e3e05b0425 100644 --- a/third_party/eigen3/Eigen/Cholesky +++ b/third_party/eigen3/Eigen/Cholesky @@ -1 +1 @@ -#include "eigen-eigen-88444e025a5c/Eigen/Cholesky" +#include "eigen-eigen-2f482bcc8b95/Eigen/Cholesky" diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core index b4a10f6ed1..87e7c9813e 100644 --- a/third_party/eigen3/Eigen/Core +++ b/third_party/eigen3/Eigen/Core @@ -1 +1 @@ -#include "eigen-eigen-88444e025a5c/Eigen/Core" +#include "eigen-eigen-2f482bcc8b95/Eigen/Core" diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues index 56657aa837..6e1a765be7 100644 --- a/third_party/eigen3/Eigen/Eigenvalues +++ b/third_party/eigen3/Eigen/Eigenvalues @@ -1 +1 @@ -#include "eigen-eigen-88444e025a5c/Eigen/Eigenvalues" +#include "eigen-eigen-2f482bcc8b95/Eigen/Eigenvalues" diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU index 3c491eeef9..be0e8c7eb8 100644 --- a/third_party/eigen3/Eigen/LU +++ b/third_party/eigen3/Eigen/LU @@ -1 +1 @@ -#include "eigen-eigen-88444e025a5c/Eigen/LU" +#include "eigen-eigen-2f482bcc8b95/Eigen/LU" diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR index 5a97880470..17a6a1e34e 100644 --- a/third_party/eigen3/Eigen/QR +++ b/third_party/eigen3/Eigen/QR @@ -1 +1 @@ -#include "eigen-eigen-88444e025a5c/Eigen/QR" +#include "eigen-eigen-2f482bcc8b95/Eigen/QR" diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor index 20150d0594..57404ad17d 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor +++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor @@ -1 +1 @@ -#include "eigen-eigen-88444e025a5c/unsupported/Eigen/CXX11/Tensor" +#include "eigen-eigen-2f482bcc8b95/unsupported/Eigen/CXX11/Tensor" |