diff options
author | Mingsheng Hong <hongm@google.com> | 2018-03-22 20:51:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-22 20:54:37 -0700 |
commit | 0191d264a6e3da12ff7db5ba8002fed6356f071b (patch) | |
tree | 1e974458c633f2f64e3a6d6165e66905c7605083 /tensorflow/c/c_api_experimental.cc | |
parent | 72f48771ea4fce68af893bcda862ca390a0e6b70 (diff) |
Added experimental C APIs to build a hard-coded stack of dataset + iterator.
PiperOrigin-RevId: 190170332
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r-- | tensorflow/c/c_api_experimental.cc | 223 |
1 files changed, 223 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 29caf508e7..8593a8eb50 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -22,7 +22,15 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/protobuf/config.pb.h" +using tensorflow::Node; +using tensorflow::NodeBuilder; using tensorflow::Status; +using tensorflow::Tensor; + +// struct TF_Operation { tensorflow::Node node; }; +static TF_Operation* ToTF_Operation(Node* node) { + return static_cast<TF_Operation*>(static_cast<void*>(node)); +} void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { tensorflow::ConfigProto& config = options->options.config; @@ -103,3 +111,218 @@ TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, memcpy(ret, debug_str.c_str(), *len + 1); return ret; } + +// TODO(hongm): Replace this will a real implementation. +static tensorflow::Status BuildDatasetTest(TF_Graph* dataset_graph, + Node** dataset_node) { + tensorflow::mutex_lock c(dataset_graph->mu); + Tensor const_t(tensorflow::DT_INT32, tensorflow::TensorShape({})); + const_t.flat<tensorflow::int32>()(0) = 1; + + Node* const_node; + TF_RETURN_IF_ERROR(NodeBuilder("Const", "Const") + .Attr("dtype", tensorflow::DT_INT32) + .Attr("value", const_t) + .Finalize(&dataset_graph->graph, &const_node)); + + std::vector<NodeBuilder::NodeOut> input_list; + input_list.push_back(NodeBuilder::NodeOut(const_node, 0)); + + return NodeBuilder("TensorDataset", "TensorDataset") + .Input(input_list) + .Attr("Toutput_types", {tensorflow::DT_INT32}) + .Attr("output_shapes", {tensorflow::TensorShapeProto()}) + .Finalize(&dataset_graph->graph, dataset_node); +} + +// On success, returns a newly created TF_Function instance from +// `text_proto`. It must be deleted by calling TF_DeleteFunction. +static TF_Function* CreateFunctionFromTextProto(const char* text_proto, + TF_Status* status) { + tensorflow::FunctionDef fdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &fdef)) { + status->status = tensorflow::errors::Internal( + "Invalid text proto for FunctionDef: ", text_proto); + return nullptr; + } + std::vector<char> binary_proto_buf(fdef.ByteSizeLong()); + fdef.SerializeToArray(binary_proto_buf.data(), binary_proto_buf.size()); + return TF_FunctionImportFunctionDef(binary_proto_buf.data(), + binary_proto_buf.size(), status); +} + +// On success, returns a newly created TF_Function instance from `proto_file`, +// and sets `dataset_name` to the created dataset name. The returned function +// must be deleted by calling TF_DeleteFunction. +// +// TODO(hongm): Support reading the file given by `proto_file`. +static TF_Function* LoadDatasetFunction(const char* proto_file, + std::string* dataset_name, + TF_Status* status) { + const char* func_def = R"PREFIX( +signature { + name: "_make_dataset_d8de2712" + output_arg { + name: "TensorSliceDataset" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "TensorSliceDataset/tensors/component_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\000\000(B\000\000,B\000\0000B" + } + } + } + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "TensorSliceDataset/tensors/component_0:output:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + ret { + key: "TensorSliceDataset" + value: "TensorSliceDataset:handle:0" + })PREFIX"; + + *dataset_name = "_make_dataset_d8de2712"; + return CreateFunctionFromTextProto(func_def, status); +} + +// TODO(hongm): Use `file_path` in the implementation. +TF_Operation* TF_MakeIteratorGetNextWithDatasets(TF_Graph* graph, + const char* file_path, + TF_Function** dataset_func, + TF_Status* status) { + tensorflow::Status s; + + // We can parameterize the function name, if we ever need more than 1 + // iterators in a graph. + const std::string dataset_name = "UNIQUE_DATASET"; + + std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> dataset_graph( + TF_NewGraph(), TF_DeleteGraph); + Node* dataset_node = nullptr; + s = BuildDatasetTest(dataset_graph.get(), &dataset_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + TF_Output output{ToTF_Operation(dataset_node), 0}; + std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> result_func( + TF_GraphToFunction(dataset_graph.get(), dataset_name.c_str(), + /*append_hash_to_fn_name*/ false, + /*num_opers*/ -1, + /*opers*/ nullptr, + /*numinputs*/ 0, + /*inputs*/ nullptr, + /*noutputs*/ 1, + /*outputs*/ &output, + /*outputnames*/ nullptr, + /*functionoptions*/ nullptr, "", status), + TF_DeleteFunction); + if (!status->status.ok()) { + return nullptr; + } + + TF_GraphCopyFunction(graph, result_func.get(), /*gradient*/ nullptr, status); + + if (!status->status.ok()) { + return nullptr; + } + + tensorflow::mutex_lock c(graph->mu); + + tensorflow::NameAttrList func; + func.set_name(dataset_name); + // Run the iterator node on CPU. + Node* oneshot_iterator_node; + std::vector<tensorflow::TensorShapeProto> output_shape_list; + output_shape_list.push_back(tensorflow::TensorShapeProto()); + s = NodeBuilder("OneShotIterator", "OneShotIterator") + .Device("/device:CPU:0") + .Attr("container", "") + .Attr("dataset_factory", func) + .Attr("output_types", {tensorflow::DT_INT32}) + .Attr("output_shapes", output_shape_list) + .Attr("shared_name", "") + .Finalize(&graph->graph, &oneshot_iterator_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + // Run shape inference function for each newly added node, so that more + // subsequent nodes can be added to the graph via C API (TF_NewOperation()). + s = graph->refiner.AddNode(oneshot_iterator_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + // Run the iterator node on CPU. + Node* getnext_node; + s = NodeBuilder("IteratorGetNext", "IteratorGetNext") + .Input(oneshot_iterator_node) + .Device("/device:CPU:0") + .Attr("output_types", {tensorflow::DT_INT32}) + .Attr("output_shapes", output_shape_list) + .Finalize(&graph->graph, &getnext_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + // Run shape inference function for each newly added node, so that more + // subsequent nodes can be added to the graph via C API (TF_NewOperation()). + s = graph->refiner.AddNode(getnext_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + VLOG(1) << "Output graph: " << graph->graph.ToGraphDefDebug().DebugString(); + *dataset_func = result_func.release(); + return ToTF_Operation(getnext_node); +} + +void TF_GetAttrScalarTensorShapeProto(TF_Buffer* value, TF_Status* status) { + status->status = Status::OK(); + auto shape = tensorflow::TensorShape({}); + tensorflow::TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + status->status = MessageToBuffer(shape_proto, value); +} |