aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-03-09 11:13:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-09 11:30:12 -0800
commitcdecf416365c85f8274393e097ecab163cbea7c3 (patch)
tree5c57b539c5dba68678c22f2cbdf7da91c5c822cd /tensorflow
parenta62ed13bf366cb1a99183f93c859b6786f9d3070 (diff)
Enable the direct use of TensorHandles as feed values through ResourceHandles
This is motivated by, among other goals, the need to enhance memory efficiency during TFDBG's stepper operations. The stepper caches TensorHandles to already-continued-to tensors and use them as feeds if later continue-to actions depend on the tensors as transitive inputs. However, previously the TensorHandles had to be converted to Numpy arrays by calling eval() and the Numpy arrays were then fed back to next Session.run() calls. This mode of operation involved at least two unnecessary tensor-numpy and numpy-tensor copying. This CL makes it possible to use the ResourceHandle representations TensorHandles directly as feed values, eliminating the need for the aforementioned copying. To this end, the following changes are made 1) the underlying representations of TensorHandles are changed from string to ResourceHandle. A custom numpy struct type is created to allow ResourceHandle of the TensorHandle subtype to be fed during Session.run() calls. 2) added GetSessionHandleOpV2, which deprecates GetSessionHandleOp. The V2 op outputs a DT_RESOURCE Tensor, instead of a string Tensor in the deprecated version. Change: 149672538
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc31
-rw-r--r--tensorflow/core/common_runtime/direct_session.h3
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc12
-rw-r--r--tensorflow/core/common_runtime/session_state.cc2
-rw-r--r--tensorflow/core/framework/session_state.h2
-rw-r--r--tensorflow/core/graph/graph.cc1
-rw-r--r--tensorflow/core/kernels/session_ops.cc31
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc10
-rw-r--r--tensorflow/python/client/session.py16
-rw-r--r--tensorflow/python/client/tf_session_helper.cc108
-rw-r--r--tensorflow/python/debug/lib/debug_data.py13
-rw-r--r--tensorflow/python/framework/dtypes.py3
-rw-r--r--tensorflow/python/kernel_tests/session_ops_test.py41
-rw-r--r--tensorflow/python/ops/data_flow_grad.py1
-rw-r--r--tensorflow/python/ops/hidden_ops.txt1
-rw-r--r--tensorflow/python/ops/session_ops.py43
16 files changed, 283 insertions, 35 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 18bc8fb634..c4b2b6c12a 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -739,6 +739,26 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
return s;
}
+Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
+ Tensor* retrieved_tensor) {
+ if (resource_tensor.dtype() != DT_RESOURCE) {
+ return errors::InvalidArgument(strings::StrCat(
+ "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
+ resource_tensor.dtype()));
+ }
+
+ ResourceHandle resource_handle = resource_tensor.scalar<ResourceHandle>()();
+
+ if (resource_handle.hash_code() == MakeTypeIndex<Tensor>().hash_code()) {
+ return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid resource type hash code: ", resource_handle.hash_code(),
+ "(name: ", resource_handle.name(),
+ " type: ", resource_handle.maybe_type_name(), ")"));
+ }
+}
+
Status DirectSession::SendInputs(const NamedTensorList& inputs,
const ExecutorsAndKeys* executors_and_keys,
IntraProcessRendezvous* rendez) {
@@ -759,7 +779,16 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
return s;
}
- s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
+ if (input.second.dtype() == DT_RESOURCE) {
+ Tensor tensor_from_handle;
+ s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
+ if (s.ok()) {
+ s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
+ }
+ } else {
+ s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
+ }
+
if (!s.ok()) {
rendez->StartAbort(s);
return s;
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 3e3a5eaa8f..1495648631 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -192,6 +192,9 @@ class DirectSession : public Session {
::tensorflow::Status ExtendLocked(const GraphDef& graph)
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ ::tensorflow::Status ResourceHandleToInputTensor(
+ const Tensor& resource_tensor, Tensor* retrieved_tensor);
+
// Feeds more inputs to the executors, triggering further execution.
::tensorflow::Status SendInputs(
const std::vector<std::pair<string, Tensor>>& inputs,
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 9e717dfc23..c8b8a09b8e 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -627,7 +627,7 @@ TEST(DirectSessionTest, RunHandleTest) {
value1.scalar<float>()() = 2.0;
Node* const1 = test::graph::Constant(&g, value1);
Node* node3 = test::graph::Add(&g, identity0, const1);
- Node* node4 = test::graph::Unary(&g, "GetSessionHandle", node3);
+ Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
Tensor value2(DT_STRING, TensorShape({}));
Node* const2 = test::graph::Constant(&g, value2);
@@ -648,17 +648,21 @@ TEST(DirectSessionTest, RunHandleTest) {
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
+ ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+ Tensor string_handle(DT_STRING, {});
+ string_handle.flat<string>().setConstant(resource_handle.name());
+
// Second run call: Use a handle.
std::vector<Tensor> outputs1;
- s = session->Run({{const2->name(), outputs[0]}}, {node6->name() + ":0"}, {},
- &outputs1);
+ s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
+ {}, &outputs1);
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs1.size());
ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
// Third run call: Delete a handle.
std::vector<Tensor> outputs2;
- s = session->Run({{const2->name(), outputs[0]}}, {}, {node7->name()},
+ s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
&outputs2);
ASSERT_TRUE(s.ok());
}
diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc
index 2c80c4d112..7e7200070d 100644
--- a/tensorflow/core/common_runtime/session_state.cc
+++ b/tensorflow/core/common_runtime/session_state.cc
@@ -18,6 +18,8 @@ limitations under the License.
namespace tensorflow {
+const char* SessionState::kTensorHandleResourceTypeName = "TensorHandle";
+
Status SessionState::GetTensor(const string& handle, Tensor* tensor) {
mutex_lock l(state_lock_);
auto it = tensors_.find(handle);
diff --git a/tensorflow/core/framework/session_state.h b/tensorflow/core/framework/session_state.h
index a3eafcf474..8fbe940f6a 100644
--- a/tensorflow/core/framework/session_state.h
+++ b/tensorflow/core/framework/session_state.h
@@ -41,6 +41,8 @@ class SessionState {
int64 GetNewId();
+ static const char* kTensorHandleResourceTypeName;
+
private:
mutex state_lock_;
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 509c67c11f..6d9b114e90 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -98,6 +98,7 @@ void Node::Initialize(int id, int cost_id, Properties* props) {
SET_CLASS(NC_VARIABLE, ts, "VariableV2", "");
SET_CLASS(NC_IDENTITY, ts, "Identity", "RefIdentity");
SET_CLASS(NC_GET_SESSION_HANDLE, ts, "GetSessionHandle", "");
+ SET_CLASS(NC_GET_SESSION_HANDLE, ts, "GetSessionHandleV2", "");
SET_CLASS(NC_GET_SESSION_TENSOR, ts, "GetSessionTensor", "");
SET_CLASS(NC_DELETE_SESSION_TENSOR, ts, "DeleteSessionTensor", "");
if (class_ == NC_UNINITIALIZED) {
diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc
index 59fb225b92..54eca4a20a 100644
--- a/tensorflow/core/kernels/session_ops.cc
+++ b/tensorflow/core/kernels/session_ops.cc
@@ -41,13 +41,24 @@ class GetSessionHandleOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
- const Tensor& val = ctx->input(0);
+ Tensor val = ctx->input(0);
int64 id = ctx->session_state()->GetNewId();
TensorStore::TensorAndKey tk{val, id, def().device()};
OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(def().name(), tk));
+
Tensor* handle = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
- handle->flat<string>().setConstant(tk.GetHandle(def().name()));
+ if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
+ ResourceHandle resource_handle = MakeResourceHandle<Tensor>(
+ ctx, SessionState::kTensorHandleResourceTypeName,
+ tk.GetHandle(def().name()));
+ resource_handle.set_maybe_type_name(
+ SessionState::kTensorHandleResourceTypeName);
+ handle->scalar<ResourceHandle>()() = resource_handle;
+ } else {
+ // Legacy behavior in V1.
+ handle->flat<string>().setConstant(tk.GetHandle(def().name()));
+ }
}
TF_DISALLOW_COPY_AND_ASSIGN(GetSessionHandleOp);
@@ -55,12 +66,19 @@ class GetSessionHandleOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("GetSessionHandle").Device(DEVICE_CPU),
GetSessionHandleOp);
+REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2").Device(DEVICE_CPU),
+ GetSessionHandleOp);
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \
.Device(DEVICE_GPU) \
.HostMemory("handle") \
.TypeConstraint<type>("T"), \
+ GetSessionHandleOp) \
+ REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("T"), \
GetSessionHandleOp)
TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
@@ -73,12 +91,17 @@ REGISTER_GPU_KERNEL(bool);
.Device(DEVICE_SYCL) \
.HostMemory("handle") \
.TypeConstraint<type>("T"), \
+ GetSessionHandleOp) \
+ REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("T"), \
GetSessionHandleOp)
TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
REGISTER_SYCL_KERNEL(bool);
#undef REGISTER_SYCL_KERNEL
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
class GetSessionTensorOp : public OpKernel {
public:
@@ -147,5 +170,5 @@ REGISTER_KERNEL_BUILDER(
REGISTER_KERNEL_BUILDER(
Name("DeleteSessionTensor").Device(DEVICE_SYCL).HostMemory("handle"),
DeleteSessionTensorOp);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 365716b372..f2a78956e3 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -2152,11 +2152,19 @@ REGISTER_OP("GetSessionHandle")
.Output("handle: string")
.Attr("T: type")
.SetShapeFn(shape_inference::ScalarShape)
+ .Deprecated(23, "Use GetSessionHandleV2");
+
+REGISTER_OP("GetSessionHandleV2")
+ .Input("value: T")
+ .Output("handle: resource")
+ .Attr("T: type")
+ .SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Store the input tensor in the state of the current session.
value: The tensor to be stored.
-handle: The handle for the tensor stored in the session state.
+handle: The handle for the tensor stored in the session state, represented
+ as a ResourceHandle object.
)doc");
REGISTER_OP("GetSessionTensor")
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index aa06d0ee70..7429bafff6 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -422,7 +422,9 @@ class _FetchHandler(object):
self._fetches.append(fetch_name)
self._ops.append(False)
# Remember the fetch if it is for a tensor handle.
- if isinstance(fetch, ops.Tensor) and fetch.op.type == 'GetSessionHandle':
+ if (isinstance(fetch, ops.Tensor) and
+ (fetch.op.type == 'GetSessionHandle' or
+ fetch.op.type == 'GetSessionHandleV2')):
self._fetch_handles[fetch_name] = fetch.op.inputs[0].dtype
self._final_fetches = [x for x in self._fetches if x not in feeds]
@@ -926,7 +928,7 @@ class BaseSession(SessionInterface):
if isinstance(subfeed_val, ops.Tensor):
raise TypeError('The value of a feed cannot be a tf.Tensor object. '
'Acceptable feed values include Python scalars, '
- 'strings, lists, or numpy ndarrays.')
+ 'strings, lists, numpy ndarrays, or TensorHandles.')
subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
if isinstance(subfeed_val,
@@ -937,9 +939,15 @@ class BaseSession(SessionInterface):
' Try explicitly setting the type of the feed tensor'
' to a larger type (e.g. int64).')
- np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
+ is_tensor_handle_feed = isinstance(subfeed_val,
+ session_ops.TensorHandle)
+ if is_tensor_handle_feed:
+ np_val = subfeed_val.to_numpy_array()
+ else:
+ np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
- if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
+ if (not is_tensor_handle_feed and
+ not subfeed_t.get_shape().is_compatible_with(np_val.shape)):
raise ValueError(
'Cannot feed value of shape %r for Tensor %r, '
'which has shape %r'
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index a69c56368f..a370b904b6 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -50,7 +50,11 @@ Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
PyObject* value;
Py_ssize_t pos = 0;
if (PyDict_Next(descr->fields, &pos, &key, &value)) {
- const char* key_string = PyBytes_AsString(key);
+ // In Python 3, the keys of numpy custom struct types are unicode, unlike
+ // Python 2, where the keys are bytes.
+ const char* key_string =
+ PyBytes_Check(key) ? PyBytes_AsString(key)
+ : PyBytes_AsString(PyUnicode_AsASCIIString(key));
if (!key_string) {
return errors::Internal("Corrupt numpy type descriptor");
}
@@ -69,6 +73,8 @@ Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
*out_tf_datatype = TF_QUINT16;
} else if (key == "qint32") {
*out_tf_datatype = TF_QINT32;
+ } else if (key == "resource") {
+ *out_tf_datatype = TF_RESOURCE;
} else {
return errors::Internal("Unsupported numpy data type");
}
@@ -125,6 +131,8 @@ Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
// Quantized types are currently represented as custom struct types.
// PyArray_TYPE returns NPY_VOID for structs, and we should look into
// descr to derive the actual type.
+ // Direct feeds of certain types of ResourceHandles are represented as a
+ // custom struct type.
return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype);
default:
// TODO(mrry): Support these.
@@ -175,6 +183,9 @@ Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
case TF_STRING:
*out_pyarray_type = NPY_OBJECT;
break;
+ case TF_RESOURCE:
+ *out_pyarray_type = NPY_VOID;
+ break;
// TODO(keveman): These should be changed to NPY_VOID, and the type used for
// the resulting numpy array should be the custom struct types that we
// expect for quantized types.
@@ -322,6 +333,61 @@ static Status CopyStringToPyArrayElement(PyArrayObject* pyarray, void* i_ptr,
return Status::OK();
}
+// Determine the dimensions of a numpy ndarray to be created to represent an
+// output Tensor.
+gtl::InlinedVector<npy_intp, 4> GetPyArrayDimensionsForTensor(
+ const TF_Tensor* tensor, tensorflow::int64* nelems) {
+ if (TF_TensorType(tensor) == TF_RESOURCE) {
+ gtl::InlinedVector<npy_intp, 4> dims(1);
+ ResourceHandle* resource_handle =
+ reinterpret_cast<ResourceHandle*>(TF_TensorData(tensor));
+ dims[0] = resource_handle->SerializeAsString().size();
+ *nelems = dims[0];
+
+ return dims;
+ } else {
+ const int ndims = TF_NumDims(tensor);
+ gtl::InlinedVector<npy_intp, 4> dims(ndims);
+ *nelems = 1;
+ for (int i = 0; i < ndims; ++i) {
+ dims[i] = TF_Dim(tensor, i);
+ *nelems *= dims[i];
+ }
+
+ return dims;
+ }
+}
+
+// Determine the type description (PyArray_Descr) of a numpy ndarray to be
+// created to represent an output Tensor.
+Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
+ PyArray_Descr** descr) {
+ if (TF_TensorType(tensor) == TF_RESOURCE) {
+ PyObject* field = PyTuple_New(3);
+#if PY_MAJOR_VERSION < 3
+ PyTuple_SetItem(field, 0, PyBytes_FromString("resource"));
+#else
+ PyTuple_SetItem(field, 0, PyUnicode_FromString("resource"));
+#endif
+ PyTuple_SetItem(field, 1, PyArray_TypeObjectFromType(NPY_UBYTE));
+ PyTuple_SetItem(field, 2, PyLong_FromLong(1));
+ PyObject* fields = PyList_New(1);
+ PyList_SetItem(fields, 0, field);
+ int convert_result = PyArray_DescrConverter(fields, descr);
+ if (convert_result != 1) {
+ return errors::Internal("Failed to create numpy array description for ",
+ "TF_RESOURCE-type tensor");
+ }
+ } else {
+ int type_num = -1;
+ TF_RETURN_IF_ERROR(
+ TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num));
+ *descr = PyArray_DescrFromType(type_num);
+ }
+
+ return Status::OK();
+}
+
// Converts the given TF_Tensor to a Numpy array.
// If the returned status is OK, the caller becomes the owner of *out_array.
Status TF_Tensor_to_PyObject(TF_Tensor* tensor, PyObject** out_array) {
@@ -333,26 +399,20 @@ Status TF_Tensor_to_PyObject(TF_Tensor* tensor, PyObject** out_array) {
return Status::OK();
}
- const int ndims = TF_NumDims(tensor);
- gtl::InlinedVector<npy_intp, 4> dims(ndims);
- tensorflow::int64 nelems = 1;
- for (int i = 0; i < ndims; ++i) {
- dims[i] = TF_Dim(tensor, i);
- nelems *= dims[i];
- }
+ tensorflow::int64 nelems = -1;
+ gtl::InlinedVector<npy_intp, 4> dims =
+ GetPyArrayDimensionsForTensor(tensor, &nelems);
// Convert TensorFlow dtype to numpy type descriptor.
- int type_num = -1;
- TF_RETURN_IF_ERROR(
- TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num));
- PyArray_Descr* descr = PyArray_DescrFromType(type_num);
+ PyArray_Descr* descr = nullptr;
+ TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor, &descr));
// Copy the TF_TensorData into a newly-created ndarray and return it.
// TODO(mrry): Perhaps investigate zero-copy approaches. This would involve
// creating an ndarray-like object that wraps the TF_Tensor buffer, and
// maps its destructor to TF_DeleteTensor.
Safe_PyObjectPtr safe_out_array =
- tensorflow::make_safe(PyArray_Empty(ndims, dims.data(), descr, 0));
+ tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
if (!safe_out_array) {
return errors::Internal("Could not allocate ndarray");
}
@@ -371,6 +431,12 @@ Status TF_Tensor_to_PyObject(TF_Tensor* tensor, PyObject** out_array) {
}
PyArray_ITER_NEXT(iter.get());
}
+ } else if (TF_TensorType(tensor) == TF_RESOURCE) {
+ ResourceHandle* resource_handle =
+ reinterpret_cast<ResourceHandle*>(TF_TensorData(tensor));
+ memcpy(PyArray_DATA(py_array),
+ resource_handle->SerializeAsString().c_str(),
+ PyArray_NBYTES(py_array));
} else {
return errors::Internal("ndarray was ", PyArray_NBYTES(py_array),
" bytes but TF_Tensor was ",
@@ -418,6 +484,8 @@ void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
Py_ssize_t pos = 0;
int index = 0;
Status s;
+
+ gtl::InlinedVector<std::shared_ptr<ResourceHandle>, 4> resource_handles;
while (PyDict_Next(feed_dict, &pos, &key, &value)) {
char* key_string = PyBytes_AsString(key);
if (!key_string) {
@@ -457,7 +525,19 @@ void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
// type, this steals a reference to array, which will be relinquished when
// the underlying buffer is deallocated. For string, a new temporary buffer
// is allocated into which the strings are encoded.
- if (dtype != TF_STRING) {
+ if (dtype == TF_RESOURCE) {
+ const string serialized(reinterpret_cast<char*>(PyArray_DATA(array)),
+ PyArray_NBYTES(array));
+ std::shared_ptr<ResourceHandle> resource_handle(new ResourceHandle());
+ resource_handle->ParseFromString(serialized);
+ resource_handles.emplace_back(resource_handle);
+ TF_Tensor* tensor =
+ TF_AllocateTensor(dtype, {}, 0, sizeof(ResourceHandle));
+ std::memcpy(TF_TensorData(tensor),
+ reinterpret_cast<void*>(resource_handle.get()),
+ sizeof(ResourceHandle));
+ inputs_safe.emplace_back(make_safe(tensor));
+ } else if (dtype != TF_STRING) {
// NOTE(mrry): We currently copy the numpy array into a new
// buffer to avoid possible issues on deallocation (such as
// having to acquire the Python Global Interpreter Lock).
diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py
index baaa15abca..af1256ee32 100644
--- a/tensorflow/python/debug/lib/debug_data.py
+++ b/tensorflow/python/debug/lib/debug_data.py
@@ -26,6 +26,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import types_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import gfile
@@ -77,10 +78,14 @@ def load_tensor_from_event(event):
if (event.summary.value[0].tensor.tensor_content or
event.summary.value[0].tensor.string_val):
# Initialized tensor.
- try:
- tensor_value = tensor_util.MakeNdarray(event.summary.value[0].tensor)
- except KeyError:
- tensor_value = None
+ tensor_proto = event.summary.value[0].tensor
+ if tensor_proto.dtype == types_pb2.DT_RESOURCE:
+ return None
+ else:
+ try:
+ tensor_value = tensor_util.MakeNdarray(tensor_proto)
+ except KeyError:
+ tensor_value = None
else:
# Uninitialized tensor or tensor of unconvertible data type.
tensor_value = None
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 7564cbcb05..d373bac47a 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -458,6 +458,9 @@ _np_qint16 = np.dtype([("qint16", np.int16, 1)])
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
+# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
+np_resource = np.dtype([("resource", np.ubyte, 1)])
+
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
_NP_TO_TF = frozenset([
(np.float16, float16),
diff --git a/tensorflow/python/kernel_tests/session_ops_test.py b/tensorflow/python/kernel_tests/session_ops_test.py
index 25d60c5259..41b678feb9 100644
--- a/tensorflow/python/kernel_tests/session_ops_test.py
+++ b/tensorflow/python/kernel_tests/session_ops_test.py
@@ -22,6 +22,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import session_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -229,6 +231,45 @@ class SessionOpsTest(test.TestCase):
b_p: b_handle.handle})
self.assertEqual(3.0, c_handle.eval())
+ def testFeedOneHandleDirectly(self):
+ with self.test_session() as sess:
+ a = constant_op.constant(10.0)
+ b = constant_op.constant(5.0)
+ c = math_ops.multiply(a, b)
+ d = math_ops.multiply(c, c)
+
+ h_c = sess.run(session_ops.get_session_handle(c))
+
+ self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c}))
+
+ def testFeedTwoHandlesDirectly(self):
+ with self.test_session() as sess:
+ a = constant_op.constant(10.0)
+ b = constant_op.constant(5.0)
+ c = math_ops.multiply(a, b)
+ d = math_ops.div(a, b)
+ e = math_ops.subtract(c, d)
+
+ h_c = sess.run(session_ops.get_session_handle(c))
+ h_d = sess.run(session_ops.get_session_handle(d))
+
+ self.assertAllClose(48.0, sess.run(e, feed_dict={c: h_c, d: h_d}))
+ self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c}))
+
+ def testFeedHandleToVariableDirectly(self):
+ with self.test_session() as sess:
+ a = variables.Variable(12.0)
+ inc_a = state_ops.assign_add(a, 2.0)
+ b = math_ops.add(a, 5.0)
+ sess.run(a.initializer)
+
+ h_a_read = sess.run(session_ops.get_session_handle(a.read_value()))
+ self.assertAllClose(12.0, sess.run(a))
+
+ self.assertAllClose(17.0, sess.run(b, feed_dict={a: h_a_read}))
+ sess.run(inc_a)
+ self.assertAllClose(19.0, sess.run(b, feed_dict={a: h_a_read}))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py
index 95c15f334d..79e94dace0 100644
--- a/tensorflow/python/ops/data_flow_grad.py
+++ b/tensorflow/python/ops/data_flow_grad.py
@@ -78,5 +78,6 @@ ops.NotDifferentiable("StackPop")
ops.NotDifferentiable("StackClose")
ops.NotDifferentiable("GetSessionHandle")
+ops.NotDifferentiable("GetSessionHandleV2")
ops.NotDifferentiable("GetSessionTensor")
ops.NotDifferentiable("DeleteSessionTensor")
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 4937f1a50a..fbfd7bb7d6 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -60,6 +60,7 @@ FakeQueue
FIFOQueue
FIFOQueueV2
GetSessionHandle
+GetSessionHandleV2
GetSessionTensor
HashTable
InitializeTable
diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py
index ff5d5a2b2f..0a06982ad7 100644
--- a/tensorflow/python/ops/session_ops.py
+++ b/tensorflow/python/ops/session_ops.py
@@ -16,6 +16,7 @@
"""Tensor Handle Operations. See the @{python/session_ops} guide.
@@get_session_handle
+@@get_session_handle_v2
@@get_session_tensor
@@delete_session_tensor
"""
@@ -25,6 +26,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
+import numpy as np
+
+from tensorflow.core.framework import resource_handle_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -33,6 +39,22 @@ from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.util import compat
+def decode_resource_handle(encoded):
+ """Decode a ResourceHandle proto encoded as custom numpy struct type."""
+ resource_handle = resource_handle_pb2.ResourceHandle()
+ if sys.version_info.major < 3:
+ resource_handle.ParseFromString("".join([chr(ch[0]) for ch in encoded]))
+ else:
+ resource_handle.ParseFromString(bytes([ch[0] for ch in encoded]))
+ return resource_handle
+
+
+def encode_resource_handle(resource_handle):
+ """Encode a ResourceHandle proto as custom numpy struct type."""
+ return np.asarray(bytearray(resource_handle.SerializeToString()),
+ dtype=dtypes.np_resource)
+
+
class TensorHandle(object):
"""Represents a handle for a live tensor in a session."""
@@ -47,7 +69,8 @@ class TensorHandle(object):
dtype: The data type of the tensor represented by `handle`.
session: The session in which the tensor is produced.
"""
- self._handle = compat.as_str_any(handle)
+ self._resource_handle = decode_resource_handle(handle)
+ self._handle = compat.as_str_any(self._resource_handle.name)
self._dtype = dtype
self._session = session
self._auto_gc_enabled = True
@@ -60,6 +83,20 @@ class TensorHandle(object):
return self._handle
@property
+ def resource_handle(self):
+ """The ResourceHandle representation of this handle."""
+ return self._resource_handle
+
+ def to_numpy_array(self):
+ """Convert a TensorHandle object to a feedable numpy value.
+
+ Returns:
+ A numpy array of a custom struct type that can be used as a feed value
+ to run().
+ """
+ return encode_resource_handle(self.resource_handle)
+
+ @property
def handle(self):
"""The string representation of this handle."""
return self._handle
@@ -154,7 +191,7 @@ def get_session_handle(data, name=None):
# Colocate this operation with data.
with ops.colocate_with(data):
- return gen_data_flow_ops._get_session_handle(data, name=name)
+ return gen_data_flow_ops._get_session_handle_v2(data, name=name) # pylint: disable=protected-access
def get_session_tensor(handle, dtype, name=None):
@@ -259,7 +296,7 @@ def _get_handle_mover(graph, feeder, handle):
# Create mover if we haven't done it.
holder, reader = _get_handle_reader(graph, handle, dtype)
with graph.as_default(), graph.device(feeder.op.device):
- mover = gen_data_flow_ops._get_session_handle(reader)
+ mover = gen_data_flow_ops._get_session_handle_v2(reader) # pylint: disable=protected-access
result = (holder, mover)
graph._handle_movers[graph_key] = result
return result