aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/testing/BUILD2
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py3
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc5
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc9
4 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index dc9492f5e2..555ea90034 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -29,7 +29,7 @@ gen_zipped_test_files(
"exp.zip",
"fully_connected.zip",
"fused_batch_norm.zip",
- # "gather.zip", #TODO(b/76437794)
+ "gather.zip",
"global_batch_norm.zip",
"l2_pool.zip",
"l2norm.zip",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index e4ef17585f..cb5c500136 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -874,12 +874,11 @@ def make_gather_tests(zip_path):
# TODO(mgubin): add string tests when they are supported by Toco.
# TODO(mgubin): add tests for Nd indices when they are supported by
# TfLite.
- # TODO(mgubin): add tests for axis != 0 when it is supported by TfLite.
"params_dtype": [tf.float32, tf.int32],
"params_shape": [[10], [1, 2, 20]],
"indices_dtype": [tf.int32],
"indices_shape": [[3], [5]],
- "axis": [0], # axis!=0 is GatherV2
+ "axis": [0, 1],
}]
def build_graph(parameters):
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 08354b762c..a4a7283508 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -91,6 +91,9 @@ std::map<string, string> kBrokenTests = {
// PRelu only supports 4D input with (1, 1, channels) 3D alpha now.
{R"(^\/prelu.*shared_axes=\[1\])", "75975192"},
+
+ // No support for axis!=0 in GatherV2.
+ {R"(^\/gather.*axis=1)", "76910444"},
};
// Allows test data to be unzipped into a temporary directory and makes
@@ -244,7 +247,7 @@ INSTANTIATE_TESTS(div)
INSTANTIATE_TESTS(exp)
INSTANTIATE_TESTS(fully_connected)
INSTANTIATE_TESTS(fused_batch_norm)
-// INSTANTIATE_TESTS(gather) //TODO(b/76437794)
+INSTANTIATE_TESTS(gather)
INSTANTIATE_TESTS(global_batch_norm)
INSTANTIATE_TESTS(l2_pool)
INSTANTIATE_TESTS(l2norm)
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index b844e0b948..c26e4bddff 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1343,13 +1343,16 @@ void ConvertFloorOperator(const NodeDef& node,
void ConvertGatherOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK_EQ(node.op(), "Gather");
- CheckInputsCount(node, tf_import_flags, 2);
+ CHECK(node.op() == "Gather" || node.op() == "GatherV2");
+ if (node.op() == "Gather") CheckInputsCount(node, tf_import_flags, 2);
+ if (node.op() == "GatherV2") CheckInputsCount(node, tf_import_flags, 3);
const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
auto* op = new GatherOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
+ // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we
+ // should read it an pass it on to the TF Lite Interpreter.
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
}
@@ -2119,7 +2122,7 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
ConvertCastOperator(node, tf_import_flags, model);
} else if (node.op() == "Floor") {
ConvertFloorOperator(node, tf_import_flags, model);
- } else if (node.op() == "Gather") {
+ } else if (node.op() == "Gather" || node.op() == "GatherV2") {
ConvertGatherOperator(node, tf_import_flags, model);
} else if (node.op() == "ResizeBilinear") {
ConvertResizeBilinearOperator(node, tf_import_flags, model);