diff options
author | Mingsheng Hong <hongm@google.com> | 2018-03-19 22:29:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-19 22:33:26 -0700 |
commit | 1a8258e0593270a8e2370517dff8faafce40a687 (patch) | |
tree | e097a315ba6c21f8f2ac72c5af17564d73efbc4e /tensorflow/c/c_test_util.cc | |
parent | 407ddd1c0539cfc5d33ab2629230eab5a958b7d4 (diff) |
Added infeed support for experimental C APIs associated with TPU graph rewrite.
This initial design of the C API is different from (and mostly higher level
than) the python API counterparts for infeed, in that the python API has
explicit graph construction APIs for generating infeed enqueue/dequeue
ops (e.g. split_inputs_and_generate_enqueue_ops() and generate_dequeue_op()),
while the C API takes an input graph and redirects all input nodes to feed the
infeed enqueue.
One requirement/restriction is that the input nodes in the TF
graph (e.g. Placeholder) must specify their tensor shapes, for infeed enqueue
and dequeue nodes to properly compile with XLA. The API for more general shape
support will be designed and implemented later.
PiperOrigin-RevId: 189693028
Diffstat (limited to 'tensorflow/c/c_test_util.cc')
-rw-r--r-- | tensorflow/c/c_test_util.cc | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 22f77e7b87..f3b28c1708 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -94,18 +94,22 @@ TF_Tensor* FloatTensor(float v) { // one cannot call ASSERT_* methods in non-void-returning functions (when // exceptions are disabled during compilation) void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, - TF_DataType dtype, TF_Operation** op) { + TF_DataType dtype, const std::vector<int64_t>& dims, + TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); TF_SetAttrType(desc, "dtype", dtype); + if (!dims.empty()) { + TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); + } *op = TF_FinishOperation(desc, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_NE(*op, nullptr); } TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name, - TF_DataType dtype) { + TF_DataType dtype, const std::vector<int64_t>& dims) { TF_Operation* op; - PlaceholderHelper(graph, s, name, dtype, &op); + PlaceholderHelper(graph, s, name, dtype, dims, &op); return op; } |