diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-18 15:02:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-18 15:05:23 -0700 |
commit | d964834a922e77198fd387aac6c6cc5970a31e7d (patch) | |
tree | e6e5e914abf941b161d46ad4c2e940422643eccf /tensorflow/c/eager | |
parent | 325ba9ece698d04082b173ba300a10623d27de96 (diff) |
Merged commit includes the following changes:
193422827 by yifeif:
Fix buildifier error.
--
193421691 by skyewm:
Make GraphModeFunctions work with _USE_C_SHAPES=True.
Tensor._handle_data is going away. This change adds special hooks for
propagating the resource handle shape information through
EagerTensors.
--
193421473 by A. Unique TensorFlower:
Register dynamic_stitch for DT_VARIANT type.
--
193421175 by nolivia:
disabling flaky tsan test
--
193420117 by nolivia:
disabling flaky test in tensorflow that has no apparent culprit
--
PiperOrigin-RevId: 193422827
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 57 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.h | 14 |
3 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index a2d96357ac..3e14c10727 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -41,6 +41,8 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + # TODO(b/74620627): move this here + "//tensorflow/python:cpp_shape_inference_proto_cc", ], }) + select({ "//tensorflow:with_xla_support": [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 393851d13c..369342b142 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/python/framework/cpp_shape_inference.pb.h" using tensorflow::int64; using tensorflow::string; @@ -1015,6 +1016,62 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, ctx->context.RunMetadataProto()->Clear(); } +void TFE_GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, + TF_Buffer* output_proto, + TF_Status* status) { + tensorflow::Node* node = &output.oper->node; + tensorflow::CppShapeInferenceResult::HandleData handle_data; + handle_data.set_is_set(true); + { + tensorflow::mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + CHECK(ic != nullptr); + CHECK_LT(output.index, ic->num_outputs()); + const auto* shapes_and_types = + ic->output_handle_shapes_and_types(output.index); + if (shapes_and_types == nullptr) { + output_proto->data = nullptr; + output_proto->length = 0; + output_proto->data_deallocator = nullptr; + return; + } + + for (const auto& p : *shapes_and_types) { + auto* out_shape_and_type = handle_data.add_shape_and_type(); + ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); + out_shape_and_type->set_dtype(p.dtype); + } + } + status->status = MessageToBuffer(handle_data, output_proto); +} + +void TFE_SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::CppShapeInferenceResult::HandleData handle_data; + if (!handle_data.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Couldn't deserialize HandleData proto"); + return; + } + DCHECK(handle_data.is_set()); + + tensorflow::mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&output.oper->node); + + std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types; + for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { + tensorflow::shape_inference::ShapeHandle shape; + status->status = + ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); + if (status->status.ok()) return; + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); + } + ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); +} + namespace { TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, TF_Status* status) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 3926c22ce1..15ac0f376c 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -329,6 +329,20 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Returns the serialized CppShapeInferenceResult::HandleData proto for +// `output` if its a resource tensor, or otherwise returns an empty buffer. +TF_CAPI_EXPORT extern void TFE_GetResourceHandleShapeAndType( + TF_Graph* graph, TF_Output output, TF_Buffer* output_proto, + TF_Status* status); + +// Sets `output` based on `proto`, which should be a serialized +// CppShapeInferenceResult::HandleData proto. +TF_CAPI_EXPORT extern void TFE_SetResourceHandleShapeAndType(TF_Graph* graph, + TF_Output output, + const void* proto, + size_t proto_len, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif |