aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-03-09 11:13:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-09 11:30:12 -0800
commitcdecf416365c85f8274393e097ecab163cbea7c3 (patch)
tree5c57b539c5dba68678c22f2cbdf7da91c5c822cd /tensorflow/core
parenta62ed13bf366cb1a99183f93c859b6786f9d3070 (diff)
Enable the direct use of TensorHandles as feed values through ResourceHandles
This is motivated by, among other goals, the need to enhance memory efficiency during TFDBG's stepper operations. The stepper caches TensorHandles to already-continued-to tensors and use them as feeds if later continue-to actions depend on the tensors as transitive inputs. However, previously the TensorHandles had to be converted to Numpy arrays by calling eval() and the Numpy arrays were then fed back to next Session.run() calls. This mode of operation involved at least two unnecessary tensor-numpy and numpy-tensor copying. This CL makes it possible to use the ResourceHandle representations TensorHandles directly as feed values, eliminating the need for the aforementioned copying. To this end, the following changes are made 1) the underlying representations of TensorHandles are changed from string to ResourceHandle. A custom numpy struct type is created to allow ResourceHandle of the TensorHandle subtype to be fed during Session.run() calls. 2) added GetSessionHandleOpV2, which deprecates GetSessionHandleOp. The V2 op outputs a DT_RESOURCE Tensor, instead of a string Tensor in the deprecated version. Change: 149672538
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc31
-rw-r--r--tensorflow/core/common_runtime/direct_session.h3
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc12
-rw-r--r--tensorflow/core/common_runtime/session_state.cc2
-rw-r--r--tensorflow/core/framework/session_state.h2
-rw-r--r--tensorflow/core/graph/graph.cc1
-rw-r--r--tensorflow/core/kernels/session_ops.cc31
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc10
8 files changed, 82 insertions, 10 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 18bc8fb634..c4b2b6c12a 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -739,6 +739,26 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
return s;
}
+Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
+ Tensor* retrieved_tensor) {
+ if (resource_tensor.dtype() != DT_RESOURCE) {
+ return errors::InvalidArgument(strings::StrCat(
+ "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
+ resource_tensor.dtype()));
+ }
+
+ ResourceHandle resource_handle = resource_tensor.scalar<ResourceHandle>()();
+
+ if (resource_handle.hash_code() == MakeTypeIndex<Tensor>().hash_code()) {
+ return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid resource type hash code: ", resource_handle.hash_code(),
+ "(name: ", resource_handle.name(),
+ " type: ", resource_handle.maybe_type_name(), ")"));
+ }
+}
+
Status DirectSession::SendInputs(const NamedTensorList& inputs,
const ExecutorsAndKeys* executors_and_keys,
IntraProcessRendezvous* rendez) {
@@ -759,7 +779,16 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
return s;
}
- s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
+ if (input.second.dtype() == DT_RESOURCE) {
+ Tensor tensor_from_handle;
+ s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
+ if (s.ok()) {
+ s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
+ }
+ } else {
+ s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
+ }
+
if (!s.ok()) {
rendez->StartAbort(s);
return s;
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 3e3a5eaa8f..1495648631 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -192,6 +192,9 @@ class DirectSession : public Session {
::tensorflow::Status ExtendLocked(const GraphDef& graph)
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ ::tensorflow::Status ResourceHandleToInputTensor(
+ const Tensor& resource_tensor, Tensor* retrieved_tensor);
+
// Feeds more inputs to the executors, triggering further execution.
::tensorflow::Status SendInputs(
const std::vector<std::pair<string, Tensor>>& inputs,
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 9e717dfc23..c8b8a09b8e 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -627,7 +627,7 @@ TEST(DirectSessionTest, RunHandleTest) {
value1.scalar<float>()() = 2.0;
Node* const1 = test::graph::Constant(&g, value1);
Node* node3 = test::graph::Add(&g, identity0, const1);
- Node* node4 = test::graph::Unary(&g, "GetSessionHandle", node3);
+ Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
Tensor value2(DT_STRING, TensorShape({}));
Node* const2 = test::graph::Constant(&g, value2);
@@ -648,17 +648,21 @@ TEST(DirectSessionTest, RunHandleTest) {
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
+ ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+ Tensor string_handle(DT_STRING, {});
+ string_handle.flat<string>().setConstant(resource_handle.name());
+
// Second run call: Use a handle.
std::vector<Tensor> outputs1;
- s = session->Run({{const2->name(), outputs[0]}}, {node6->name() + ":0"}, {},
- &outputs1);
+ s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
+ {}, &outputs1);
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs1.size());
ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
// Third run call: Delete a handle.
std::vector<Tensor> outputs2;
- s = session->Run({{const2->name(), outputs[0]}}, {}, {node7->name()},
+ s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
&outputs2);
ASSERT_TRUE(s.ok());
}
diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc
index 2c80c4d112..7e7200070d 100644
--- a/tensorflow/core/common_runtime/session_state.cc
+++ b/tensorflow/core/common_runtime/session_state.cc
@@ -18,6 +18,8 @@ limitations under the License.
namespace tensorflow {
+const char* SessionState::kTensorHandleResourceTypeName = "TensorHandle";
+
Status SessionState::GetTensor(const string& handle, Tensor* tensor) {
mutex_lock l(state_lock_);
auto it = tensors_.find(handle);
diff --git a/tensorflow/core/framework/session_state.h b/tensorflow/core/framework/session_state.h
index a3eafcf474..8fbe940f6a 100644
--- a/tensorflow/core/framework/session_state.h
+++ b/tensorflow/core/framework/session_state.h
@@ -41,6 +41,8 @@ class SessionState {
int64 GetNewId();
+ static const char* kTensorHandleResourceTypeName;
+
private:
mutex state_lock_;
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 509c67c11f..6d9b114e90 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -98,6 +98,7 @@ void Node::Initialize(int id, int cost_id, Properties* props) {
SET_CLASS(NC_VARIABLE, ts, "VariableV2", "");
SET_CLASS(NC_IDENTITY, ts, "Identity", "RefIdentity");
SET_CLASS(NC_GET_SESSION_HANDLE, ts, "GetSessionHandle", "");
+ SET_CLASS(NC_GET_SESSION_HANDLE, ts, "GetSessionHandleV2", "");
SET_CLASS(NC_GET_SESSION_TENSOR, ts, "GetSessionTensor", "");
SET_CLASS(NC_DELETE_SESSION_TENSOR, ts, "DeleteSessionTensor", "");
if (class_ == NC_UNINITIALIZED) {
diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc
index 59fb225b92..54eca4a20a 100644
--- a/tensorflow/core/kernels/session_ops.cc
+++ b/tensorflow/core/kernels/session_ops.cc
@@ -41,13 +41,24 @@ class GetSessionHandleOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
- const Tensor& val = ctx->input(0);
+ Tensor val = ctx->input(0);
int64 id = ctx->session_state()->GetNewId();
TensorStore::TensorAndKey tk{val, id, def().device()};
OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(def().name(), tk));
+
Tensor* handle = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
- handle->flat<string>().setConstant(tk.GetHandle(def().name()));
+ if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
+ ResourceHandle resource_handle = MakeResourceHandle<Tensor>(
+ ctx, SessionState::kTensorHandleResourceTypeName,
+ tk.GetHandle(def().name()));
+ resource_handle.set_maybe_type_name(
+ SessionState::kTensorHandleResourceTypeName);
+ handle->scalar<ResourceHandle>()() = resource_handle;
+ } else {
+ // Legacy behavior in V1.
+ handle->flat<string>().setConstant(tk.GetHandle(def().name()));
+ }
}
TF_DISALLOW_COPY_AND_ASSIGN(GetSessionHandleOp);
@@ -55,12 +66,19 @@ class GetSessionHandleOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("GetSessionHandle").Device(DEVICE_CPU),
GetSessionHandleOp);
+REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2").Device(DEVICE_CPU),
+ GetSessionHandleOp);
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \
.Device(DEVICE_GPU) \
.HostMemory("handle") \
.TypeConstraint<type>("T"), \
+ GetSessionHandleOp) \
+ REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("T"), \
GetSessionHandleOp)
TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
@@ -73,12 +91,17 @@ REGISTER_GPU_KERNEL(bool);
.Device(DEVICE_SYCL) \
.HostMemory("handle") \
.TypeConstraint<type>("T"), \
+ GetSessionHandleOp) \
+ REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("T"), \
GetSessionHandleOp)
TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
REGISTER_SYCL_KERNEL(bool);
#undef REGISTER_SYCL_KERNEL
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
class GetSessionTensorOp : public OpKernel {
public:
@@ -147,5 +170,5 @@ REGISTER_KERNEL_BUILDER(
REGISTER_KERNEL_BUILDER(
Name("DeleteSessionTensor").Device(DEVICE_SYCL).HostMemory("handle"),
DeleteSessionTensorOp);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 365716b372..f2a78956e3 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -2152,11 +2152,19 @@ REGISTER_OP("GetSessionHandle")
.Output("handle: string")
.Attr("T: type")
.SetShapeFn(shape_inference::ScalarShape)
+ .Deprecated(23, "Use GetSessionHandleV2");
+
+REGISTER_OP("GetSessionHandleV2")
+ .Input("value: T")
+ .Output("handle: resource")
+ .Attr("T: type")
+ .SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Store the input tensor in the state of the current session.
value: The tensor to be stored.
-handle: The handle for the tensor stored in the session state.
+handle: The handle for the tensor stored in the session state, represented
+ as a ResourceHandle object.
)doc");
REGISTER_OP("GetSessionTensor")