diff options
Diffstat (limited to 'tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c')
-rw-r--r-- | tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c | 81 |
1 files changed, 54 insertions, 27 deletions
diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c b/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c index 7c82158522..d83b58dc6b 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c +++ b/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c @@ -93,13 +93,16 @@ static uint32_t FindMaxIdx(const float* data, uint32_t entries) { return FindMaxIdxWithExcludeList(data, entries, 0, NULL); } -void hexagon_controller_PrintMaxNIdx(const float *data, const uint32_t entries, +void hexagon_controller_PrintMaxNIdx(const float* data, const uint32_t entries, const int n, int* out_ranking) { if (DUMP_OUTPUT) { for (int i = 0; i < entries; ++i) { TFMLOGD("%d: val = %f", i, data[i]); } } + if (n >= entries) { + TFMLOGD("Too many N %d >= %d", n, entries); + } for (int i = 0; i < n; ++i) { out_ranking[i] = INT_MAX; } @@ -182,6 +185,32 @@ uint32_t hexagon_controller_SetupGraph(int version) { return nn_id; } +bool hexagon_controller_ExecuteGraphWithMultipleInOut( + const uint32_t nn_id, const int input_count, hexagon_nn_tensordef* inputs, + const int output_count, hexagon_nn_tensordef* outputs) { + if (DBG_EXECUTION) { + TFMLOGD("Preparing to execute... in = %d, out = %d", input_count, + output_count); + LogDHexagon("Execute graph!"); + } + + const int err = + hexagon_nn_execute_new(nn_id, inputs, input_count, outputs, output_count); + if (err != 0) { + if (DBG_EXECUTION) { + LogDHexagon("Execution failed!"); + TFMLOGE("execute got err: %d\n", err); + hexagon_controller_PrintLog(nn_id); + } + return false; + } else { + if (DBG_EXECUTION) { + LogDHexagon("Execution succeeded!"); + } + return true; + } +} + bool hexagon_controller_ExecuteGraph( const uint32_t nn_id, const uint32_t batches, @@ -197,7 +226,6 @@ bool hexagon_controller_ExecuteGraph( uint8_t* out_vals, const uint32_t output_val_byte_size, uint32_t* out_data_byte_size) { - int err; if (DBG_EXECUTION) { TFMLOGD("Preparing to execute..."); TFMLOGD("Input: %d, %d, %d, %d, %d, %d", @@ -205,35 +233,34 @@ bool hexagon_controller_ExecuteGraph( TFMLOGD("Output: %d, %p", output_val_byte_size, out_vals); LogDHexagon("Execute graph!"); } - - if ((err = hexagon_nn_execute(nn_id, - batches, - height, - width, - depth, - int_data, - int_data_size, - out_batches, - out_height, - out_width, - out_depth, - out_vals, - output_val_byte_size, - out_data_byte_size)) != 0) { - if (DBG_EXECUTION) { - LogDHexagon("Execution failed!"); - TFMLOGE("execute got err: %d\n",err); - } + + hexagon_nn_tensordef input; + hexagon_nn_tensordef output; + + input.batches = batches; + input.height = height; + input.width = width; + input.depth = depth; + input.data = int_data; + input.dataLen = int_data_size; + + output.data = out_vals; + output.dataLen = output_val_byte_size; + + if (!hexagon_controller_ExecuteGraphWithMultipleInOut(nn_id, 1, &input, 1, + &output)) { return false; } else { + *out_batches = output.batches; + *out_height = output.height; + *out_width = output.width; + *out_depth = output.depth; + *out_data_byte_size = output.dataLen; + if (DBG_EXECUTION) { LogDHexagon("Execution succeeded!"); - TFMLOGD("%d x %d x %d x %d, byte size = %d\n", - *out_batches, - *out_height, - *out_width, - *out_depth, - *out_data_byte_size); + TFMLOGD("%d x %d x %d x %d, byte size = %d\n", *out_batches, *out_height, + *out_width, *out_depth, *out_data_byte_size); } return true; } |