aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_test_util.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-03-19 22:29:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-19 22:33:26 -0700
commit1a8258e0593270a8e2370517dff8faafce40a687 (patch)
treee097a315ba6c21f8f2ac72c5af17564d73efbc4e /tensorflow/c/c_test_util.cc
parent407ddd1c0539cfc5d33ab2629230eab5a958b7d4 (diff)
Added infeed support for experimental C APIs associated with TPU graph rewrite.
This initial design of the C API is different from (and mostly higher level than) the python API counterparts for infeed, in that the python API has explicit graph construction APIs for generating infeed enqueue/dequeue ops (e.g. split_inputs_and_generate_enqueue_ops() and generate_dequeue_op()), while the C API takes an input graph and redirects all input nodes to feed the infeed enqueue. One requirement/restriction is that the input nodes in the TF graph (e.g. Placeholder) must specify their tensor shapes, for infeed enqueue and dequeue nodes to properly compile with XLA. The API for more general shape support will be designed and implemented later. PiperOrigin-RevId: 189693028
Diffstat (limited to 'tensorflow/c/c_test_util.cc')
-rw-r--r--tensorflow/c/c_test_util.cc10
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index 22f77e7b87..f3b28c1708 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -94,18 +94,22 @@ TF_Tensor* FloatTensor(float v) {
// one cannot call ASSERT_* methods in non-void-returning functions (when
// exceptions are disabled during compilation)
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
- TF_DataType dtype, TF_Operation** op) {
+ TF_DataType dtype, const std::vector<int64_t>& dims,
+ TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", dtype);
+ if (!dims.empty()) {
+ TF_SetAttrShape(desc, "shape", dims.data(), dims.size());
+ }
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name,
- TF_DataType dtype) {
+ TF_DataType dtype, const std::vector<int64_t>& dims) {
TF_Operation* op;
- PlaceholderHelper(graph, s, name, dtype, &op);
+ PlaceholderHelper(graph, s, name, dtype, dims, &op);
return op;
}