aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c')
-rw-r--r--tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c81
1 files changed, 54 insertions, 27 deletions
diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c b/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c
index 7c82158522..d83b58dc6b 100644
--- a/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c
+++ b/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c
@@ -93,13 +93,16 @@ static uint32_t FindMaxIdx(const float* data, uint32_t entries) {
return FindMaxIdxWithExcludeList(data, entries, 0, NULL);
}
-void hexagon_controller_PrintMaxNIdx(const float *data, const uint32_t entries,
+void hexagon_controller_PrintMaxNIdx(const float* data, const uint32_t entries,
const int n, int* out_ranking) {
if (DUMP_OUTPUT) {
for (int i = 0; i < entries; ++i) {
TFMLOGD("%d: val = %f", i, data[i]);
}
}
+ if (n >= entries) {
+ TFMLOGD("Too many N %d >= %d", n, entries);
+ }
for (int i = 0; i < n; ++i) {
out_ranking[i] = INT_MAX;
}
@@ -182,6 +185,32 @@ uint32_t hexagon_controller_SetupGraph(int version) {
return nn_id;
}
+bool hexagon_controller_ExecuteGraphWithMultipleInOut(
+ const uint32_t nn_id, const int input_count, hexagon_nn_tensordef* inputs,
+ const int output_count, hexagon_nn_tensordef* outputs) {
+ if (DBG_EXECUTION) {
+ TFMLOGD("Preparing to execute... in = %d, out = %d", input_count,
+ output_count);
+ LogDHexagon("Execute graph!");
+ }
+
+ const int err =
+ hexagon_nn_execute_new(nn_id, inputs, input_count, outputs, output_count);
+ if (err != 0) {
+ if (DBG_EXECUTION) {
+ LogDHexagon("Execution failed!");
+ TFMLOGE("execute got err: %d\n", err);
+ hexagon_controller_PrintLog(nn_id);
+ }
+ return false;
+ } else {
+ if (DBG_EXECUTION) {
+ LogDHexagon("Execution succeeded!");
+ }
+ return true;
+ }
+}
+
bool hexagon_controller_ExecuteGraph(
const uint32_t nn_id,
const uint32_t batches,
@@ -197,7 +226,6 @@ bool hexagon_controller_ExecuteGraph(
uint8_t* out_vals,
const uint32_t output_val_byte_size,
uint32_t* out_data_byte_size) {
- int err;
if (DBG_EXECUTION) {
TFMLOGD("Preparing to execute...");
TFMLOGD("Input: %d, %d, %d, %d, %d, %d",
@@ -205,35 +233,34 @@ bool hexagon_controller_ExecuteGraph(
TFMLOGD("Output: %d, %p", output_val_byte_size, out_vals);
LogDHexagon("Execute graph!");
}
-
- if ((err = hexagon_nn_execute(nn_id,
- batches,
- height,
- width,
- depth,
- int_data,
- int_data_size,
- out_batches,
- out_height,
- out_width,
- out_depth,
- out_vals,
- output_val_byte_size,
- out_data_byte_size)) != 0) {
- if (DBG_EXECUTION) {
- LogDHexagon("Execution failed!");
- TFMLOGE("execute got err: %d\n",err);
- }
+
+ hexagon_nn_tensordef input;
+ hexagon_nn_tensordef output;
+
+ input.batches = batches;
+ input.height = height;
+ input.width = width;
+ input.depth = depth;
+ input.data = int_data;
+ input.dataLen = int_data_size;
+
+ output.data = out_vals;
+ output.dataLen = output_val_byte_size;
+
+ if (!hexagon_controller_ExecuteGraphWithMultipleInOut(nn_id, 1, &input, 1,
+ &output)) {
return false;
} else {
+ *out_batches = output.batches;
+ *out_height = output.height;
+ *out_width = output.width;
+ *out_depth = output.depth;
+ *out_data_byte_size = output.dataLen;
+
if (DBG_EXECUTION) {
LogDHexagon("Execution succeeded!");
- TFMLOGD("%d x %d x %d x %d, byte size = %d\n",
- *out_batches,
- *out_height,
- *out_width,
- *out_depth,
- *out_data_byte_size);
+ TFMLOGD("%d x %d x %d x %d, byte size = %d\n", *out_batches, *out_height,
+ *out_width, *out_depth, *out_data_byte_size);
}
return true;
}