diff options
Diffstat (limited to 'tensorflow/tools/benchmark/benchmark_model.cc')
-rw-r--r-- | tensorflow/tools/benchmark/benchmark_model.cc | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index 10a20db956..078fc95076 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -109,8 +109,21 @@ void CreateTensorsFromInputInfo( InitializeTensor<uint8>(input.initialization_values, &input_tensor); break; } + case DT_BOOL: { + InitializeTensor<bool>(input.initialization_values, &input_tensor); + break; + } + case DT_STRING: { + if (!input.initialization_values.empty()) { + LOG(FATAL) << "Initialization values are not supported for strings"; + } + auto type_tensor = input_tensor.flat<string>(); + type_tensor = type_tensor.constant(""); + break; + } default: - LOG(FATAL) << "Unsupported input type: " << input.data_type; + LOG(FATAL) << "Unsupported input type: " + << DataTypeString(input.data_type); } input_tensors->push_back({input.name, input_tensor}); } @@ -212,6 +225,7 @@ Status RunBenchmark(const std::vector<InputLayerInfo>& inputs, if (!s.ok()) { LOG(ERROR) << "Error during inference: " << s; + return s; } assert(run_metadata.has_step_stats()); |