diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-18 12:55:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-18 12:59:31 -0700 |
commit | f1a1496148ea8a828e37201b8d0ab5d7e4979a1a (patch) | |
tree | f33c2a456b66af0204833e4e77c04cd3eda1a7eb /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | 9fd62257907ed2df20bd89f9346dbe91ada1fbbc (diff) |
Fixing gather to support round-tripping non-zero axis.
PiperOrigin-RevId: 205122011
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 0d7eff5db4..9dde7a8bd6 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1197,8 +1197,17 @@ tensorflow::Status ConvertGatherOperator( auto* op = new GatherOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); - // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we - // should read it an pass it on to the TF Lite Interpreter. + if (node.input_size() >= 3) { + // GatherV2 form where we are provided an axis. It may be either a constant + // or runtime defined value, so we just wire up the array and let + // ResolveGatherAttributes take care of it later on. + const auto axis_data_type = GetDataTypeAttr(node, "Taxis"); + CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64); + op->inputs.push_back(node.input(2)); + } else { + // Gather form that assumes axis=0. + op->axis = {0}; + } op->outputs.push_back(node.name()); model->operators.emplace_back(op); return tensorflow::Status::OK(); |