aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-11-16 09:15:41 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-11-16 09:15:41 -0800
commitcb9fa5fc9de9f3fc97c15bbcce252d7d7fdcb73b (patch)
tree7f0dd44e9aff8a01fff6f323664c3591d04c1225
parent011e9baccd343eb943d25014c4e8aec53eac396b (diff)
TensorFlow: Upstream changes from since last Thursday.
Changes: - Bug fix to input.py by @dave-andersen with tests. - Some fixes to names in the whitepaper - Include cfloat for FLT_MAX by @benoitsteiner - Fix broken link in mnist tutorials by @mrry - Typos and fixes to documentation by @Sohl-Dickstein, @ebrevdo, @vrv, Yaroslav, @martinwicke - Fixed unnecessary check around a delete by @girving - More compatibility tests by @josh11b - Misc other typos and fixes by Googlers. Base CL: 107944070
-rw-r--r--tensorflow/core/framework/op_compatibility_test.cc146
-rw-r--r--tensorflow/core/framework/op_def_util.cc162
-rw-r--r--tensorflow/core/framework/op_def_util.h5
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.cu.cc1
-rw-r--r--tensorflow/core/kernels/reverse_op.h10
-rw-r--r--tensorflow/core/util/tensor_slice_reader_cache.cc4
-rw-r--r--tensorflow/g3doc/api_docs/leftnav_files30
-rw-r--r--tensorflow/g3doc/api_docs/python/array_ops.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/framework.md5
-rw-r--r--tensorflow/g3doc/api_docs/python/io_ops.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/nn.md3
-rw-r--r--tensorflow/g3doc/api_docs/python/train.md8
-rw-r--r--tensorflow/g3doc/get_started/leftnav_files3
-rw-r--r--tensorflow/g3doc/how_tos/leftnav_files10
-rw-r--r--tensorflow/g3doc/resources/leftnav_files5
-rw-r--r--tensorflow/g3doc/tutorials/leftnav_files10
-rw-r--r--tensorflow/g3doc/tutorials/mnist/beginners/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/mnist/pros/index.md6
-rw-r--r--tensorflow/g3doc/tutorials/recurrent/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/seq2seq/index.md2
-rw-r--r--tensorflow/models/embedding/BUILD3
-rw-r--r--tensorflow/models/embedding/word2vec.py3
-rw-r--r--tensorflow/models/embedding/word2vec_optimized.py3
-rw-r--r--tensorflow/models/rnn/ptb/ptb_word_lm.py3
-rw-r--r--tensorflow/models/rnn/rnn_cell.py4
-rw-r--r--tensorflow/python/client/session_test.py23
-rw-r--r--tensorflow/python/framework/ops.py24
-rw-r--r--tensorflow/python/kernel_tests/gradient_checker.py2
-rw-r--r--tensorflow/python/ops/array_ops.py2
-rw-r--r--tensorflow/python/training/input.py2
-rw-r--r--tensorflow/python/training/input_test.py25
-rw-r--r--tensorflow/python/training/moving_averages.py2
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts15
-rw-r--r--tensorflow/tensorboard/components/tf-graph-loader/tf-graph-loader.html4
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/NeuralNetworks/Pooling.h33
35 files changed, 488 insertions, 78 deletions
diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc
index 7716b34ca1..0f5e27de79 100644
--- a/tensorflow/core/framework/op_compatibility_test.cc
+++ b/tensorflow/core/framework/op_compatibility_test.cc
@@ -49,6 +49,7 @@ class OpCompatibilityTest : public OpsTestBase {
&new_out_types));
ASSERT_EQ(new_in_types, old_in_types);
ASSERT_EQ(new_out_types, old_out_types);
+ ASSERT_OK(OpDefCompatible(old_op_def, *new_op_def));
// Verify the Op actually runs. Result() will return the output.
ASSERT_OK(InitOp());
@@ -57,7 +58,21 @@ class OpCompatibilityTest : public OpsTestBase {
string Result() { return GetOutput(0)->scalar<string>()(); }
- void ExpectInvalid(const OpDef& old_op_def, string error) {
+ void ExpectIncompatible(const OpDef& old_op_def, const OpDef& new_op_def,
+ const string& error) {
+ // Test OpDefCompatible gives the same answer without the node_def.
+ Status status = OpDefCompatible(old_op_def, new_op_def);
+ if (status.ok()) {
+ ADD_FAILURE() << SummarizeOpDef(old_op_def) << " vs. "
+ << SummarizeOpDef(new_op_def);
+ } else {
+ EXPECT_TRUE(StringPiece(status.error_message()).contains(error))
+ << status << " does not contain " << error;
+ }
+ }
+
+ void ExpectInvalid(const OpDef& old_op_def, const string& validation_error,
+ const string& compatibility_error) {
// Record the original signature before we change *node_def().
DataTypeVector old_in_types, old_out_types;
ASSERT_OK(InOutTypesForNode(*node_def(), old_op_def, &old_in_types,
@@ -72,12 +87,16 @@ class OpCompatibilityTest : public OpsTestBase {
if (status.ok()) {
ADD_FAILURE() << SummarizeNodeDef(*node_def());
} else {
- EXPECT_TRUE(StringPiece(status.error_message()).contains(error))
- << status << " does not contain " << error;
+ EXPECT_TRUE(
+ StringPiece(status.error_message()).contains(validation_error))
+ << status << " does not contain " << validation_error;
}
+
+ ExpectIncompatible(old_op_def, *new_op_def, compatibility_error);
}
- void ExpectTypeMismatch(const OpDef& old_op_def) {
+ void ExpectTypeMismatch(const OpDef& old_op_def,
+ const string& compatibility_error) {
// Record the original signature before we change *node_def().
DataTypeVector old_in_types, old_out_types;
ASSERT_OK(InOutTypesForNode(*node_def(), old_op_def, &old_in_types,
@@ -98,6 +117,8 @@ class OpCompatibilityTest : public OpsTestBase {
<< DataTypeVectorString(new_in_types) << " -> "
<< DataTypeVectorString(new_out_types);
}
+
+ ExpectIncompatible(old_op_def, *new_op_def, compatibility_error);
}
};
@@ -338,6 +359,57 @@ TEST_F(OpCompatibilityTest, PolyIntoList) {
EXPECT_EQ("poly_into_list = PolyIntoList[N=1, T=DT_INT32](a)", Result());
}
+// Should be able to make a multiple inputs/outputs into a list with
+// the default types matching the inputs/outputs being replaced.
+
+// Changing from int32, int32 -> N * int32 (where N: int = 2 by default).
+REGISTER_OP("MakeMultipleSameList")
+ .Input("a: N * int32")
+ .Output("ndef: string")
+ .Attr("N: int = 2");
+REGISTER_KERNEL_BUILDER(Name("MakeMultipleSameList").Device(DEVICE_CPU),
+ TestKernel);
+
+TEST_F(OpCompatibilityTest, MakeMultipleSameList) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("MakeMultipleSameList")
+ .Input("a: int32")
+ .Input("b: int32")
+ .Output("ndef: string")
+ .Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("make_list", &old_op_def)
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Finalize(node_def()));
+ ExpectSuccess(old_op_def);
+ EXPECT_EQ("make_list = MakeMultipleSameList[N=2](a, b)", Result());
+}
+
+// Changing from int32, float -> T
+// (where T: list(type) = [int32, float] by default).
+REGISTER_OP("MakeMultipleAnyList")
+ .Input("a: T")
+ .Output("ndef: string")
+ .Attr("T: list(type) = [DT_INT32, DT_FLOAT]");
+REGISTER_KERNEL_BUILDER(Name("MakeMultipleAnyList").Device(DEVICE_CPU),
+ TestKernel);
+
+TEST_F(OpCompatibilityTest, MakeMultipleAnyList) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("MakeMultipleAnyList")
+ .Input("a: int32")
+ .Input("b: float")
+ .Output("ndef: string")
+ .Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("make_list", &old_op_def)
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Finalize(node_def()));
+ ExpectSuccess(old_op_def);
+ EXPECT_EQ("make_list = MakeMultipleAnyList[T=[DT_INT32, DT_FLOAT]](a, b)",
+ Result());
+}
+
// Should be able to change the name of an input/output.
REGISTER_OP("ChangeName").Input("y: int32").Output("ndef: string");
REGISTER_KERNEL_BUILDER(Name("ChangeName").Device(DEVICE_CPU), TestKernel);
@@ -477,6 +549,18 @@ TEST_F(OpCompatibilityTest, ShorterSameList) {
// Negative tests -------------------------------------------------------------
+// Can't remove an attr.
+REGISTER_OP("RemoveAttr");
+
+TEST_F(OpCompatibilityTest, RemoveAttrFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("RemoveAttr").Attr("a: int").Finalize(&old_op_def));
+ ASSERT_OK(
+ NodeDefBuilder("fails", &old_op_def).Attr("a", 3).Finalize(node_def()));
+ ExpectInvalid(old_op_def, "NodeDef mentions attr 'a' not in",
+ "Attr 'a' removed");
+}
+
// Can't add an attr without a default.
REGISTER_OP("AddAttrNoDefault").Attr("a: int");
@@ -484,7 +568,8 @@ TEST_F(OpCompatibilityTest, AddAttrNoDefaultFails) {
OpDef old_op_def;
ASSERT_OK(OpDefBuilder("AddAttrNoDefault").Finalize(&old_op_def));
ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
- ExpectInvalid(old_op_def, "NodeDef missing attr 'a'");
+ ExpectInvalid(old_op_def, "NodeDef missing attr 'a'",
+ "Attr 'a' added without default");
}
// Can't add a non-list input/output.
@@ -495,7 +580,8 @@ TEST_F(OpCompatibilityTest, AddSingleInputFails) {
ASSERT_OK(OpDefBuilder("AddSingleInput").Finalize(&old_op_def));
ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
ExpectInvalid(old_op_def,
- "expected inputs 'int32' do not match 0 inputs specified");
+ "expected inputs 'int32' do not match 0 inputs specified",
+ "Input signature mismatch '' vs. 'int32'");
}
// Can't add a list input/output without an empty default.
@@ -514,7 +600,8 @@ TEST_F(OpCompatibilityTest, AddNIntsBigDefaultFails) {
ASSERT_OK(OpDefBuilder("AddNIntsBigDefault").Finalize(&old_op_def));
ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
ExpectInvalid(old_op_def,
- "expected inputs 'int32' do not match 0 inputs specified");
+ "expected inputs 'int32' do not match 0 inputs specified",
+ "Input signature mismatch '' vs. 'int32'");
}
TEST_F(OpCompatibilityTest, AddNSameBigDefaultFails) {
@@ -522,7 +609,8 @@ TEST_F(OpCompatibilityTest, AddNSameBigDefaultFails) {
ASSERT_OK(OpDefBuilder("AddNSameBigDefault").Finalize(&old_op_def));
ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
ExpectInvalid(old_op_def,
- "expected inputs 'int32' do not match 0 inputs specified");
+ "expected inputs 'int32' do not match 0 inputs specified",
+ "Input signature mismatch '' vs. 'int32'");
}
TEST_F(OpCompatibilityTest, AddListBigDefaultFails) {
@@ -530,7 +618,8 @@ TEST_F(OpCompatibilityTest, AddListBigDefaultFails) {
ASSERT_OK(OpDefBuilder("AddListBigDefault").Finalize(&old_op_def));
ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
ExpectInvalid(old_op_def,
- "expected inputs 'int32' do not match 0 inputs specified");
+ "expected inputs 'int32' do not match 0 inputs specified",
+ "Input signature mismatch '' vs. 'int32'");
}
// Can't change the type of an input/output.
@@ -543,7 +632,8 @@ TEST_F(OpCompatibilityTest, ChangeTypeFails) {
ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
.Input(FakeInput())
.Finalize(node_def()));
- ExpectTypeMismatch(old_op_def);
+ ExpectTypeMismatch(old_op_def,
+ "Input signature mismatch 'int32' vs. 'float'");
}
// Can't change the order of inputs/outputs.
@@ -560,7 +650,24 @@ TEST_F(OpCompatibilityTest, ChangeOrderFails) {
.Input(FakeInput())
.Input(FakeInput())
.Finalize(node_def()));
- ExpectTypeMismatch(old_op_def);
+ ExpectTypeMismatch(
+ old_op_def, "Input signature mismatch 'int32, float' vs. 'float, int32'");
+}
+
+// Can't remove inputs/outputs.
+
+REGISTER_OP("RemoveInput");
+
+TEST_F(OpCompatibilityTest, RemoveInputFails) {
+ OpDef old_op_def;
+ ASSERT_OK(
+ OpDefBuilder("RemoveInput").Input("a: float").Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ .Input(FakeInput())
+ .Finalize(node_def()));
+ ExpectInvalid(old_op_def,
+ "expected inputs '' do not match 1 inputs specified",
+ "Input signature mismatch 'float' vs. ''");
}
// Can't change the type of an attr.
@@ -574,7 +681,8 @@ TEST_F(OpCompatibilityTest, ChangeAttrTypeFails) {
ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
.Attr("a", true)
.Finalize(node_def()));
- ExpectInvalid(old_op_def, "value with type 'bool' when 'int' expected");
+ ExpectInvalid(old_op_def, "value with type 'bool' when 'int' expected",
+ "Attr 'a' changed type 'bool' -> 'int'");
}
// Can't change an attr from a list.
@@ -587,7 +695,8 @@ TEST_F(OpCompatibilityTest, AttrFromListFails) {
OpDefBuilder("AttrFromList").Attr("a: list(int)").Finalize(&old_op_def));
ASSERT_OK(
NodeDefBuilder("fails", &old_op_def).Attr("a", {5}).Finalize(node_def()));
- ExpectInvalid(old_op_def, "value with type 'list(int)' when 'int' expected");
+ ExpectInvalid(old_op_def, "value with type 'list(int)' when 'int' expected",
+ "Attr 'a' changed type 'list(int)' -> 'int'");
}
// Can't change an attr to a list.
@@ -599,7 +708,8 @@ TEST_F(OpCompatibilityTest, AttrToListFails) {
ASSERT_OK(OpDefBuilder("AttrToList").Attr("a: int").Finalize(&old_op_def));
ASSERT_OK(
NodeDefBuilder("fails", &old_op_def).Attr("a", 5).Finalize(node_def()));
- ExpectInvalid(old_op_def, "value with type 'int' when 'list(int)' expected");
+ ExpectInvalid(old_op_def, "value with type 'int' when 'list(int)' expected",
+ "Attr 'a' changed type 'int' -> 'list(int)'");
}
// Can't change an input from polymorphic to a list of any type.
@@ -615,8 +725,8 @@ TEST_F(OpCompatibilityTest, PolymorphicToAnyListFails) {
ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
.Input(FakeInput(DT_INT32))
.Finalize(node_def()));
- ExpectInvalid(old_op_def,
- "value with type 'type' when 'list(type)' expected");
+ ExpectInvalid(old_op_def, "value with type 'type' when 'list(type)' expected",
+ "Attr 'T' changed type 'type' -> 'list(type)'");
}
// Can't change an input from a list of the same type to a list of any type.
@@ -636,8 +746,8 @@ TEST_F(OpCompatibilityTest, SameToAnyListFails) {
ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
.Input(FakeInput(1, DT_INT32))
.Finalize(node_def()));
- ExpectInvalid(old_op_def,
- "value with type 'type' when 'list(type)' expected");
+ ExpectInvalid(old_op_def, "value with type 'type' when 'list(type)' expected",
+ "Attr 'T' changed type 'type' -> 'list(type)'");
}
// Changing an attr's default is not technically illegal, but should
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index e3aef011de..7b0fa668bf 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -1,13 +1,14 @@
#include "tensorflow/core/framework/op_def_util.h"
#include <set>
+#include <unordered_map>
#include "tensorflow/core/framework/attr_value_util.h"
-#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
@@ -341,4 +342,161 @@ string SummarizeOpDef(const OpDef& op_def) {
return ret;
}
+namespace {
+
+typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap;
+void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) {
+ for (const auto& attr : op_def.attr()) {
+ (*attr_map)[attr.name()] = &attr;
+ }
+}
+
+// Add a comma to *s every call but the first (*add_comma should be
+// initialized to false).
+void AddComma(string* s, bool* add_comma) {
+ if (*add_comma) {
+ strings::StrAppend(s, ", ");
+ } else {
+ *add_comma = true;
+ }
+}
+
+// Compute a signature for either inputs or outputs that will be the
+// same for both the old and new OpDef if they are compatible. We
+// assume that new_attrs is a superset of old_attrs, and that any attr
+// in the difference has a default. Our strategy is to make a list of
+// types, where the types are things like:
+// * "int32", "float", etc.,
+// * "T" for some attr "T" in old_attrs, or
+// * "N * type" for "N" either some attr in old_attrs.
+//
+// We get the types by either using the attrs in args if they are in
+// old_attrs, or substituting the default value from new_attrs.
+string ComputeArgSignature(
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
+ const AttrMap& old_attrs, const AttrMap& new_attrs) {
+ string s;
+ bool add_comma = false;
+ for (const OpDef::ArgDef& arg : args) {
+ if (!arg.type_list_attr().empty()) {
+ const OpDef::AttrDef* old_attr =
+ gtl::FindPtrOrNull(old_attrs, arg.type_list_attr());
+ if (old_attr) {
+ // Both old and new have the list(type) attr, so can use it directly.
+ AddComma(&s, &add_comma);
+ strings::StrAppend(&s, arg.type_list_attr());
+ if (arg.is_ref()) strings::StrAppend(&s, " ref");
+ } else {
+ // Missing the list(type) attr in the old, so use the default
+ // value for the attr from new instead.
+ const OpDef::AttrDef* new_attr =
+ gtl::FindPtrOrNull(new_attrs, arg.type_list_attr());
+ const auto& type_list = new_attr->default_value().list().type();
+ if (type_list.empty()) continue;
+ for (int i = 0; i < type_list.size(); ++i) {
+ AddComma(&s, &add_comma);
+ strings::StrAppend(
+ &s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
+ if (arg.is_ref()) strings::StrAppend(&s, " ref");
+ }
+ }
+ } else {
+ int num = 1; // How many input/outputs does this represent?
+ if (!arg.number_attr().empty()) {
+ // N * type case.
+ const OpDef::AttrDef* old_attr =
+ gtl::FindPtrOrNull(old_attrs, arg.number_attr());
+ if (old_attr) {
+ // Both old and new have the number attr, so can use it directly.
+ AddComma(&s, &add_comma);
+ strings::StrAppend(&s, arg.number_attr(), " * ");
+ add_comma = false; // Don't add another comma before the type.
+ } else {
+ // Missing the number attr in the old, so use the default
+ // value for the attr from new instead.
+ const OpDef::AttrDef* new_attr =
+ gtl::FindPtrOrNull(new_attrs, arg.number_attr());
+ num = new_attr->default_value().i();
+ }
+ }
+
+ string type; // What is the type of this arg?
+ if (arg.type() != DT_INVALID) {
+ // int32, float, etc. case
+ type = DataTypeString(arg.type());
+ } else {
+ const OpDef::AttrDef* old_attr =
+ gtl::FindPtrOrNull(old_attrs, arg.type_attr());
+ if (old_attr) {
+ // Both old and new have the type attr, so can use it directly.
+ type = arg.type_attr();
+ } else {
+ // Missing the type attr in the old, so use the default
+ // value for the attr from new instead.
+ const OpDef::AttrDef* new_attr =
+ gtl::FindPtrOrNull(new_attrs, arg.type_attr());
+ type = DataTypeString(new_attr->default_value().type());
+ }
+ }
+ if (arg.is_ref()) strings::StrAppend(&type, " ref");
+
+ // Record `num` * `type` in the signature.
+ for (int i = 0; i < num; ++i) {
+ AddComma(&s, &add_comma);
+ strings::StrAppend(&s, type);
+ }
+ }
+ }
+
+ return s;
+}
+
+} // namespace
+
+Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
+#define VALIDATE(CONDITION, ...) \
+ if (!(CONDITION)) { \
+ return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \
+ "; old: ", SummarizeOpDef(old_op), \
+ "; new: ", SummarizeOpDef(new_op)); \
+ }
+
+ VALIDATE(old_op.name() == new_op.name(), "Name mismatch");
+
+ AttrMap new_attrs, old_attrs;
+ FillAttrMap(old_op, &old_attrs);
+ FillAttrMap(new_op, &new_attrs);
+ for (const auto& old_attr : old_op.attr()) {
+ const OpDef::AttrDef* new_attr =
+ gtl::FindPtrOrNull(new_attrs, old_attr.name());
+ VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed");
+ VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(),
+ "' changed type '", old_attr.type(), "' -> '", new_attr->type(),
+ "'");
+ }
+
+ for (const auto& new_attr : new_op.attr()) {
+ const OpDef::AttrDef* old_attr =
+ gtl::FindPtrOrNull(old_attrs, new_attr.name());
+ VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '",
+ new_attr.name(), "' added without default");
+ }
+
+ const string old_in_sig =
+ ComputeArgSignature(old_op.input_arg(), old_attrs, new_attrs);
+ const string new_in_sig =
+ ComputeArgSignature(new_op.input_arg(), old_attrs, new_attrs);
+ VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
+ "' vs. '", new_in_sig, "'");
+
+ const string old_out_sig =
+ ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs);
+ const string new_out_sig =
+ ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs);
+ VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
+ old_out_sig, "' vs. '", new_out_sig, "'");
+
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h
index a9fecf3fa0..c8440c5def 100644
--- a/tensorflow/core/framework/op_def_util.h
+++ b/tensorflow/core/framework/op_def_util.h
@@ -27,6 +27,11 @@ OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def);
// than a text-format proto. Excludes descriptions.
string SummarizeOpDef(const OpDef& op_def);
+// Returns an error if new_op is not backwards-compatible with (more
+// accepting than) old_op.
+// REQUIRES: old_op and new_op must pass validation.
+Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op);
+
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
index 65262eb54e..13684c82f7 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
@@ -3,6 +3,7 @@
#define EIGEN_USE_GPU
#include <stdio.h>
+#include <cfloat>
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
diff --git a/tensorflow/core/kernels/reverse_op.h b/tensorflow/core/kernels/reverse_op.h
index bba25f70e8..ca13b98a71 100644
--- a/tensorflow/core/kernels/reverse_op.h
+++ b/tensorflow/core/kernels/reverse_op.h
@@ -22,6 +22,16 @@ struct Reverse {
}
};
+template <typename Device, typename T>
+struct Reverse<Device, T, 0> {
+ void operator()(const Device& d, typename TTypes<T, 0>::ConstTensor input,
+ typename TTypes<bool, 1>::ConstTensor,
+ typename TTypes<T, 0>::Tensor output) {
+ // Reversing a scalar is copying it.
+ output.device(d) = input;
+ }
+};
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_reader_cache.cc b/tensorflow/core/util/tensor_slice_reader_cache.cc
index af81d0115e..24eedd5c5c 100644
--- a/tensorflow/core/util/tensor_slice_reader_cache.cc
+++ b/tensorflow/core/util/tensor_slice_reader_cache.cc
@@ -9,9 +9,7 @@ namespace checkpoint {
TensorSliceReaderCacheWrapper::TensorSliceReaderCacheWrapper() {}
TensorSliceReaderCacheWrapper::~TensorSliceReaderCacheWrapper() {
- if (cache_) {
- delete cache_;
- }
+ delete cache_;
cache_ = nullptr;
}
diff --git a/tensorflow/g3doc/api_docs/leftnav_files b/tensorflow/g3doc/api_docs/leftnav_files
new file mode 100644
index 0000000000..ceeb88af8f
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/leftnav_files
@@ -0,0 +1,30 @@
+### [Overview](/api_docs/index.md)
+### [Python API](/api_docs/python/index.md)
+python/framework.md
+python/constant_op.md
+python/state_ops.md
+python/array_ops.md
+python/math_ops.md
+python/control_flow_ops.md
+python/image.md
+python/sparse_ops.md
+python/io_ops.md
+python/python_io.md
+python/nn.md
+python/client.md
+python/train.md
+### [C++ API](/api_docs/cc/index.md)
+cc/ClassEnv.md
+cc/ClassRandomAccessFile.md
+cc/ClassWritableFile.md
+cc/ClassEnvWrapper.md
+cc/ClassSession.md
+cc/StructSessionOptions.md
+cc/ClassStatus.md
+cc/StructState.md
+cc/ClassTensor.md
+cc/ClassTensorShape.md
+cc/StructTensorShapeDim.md
+cc/ClassTensorShapeUtils.md
+cc/ClassThread.md
+cc/StructThreadOptions.md \ No newline at end of file
diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md
index 10a6df41ff..1b05a48862 100644
--- a/tensorflow/g3doc/api_docs/python/array_ops.md
+++ b/tensorflow/g3doc/api_docs/python/array_ops.md
@@ -329,7 +329,7 @@ reshape(t, [3, 3]) ==> [[1, 2, 3]
# tensor 't' is [[[1, 1], [2, 2]]
# [[3, 3], [4, 4]]]
-# tensor 't' has shape [2, 2]
+# tensor 't' has shape [2, 2, 2]
reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
[3, 3, 4, 4]]
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index 1f88ad0ca4..eab4ec0152 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -147,6 +147,11 @@ This method is thread-safe.
A [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
protocol buffer.
+##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
+
+
+* <b>`ValueError`</b>: If the graph_def would be too large.
+
- - -
diff --git a/tensorflow/g3doc/api_docs/python/io_ops.md b/tensorflow/g3doc/api_docs/python/io_ops.md
index a728ba579e..b10cc79bc8 100644
--- a/tensorflow/g3doc/api_docs/python/io_ops.md
+++ b/tensorflow/g3doc/api_docs/python/io_ops.md
@@ -65,7 +65,7 @@ be fed using the `feed_dict` optional argument to `Session.run()`,
For example:
```python
-x = tf.placeholder(float, shape=(1024, 1024))
+x = tf.placeholder(tf.float32, shape=(1024, 1024))
y = tf.matmul(x, x)
with tf.Session() as sess:
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index 4f4292ff5e..8cb1f96e09 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -139,7 +139,8 @@ kept independently and each row and column will be kept or not kept together.
* <b>`x`</b>: A tensor.
-* <b>`keep_prob`</b>: A Python float. The probability that each element is kept.
+* <b>`keep_prob`</b>: A scalar `Tensor` with the same type as x. The probability
+ that each element is kept.
* <b>`noise_shape`</b>: A 1-D `Tensor` of type `int32`, representing the
shape for randomly generated keep/drop flags.
* <b>`seed`</b>: A Python integer. Used to create random seeds. See
diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md
index 85739e6d3b..a1ea3c2a2f 100644
--- a/tensorflow/g3doc/api_docs/python/train.md
+++ b/tensorflow/g3doc/api_docs/python/train.md
@@ -850,9 +850,9 @@ Example: decay every 100000 steps with a base of 0.96:
...
global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 0.1
-learning_rate = tf.exponential_decay(starter_learning_rate, global_step,
- 100000, 0.96, staircase=True)
-optimizer = tf.GradientDescent(learning_rate)
+learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
+ 100000, 0.96, staircase=True)
+optimizer = tf.GradientDescentOptimizer(learning_rate)
# Passing global_step to minimize() will increment it at each step.
optimizer.minimize(...my loss..., global_step=global_step)
```
@@ -888,7 +888,7 @@ moving averages for evaluations often improve results significantly.
### `class tf.train.ExponentialMovingAverage` <a class="md-anchor" id="ExponentialMovingAverage"></a>
-Maintains moving averages of variables by employing and exponential decay.
+Maintains moving averages of variables by employing an exponential decay.
When training a model, it is often beneficial to maintain moving averages of
the trained parameters. Evaluations that use averaged parameters sometimes
diff --git a/tensorflow/g3doc/get_started/leftnav_files b/tensorflow/g3doc/get_started/leftnav_files
new file mode 100644
index 0000000000..3fccbc0db5
--- /dev/null
+++ b/tensorflow/g3doc/get_started/leftnav_files
@@ -0,0 +1,3 @@
+index.md
+os_setup.md
+basic_usage.md \ No newline at end of file
diff --git a/tensorflow/g3doc/how_tos/leftnav_files b/tensorflow/g3doc/how_tos/leftnav_files
new file mode 100644
index 0000000000..f2b0a9fe9d
--- /dev/null
+++ b/tensorflow/g3doc/how_tos/leftnav_files
@@ -0,0 +1,10 @@
+variables/index.md
+../tutorials/mnist/tf/index.md
+summaries_and_tensorboard/index.md
+graph_viz/index.md
+reading_data/index.md
+threading_and_queues/index.md
+adding_an_op/index.md
+new_data_formats/index.md
+using_gpu/index.md
+variable_scope/index.md \ No newline at end of file
diff --git a/tensorflow/g3doc/resources/leftnav_files b/tensorflow/g3doc/resources/leftnav_files
new file mode 100644
index 0000000000..2e1940b5d4
--- /dev/null
+++ b/tensorflow/g3doc/resources/leftnav_files
@@ -0,0 +1,5 @@
+bib.md
+uses.md
+faq.md
+glossary.md
+dims_types.md
diff --git a/tensorflow/g3doc/tutorials/leftnav_files b/tensorflow/g3doc/tutorials/leftnav_files
new file mode 100644
index 0000000000..d27ec1bbf6
--- /dev/null
+++ b/tensorflow/g3doc/tutorials/leftnav_files
@@ -0,0 +1,10 @@
+mnist/beginners/index.md
+mnist/pros/index.md
+mnist/tf/index.md
+deep_cnn/index.md
+word2vec/index.md
+recurrent/index.md
+seq2seq/index.md
+mandelbrot/index.md
+pdes/index.md
+mnist/download/index.md
diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
index fcd891c58b..f21c97b7c6 100644
--- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
@@ -418,6 +418,6 @@ a look at this
[list of results](http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html).)
What matters is that we learned from this model. Still, if you're feeling a bit
-down about these results, check out [the next tutorial](../../../tutorials/index.md) where we
+down about these results, check out [the next tutorial](../../../tutorials/mnist/pros/index.md) where we
do a lot better, and learn how to build more sophisticated models using
TensorFlow!
diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md
index 0a7d1aeae0..58cec4c348 100644
--- a/tensorflow/g3doc/tutorials/mnist/pros/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md
@@ -232,9 +232,9 @@ print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
## Build a Multilayer Convolutional Network <a class="md-anchor" id="AUTOGENERATED-build-a-multilayer-convolutional-network"></a>
Getting 91% accuracy on MNIST is bad. It's almost embarrassingly bad. In this
-section, we'll fix that, jumping from a very simple model to something moderately
-sophisticated: a small convolutional neural network. This will get us to around
-99.2% accuracy -- not state of the art, but respectable.
+section, we'll fix that, jumping from a very simple model to something
+moderately sophisticated: a small convolutional neural network. This will get us
+to around 99.2% accuracy -- not state of the art, but respectable.
### Weight Initialization <a class="md-anchor" id="AUTOGENERATED-weight-initialization"></a>
diff --git a/tensorflow/g3doc/tutorials/recurrent/index.md b/tensorflow/g3doc/tutorials/recurrent/index.md
index d1be50e00f..510ecd5d66 100644
--- a/tensorflow/g3doc/tutorials/recurrent/index.md
+++ b/tensorflow/g3doc/tutorials/recurrent/index.md
@@ -186,7 +186,7 @@ Now we can run the model:
```
bazel-bin/tensorflow/models/rnn/ptb/ptb_word_lm \
- --data_path=/tmp/simple-examples/data/ --alsologtostderr --model small
+ --data_path=/tmp/simple-examples/data/ --model small
```
There are 3 supported model configurations in the tutorial code: "small",
diff --git a/tensorflow/g3doc/tutorials/seq2seq/index.md b/tensorflow/g3doc/tutorials/seq2seq/index.md
index a13f7cd95c..66c653e1f8 100644
--- a/tensorflow/g3doc/tutorials/seq2seq/index.md
+++ b/tensorflow/g3doc/tutorials/seq2seq/index.md
@@ -12,7 +12,7 @@ This tutorial will show you how to build and train such a system end-to-end.
You can start by running this binary.
```
-bazel run -c opt <...>/models/rnn/translate/translate.py
+bazel run -c opt <...>/models/rnn/translate:translate
--data_dir [your_data_directory]
```
diff --git a/tensorflow/models/embedding/BUILD b/tensorflow/models/embedding/BUILD
index fe52778fa9..9cd0d24b5b 100644
--- a/tensorflow/models/embedding/BUILD
+++ b/tensorflow/models/embedding/BUILD
@@ -49,6 +49,9 @@ py_test(
size = "small",
srcs = ["word2vec_optimized_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "notsan",
+ ],
deps = [
":word2vec_optimized",
"//tensorflow:tensorflow_py",
diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py
index 29bd9602bb..400168c8ca 100644
--- a/tensorflow/models/embedding/word2vec.py
+++ b/tensorflow/models/embedding/word2vec.py
@@ -18,13 +18,14 @@ from __future__ import division
from __future__ import print_function
import os
-from six.moves import xrange # pylint: disable=redefined-builtin
import sys
import threading
import time
import tensorflow.python.platform
+from six.moves import xrange # pylint: disable=redefined-builtin
+
import numpy as np
import tensorflow as tf
diff --git a/tensorflow/models/embedding/word2vec_optimized.py b/tensorflow/models/embedding/word2vec_optimized.py
index e6fccf689a..547abf8c3e 100644
--- a/tensorflow/models/embedding/word2vec_optimized.py
+++ b/tensorflow/models/embedding/word2vec_optimized.py
@@ -17,13 +17,14 @@ from __future__ import division
from __future__ import print_function
import os
-from six.moves import xrange # pylint: disable=redefined-builtin
import sys
import threading
import time
import tensorflow.python.platform
+from six.moves import xrange # pylint: disable=redefined-builtin
+
import numpy as np
import tensorflow as tf
diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py
index 8a6c7203c2..3a7330836b 100644
--- a/tensorflow/models/rnn/ptb/ptb_word_lm.py
+++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py
@@ -37,8 +37,7 @@ To compile on GPU:
bazel build -c opt tensorflow --config=cuda \
tensorflow/models/rnn/ptb:ptb_word_lm
To run:
- ./bazel-bin/.../ptb_word_lm \
- --data_path=/tmp/simple-examples/data/ --alsologtostderr
+ ./bazel-bin/.../ptb_word_lm --data_path=/tmp/simple-examples/data/
"""
from __future__ import absolute_import
diff --git a/tensorflow/models/rnn/rnn_cell.py b/tensorflow/models/rnn/rnn_cell.py
index ff93b47892..3ffff7815c 100644
--- a/tensorflow/models/rnn/rnn_cell.py
+++ b/tensorflow/models/rnn/rnn_cell.py
@@ -68,8 +68,8 @@ class RNNCell(object):
A 2D Tensor of shape [batch_size x state_size] filled with zeros.
"""
zeros = tf.zeros(tf.pack([batch_size, self.state_size]), dtype=dtype)
- # The reshape below is a no-op, but it allows shape inference of shape[1].
- return tf.reshape(zeros, [-1, self.state_size])
+ zeros.set_shape([None, self.state_size])
+ return zeros
class BasicRNNCell(RNNCell):
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 96b7136b15..99bf356cc2 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -602,6 +602,29 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(45.0, sess.run(b'f:0'))
self.assertEqual(45.0, sess.run(r'f:0'))
+ def testIncorrectGraph(self):
+ with ops.Graph().as_default() as g_1:
+ c_1 = constant_op.constant(1.0, name='c')
+
+ with ops.Graph().as_default() as g_2:
+ c_2 = constant_op.constant(2.0, name='c')
+
+ self.assertEqual('c', c_1.op.name)
+ self.assertEqual('c', c_2.op.name)
+
+ with session.Session(graph=g_1) as sess_1:
+ self.assertEqual(1.0, sess_1.run(c_1))
+ with self.assertRaises(ValueError):
+ sess_1.run(c_2)
+ with self.assertRaises(ValueError):
+ sess_1.run(c_2.op)
+
+ with session.Session(graph=g_2) as sess_2:
+ with self.assertRaises(ValueError):
+ sess_2.run(c_1)
+ with self.assertRaises(ValueError):
+ sess_2.run(c_1.op)
+ self.assertEqual(2.0, sess_2.run(c_2))
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 54917a506f..e5a8f8a4dc 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1552,7 +1552,7 @@ class Graph(object):
# new operations can be added.
self._finalized = False
# Functions defined in the graph
- self._functions = []
+ self._functions = collections.OrderedDict()
def _check_not_finalized(self):
"""Check if the graph is finalized.
@@ -1647,6 +1647,9 @@ class Graph(object):
Returns:
A [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
protocol buffer.
+
+ Raises:
+ ValueError: If the graph_def would be too large.
"""
graph = graph_pb2.GraphDef()
bytesize = 0
@@ -1658,18 +1661,18 @@ class Graph(object):
if bytesize >= (1 << 31) or bytesize < 0:
raise ValueError("GraphDef cannot be larger than 2GB.")
if self._functions:
- for f in self._functions:
+ for f in self._functions.values():
bytesize += f.ByteSize()
if bytesize >= (1 << 31) or bytesize < 0:
raise ValueError("GraphDef cannot be larger than 2GB.")
- graph.library.function.extend(self._functions)
+ graph.library.function.extend(self._functions.values())
return graph
def _add_function(self, function_def):
"""Adds a function to the graph.
The function is specified as a [`FunctionDef`]
- (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
+ (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/function.proto)
protocol buffer.
After the function has been added, you can call to the function by
@@ -1679,7 +1682,14 @@ class Graph(object):
Args:
function_def: A `FunctionDef` protocol buffer.
"""
- self._functions.append(function_def)
+ previous_def = self._functions.get(function_def.signature.name, None)
+ if previous_def:
+ if previous_def != function_def:
+ raise ValueError("Another function is already defined with that name")
+ else:
+ # No need to add again.
+ return
+ self._functions[function_def.signature.name] = function_def
# Helper functions to create operations.
def create_op(self, op_type, inputs, dtypes,
@@ -1871,9 +1881,13 @@ class Graph(object):
elif isinstance(obj, Tensor) and allow_tensor:
# Actually obj is just the object it's referring to.
+ if obj.graph is not self:
+ raise ValueError("Tensor %s is not an element of this graph." % obj)
return obj
elif isinstance(obj, Operation) and allow_operation:
# Actually obj is just the object it's referring to.
+ if obj.graph is not self:
+ raise ValueError("Operation %s is not an element of this graph." % obj)
return obj
else:
# We give up!
diff --git a/tensorflow/python/kernel_tests/gradient_checker.py b/tensorflow/python/kernel_tests/gradient_checker.py
index 1adf22d512..d08fc393fa 100644
--- a/tensorflow/python/kernel_tests/gradient_checker.py
+++ b/tensorflow/python/kernel_tests/gradient_checker.py
@@ -80,7 +80,7 @@ def _ComputeTheoricalJacobian(x, x_shape, x_data, dy, dy_shape, dx):
def _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta):
"""Computes the numeric Jacobian for dy/dx.
- Computes the numeric Japcobian by slightly perturbing the inputs and
+ Computes the numeric Jacobian by slightly perturbing the inputs and
measuring the differences on the output.
Args:
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 78262a55f3..33fdc017c7 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -624,7 +624,7 @@ def placeholder(dtype, shape=None, name=None):
For example:
```python
- x = tf.placeholder(float, shape=(1024, 1024))
+ x = tf.placeholder(tf.float32, shape=(1024, 1024))
y = tf.matmul(x, x)
with tf.Session() as sess:
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 77c496fb85..3f9371fead 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -243,7 +243,7 @@ def _merge_shapes(shape_list, enqueue_many):
shape_list = [tensor_shape.as_shape(s) for s in shape_list]
if enqueue_many:
# We want the shapes without the leading batch dimension.
- shape_list = [s.WithRankAtLeast(1)[1:] for s in shape_list]
+ shape_list = [s.with_rank_at_least(1)[1:] for s in shape_list]
merged_shape = shape_list[0]
for s in shape_list[1:]:
merged_shape.merge_with(s)
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index d85df52431..bd6e199054 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -269,6 +269,31 @@ class BatchTest(tf.test.TestCase):
for thread in threads:
thread.join()
+ def testOneThreadEnqueueMany(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ pre_batched = tf.train.batch([counter, "string"], batch_size=2)
+ batched = tf.train.batch(pre_batched, enqueue_many=True,
+ batch_size=batch_size)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ for i in range(num_batches):
+ results = sess.run(batched)
+ self.assertAllEqual(results[0], np.arange(i * batch_size,
+ (i + 1) * batch_size))
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
def testManyThreads(self):
with self.test_session() as sess:
batch_size = 10
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index d834b698ba..50a8f9bbce 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -45,7 +45,7 @@ def assign_moving_average(variable, value, decay, name=None):
class ExponentialMovingAverage(object):
- """Maintains moving averages of variables by employing and exponential decay.
+ """Maintains moving averages of variables by employing an exponential decay.
When training a model, it is often beneficial to maintain moving averages of
the trained parameters. Evaluations that use averaged parameters sometimes
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
index 8ada82770c..36692dd0f0 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
@@ -372,7 +372,7 @@ function findEdgeTargetsInGraph(
interface HierarchyParams {
verifyTemplate: boolean;
- groupSeries: boolean;
+ seriesNodeMinSize: number;
}
/**
@@ -396,8 +396,8 @@ export function build(graph: tf.graph.SlimGraph, params: HierarchyParams,
}, tracker)
.then(() => {
return runAsyncTask("Detect series", 20, () => {
- if (params.groupSeries) {
- groupSeries(h.root, h, seriesNames);
+ if (params.seriesNodeMinSize > 0) {
+ groupSeries(h.root, h, seriesNames, params.seriesNodeMinSize);
}
}, tracker);
})
@@ -550,15 +550,17 @@ function addEdges(h: Hierarchy, graph: SlimGraph,
*
* @param metanode
* @param hierarchy
+ * @param threshold If the series has this many nodes or more, then group them
+ * into a series.
* @return A dictionary from node name to series node name that contains the node
*/
function groupSeries(metanode: Metanode, hierarchy: Hierarchy,
- seriesNames: { [name: string]: string }) {
+ seriesNames: { [name: string]: string }, threshold: number) {
let metagraph = metanode.metagraph;
_.each(metagraph.nodes(), n => {
let child = metagraph.node(n);
if (child.type === tf.graph.NodeType.META) {
- groupSeries(<Metanode>child, hierarchy, seriesNames);
+ groupSeries(<Metanode>child, hierarchy, seriesNames, threshold);
}
});
@@ -569,6 +571,9 @@ function groupSeries(metanode: Metanode, hierarchy: Hierarchy,
// metagraph.
_.each(seriesDict, function(seriesNode: SeriesNode, seriesName: string) {
let nodeMemberNames = seriesNode.metagraph.nodes();
+ if (nodeMemberNames.length < threshold) {
+ return;
+ }
let firstMember = seriesNode.metagraph.node(nodeMemberNames[0]);
let seriesType = firstMember.type;
diff --git a/tensorflow/tensorboard/components/tf-graph-loader/tf-graph-loader.html b/tensorflow/tensorboard/components/tf-graph-loader/tf-graph-loader.html
index 3f290f2152..c3647a41b5 100644
--- a/tensorflow/tensorboard/components/tf-graph-loader/tf-graph-loader.html
+++ b/tensorflow/tensorboard/components/tf-graph-loader/tf-graph-loader.html
@@ -126,7 +126,9 @@ Polymer({
}
var hierarchyParams = {
verifyTemplate: true,
- groupSeries: true,
+ // If a set of numbered op nodes has at least this number of nodes
+ // then group them into a series node.
+ seriesNodeMinSize: 5,
};
var hierarchyTracker = tf.getSubtaskTracker(tracker, 50,
'Namespace hierarchy');
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/NeuralNetworks/Pooling.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/NeuralNetworks/Pooling.h
index 8dea22806c..942b060ba7 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/src/NeuralNetworks/Pooling.h
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/NeuralNetworks/Pooling.h
@@ -196,7 +196,7 @@ namespace internal {
template <typename T> struct AvgPoolMeanReducer
{
-#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64 || defined (EIGEN_USE_GPU) || defined(__CUDACC__) || defined(__CUDA_ARCH__))
+#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__)
// We only support packet access for floats.
static const bool PacketAccess = internal::is_same<T, float>::value;
#else
@@ -216,8 +216,16 @@ template <typename T> struct AvgPoolMeanReducer
}
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
+ return static_cast<T>(0);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
+ eigen_assert(scalarCount_ > 0);
+ return accum / scalarCount_;
+ }
-#if (!defined (EIGEN_USE_GPU) || !defined(__CUDACC__) || !defined(__CUDA_ARCH__))
+#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__)
#ifdef EIGEN_VECTORIZE_AVX
#define pequal(a,b) _mm256_cmp_ps(a,b,_CMP_EQ_UQ)
#define psel(a,b,false_mask) _mm256_blendv_ps(a,b,false_mask)
@@ -238,29 +246,11 @@ template <typename T> struct AvgPoolMeanReducer
packetCount_ = padd<Packet>(packetCount_, psel(pset1<Packet>(1), pset1<Packet>(0), skip_mask));
}
-#else
-#define pequal(a,b) make_float4(a.x == b.x ? 1.f : 0, a.y == b.y ? 1.f : 0, a.z == b.z ? 1.f : 0, a.w == b.w ? 1.f : 0)
-#define psel(a,b,c) make_float4(c.x ? b.x : a.x, c.y ? b.y : a.y, c.z ? b.z : a.z, c.w ? b.w : a.w)
-
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const float4& p, float4* accum) {
- float4 skip_mask = pequal(p, pset1<float4>(-Eigen::NumTraits<float>::highest()));
- (*accum) = padd<float4>(*accum, psel(p, pset1<float4>(0), skip_mask));
- packetCount_ = padd<float4>(packetCount_, psel(pset1<float4>(1), pset1<float4>(0), skip_mask));
- }
-
-#endif
-
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
- return static_cast<T>(0);
- }
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
return pset1<Packet>(0);
}
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
- eigen_assert(scalarCount_ > 0);
- return accum / scalarCount_;
- }
+
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
return pdiv(vaccum, packetCount_);
@@ -269,6 +259,7 @@ template <typename T> struct AvgPoolMeanReducer
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
return (saccum + predux(vaccum)) / (scalarCount_ + predux(packetCount_));
}
+#endif
protected:
typedef typename packet_traits<T>::type Packet;