aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Yuan Yu <yuanbyu@google.com>2016-04-10 08:46:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-10 09:51:54 -0700
commit098f930de4ef044021f3ef1d3cdd6848c23eddb0 (patch)
tree107b20a63c2c1f4069a804af84489b38e9899478 /tensorflow
parentcc9560e8f449060feeaa73f47eb41e4d77079573 (diff)
This is another step to make TensorFlow more interactive and flexible to users. It allows a tensor produced by a run call to stay "in-place" so that a future run call can use it in-place. To achieve this, a run call can now return a handle of a tensor to the client, which can then be fed to a subsequent run call. This feature is complimentary to partial run, though there are some overlaps.
Here are a few properties of the current implementation: 1. Tensors are stored in the state of a session. The tensors are garbage collected if the client doesn't have a reference to the tensor or the session is closed. 2. There is no change to the current session API. We introduced two ops to manage the conversions between tensors and its handles. (There is a third op to garbage collect a tensor.) See the example below. 3. It fits quite well into the current feed-fetch design/implementation. It tries to reuse the graph (and caches) as much as possible so to make things efficient. Below is a simple example. More examples can be found in sessopn_ops_test.py. # Return a handle. a = tf.constant(10) b = tf.constant(5) c = tf.mul(a, b) h = tf.get_session_handle(c).eval() # Feed a tensor handle. f, x = tf.get_session_tensor(dtypes.int32) y = tf.mul(x, 10) result = sess.run(y, feed_dict={f: h.handle}) # result == 500 Change: 119481352
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/BUILD3
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc5
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc26
-rw-r--r--tensorflow/core/common_runtime/direct_session.h7
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc49
-rw-r--r--tensorflow/core/common_runtime/executor.cc6
-rw-r--r--tensorflow/core/common_runtime/executor.h3
-rw-r--r--tensorflow/core/common_runtime/session_state.cc83
-rw-r--r--tensorflow/core/framework/op_kernel.h13
-rw-r--r--tensorflow/core/framework/session_state.h85
-rw-r--r--tensorflow/core/graph/graph.cc3
-rw-r--r--tensorflow/core/graph/graph.h8
-rw-r--r--tensorflow/core/graph/testlib.cc9
-rw-r--r--tensorflow/core/graph/testlib.h3
-rw-r--r--tensorflow/core/kernels/BUILD2
-rw-r--r--tensorflow/core/kernels/session_ops.cc120
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc31
-rw-r--r--tensorflow/g3doc/api_docs/python/client.md3
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md11
-rw-r--r--tensorflow/g3doc/api_docs/python/math_ops.md164
-rw-r--r--tensorflow/g3doc/api_docs/python/session_ops.md102
-rw-r--r--tensorflow/g3doc/api_docs/python/state_ops.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/string_ops.md96
-rw-r--r--tensorflow/python/BUILD4
-rw-r--r--tensorflow/python/__init__.py5
-rw-r--r--tensorflow/python/client/session.py123
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py1
-rw-r--r--tensorflow/python/framework/ops.py8
-rw-r--r--tensorflow/python/kernel_tests/session_ops_test.py157
-rw-r--r--tensorflow/python/ops/control_flow_ops.py4
-rw-r--r--tensorflow/python/ops/data_flow_grad.py4
-rw-r--r--tensorflow/python/ops/data_flow_ops.py5
-rw-r--r--tensorflow/python/ops/session_ops.py255
-rw-r--r--tensorflow/python/ops/standard_ops.py1
34 files changed, 1346 insertions, 57 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index e7d9ab13a0..64fbd933b8 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -245,6 +245,7 @@ tf_cuda_library(
"framework/register_types.h",
"framework/resource_mgr.h",
"framework/selective_registration.h",
+ "framework/session_state.h",
"framework/tensor.h",
"framework/tensor_shape.h",
"framework/tensor_slice.h",
@@ -856,6 +857,7 @@ filegroup(
"framework/partial_tensor_shape.h",
"framework/rendezvous.h",
"framework/selective_registration.h",
+ "framework/session_state.h",
"framework/tensor.h",
"framework/tensor_reference.h",
"framework/tensor_shape.h",
@@ -1268,6 +1270,7 @@ tf_cc_test(
"//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/kernels:queue_ops",
+ "//tensorflow/core/kernels:session_ops",
"//tensorflow/core/kernels:variable_ops",
"//third_party/eigen3",
],
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 07f08c5577..71dbd6d680 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -50,6 +50,11 @@ bool IsConstantFoldable(const Node* n,
if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) {
return false;
}
+ // TODO(yuanbyu): For now disable these session handle operations.
+ if (n->IsGetSessionHandle() || n->IsGetSessionTensor() ||
+ n->IsDeleteSessionTensor()) {
+ return false;
+ }
if (n->IsSource()) {
return false;
}
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index b6e2988f3b..67605e23e5 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -313,6 +313,8 @@ Status DirectSession::Run(const RunOptions& run_options,
args.rendezvous = run_state.rendez;
args.cancellation_manager = cancellation_manager_;
args.runner = [this](Executor::Args::Closure c) { SchedClosure(c); };
+ args.session_state = &session_state_;
+ args.tensor_store = &run_state.tensor_store;
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
@@ -340,6 +342,11 @@ Status DirectSession::Run(const RunOptions& run_options,
// Receive outputs.
TF_RETURN_IF_ERROR(
RecvOutputs(output_names, executors_and_keys, &run_state, outputs));
+
+ // Save the output tensors of this run we choose to keep.
+ TF_RETURN_IF_ERROR(
+ run_state.tensor_store.SaveTensors(output_names, &session_state_));
+
return Status::OK();
}
@@ -369,9 +376,8 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
{
mutex_lock l(executor_lock_);
if (!partial_runs_.insert({run_state_args.handle, run_state}).second) {
- return errors::Internal("The handle ", run_state_args.handle,
- " created for this partial"
- " run is not unique.");
+ return errors::Internal("The handle '", run_state_args.handle,
+ "' created for this partial run is not unique.");
}
}
@@ -390,13 +396,12 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
});
Executor::Args args;
- {
- mutex_lock l(mu_);
- args.step_id = name_counter_++;
- }
+ args.step_id = step_id_counter_.fetch_add(1);
args.rendezvous = run_state->rendez;
args.cancellation_manager = cancellation_manager_;
args.runner = [this](Executor::Args::Closure c) { SchedClosure(c); };
+ args.session_state = &session_state_;
+ args.tensor_store = &run_state->tensor_store;
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
@@ -470,9 +475,14 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
s = RecvOutputs(output_names, executors_and_keys, run_state, outputs);
}
- // Delete the run state if there is an error or all fetches are done.
+ // Save the output tensors of this run we choose to keep.
+ if (s.ok()) {
+ s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
+ }
+
{
mutex_lock l(executor_lock_);
+ // Delete the run state if there is an error or all fetches are done.
bool done = true;
if (s.ok()) {
{
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 15c3b2625a..a35036ecd8 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@@ -78,6 +79,7 @@ class DirectSession : public Session {
::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs,
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) override;
+
::tensorflow::Status Close() override;
// NOTE: This is a temporary api that is only meant to enable testing.
@@ -135,6 +137,7 @@ class DirectSession : public Session {
Notification executors_done;
std::unordered_set<string> pending_inputs;
std::unordered_set<string> pending_outputs;
+ TensorStore tensor_store;
RunState(const std::vector<string>& input_names,
const std::vector<string>& output_names) {
@@ -146,6 +149,7 @@ class DirectSession : public Session {
pending_outputs.emplace(name);
}
}
+
~RunState();
};
@@ -228,6 +232,9 @@ class DirectSession : public Session {
std::unordered_map<string, RunState*> partial_runs_
GUARDED_BY(executor_lock_);
+ // This holds all the tensors that are currently alive in the session.
+ SessionState session_state_;
+
CancellationManager* cancellation_manager_;
// Saves and restores device placements for stateful nodes.
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 37e0ff11d4..75a1235f0b 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -564,6 +564,55 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
ASSERT_EQ(true, outputs[0].flat<bool>()(0));
}
+TEST(DirectSessionTest, RunHandleTest) {
+ GraphDef def;
+ Graph g(OpRegistry::Global());
+
+ Tensor value0(DT_FLOAT, TensorShape({}));
+ value0.scalar<float>()() = 1.0;
+ Node* const0 = test::graph::Constant(&g, value0);
+ Node* identity0 = test::graph::Identity(&g, const0);
+
+ Tensor value1(DT_FLOAT, TensorShape({}));
+ value1.scalar<float>()() = 2.0;
+ Node* const1 = test::graph::Constant(&g, value1);
+ Node* node3 = test::graph::Add(&g, identity0, const1);
+ Node* node4 = test::graph::Unary(&g, "GetSessionHandle", node3);
+
+ Tensor value2(DT_STRING, TensorShape({}));
+ Node* const2 = test::graph::Constant(&g, value2);
+ Node* node5 = test::graph::GetSessionTensor(&g, const2);
+ Node* node6 = test::graph::Add(&g, node5, const1);
+
+ Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
+
+ test::graph::ToGraphDef(&g, &def);
+
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(def));
+
+ // First run call: Create a handle.
+ std::vector<Tensor> outputs;
+ Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(1, outputs.size());
+
+ // Second run call: Use a handle.
+ std::vector<Tensor> outputs1;
+ s = session->Run({{const2->name(), outputs[0]}}, {node6->name() + ":0"}, {},
+ &outputs1);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(1, outputs1.size());
+ ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
+
+ // Third run call: Delete a handle.
+ std::vector<Tensor> outputs2;
+ s = session->Run({{const2->name(), outputs[0]}}, {}, {node7->name()},
+ &outputs2);
+ ASSERT_TRUE(s.ok());
+}
+
TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) {
Graph graph(OpRegistry::Global());
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 1051fe7193..87868462bb 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -645,6 +645,8 @@ class ExecutorState {
int64 step_id_;
// Not owned.
Rendezvous* rendezvous_;
+ SessionState* session_state_;
+ TensorStore* tensor_store_;
StepStatsCollector* stats_collector_;
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
@@ -793,6 +795,8 @@ class ExecutorState {
ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
: step_id_(args.step_id),
rendezvous_(args.rendezvous),
+ session_state_(args.session_state),
+ tensor_store_(args.tensor_store),
stats_collector_(args.stats_collector),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
@@ -938,6 +942,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
// track allocations if and only if we are collecting statistics
params.track_allocations = (stats_collector_ != nullptr);
params.rendezvous = rendezvous_;
+ params.session_state = session_state_;
+ params.tensor_store = tensor_store_;
params.cancellation_manager = cancellation_manager_;
params.call_frame = call_frame_;
params.function_library = impl_->params_.function_library;
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index 1d4972d04d..b013927980 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/notification.h"
@@ -85,6 +86,8 @@ class Executor {
StepStatsCollector* stats_collector = nullptr;
FunctionCallFrame* call_frame = nullptr;
CancellationManager* cancellation_manager = nullptr;
+ SessionState* session_state = nullptr;
+ TensorStore* tensor_store = nullptr;
typedef std::function<void()> Closure;
typedef std::function<void(Closure)> Runner;
diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc
new file mode 100644
index 0000000000..10e614cce5
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_state.cc
@@ -0,0 +1,83 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/session_state.h"
+#include "tensorflow/core/graph/tensor_id.h"
+
+namespace tensorflow {
+
+Status SessionState::GetTensor(const string& handle, Tensor* tensor) {
+ mutex_lock l(state_lock_);
+ auto it = tensors_.find(handle);
+ if (it == tensors_.end()) {
+ return errors::InvalidArgument("The tensor with handle '", handle,
+ "' is not in the session store.");
+ }
+ *tensor = it->second;
+ return Status::OK();
+}
+
+Status SessionState::AddTensor(const string& handle, const Tensor& tensor) {
+ mutex_lock l(state_lock_);
+ if (!tensors_.insert({handle, tensor}).second) {
+ return errors::InvalidArgument("Failed to add a tensor with handle '",
+ handle, "' to the session store.");
+ }
+ return Status::OK();
+}
+
+Status SessionState::DeleteTensor(const string& handle) {
+ mutex_lock l(state_lock_);
+ if (tensors_.erase(handle) == 0) {
+ return errors::InvalidArgument("Failed to delete a tensor with handle '",
+ handle, "' in the session store.");
+ }
+ return Status::OK();
+}
+
+int64 SessionState::GetNewId() {
+ mutex_lock l(state_lock_);
+ return tensor_id_++;
+}
+
+Status TensorStore::AddTensor(const string& name, const TensorAndKey& tk) {
+ mutex_lock l(lock_);
+ if (!tensors_.insert({name, tk}).second) {
+ return errors::InvalidArgument("Failed to add a tensor with name '", name,
+ "' to the tensor store.");
+ }
+ return Status::OK();
+}
+
+Status TensorStore::SaveTensors(const std::vector<string>& output_names,
+ SessionState* session_state) {
+ mutex_lock l(lock_);
+ if (tensors_.size() != 0) {
+ // Save only the tensors in output_names in the session.
+ for (const string& name : output_names) {
+ TensorId id(ParseTensorName(name));
+ const string& op_name = id.first.ToString();
+ auto it = tensors_.find(op_name);
+ if (it != tensors_.end()) {
+ // Save the tensor to the session state.
+ string key = it->second.GetHandle(op_name);
+ TF_RETURN_IF_ERROR(session_state->AddTensor(key, it->second.tensor));
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 61d15edf7b..7667b1cd0f 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/selective_registration.h"
+#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -502,6 +503,12 @@ class OpKernelContext {
// computations running on other devices.
Rendezvous* rendezvous = nullptr;
+ // The session state for this op.
+ SessionState* session_state = nullptr;
+
+ // The tensor store for this op.
+ TensorStore* tensor_store = nullptr;
+
// Mechanism used by this op kernel invocation to register a callback
// for its cancellation.
CancellationManager* cancellation_manager = nullptr;
@@ -841,6 +848,12 @@ class OpKernelContext {
// Rendezvous Send() and Recv().
Rendezvous* rendezvous() const { return params_->rendezvous; }
+ // An op kernel can access the session state it belongs to.
+ SessionState* session_state() const { return params_->session_state; }
+
+ // An op kernel can access the tensor store of the run it belongs to.
+ TensorStore* tensor_store() const { return params_->tensor_store; }
+
// Function call support.
//
// If this kernel invocation is within a function execution,
diff --git a/tensorflow/core/framework/session_state.h b/tensorflow/core/framework/session_state.h
new file mode 100644
index 0000000000..0093e91f9b
--- /dev/null
+++ b/tensorflow/core/framework/session_state.h
@@ -0,0 +1,85 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_FRAMEWORK_SESSION_STATE_H_
+#define TENSORFLOW_FRAMEWORK_SESSION_STATE_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// The session state remembers the tensors we choose to keep across
+// multiple run calls.
+class SessionState {
+ public:
+ // Get a tensor from the session state.
+ Status GetTensor(const string& handle, Tensor* tensor);
+
+ // Store a tensor in the session state.
+ Status AddTensor(const string& handle, const Tensor& tensor);
+
+ // Delete a tensdor from the session state.
+ Status DeleteTensor(const string& handle);
+
+ int64 GetNewId();
+
+ private:
+ mutex state_lock_;
+
+ // For generating unique ids for tensors stored in the session.
+ int64 tensor_id_ = 0;
+
+ // The live tensors in the session. A map from tensor handle to tensor.
+ std::unordered_map<string, Tensor> tensors_;
+};
+
+// The tensor store remembers the tensors we choose to keep for the
+// current run call. It is available to every op kernel.
+class TensorStore {
+ public:
+ struct TensorAndKey {
+ Tensor tensor;
+ int64 id;
+ string device_name;
+
+ string GetHandle(const string& tensor_name) {
+ return strings::StrCat(tensor_name, ";", id, ";", device_name);
+ }
+ };
+
+ // Add the named tensor to the tensor store for this run.
+ Status AddTensor(const string& name, const TensorAndKey& tk);
+
+ // Save the tensors in the tensor store of this run to the session.
+ Status SaveTensors(const std::vector<string>& output_names,
+ SessionState* session_state);
+
+ private:
+ mutex lock_;
+
+ // The tensors that will be saved to session state when this run completes.
+ // A map from tensor string name to tensor.
+ std::unordered_map<string, TensorAndKey> tensors_ GUARDED_BY(lock_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_SESSION_STATE_H_
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 57c5b2b200..80eaed56a9 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -95,6 +95,9 @@ void Node::Initialize(int id, int cost_id, Properties* props) {
SET_CLASS(NC_CONSTANT, ts, "Const", "HostConst");
SET_CLASS(NC_VARIABLE, ts, "Variable", "");
SET_CLASS(NC_IDENTITY, ts, "Identity", "RefIdentity");
+ SET_CLASS(NC_GET_SESSION_HANDLE, ts, "GetSessionHandle", "");
+ SET_CLASS(NC_GET_SESSION_TENSOR, ts, "GetSessionTensor", "");
+ SET_CLASS(NC_DELETE_SESSION_TENSOR, ts, "DeleteSessionTensor", "");
if (class_ == NC_UNINITIALIZED) {
class_ = NC_OTHER; // Catch all
}
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 4ad2a306b2..23aa211c84 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -118,6 +118,11 @@ class Node {
bool IsConstant() const { return (class_ == NC_CONSTANT); }
bool IsVariable() const { return (class_ == NC_VARIABLE); }
bool IsIdentity() const { return (class_ == NC_IDENTITY); }
+ bool IsGetSessionHandle() const { return (class_ == NC_GET_SESSION_HANDLE); }
+ bool IsGetSessionTensor() const { return (class_ == NC_GET_SESSION_TENSOR); }
+ bool IsDeleteSessionTensor() const {
+ return (class_ == NC_DELETE_SESSION_TENSOR);
+ }
bool IsControlFlow() const {
return (class_ != NC_OTHER) && // Fast path
(IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
@@ -172,6 +177,9 @@ class Node {
NC_CONSTANT,
NC_VARIABLE,
NC_IDENTITY,
+ NC_GET_SESSION_HANDLE,
+ NC_GET_SESSION_TENSOR,
+ NC_DELETE_SESSION_TENSOR,
NC_OTHER // Not a special kind of node
};
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index f3164009fc..e5267c4575 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -360,6 +360,15 @@ Node* Gather(Graph* g, Node* in0, Node* in1) {
return ret;
}
+Node* GetSessionTensor(Graph* g, Node* in) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor")
+ .Input(in, 0)
+ .Attr("dtype", DT_FLOAT)
+ .Finalize(g, &ret));
+ return ret;
+}
+
void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
} // end namespace graph
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index cb6f2468f2..f61265d6f4 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -161,6 +161,9 @@ Node* Gather(Graph* g, Node* in0, Node* in1);
// Computes the args needed broadcast gradient function.
Node* BroadcastGradientArgs(Graph* g, Node* s0, Node* s1);
+// Gets a tensor stored in the session state.
+Node* GetSessionTensor(Graph* g, Node* in);
+
} // end namespace graph
} // end namespace test
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index af9cae06d1..3188a730bc 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -542,6 +542,7 @@ tf_kernel_libraries(
"padding_fifo_queue_op",
"queue_ops",
"random_shuffle_queue_op",
+ "session_ops",
"stack_ops",
"tensor_array_ops",
],
@@ -1518,6 +1519,7 @@ filegroup(
"restore_op.cc",
"save_op.cc",
"save_restore_tensor.cc",
+ "session_ops.cc",
"softplus_op.cc",
"softsign_op.cc",
"sparse_to_dense_op.cc",
diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc
new file mode 100644
index 0000000000..6c814e2d40
--- /dev/null
+++ b/tensorflow/core/kernels/session_ops.cc
@@ -0,0 +1,120 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/data_flow_ops.cc.
+
+#include <limits.h>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+class GetSessionHandleOp : public OpKernel {
+ public:
+ explicit GetSessionHandleOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor val = ctx->input(0);
+ int64 id = ctx->session_state()->GetNewId();
+ TensorStore::TensorAndKey tk{val, id, def().device()};
+ OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(def().name(), tk));
+ Tensor* handle = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
+ handle->flat<string>().setConstant(tk.GetHandle(def().name()));
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GetSessionHandleOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("GetSessionHandle").Device(DEVICE_CPU),
+ GetSessionHandleOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("T"), \
+ GetSessionHandleOp)
+
+TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
+REGISTER_GPU_KERNEL(bool);
+#undef REGISTER_GPU_KERNEL
+
+class GetSessionTensorOp : public OpKernel {
+ public:
+ explicit GetSessionTensorOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& handle = ctx->input(0);
+ const string& name = handle.scalar<string>()();
+ Tensor val;
+ OP_REQUIRES_OK(ctx, ctx->session_state()->GetTensor(name, &val));
+ ctx->set_output(0, val);
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GetSessionTensorOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("GetSessionTensor").Device(DEVICE_CPU),
+ GetSessionTensorOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("GetSessionTensor") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("dtype"), \
+ GetSessionTensorOp)
+
+TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
+REGISTER_GPU_KERNEL(bool);
+#undef REGISTER_GPU_KERNEL
+
+class DeleteSessionTensorOp : public OpKernel {
+ public:
+ explicit DeleteSessionTensorOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& handle = ctx->input(0);
+ const string& name = handle.scalar<string>()();
+ OP_REQUIRES_OK(ctx, ctx->session_state()->DeleteTensor(name));
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DeleteSessionTensorOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("DeleteSessionTensor").Device(DEVICE_CPU),
+ DeleteSessionTensorOp);
+REGISTER_KERNEL_BUILDER(
+ Name("DeleteSessionTensor").Device(DEVICE_GPU).HostMemory("handle"),
+ DeleteSessionTensorOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index e26a55f3b2..cef74ca8ac 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -686,4 +686,35 @@ keys: Keys of type Tkey.
values: Values of type Tval. Same shape as `keys`.
)doc");
+REGISTER_OP("GetSessionHandle")
+ .Input("value: T")
+ .Output("handle: string")
+ .Attr("T: type")
+ .Doc(R"doc(
+Store the input tensor in the state of the current session.
+
+value: The tensor to be stored.
+handle: The handle for the tensor stored in the session state.
+)doc");
+
+REGISTER_OP("GetSessionTensor")
+ .Input("handle: string")
+ .Output("value: dtype")
+ .Attr("dtype: type")
+ .Doc(R"doc(
+Get the value of the tensor specified by its handle.
+
+handle: The handle for a tensor stored in the session state.
+value: The tensor for the given handle.
+dtype: The type of the output value.
+)doc");
+
+REGISTER_OP("DeleteSessionTensor")
+ .Input("handle: string")
+ .Doc(R"doc(
+Delete the tensor specified by its handle in the session.
+
+handle: The handle for a tensor stored in the session state.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/g3doc/api_docs/python/client.md b/tensorflow/g3doc/api_docs/python/client.md
index cdb9df53a5..690707b1b2 100644
--- a/tensorflow/g3doc/api_docs/python/client.md
+++ b/tensorflow/g3doc/api_docs/python/client.md
@@ -117,6 +117,9 @@ method. A graph element can be one of the following types:
the *i*th return value will be a
[`SparseTensorValue`](../../api_docs/python/sparse_ops.md#SparseTensorValue)
containing the value of that sparse tensor.
+* If the *i*th element of `fetches` is produced by a `get_tensor_handle` op,
+ the *i*th return value will be a numpy ndarray containing the handle of
+ that tensor.
The optional `feed_dict` argument allows the caller to override
the value of tensors in the graph. Each key in `feed_dict` can be
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index 295f956da9..49636d657a 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -141,6 +141,8 @@
* [`batch_ifft3d`](../../api_docs/python/math_ops.md#batch_ifft3d)
* [`batch_matmul`](../../api_docs/python/math_ops.md#batch_matmul)
* [`batch_matrix_determinant`](../../api_docs/python/math_ops.md#batch_matrix_determinant)
+ * [`batch_matrix_diag`](../../api_docs/python/math_ops.md#batch_matrix_diag)
+ * [`batch_matrix_diag_part`](../../api_docs/python/math_ops.md#batch_matrix_diag_part)
* [`batch_matrix_inverse`](../../api_docs/python/math_ops.md#batch_matrix_inverse)
* [`batch_matrix_solve`](../../api_docs/python/math_ops.md#batch_matrix_solve)
* [`batch_matrix_solve_ls`](../../api_docs/python/math_ops.md#batch_matrix_solve_ls)
@@ -224,6 +226,10 @@
* [`unsorted_segment_sum`](../../api_docs/python/math_ops.md#unsorted_segment_sum)
* [`where`](../../api_docs/python/math_ops.md#where)
+* **[Strings](../../api_docs/python/string_ops.md)**:
+ * [`reduce_join`](../../api_docs/python/string_ops.md#reduce_join)
+ * [`string_to_hash_bucket`](../../api_docs/python/string_ops.md#string_to_hash_bucket)
+
* **[Histograms](../../api_docs/python/histogram_ops.md)**:
* [`histogram_fixed_width`](../../api_docs/python/histogram_ops.md#histogram_fixed_width)
@@ -262,6 +268,11 @@
* [`map_fn`](../../api_docs/python/functional_ops.md#map_fn)
* [`scan`](../../api_docs/python/functional_ops.md#scan)
+* **[Tensor Handle Operations](../../api_docs/python/session_ops.md)**:
+ * [`delete_session_tensor`](../../api_docs/python/session_ops.md#delete_session_tensor)
+ * [`get_session_handle`](../../api_docs/python/session_ops.md#get_session_handle)
+ * [`get_session_tensor`](../../api_docs/python/session_ops.md#get_session_tensor)
+
* **[Images](../../api_docs/python/image.md)**:
* [`adjust_brightness`](../../api_docs/python/image.md#adjust_brightness)
* [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast)
diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md
index 403621e310..627b6fa5d0 100644
--- a/tensorflow/g3doc/api_docs/python/math_ops.md
+++ b/tensorflow/g3doc/api_docs/python/math_ops.md
@@ -743,6 +743,101 @@ mathematical functions for matrices to your graph.
- - -
+### `tf.batch_matrix_diag(diagonal, name=None)` {#batch_matrix_diag}
+
+Returns a batched diagonal tensor with a given batched diagonal values.
+
+Given a `diagonal`, this operation returns a tensor with the `diagonal` and
+everything else padded with zeros. The diagonal is computed as follows:
+
+Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a
+tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where:
+
+`output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`.
+
+For example:
+
+```prettyprint
+# 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]]
+
+and diagonal.shape = (2, 4)
+
+tf.batch_matrix_diag(diagonal) ==> [[[1, 0, 0, 0]
+ [0, 2, 0, 0]
+ [0, 0, 3, 0]
+ [0, 0, 0, 4]],
+ [[5, 0, 0, 0]
+ [0, 6, 0, 0]
+ [0, 0, 7, 0]
+ [0, 0, 0, 8]]]
+
+which has shape (2, 4, 4)
+```
+
+##### Args:
+
+
+* <b>`diagonal`</b>: A `Tensor`. Rank `k`, where `k >= 1`.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `Tensor`. Has the same type as `diagonal`.
+ Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`.
+
+
+- - -
+
+### `tf.batch_matrix_diag_part(input, name=None)` {#batch_matrix_diag_part}
+
+Returns the batched diagonal part of a batched tensor.
+
+This operation returns a tensor with the `diagonal` part
+of the batched `input`. The `diagonal` part is computed as follows:
+
+Assume `input` has `k` dimensions `[I, J, K, ..., N, N]`, then the output is a
+tensor of rank `k - 1` with dimensions `[I, J, K, ..., N]` where:
+
+`diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`.
+
+The input must be at least a matrix.
+
+For example:
+
+```prettyprint
+# 'input' is [[[1, 0, 0, 0]
+ [0, 2, 0, 0]
+ [0, 0, 3, 0]
+ [0, 0, 0, 4]],
+ [[5, 0, 0, 0]
+ [0, 6, 0, 0]
+ [0, 0, 7, 0]
+ [0, 0, 0, 8]]]
+
+and input.shape = (2, 4, 4)
+
+tf.batch_matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]]
+
+which has shape (2, 4)
+```
+
+##### Args:
+
+
+* <b>`input`</b>: A `Tensor`.
+ Rank `k` tensor where `k >= 2` and the last two dimensions are equal.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `Tensor`. Has the same type as `input`.
+ The extracted diagonal(s) having shape
+ `diagonal.shape = input.shape[:-1]`.
+
+
+
+- - -
+
### `tf.diag(diagonal, name=None)` {#diag}
Returns a diagonal tensor with a given diagonal values.
@@ -1192,7 +1287,7 @@ eigenvalues, and subsequent [...,1:, :] containing the eigenvectors.
- - -
-### `tf.matrix_solve(matrix, rhs, name=None)` {#matrix_solve}
+### `tf.matrix_solve(matrix, rhs, adjoint=None, name=None)` {#matrix_solve}
Solves a system of linear equations. Checks for invertibility.
@@ -1202,25 +1297,30 @@ Solves a system of linear equations. Checks for invertibility.
* <b>`matrix`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`.
Shape is `[M, M]`.
* <b>`rhs`</b>: A `Tensor`. Must have the same type as `matrix`. Shape is `[M, K]`.
+* <b>`adjoint`</b>: An optional `bool`. Defaults to `False`.
+ Boolean indicating whether to solve with `matrix` or its adjoint.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A `Tensor`. Has the same type as `matrix`.
- Shape is `[M, K]` containing the tensor that solves
- matrix * output = rhs.
+ Shape is `[M, K]`. If `adjoint` is `False` then `output` that solves
+ `matrix` * `output` = `rhs`. If `adjoint` is `True` then `output` that solves
+ `adjoint(matrix)` * `output` = `rhs`.
- - -
-### `tf.batch_matrix_solve(matrix, rhs, name=None)` {#batch_matrix_solve}
+### `tf.batch_matrix_solve(matrix, rhs, adjoint=None, name=None)` {#batch_matrix_solve}
Solves systems of linear equations. Checks for invertibility.
Matrix is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
form square matrices. Rhs is a tensor of shape
-`[..., M, K]`. The output is a tensor shape `[..., M, K]` where each output
-matrix satisfies matrix[..., :, :] * output[..., :, :] = rhs[..., :, :].
+`[..., M, K]`. The output is a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output
+matrix satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
+If `adjoint` is `True` then each output
+matrix satisfies `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`.
##### Args:
@@ -1229,6 +1329,9 @@ matrix satisfies matrix[..., :, :] * output[..., :, :] = rhs[..., :, :].
Shape is `[..., M, M]`.
* <b>`rhs`</b>: A `Tensor`. Must have the same type as `matrix`.
Shape is `[..., M, K]`.
+* <b>`adjoint`</b>: An optional `bool`. Defaults to `False`.
+ Boolean indicating whether to solve with `matrix` or its (block-wise)
+ adjoint.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
@@ -1239,21 +1342,24 @@ matrix satisfies matrix[..., :, :] * output[..., :, :] = rhs[..., :, :].
- - -
-### `tf.matrix_triangular_solve(matrix, rhs, lower=None, name=None)` {#matrix_triangular_solve}
+### `tf.matrix_triangular_solve(matrix, rhs, lower=None, adjoint=None, name=None)` {#matrix_triangular_solve}
Solves a system of linear equations with an upper or lower triangular matrix by
backsubstitution.
`matrix` is a matrix of shape `[M, M]`. If `lower` is `True` then the strictly
-upper triangular part of `matrix` is ignored. If `lower` is False then the
-strictly lower triangular part of `matrix` is ignored. `rhs` is a matrix of
-shape [M, K]`.
+upper triangular part of `matrix` is assumed to be zero and not accessed.
+If `lower` is False then the strictly lower triangular part of `matrix` is
+assumed to be zero and not accessed.
+`rhs` is a matrix of shape [M, K]`.
-The output is a matrix of shape `[M, K]`. If `lower` is `True` then the output
-satisfies \\(\sum_{k=0}^{i}\\) matrix[i, k] * output[k, j] = rhs[i, j].
-If `lower` is false then output satisfies
-\\(\sum_{k=i}^{K-1}\\) matrix[i, k] * output[k, j] = rhs[i, j].
+The output is a matrix of shape `[M, K]`. If `adjoint` is `False` the output
+satisfies the matrix equation `matrix` * `output` = `rhs`.
+If `adjoint` is `False` then `output` satisfies the matrix equation
+`matrix` * `output` = `rhs`.
+If `adjoint` is `True` then `output` satisfies the matrix equation
+`adjoint(matrix)` * `output` = `rhs`.
##### Args:
@@ -1262,7 +1368,9 @@ If `lower` is false then output satisfies
Shape is `[M, M]`.
* <b>`rhs`</b>: A `Tensor`. Must have the same type as `matrix`. Shape is `[M, K]`.
* <b>`lower`</b>: An optional `bool`. Defaults to `True`.
- Boolean indicating whether matrix is lower or upper triangular.
+ Boolean indicating whether `matrix` is lower or upper triangular
+* <b>`adjoint`</b>: An optional `bool`. Defaults to `False`.
+ Boolean indicating whether to solve with `matrix` or its adjoint.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
@@ -1272,7 +1380,7 @@ If `lower` is false then output satisfies
- - -
-### `tf.batch_matrix_triangular_solve(matrix, rhs, lower=None, name=None)` {#batch_matrix_triangular_solve}
+### `tf.batch_matrix_triangular_solve(matrix, rhs, lower=None, adjoint=None, name=None)` {#batch_matrix_triangular_solve}
Solves systems of linear equations with upper or lower triangular matrices by
@@ -1280,15 +1388,17 @@ backsubstitution.
`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
square matrices. If `lower` is `True` then the strictly upper triangular part
-of each inner-most matrix is ignored. If `lower` is False then the strictly
-lower triangular part of each inner-most matrix is ignored. `rhs` is a tensor
-of shape [..., M, K]`.
+of each inner-most matrix is assumed to be zero and not accessed.
+If `lower` is False then the strictly lower triangular part of each inner-most
+matrix is assumed to be zero and not accessed.
+`rhs` is a tensor of shape [..., M, K]`.
-The output is a tensor of shape `[..., M, K]`. If `lower` is `True` then the
-output satisfies
-\\(\sum_{k=0}^{i}\\) matrix[..., i, k] * output[..., k, j] = rhs[..., i, j].
-If `lower` is false then the strictly then the output satisfies
-\\(sum_{k=i}^{K-1}\\) matrix[..., i, k] * output[..., k, j] = rhs[..., i, j].
+The output is a tensor of shape `[..., M, K]`. If `adjoint` is `True` then the
+innermost matrices in output` satisfy matrix equations
+`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
+If `adjoint` is `False` then the strictly then the innermost matrices in
+`output` satisfy matrix equations
+`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`.
##### Args:
@@ -1298,7 +1408,11 @@ If `lower` is false then the strictly then the output satisfies
* <b>`rhs`</b>: A `Tensor`. Must have the same type as `matrix`.
Shape is `[..., M, K]`.
* <b>`lower`</b>: An optional `bool`. Defaults to `True`.
- Boolean indicating whether matrix is lower or upper triangular.
+ Boolean indicating whether the innermost matrices in `matrix` are
+ lower or upper triangular.
+* <b>`adjoint`</b>: An optional `bool`. Defaults to `False`.
+ Boolean indicating whether to solve with `matrix` or its (block-wise)
+ adjoint.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/session_ops.md b/tensorflow/g3doc/api_docs/python/session_ops.md
new file mode 100644
index 0000000000..388c2cb81b
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/session_ops.md
@@ -0,0 +1,102 @@
+<!-- This file is machine generated: DO NOT EDIT! -->
+
+# Tensor Handle Operations
+
+Note: Functions taking `Tensor` arguments can also take anything accepted by
+[`tf.convert_to_tensor`](framework.md#convert_to_tensor).
+
+[TOC]
+
+## Tensor Handle Operations.
+
+TensorFlow provides several operators that allows the user to keep tensors
+"in-place" across run calls.
+
+- - -
+
+### `tf.get_session_handle(data, name=None)` {#get_session_handle}
+
+Return the handle of `data`.
+
+This is EXPERIMENTAL and subject to change.
+
+Keep `data` "in-place" in the runtime and create a handle that can be
+used to retrieve `data` in a subsequent run().
+
+Combined with `get_session_tensor`, we can keep a tensor produced in
+one run call in place, and use it as the input in a future run call.
+Below is a simple example:
+
+```python
+c = tf.mul(a, b)
+h = tf.get_session_handle(c)
+h = sess.run(h)
+
+p, a = tf.get_session_tensor(tf.float32)
+b = tf.mul(a, 10)
+c = sess.run(b, feed_dict={p: h.handle})
+```
+
+##### Args:
+
+
+* <b>`data`</b>: A tensor to be stored in the session.
+* <b>`name`</b>: Optional name prefix for the return tensor.
+
+##### Returns:
+
+ A scalar string tensor representing a unique handle for `data`.
+
+##### Raises:
+
+
+* <b>`TypeError`</b>: if `data` is not a Tensor.
+
+
+- - -
+
+### `tf.get_session_tensor(dtype, name=None)` {#get_session_tensor}
+
+Get the tensor of type `dtype` by feeding a tensor handle.
+
+This is EXPERIMENTAL and subject to change.
+
+Get the value of the tensor from a tensor handle. The tensor
+is produced in a previous run() and stored in the state of the
+session.
+
+##### Args:
+
+
+* <b>`dtype`</b>: The type of the output tensor.
+* <b>`name`</b>: Optional name prefix for the return tensor.
+
+##### Returns:
+
+ A pair of tensors. The first is a placeholder for feeding a
+ tensor handle and the second is the tensor in the session state
+ keyed by the tensor handle.
+
+
+- - -
+
+### `tf.delete_session_tensor(name=None)` {#delete_session_tensor}
+
+Delete the tensor by feeding a tensor handle.
+
+This is EXPERIMENTAL and subject to change.
+
+Delete the tensor of a given tensor handle. The tensor is produced
+in a previous run() and stored in the state of the session.
+
+##### Args:
+
+
+* <b>`name`</b>: Optional name prefix for the return tensor.
+
+##### Returns:
+
+ A pair of graph elements. The first is a placeholder for feeding a
+ tensor handle and the second is a deletion operation.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md
index 6e43a50045..172c478500 100644
--- a/tensorflow/g3doc/api_docs/python/state_ops.md
+++ b/tensorflow/g3doc/api_docs/python/state_ops.md
@@ -781,7 +781,7 @@ checkpoints per device.
- - -
-#### `tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta')` {#Saver.save}
+#### `tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True)` {#Saver.save}
Saves variables.
@@ -807,6 +807,8 @@ path can be passed directly to a call to `restore()`.
managed by the saver to keep track of recent checkpoints. Defaults to
'checkpoint'.
* <b>`meta_graph_suffix`</b>: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
+* <b>`write_meta_graph`</b>: `Boolean` indicating whether or not to write the meta
+ graph file.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/string_ops.md b/tensorflow/g3doc/api_docs/python/string_ops.md
new file mode 100644
index 0000000000..c3d275ac6d
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/string_ops.md
@@ -0,0 +1,96 @@
+<!-- This file is machine generated: DO NOT EDIT! -->
+
+# Strings
+
+Note: Functions taking `Tensor` arguments can also take anything accepted by
+[`tf.convert_to_tensor`](framework.md#convert_to_tensor).
+
+[TOC]
+
+## Hashing
+
+String hashing ops take a string input tensor and map each element to an
+integer.
+
+- - -
+
+### `tf.string_to_hash_bucket(string_tensor, num_buckets, name=None)` {#string_to_hash_bucket}
+
+Converts each string in the input Tensor to its hash mod by a number of buckets.
+
+The hash function is deterministic on the content of the string within the
+process.
+
+Note that the hash function may change from time to time.
+
+##### Args:
+
+
+* <b>`string_tensor`</b>: A `Tensor` of type `string`.
+* <b>`num_buckets`</b>: An `int` that is `>= 1`. The number of buckets.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `Tensor` of type `int64`.
+ A Tensor of the same shape as the input `string_tensor`.
+
+
+
+## Joining
+
+String joining ops concatenate elements of input string tensors to produce a new
+string tensor.
+
+- - -
+
+### `tf.reduce_join(inputs, reduction_indices, keep_dims=None, separator=None, name=None)` {#reduce_join}
+
+Joins a string Tensor across the given dimensions.
+
+Computes the string join across dimensions in the given string Tensor of shape
+`[d_0, d_1, ..., d_n-1]`. Returns a new Tensor created by joining the input
+strings with the given separator (default: empty string). Negative indices are
+counted backwards from the end, with `-1` being equivalent to `n - 1`. Passing
+an empty `reduction_indices` joins all strings in linear index order and outputs
+a scalar string.
+
+
+For example:
+```
+# tensor `a` is [["a", "b"], ["c", "d"]]
+tf.reduce_join(a, 0) ==> ["ac", "bd"]
+tf.reduce_join(a, 1) ==> ["ab", "cd"]
+tf.reduce_join(a, -2) = tf.reduce_join(a, 0) ==> ["ac", "bd"]
+tf.reduce_join(a, -1) = tf.reduce_join(a, 1) ==> ["ab", "cd"]
+tf.reduce_join(a, 0, keep_dims=True) ==> [["ac", "bd"]]
+tf.reduce_join(a, 1, keep_dims=True) ==> [["ab"], ["cd"]]
+tf.reduce_join(a, 0, separator=".") ==> ["a.c", "b.d"]
+tf.reduce_join(a, [0, 1]) ==> ["acbd"]
+tf.reduce_join(a, [1, 0]) ==> ["abcd"]
+tf.reduce_join(a, []) ==> ["abcd"]
+```
+
+##### Args:
+
+
+* <b>`inputs`</b>: A `Tensor` of type `string`.
+ The input to be joined. All reduced indices must have non-zero size.
+* <b>`reduction_indices`</b>: A `Tensor` of type `int32`.
+ The dimensions to reduce over. Dimensions are reduced in the
+ order specified. If `reduction_indices` has higher rank than `1`, it is
+ flattened. Omitting `reduction_indices` is equivalent to passing
+ `[n-1, n-2, ..., 0]`. Negative indices from `-n` to `-1` are supported.
+* <b>`keep_dims`</b>: An optional `bool`. Defaults to `False`.
+ If `True`, retain reduced dimensions with length `1`.
+* <b>`separator`</b>: An optional `string`. Defaults to `""`.
+ The separator to use when joining.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `Tensor` of type `string`.
+ Has shape equal to that of the input with reduced dimensions removed or
+ set to `1` depending on `keep_dims`.
+
+
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 958c499159..a8ac2c8209 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -575,6 +575,9 @@ tf_gen_op_wrapper_py(
"TensorArraySplit",
"TensorArrayUnpack",
"TensorArrayWrite",
+ "GetSessionHandle",
+ "GetSessionTensor",
+ "DeleteSessionTensor",
],
require_shape_functions = True,
)
@@ -810,6 +813,7 @@ py_library(
"ops/rnn_cell.py",
"ops/script_ops.py",
"ops/seq2seq.py",
+ "ops/session_ops.py",
"ops/sparse_grad.py",
"ops/sparse_ops.py",
"ops/standard_ops.py",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 9c02283f64..4306d38e69 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -106,6 +106,7 @@ from tensorflow.python.ops import histogram_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
+from tensorflow.python.ops import session_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
@@ -121,8 +122,8 @@ _whitelist = set([app, compat, contrib, errors, flags, gfile, image,
__all__ = make_all(__name__,
[framework_lib, array_ops, client_lib, constant_op,
control_flow_ops, functional_ops, histogram_ops, io_ops,
- math_ops, nn, script_ops, sparse_ops, state_ops, string_ops,
- train])
+ math_ops, nn, script_ops, session_ops, sparse_ops,
+ state_ops, string_ops, train])
# Symbols whitelisted for export without documentation.
# TODO(cwhipkey): review these and move to contrib, expose through
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index a9bfdb63c0..817965f992 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.ops import session_ops
from tensorflow.python.platform import logging
from tensorflow.python.util import compat
@@ -99,6 +100,9 @@ class BaseSession(SessionInterface):
self._extend_lock = threading.Lock()
self._target = target
+ self._delete_lock = threading.Lock()
+ self._dead_handles = []
+
self._session = None
opts = tf_session.TF_NewSessionOptions(target=target, config=config)
@@ -277,6 +281,9 @@ class BaseSession(SessionInterface):
the *i*th return value will be a
[`SparseTensorValue`](../../api_docs/python/sparse_ops.md#SparseTensorValue)
containing the value of that sparse tensor.
+ * If the *i*th element of `fetches` is produced by a `get_tensor_handle` op,
+ the *i*th return value will be a numpy ndarray containing the handle of
+ that tensor.
The optional `feed_dict` argument allows the caller to override
the value of tensors in the graph. Each key in `feed_dict` can be
@@ -350,17 +357,22 @@ class BaseSession(SessionInterface):
list of feeds and fetches that will be used in the subsequent
`partial_run` calls.
- Below is a simple example:
+ The optional `feed_dict` argument allows the caller to override
+ the value of tensors in the graph. See run() for more information.
- a = array_ops.placeholder(dtypes.float32, shape=[])
- b = array_ops.placeholder(dtypes.float32, shape=[])
- c = array_ops.placeholder(dtypes.float32, shape=[])
- r1 = math_ops.add(a, b)
- r2 = math_ops.mul(r1, c)
+ Below is a simple example:
- h = sess.partial_run_setup([r1, r2], [a, b, c])
- res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
- res = sess.partial_run(h, r2, feed_dict={c: res})
+ ```python
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.mul(r1, c)
+
+ h = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
+ res = sess.partial_run(h, r2, feed_dict={c: res})
+ ```
Args:
handle: A handle for a sequence of partial runs.
@@ -410,7 +422,7 @@ class BaseSession(SessionInterface):
'graph before calling run().')
# Validate and process fetches.
- unique_fetches, target_list, _ = self._process_fetches(fetches)
+ unique_fetches, target_list, _, _ = self._process_fetches(fetches)
# Create request.
feed_list = []
@@ -455,6 +467,7 @@ class BaseSession(SessionInterface):
fetches = [fetches]
unique_fetch_targets = set()
+ unique_fetch_handles = {}
target_list = []
fetch_info = []
@@ -465,10 +478,15 @@ class BaseSession(SessionInterface):
try:
fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True,
allow_operation=True)
+ fetch_name = compat.as_bytes(fetch_t.name)
if isinstance(fetch_t, ops.Operation):
- target_list.append(compat.as_bytes(fetch_t.name))
+ target_list.append(fetch_name)
else:
- subfetch_names.append(compat.as_bytes(fetch_t.name))
+ subfetch_names.append(fetch_name)
+ # Remember the fetch if it is for a tensor handle.
+ if (isinstance(fetch_t, ops.Tensor) and
+ fetch_t.op.type == 'GetSessionHandle'):
+ unique_fetch_handles[fetch_name] = fetch_t.op.inputs[0].dtype
except TypeError as e:
raise TypeError('Fetch argument %r of %r has invalid type %r, '
'must be a string or Tensor. (%s)'
@@ -483,7 +501,7 @@ class BaseSession(SessionInterface):
fetch_info.append((subfetch_names, fetch_contraction_fn))
unique_fetch_targets = list(unique_fetch_targets)
- return unique_fetch_targets, target_list, fetch_info
+ return unique_fetch_targets, target_list, fetch_info, unique_fetch_handles
def _run(self, handle, fetches, feed_dict, options, run_metadata):
"""Perform either run or partial_run, depending the exitence of `handle`."""
@@ -502,10 +520,15 @@ class BaseSession(SessionInterface):
'graph before calling run().')
# Validate and process fetches.
- unique_fetches, target_list, fetch_info = self._process_fetches(fetches)
+ processed_fetches = self._process_fetches(fetches)
+ unique_fetches = processed_fetches[0]
+ target_list = processed_fetches[1]
+ fetch_info = processed_fetches[2]
+ unique_handles = processed_fetches[3]
# Create request.
feed_dict_string = {}
+ feed_map = {}
# Validate and process feed_dict.
if feed_dict:
@@ -522,7 +545,6 @@ class BaseSession(SessionInterface):
raise TypeError('The value of a feed cannot be a tf.Tensor object. '
'Acceptable feed values include Python scalars, '
'strings, lists, or numpy ndarrays.')
-
np_val = np.array(subfeed_val, dtype=subfeed_t.dtype.as_numpy_dtype)
if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
raise ValueError(
@@ -531,17 +553,31 @@ class BaseSession(SessionInterface):
% (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
if not self.graph.is_feedable(subfeed_t):
raise ValueError('Tensor %s may not be fed.' % subfeed_t)
- feed_dict_string[compat.as_bytes(subfeed_t.name)] = np_val
+ subfeed_name = compat.as_bytes(subfeed_t.name)
+ feed_dict_string[subfeed_name] = np_val
+ feed_map[subfeed_name] = (subfeed_t, subfeed_val)
# Run request and get response.
- results = self._do_run(handle, target_list, unique_fetches,
- feed_dict_string, options, run_metadata)
+ movers = self._update_with_movers(feed_dict_string, feed_map)
+ try:
+ results = self._do_run(handle, target_list, unique_fetches,
+ feed_dict_string, options, run_metadata)
+ finally:
+ # The movers are no longer used. Delete them.
+ for handle in movers:
+ self._register_dead_handle(handle)
# User may have fetched the same tensor multiple times, but we
# only fetch them from the runtime once. Furthermore, they may
# be wrapped as a tuple of tensors. Here we map the results back
# to what the client asked for.
- fetched_results = dict(zip(unique_fetches, results))
+ # TODO(yuanbyu): Use the contraction_fn in _REGISTERED_EXPANSIONS.
+ fetched_results = {}
+ for fetch, result in zip(unique_fetches, results):
+ dtype = unique_handles.get(fetch)
+ if dtype:
+ result = session_ops.TensorHandle(result, dtype, self)
+ fetched_results[fetch] = result
ret = []
for fetch_names, fetch_contraction_fn in fetch_info:
if fetch_names:
@@ -642,6 +678,55 @@ class BaseSession(SessionInterface):
self._current_version = self._graph.version
+ # The threshold to run garbage collection to delete dead tensors.
+ _DEAD_HANDLES_THRESHOLD = 10
+
+ def _register_dead_handle(self, handle):
+ # Register a dead handle in the session. Delete the dead tensors when
+ # the number of dead tensors exceeds certain threshold.
+ tensors_to_delete = None
+ with self._delete_lock:
+ self._dead_handles.append(handle)
+ if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD:
+ tensors_to_delete = self._dead_handles
+ self._dead_handles = []
+ # Delete the dead tensors.
+ # TODO(yuanbyu): For now we use a sequence of runs to minimize the graph
+ # size and the overhead of graph construction/partitioning.
+ if tensors_to_delete:
+ for tensor_handle in tensors_to_delete:
+ feeds = {}
+ fetches = []
+ holder, deleter = session_ops._get_handle_deleter(self.graph,
+ tensor_handle)
+ feeds[holder] = tensor_handle
+ fetches.append(deleter)
+ self.run(fetches, feed_dict=feeds)
+
+ def _update_with_movers(self, feed_dict, feed_map):
+ # If a tensor handle that is fed to a device incompatible placeholder,
+ # we move the tensor to the right device, generate a new tensor handle,
+ # and update `feed_dict` to use the new handle.
+ handle_movers = []
+ for feed_name, val in feed_map.items():
+ mover = session_ops._get_handle_mover(self.graph, *val)
+ if mover:
+ handle_movers.append((feed_name, val[1], mover))
+ # Transfer a tensor to the right device if needed.
+ if not handle_movers:
+ return []
+ else:
+ feeds = {}
+ fetches = []
+ for _, handle, mover in handle_movers:
+ feeds[mover[0]] = handle
+ fetches.append(mover[1])
+ handles = self.run(fetches, feed_dict=feeds)
+ for handle_mover, handle in zip(handle_movers, handles):
+ np_val = np.array(handle.handle, dtype=np.object)
+ feed_dict[handle_mover[0]] = np_val
+ return handles
+
class Session(BaseSession):
"""A class for running TensorFlow operations.
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index e29d0d70ab..8bcde1f6b4 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -85,6 +85,7 @@ def all_libraries(module_to_name, members, documented):
library("histogram_ops", "Histograms"),
library("control_flow_ops", "Control Flow", prefix=PREFIX_TEXT),
library("functional_ops", "Higher Order Functions", prefix=PREFIX_TEXT),
+ library("session_ops", "Tensor Handle Operations", prefix=PREFIX_TEXT),
library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"],
prefix=PREFIX_TEXT),
library("sparse_ops", "Sparse Tensors",
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index aee976c71e..b7bc6690b2 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1871,6 +1871,14 @@ class Graph(object):
self._colocation_stack = []
# Set of tensors that are dangerous to feed!
self._unfeedable_tensors = set()
+ # A map of tensor handle placeholder to tensor dtype.
+ self._handle_feeders = {}
+ # A map from tensor handle to its read op.
+ self._handle_readers = {}
+ # A map from tensor handle to its move op.
+ self._handle_movers = {}
+ # A map from tensor handle to its delete op.
+ self._handle_deleters = {}
def _check_not_finalized(self):
"""Check if the graph is finalized.
diff --git a/tensorflow/python/kernel_tests/session_ops_test.py b/tensorflow/python/kernel_tests/session_ops_test.py
new file mode 100644
index 0000000000..4f61055cbc
--- /dev/null
+++ b/tensorflow/python/kernel_tests/session_ops_test.py
@@ -0,0 +1,157 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.session_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class SessionOpsTest(tf.test.TestCase):
+
+ def testHandleBasic(self):
+ with self.test_session() as sess:
+ # Return a handle.
+ a = tf.constant(10)
+ b = tf.constant(5)
+ c = tf.mul(a, b)
+ h = tf.get_session_handle(c)
+ h = sess.run(h)
+
+ # Feed a tensor handle.
+ f, x = tf.get_session_tensor(tf.int32)
+ y = tf.mul(x, 10)
+ self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))
+
+ def testHandleEval(self):
+ with self.test_session() as sess:
+ # Return a handle.
+ a = tf.constant(10)
+ b = tf.constant(5)
+ c = tf.mul(a, b)
+ h = tf.get_session_handle(c)
+ h = sess.run(h)
+
+ # Get the tensor from its handle.
+ self.assertEqual(50, h.eval())
+
+ def testHandleAndValue(self):
+ with self.test_session() as sess:
+ # Return a handle and a value.
+ a = tf.constant(10)
+ b = tf.constant(5)
+ c = tf.mul(a, b)
+ h = tf.get_session_handle(c)
+ v = tf.mul(a, c)
+ h, v = sess.run([h, v])
+
+ self.assertEqual(50, h.eval())
+ self.assertEqual(500, v)
+
+ def testHandleCond(self):
+ with self.test_session() as sess:
+ # Return a handle and a value
+ a = tf.constant(10)
+ b = tf.constant(5)
+ p = tf.less(a, b)
+ c = tf.mul(a, b)
+ h = tf.get_session_handle(c)
+ p, h = sess.run([p, h])
+
+ # Run by feeding a tensor handle.
+ f, x = tf.get_session_tensor(tf.int32)
+ if p:
+ y = tf.mul(x, 10)
+ else:
+ y = tf.mul(x, 100)
+ result = sess.run(y, feed_dict={f: h.handle})
+
+ self.assertEqual(5000, result)
+
+ def testHandleForLoop(self):
+ with self.test_session() as sess:
+ # Initialize a handle.
+ a = tf.constant(0)
+ h = tf.get_session_handle(a)
+ h = sess.run(h)
+
+ # Do some computation.
+ f, x = tf.get_session_tensor(tf.int32)
+ # Must define the loop body outside the loop.
+ h_x = tf.get_session_handle(tf.add(x, 1))
+ for _ in range(100):
+ # This exercises garbage collection.
+ h = sess.run(h_x, feed_dict={f: h.handle})
+
+ self.assertEqual(100, h.eval())
+
+ def testHandleWhileLoop(self):
+ with self.test_session() as sess:
+ # Initialize a handle.
+ a = tf.constant(0)
+ h = tf.get_session_handle(a)
+ h = sess.run(h)
+
+ # Do some computation.
+ f, x = tf.get_session_tensor(tf.int32)
+ b = tf.constant(100)
+ p = tf.less(x, b)
+ # Must define the loop body outside the loop.
+ h_x = tf.get_session_handle(tf.add(x, 1))
+ while True:
+ rp, h = sess.run([p, h_x], feed_dict={f: h.handle})
+ if not rp:
+ break
+
+ self.assertEqual(101, h.eval())
+
+ def testHandleMover(self):
+ with self.test_session() as sess:
+ # Return a handle.
+ a = tf.constant(10)
+ b = tf.constant(5)
+ c = tf.mul(a, b)
+ h = tf.get_session_handle(c)
+ h = sess.run(h)
+
+ # Feed a tensor handle.
+ f, x = tf.get_session_tensor(tf.int32)
+ y = tf.mul(x, 10)
+ self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))
+
+ # Feed another tensor handle.
+ with tf.device("/gpu:0"):
+ a = tf.constant(10)
+ h = tf.get_session_handle(a)
+ h = sess.run(h)
+ self.assertEqual(100, sess.run(y, feed_dict={f: h.handle}))
+
+ def testHandleDeleter(self):
+ with self.test_session() as sess:
+ # Return a handle.
+ a = tf.constant(10)
+ b = tf.constant(5)
+ c = tf.mul(a, b)
+ h = tf.get_session_handle(c)
+ h = sess.run(h)
+
+ # Delete using a raw tensor handle.
+ h = h.get_raw_handle()
+ f, x = tf.delete_session_tensor()
+ sess.run(x, feed_dict={f: h})
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 5458444337..1b09c90503 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -257,6 +257,7 @@ def merge(inputs, name=None):
else:
dense_shape = None
return ops.IndexedSlices(values, indices, dense_shape), chosen_index
+# pylint: enable=protected-access
def _SwitchRefOrTensor(data, pred, name="Switch"):
@@ -970,9 +971,8 @@ class ControlFlowContext(object):
"""
while_ctxt = self.GetWhileContext()
if while_ctxt is not None:
- # pylint: disable=protected-access
op._add_control_input(while_ctxt.GetControlPivot().op)
- # pylint: enable=protected-access
+ # pylint: enable=protected-access
class CondContext(ControlFlowContext):
diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py
index 84cb9a39b1..dedecaa375 100644
--- a/tensorflow/python/ops/data_flow_grad.py
+++ b/tensorflow/python/ops/data_flow_grad.py
@@ -76,3 +76,7 @@ ops.NoGradient("Stack")
ops.NoGradient("StackPush")
ops.NoGradient("StackPop")
ops.NoGradient("StackClose")
+
+ops.NoGradient("GetSessionHandle")
+ops.NoGradient("GetSessionTensor")
+ops.NoGradient("DeleteSessionTensor")
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 214b7bb29a..3f72ccf5cd 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -570,6 +570,11 @@ ops.RegisterShape("StackPush")(common_shapes.unknown_shape)
ops.RegisterShape("StackPop")(common_shapes.unknown_shape)
ops.RegisterShape("StackClose")(_ScalarToVoidShape)
+# NOTE(yuanbyu): We probably can do better here.
+ops.RegisterShape("GetSessionHandle")(common_shapes.scalar_shape)
+ops.RegisterShape("GetSessionTensor")(common_shapes.unknown_shape)
+ops.RegisterShape("DeleteSessionTensor")(_ScalarToVoidShape)
+
@ops.RegisterShape("DynamicPartition")
def _DynamicPartitionShape(op):
diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py
new file mode 100644
index 0000000000..8a1ba164d6
--- /dev/null
+++ b/tensorflow/python/ops/session_ops.py
@@ -0,0 +1,255 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""## Tensor Handle Operations.
+
+TensorFlow provides several operators that allows the user to keep tensors
+"in-place" across run calls.
+
+@@get_session_handle
+@@get_session_tensor
+@@delete_session_tensor
+"""
+
+# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.util import compat
+
+
+class TensorHandle(object):
+ """Represents a handle for a live tensor in a session."""
+
+ def __init__(self, handle, dtype, session):
+ """Constructs a new tensor handle.
+
+ A tensor handle for a persistent tensor is a python string
+ that has the form of "tensor_name;unique_id;device_name".
+
+ Args:
+ handle: A tensor handle.
+ dtype: The data type of the tensor represented by `handle`.
+ session: The session in which the tensor is produced.
+ """
+ self._handle = compat.as_str_any(handle)
+ self._dtype = dtype
+ self._session = session
+ self._auto_gc_enabled = True
+
+ def __del__(self):
+ if self._auto_gc_enabled:
+ self._session._register_dead_handle(self.handle)
+
+ def __str__(self):
+ return self._handle
+
+ @property
+ def handle(self):
+ return self._handle
+
+ def eval(self):
+ """Return the value of the tensor represented by this handle."""
+ holder, reader = _get_handle_reader(self._session.graph, self._handle,
+ self._dtype)
+ return self._session.run(reader, feed_dict={holder: self._handle})
+
+ def get_raw_handle(self):
+ """Return the raw handle of the tensor.
+
+ Note that the method disables the automatic garbage collection of this
+ persistent tensor. The caller is now responsible for managing the life
+ time of the tensor.
+ """
+ self._auto_gc_enabled = False
+ return self._handle
+
+ @staticmethod
+ def _get_device_name(handle):
+ """The device name encoded in the handle."""
+ handle_str = compat.as_str_any(handle)
+ return pydev.canonical_name(handle_str.split(';')[-1])
+
+ @staticmethod
+ def _get_reader_key(handle):
+ """The graph key for reader."""
+ handle_parts = str(handle).split(';')
+ return handle_parts[0] + ';' + handle_parts[-1]
+
+ @staticmethod
+ def _get_deleter_key(handle):
+ """The graph key for deleter."""
+ return str(handle).split(';')[-1]
+
+ @staticmethod
+ def _get_mover_key(feeder, handle):
+ """The graph key for mover."""
+ return feeder.op.name + ';' + TensorHandle._get_reader_key(handle)
+
+
+def get_session_handle(data, name=None):
+ """Return the handle of `data`.
+
+ This is EXPERIMENTAL and subject to change.
+
+ Keep `data` "in-place" in the runtime and create a handle that can be
+ used to retrieve `data` in a subsequent run().
+
+ Combined with `get_session_tensor`, we can keep a tensor produced in
+ one run call in place, and use it as the input in a future run call.
+ Below is a simple example:
+
+ ```python
+ c = tf.mul(a, b)
+ h = tf.get_session_handle(c)
+ h = sess.run(h)
+
+ p, a = tf.get_session_tensor(tf.float32)
+ b = tf.mul(a, 10)
+ c = sess.run(b, feed_dict={p: h.handle})
+ ```
+
+ Args:
+ data: A tensor to be stored in the session.
+ name: Optional name prefix for the return tensor.
+
+ Returns:
+ A scalar string tensor representing a unique handle for `data`.
+
+ Raises:
+ TypeError: if `data` is not a Tensor.
+ """
+ if not isinstance(data, ops.Tensor):
+ raise TypeError('`data` must be of type Tensor.')
+
+ # Colocate this operation with data.
+ with ops.colocate_with(data):
+ return gen_data_flow_ops._get_session_handle(data, name=name)
+
+
+def get_session_tensor(dtype, name=None):
+ """Get the tensor of type `dtype` by feeding a tensor handle.
+
+ This is EXPERIMENTAL and subject to change.
+
+ Get the value of the tensor from a tensor handle. The tensor
+ is produced in a previous run() and stored in the state of the
+ session.
+
+ Args:
+ dtype: The type of the output tensor.
+ name: Optional name prefix for the return tensor.
+
+ Returns:
+ A pair of tensors. The first is a placeholder for feeding a
+ tensor handle and the second is the tensor in the session state
+ keyed by the tensor handle.
+ """
+ with ops.device(None):
+ # Commit the device when it is used the first time.
+ holder = array_ops.placeholder(dtypes.string)
+ _register_handle_feeder(holder.graph, holder, dtype)
+ tensor = gen_data_flow_ops._get_session_tensor(holder, dtype, name=name)
+ return (holder, tensor)
+
+
+def delete_session_tensor(name=None):
+ """Delete the tensor by feeding a tensor handle.
+
+ This is EXPERIMENTAL and subject to change.
+
+ Delete the tensor of a given tensor handle. The tensor is produced
+ in a previous run() and stored in the state of the session.
+
+ Args:
+ name: Optional name prefix for the return tensor.
+
+ Returns:
+ A pair of graph elements. The first is a placeholder for feeding a
+ tensor handle and the second is a deletion operation.
+ """
+ with ops.device(None):
+ # We will commit the device at the time it is used.
+ holder = array_ops.placeholder(dtypes.string)
+ deleter = gen_data_flow_ops._delete_session_tensor(holder, name=name)
+ return (holder, deleter)
+
+
+def _register_handle_feeder(graph, feeder, dtype):
+ graph._handle_feeders[feeder.op.name] = dtype
+
+
+def _get_handle_feeder(graph, feeder):
+ return graph._handle_feeders.get(feeder.op.name)
+
+
+def _get_handle_reader(graph, handle, dtype):
+ """Return a read subgraph for this handle."""
+ graph_key = TensorHandle._get_reader_key(handle)
+ result = graph._handle_readers.get(graph_key)
+ if result is None:
+ # Create reader if we haven't done it.
+ handle_device = TensorHandle._get_device_name(handle)
+ with ops.device(handle_device):
+ holder = array_ops.placeholder(dtypes.string)
+ _register_handle_feeder(holder.graph, holder, dtype)
+ reader = gen_data_flow_ops._get_session_tensor(holder, dtype)
+ result = (holder, reader)
+ graph._handle_readers[graph_key] = result
+ return result
+
+
+def _get_handle_mover(graph, feeder, handle):
+ """Return a move subgraph for this pair of feeder and handle."""
+ dtype = _get_handle_feeder(graph, feeder)
+ if dtype is None:
+ return None
+ handle_device = TensorHandle._get_device_name(handle)
+ if not feeder.op.device:
+ feeder.op._set_device(handle_device)
+ return None
+ if feeder.op.device == handle_device:
+ return None
+ # Now we know we have to move the tensor.
+ graph_key = TensorHandle._get_mover_key(feeder, handle)
+ result = graph._handle_movers.get(graph_key)
+ if result is None:
+ # Create mover if we haven't done it.
+ holder, reader = _get_handle_reader(graph, handle, dtype)
+ with ops.device(feeder.op.device):
+ mover = gen_data_flow_ops._get_session_handle(reader)
+ result = (holder, mover)
+ graph._handle_movers[graph_key] = result
+ return result
+
+
+def _get_handle_deleter(graph, handle):
+ """Return a deletion subgraph for this handle."""
+ graph_key = TensorHandle._get_deleter_key(handle)
+ result = graph._handle_deleters.get(graph_key)
+ if result is None:
+ # Create deleter if we haven't done it.
+ handle_device = TensorHandle._get_device_name(handle)
+ with ops.device(handle_device):
+ holder = array_ops.placeholder(dtypes.string)
+ deleter = gen_data_flow_ops._delete_session_tensor(holder)
+ result = (holder, deleter)
+ graph._handle_deleters[graph_key] = result
+ return result
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index ace9cd2c5c..28b2a4075d 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -52,6 +52,7 @@ from tensorflow.python.ops.parsing_ops import *
from tensorflow.python.ops.partitioned_variables import *
from tensorflow.python.ops.random_ops import *
from tensorflow.python.ops.script_ops import py_func
+from tensorflow.python.ops.session_ops import *
from tensorflow.python.ops.sparse_ops import *
from tensorflow.python.ops.state_ops import assign
from tensorflow.python.ops.state_ops import assign_add