aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-18 15:02:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 15:05:23 -0700
commitd964834a922e77198fd387aac6c6cc5970a31e7d (patch)
treee6e5e914abf941b161d46ad4c2e940422643eccf /tensorflow/c/eager
parent325ba9ece698d04082b173ba300a10623d27de96 (diff)
Merged commit includes the following changes:
193422827 by yifeif: Fix buildifier error. -- 193421691 by skyewm: Make GraphModeFunctions work with _USE_C_SHAPES=True. Tensor._handle_data is going away. This change adds special hooks for propagating the resource handle shape information through EagerTensors. -- 193421473 by A. Unique TensorFlower: Register dynamic_stitch for DT_VARIANT type. -- 193421175 by nolivia: disabling flaky tsan test -- 193420117 by nolivia: disabling flaky test in tensorflow that has no apparent culprit -- PiperOrigin-RevId: 193422827
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/BUILD2
-rw-r--r--tensorflow/c/eager/c_api.cc57
-rw-r--r--tensorflow/c/eager/c_api.h14
3 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index a2d96357ac..3e14c10727 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -41,6 +41,8 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ # TODO(b/74620627): move this here
+ "//tensorflow/python:cpp_shape_inference_proto_cc",
],
}) + select({
"//tensorflow:with_xla_support": [
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 393851d13c..369342b142 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -49,6 +49,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
+#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
using tensorflow::int64;
using tensorflow::string;
@@ -1015,6 +1016,62 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
ctx->context.RunMetadataProto()->Clear();
}
+void TFE_GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
+ TF_Buffer* output_proto,
+ TF_Status* status) {
+ tensorflow::Node* node = &output.oper->node;
+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
+ handle_data.set_is_set(true);
+ {
+ tensorflow::mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(node);
+ CHECK(ic != nullptr);
+ CHECK_LT(output.index, ic->num_outputs());
+ const auto* shapes_and_types =
+ ic->output_handle_shapes_and_types(output.index);
+ if (shapes_and_types == nullptr) {
+ output_proto->data = nullptr;
+ output_proto->length = 0;
+ output_proto->data_deallocator = nullptr;
+ return;
+ }
+
+ for (const auto& p : *shapes_and_types) {
+ auto* out_shape_and_type = handle_data.add_shape_and_type();
+ ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
+ out_shape_and_type->set_dtype(p.dtype);
+ }
+ }
+ status->status = MessageToBuffer(handle_data, output_proto);
+}
+
+void TFE_SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
+ const void* proto, size_t proto_len,
+ TF_Status* status) {
+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
+ if (!handle_data.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Couldn't deserialize HandleData proto");
+ return;
+ }
+ DCHECK(handle_data.is_set());
+
+ tensorflow::mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(&output.oper->node);
+
+ std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
+ for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
+ tensorflow::shape_inference::ShapeHandle shape;
+ status->status =
+ ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
+ if (status->status.ok()) return;
+ shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
+ }
+ ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
+}
+
namespace {
TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
TF_Status* status) {
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 3926c22ce1..15ac0f376c 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -329,6 +329,20 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
+// Returns the serialized CppShapeInferenceResult::HandleData proto for
+// `output` if its a resource tensor, or otherwise returns an empty buffer.
+TF_CAPI_EXPORT extern void TFE_GetResourceHandleShapeAndType(
+ TF_Graph* graph, TF_Output output, TF_Buffer* output_proto,
+ TF_Status* status);
+
+// Sets `output` based on `proto`, which should be a serialized
+// CppShapeInferenceResult::HandleData proto.
+TF_CAPI_EXPORT extern void TFE_SetResourceHandleShapeAndType(TF_Graph* graph,
+ TF_Output output,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif