aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-18 12:55:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 12:59:31 -0700
commitf1a1496148ea8a828e37201b8d0ab5d7e4979a1a (patch)
treef33c2a456b66af0204833e4e77c04cd3eda1a7eb /tensorflow/contrib/lite/toco/import_tensorflow.cc
parent9fd62257907ed2df20bd89f9346dbe91ada1fbbc (diff)
Fixing gather to support round-tripping non-zero axis.
PiperOrigin-RevId: 205122011
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc13
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 0d7eff5db4..9dde7a8bd6 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1197,8 +1197,17 @@ tensorflow::Status ConvertGatherOperator(
auto* op = new GatherOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
- // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we
- // should read it an pass it on to the TF Lite Interpreter.
+ if (node.input_size() >= 3) {
+ // GatherV2 form where we are provided an axis. It may be either a constant
+ // or runtime defined value, so we just wire up the array and let
+ // ResolveGatherAttributes take care of it later on.
+ const auto axis_data_type = GetDataTypeAttr(node, "Taxis");
+ CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64);
+ op->inputs.push_back(node.input(2));
+ } else {
+ // Gather form that assumes axis=0.
+ op->axis = {0};
+ }
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
return tensorflow::Status::OK();