diff options
author | Akshay Agrawal <akshayka@google.com> | 2017-12-21 11:23:20 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-21 11:28:14 -0800 |
commit | 1938feab97e36275f18a0745804299acfe137dc8 (patch) | |
tree | 36fc978e9203487326e7855b27f0faafe01c3cd9 /tensorflow/c/c_api_internal.h | |
parent | c7a05f4b18df0a9bd6b594d6f3d67b7489fcb54e (diff) |
This change adds a mechanism to the internal C API for updating an output handle's shapes and types after its source operation has been created.
Context: framework/ops.py was recently updated to use the C API when setting shapes for an op's outputs. This update broke shape inference for graph functions that captured resource handles; this, in turn, made it impossible to create graph functions from Python methods that required fully defined shapes (e.g., like MNIST's `call` method). In particular, the C API computes shapes for ops when they are created and does not update them thereafter; this is problematic because when a resource handle is captured while building a function, we need to update the captured operation's output handle in order to propagate its outputs shapes and dtypes.
PiperOrigin-RevId: 179837104
Diffstat (limited to 'tensorflow/c/c_api_internal.h')
-rw-r--r-- | tensorflow/c/c_api_internal.h | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index f8edc90a9f..91667056e0 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -194,6 +194,21 @@ TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +// Set the shapes and types of the output's handle. +// +// The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must +// all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the +// rank is known), then it must be equal to the length of `shapes[i]`; if +// `ranks[i] == 1`, then `shapes[i]` may be nullptr. +// +// TODO(akshayka): Implement a corresponding getter method. +void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, + int num_shapes_and_types, + const int64_t** shapes, + const int* ranks, + const TF_DataType* types, + TF_Status* status); + void RecordMutation(TF_Graph* graph, const TF_Operation& op, const char* mutation_type); |