aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_experimental.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-03-26 15:37:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 15:40:16 -0700
commit290632966fae0619db30c1ba777634db9a43b757 (patch)
treeea02368f1ffa13f5ce42a71a96eef5fad9ebde4a /tensorflow/c/c_api_experimental.cc
parenteee15c1f8ea56dbb516fa9e35392e0a224e99966 (diff)
In the experimental C API, parametrized batch_size for the generate dataset / iterator stack.
PiperOrigin-RevId: 190536945
Diffstat (limited to 'tensorflow/c/c_api_experimental.cc')
-rw-r--r--tensorflow/c/c_api_experimental.cc67
1 files changed, 41 insertions, 26 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index f411efc941..bea9378571 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -7125,7 +7125,8 @@ library {
// sets `dataset_name` to the created dataset name. The returned functions must
// be deleted by calling TF_DeleteFunction.
static std::vector<UniqueFuncPtr> CreateMNISTDatasetFunctions(
- const char* file_path, std::string* dataset_name, TF_Status* status) {
+ const char* file_path, int batch_size, std::string* dataset_name,
+ TF_Status* status) {
const char* func_def = R"PREFIX(
library {
function {
@@ -8089,7 +8090,7 @@ library {
dtype: DT_INT64
tensor_shape {
}
- int64_val: 128
+ int64_val: -123
}
}
}
@@ -8145,7 +8146,7 @@ library {
dtype: DT_INT64
tensor_shape {
}
- int64_val: 128
+ int64_val: -123
}
}
}
@@ -8211,35 +8212,48 @@ library {
*dataset_name = "_make_dataset_2451e43a";
std::function<void(FunctionDef*)> mutate_proto_func =
- [dataset_name, file_path](FunctionDef* fdef) {
+ [dataset_name, file_path, batch_size](FunctionDef* fdef) {
VLOG(1) << "Processsing function " << fdef->DebugString();
if (std::string(fdef->signature().name()) != *dataset_name) return;
// Change the input file pattern to `file_path`.
- bool found = false;
+ bool found_file_path = false, found_batch_size = false;
// `node_def` may be mutated.
for (auto& node_def : *fdef->mutable_node_def()) {
- if (node_def.name() != "FixedLengthRecordDataset/filenames" &&
- node_def.name() != "FixedLengthRecordDataset_1/filenames_1")
- continue;
- DCHECK_EQ(node_def.op(), "Const");
- DCHECK_GT(node_def.attr().count("value"), 0);
- found = true;
- // Replace $(DATA_DIR)/foo with <file_path>/foo
- // TODO(hongm): Use StringPiece manipulation for better efficiency.
- const std::string cur_value =
- node_def.attr().at("value").tensor().string_val(0);
- const std::string pattern = "$(DATA_DIR)";
- DCHECK_EQ(cur_value.compare(0, pattern.length(), pattern), 0);
- const std::string new_value =
- file_path + cur_value.substr(pattern.length());
- VLOG(1) << "Setting the value of node_def " << node_def.name()
- << " to " << new_value;
- auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor();
- tensor->clear_string_val();
- tensor->add_string_val(new_value);
+ if (node_def.name() == "FixedLengthRecordDataset/filenames" ||
+ node_def.name() == "FixedLengthRecordDataset_1/filenames_1") {
+ DCHECK_EQ(node_def.op(), "Const");
+ DCHECK_GT(node_def.attr().count("value"), 0);
+ found_file_path = true;
+ // Replace $(DATA_DIR)/foo with <file_path>/foo
+ // TODO(hongm): Use StringPiece manipulation for better efficiency.
+ const std::string cur_value =
+ node_def.attr().at("value").tensor().string_val(0);
+ const std::string pattern = "$(DATA_DIR)";
+ DCHECK_EQ(cur_value.compare(0, pattern.length(), pattern), 0);
+ const std::string new_value =
+ file_path + cur_value.substr(pattern.length());
+ VLOG(1) << "Setting the value of node_def " << node_def.name()
+ << " to " << new_value;
+ auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor();
+ tensor->clear_string_val();
+ tensor->add_string_val(new_value);
+ } else if (node_def.name() == "BatchDataset/batch_size" ||
+ node_def.name() == "FilterDataset/batch_size_1") {
+ DCHECK_EQ(node_def.op(), "Const");
+ DCHECK_GT(node_def.attr().count("value"), 0);
+ found_batch_size = true;
+ // Replace $(BATCH_SIZE) with `batch_size`
+ DCHECK_EQ(node_def.attr().at("value").tensor().int64_val(0), -123);
+ VLOG(1) << "Setting the batch size attr value of node_def "
+ << node_def.name() << " to " << batch_size;
+ auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor();
+ tensor->clear_int64_val();
+ tensor->add_int64_val(batch_size);
+ }
}
VLOG(1) << "Rewrote function to " << fdef->DebugString();
- DCHECK(found);
+ DCHECK(found_file_path);
+ DCHECK(found_batch_size);
};
return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status);
}
@@ -8341,7 +8355,8 @@ TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
std::string dataset_name;
const auto& funcs =
is_mnist
- ? CreateMNISTDatasetFunctions(file_path, &dataset_name, status)
+ ? CreateMNISTDatasetFunctions(file_path, batch_size, &dataset_name,
+ status)
: CreateImagenetDatasetFunctions(file_path, &dataset_name, status);
if (!status->status.ok()) {
return nullptr;