diff options
4 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index dc9492f5e2..555ea90034 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -29,7 +29,7 @@ gen_zipped_test_files( "exp.zip", "fully_connected.zip", "fused_batch_norm.zip", - # "gather.zip", #TODO(b/76437794) + "gather.zip", "global_batch_norm.zip", "l2_pool.zip", "l2norm.zip", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index e4ef17585f..cb5c500136 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -874,12 +874,11 @@ def make_gather_tests(zip_path): # TODO(mgubin): add string tests when they are supported by Toco. # TODO(mgubin): add tests for Nd indices when they are supported by # TfLite. - # TODO(mgubin): add tests for axis != 0 when it is supported by TfLite. "params_dtype": [tf.float32, tf.int32], "params_shape": [[10], [1, 2, 20]], "indices_dtype": [tf.int32], "indices_shape": [[3], [5]], - "axis": [0], # axis!=0 is GatherV2 + "axis": [0, 1], }] def build_graph(parameters): diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 08354b762c..a4a7283508 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -91,6 +91,9 @@ std::map<string, string> kBrokenTests = { // PRelu only supports 4D input with (1, 1, channels) 3D alpha now. {R"(^\/prelu.*shared_axes=\[1\])", "75975192"}, + + // No support for axis!=0 in GatherV2. + {R"(^\/gather.*axis=1)", "76910444"}, }; // Allows test data to be unzipped into a temporary directory and makes @@ -244,7 +247,7 @@ INSTANTIATE_TESTS(div) INSTANTIATE_TESTS(exp) INSTANTIATE_TESTS(fully_connected) INSTANTIATE_TESTS(fused_batch_norm) -// INSTANTIATE_TESTS(gather) //TODO(b/76437794) +INSTANTIATE_TESTS(gather) INSTANTIATE_TESTS(global_batch_norm) INSTANTIATE_TESTS(l2_pool) INSTANTIATE_TESTS(l2norm) diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index b844e0b948..c26e4bddff 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1343,13 +1343,16 @@ void ConvertFloorOperator(const NodeDef& node, void ConvertGatherOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "Gather"); - CheckInputsCount(node, tf_import_flags, 2); + CHECK(node.op() == "Gather" || node.op() == "GatherV2"); + if (node.op() == "Gather") CheckInputsCount(node, tf_import_flags, 2); + if (node.op() == "GatherV2") CheckInputsCount(node, tf_import_flags, 3); const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64); auto* op = new GatherOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); + // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we + // should read it an pass it on to the TF Lite Interpreter. op->outputs.push_back(node.name()); model->operators.emplace_back(op); } @@ -2119,7 +2122,7 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef( ConvertCastOperator(node, tf_import_flags, model); } else if (node.op() == "Floor") { ConvertFloorOperator(node, tf_import_flags, model); - } else if (node.op() == "Gather") { + } else if (node.op() == "Gather" || node.op() == "GatherV2") { ConvertGatherOperator(node, tf_import_flags, model); } else if (node.op() == "ResizeBilinear") { ConvertResizeBilinearOperator(node, tf_import_flags, model); |