aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_experimental.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-03-26 11:50:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 11:53:04 -0700
commitf9cfb9e917c8937152b248c300b095798d79501a (patch)
treec269dd414b3dded2d8b6a2dc89010c7540926369 /tensorflow/c/c_api_experimental.cc
parentd2604f8dcb8a63ca063f712c24ce5aa63403b0aa (diff)
Extended experimental C API with MNIST dataset/iterators support.
PiperOrigin-RevId: 190500020
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r--tensorflow/c/c_api_experimental.cc1151
1 files changed, 1140 insertions, 11 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 1c809cb21e..f411efc941 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -138,7 +138,7 @@ static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
return {};
}
std::vector<UniqueFuncPtr> ret;
- for (const auto& fdef : fdef_lib.function()) {
+ for (const FunctionDef& fdef : fdef_lib.function()) {
// Make a copy so that we can mutate it.
FunctionDef fdef_to_load = fdef;
if (mutate_proto_func) {
@@ -148,8 +148,8 @@ static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
fdef_to_load.SerializeToArray(binary_proto_buf.data(),
binary_proto_buf.size());
- auto func = TF_FunctionImportFunctionDef(binary_proto_buf.data(),
- binary_proto_buf.size(), status);
+ TF_Function* func = TF_FunctionImportFunctionDef(
+ binary_proto_buf.data(), binary_proto_buf.size(), status);
if (!status->status.ok()) return {};
ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
}
@@ -7120,6 +7120,1130 @@ library {
return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status);
}
+// On success, returns a set of TF_Function instances encoding a dataset
+// node stack that reads an MNIST file dataset from `file_path`, and
+// sets `dataset_name` to the created dataset name. The returned functions must
+// be deleted by calling TF_DeleteFunction.
+static std::vector<UniqueFuncPtr> CreateMNISTDatasetFunctions(
+ const char* file_path, std::string* dataset_name, TF_Status* status) {
+ const char* func_def = R"PREFIX(
+library {
+ function {
+ signature {
+ name: "tf_map_func_521bfd08"
+ input_arg {
+ name: "arg0"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "truediv"
+ type: DT_FLOAT
+ }
+ description: "A wrapper for Defun that facilitates shape inference."
+ }
+ node_def {
+ name: "DecodeRaw"
+ op: "DecodeRaw"
+ input: "arg0"
+ attr {
+ key: "little_endian"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_UINT8
+ }
+ }
+ }
+ node_def {
+ name: "Cast"
+ op: "Cast"
+ input: "DecodeRaw:output:0"
+ attr {
+ key: "DstT"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "SrcT"
+ value {
+ type: DT_UINT8
+ }
+ }
+ }
+ node_def {
+ name: "Reshape/shape"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 784
+ }
+ }
+ }
+ }
+ node_def {
+ name: "Reshape"
+ op: "Reshape"
+ input: "Cast:y:0"
+ input: "Reshape/shape:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node_def {
+ name: "truediv/y"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 255.0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "truediv"
+ op: "RealDiv"
+ input: "Reshape:output:0"
+ input: "truediv/y:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "truediv"
+ value: "truediv:z:0"
+ }
+ }
+ function {
+ signature {
+ name: "tf_map_func_9a08860d"
+ input_arg {
+ name: "arg0"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "ToInt32"
+ type: DT_INT32
+ }
+ description: "A wrapper for Defun that facilitates shape inference."
+ }
+ node_def {
+ name: "DecodeRaw"
+ op: "DecodeRaw"
+ input: "arg0"
+ attr {
+ key: "little_endian"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_UINT8
+ }
+ }
+ }
+ node_def {
+ name: "Reshape/shape"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+ }
+ node_def {
+ name: "Reshape"
+ op: "Reshape"
+ input: "DecodeRaw:output:0"
+ input: "Reshape/shape:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_UINT8
+ }
+ }
+ attr {
+ key: "Tshape"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node_def {
+ name: "ToInt32"
+ op: "Cast"
+ input: "Reshape:output:0"
+ attr {
+ key: "DstT"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "SrcT"
+ value {
+ type: DT_UINT8
+ }
+ }
+ }
+ ret {
+ key: "ToInt32"
+ value: "ToInt32:y:0"
+ }
+ }
+ function {
+ signature {
+ name: "tf_predicate_7089b845"
+ input_arg {
+ name: "arg0"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "arg1"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "Equal/Placeholder"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "Equal"
+ type: DT_BOOL
+ }
+ description: "A wrapper for Defun that facilitates shape inference."
+ }
+ node_def {
+ name: "Shape"
+ op: "Shape"
+ input: "arg0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT64
+ }
+ }
+ }
+ node_def {
+ name: "strided_slice/stack"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "strided_slice/stack_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node_def {
+ name: "strided_slice/stack_2"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node_def {
+ name: "strided_slice"
+ op: "StridedSlice"
+ input: "Shape:output:0"
+ input: "strided_slice/stack:output:0"
+ input: "strided_slice/stack_1:output:0"
+ input: "strided_slice/stack_2:output:0"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "begin_mask"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "ellipsis_mask"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "end_mask"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "new_axis_mask"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "shrink_axis_mask"
+ value {
+ i: 1
+ }
+ }
+ }
+ node_def {
+ name: "Equal"
+ op: "Equal"
+ input: "strided_slice:output:0"
+ input: "Equal/Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT64
+ }
+ }
+ }
+ ret {
+ key: "Equal"
+ value: "Equal:z:0"
+ }
+ }
+ function {
+ signature {
+ name: "_make_dataset_2451e43a"
+ output_arg {
+ name: "FilterDataset"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+ }
+ node_def {
+ name: "FixedLengthRecordDataset/filenames"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ }
+ string_val: "$(DATA_DIR)/train-images-idx3-ubyte"
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset/header_bytes"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 16
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset/record_bytes"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 784
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset/footer_bytes"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset/buffer_size"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 262144
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset"
+ op: "FixedLengthRecordDataset"
+ input: "FixedLengthRecordDataset/filenames:output:0"
+ input: "FixedLengthRecordDataset/header_bytes:output:0"
+ input: "FixedLengthRecordDataset/record_bytes:output:0"
+ input: "FixedLengthRecordDataset/footer_bytes:output:0"
+ input: "FixedLengthRecordDataset/buffer_size:output:0"
+ }
+ node_def {
+ name: "MapDataset"
+ op: "MapDataset"
+ input: "FixedLengthRecordDataset:handle:0"
+ attr {
+ key: "Targuments"
+ value {
+ list {
+ }
+ }
+ }
+ attr {
+ key: "f"
+ value {
+ func {
+ name: "tf_map_func_521bfd08"
+ }
+ }
+ }
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_types"
+ value {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset_1/filenames_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ }
+ string_val: "$(DATA_DIR)/train-labels-idx1-ubyte"
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset_1/header_bytes_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 8
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset_1/record_bytes_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 1
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset_1/footer_bytes_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset_1/buffer_size_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 262144
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FixedLengthRecordDataset_1"
+ op: "FixedLengthRecordDataset"
+ input: "FixedLengthRecordDataset_1/filenames_1:output:0"
+ input: "FixedLengthRecordDataset_1/header_bytes_1:output:0"
+ input: "FixedLengthRecordDataset_1/record_bytes_1:output:0"
+ input: "FixedLengthRecordDataset_1/footer_bytes_1:output:0"
+ input: "FixedLengthRecordDataset_1/buffer_size_1:output:0"
+ }
+ node_def {
+ name: "MapDataset_1"
+ op: "MapDataset"
+ input: "FixedLengthRecordDataset_1:handle:0"
+ attr {
+ key: "Targuments"
+ value {
+ list {
+ }
+ }
+ }
+ attr {
+ key: "f"
+ value {
+ func {
+ name: "tf_map_func_9a08860d"
+ }
+ }
+ }
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_types"
+ value {
+ list {
+ type: DT_INT32
+ }
+ }
+ }
+ }
+ node_def {
+ name: "ZipDataset"
+ op: "ZipDataset"
+ input: "MapDataset:handle:0"
+ input: "MapDataset_1:handle:0"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ }
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_types"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_INT32
+ }
+ }
+ }
+ }
+ node_def {
+ name: "CacheDataset/filename"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ }
+ string_val: ""
+ }
+ }
+ }
+ }
+ node_def {
+ name: "CacheDataset"
+ op: "CacheDataset"
+ input: "ZipDataset:handle:0"
+ input: "CacheDataset/filename:output:0"
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ }
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_types"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_INT32
+ }
+ }
+ }
+ }
+ node_def {
+ name: "RepeatDataset/count"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: -1
+ }
+ }
+ }
+ }
+ node_def {
+ name: "RepeatDataset"
+ op: "RepeatDataset"
+ input: "CacheDataset:handle:0"
+ input: "RepeatDataset/count:output:0"
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ }
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_types"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_INT32
+ }
+ }
+ }
+ }
+ node_def {
+ name: "ShuffleDataset/buffer_size_2"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 50000
+ }
+ }
+ }
+ }
+ node_def {
+ name: "ShuffleDataset/seed"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "ShuffleDataset/seed2"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "ShuffleDataset"
+ op: "ShuffleDataset"
+ input: "RepeatDataset:handle:0"
+ input: "ShuffleDataset/buffer_size_2:output:0"
+ input: "ShuffleDataset/seed:output:0"
+ input: "ShuffleDataset/seed2:output:0"
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 784
+ }
+ }
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_types"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_INT32
+ }
+ }
+ }
+ attr {
+ key: "reshuffle_each_iteration"
+ value {
+ b: true
+ }
+ }
+ }
+ node_def {
+ name: "BatchDataset/batch_size"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 128
+ }
+ }
+ }
+ }
+ node_def {
+ name: "BatchDataset"
+ op: "BatchDataset"
+ input: "ShuffleDataset:handle:0"
+ input: "BatchDataset/batch_size:output:0"
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_types"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_INT32
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FilterDataset/batch_size_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT64
+ tensor_shape {
+ }
+ int64_val: 128
+ }
+ }
+ }
+ }
+ node_def {
+ name: "FilterDataset"
+ op: "FilterDataset"
+ input: "BatchDataset:handle:0"
+ input: "FilterDataset/batch_size_1:output:0"
+ attr {
+ key: "Targuments"
+ value {
+ list {
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 784
+ }
+ }
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "output_types"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_INT32
+ }
+ }
+ }
+ attr {
+ key: "predicate"
+ value {
+ func {
+ name: "tf_predicate_7089b845"
+ }
+ }
+ }
+ }
+ ret {
+ key: "FilterDataset"
+ value: "FilterDataset:handle:0"
+ }
+ }
+}
+)PREFIX";
+
+ *dataset_name = "_make_dataset_2451e43a";
+ std::function<void(FunctionDef*)> mutate_proto_func =
+ [dataset_name, file_path](FunctionDef* fdef) {
+ VLOG(1) << "Processsing function " << fdef->DebugString();
+ if (std::string(fdef->signature().name()) != *dataset_name) return;
+ // Change the input file pattern to `file_path`.
+ bool found = false;
+ // `node_def` may be mutated.
+ for (auto& node_def : *fdef->mutable_node_def()) {
+ if (node_def.name() != "FixedLengthRecordDataset/filenames" &&
+ node_def.name() != "FixedLengthRecordDataset_1/filenames_1")
+ continue;
+ DCHECK_EQ(node_def.op(), "Const");
+ DCHECK_GT(node_def.attr().count("value"), 0);
+ found = true;
+ // Replace $(DATA_DIR)/foo with <file_path>/foo
+ // TODO(hongm): Use StringPiece manipulation for better efficiency.
+ const std::string cur_value =
+ node_def.attr().at("value").tensor().string_val(0);
+ const std::string pattern = "$(DATA_DIR)";
+ DCHECK_EQ(cur_value.compare(0, pattern.length(), pattern), 0);
+ const std::string new_value =
+ file_path + cur_value.substr(pattern.length());
+ VLOG(1) << "Setting the value of node_def " << node_def.name()
+ << " to " << new_value;
+ auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor();
+ tensor->clear_string_val();
+ tensor->add_string_val(new_value);
+ }
+ VLOG(1) << "Rewrote function to " << fdef->DebugString();
+ DCHECK(found);
+ };
+ return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status);
+}
+
// Adds the input functions to `graph`. On success, returns the created
// IteratorGetNext node.
static TF_Operation* AddDatasetFunctionAndIteratorNodesToGraph(
@@ -7209,15 +8333,16 @@ TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(TF_Graph* graph,
return getnext_node;
}
-TF_Operation* TF_MakeImagenetIteratorGetNextWithDatasets(TF_Graph* graph,
- const char* file_path,
- int batch_size,
- TF_Status* status) {
+TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
+ TF_Graph* graph, const char* file_path, int batch_size,
+ unsigned char is_mnist, TF_Status* status) {
tensorflow::Status s;
std::string dataset_name;
const auto& funcs =
- CreateImagenetDatasetFunctions(file_path, &dataset_name, status);
+ is_mnist
+ ? CreateMNISTDatasetFunctions(file_path, &dataset_name, status)
+ : CreateImagenetDatasetFunctions(file_path, &dataset_name, status);
if (!status->status.ok()) {
return nullptr;
}
@@ -7226,9 +8351,13 @@ TF_Operation* TF_MakeImagenetIteratorGetNextWithDatasets(TF_Graph* graph,
// batch_size X 224 X 224 X 3
auto image_shape = tensorflow::TensorShapeProto();
image_shape.add_dim()->set_size(batch_size);
- image_shape.add_dim()->set_size(224);
- image_shape.add_dim()->set_size(224);
- image_shape.add_dim()->set_size(3);
+ if (is_mnist) {
+ image_shape.add_dim()->set_size(784);
+ } else {
+ image_shape.add_dim()->set_size(224);
+ image_shape.add_dim()->set_size(224);
+ image_shape.add_dim()->set_size(3);
+ }
output_shape_list.push_back(image_shape);
// batch_size