aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--eigen.BUILD2
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/common_runtime/executor.cc4
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc67
-rw-r--r--tensorflow/core/framework/attr_value_util.cc11
-rw-r--r--tensorflow/core/framework/graph_def_util.cc5
-rw-r--r--tensorflow/core/framework/graph_def_util_test.cc30
-rw-r--r--tensorflow/core/framework/node_def_util.cc49
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc33
-rw-r--r--tensorflow/core/framework/op_def_builder.cc244
-rw-r--r--tensorflow/core/framework/op_def_builder_test.cc40
-rw-r--r--tensorflow/core/framework/op_def_util.cc14
-rw-r--r--tensorflow/core/framework/resource_mgr.cc15
-rw-r--r--tensorflow/core/framework/resource_mgr_test.cc3
-rw-r--r--tensorflow/core/graph/graph_constructor.cc23
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc23
-rw-r--r--tensorflow/core/kernels/BUILD2
-rw-r--r--tensorflow/core/kernels/decode_raw_op.cc8
-rw-r--r--tensorflow/core/kernels/ops_util.cc16
-rw-r--r--tensorflow/core/kernels/resize_bicubic_op.cc85
-rw-r--r--tensorflow/core/kernels/resize_bicubic_op_test.cc57
-rw-r--r--tensorflow/core/kernels/tile_ops.h7
-rw-r--r--tensorflow/core/kernels/training_ops.cc46
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.cc28
-rw-r--r--tensorflow/core/lib/strings/scanner.cc59
-rw-r--r--tensorflow/core/lib/strings/scanner.h218
-rw-r--r--tensorflow/core/lib/strings/scanner_test.cc266
-rw-r--r--tensorflow/python/BUILD20
-rw-r--r--tensorflow/python/client/device_lib.i81
-rw-r--r--tensorflow/python/client/device_lib.py37
-rw-r--r--tensorflow/python/client/device_lib_test.py42
-rw-r--r--tensorflow/python/client/session.py17
-rw-r--r--tensorflow/python/client/session_test.py13
-rw-r--r--tensorflow/python/framework/errors.py3
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py4
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py3
-rw-r--r--tensorflow/python/tensorflow.i1
-rw-r--r--tensorflow/workspace.bzl4
-rw-r--r--third_party/eigen3/Eigen/Cholesky2
-rw-r--r--third_party/eigen3/Eigen/Core2
-rw-r--r--third_party/eigen3/Eigen/Eigenvalues2
-rw-r--r--third_party/eigen3/Eigen/LU2
-rw-r--r--third_party/eigen3/Eigen/QR2
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/Tensor2
46 files changed, 1307 insertions, 289 deletions
diff --git a/eigen.BUILD b/eigen.BUILD
index 806b6d36b9..3e92d5887a 100644
--- a/eigen.BUILD
+++ b/eigen.BUILD
@@ -1,6 +1,6 @@
package(default_visibility = ["//visibility:public"])
-archive_dir = "eigen-eigen-88444e025a5c"
+archive_dir = "eigen-eigen-2f482bcc8b95"
cc_library(
name = "eigen",
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5ee2337647..901071e0f9 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -738,6 +738,7 @@ cc_library(
"lib/random/weighted_picker.h",
"lib/strings/ordered_code.h",
"lib/strings/regexp.h",
+ "lib/strings/scanner.h",
"platform/denormal.h",
"platform/platform.h",
"platform/tensor_coding.h",
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index c43f1d0973..6ee5f7d446 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1635,14 +1635,14 @@ void ExecutorState::DumpActiveNodeState(const int node_id,
void ExecutorState::DumpIterationState(IterationState* iteration) {
// Dump any waiting nodes that are holding on to tensors.
- for (size_t i = 0; i < impl_->graph_->num_node_ids(); ++i) {
+ for (int i = 0; i < impl_->graph_->num_node_ids(); ++i) {
if (iteration->node_state(i) == PendingCounts::PENDING_NOTREADY ||
iteration->node_state(i) == PendingCounts::PENDING_READY) {
DumpPendingNodeState(i, iteration->input_tensors, false);
}
}
// Then the active nodes.
- for (size_t i = 0; i < impl_->graph_->num_node_ids(); ++i) {
+ for (int i = 0; i < impl_->graph_->num_node_ids(); ++i) {
if (iteration->node_state(i) == PendingCounts::STARTED) {
DumpActiveNodeState(i, iteration->input_tensors);
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc
index a25d072eea..cd064dd4c4 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc
@@ -388,7 +388,7 @@ string GPUUtil::MemoryDebugString(const Device* device, Tensor* tensor) {
string buf;
buf.resize(num_bytes);
DeviceMemoryBase gpu_ptr(ptr, num_bytes);
- Status s = dev_info->stream->parent()->SynchronousMemcpyD2H(
+ auto s = dev_info->stream->parent()->SynchronousMemcpyD2H(
gpu_ptr, num_bytes, gtl::string_as_array(&buf));
strings::StrAppend(&ret,
PrintMemory(gtl::string_as_array(&buf), num_bytes));
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index df86046c45..be17ad4c4d 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -329,7 +329,6 @@ tf_cc_tests(
tags = tf_cuda_tests_tags() + ["exclusive"],
tests = [
"grpc_channel_test.cc",
- "grpc_server_lib_test.cc",
"grpc_session_test.cc",
"rpc_rendezvous_mgr_test.cc",
],
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc
deleted file mode 100644
index a56afb05a6..0000000000
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc
+++ /dev/null
@@ -1,67 +0,0 @@
-/* Copyright 2016 Google Inc. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/distributed_runtime/server_lib.h"
-
-#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-
-// Tests that a server can be cleanly started, stopped, and joined
-// when no calls are made against the server.
-TEST(Server, StopAfterNoop) {
- ServerDef def;
- def.set_protocol("grpc");
- def.set_job_name("localhost");
- def.set_task_index(0);
- JobDef* job_def = def.mutable_cluster()->add_job();
- job_def->set_name("localhost");
- (*job_def->mutable_tasks())[0] =
- strings::StrCat("localhost:", testing::PickUnusedPortOrDie());
- std::unique_ptr<ServerInterface> svr;
- TF_EXPECT_OK(NewServer(def, &svr));
- TF_EXPECT_OK(svr->Start());
- TF_EXPECT_OK(svr->Stop());
- TF_EXPECT_OK(svr->Join());
-}
-
-// Tests that a server can be cleanly started, stopped, and joined
-// when a simple call is made against the server.
-TEST(Server, StopAfterCall) {
- ServerDef def;
- def.set_protocol("grpc");
- def.set_job_name("localhost");
- def.set_task_index(0);
- JobDef* job_def = def.mutable_cluster()->add_job();
- job_def->set_name("localhost");
- int port = testing::PickUnusedPortOrDie();
- (*job_def->mutable_tasks())[0] = strings::StrCat("localhost:", port);
- std::unique_ptr<ServerInterface> svr;
- TF_EXPECT_OK(NewServer(def, &svr));
- TF_EXPECT_OK(svr->Start());
- {
- SessionOptions options;
- options.target = strings::StrCat("grpc://localhost:", port);
- std::unique_ptr<GrpcSession> sess(new GrpcSession(options));
- const std::vector<DeviceAttributes> devices = sess->ListDevices();
- EXPECT_GT(devices.size(), 0);
- }
- TF_EXPECT_OK(svr->Stop());
- TF_EXPECT_OK(svr->Join());
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
index 93823b154e..05e1f01a0c 100644
--- a/tensorflow/core/framework/attr_value_util.cc
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/regexp.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -250,10 +249,16 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
if (is_list) {
// TextFormat parser considers "i: 7" to be the same as "i: [7]",
// but we only want to allow list values with [].
- if (!RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[.*\\]\\s*")) {
+ StringPiece cleaned = text;
+ str_util::RemoveLeadingWhitespace(&cleaned);
+ str_util::RemoveTrailingWhitespace(&cleaned);
+ if (cleaned.size() < 2 || cleaned[0] != '[' ||
+ cleaned[cleaned.size() - 1] != ']') {
return false;
}
- if (RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[\\s*\\]\\s*")) {
+ cleaned.remove_prefix(1);
+ str_util::RemoveLeadingWhitespace(&cleaned);
+ if (cleaned.size() == 1) {
// User wrote "[]", so return empty list without invoking the TextFormat
// parse which returns an error for "i: []".
out->Clear();
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc
index 71281cf7b2..ccfbec662f 100644
--- a/tensorflow/core/framework/graph_def_util.cc
+++ b/tensorflow/core/framework/graph_def_util.cc
@@ -84,8 +84,9 @@ Status RemoveNewDefaultAttrsFromGraphDef(
if (!s.ok()) return s;
for (const auto& attr : node_def->attr()) {
- // If the attr is not in consumer_op_def...
- if (FindAttr(attr.first, *consumer_op_def) == nullptr) {
+ // If the attr is not in consumer_op_def and doesn't start with '_'...
+ if (!StringPiece(attr.first).starts_with("_") &&
+ FindAttr(attr.first, *consumer_op_def) == nullptr) {
const OpDef::AttrDef* producer_attr_def =
FindAttr(attr.first, *producer_op_def);
if (producer_attr_def == nullptr) {
diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc
index 96001c8f6d..2b51a0b1de 100644
--- a/tensorflow/core/framework/graph_def_util_test.cc
+++ b/tensorflow/core/framework/graph_def_util_test.cc
@@ -133,6 +133,36 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
EXPECT_TRUE(op_attr_removed.empty());
}
+// Attrs starting with underscores should not be removed.
+TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) {
+ OpList consumer_op_list;
+ TF_ASSERT_OK(OpDefBuilder("Underscore").Finalize(consumer_op_list.add_op()));
+ OpListOpRegistry consumer_registry(&consumer_op_list);
+
+ OpList producer_op_list;
+ TF_ASSERT_OK(OpDefBuilder("Underscore").Finalize(producer_op_list.add_op()));
+ // Add the _underscore attr manually since OpDefBuilder would complain
+ OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr();
+ attr->set_name("_underscore");
+ attr->set_type("int");
+ attr->mutable_default_value()->set_i(17);
+ OpListOpRegistry producer_registry(&producer_op_list);
+
+ GraphDef produced_graph_def;
+ TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry)
+ .Attr("_underscore", 17)
+ .Finalize(produced_graph_def.add_node()));
+ GraphDef expected_graph_def = produced_graph_def;
+
+ std::set<std::pair<string, string>> op_attr_removed;
+ TF_ASSERT_OK(
+ RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
+ producer_registry, &op_attr_removed));
+
+ TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
+ EXPECT_EQ(op_attr_removed.size(), 0);
+}
+
TEST(StrippedOpListForGraphTest, FlatTest) {
// Make four ops
OpList op_list;
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 0f08d391ac..641411892d 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -24,9 +24,9 @@ limitations under the License.
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
@@ -381,19 +381,50 @@ void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
namespace {
-static RE2* valid_op_name_pattern = new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
-static RE2* valid_data_input_pattern =
- new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*(\\:(0|([1-9][0-9]*)))?");
-static RE2* valid_control_input_pattern =
- new RE2("\\^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
+using ::tensorflow::strings::Scanner;
+
+bool IsValidOpName(StringPiece sp) {
+ return Scanner(sp)
+ .One(Scanner::LETTER_DIGIT_DOT)
+ .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .Eos()
+ .GetResult();
+}
+
+bool IsValidDataInputName(StringPiece sp) {
+ // Data inputs are op_name, op_name:0, or op_name:12345.
+ Scanner scan(sp);
+ scan.One(Scanner::LETTER_DIGIT_DOT)
+ .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
+ if (scan.Peek() == ':') {
+ scan.OneLiteral(":");
+ if (scan.Peek() == '0') {
+ scan.OneLiteral("0"); // :0
+ } else {
+ scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
+ }
+ }
+ scan.Eos();
+
+ return scan.GetResult();
+}
+
+bool IsValidControlInputName(StringPiece sp) {
+ return Scanner(sp)
+ .OneLiteral("^")
+ .One(Scanner::LETTER_DIGIT_DOT)
+ .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .Eos()
+ .GetResult();
+}
} // namespace
Status ValidateOpInput(const string& input_name, bool* is_control_input) {
*is_control_input = false;
- if (RE2::FullMatch(input_name, *valid_data_input_pattern)) {
+ if (IsValidDataInputName(input_name)) {
return Status::OK();
- } else if (RE2::FullMatch(input_name, *valid_control_input_pattern)) {
+ } else if (IsValidControlInputName(input_name)) {
*is_control_input = true;
return Status::OK();
} else {
@@ -402,7 +433,7 @@ Status ValidateOpInput(const string& input_name, bool* is_control_input) {
}
Status ValidateOpName(const string& op_name) {
- if (RE2::FullMatch(op_name, *valid_op_name_pattern)) {
+ if (IsValidOpName(op_name)) {
return Status::OK();
} else {
return errors::InvalidArgument("Illegal op name '", op_name, "'");
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 3a405fc275..07bd60f3b7 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -308,6 +308,12 @@ TEST(NodeDefUtilTest, ValidSyntax) {
)proto");
ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'");
+ const NodeDef node_def_slash_in_name = ToNodeDef(R"proto(
+ name:'n\\' op:'AnyIn' input:'a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'");
+
const NodeDef node_def_internal_input_name = ToNodeDef(R"proto(
name:'n' op:'AnyIn' input:'_a' input:'b'
attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
@@ -315,6 +321,12 @@ TEST(NodeDefUtilTest, ValidSyntax) {
ExpectInvalidSyntax(node_def_internal_input_name,
"Illegal op input name '_a'");
+ const NodeDef node_def_input_name_slash = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a\\' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'");
+
const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto(
name:'n' op:'AnyIn' input:'a' input:'^b:0'
attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
@@ -322,12 +334,33 @@ TEST(NodeDefUtilTest, ValidSyntax) {
ExpectInvalidSyntax(node_def_invalid_control_input_name,
"Illegal op input name '^b:0'");
+ const NodeDef node_def_control_input_name_slash = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' input:'^b\\'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_control_input_name_slash,
+ "Illegal op input name '^b\\'");
+
const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto(
name:'n' op:'AnyIn' input:'^a' input:'b'
attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
)proto");
ExpectInvalidSyntax(node_def_data_input_after_control,
"All control inputs must follow all data inputs");
+
+ const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a:b' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_data_input_invalid_port,
+ "Illegal op input name 'a:b");
+
+ const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a:00' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_data_input_invalid_port2,
+ "Illegal op input name 'a:00");
}
TEST(NameRangesForNodeTest, Simple) {
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 8983371503..2cd6770f6c 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -15,80 +15,99 @@ limitations under the License.
#include "tensorflow/core/framework/op_def_builder.h"
+#include <limits>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/regexp.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+using ::tensorflow::strings::Scanner;
+
namespace tensorflow {
namespace {
-bool RE2Consume(StringPiece* sp, const RE2& pattern) {
- RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
- bool r = RE2::Consume(&base_sp, pattern);
- *sp = FromRegexpStringPiece(base_sp);
- return r;
-}
-
-bool RE2Consume(StringPiece* sp, const RE2& pattern, StringPiece* out) {
- RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
- RegexpStringPiece base_out;
- bool r = RE2::Consume(&base_sp, pattern, &base_out);
- *sp = FromRegexpStringPiece(base_sp);
- *out = FromRegexpStringPiece(base_out);
- return r;
-}
-
-bool RE2Consume(StringPiece* sp, const RE2& pattern, int64* out) {
- RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
- bool r = RE2::Consume(&base_sp, pattern, out);
- *sp = FromRegexpStringPiece(base_sp);
- return r;
-}
-
string AttrError(StringPiece orig, const string& op_name) {
return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
}
-const RE2& AttrNameRE() {
- static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*");
- return pattern;
-}
-
-const RE2& AttrListPrefixRE() {
- static RE2 pattern("list\\s*\\(\\s*");
- return pattern;
-}
-
-const RE2& SpacesRE() {
- static RE2 pattern("\\s*");
- return pattern;
-}
-
-const RE2& AttrDoubleQuotedRE() {
- static RE2 pattern(R"xx("((?:[^"\\]|\\.)*)"\s*)xx");
- return pattern;
-}
-
-const RE2& AttrSingleQuotedRE() {
- static RE2 pattern(R"xx('((?:[^'\\]|\\.)*)'\s*)xx");
- return pattern;
-}
-
-const RE2& AttrTypeRE() {
- static RE2 pattern("([a-z0-9]+)\\s*");
- return pattern;
-}
-
-const RE2& AttrNumberRE() {
- static RE2 pattern("\\s*(-?\\d+)\\s*");
- return pattern;
+bool ConsumeAttrName(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .One(Scanner::LETTER)
+ .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
+ .StopCapture()
+ .AnySpace()
+ .OneLiteral(":")
+ .AnySpace()
+ .GetResult(sp, out);
+}
+
+bool ConsumeListPrefix(StringPiece* sp) {
+ return Scanner(*sp)
+ .OneLiteral("list")
+ .AnySpace()
+ .OneLiteral("(")
+ .AnySpace()
+ .GetResult(sp);
+}
+
+bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) {
+ const string quote_str(1, quote_ch);
+ return Scanner(*sp)
+ .OneLiteral(quote_str.c_str())
+ .RestartCapture()
+ .ScanEscapedUntil(quote_ch)
+ .StopCapture()
+ .OneLiteral(quote_str.c_str())
+ .AnySpace()
+ .GetResult(sp, out);
+}
+
+bool ConsumeAttrType(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .Many(Scanner::LOWERLETTER_DIGIT)
+ .StopCapture()
+ .AnySpace()
+ .GetResult(sp, out);
+}
+
+bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
+ Scanner scan(*sp);
+ StringPiece match;
+ StringPiece remaining;
+
+ scan.AnySpace();
+ bool is_negative = false;
+ if (scan.Peek() == '-') {
+ is_negative = true;
+ scan.OneLiteral("-");
+ }
+ if (!scan.RestartCapture()
+ .Many(Scanner::DIGIT)
+ .StopCapture()
+ .AnySpace()
+ .GetResult(&remaining, &match)) {
+ return false;
+ }
+ uint64 val = 0;
+ if (!str_util::ConsumeLeadingDigits(&match, &val)) return false;
+ if (is_negative) {
+ const int64 final_val = static_cast<int64>(val) * -1;
+ if (final_val > 0) return false;
+ *out = final_val;
+ } else {
+ if (val > static_cast<uint64>(std::numeric_limits<int64>::max())) {
+ return false;
+ }
+ *out = val;
+ }
+ *sp = remaining;
+ return true;
}
#define VERIFY(expr, ...) \
@@ -107,12 +126,11 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
// Parse "<name>:" at the beginning.
StringPiece tmp_name;
- VERIFY(RE2Consume(&spec, AttrNameRE(), &tmp_name),
- "Trouble parsing '<name>:'");
+ VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'");
attr->set_name(tmp_name.data(), tmp_name.size());
// Read "<type>" or "list(<type>)".
- bool is_list = RE2Consume(&spec, AttrListPrefixRE());
+ bool is_list = ConsumeListPrefix(&spec);
string type;
if (spec.Consume("string")) {
type = "string";
@@ -151,14 +169,14 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
}
} else if (spec.Consume("{")) {
// e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
AttrValue* allowed = attr->mutable_allowed_values();
if (spec.starts_with("\"") || spec.starts_with("'")) {
type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
while (true) {
StringPiece escaped_string;
- VERIFY((RE2Consume(&spec, AttrDoubleQuotedRE(), &escaped_string) ||
- RE2Consume(&spec, AttrSingleQuotedRE(), &escaped_string)),
+ VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) ||
+ ConsumeQuotedString('\'', &spec, &escaped_string),
"Trouble parsing allowed string at '", spec, "'");
string unescaped;
string error;
@@ -167,7 +185,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
error);
allowed->mutable_list()->add_s(unescaped);
if (spec.Consume(",")) {
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
if (spec.Consume("}")) break; // Allow ending with ", }".
} else {
VERIFY(spec.Consume("}"),
@@ -179,14 +197,14 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
type = "type";
while (true) {
StringPiece type_string;
- VERIFY(RE2Consume(&spec, AttrTypeRE(), &type_string),
+ VERIFY(ConsumeAttrType(&spec, &type_string),
"Trouble parsing type string at '", spec, "'");
DataType dt;
VERIFY(DataTypeFromString(type_string, &dt),
"Unrecognized type string '", type_string, "'");
allowed->mutable_list()->add_type(dt);
if (spec.Consume(",")) {
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
if (spec.Consume("}")) break; // Allow ending with ", }".
} else {
VERIFY(spec.Consume("}"),
@@ -198,12 +216,12 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
} else {
VERIFY(false, "Trouble parsing type string at '", spec, "'");
}
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
// Write the type into *attr.
if (is_list) {
VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'");
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
attr->set_type(strings::StrCat("list(", type, ")"));
} else {
attr->set_type(type);
@@ -212,7 +230,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
// Read optional minimum constraint at the end.
if ((is_list || type == "int") && spec.Consume(">=")) {
int64 min_limit = -999;
- VERIFY(RE2Consume(&spec, AttrNumberRE(), &min_limit),
+ VERIFY(ConsumeAttrNumber(&spec, &min_limit),
"Could not parse integer lower limit after '>=', found '", spec,
"' instead");
attr->set_has_minimum(true);
@@ -221,7 +239,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
// Parse default value, if present.
if (spec.Consume("=")) {
- RE2Consume(&spec, SpacesRE());
+ str_util::RemoveLeadingWhitespace(&spec);
VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
"Could not parse default value '", spec, "'");
} else {
@@ -236,29 +254,49 @@ string InOutError(bool is_output, StringPiece orig, const string& op_name) {
"\") for Op ", op_name);
}
-const RE2& InOutNameRE() {
- static RE2 pattern("([a-z][a-z0-9_]*)\\s*:\\s*");
- return pattern;
+bool ConsumeInOutName(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .One(Scanner::LOWERLETTER)
+ .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)
+ .StopCapture()
+ .AnySpace()
+ .OneLiteral(":")
+ .AnySpace()
+ .GetResult(sp, out);
}
-const RE2& InOutRefOpenRE() {
- static RE2 pattern("Ref\\s*\\(\\s*");
- return pattern;
+bool ConsumeInOutRefOpen(StringPiece* sp) {
+ return Scanner(*sp)
+ .OneLiteral("Ref")
+ .AnySpace()
+ .OneLiteral("(")
+ .AnySpace()
+ .GetResult(sp);
}
-const RE2& InOutRefCloseRE() {
- static RE2 pattern("\\)\\s*");
- return pattern;
+bool ConsumeInOutRefClose(StringPiece* sp) {
+ return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp);
}
-const RE2& InOutNameOrTypeRE() {
- static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*");
- return pattern;
+bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .One(Scanner::LETTER)
+ .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
+ .StopCapture()
+ .AnySpace()
+ .GetResult(sp, out);
}
-const RE2& InOutTimesTypeRE() {
- static RE2 pattern("[*]\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*");
- return pattern;
+bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .OneLiteral("*")
+ .AnySpace()
+ .RestartCapture()
+ .One(Scanner::LETTER)
+ .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
+ .StopCapture()
+ .AnySpace()
+ .GetResult(sp, out);
}
#define VERIFY(expr, ...) \
@@ -279,20 +317,19 @@ void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
// Parse "<name>:" at the beginning.
StringPiece tmp_name;
- VERIFY(RE2Consume(&spec, InOutNameRE(), &tmp_name),
- "Trouble parsing 'name:'");
+ VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'");
arg->set_name(tmp_name.data(), tmp_name.size());
// Detect "Ref(...)".
- if (RE2Consume(&spec, InOutRefOpenRE())) {
+ if (ConsumeInOutRefOpen(&spec)) {
arg->set_is_ref(true);
}
{ // Parse "<name|type>" or "<name>*<name|type>".
StringPiece first, second, type_or_attr;
- VERIFY(RE2Consume(&spec, InOutNameOrTypeRE(), &first),
+ VERIFY(ConsumeInOutNameOrType(&spec, &first),
"Trouble parsing either a type or an attr name at '", spec, "'");
- if (RE2Consume(&spec, InOutTimesTypeRE(), &second)) {
+ if (ConsumeInOutTimesType(&spec, &second)) {
arg->set_number_attr(first.data(), first.size());
type_or_attr = second;
} else {
@@ -317,7 +354,7 @@ void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
// Closing ) for Ref(.
if (arg->is_ref()) {
- VERIFY(RE2Consume(&spec, InOutRefCloseRE()),
+ VERIFY(ConsumeInOutRefClose(&spec),
"Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
}
@@ -354,14 +391,19 @@ int num_leading_spaces(StringPiece s) {
return i;
}
-const RE2& DocNameColonRE() {
- static RE2 pattern("^[a-zA-Z][a-zA-Z0-9_]*\\s*:");
- return pattern;
+bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) {
+ return Scanner(*sp)
+ .One(Scanner::LETTER)
+ .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
+ .StopCapture()
+ .AnySpace()
+ .OneLiteral(":")
+ .AnySpace()
+ .GetResult(sp, out);
}
-const RE2& DocNameColonSpacesRE() {
- static RE2 pattern("([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*");
- return pattern;
+bool IsDocNameColon(StringPiece s) {
+ return ConsumeDocNameColon(&s, nullptr /* out */);
}
void FinalizeDoc(const string& text, OpDef* op_def,
@@ -384,8 +426,7 @@ void FinalizeDoc(const string& text, OpDef* op_def,
// Lines until we see name: -> description.
int start_l = l;
- while (static_cast<size_t>(l) < lines.size() &&
- !RE2::PartialMatch(lines[l], DocNameColonRE())) {
+ while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
++l;
}
int end_l = l;
@@ -403,10 +444,9 @@ void FinalizeDoc(const string& text, OpDef* op_def,
while (static_cast<size_t>(l) < lines.size()) {
description.clear();
description.push_back(lines[l]);
- RE2Consume(&description.back(), DocNameColonSpacesRE(), &name);
+ ConsumeDocNameColon(&description.back(), &name);
++l;
- while (static_cast<size_t>(l) < lines.size() &&
- !RE2::PartialMatch(lines[l], DocNameColonRE())) {
+ while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
description.push_back(lines[l]);
++l;
}
diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc
index bca67120bb..2d6a7f01ae 100644
--- a/tensorflow/core/framework/op_def_builder_test.cc
+++ b/tensorflow/core/framework/op_def_builder_test.cc
@@ -140,6 +140,12 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
ExpectSuccess(
b().Attr("i: int >= -5"),
"attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }");
+ ExpectSuccess(b().Attr("i: int >= 9223372036854775807"),
+ ("attr: { name: 'i' type: 'int' has_minimum: true "
+ "minimum: 9223372036854775807 }"));
+ ExpectSuccess(b().Attr("i: int >= -9223372036854775808"),
+ ("attr: { name: 'i' type: 'int' has_minimum: true "
+ "minimum: -9223372036854775808 }"));
}
TEST_F(OpDefBuilderTest, AttrRestrictionFailure) {
@@ -164,6 +170,20 @@ TEST_F(OpDefBuilderTest, AttrRestrictionFailure) {
ExpectFailure(b().Attr("a:{float,,}"),
"Trouble parsing type string at ',}' from "
"Attr(\"a:{float,,}\") for Op Test");
+ ExpectFailure(b().Attr("i: int >= a"),
+ "Could not parse integer lower limit after '>=', "
+ "found ' a' instead from Attr(\"i: int >= a\") for Op Test");
+ ExpectFailure(b().Attr("i: int >= -a"),
+ "Could not parse integer lower limit after '>=', found ' -a' "
+ "instead from Attr(\"i: int >= -a\") for Op Test");
+ ExpectFailure(b().Attr("i: int >= 9223372036854775808"),
+ "Could not parse integer lower limit after '>=', found "
+ "' 9223372036854775808' instead from "
+ "Attr(\"i: int >= 9223372036854775808\") for Op Test");
+ ExpectFailure(b().Attr("i: int >= -9223372036854775809"),
+ "Could not parse integer lower limit after '>=', found "
+ "' -9223372036854775809' instead from "
+ "Attr(\"i: int >= -9223372036854775809\") for Op Test");
}
TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
@@ -241,6 +261,9 @@ TEST_F(OpDefBuilderTest, AttrListWithDefaults) {
ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"),
"attr: { name: 'a' type: 'list(int)' "
"default_value { list { i: [0, -1, 2, -4, 8] } } }");
+ ExpectSuccess(b().Attr(R"(a:list(int)=[ ])"),
+ "attr: { name: 'a' type: 'list(int)' "
+ "default_value { list { i: [] } } }");
}
TEST_F(OpDefBuilderTest, AttrFailedListDefaults) {
@@ -259,6 +282,12 @@ TEST_F(OpDefBuilderTest, AttrFailedListDefaults) {
ExpectFailure(b().Attr(R"(a:list(string)='foo')"),
"Could not parse default value ''foo'' from "
"Attr(\"a:list(string)='foo'\") for Op Test");
+ ExpectFailure(b().Attr("a:list(float) = ["),
+ "Could not parse default value '[' from "
+ "Attr(\"a:list(float) = [\") for Op Test");
+ ExpectFailure(b().Attr("a:list(float) = "),
+ "Could not parse default value '' from "
+ "Attr(\"a:list(float) = \") for Op Test");
}
TEST_F(OpDefBuilderTest, InputOutput) {
@@ -268,7 +297,7 @@ TEST_F(OpDefBuilderTest, InputOutput) {
"output_arg: { name: 'b' type: DT_STRING }");
ExpectSuccess(b().Input("c: float "),
"input_arg: { name: 'c' type: DT_FLOAT }");
- ExpectSuccess(b().Output("d: Ref(bool)"),
+ ExpectSuccess(b().Output("d: Ref ( bool ) "),
"output_arg: { name: 'd' type: DT_BOOL is_ref: true }");
ExpectOrdered(b().Input("a: bool")
.Output("c: complex64")
@@ -326,6 +355,12 @@ TEST_F(OpDefBuilderTest, InputOutputFailure) {
ExpectFailure(
b().Input("CAPS: int32"),
"Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test");
+ ExpectFailure(
+ b().Input("_underscore: int32"),
+ "Trouble parsing 'name:' from Input(\"_underscore: int32\") for Op Test");
+ ExpectFailure(
+ b().Input("0digit: int32"),
+ "Trouble parsing 'name:' from Input(\"0digit: int32\") for Op Test");
ExpectFailure(b().Input("a: _"),
"Trouble parsing either a type or an attr name at '_' from "
"Input(\"a: _\") for Op Test");
@@ -344,6 +379,9 @@ TEST_F(OpDefBuilderTest, InputOutputFailure) {
ExpectFailure(b().Input("a: Ref(int32"),
"Did not find closing ')' for 'Ref(', instead found: '' from "
"Input(\"a: Ref(int32\") for Op Test");
+ ExpectFailure(
+ b().Input("a: Ref"),
+ "Reference to unknown attr 'Ref' from Input(\"a: Ref\") for Op Test");
ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"),
"Did not find closing ')' for 'Ref(', instead found: 'y' from "
"Input(\"a: Ref(x y\") for Op Test");
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index b94207e2e8..f7e4f1f05a 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -22,8 +22,8 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -221,8 +221,16 @@ static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
}
Status ValidateOpDef(const OpDef& op_def) {
- VALIDATE(RE2::FullMatch(op_def.name(), "(?:_.*|[A-Z][a-zA-Z0-9]*)"),
- "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
+ using ::tensorflow::strings::Scanner;
+
+ if (!StringPiece(op_def.name()).starts_with("_")) {
+ VALIDATE(Scanner(op_def.name())
+ .One(Scanner::UPPERLETTER)
+ .Any(Scanner::LETTER_DIGIT)
+ .Eos()
+ .GetResult(),
+ "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
+ }
std::set<string> names; // for detecting duplicate names
for (const auto& attr : op_def.attr()) {
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index f0c9085d48..60425eedb0 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
@@ -117,15 +117,22 @@ Status ResourceMgr::Cleanup(const string& container) {
return Status::OK();
}
+static bool IsValidContainerName(StringPiece s) {
+ using ::tensorflow::strings::Scanner;
+ return Scanner(s)
+ .One(Scanner::LETTER_DIGIT_DOT)
+ .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH)
+ .Eos()
+ .GetResult();
+}
+
Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef,
bool use_node_name_as_default) {
CHECK(rmgr);
rmgr_ = rmgr;
string attr_container;
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container));
- static RE2 container_re("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
- if (!attr_container.empty() &&
- !RE2::FullMatch(attr_container, container_re)) {
+ if (!attr_container.empty() && !IsValidContainerName(attr_container)) {
return errors::InvalidArgument("container contains invalid characters: ",
attr_container);
}
diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc
index f776d1ebcc..56bc76d384 100644
--- a/tensorflow/core/framework/resource_mgr_test.cc
+++ b/tensorflow/core/framework/resource_mgr_test.cc
@@ -161,6 +161,8 @@ TEST(ContainerInfo, Basic) {
EXPECT_EQ(Policy("cat", "", true), "[cat,foo,public]");
EXPECT_EQ(Policy("cat", "bar", false), "[cat,bar,public]");
EXPECT_EQ(Policy("cat", "bar", true), "[cat,bar,public]");
+ EXPECT_EQ(Policy("cat.0-dog", "bar", true), "[cat.0-dog,bar,public]");
+ EXPECT_EQ(Policy(".cat", "bar", true), "[.cat,bar,public]");
}
Status WrongPolicy(const string& attr_container, const string& attr_shared_name,
@@ -180,6 +182,7 @@ TEST(ContainerInfo, Error) {
// Invalid container.
HasError(WrongPolicy("12$%", "", false), "container contains invalid char");
+ HasError(WrongPolicy("-cat", "", false), "container contains invalid char");
// Invalid shared name.
HasError(WrongPolicy("", "_foo", false), "shared_name cannot start with '_'");
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 6c5873d0c1..db3b46863f 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -28,8 +28,8 @@ limitations under the License.
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -126,20 +126,21 @@ void GraphConstructor::SetError(const string& error) {
status_->Update(errors::InvalidArgument(error));
}
-void GraphConstructor::BuildNodeIndex() {
- // Initialized outside the loop for efficiency
- const char* pattern;
- if (opts_.allow_internal_ops) {
- pattern = "[A-Za-z0-9._][A-Za-z0-9_.\\-/]*";
- } else {
- pattern = "[A-Za-z0-9.][A-Za-z0-9_.\\-/]*";
- }
- RE2 node_name_re(pattern);
+bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
+ using ::tensorflow::strings::Scanner;
+ return Scanner(s)
+ .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
+ : Scanner::LETTER_DIGIT_DOT)
+ .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .Eos()
+ .GetResult();
+}
+void GraphConstructor::BuildNodeIndex() {
// Validate the node names and add them to name_index_.
for (int n = 0; n < gdef_->node_size(); ++n) {
const NodeDef& node_def(gdef_->node(n));
- if (!RE2::FullMatch(node_def.name(), node_name_re)) {
+ if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
SetNodeError(node_def, "Node name contains invalid characters");
return;
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 8e391d6510..ea8ae5dc06 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
@@ -127,12 +128,24 @@ REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
REGISTER_OP("TestInt").Input("a: int32");
TEST_F(GraphConstructorTest, InvalidNodeName) {
- ExpectError("node { name: 'a:b' op: 'ABC' }",
- "Node 'a:b': Node name contains invalid characters");
- ExpectError("node { name: '_abc' op: 'ABC' }",
- // Can't start with '_'
- "Node '_abc': Node name contains invalid characters");
+ auto expect_invalid_name = [this](const char* name) {
+ ExpectError(strings::StrCat("node { name: '", name, "' op: 'ABC' }"),
+ strings::StrCat("Node '", name,
+ "': Node name contains invalid characters"));
+ };
+
+ expect_invalid_name("a:b");
+ expect_invalid_name("_abc"); // Can't start with '_'
+ // Name is a\b, but proto text format escapes slashes so we use a\\b here.
+ // This works for ExpectError too, since re2 also treats \\ as one slash.
+ expect_invalid_name(R"(a\\b)");
+ expect_invalid_name("/a");
+ expect_invalid_name("-a");
+
ExpectOK("node { name: 'a-bc_' op: 'ABC' }");
+ ExpectOK("node { name: 'a-B.0/.c_' op: 'ABC' }");
+ ExpectOK("node { name: '0123' op: 'ABC' }");
+ ExpectOK("node { name: '.0123' op: 'ABC' }");
}
TEST_F(GraphConstructorTest, InvalidSourceNodeName) {
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 7ca186eaec..9dc64a6cef 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -540,6 +540,7 @@ tf_cc_tests(
"adjust_contrast_op_benchmark_test",
"adjust_contrast_op_test",
"colorspace_op_test",
+ "resize_bicubic_op_test",
"resize_bilinear_op_test",
"resize_nearest_neighbor_op_test",
],
@@ -1104,6 +1105,7 @@ filegroup(
"cwise_op_minimum.cc",
"cwise_op_mul.cc",
"cwise_op_neg.cc",
+ "cwise_op_rsqrt.cc",
"cwise_op_select.cc",
"cwise_op_sigmoid.cc",
"cwise_op_sqrt.cc",
diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc
index 42674be967..172bb95661 100644
--- a/tensorflow/core/kernels/decode_raw_op.cc
+++ b/tensorflow/core/kernels/decode_raw_op.cc
@@ -35,9 +35,9 @@ class DecodeRawOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const auto& input = context->input(0);
- int str_size = -1;
+ int64 str_size = -1;
auto flat_in = input.flat<string>();
- for (int i = 0; i < flat_in.size(); ++i) {
+ for (int64 i = 0; i < flat_in.size(); ++i) {
const string& in_str = flat_in(i);
if (str_size == -1) {
str_size = in_str.size();
@@ -62,7 +62,7 @@ class DecodeRawOp : public OpKernel {
errors::InvalidArgument("Input to DecodeRaw has length ", str_size,
" that is not a multiple of ", sizeof(T),
", the size of ", DataTypeString(out_type_)));
- const int added_dim = str_size / sizeof(T);
+ const int64 added_dim = str_size / sizeof(T);
out_shape.AddDim(added_dim);
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(
@@ -76,7 +76,7 @@ class DecodeRawOp : public OpKernel {
little_endian_ ? "true" : "false"));
// Endianness matches, so just copy each string byte-for-byte.
T* out_data = out.data();
- for (int i = 0; i < flat_in.size(); ++i) {
+ for (int64 i = 0; i < flat_in.size(); ++i) {
const T* in_data = reinterpret_cast<const T*>(flat_in(i).data());
memcpy(out_data, in_data, str_size);
out_data += added_dim;
diff --git a/tensorflow/core/kernels/ops_util.cc b/tensorflow/core/kernels/ops_util.cc
index 64955ab0b7..8b03c570de 100644
--- a/tensorflow/core/kernels/ops_util.cc
+++ b/tensorflow/core/kernels/ops_util.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/util/padding.h"
namespace tensorflow {
@@ -119,9 +119,17 @@ Status GetBroadcastSize(const int index, const int in_size, const int ksize,
}
string SanitizeThreadSuffix(string suffix) {
- static RE2 re("[^A-Za-z0-9_-]");
- re.GlobalReplace(&suffix, re, "_");
- return suffix;
+ string clean;
+ for (int i = 0; i < suffix.size(); ++i) {
+ const char ch = suffix[i];
+ if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
+ (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') {
+ clean += ch;
+ } else {
+ clean += '_';
+ }
+ }
+ return clean;
}
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc
index b92ec6e004..a34703402c 100644
--- a/tensorflow/core/kernels/resize_bicubic_op.cc
+++ b/tensorflow/core/kernels/resize_bicubic_op.cc
@@ -28,17 +28,18 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
+namespace {
-static const int64 tab_size = (1 << 10);
+static const int64 kTableSize = (1 << 10);
const float* InitCoeffsTable() {
// Allocate and initialize coefficients table using Bicubic
// convolution algorithm.
// https://en.wikipedia.org/wiki/Bicubic_interpolation
- float* coeffs_tab = new float[(tab_size + 1) * 2];
+ float* coeffs_tab = new float[(kTableSize + 1) * 2];
static const double A = -0.75;
- for (int i = 0; i <= tab_size; ++i) {
- float x = i * 1.0 / tab_size;
+ for (int i = 0; i <= kTableSize; ++i) {
+ float x = i * 1.0 / kTableSize;
coeffs_tab[i * 2] = ((A + 2) * x - (A + 3)) * x * x + 1;
x += 1.0;
coeffs_tab[i * 2 + 1] = ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
@@ -52,6 +53,32 @@ const float* GetCoeffsTable() {
return coeffs_tab;
}
+inline int64 Bound(int64 val, int64 limit) {
+ return std::min(limit - 1ll, std::max(0ll, val));
+}
+
+inline void GetWeightsAndIndices(float scale, int64 out_loc, int64 limit,
+ std::array<float, 4>* weights,
+ std::array<int64, 4>* indices) {
+ const int64 in_loc = floor(scale * out_loc);
+ const float delta = scale * out_loc - in_loc;
+ const int64 offset = round(delta * kTableSize);
+ const float* coeffs_tab = GetCoeffsTable();
+ *weights = {{coeffs_tab[offset * 2 + 1], coeffs_tab[offset * 2],
+ coeffs_tab[(kTableSize - offset) * 2],
+ coeffs_tab[(kTableSize - offset) * 2 + 1]}};
+ *indices = {{Bound(in_loc - 1, limit), Bound(in_loc, limit),
+ Bound(in_loc + 1, limit), Bound(in_loc + 2, limit)}};
+}
+
+inline float Interpolate1D(const std::array<float, 4>& weights,
+ const std::array<float, 4>& values) {
+ return values[0] * weights[0] + values[1] * weights[1] +
+ values[2] * weights[2] + values[3] * weights[3];
+}
+
+} // namespace
+
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Device, typename T>
@@ -106,40 +133,34 @@ class ResizeBicubicOp : public OpKernel {
? (in_width - 1) / static_cast<float>(out_width - 1)
: in_width / static_cast<float>(out_width);
- const float* coeffs_tab = GetCoeffsTable();
-
- auto cal = [](const float* coeffs_tab, float v0, float v1, float v2,
- float v3, float dx) {
- const int64 offset = round(dx * tab_size);
- const float a0 = coeffs_tab[offset * 2 + 1];
- const float a1 = coeffs_tab[offset * 2];
- const float a2 = coeffs_tab[(tab_size - offset) * 2];
- const float a3 = coeffs_tab[(tab_size - offset) * 2 + 1];
- return a0 * v0 + a1 * v1 + a2 * v2 + a3 * v3;
- };
-
- float coeff[4] = {0.0};
+ std::array<float, 4> coeff = {{0.0, 0.0, 0.0, 0.0}};
for (int64 b = 0; b < batch_size; ++b) {
for (int64 y = 0; y < out_height; ++y) {
- const int64 in_y = floor(height_scale * y);
- const float dy = height_scale * y - in_y;
+ std::array<float, 4> y_weights;
+ std::array<int64, 4> y_indices;
+ GetWeightsAndIndices(height_scale, y, in_height, &y_weights,
+ &y_indices);
for (int64 x = 0; x < out_width; ++x) {
- const int64 in_x = floor(width_scale * x);
- const float dx = width_scale * x - in_x;
+ std::array<float, 4> x_weights;
+ std::array<int64, 4> x_indices;
+ GetWeightsAndIndices(width_scale, x, in_width, &x_weights,
+ &x_indices);
for (int64 c = 0; c < channels; ++c) {
+ // Use a 4x4 patch to compute the interpolated output value at
+ // (b, y, x, c).
for (int64 i = 0; i < 4; ++i) {
-#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val))))
- int64 bound_y = BOUND(in_y - 1 + i, in_height);
- coeff[i] =
- cal(coeffs_tab,
- input_data(b, bound_y, BOUND(in_x - 1, in_width), c),
- input_data(b, bound_y, BOUND(in_x, in_width), c),
- input_data(b, bound_y, BOUND(in_x + 1, in_width), c),
- input_data(b, bound_y, BOUND(in_x + 2, in_width), c), dx);
-#undef BOUND
+ const std::array<float, 4> values = {
+ {static_cast<float>(
+ input_data(b, y_indices[i], x_indices[0], c)),
+ static_cast<float>(
+ input_data(b, y_indices[i], x_indices[1], c)),
+ static_cast<float>(
+ input_data(b, y_indices[i], x_indices[2], c)),
+ static_cast<float>(
+ input_data(b, y_indices[i], x_indices[3], c))}};
+ coeff[i] = Interpolate1D(x_weights, values);
}
- output_data(b, y, x, c) =
- cal(coeffs_tab, coeff[0], coeff[1], coeff[2], coeff[3], dy);
+ output_data(b, y, x, c) = Interpolate1D(y_weights, coeff);
}
}
}
diff --git a/tensorflow/core/kernels/resize_bicubic_op_test.cc b/tensorflow/core/kernels/resize_bicubic_op_test.cc
new file mode 100644
index 0000000000..d9d68844eb
--- /dev/null
+++ b/tensorflow/core/kernels/resize_bicubic_op_test.cc
@@ -0,0 +1,57 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+static Graph* ResizeBicubic(int batch_size, int size, int channels) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor input(DT_FLOAT, TensorShape({batch_size, size, size, channels}));
+ input.flat<float>().setRandom();
+ Tensor shape(DT_INT32, TensorShape({2}));
+ auto shape_t = shape.flat<int32>();
+ shape_t(0) = 0.3 * size;
+ shape_t(1) = 0.7 * size;
+ test::graph::Binary(g, "ResizeBicubic", test::graph::Constant(g, input),
+ test::graph::Constant(g, shape));
+ return g;
+}
+
+#define BM_ResizeBicubicDev(BATCH, SIZE, CHANNELS) \
+ static void BM_ResizeBicubic##_##BATCH##_##SIZE##_##CHANNELS(int iters) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * SIZE * SIZE * \
+ CHANNELS); \
+ test::Benchmark("cpu", ResizeBicubic(BATCH, SIZE, CHANNELS)).Run(iters); \
+ } \
+ BENCHMARK(BM_ResizeBicubic##_##BATCH##_##SIZE##_##CHANNELS);
+
+BM_ResizeBicubicDev(8, 32, 3);
+BM_ResizeBicubicDev(8, 128, 3);
+BM_ResizeBicubicDev(8, 512, 3);
+BM_ResizeBicubicDev(8, 1024, 3);
+BM_ResizeBicubicDev(16, 32, 3);
+BM_ResizeBicubicDev(16, 128, 3);
+BM_ResizeBicubicDev(16, 512, 3);
+BM_ResizeBicubicDev(16, 1024, 3);
+BM_ResizeBicubicDev(32, 32, 3);
+BM_ResizeBicubicDev(32, 128, 3);
+BM_ResizeBicubicDev(32, 512, 3);
+BM_ResizeBicubicDev(32, 1024, 3);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/tile_ops.h b/tensorflow/core/kernels/tile_ops.h
index 379f877fcb..3dc5deb286 100644
--- a/tensorflow/core/kernels/tile_ops.h
+++ b/tensorflow/core/kernels/tile_ops.h
@@ -28,7 +28,12 @@ struct Tile {
void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
typename TTypes<T, NDIM>::ConstTensor in,
const Eigen::array<int32, NDIM>& broadcast_array) const {
- out.device(d) = in.broadcast(broadcast_array);
+ if (Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
+ // Use 32bit indexing to speed up the computations
+ To32Bit(out).device(d) = To32Bit(in).broadcast(broadcast_array);
+ } else {
+ out.device(d) = in.broadcast(broadcast_array);
+ }
}
};
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 7a99cbd5d6..a24b7c704a 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -576,7 +576,6 @@ class SparseApplyFtrlOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
- const Device& device = ctx->template eigen_device<Device>();
mutex* mu_var = ctx->input_ref_mutex(0);
// mu_accum is actually the same mutex as mu_var since currently we use a
// global mutex.
@@ -666,21 +665,42 @@ class SparseApplyFtrlOp : public OpKernel {
auto accum_flat = accum.flat_outer_dims<T>();
auto linear_flat = linear.flat_outer_dims<T>();
auto grad_flat = grad.flat_outer_dims<T>();
+ T lr_scalar = lr.scalar<T>()();
+ T l1_scalar = l1.scalar<T>()();
+ T l2_scalar = l2.scalar<T>()();
+ T lr_power_scalar = lr_power.scalar<T>()();
for (Tindex i = 0; i < N; i++) {
const Tindex index = indices_vec(i);
- typename TTypes<T>::Flat accum(&accum_flat(index, 0),
- accum_flat.dimension(1));
- typename TTypes<T>::Flat linear(&linear_flat(index, 0),
- linear_flat.dimension(1));
- typename TTypes<T>::Flat var(&var_flat(index, 0),
- var_flat.dimension(1));
- typename TTypes<T>::ConstFlat grad(&grad_flat(i, 0),
- grad_flat.dimension(1));
-
- functor::ApplyFtrl<Device, T>()(device, var, accum, linear, grad,
- lr.scalar<T>(), l1.scalar<T>(),
- l2.scalar<T>(), lr_power.scalar<T>());
+ auto accum = accum_flat.template chip<0>(index);
+ auto linear = linear_flat.template chip<0>(index);
+ auto grad = grad_flat.template chip<0>(i);
+ auto var = var_flat.template chip<0>(index);
+
+ auto new_accum = accum + grad.square();
+ if (lr_power_scalar == -0.5) {
+ linear +=
+ grad - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var;
+ } else {
+ linear += grad -
+ (new_accum.pow(-lr_power_scalar) -
+ accum.pow(-lr_power_scalar)) /
+ lr_scalar * var;
+ }
+ auto x = (linear.constant(l1_scalar) * linear.sign() - linear);
+ if (lr_power_scalar == -0.5) {
+ auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) +
+ linear.constant(2 * l2_scalar);
+ var = x / y;
+ } else {
+ auto y = new_accum.pow(-lr_power_scalar) /
+ new_accum.constant(lr_scalar) +
+ linear.constant(2 * l2_scalar);
+ var = x / y;
+ }
+ var = (linear.abs() > linear.constant(l1_scalar))
+ .select(var, var.constant(0));
+ accum += grad.square();
}
} else {
CHECK_EQ(1, inner_dim);
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc
index 8727f1367b..d681958162 100644
--- a/tensorflow/core/lib/jpeg/jpeg_mem.cc
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc
@@ -141,6 +141,21 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) {
jpeg_start_decompress(&cinfo);
+ int64 total_size = static_cast<int64>(cinfo.output_height) *
+ static_cast<int64>(cinfo.output_width);
+ // Some of the internal routines do not gracefully handle ridiculously
+ // large images, so fail fast.
+ if (cinfo.output_width <= 0 || cinfo.output_height <= 0) {
+ LOG(ERROR) << "Invalid image size: " << cinfo.output_width << " x "
+ << cinfo.output_height;
+ return nullptr;
+ }
+ if (total_size >= (1LL << 29)) {
+ LOG(ERROR) << "Image too large: " << total_size;
+ jpeg_destroy_decompress(&cinfo);
+ return nullptr;
+ }
+
// check for compatible stride
const int min_stride = cinfo.output_width * components * sizeof(JSAMPLE);
if (stride == 0) {
@@ -405,6 +420,19 @@ bool CompressInternal(const uint8* srcdata, int width, int height,
const CompressFlags& flags, string* output) {
output->clear();
const int components = (static_cast<int>(flags.format) & 0xff);
+
+ int64 total_size = static_cast<int64>(width) * static_cast<int64>(height);
+ // Some of the internal routines do not gracefully handle ridiculously
+ // large images, so fail fast.
+ if (width <= 0 || height <= 0) {
+ LOG(ERROR) << "Invalid image size: " << width << " x " << height;
+ return false;
+ }
+ if (total_size >= (1LL << 29)) {
+ LOG(ERROR) << "Image too large: " << total_size;
+ return false;
+ }
+
int in_stride = flags.stride;
if (in_stride == 0) {
in_stride = width * (static_cast<int>(flags.format) & 0xff);
diff --git a/tensorflow/core/lib/strings/scanner.cc b/tensorflow/core/lib/strings/scanner.cc
new file mode 100644
index 0000000000..b05400c97d
--- /dev/null
+++ b/tensorflow/core/lib/strings/scanner.cc
@@ -0,0 +1,59 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/strings/scanner.h"
+
+namespace tensorflow {
+namespace strings {
+
+void Scanner::ScanEscapedUntilImpl(char end_ch) {
+ for (;;) {
+ if (cur_.empty()) {
+ Error();
+ return;
+ }
+ const char ch = cur_[0];
+ if (ch == end_ch) {
+ return;
+ }
+
+ cur_.remove_prefix(1);
+ if (ch == '\\') {
+ // Escape character, skip next character.
+ if (cur_.empty()) {
+ Error();
+ return;
+ }
+ cur_.remove_prefix(1);
+ }
+ }
+}
+
+bool Scanner::GetResult(StringPiece* remaining, StringPiece* capture) {
+ if (error_) {
+ return false;
+ }
+ if (remaining != nullptr) {
+ *remaining = cur_;
+ }
+ if (capture != nullptr) {
+ const char* end = capture_end_ == nullptr ? cur_.data() : capture_end_;
+ *capture = StringPiece(capture_start_, end - capture_start_);
+ }
+ return true;
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/scanner.h b/tensorflow/core/lib/strings/scanner.h
new file mode 100644
index 0000000000..ecbb139d60
--- /dev/null
+++ b/tensorflow/core/lib/strings/scanner.h
@@ -0,0 +1,218 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LIB_STRINGS_SCANNER_H_
+#define TENSORFLOW_LIB_STRINGS_SCANNER_H_
+
+#include <string>
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace strings {
+
+// Scanner provides simplified string parsing, in which a string is parsed as a
+// series of scanning calls (e.g. One, Any, Many, OneLiteral, Eos), and then
+// finally GetResult is called. If GetResult returns true, then it also returns
+// the remaining characters and any captured substring.
+//
+// The range to capture can be controlled with RestartCapture and StopCapture;
+// by default, all processed characters are captured.
+class Scanner {
+ public:
+ // Classes of characters. Each enum name is to be read as the union of the
+ // parts - e.g., class LETTER_DIGIT means the class includes all letters and
+ // all digits.
+ //
+ // LETTER means ascii letter a-zA-Z.
+ // DIGIT means ascii digit: 0-9.
+ enum CharClass {
+ // NOTE: When adding a new CharClass, update the AllCharClasses ScannerTest
+ // in scanner_test.cc
+ DIGIT,
+ LETTER,
+ LETTER_DIGIT,
+ LETTER_DIGIT_DASH_DOT_SLASH, // SLASH is / only, not backslash
+ LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE, // SLASH is / only, not backslash
+ LETTER_DIGIT_DOT,
+ LETTER_DIGIT_DOT_UNDERSCORE,
+ LETTER_DIGIT_UNDERSCORE,
+ LOWERLETTER,
+ LOWERLETTER_DIGIT,
+ LOWERLETTER_DIGIT_UNDERSCORE,
+ NON_ZERO_DIGIT,
+ SPACE,
+ UPPERLETTER,
+ };
+
+ explicit Scanner(StringPiece source) : cur_(source) { RestartCapture(); }
+
+ // Consume the next character of the given class from input. If the next
+ // character is not in the class, then GetResult will ultimately return false.
+ Scanner& One(CharClass clz) {
+ if (cur_.empty() || !Matches(clz, cur_[0])) {
+ return Error();
+ }
+ cur_.remove_prefix(1);
+ return *this;
+ }
+
+ // Consume the next s.size() characters of the input, if they match <s>. If
+ // they don't match <s>, this is a no-op.
+ Scanner& ZeroOrOneLiteral(StringPiece s) {
+ cur_.Consume(s);
+ return *this;
+ }
+
+ // Consume the next s.size() characters of the input, if they match <s>. If
+ // they don't match <s>, then GetResult will ultimately return false.
+ Scanner& OneLiteral(StringPiece s) {
+ if (!cur_.Consume(s)) {
+ error_ = true;
+ }
+ return *this;
+ }
+
+ // Consume characters from the input as long as they match <clz>.
+ Scanner& Any(CharClass clz) {
+ while (!cur_.empty() && Matches(clz, cur_[0])) {
+ cur_.remove_prefix(1);
+ }
+ return *this;
+ }
+
+ // Shorthand for One(clz).Any(clz).
+ Scanner& Many(CharClass clz) { return One(clz).Any(clz); }
+
+ // Reset the capture start point.
+ //
+ // Later, when GetResult is called and if it returns true, the capture
+ // returned will start at the position at the time this was called.
+ Scanner& RestartCapture() {
+ capture_start_ = cur_.data();
+ return *this;
+ }
+
+ // Stop capturing input.
+ //
+ // Later, when GetResult is called and if it returns true, the capture
+ // returned will end at the position at the time this was called.
+ Scanner& StopCapture() {
+ capture_end_ = cur_.data();
+ return *this;
+ }
+
+ // If not at the input of input, then GetResult will ultimately return false.
+ Scanner& Eos() {
+ if (!cur_.empty()) error_ = true;
+ return *this;
+ }
+
+ // Shorthand for Any(SPACE).
+ Scanner& AnySpace() { return Any(SPACE); }
+
+ // This scans input until <end_ch> is reached. <end_ch> is NOT consumed.
+ // Backslash escape sequences are skipped.
+ // Used for implementing quoted string scanning.
+ Scanner& ScanEscapedUntil(char end_ch) {
+ ScanEscapedUntilImpl(end_ch);
+ return *this;
+ }
+
+ // Return the next character that will be scanned, or <default_value> if there
+ // are no more characters to scan.
+ // Note that if a scan operation has failed (so GetResult() returns false),
+ // then the value of Peek may or may not have advanced since the scan
+ // operation that failed.
+ char Peek(char default_value = '\0') const {
+ return cur_.empty() ? default_value : cur_[0];
+ }
+
+ // Returns true if the input string successfully matched. When true is
+ // returned, the remaining string is returned in <remaining> and the captured
+ // string returned in <capture>, if non-NULL.
+ bool GetResult(StringPiece* remaining = nullptr,
+ StringPiece* capture = nullptr);
+
+ private:
+ void ScanEscapedUntilImpl(char end_ch);
+
+ Scanner& Error() {
+ error_ = true;
+ return *this;
+ }
+
+ static bool IsLetter(char ch) {
+ return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z');
+ }
+
+ static bool IsLowerLetter(char ch) { return ch >= 'a' && ch <= 'z'; }
+
+ static bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; }
+
+ static bool IsSpace(char ch) {
+ return (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\v' || ch == '\f' ||
+ ch == '\r');
+ }
+
+ static bool Matches(CharClass clz, char ch) {
+ switch (clz) {
+ case DIGIT:
+ return IsDigit(ch);
+ case LETTER:
+ return IsLetter(ch);
+ case LETTER_DIGIT:
+ return IsLetter(ch) || IsDigit(ch);
+ case LETTER_DIGIT_DASH_DOT_SLASH:
+ return IsLetter(ch) || IsDigit(ch) || ch == '-' || ch == '.' ||
+ ch == '/';
+ case LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE:
+ return (IsLetter(ch) || IsDigit(ch) || ch == '-' || ch == '.' ||
+ ch == '/' || ch == '_');
+ case LETTER_DIGIT_DOT:
+ return IsLetter(ch) || IsDigit(ch) || ch == '.';
+ case LETTER_DIGIT_DOT_UNDERSCORE:
+ return IsLetter(ch) || IsDigit(ch) || ch == '.' || ch == '_';
+ case LETTER_DIGIT_UNDERSCORE:
+ return IsLetter(ch) || IsDigit(ch) || ch == '_';
+ case LOWERLETTER:
+ return ch >= 'a' && ch <= 'z';
+ case LOWERLETTER_DIGIT:
+ return IsLowerLetter(ch) || IsDigit(ch);
+ case LOWERLETTER_DIGIT_UNDERSCORE:
+ return IsLowerLetter(ch) || IsDigit(ch) || ch == '_';
+ case NON_ZERO_DIGIT:
+ return IsDigit(ch) && ch != '0';
+ case SPACE:
+ return IsSpace(ch);
+ case UPPERLETTER:
+ return ch >= 'A' && ch <= 'Z';
+ }
+ return false;
+ }
+
+ StringPiece cur_;
+ const char* capture_start_ = nullptr;
+ const char* capture_end_ = nullptr;
+ bool error_ = false;
+
+ friend class ScannerTest;
+ TF_DISALLOW_COPY_AND_ASSIGN(Scanner);
+};
+
+} // namespace strings
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_SCANNER_H_
diff --git a/tensorflow/core/lib/strings/scanner_test.cc b/tensorflow/core/lib/strings/scanner_test.cc
new file mode 100644
index 0000000000..98028ae516
--- /dev/null
+++ b/tensorflow/core/lib/strings/scanner_test.cc
@@ -0,0 +1,266 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/strings/scanner.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace strings {
+
+class ScannerTest : public ::testing::Test {
+ protected:
+ // Returns a string with all chars that are in <clz>, in byte value order.
+ string ClassStr(Scanner::CharClass clz) {
+ string s;
+ for (int i = 0; i < 256; ++i) {
+ char ch = i;
+ if (Scanner::Matches(clz, ch)) {
+ s += ch;
+ }
+ }
+ return s;
+ }
+};
+
+TEST_F(ScannerTest, Any) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(Scanner(" horse0123")
+ .Any(Scanner::SPACE)
+ .Any(Scanner::DIGIT)
+ .Any(Scanner::LETTER)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ(" horse", match.ToString());
+ EXPECT_EQ("0123", remaining.ToString());
+
+ EXPECT_TRUE(Scanner("")
+ .Any(Scanner::SPACE)
+ .Any(Scanner::DIGIT)
+ .Any(Scanner::LETTER)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining.ToString());
+ EXPECT_EQ("", match.ToString());
+
+ EXPECT_TRUE(Scanner("----")
+ .Any(Scanner::SPACE)
+ .Any(Scanner::DIGIT)
+ .Any(Scanner::LETTER)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("----", remaining.ToString());
+ EXPECT_EQ("", match.ToString());
+}
+
+TEST_F(ScannerTest, AnySpace) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(Scanner(" a b ")
+ .AnySpace()
+ .One(Scanner::LETTER)
+ .AnySpace()
+ .GetResult(&remaining, &match));
+ EXPECT_EQ(" a ", match.ToString());
+ EXPECT_EQ("b ", remaining.ToString());
+}
+
+TEST_F(ScannerTest, Eos) {
+ EXPECT_FALSE(Scanner("a").Eos().GetResult());
+ EXPECT_TRUE(Scanner("").Eos().GetResult());
+ EXPECT_FALSE(Scanner("abc").OneLiteral("ab").Eos().GetResult());
+ EXPECT_TRUE(Scanner("abc").OneLiteral("abc").Eos().GetResult());
+}
+
+TEST_F(ScannerTest, Many) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(Scanner("abc").Many(Scanner::LETTER).GetResult());
+ EXPECT_FALSE(Scanner("0").Many(Scanner::LETTER).GetResult());
+ EXPECT_FALSE(Scanner("").Many(Scanner::LETTER).GetResult());
+
+ EXPECT_TRUE(
+ Scanner("abc ").Many(Scanner::LETTER).GetResult(&remaining, &match));
+ EXPECT_EQ(" ", remaining);
+ EXPECT_EQ("abc", match);
+ EXPECT_TRUE(
+ Scanner("abc").Many(Scanner::LETTER).GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining);
+ EXPECT_EQ("abc", match);
+}
+
+TEST_F(ScannerTest, One) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(Scanner("abc").One(Scanner::LETTER).GetResult());
+ EXPECT_FALSE(Scanner("0").One(Scanner::LETTER).GetResult());
+ EXPECT_FALSE(Scanner("").One(Scanner::LETTER).GetResult());
+
+ EXPECT_TRUE(Scanner("abc")
+ .One(Scanner::LETTER)
+ .One(Scanner::LETTER)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("c", remaining);
+ EXPECT_EQ("ab", match);
+ EXPECT_TRUE(Scanner("a").One(Scanner::LETTER).GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining);
+ EXPECT_EQ("a", match);
+}
+
+TEST_F(ScannerTest, OneLiteral) {
+ EXPECT_FALSE(Scanner("abc").OneLiteral("abC").GetResult());
+ EXPECT_TRUE(Scanner("abc").OneLiteral("ab").OneLiteral("c").GetResult());
+}
+
+TEST_F(ScannerTest, ScanEscapedUntil) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(Scanner(R"(' \1 \2 \3 \' \\'rest)")
+ .OneLiteral("'")
+ .ScanEscapedUntil('\'')
+ .OneLiteral("'")
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("rest", remaining.ToString());
+ EXPECT_EQ(R"(' \1 \2 \3 \' \\')", match.ToString());
+
+ // The "scan until" character is not present.
+ remaining = match = "unset";
+ EXPECT_FALSE(Scanner(R"(' \1 \2 \3 \' \\rest)")
+ .OneLiteral("'")
+ .ScanEscapedUntil('\'')
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("unset", remaining.ToString());
+ EXPECT_EQ("unset", match.ToString());
+}
+
+TEST_F(ScannerTest, ZeroOrOneLiteral) {
+ StringPiece remaining, match;
+ EXPECT_TRUE(
+ Scanner("abc").ZeroOrOneLiteral("abC").GetResult(&remaining, &match));
+ EXPECT_EQ("abc", remaining.ToString());
+ EXPECT_EQ("", match.ToString());
+
+ EXPECT_TRUE(
+ Scanner("abcd").ZeroOrOneLiteral("ab").ZeroOrOneLiteral("c").GetResult(
+ &remaining, &match));
+ EXPECT_EQ("d", remaining.ToString());
+ EXPECT_EQ("abc", match.ToString());
+
+ EXPECT_TRUE(
+ Scanner("").ZeroOrOneLiteral("abc").GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining.ToString());
+ EXPECT_EQ("", match.ToString());
+}
+
+// Test output of GetResult (including the forms with optional params),
+// and that it can be called multiple times.
+TEST_F(ScannerTest, CaptureAndGetResult) {
+ StringPiece remaining, match;
+
+ Scanner scan(" first second");
+ EXPECT_TRUE(scan.Any(Scanner::SPACE)
+ .RestartCapture()
+ .One(Scanner::LETTER)
+ .Any(Scanner::LETTER_DIGIT)
+ .StopCapture()
+ .Any(Scanner::SPACE)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("second", remaining.ToString());
+ EXPECT_EQ("first", match.ToString());
+ EXPECT_TRUE(scan.GetResult());
+ remaining = "";
+ EXPECT_TRUE(scan.GetResult(&remaining));
+ EXPECT_EQ("second", remaining.ToString());
+ remaining = "";
+ match = "";
+ EXPECT_TRUE(scan.GetResult(&remaining, &match));
+ EXPECT_EQ("second", remaining.ToString());
+ EXPECT_EQ("first", match.ToString());
+}
+
+// Tests that if StopCapture is not called, then calling GetResult, then
+// scanning more, then GetResult again will update the capture.
+TEST_F(ScannerTest, MultipleGetResultExtendsCapture) {
+ StringPiece remaining, match;
+
+ Scanner scan("one2three");
+ EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match));
+ EXPECT_EQ("2three", remaining.ToString());
+ EXPECT_EQ("one", match.ToString());
+ EXPECT_TRUE(scan.Many(Scanner::DIGIT).GetResult(&remaining, &match));
+ EXPECT_EQ("three", remaining.ToString());
+ EXPECT_EQ("one2", match.ToString());
+ EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining.ToString());
+ EXPECT_EQ("one2three", match.ToString());
+}
+
+TEST_F(ScannerTest, FailedMatchDoesntChangeResult) {
+ // A failed match doesn't change pointers passed to GetResult.
+ Scanner scan("name");
+ StringPiece remaining = "rem";
+ StringPiece match = "match";
+ EXPECT_FALSE(scan.One(Scanner::SPACE).GetResult(&remaining, &match));
+ EXPECT_EQ("rem", remaining.ToString());
+ EXPECT_EQ("match", match.ToString());
+}
+
+TEST_F(ScannerTest, DefaultCapturesAll) {
+ // If RestartCapture() is not called, the whole string is used.
+ Scanner scan("a b");
+ StringPiece remaining = "rem";
+ StringPiece match = "match";
+ EXPECT_TRUE(scan.Any(Scanner::LETTER)
+ .AnySpace()
+ .Any(Scanner::LETTER)
+ .GetResult(&remaining, &match));
+ EXPECT_EQ("", remaining.ToString());
+ EXPECT_EQ("a b", match.ToString());
+}
+
+TEST_F(ScannerTest, AllCharClasses) {
+ EXPECT_EQ("0123456789", ClassStr(Scanner::DIGIT));
+ EXPECT_EQ("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER));
+ EXPECT_EQ("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT));
+ EXPECT_EQ(
+ "-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT_DASH_DOT_SLASH));
+ EXPECT_EQ(
+ "-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_"
+ "abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE));
+ EXPECT_EQ(".0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT_DOT));
+ EXPECT_EQ(".0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT_DOT_UNDERSCORE));
+ EXPECT_EQ("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LETTER_DIGIT_UNDERSCORE));
+ EXPECT_EQ("abcdefghijklmnopqrstuvwxyz", ClassStr(Scanner::LOWERLETTER));
+ EXPECT_EQ("0123456789abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LOWERLETTER_DIGIT));
+ EXPECT_EQ("0123456789_abcdefghijklmnopqrstuvwxyz",
+ ClassStr(Scanner::LOWERLETTER_DIGIT_UNDERSCORE));
+ EXPECT_EQ("123456789", ClassStr(Scanner::NON_ZERO_DIGIT));
+ EXPECT_EQ("\t\n\v\f\r ", ClassStr(Scanner::SPACE));
+ EXPECT_EQ("ABCDEFGHIJKLMNOPQRSTUVWXYZ", ClassStr(Scanner::UPPERLETTER));
+}
+
+TEST_F(ScannerTest, Peek) {
+ EXPECT_EQ('a', Scanner("abc").Peek());
+ EXPECT_EQ('a', Scanner("abc").Peek('b'));
+ EXPECT_EQ('\0', Scanner("").Peek());
+ EXPECT_EQ('z', Scanner("").Peek('z'));
+ EXPECT_EQ('A', Scanner("0123A").Any(Scanner::DIGIT).Peek());
+ EXPECT_EQ('\0', Scanner("0123A").Any(Scanner::LETTER_DIGIT).Peek());
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 6fcb33d39b..b7d1767e00 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -876,6 +876,25 @@ py_test(
],
)
+py_library(
+ name = "device_lib",
+ srcs = ["client/device_lib.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pywrap_tensorflow",
+ ],
+)
+
+cuda_py_tests(
+ name = "device_lib_test",
+ srcs = [
+ "client/device_lib_test.py",
+ ],
+ additional_deps = [
+ ":device_lib",
+ ],
+)
+
tf_cuda_library(
name = "tf_session_helper",
srcs = ["client/tf_session_helper.cc"],
@@ -897,6 +916,7 @@ tf_py_wrap_cc(
name = "pywrap_tensorflow",
srcs = ["tensorflow.i"],
swig_includes = [
+ "client/device_lib.i",
"client/events_writer.i",
"client/server_lib.i",
"client/tf_session.i",
diff --git a/tensorflow/python/client/device_lib.i b/tensorflow/python/client/device_lib.i
new file mode 100644
index 0000000000..a651605fc0
--- /dev/null
+++ b/tensorflow/python/client/device_lib.i
@@ -0,0 +1,81 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/public/session_options.h"
+%}
+
+%typemap(in, numinputs=0) const tensorflow::SessionOptions& options (
+ tensorflow::SessionOptions temp) {
+ $1 = &temp;
+}
+
+%typemap(in, numinputs=0) std::vector<tensorflow::Device*>* devices (
+ std::vector<tensorflow::Device*> temp) {
+ $1 = &temp;
+}
+
+%typemap(argout) std::vector<tensorflow::Device*>* devices {
+ std::vector< std::unique_ptr<tensorflow::Device> > safe_devices;
+ for (auto* device : *$1) safe_devices.emplace_back(device);
+
+ auto temp_string_list = tensorflow::make_safe(PyList_New(0));
+ if (!temp_string_list) {
+ SWIG_fail;
+ }
+
+ for (const auto& device : safe_devices) {
+ const tensorflow::DeviceAttributes& attr = device->attributes();
+ string attr_serialized;
+ if (!attr.SerializeToString(&attr_serialized)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Unable to serialize DeviceAttributes");
+ SWIG_fail;
+ }
+
+ tensorflow::Safe_PyObjectPtr safe_attr_string = tensorflow::make_safe(
+ %#if PY_MAJOR_VERSION < 3
+ PyString_FromStringAndSize(
+ %#else
+ PyUnicode_FromStringAndSize(
+ %#endif
+ reinterpret_cast<const char*>(
+ attr_serialized.data()), attr_serialized.size()));
+
+ if (PyList_Append(temp_string_list.get(), safe_attr_string.get()) == -1) {
+ SWIG_fail;
+ }
+ }
+
+ $result = temp_string_list.release();
+}
+
+
+%ignoreall
+
+%unignore tensorflow;
+%unignore tensorflow::DeviceFactory;
+%unignore tensorflow::DeviceFactory::AddDevices;
+
+%include "tensorflow/core/common_runtime/device_factory.h"
+
+%unignoreall
+
+%newobject tensorflow::SessionOptions;
diff --git a/tensorflow/python/client/device_lib.py b/tensorflow/python/client/device_lib.py
new file mode 100644
index 0000000000..75872c463a
--- /dev/null
+++ b/tensorflow/python/client/device_lib.py
@@ -0,0 +1,37 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A Python interface for creating TensorFlow servers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six # pylint: disable=unused-import
+
+from tensorflow.core.framework import device_attributes_pb2
+from tensorflow.python import pywrap_tensorflow
+
+
+def list_local_devices():
+ """List the available devices available in the local process.
+
+ Returns:
+ A list of `DeviceAttribute` protocol buffers.
+ """
+ def _convert(pb_str):
+ m = device_attributes_pb2.DeviceAttributes()
+ m.ParseFromString(pb_str)
+ return m
+ return [_convert(s) for s in pywrap_tensorflow.DeviceFactory_AddDevices("")]
diff --git a/tensorflow/python/client/device_lib_test.py b/tensorflow/python/client/device_lib_test.py
new file mode 100644
index 0000000000..ee028573aa
--- /dev/null
+++ b/tensorflow/python/client/device_lib_test.py
@@ -0,0 +1,42 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for the SWIG-wrapped device lib."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.client import device_lib
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class DeviceLibTest(test_util.TensorFlowTestCase):
+
+ def testListLocalDevices(self):
+ devices = device_lib.list_local_devices()
+ self.assertGreater(len(devices), 0)
+ self.assertEqual(devices[0].device_type, "CPU")
+
+ # GPU test
+ if tf.test.is_built_with_cuda():
+ self.assertGreater(len(devices), 1)
+ self.assertTrue("GPU" in [d.device_type for d in devices])
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 5fbdf19229..217dc1f14c 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -20,11 +20,9 @@ from __future__ import division
from __future__ import print_function
import re
-import sys
import threading
import numpy as np
-import six
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.framework import errors
@@ -570,22 +568,21 @@ class BaseSession(SessionInterface):
try:
return fn(*args)
except tf_session.StatusNotOK as e:
- e_type, e_value, e_traceback = sys.exc_info()
error_message = compat.as_text(e.error_message)
m = BaseSession._NODEDEF_NAME_RE.search(error_message)
+ node_def = None
+ op = None
if m is not None:
node_name = m.group(1)
- node_def = None
try:
op = self._graph.get_operation_by_name(node_name)
node_def = op.node_def
except KeyError:
- op = None
- # pylint: disable=protected-access
- raise errors._make_specific_exception(node_def, op, error_message,
- e.code)
- # pylint: enable=protected-access
- six.reraise(e_type, e_value, e_traceback)
+ pass
+ # pylint: disable=protected-access
+ raise errors._make_specific_exception(node_def, op, error_message,
+ e.code)
+ # pylint: enable=protected-access
def _extend_graph(self):
# Ensure any changes to the graph are reflected in the runtime.
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 9a59562842..09a9935c5e 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -108,6 +108,19 @@ class SessionTest(test_util.TensorFlowTestCase):
with self.assertRaisesOpError(lambda e: e.op == a.op):
a.eval()
+ def testErrorCodeWithNoNodeDef(self):
+ with session.Session() as s:
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+
+ def exc_predicate(e):
+ return (e.op is None and e.node_def is None and
+ e.error_code == error_codes_pb2.INVALID_ARGUMENT)
+ with self.assertRaisesOpError(exc_predicate):
+ # Run with a bogus handle.
+ s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
+
def testOpConstructionErrorPayload(self):
with session.Session():
failing_op = ops.get_default_graph().create_op(
diff --git a/tensorflow/python/framework/errors.py b/tensorflow/python/framework/errors.py
index f2ffbb9876..f7aaa63792 100644
--- a/tensorflow/python/framework/errors.py
+++ b/tensorflow/python/framework/errors.py
@@ -38,7 +38,8 @@ class OpError(Exception):
"""Creates a new `OpError` indicating that a particular op failed.
Args:
- node_def: The `graph_pb2.NodeDef` proto representing the op that failed.
+ node_def: The `graph_pb2.NodeDef` proto representing the op that failed,
+ if known; otherwise None.
op: The `ops.Operation` that failed, if known; otherwise None.
message: The message string describing the failure.
error_code: The `error_codes_pb2.Code` describing the error.
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 77da519fcc..9c25ea555c 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -31,7 +31,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import logging_ops
-from tensorflow.python.pywrap_tensorflow import StatusNotOK
def check_op_order(graph):
"""Sanity check on the ordering of op id."""
@@ -137,7 +136,8 @@ class ControlFlowTest(tf.test.TestCase):
dead_branch = tf.identity(switch_op[0])
with self.assertRaisesWithPredicateMatch(
- StatusNotOK, lambda e: "The tensor returned for" in str(e)):
+ tf.errors.InvalidArgumentError,
+ lambda e: "The tensor returned for" in str(e)):
dead_branch.eval()
def testSwitchMergeLess(self):
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index af232f65cc..2504dfaa31 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -25,7 +25,6 @@ import time
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
-from tensorflow.python.pywrap_tensorflow import StatusNotOK
class FIFOQueueTest(tf.test.TestCase):
@@ -1161,7 +1160,7 @@ class FIFOQueueWithTimeoutTest(tf.test.TestCase):
# Intentionally do not run any enqueue_ops so that dequeue will block
# until operation_timeout_in_ms.
- with self.assertRaisesRegexp(StatusNotOK,
+ with self.assertRaisesRegexp(tf.errors.DeadlineExceededError,
"Timed out waiting for notification"):
sess.run(dequeued_t)
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 766bdf7dd3..fb2dd00b69 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -29,5 +29,6 @@ limitations under the License.
%include "tensorflow/python/client/tf_session.i"
%include "tensorflow/python/client/server_lib.i"
+%include "tensorflow/python/client/device_lib.i"
%include "tensorflow/python/framework/python_op_gen.i"
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 6debeabd97..24d7e16831 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -10,8 +10,8 @@ def tf_workspace(path_prefix = ""):
native.new_http_archive(
name = "eigen_archive",
- url = "https://bitbucket.org/eigen/eigen/get/88444e025a5c.tar.gz",
- sha256 = "42e6f6de56b3ff010531a2bbf3e2db1db46be30d3965efb1eaa5634c5db013dd",
+ url = "https://bitbucket.org/eigen/eigen/get/2f482bcc8b95.tar.gz",
+ sha256 = "1a80b0e348c3f3608fc5ce1c222902da37a229a01d3965c4622ec3d287617456",
build_file = path_prefix + "eigen.BUILD",
)
diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky
index 95a503d611..e3e05b0425 100644
--- a/third_party/eigen3/Eigen/Cholesky
+++ b/third_party/eigen3/Eigen/Cholesky
@@ -1 +1 @@
-#include "eigen-eigen-88444e025a5c/Eigen/Cholesky"
+#include "eigen-eigen-2f482bcc8b95/Eigen/Cholesky"
diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core
index b4a10f6ed1..87e7c9813e 100644
--- a/third_party/eigen3/Eigen/Core
+++ b/third_party/eigen3/Eigen/Core
@@ -1 +1 @@
-#include "eigen-eigen-88444e025a5c/Eigen/Core"
+#include "eigen-eigen-2f482bcc8b95/Eigen/Core"
diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues
index 56657aa837..6e1a765be7 100644
--- a/third_party/eigen3/Eigen/Eigenvalues
+++ b/third_party/eigen3/Eigen/Eigenvalues
@@ -1 +1 @@
-#include "eigen-eigen-88444e025a5c/Eigen/Eigenvalues"
+#include "eigen-eigen-2f482bcc8b95/Eigen/Eigenvalues"
diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU
index 3c491eeef9..be0e8c7eb8 100644
--- a/third_party/eigen3/Eigen/LU
+++ b/third_party/eigen3/Eigen/LU
@@ -1 +1 @@
-#include "eigen-eigen-88444e025a5c/Eigen/LU"
+#include "eigen-eigen-2f482bcc8b95/Eigen/LU"
diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR
index 5a97880470..17a6a1e34e 100644
--- a/third_party/eigen3/Eigen/QR
+++ b/third_party/eigen3/Eigen/QR
@@ -1 +1 @@
-#include "eigen-eigen-88444e025a5c/Eigen/QR"
+#include "eigen-eigen-2f482bcc8b95/Eigen/QR"
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
index 20150d0594..57404ad17d 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
@@ -1 +1 @@
-#include "eigen-eigen-88444e025a5c/unsupported/Eigen/CXX11/Tensor"
+#include "eigen-eigen-2f482bcc8b95/unsupported/Eigen/CXX11/Tensor"