aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-12-03 17:56:16 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-03 17:56:16 -0800
commit54a644f33f34a32fa2cb5e7a489c64540b16e166 (patch)
tree6c15047163a4e2554bdb737079f6abb536b8d641
parenteb5e56e479a41ad3696ea351e5804d17464d521a (diff)
TensorFlow: upstream changes to git
Change 109366961 TensorFlow BUILD: now that we have an ops library, set linkstatic to 1. This fixes a breakage in the would-be opensource build, and it *might* mean we can get rid of all of the RequireDefaultOps() calls in our code. The ops library is much smaller than the kernels library that was previously linked together. We set linkstatic=0 presumably since we didn't want to package a static copy of the kernels (very large) everywhere. But the op definitions are small, so this seems like a safe change to make. Time to build the various tests was not any longer after this change, and inspecting the example_trainer binary showed no large increase. Change 109363613 TensorFlow: new graph_def_builder_test needs to RequireDefaultOps. Change 109362569 Split ":ops" out of ":kernels" target in tensorflow/core. Change 109360666 Catch dtype and some shape errors sooner in `QueueBase`. Some avoidable errors were not being caught (e.g. the dtypes of the enqueue components were not checked against the queue's dtypes in Python), leading to cryptic messages at runtime. After this CL, they will be caught earlier. Change 109359569 TensorFlow: Expect g_ != nullptr in test Change 109350735 Add a version number to GraphDef We would like to be able to deprecate behavior in newly generated graphs without invalidating tensorflow's ability to read and evaluate old graphs. For this purpose, GraphDef now has a version field which can be checked inside op kernels to determine how backwards compatible to be. version.h defines TF_GRAPHDEF_VERSION_MIN and TF_GRAPHDEF_VERSION_MAX specifying the range of supported GraphDef versions in the current version of tensorflow. Also expose tf.__version__ and tf.__graph_def_version{,_min,_max}__ for Python interrogation purposes. Whenever we want to deprecate or change some GraphDef semantics, we will proceed as follows: 1. Bump TF_GRAPHDEF_VERSION_MAX, leaving TF_GRAPHDEF_VERSION_MIN unchanged. Describe the change in graph.proto, include the date introduced. 2. In each relevant kernel, implement the new behavior if the GraphDef version is new, but preserve the old behavior for previous GraphDef versions. 3. Wait six months or so (we need to formalize this somewhere). 4. Bump TF_GRAPHDEF_VERSION_MIN and remove the backwards compatibility. The GraphDef version is distinct from the open source version, but at least (4) and possibly (1) correspond to major version number bumps. The first GraphDef version bump is the upcoming scalar strictness change, which affects Google users only since open source is already scalar strict. This commit does not yet plumb the version number into OpKernelConstruction so that ops can access it. That will follow. Change 109350260 Made TensorShapeProto implicitly convertible to TensorShape. Base CL: 109366982
-rw-r--r--tensorflow/core/BUILD41
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc5
-rw-r--r--tensorflow/core/common_runtime/function.cc1
-rw-r--r--tensorflow/core/framework/function_testlib.cc2
-rw-r--r--tensorflow/core/framework/graph.proto11
-rw-r--r--tensorflow/core/framework/graph_def_util.cc1
-rw-r--r--tensorflow/core/graph/equal_graph_def.cc8
-rw-r--r--tensorflow/core/graph/equal_graph_def_test.cc10
-rw-r--r--tensorflow/core/graph/graph.cc4
-rw-r--r--tensorflow/core/graph/graph.h9
-rw-r--r--tensorflow/core/graph/graph_constructor.cc14
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc32
-rw-r--r--tensorflow/core/graph/graph_def_builder_test.cc48
-rw-r--r--tensorflow/core/graph/graph_partition.cc5
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc7
-rw-r--r--tensorflow/core/public/version.h5
-rw-r--r--tensorflow/python/BUILD13
-rw-r--r--tensorflow/python/__init__.py2
-rw-r--r--tensorflow/python/client/session_test.py4
-rw-r--r--tensorflow/python/client/tf_session.i7
-rw-r--r--tensorflow/python/framework/importer.py1
-rw-r--r--tensorflow/python/framework/importer_test.py25
-rw-r--r--tensorflow/python/framework/ops.py34
-rw-r--r--tensorflow/python/framework/ops_test.py16
-rw-r--r--tensorflow/python/framework/tensor_shape.py4
-rw-r--r--tensorflow/python/framework/tensor_shape_test.py14
-rw-r--r--tensorflow/python/framework/test_util.py6
-rw-r--r--tensorflow/python/framework/versions.py33
-rw-r--r--tensorflow/python/framework/versions_test.py45
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py38
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py69
-rw-r--r--tensorflow/python/ops/data_flow_ops.py68
-rw-r--r--tensorflow/python/ops/rnn.py89
33 files changed, 620 insertions, 51 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index cb26384ebe..8d0e71efd6 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -187,10 +187,7 @@ cc_library(
"graph/testlib.h",
],
copts = tf_copts(),
- visibility = [
- ":friends",
- "//tensorflow:internal",
- ],
+ visibility = ["//visibility:public"],
deps = [
":core_cpu",
":tensorflow",
@@ -213,11 +210,9 @@ tf_cuda_library(
)
tf_cuda_library(
- name = "kernels",
+ name = "ops",
srcs = glob(
[
- "kernels/**/*.h",
- "kernels/**/*.cc",
"ops/**/*.h",
"ops/**/*.cc",
"user_ops/**/*.h",
@@ -226,14 +221,38 @@ tf_cuda_library(
exclude = [
"**/*test*",
"**/*main.cc",
- "kernels/**/*.cu.cc",
"user_ops/**/*.cu.cc",
],
),
copts = tf_copts(),
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":core",
+ ":lib",
+ ":protos_cc",
+ "//tensorflow/models/embedding:word2vec_ops",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 1,
+)
+
+tf_cuda_library(
+ name = "kernels",
+ srcs = glob(
+ [
+ "kernels/**/*.h",
+ "kernels/**/*.cc",
+ ],
+ exclude = [
+ "**/*test*",
+ "**/*main.cc",
+ "kernels/**/*.cu.cc",
+ ],
+ ),
+ copts = tf_copts(),
cuda_deps = [
":gpu_kernels",
- ":cuda",
],
linkstatic = 0,
visibility = ["//visibility:public"],
@@ -241,10 +260,10 @@ tf_cuda_library(
"@gemmlowp//:eight_bit_int_gemm",
":core",
":lib",
+ ":ops",
":protos_cc",
":stream_executor",
"//tensorflow/models/embedding:word2vec_kernels",
- "//tensorflow/models/embedding:word2vec_ops",
"//third_party/eigen3",
],
alwayslink = 1,
@@ -262,6 +281,7 @@ tf_gpu_kernel_library(
),
visibility = ["//visibility:public"],
deps = [
+ ":cuda",
"//third_party/eigen3",
],
)
@@ -416,6 +436,7 @@ tf_cc_tests(
":direct_session",
":kernels",
":lib",
+ ":ops",
":strict_headers",
":test_main",
":testlib",
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 90eae3447a..fd5b2d5927 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -164,6 +164,11 @@ Status DirectSession::Extend(const GraphDef& graph) {
}
Status DirectSession::ExtendLocked(const GraphDef& graph) {
+ if (graph_created_ && graph_def_.version() != graph.version()) {
+ return errors::InvalidArgument("Incompatible GraphDef versions in Extend: ",
+ graph_def_.version(), " != ",
+ graph.version());
+ }
graph_created_ = true; // In case this is first call
graph_def_.MergeFrom(graph);
return Status::OK();
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index c7791e26d6..528ded205a 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -980,6 +980,7 @@ static void ToGraphDef(const Graph* g, GraphDef* gdef) {
}
gtl::InlinedVector<const Edge*, 4> inputs;
gdef->Clear();
+ gdef->set_version(g->version());
while (!ready.empty()) {
const Node* n = ready.front();
ready.pop_front();
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index 7a86bd0e2d..96720bb5fb 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace test {
@@ -27,6 +28,7 @@ typedef FunctionDefHelper FDH;
GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
gtl::ArraySlice<FunctionDef> funcs) {
GraphDef g;
+ g.set_version(TF_GRAPH_DEF_VERSION);
for (auto n : nodes) {
*(g.add_node()) = n;
}
diff --git a/tensorflow/core/framework/graph.proto b/tensorflow/core/framework/graph.proto
index a9bc07e88c..55b548d457 100644
--- a/tensorflow/core/framework/graph.proto
+++ b/tensorflow/core/framework/graph.proto
@@ -15,6 +15,17 @@ import "tensorflow/core/framework/function.proto";
message GraphDef {
repeated NodeDef node = 1;
+ // Compatibility version of the graph. Newly created graphs use
+ // the most recent version. Version history:
+ //
+ // 0. Graphs created before GraphDef versioning
+ // 1. First real version (2dec2015)
+ //
+ // The GraphDef version is distinct from the TensorFlow version.
+ // Each released version of TensorFlow will support a range of
+ // GraphDef versions.
+ int32 version = 3;
+
// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
//
// "library" provides user-defined functions.
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc
index 802eaa1aa8..c5fb630d01 100644
--- a/tensorflow/core/framework/graph_def_util.cc
+++ b/tensorflow/core/framework/graph_def_util.cc
@@ -24,6 +24,7 @@ namespace tensorflow {
string SummarizeGraphDef(const GraphDef& graph_def) {
string ret;
+ strings::StrAppend(&ret, "version = ", graph_def.version(), ";\n");
for (const NodeDef& node : graph_def.node()) {
strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
}
diff --git a/tensorflow/core/graph/equal_graph_def.cc b/tensorflow/core/graph/equal_graph_def.cc
index d22690de86..ddd10bc257 100644
--- a/tensorflow/core/graph/equal_graph_def.cc
+++ b/tensorflow/core/graph/equal_graph_def.cc
@@ -26,6 +26,14 @@ namespace tensorflow {
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
string* diff) {
+ if (actual.version() != expected.version()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Expected version ", expected.version(),
+ ", got version ", actual.version());
+ }
+ return false;
+ }
+
std::unordered_map<string, const NodeDef*> actual_index;
for (const NodeDef& node : actual.node()) {
actual_index[node.name()] = &node;
diff --git a/tensorflow/core/graph/equal_graph_def_test.cc b/tensorflow/core/graph/equal_graph_def_test.cc
index 2135f6c58b..2c879b329d 100644
--- a/tensorflow/core/graph/equal_graph_def_test.cc
+++ b/tensorflow/core/graph/equal_graph_def_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
@@ -88,10 +89,11 @@ TEST_F(EqualGraphDefTest, ExtraNode) {
Input(a_.opts().WithName("A"));
Input(a_.opts().WithName("B"));
EXPECT_FALSE(Match());
- EXPECT_EQ(
- "Found unexpected node 'B = Input[]()' not in expected graph:\n"
- "A = Input[]();\n",
- diff_);
+ EXPECT_EQ(strings::StrCat(
+ "Found unexpected node 'B = Input[]()' not in expected graph:\n"
+ "version = ",
+ TF_GRAPH_DEF_VERSION, ";\nA = Input[]();\n"),
+ diff_);
}
TEST_F(EqualGraphDefTest, NodeOrder) {
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 1d947cbbf1..d2c0d9e736 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -105,7 +106,7 @@ Node::Properties::~Properties() {}
// Graph
Graph::Graph(const OpRegistryInterface* ops)
- : ops_(ops), arena_(8 << 10 /* 8kB */) {
+ : ops_(ops), version_(TF_GRAPH_DEF_VERSION), arena_(8 << 10 /* 8kB */) {
// Source and sink have no endpoints, just control edges.
NodeDef def;
def.set_name("_SOURCE");
@@ -253,6 +254,7 @@ void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
void Graph::ToGraphDef(GraphDef* graph_def) const {
graph_def->Clear();
+ graph_def->set_version(version());
std::vector<const Edge*>
inputs; // Construct this outside the loop for speed.
for (const Node* node : nodes()) {
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index d425b417a3..e251b27bfe 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -187,11 +187,17 @@ class Graph {
// single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
//
// The graph can hold ops found in registry.
+ //
+ // The version defaults to TF_GRAPH_DEF_VERSION.
explicit Graph(const OpRegistryInterface* registry);
~Graph();
static const int kControlSlot = -1;
+ // The GraphDef version of this graph (see graph.proto).
+ int version() const { return version_; }
+ void set_version(int version) { version_ = version; }
+
// Adds a new node to this graph, and returns it. Infers the Op and
// input/output types for the node. *this owns the returned instance.
// Returns nullptr and sets *status on error.
@@ -274,6 +280,9 @@ class Graph {
// Registry of all known ops. Not owned.
const OpRegistryInterface* const ops_;
+ // GraphDef version
+ int version_;
+
// Allocator which will give us good locality.
core::Arena arena_;
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index edf71ec168..4459d0b54b 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -45,6 +46,19 @@ class GraphConstructor {
GraphConstructor(const GraphConstructorOptions& opts, const GraphDef* gdef,
Graph* g, Status* status)
: opts_(opts), gdef_(gdef), g_(g), status_(status) {
+ const int version = gdef->version();
+ if (!(TF_GRAPH_DEF_VERSION_MIN <= version &&
+ version <= TF_GRAPH_DEF_VERSION_MAX)) {
+ bool low = version < TF_GRAPH_DEF_VERSION_MAX;
+ *status = errors::InvalidArgument(
+ "GraphDef version ", version, " is ", low ? "no longer" : "not yet",
+ " supported: TensorFlow ", TF_VERSION_STRING, " needs ",
+ TF_GRAPH_DEF_VERSION_MAX, " <= version <= ", TF_GRAPH_DEF_VERSION_MIN,
+ ". ",
+ low ? "Please regenerate your graph." : "Please upgrade TensorFlow.");
+ return;
+ }
+ g->set_version(gdef->version());
BuildNodeIndex();
InitFromEdges();
Convert();
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 336cc5057b..7706a3d0c6 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/version.h"
// TODO(josh11b): Test InitCostModel().
// TODO(josh11b): Test setting the "device" field of a NodeDef.
@@ -58,6 +59,12 @@ class GraphConstructorTest : public ::testing::Test {
TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get()));
}
+ void ExpectVersion(int version) {
+ EXPECT_NE(nullptr, g_);
+ EXPECT_EQ(version, g_->version()) << "Expected version " << version
+ << ", got " << g_->version();
+ }
+
Node* FindNode(const string& name) {
for (Node* n : g_->nodes()) {
if (n->name() == name) return n;
@@ -160,7 +167,30 @@ TEST_F(GraphConstructorTest, TypeMismatch) {
"expected int32.");
}
-TEST_F(GraphConstructorTest, EmptyGraph) { ExpectOK(""); }
+TEST_F(GraphConstructorTest, EmptyGraph) {
+ ExpectOK("");
+ ExpectVersion(0); // The default GraphDef version is 0
+}
+
+TEST_F(GraphConstructorTest, VersionGraph) {
+ ASSERT_LT(0, TF_GRAPH_DEF_VERSION); // Verify the assertion is nontrivial
+ ExpectOK(strings::StrCat("version: ", TF_GRAPH_DEF_VERSION));
+ ExpectVersion(TF_GRAPH_DEF_VERSION);
+}
+
+TEST_F(GraphConstructorTest, LowVersion) {
+ ExpectError(strings::StrCat("version: ", -1),
+ R"(^GraphDef version -1 is no longer supported: TensorFlow \S+ )"
+ R"(needs \d+ <= version <= \d+\. )"
+ R"(Please regenerate your graph\.$)");
+}
+
+TEST_F(GraphConstructorTest, HighVersion) {
+ ExpectError(strings::StrCat("version: ", TF_GRAPH_DEF_VERSION_MAX + 1),
+ R"(^GraphDef version \d+ is not yet supported: TensorFlow \S+ )"
+ R"(needs \d+ <= version <= \d+\. )"
+ R"(Please upgrade TensorFlow\.$)");
+}
TEST_F(GraphConstructorTest, SimpleModel) {
ExpectOK(
diff --git a/tensorflow/core/graph/graph_def_builder_test.cc b/tensorflow/core/graph/graph_def_builder_test.cc
new file mode 100644
index 0000000000..a05594b98c
--- /dev/null
+++ b/tensorflow/core/graph/graph_def_builder_test.cc
@@ -0,0 +1,48 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/graph/graph_def_builder.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(GraphDefBuilderTest, Version) {
+ RequireDefaultOps();
+
+ // Verify that our assertions will be nontrivial
+ ASSERT_LT(0, TF_GRAPH_DEF_VERSION);
+
+ // Newly built graphs should use the current version
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+
+ // Check version when we convert to a Graph
+ Graph graph(OpRegistry::Global());
+ EXPECT_OK(builder.ToGraph(&graph));
+ ASSERT_EQ(graph.version(), TF_GRAPH_DEF_VERSION);
+
+ // Check version when we convert to a GraphDef
+ GraphDef graph_def;
+ EXPECT_OK(builder.ToGraphDef(&graph_def));
+ ASSERT_EQ(graph_def.version(), TF_GRAPH_DEF_VERSION);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index abb3d68676..5a39d8e358 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -1051,6 +1051,11 @@ Status Partition(const PartitionOptions& opts, Graph* g,
}
}
+ // Set versions
+ for (auto& it : *partitions) {
+ it.second.set_version(g->version());
+ }
+
// Set the start times for recvs at the very end.
if (opts.scheduling_for_recvs) {
for (auto& it : dup_recv) {
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
index fc7f581fd3..33e05e803e 100644
--- a/tensorflow/core/graph/graph_partition_test.cc
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
@@ -72,6 +73,12 @@ void Partition(const GraphDef& graph_def,
popts.control_flow_added = false;
Status s = Partition(popts, &g, partitions);
CHECK(s.ok()) << s;
+
+ // Check versions
+ EXPECT_EQ(graph_def.version(), TF_GRAPH_DEF_VERSION);
+ for (auto& it : *partitions) {
+ EXPECT_EQ(graph_def.version(), it.second.version());
+ }
}
void CheckLoopConstruction(const GraphDef& graph_def) {
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 68152ac48e..0cf11e6dd2 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -36,4 +36,9 @@ limitations under the License.
// TODO(josh11b): Public API functions for exporting the above.
+// Supported GraphDef versions (see graph.proto).
+#define TF_GRAPH_DEF_VERSION_MIN 0
+#define TF_GRAPH_DEF_VERSION_MAX 1
+#define TF_GRAPH_DEF_VERSION TF_GRAPH_DEF_VERSION_MAX
+
#endif // THIRD_PARTY_TENSORFLOW_CORE_PUBLIC_VERSION_H_
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 1fd5c539b8..8f152a7231 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -138,6 +138,7 @@ py_library(
"framework/tensor_shape.py",
"framework/dtypes.py",
"framework/tensor_util.py",
+ "framework/versions.py",
"ops/common_shapes.py",
],
srcs_version = "PY2AND3",
@@ -196,6 +197,18 @@ py_test(
)
py_test(
+ name = "framework_versions_test",
+ srcs = ["framework/versions_test.py"],
+ main = "framework/versions_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_test_lib",
+ ":platform_test",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
name = "framework_importer_test",
srcs = ["framework/importer_test.py"],
main = "framework/importer_test.py",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 718ab5cd93..a930d62951 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -48,6 +48,7 @@ from tensorflow.core.util.event_pb2 import *
# Framework
from tensorflow.python.framework.framework_lib import *
+from tensorflow.python.framework.versions import *
from tensorflow.python.framework import errors
# Session
@@ -81,3 +82,4 @@ _whitelist = set([app, compat, errors, flags, image, logging, nn,
_whitelist.update([ops, tensor_util]) # pylint: disable=undefined-variable
__all__ = [name for name, x in locals().items() if not name.startswith('_') and
(not inspect.ismodule(x) or x in _whitelist)]
+__all__.append('__version__')
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 7dbc5aad87..b5ef5b64f8 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
+from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops
@@ -425,7 +426,8 @@ class SessionTest(test_util.TensorFlowTestCase):
def testGraphDef(self):
with session.Session() as sess:
- self.assertProtoEquals('', sess.graph_def)
+ self.assertProtoEquals('version: %d' % versions.GRAPH_DEF_VERSION,
+ sess.graph_def)
c = constant_op.constant(5.0, name='c')
self.assertEquals(len(sess.graph_def.node), 1)
d = constant_op.constant(6.0, name='d')
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 168a323a7e..2d6a73eb9e 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/version.h"
%}
@@ -32,6 +33,12 @@ limitations under the License.
tensorflow::ImportNumpy();
%}
+// TensorFlow version and GraphDef versions
+%constant const char* __version__ = TF_VERSION_STRING;
+%constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION;
+%constant int GRAPH_DEF_VERSION_MIN = TF_GRAPH_DEF_VERSION_MIN;
+%constant int GRAPH_DEF_VERSION_MAX = TF_GRAPH_DEF_VERSION_MAX;
+
// Release the Python GIL for the duration of most methods.
%exception {
Py_BEGIN_ALLOW_THREADS;
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index e52a0e2037..0703ca2ae1 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -215,6 +215,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
with ops.op_scope(input_map.values(), name, 'import'):
g = ops.get_default_graph()
+ g.graph_def_version = graph_def.version
with ops.name_scope('_inputs'):
input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 154c550ed5..efc977aeb5 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -111,7 +111,8 @@ for op_def in _op_list.op:
class ImportGraphDefTest(tf.test.TestCase):
- def _MakeGraphDef(self, text):
+ def _MakeGraphDef(self, text, version=tf.GRAPH_DEF_VERSION):
+ text = "version: %d\n%s" % (version, text)
ret = tf.GraphDef()
text_format.Merge(text, ret)
return ret
@@ -610,6 +611,28 @@ class ImportGraphDefTest(tf.test.TestCase):
g = tf.identity(t)
g.eval()
+ def testVersion(self):
+ for version in tf.GRAPH_DEF_VERSION_MIN, tf.GRAPH_DEF_VERSION_MAX:
+ with tf.Graph().as_default():
+ a, = tf.import_graph_def(
+ self._MakeGraphDef("node { name: 'A' op: 'Oii' }", version=version),
+ return_elements=['A'])
+ self.assertEqual(a.graph.graph_def_version, version)
+
+ def testVersionLow(self):
+ with tf.Graph().as_default():
+ pat = (r"^GraphDef version -1 is no longer supported: TensorFlow \S+ "
+ r"needs \d+ <= version <= \d+. Please regenerate your graph.$")
+ with self.assertRaisesRegexp(ValueError, pat):
+ tf.import_graph_def(self._MakeGraphDef("", version=-1))
+
+ def testVersionHigh(self):
+ with tf.Graph().as_default():
+ pat = (r"^GraphDef version \d+ is not yet supported: TensorFlow \S+ "
+ r"needs \d+ <= version <= \d+. Please upgrade TensorFlow.$")
+ with self.assertRaisesRegexp(ValueError, pat):
+ tf.import_graph_def(self._MakeGraphDef("", version=1 << 30))
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index d3527c693e..f9796ca679 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -37,6 +37,7 @@ from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import registry
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import versions
from tensorflow.python.util import compat
@@ -1545,6 +1546,7 @@ class Graph(object):
@@seed
@@unique_name
@@version
+ @@graph_def_version
@@create_op
@@gradient_override_map
@@ -1585,6 +1587,8 @@ class Graph(object):
self._finalized = False
# Functions defined in the graph
self._functions = collections.OrderedDict()
+ # Default GraphDef version
+ self._graph_def_version = versions.GRAPH_DEF_VERSION
def _check_not_finalized(self):
"""Check if the graph is finalized.
@@ -1620,10 +1624,37 @@ class Graph(object):
@property
def version(self):
- """Returns a version number that increases as ops are added to the graph."""
+ """Returns a version number that increases as ops are added to the graph.
+
+ Note that this is unrelated to the
+ [GraphDef version](#Graph.graph_def_version).
+ """
return self._next_id_counter
@property
+ def graph_def_version(self):
+ """The GraphDef version of this graph.
+
+ For details on the meaning of each version, see [`GraphDef`]
+ (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto).
+ """
+ return self._graph_def_version
+
+ @graph_def_version.setter
+ def graph_def_version(self, version):
+ if not (versions.GRAPH_DEF_VERSION_MIN <= version <=
+ versions.GRAPH_DEF_VERSION_MAX):
+ low = version < versions.GRAPH_DEF_VERSION_MIN
+ raise ValueError(
+ "GraphDef version %d is %s supported: TensorFlow %s needs %d <= "
+ "version <= %d. Please %s." %
+ (version, "no longer" if low else "not yet",
+ versions.__version__, versions.GRAPH_DEF_VERSION_MIN,
+ versions.GRAPH_DEF_VERSION_MAX,
+ "regenerate your graph" if low else "upgrade TensorFlow"))
+ self._graph_def_version = version
+
+ @property
def seed(self):
return self._seed
@@ -1684,6 +1715,7 @@ class Graph(object):
ValueError: If the `graph_def` would be too large.
"""
graph = graph_pb2.GraphDef()
+ graph.version = self._graph_def_version
bytesize = 0
for op_id in sorted(self._nodes_by_id):
op = self._nodes_by_id[op_id]
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 8eafddca32..43044b1d39 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -410,7 +410,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
op = g.create_op("an_op", [], [dtypes.float32])
self.assertEqual(None, op.device)
gd = g.as_graph_def()
- self.assertProtoEquals("""
+ self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op" }
""", gd)
@@ -419,7 +419,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
with g.device("/job:worker/replica:2"):
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
- self.assertProtoEquals("""
+ self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op" device: "/job:worker/replica:2" }
""", gd)
@@ -430,7 +430,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
device_index=3)):
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
- self.assertProtoEquals("""
+ self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/job:worker/replica:2/task:0/device:CPU:3" }
""", gd)
@@ -443,7 +443,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
g.create_op("an_op", [], [dtypes.float32])
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
- self.assertProtoEquals("""
+ self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/job:worker/replica:2" }
node { name: "an_op_1" op: "an_op"
@@ -460,7 +460,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
g.create_op("an_op", [], [dtypes.float32])
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
- self.assertProtoEquals("""
+ self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/job:worker/replica:2" }
node { name: "an_op_1" op: "an_op"
@@ -477,7 +477,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
g.create_op("an_op", [], [dtypes.float32])
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
- self.assertProtoEquals("""
+ self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/job:worker/replica:2/device:CPU:1" }
node { name: "an_op_1" op: "an_op"
@@ -501,7 +501,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
- self.assertProtoEquals("""
+ self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/device:GPU:0" }
node { name: "an_op_1" op: "an_op"
@@ -522,7 +522,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
g.create_op("an_op", [], [dtypes.float32])
g.create_op("an_op", [], [dtypes.float32])
gd = g.as_graph_def()
- self.assertProtoEquals("""
+ self.assertProtoEqualsVersion("""
node { name: "an_op" op: "an_op"
device: "/job:worker/replica:2/device:CPU:1" }
node { name: "an_op_1" op: "an_op" }
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index 6914db0d34..d1fb6ddc06 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import tensorflow.python.platform
+from tensorflow.core.framework import tensor_shape_pb2
+
class Dimension(object):
"""Represents the value of one dimension in a TensorShape."""
@@ -407,6 +409,8 @@ class TensorShape(object):
# TODO(irving): Eliminate the single integer special case.
if dims is None:
self._dims = None
+ elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
+ self._dims = [as_dimension(dim.size) for dim in dims.dim]
else:
try:
dims_iter = iter(dims)
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
index 05681c4b16..ca5da2ba72 100644
--- a/tensorflow/python/framework/tensor_shape_test.py
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import tensorflow.python.platform
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
@@ -254,6 +255,19 @@ class ShapeTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(TypeError, r"unsupported operand type"):
unknown / unknown # pylint: disable=pointless-statement
+ def testConvertFromProto(self):
+ proto = tensor_util.MakeTensorShapeProto([])
+ self.assertEqual(tensor_shape.TensorShape([]),
+ tensor_shape.TensorShape(proto))
+ self.assertEqual(tensor_shape.TensorShape([]),
+ tensor_shape.as_shape(proto))
+
+ proto = tensor_util.MakeTensorShapeProto([1, 37, 42])
+ self.assertEqual(tensor_shape.TensorShape([1, 37, 42]),
+ tensor_shape.TensorShape(proto))
+ self.assertEqual(tensor_shape.TensorShape([1, 37, 42]),
+ tensor_shape.as_shape(proto))
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 91d2aa48fe..2b65e483de 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -38,6 +38,7 @@ from tensorflow.python.client import graph_util
from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import versions
from tensorflow.python.platform import googletest
from tensorflow.python.platform import logging
from tensorflow.python.util.protobuf import compare
@@ -113,6 +114,11 @@ class TensorFlowTestCase(googletest.TestCase):
type(expected_message_maybe_ascii) + " and " +
type(message))
+ def assertProtoEqualsVersion(self, expected, actual,
+ version=versions.GRAPH_DEF_VERSION):
+ expected = "version: %d\n%s" % (version, expected)
+ self.assertProtoEquals(expected, actual)
+
def assertStartsWith(self, actual, expected_start, msg=None):
"""Assert that actual.startswith(expected_start) is True.
diff --git a/tensorflow/python/framework/versions.py b/tensorflow/python/framework/versions.py
new file mode 100644
index 0000000000..98d2a77e1e
--- /dev/null
+++ b/tensorflow/python/framework/versions.py
@@ -0,0 +1,33 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""TensorFlow versions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+from tensorflow.python import pywrap_tensorflow
+
+__version__ = pywrap_tensorflow.__version__
+GRAPH_DEF_VERSION = pywrap_tensorflow.GRAPH_DEF_VERSION
+GRAPH_DEF_VERSION_MIN = pywrap_tensorflow.GRAPH_DEF_VERSION_MIN
+GRAPH_DEF_VERSION_MAX = pywrap_tensorflow.GRAPH_DEF_VERSION_MAX
+
+# Make sure these symbols are exported even though one starts with _.
+__all__ = ["__version__", "GRAPH_DEF_VERSION", "GRAPH_DEF_VERSION_MIN",
+ "GRAPH_DEF_VERSION_MAX"]
diff --git a/tensorflow/python/framework/versions_test.py b/tensorflow/python/framework/versions_test.py
new file mode 100644
index 0000000000..1d5f091409
--- /dev/null
+++ b/tensorflow/python/framework/versions_test.py
@@ -0,0 +1,45 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for exposed tensorflow versions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class VersionTest(tf.test.TestCase):
+
+ def testVersion(self):
+ self.assertEqual(type(tf.__version__), str)
+ # This pattern will need to grow as we include alpha, builds, etc.
+ self.assertRegexpMatches(tf.__version__, r'^\d+\.\d+\.\d+$')
+
+ def testGraphDefVersion(self):
+ version = tf.GRAPH_DEF_VERSION
+ min = tf.GRAPH_DEF_VERSION_MIN
+ max = tf.GRAPH_DEF_VERSION_MAX
+ for v in version, min, max:
+ self.assertEqual(type(v), int)
+ self.assertLessEqual(0, min)
+ self.assertLessEqual(min, version)
+ self.assertLessEqual(version, max)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index 77d1f2f3d2..f02e16a4ae 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -344,6 +344,42 @@ class FIFOQueueTest(tf.test.TestCase):
self.assertAllEqual(dequeued_t.eval(), elems)
def testEnqueueWrongShape(self):
+ q = tf.FIFOQueue(10, (tf.int32, tf.int32), ((), (2)))
+
+ with self.assertRaises(ValueError):
+ q.enqueue(([1, 2], [2, 2]))
+
+ with self.assertRaises(ValueError):
+ q.enqueue_many((7, [[1, 2], [3, 4], [5, 6]]))
+
+ def testBatchSizeMismatch(self):
+ q = tf.FIFOQueue(10, (tf.int32, tf.int32, tf.int32), ((), (), ()))
+
+ with self.assertRaises(ValueError):
+ q.enqueue_many(([1, 2, 3], [1, 2], [1, 2, 3]))
+
+ with self.assertRaises(ValueError):
+ q.enqueue_many(([1, 2, 3], [1, 2], tf.placeholder(tf.int32)))
+
+ with self.assertRaises(ValueError):
+ q.enqueue_many((tf.placeholder(tf.int32), [1, 2], [1, 2, 3]))
+
+ def testEnqueueManyEmptyTypeConversion(self):
+ q = tf.FIFOQueue(10, (tf.int32, tf.float32), ((), ()))
+ enq = q.enqueue_many(([], []))
+ self.assertEqual(tf.int32, enq.inputs[1].dtype)
+ self.assertEqual(tf.float32, enq.inputs[2].dtype)
+
+ def testEnqueueWrongType(self):
+ q = tf.FIFOQueue(10, (tf.int32, tf.float32), ((), ()))
+
+ with self.assertRaises(ValueError):
+ q.enqueue((tf.placeholder(tf.int32), tf.placeholder(tf.int32)))
+
+ with self.assertRaises(ValueError):
+ q.enqueue_many((tf.placeholder(tf.int32), tf.placeholder(tf.int32)))
+
+ def testEnqueueWrongShapeAtRuntime(self):
with self.test_session() as sess:
q = tf.FIFOQueue(10, (tf.int32, tf.int32), ((2, 2), (3, 3)))
elems_ok = np.array([1] * 4).reshape((2, 2)).astype(np.int32)
@@ -353,8 +389,6 @@ class FIFOQueueTest(tf.test.TestCase):
tf.errors.InvalidArgumentError, r"Expected \[3,3\], got \[3,4\]"):
sess.run([enqueue_op],
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
- sess.run([enqueue_op],
- feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
def testEnqueueDequeueManyWrongShape(self):
with self.test_session() as sess:
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 1ed53c0a1f..604936a0d5 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -485,5 +485,74 @@ class LSTMTest(tf.test.TestCase):
self._testDoubleInputWithDropoutAndDynamicCalculation(True)
+class BidirectionalRNNTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._seed = 23489
+ np.random.seed(self._seed)
+
+ def _testBidirectionalRNN(self, use_gpu):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+ sequence_length = tf.placeholder(tf.int64)
+ cell_fw = tf.nn.rnn_cell.LSTMCell(
+ num_units, input_size, initializer=initializer)
+ cell_bw = tf.nn.rnn_cell.LSTMCell(
+ num_units, input_size, initializer=initializer)
+ inputs = 10 * [
+ tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+ outputs = tf.nn.bidirectional_rnn(
+ cell_fw, cell_bw, inputs, dtype=tf.float32,
+ sequence_length=sequence_length)
+
+ self.assertEqual(len(outputs), len(inputs))
+ for out in outputs:
+ self.assertEqual(out.get_shape().as_list(), [batch_size, 2 * num_units])
+
+ tf.initialize_all_variables().run()
+ input_value = np.random.randn(batch_size, input_size)
+ # Run with pre-specified sequence length of 2, 3
+ out = sess.run(outputs, feed_dict={inputs[0]: input_value,
+ sequence_length: [2, 3]})
+
+ # Since the forward and backward LSTM cells were initialized with the
+ # same parameters, the forward and backward output has to be the same,
+ # but reversed in time. The format is output[time][batch][depth], and
+ # due to depth concatenation (as num_units=3 for both RNNs):
+ # - forward output: out[][][depth] for 0 <= depth < 3
+ # - backward output: out[][][depth] for 4 <= depth < 6
+ #
+ # First sequence in batch is length=2
+ # Check that the time=0 forward output is equal to time=1 backward output
+ self.assertEqual(out[0][0][0], out[1][0][3])
+ self.assertEqual(out[0][0][1], out[1][0][4])
+ self.assertEqual(out[0][0][2], out[1][0][5])
+ # Check that the time=1 forward output is equal to time=0 backward output
+ self.assertEqual(out[1][0][0], out[0][0][3])
+ self.assertEqual(out[1][0][1], out[0][0][4])
+ self.assertEqual(out[1][0][2], out[0][0][5])
+
+ # Second sequence in batch is length=3
+ # Check that the time=0 forward output is equal to time=2 backward output
+ self.assertEqual(out[0][1][0], out[2][1][3])
+ self.assertEqual(out[0][1][1], out[2][1][4])
+ self.assertEqual(out[0][1][2], out[2][1][5])
+ # Check that the time=1 forward output is equal to time=1 backward output
+ self.assertEqual(out[1][1][0], out[1][1][3])
+ self.assertEqual(out[1][1][1], out[1][1][4])
+ self.assertEqual(out[1][1][2], out[1][1][5])
+ # Check that the time=2 forward output is equal to time=0 backward output
+ self.assertEqual(out[2][1][0], out[0][1][3])
+ self.assertEqual(out[2][1][1], out[0][1][4])
+ self.assertEqual(out[2][1][2], out[0][1][5])
+
+ def testBidirectionalRNN(self):
+ self._testBidirectionalRNN(use_gpu=False)
+ self._testBidirectionalRNN(use_gpu=True)
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 261193715c..9d7ae04571 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -157,6 +157,26 @@ class QueueBase(object):
"""The list of dtypes for each component of a queue element."""
return self._dtypes
+ def _check_enqueue_dtypes(self, vals):
+ """Returns `vals` as a list of `Tensor`s, having checked their dtypes.
+
+ Args:
+ vals: A tensor or a list of tensors, corresponding to an
+ enqueue(_many) tuple.
+
+ Returns:
+ A list of `Tensor` objects.
+ """
+ if not isinstance(vals, (list, tuple)):
+ vals = [vals]
+
+ tensors = []
+ for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
+ tensors.append(ops.convert_to_tensor(val, dtype=dtype,
+ name="component_%d" % i))
+
+ return tensors
+
def enqueue(self, vals, name=None):
"""Enqueues one element to this queue.
@@ -170,16 +190,18 @@ class QueueBase(object):
Returns:
The operation that enqueues a new tuple of tensors to the queue.
"""
- if name is None:
- name = "%s_enqueue" % self._name
- ret = gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=name)
+ if not isinstance(vals, (list, tuple)):
+ vals = [vals]
- # NOTE(mrry): Not using a shape function because we need access to
- # the Queue object.
- for val, shape in zip(ret.inputs[1:], self._shapes):
- val.get_shape().assert_is_compatible_with(shape)
+ with ops.op_scope(vals, name, "%s_enqueue" % self._name) as scope:
+ vals = self._check_enqueue_dtypes(vals)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the `QueueBase` object.
+ for val, shape in zip(vals, self._shapes):
+ val.get_shape().assert_is_compatible_with(shape)
- return ret
+ return gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=scope)
def enqueue_many(self, vals, name=None):
"""Enqueues zero or elements to this queue.
@@ -199,20 +221,22 @@ class QueueBase(object):
Returns:
The operation that enqueues a batch of tuples of tensors to the queue.
"""
- if name is None:
- name = "%s_EnqueueMany" % self._name
-
- ret = gen_data_flow_ops._queue_enqueue_many(
- self._queue_ref, vals, name=name)
-
- # NOTE(mrry): Not using a shape function because we need access to
- # the `QueueBase` object.
- batch_dim = ret.inputs[1].get_shape()[0]
- for val, shape in zip(ret.inputs[1:], self._shapes):
- batch_dim.merge_with(val.get_shape()[0])
- val.get_shape()[1:].assert_is_compatible_with(shape)
-
- return ret
+ if not isinstance(vals, (list, tuple)):
+ vals = [vals]
+
+ with ops.op_scope(vals, name, "%s_EnqueueMany" % self._name) as scope:
+ vals = self._check_enqueue_dtypes(vals)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the `QueueBase` object.
+ batch_dim = vals[0].get_shape().with_rank_at_least(1)[0]
+ for val, shape in zip(vals, self._shapes):
+ batch_dim = batch_dim.merge_with(
+ val.get_shape().with_rank_at_least(1)[0])
+ val.get_shape()[1:].assert_is_compatible_with(shape)
+
+ return gen_data_flow_ops._queue_enqueue_many(
+ self._queue_ref, vals, name=scope)
def dequeue(self, name=None):
"""Dequeues one element from this queue.
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index e7d70ea79e..dcd2334e19 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -148,3 +148,92 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
outputs[-1] = array_ops.identity(outputs[-1])
return (outputs, states)
+
+
+def _reverse_seq(input_seq, lengths):
+ """Reverse a list of Tensors up to specified lengths.
+
+ Args:
+ input_seq: Sequence of seq_len tensors of dimension (batch_size, depth)
+ lengths: A tensor of dimension batch_size, containing lengths for each
+ sequence in the batch. If "None" is specified, simply reverses
+ the list.
+
+ Returns:
+ time-reversed sequence
+ """
+ if lengths is None:
+ return list(reversed(input_seq))
+
+ # Join into (time, batch_size, depth)
+ s_joined = array_ops.pack(input_seq)
+ # Reverse along dimension 0
+ s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
+ # Split again into list
+ result = array_ops.unpack(s_reversed)
+ return result
+
+
+def bidirectional_rnn(cell_fw, cell_bw, inputs,
+ initial_state_fw=None, initial_state_bw=None,
+ dtype=None, sequence_length=None, scope=None):
+ """Creates a bidirectional recurrent neural network.
+
+ Similar to the unidirectional case above (rnn) but takes input and builds
+ independent forward and backward RNNs with the final forward and backward
+ outputs depth-concatenated, such that the output will have the format
+ [time][batch][cell_fw.output_size + cell_bw.output_size]. The initial state
+ for both directions is zero by default (but can be set optionally) and no
+ intermediate states are ever returned -- the network is fully unrolled for
+ the given (passed in) length(s) of the sequence(s).
+
+ Args:
+ cell_fw: An instance of RNNCell, to be used for forward direction.
+ cell_bw: An instance of RNNCell, to be used for backward direction.
+ inputs: A length T list of inputs, each a vector with shape [batch_size].
+ initial_state_fw: (optional) An initial state for the forward RNN.
+ This must be a tensor of appropriate type and shape
+ [batch_size x cell.state_size].
+ initial_state_bw: (optional) Same as for initial_state_fw.
+ dtype: (optional) The data type for the initial state. Required if either
+ of the initial states are not provided.
+ sequence_length: An int64 vector (tensor) of size [batch_size], containing
+ the actual lengths for each of the sequences.
+ scope: VariableScope for the created subgraph; defaults to "BiRNN"
+
+ Returns:
+ A set of output `Tensors` where:
+ outputs is a length T list of outputs (one for each input), which
+ are depth-concatenated forward and backward outputs
+
+ Raises:
+ TypeError: If "cell_fw" or "cell_bw" is not an instance of RNNCell.
+ ValueError: If inputs is None or an empty list.
+ ValueError: If sequence_length is not defined.
+ """
+
+ if not isinstance(cell_fw, rnn_cell.RNNCell):
+ raise TypeError("cell_fw must be an instance of RNNCell")
+ if not isinstance(cell_bw, rnn_cell.RNNCell):
+ raise TypeError("cell_bw must be an instance of RNNCell")
+ if not isinstance(inputs, list):
+ raise TypeError("inputs must be a list")
+ if not sequence_length:
+ raise ValueError("sequence_length has to be defined")
+ if not inputs:
+ raise ValueError("inputs must not be empty")
+
+ name = scope or "BiRNN"
+ # Forward direction
+ with vs.variable_scope(name + "_FW"):
+ output_fw, _ = rnn(cell_fw, inputs, initial_state_fw, dtype)
+ # Backward direction
+ with vs.variable_scope(name + "_BW"):
+ tmp, _ = rnn(
+ cell_bw, _reverse_seq(inputs, sequence_length), initial_state_bw, dtype)
+ output_bw = _reverse_seq(tmp, sequence_length)
+ # Concat each of the forward/backward outputs
+ outputs = [array_ops.concat(1, [fw, bw])
+ for fw, bw in zip(output_fw, output_bw)]
+
+ return outputs