aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/benchmark/benchmark_model.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/benchmark/benchmark_model.cc')
-rw-r--r--tensorflow/tools/benchmark/benchmark_model.cc16
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc
index 10a20db956..078fc95076 100644
--- a/tensorflow/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/tools/benchmark/benchmark_model.cc
@@ -109,8 +109,21 @@ void CreateTensorsFromInputInfo(
InitializeTensor<uint8>(input.initialization_values, &input_tensor);
break;
}
+ case DT_BOOL: {
+ InitializeTensor<bool>(input.initialization_values, &input_tensor);
+ break;
+ }
+ case DT_STRING: {
+ if (!input.initialization_values.empty()) {
+ LOG(FATAL) << "Initialization values are not supported for strings";
+ }
+ auto type_tensor = input_tensor.flat<string>();
+ type_tensor = type_tensor.constant("");
+ break;
+ }
default:
- LOG(FATAL) << "Unsupported input type: " << input.data_type;
+ LOG(FATAL) << "Unsupported input type: "
+ << DataTypeString(input.data_type);
}
input_tensors->push_back({input.name, input_tensor});
}
@@ -212,6 +225,7 @@ Status RunBenchmark(const std::vector<InputLayerInfo>& inputs,
if (!s.ok()) {
LOG(ERROR) << "Error during inference: " << s;
+ return s;
}
assert(run_metadata.has_step_stats());