aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_internal.h
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2017-12-21 11:23:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-21 11:28:14 -0800
commit1938feab97e36275f18a0745804299acfe137dc8 (patch)
tree36fc978e9203487326e7855b27f0faafe01c3cd9 /tensorflow/c/c_api_internal.h
parentc7a05f4b18df0a9bd6b594d6f3d67b7489fcb54e (diff)
This change adds a mechanism to the internal C API for updating an output handle's shapes and types after its source operation has been created.
Context: framework/ops.py was recently updated to use the C API when setting shapes for an op's outputs. This update broke shape inference for graph functions that captured resource handles; this, in turn, made it impossible to create graph functions from Python methods that required fully defined shapes (e.g., like MNIST's `call` method). In particular, the C API computes shapes for ops when they are created and does not update them thereafter; this is problematic because when a resource handle is captured while building a function, we need to update the captured operation's output handle in order to propagate its outputs shapes and dtypes. PiperOrigin-RevId: 179837104
Diffstat (limited to 'tensorflow/c/c_api_internal.h')
-rw-r--r--tensorflow/c/c_api_internal.h15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index f8edc90a9f..91667056e0 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -194,6 +194,21 @@ TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);
+// Set the shapes and types of the output's handle.
+//
+// The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must
+// all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the
+// rank is known), then it must be equal to the length of `shapes[i]`; if
+// `ranks[i] == 1`, then `shapes[i]` may be nullptr.
+//
+// TODO(akshayka): Implement a corresponding getter method.
+void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
+ int num_shapes_and_types,
+ const int64_t** shapes,
+ const int* ranks,
+ const TF_DataType* types,
+ TF_Status* status);
+
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type);