diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-09-07 13:16:03 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-07 14:33:31 -0700 |
commit | 2677d3be19952488662a22cf9f42374a493ffd50 (patch) | |
tree | 0c2364940cf6e0c28d641da9bcec6a78f52454fd /tensorflow/c | |
parent | 9a55ed98a8edd44f2779f3a644a902ab05afbd32 (diff) |
TensorFlow C API: Add a Set and Get Tensor Shape function.
Wires up ShapeRefiner into C and C++ API so that we can
query and set the shapes of outputs. Currently only works
for the C-API, but the plumbing exists for the C++ API,
which we can only turn on once we are using the C++ shape
functions for everything.
Change: 132479208
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/c_api.cc | 116 | ||||
-rw-r--r-- | tensorflow/c/c_api.h | 45 | ||||
-rw-r--r-- | tensorflow/c/c_api_test.cc | 82 |
3 files changed, 237 insertions, 6 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 8f9b90a9af..a42b5362ae 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/shape_refiner.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" @@ -648,17 +650,22 @@ extern "C" { struct TF_Graph { TF_Graph() : graph(OpRegistry::Global()), num_sessions(0), delete_requested(false) {} - mutex mu; // protects all of the following - Graph graph; - std::unordered_map<tensorflow::string, Node*> name_map; + mutex mu; + Graph graph GUARDED_BY(mu); + + // Runs shape inference. + tensorflow::ShapeRefiner refiner GUARDED_BY(mu); + + // Maps from name of an operation to the Node* in 'graph'. + std::unordered_map<tensorflow::string, Node*> name_map GUARDED_BY(mu); // TF_Graph may only / must be deleted when // num_sessions == 0 && delete_requested == true // num_sessions incremented by TF_NewSessionWithGraph, and decremented by // TF_DeleteSessionWithGraph. - int num_sessions; - bool delete_requested; // set true by TF_DeleteGraph + int num_sessions GUARDED_BY(mu); + bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph }; struct TF_OperationDescription { @@ -711,6 +718,96 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, } // namespace +// Shape functions ----------------------------------------------------------- + +void TF_GraphSetTensorShape(TF_Graph* graph, TF_Port port, const int64_t* dims, + const int num_dims, TF_Status* status) { + Node* node = &port.oper->node; + + mutex_lock l(graph->mu); + // Set the shape. + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Node ", node->name(), " was not found in the graph"); + return; + } + + std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec; + for (int i = 0; i < num_dims; ++i) { + dim_vec.push_back(ic->MakeDim(dims[i])); + } + + tensorflow::shape_inference::ShapeHandle new_shape = ic->MakeShape(dim_vec); + status->status = graph->refiner.SetShape(node, port.index, new_shape); +} + +int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Port port, TF_Status* status) { + Node* node = &port.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Node ", node->name(), " was not found in the graph"); + return -1; + } + + tensorflow::shape_inference::ShapeHandle shape = ic->output(port.index); + + // Unknown rank means the number of dimensions is -1. + if (!ic->RankKnown(shape)) { + return -1; + } + + return ic->Rank(shape); +} + +void TF_GraphGetTensorShape(TF_Graph* graph, TF_Port port, int64_t* dims, + int num_dims, TF_Status* status) { + Node* node = &port.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Node ", node->name(), " was not found in the graph"); + return; + } + + tensorflow::shape_inference::ShapeHandle shape = ic->output(port.index); + + int rank = -1; + if (ic->RankKnown(shape)) { + rank = ic->Rank(shape); + } + + if (num_dims != rank) { + status->status = tensorflow::errors::InvalidArgument( + "Expected rank is ", num_dims, " but actual rank is ", rank); + return; + } + + if (num_dims == 0) { + // Output shape is a scalar. + return; + } + + // Rank is greater than 0, so fill in the values, if known, and + // -1 for unknown values. + for (int i = 0; i < num_dims; ++i) { + auto dim = ic->Dim(shape, i); + tensorflow::int64 value = -1; + if (ic->ValueKnown(dim)) { + value = ic->Value(dim); + } + dims[i] = value; + } +} + // TF_OperationDescription functions ------------------------------------------ extern "C" { @@ -946,7 +1043,16 @@ TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, "Duplicate node name in graph: '", desc->node_builder.node_name(), "'"); } else { status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); + if (status->status.ok()) { + // Run shape inference function for newly added node. + // + // TODO(b/28152992): Enable returning the result of this + // code-path once we have converted all python shape functions + // to call their C++ versions. + desc->graph->refiner.AddNode(ret); + + // Add the node to the name-to-node mapping. desc->graph->name_map[ret->name()] = ret; } } diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 3f1d5346f8..69ea6fe5a4 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -281,6 +281,51 @@ typedef struct TF_Port { int index; // Specifies the index of the input or output within oper. } TF_Port; +// Sets the shape of the Tensor referenced by `port` in `graph` to +// the shape described by `dims` and `num_dims`. +// +// If the number of dimensions is unknown, `num_dims` must be +// set to -1 and dims can be null. If a dimension is unknown, +// the corresponding entry in the `dims` array must be -1. +// +// This does not overwrite the existing shape associated with `port`, +// but merges the input shape with the existing shape. For example, +// setting a shape of [-1, 2] with an existing shape [2, -1] would set +// a final shape of [2, 2] based on shape merging semantics. +// +// Returns an error into `status` if: +// * `port` is not in `graph`. +// * An invalid shape is being set (e.g., the shape being set +// is incompatible with the existing shape). +extern void TF_GraphSetTensorShape(TF_Graph* graph, TF_Port port, + const int64_t* dims, const int num_dims, + TF_Status* status); + +// Returns the number of dimensions of the Tensor referenced by `port` +// in `graph`. +// +// If the number of dimensions in the shape is unknown, returns -1. +// +// Returns an error into `status` if: +// * `port` is not in `graph`. +extern int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Port port, + TF_Status* status); + +// Returns the shape of the Tensor referenced by `port` in `graph` +// into `dims`. `dims` must be an array large enough to hold `num_dims` +// entries (e.g., the return value of TF_GraphGetTensorNumDims). +// +// If the number of dimensions in the shape is unknown or the shape is +// a scalar, `dims` will remain untouched. Otherwise, each element of +// `dims` will be set corresponding to the size of the dimension. An +// unknown dimension is represented by `-1`. +// +// Returns an error into `status` if: +// * `port` is not in `graph`. +// * `num_dims` does not match the actual number of dimensions. +extern void TF_GraphGetTensorShape(TF_Graph* graph, TF_Port port, int64_t* dims, + int num_dims, TF_Status* status); + // Operation will only be added to *graph when TF_FinishOperation() is // called (assuming TF_FinishOperation() does not return an error). // *graph must not be deleted until after TF_FinishOperation() is diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index a6b3a013bc..613030d15a 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -361,11 +361,91 @@ bool GetAttrValue(TF_Operation* oper, const char* attr_name, return ret; } +TEST(CAPI, SetShape) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + TF_Operation* feed = Placeholder(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Port feed_out_0 = TF_Port{feed, 0}; + int num_dims; + + // Fetch the shape, it should be completely unknown. + num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(-1, num_dims); + + // Set the shape to be 2 x Unknown + int64_t dims[] = {2, -1}; + TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Fetch the shape and validate it is 2 by -1. + num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(2, num_dims); + + // Resize the dimension vector appropriately. + int64_t returned_dims[2]; + TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(dims[0], returned_dims[0]); + EXPECT_EQ(dims[1], returned_dims[1]); + + // Set to a new valid shape: [2, 3] + dims[1] = 3; + TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Fetch and see that the new value is returned. + TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(dims[0], returned_dims[0]); + EXPECT_EQ(dims[1], returned_dims[1]); + + // Try to set 'unknown' on the shape and see that + // it doesn't change. + dims[0] = -1; + dims[1] = -1; + TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + // Fetch and see that the new value is returned. + TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(2, num_dims); + EXPECT_EQ(2, returned_dims[0]); + EXPECT_EQ(3, returned_dims[1]); + + // Try to fetch a shape with the wrong num_dims + TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s); + + // Try to set an invalid shape (cannot change 2x3 to a 2x5). + dims[1] = 5; + TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s); + + // Test for a scalar. + TF_Operation* three = ScalarConst(3, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Port three_out_0 = TF_Port{three, 0}; + + num_dims = TF_GraphGetTensorNumDims(graph, three_out_0, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(0, num_dims); + TF_GraphGetTensorShape(graph, three_out_0, returned_dims, num_dims, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Clean up + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + TEST(CAPI, Graph) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); - // Make a placeholder oper. + // Make a placeholder operation. TF_Operation* feed = Placeholder(graph, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); |