diff options
author | Mingsheng Hong <hongm@google.com> | 2018-03-26 15:37:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-26 15:40:16 -0700 |
commit | 290632966fae0619db30c1ba777634db9a43b757 (patch) | |
tree | ea02368f1ffa13f5ce42a71a96eef5fad9ebde4a /tensorflow/c/c_api_experimental.cc | |
parent | eee15c1f8ea56dbb516fa9e35392e0a224e99966 (diff) |
In the experimental C API, parametrized batch_size for the generate dataset / iterator stack.
PiperOrigin-RevId: 190536945
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r-- | tensorflow/c/c_api_experimental.cc | 67 |
1 files changed, 41 insertions, 26 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index f411efc941..bea9378571 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -7125,7 +7125,8 @@ library { // sets `dataset_name` to the created dataset name. The returned functions must // be deleted by calling TF_DeleteFunction. static std::vector<UniqueFuncPtr> CreateMNISTDatasetFunctions( - const char* file_path, std::string* dataset_name, TF_Status* status) { + const char* file_path, int batch_size, std::string* dataset_name, + TF_Status* status) { const char* func_def = R"PREFIX( library { function { @@ -8089,7 +8090,7 @@ library { dtype: DT_INT64 tensor_shape { } - int64_val: 128 + int64_val: -123 } } } @@ -8145,7 +8146,7 @@ library { dtype: DT_INT64 tensor_shape { } - int64_val: 128 + int64_val: -123 } } } @@ -8211,35 +8212,48 @@ library { *dataset_name = "_make_dataset_2451e43a"; std::function<void(FunctionDef*)> mutate_proto_func = - [dataset_name, file_path](FunctionDef* fdef) { + [dataset_name, file_path, batch_size](FunctionDef* fdef) { VLOG(1) << "Processsing function " << fdef->DebugString(); if (std::string(fdef->signature().name()) != *dataset_name) return; // Change the input file pattern to `file_path`. - bool found = false; + bool found_file_path = false, found_batch_size = false; // `node_def` may be mutated. for (auto& node_def : *fdef->mutable_node_def()) { - if (node_def.name() != "FixedLengthRecordDataset/filenames" && - node_def.name() != "FixedLengthRecordDataset_1/filenames_1") - continue; - DCHECK_EQ(node_def.op(), "Const"); - DCHECK_GT(node_def.attr().count("value"), 0); - found = true; - // Replace $(DATA_DIR)/foo with <file_path>/foo - // TODO(hongm): Use StringPiece manipulation for better efficiency. - const std::string cur_value = - node_def.attr().at("value").tensor().string_val(0); - const std::string pattern = "$(DATA_DIR)"; - DCHECK_EQ(cur_value.compare(0, pattern.length(), pattern), 0); - const std::string new_value = - file_path + cur_value.substr(pattern.length()); - VLOG(1) << "Setting the value of node_def " << node_def.name() - << " to " << new_value; - auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); - tensor->clear_string_val(); - tensor->add_string_val(new_value); + if (node_def.name() == "FixedLengthRecordDataset/filenames" || + node_def.name() == "FixedLengthRecordDataset_1/filenames_1") { + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found_file_path = true; + // Replace $(DATA_DIR)/foo with <file_path>/foo + // TODO(hongm): Use StringPiece manipulation for better efficiency. + const std::string cur_value = + node_def.attr().at("value").tensor().string_val(0); + const std::string pattern = "$(DATA_DIR)"; + DCHECK_EQ(cur_value.compare(0, pattern.length(), pattern), 0); + const std::string new_value = + file_path + cur_value.substr(pattern.length()); + VLOG(1) << "Setting the value of node_def " << node_def.name() + << " to " << new_value; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_string_val(); + tensor->add_string_val(new_value); + } else if (node_def.name() == "BatchDataset/batch_size" || + node_def.name() == "FilterDataset/batch_size_1") { + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found_batch_size = true; + // Replace $(BATCH_SIZE) with `batch_size` + DCHECK_EQ(node_def.attr().at("value").tensor().int64_val(0), -123); + VLOG(1) << "Setting the batch size attr value of node_def " + << node_def.name() << " to " << batch_size; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_int64_val(); + tensor->add_int64_val(batch_size); + } } VLOG(1) << "Rewrote function to " << fdef->DebugString(); - DCHECK(found); + DCHECK(found_file_path); + DCHECK(found_batch_size); }; return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); } @@ -8341,7 +8355,8 @@ TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( std::string dataset_name; const auto& funcs = is_mnist - ? CreateMNISTDatasetFunctions(file_path, &dataset_name, status) + ? CreateMNISTDatasetFunctions(file_path, batch_size, &dataset_name, + status) : CreateImagenetDatasetFunctions(file_path, &dataset_name, status); if (!status->status.ok()) { return nullptr; |