aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-09-07 13:16:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-07 14:33:31 -0700
commit2677d3be19952488662a22cf9f42374a493ffd50 (patch)
tree0c2364940cf6e0c28d641da9bcec6a78f52454fd /tensorflow/c
parent9a55ed98a8edd44f2779f3a644a902ab05afbd32 (diff)
TensorFlow C API: Add a Set and Get Tensor Shape function.
Wires up ShapeRefiner into C and C++ API so that we can query and set the shapes of outputs. Currently only works for the C-API, but the plumbing exists for the C++ API, which we can only turn on once we are using the C++ shape functions for everything. Change: 132479208
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/c_api.cc116
-rw-r--r--tensorflow/c/c_api.h45
-rw-r--r--tensorflow/c/c_api_test.cc82
3 files changed, 237 insertions, 6 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 8f9b90a9af..a42b5362ae 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/shape_refiner.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
@@ -648,17 +650,22 @@ extern "C" {
struct TF_Graph {
TF_Graph()
: graph(OpRegistry::Global()), num_sessions(0), delete_requested(false) {}
- mutex mu; // protects all of the following
- Graph graph;
- std::unordered_map<tensorflow::string, Node*> name_map;
+ mutex mu;
+ Graph graph GUARDED_BY(mu);
+
+ // Runs shape inference.
+ tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
+
+ // Maps from name of an operation to the Node* in 'graph'.
+ std::unordered_map<tensorflow::string, Node*> name_map GUARDED_BY(mu);
// TF_Graph may only / must be deleted when
// num_sessions == 0 && delete_requested == true
// num_sessions incremented by TF_NewSessionWithGraph, and decremented by
// TF_DeleteSessionWithGraph.
- int num_sessions;
- bool delete_requested; // set true by TF_DeleteGraph
+ int num_sessions GUARDED_BY(mu);
+ bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
};
struct TF_OperationDescription {
@@ -711,6 +718,96 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
} // namespace
+// Shape functions -----------------------------------------------------------
+
+void TF_GraphSetTensorShape(TF_Graph* graph, TF_Port port, const int64_t* dims,
+ const int num_dims, TF_Status* status) {
+ Node* node = &port.oper->node;
+
+ mutex_lock l(graph->mu);
+ // Set the shape.
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(node);
+ if (ic == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Node ", node->name(), " was not found in the graph");
+ return;
+ }
+
+ std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
+ for (int i = 0; i < num_dims; ++i) {
+ dim_vec.push_back(ic->MakeDim(dims[i]));
+ }
+
+ tensorflow::shape_inference::ShapeHandle new_shape = ic->MakeShape(dim_vec);
+ status->status = graph->refiner.SetShape(node, port.index, new_shape);
+}
+
+int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Port port, TF_Status* status) {
+ Node* node = &port.oper->node;
+
+ mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(node);
+ if (ic == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Node ", node->name(), " was not found in the graph");
+ return -1;
+ }
+
+ tensorflow::shape_inference::ShapeHandle shape = ic->output(port.index);
+
+ // Unknown rank means the number of dimensions is -1.
+ if (!ic->RankKnown(shape)) {
+ return -1;
+ }
+
+ return ic->Rank(shape);
+}
+
+void TF_GraphGetTensorShape(TF_Graph* graph, TF_Port port, int64_t* dims,
+ int num_dims, TF_Status* status) {
+ Node* node = &port.oper->node;
+
+ mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(node);
+ if (ic == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Node ", node->name(), " was not found in the graph");
+ return;
+ }
+
+ tensorflow::shape_inference::ShapeHandle shape = ic->output(port.index);
+
+ int rank = -1;
+ if (ic->RankKnown(shape)) {
+ rank = ic->Rank(shape);
+ }
+
+ if (num_dims != rank) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Expected rank is ", num_dims, " but actual rank is ", rank);
+ return;
+ }
+
+ if (num_dims == 0) {
+ // Output shape is a scalar.
+ return;
+ }
+
+ // Rank is greater than 0, so fill in the values, if known, and
+ // -1 for unknown values.
+ for (int i = 0; i < num_dims; ++i) {
+ auto dim = ic->Dim(shape, i);
+ tensorflow::int64 value = -1;
+ if (ic->ValueKnown(dim)) {
+ value = ic->Value(dim);
+ }
+ dims[i] = value;
+ }
+}
+
// TF_OperationDescription functions ------------------------------------------
extern "C" {
@@ -946,7 +1043,16 @@ TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
"Duplicate node name in graph: '", desc->node_builder.node_name(), "'");
} else {
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
+
if (status->status.ok()) {
+ // Run shape inference function for newly added node.
+ //
+ // TODO(b/28152992): Enable returning the result of this
+ // code-path once we have converted all python shape functions
+ // to call their C++ versions.
+ desc->graph->refiner.AddNode(ret);
+
+ // Add the node to the name-to-node mapping.
desc->graph->name_map[ret->name()] = ret;
}
}
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 3f1d5346f8..69ea6fe5a4 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -281,6 +281,51 @@ typedef struct TF_Port {
int index; // Specifies the index of the input or output within oper.
} TF_Port;
+// Sets the shape of the Tensor referenced by `port` in `graph` to
+// the shape described by `dims` and `num_dims`.
+//
+// If the number of dimensions is unknown, `num_dims` must be
+// set to -1 and dims can be null. If a dimension is unknown,
+// the corresponding entry in the `dims` array must be -1.
+//
+// This does not overwrite the existing shape associated with `port`,
+// but merges the input shape with the existing shape. For example,
+// setting a shape of [-1, 2] with an existing shape [2, -1] would set
+// a final shape of [2, 2] based on shape merging semantics.
+//
+// Returns an error into `status` if:
+// * `port` is not in `graph`.
+// * An invalid shape is being set (e.g., the shape being set
+// is incompatible with the existing shape).
+extern void TF_GraphSetTensorShape(TF_Graph* graph, TF_Port port,
+ const int64_t* dims, const int num_dims,
+ TF_Status* status);
+
+// Returns the number of dimensions of the Tensor referenced by `port`
+// in `graph`.
+//
+// If the number of dimensions in the shape is unknown, returns -1.
+//
+// Returns an error into `status` if:
+// * `port` is not in `graph`.
+extern int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Port port,
+ TF_Status* status);
+
+// Returns the shape of the Tensor referenced by `port` in `graph`
+// into `dims`. `dims` must be an array large enough to hold `num_dims`
+// entries (e.g., the return value of TF_GraphGetTensorNumDims).
+//
+// If the number of dimensions in the shape is unknown or the shape is
+// a scalar, `dims` will remain untouched. Otherwise, each element of
+// `dims` will be set corresponding to the size of the dimension. An
+// unknown dimension is represented by `-1`.
+//
+// Returns an error into `status` if:
+// * `port` is not in `graph`.
+// * `num_dims` does not match the actual number of dimensions.
+extern void TF_GraphGetTensorShape(TF_Graph* graph, TF_Port port, int64_t* dims,
+ int num_dims, TF_Status* status);
+
// Operation will only be added to *graph when TF_FinishOperation() is
// called (assuming TF_FinishOperation() does not return an error).
// *graph must not be deleted until after TF_FinishOperation() is
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index a6b3a013bc..613030d15a 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -361,11 +361,91 @@ bool GetAttrValue(TF_Operation* oper, const char* attr_name,
return ret;
}
+TEST(CAPI, SetShape) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ TF_Operation* feed = Placeholder(graph, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Port feed_out_0 = TF_Port{feed, 0};
+ int num_dims;
+
+ // Fetch the shape, it should be completely unknown.
+ num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(-1, num_dims);
+
+ // Set the shape to be 2 x Unknown
+ int64_t dims[] = {2, -1};
+ TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ // Fetch the shape and validate it is 2 by -1.
+ num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(2, num_dims);
+
+ // Resize the dimension vector appropriately.
+ int64_t returned_dims[2];
+ TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(dims[0], returned_dims[0]);
+ EXPECT_EQ(dims[1], returned_dims[1]);
+
+ // Set to a new valid shape: [2, 3]
+ dims[1] = 3;
+ TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ // Fetch and see that the new value is returned.
+ TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(dims[0], returned_dims[0]);
+ EXPECT_EQ(dims[1], returned_dims[1]);
+
+ // Try to set 'unknown' on the shape and see that
+ // it doesn't change.
+ dims[0] = -1;
+ dims[1] = -1;
+ TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ // Fetch and see that the new value is returned.
+ TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(2, num_dims);
+ EXPECT_EQ(2, returned_dims[0]);
+ EXPECT_EQ(3, returned_dims[1]);
+
+ // Try to fetch a shape with the wrong num_dims
+ TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
+
+ // Try to set an invalid shape (cannot change 2x3 to a 2x5).
+ dims[1] = 5;
+ TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
+
+ // Test for a scalar.
+ TF_Operation* three = ScalarConst(3, graph, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Port three_out_0 = TF_Port{three, 0};
+
+ num_dims = TF_GraphGetTensorNumDims(graph, three_out_0, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(0, num_dims);
+ TF_GraphGetTensorShape(graph, three_out_0, returned_dims, num_dims, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ // Clean up
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
TEST(CAPI, Graph) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
- // Make a placeholder oper.
+ // Make a placeholder operation.
TF_Operation* feed = Placeholder(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);