aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h7
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD13
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/squeeze.cc99
-rw-r--r--tensorflow/contrib/lite/kernels/squeeze_test.cc113
-rw-r--r--tensorflow/contrib/lite/model.cc11
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs6
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h179
-rw-r--r--tensorflow/contrib/lite/testing/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py35
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc1
-rw-r--r--tensorflow/contrib/lite/toco/BUILD2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc (renamed from tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc)11
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc23
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc9
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc2
18 files changed, 501 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index b9d88128c2..0c333f9e8c 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -214,6 +214,13 @@ typedef struct {
bool keep_dims;
} TfLiteMeanParams;
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int squeeze_dims[8];
+ int num_squeeze_dims;
+} TfLiteSqueezeParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index b30b28a6ad..7e9644f36c 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -102,6 +102,7 @@ cc_library(
"skip_gram.cc",
"space_to_batch_nd.cc",
"space_to_depth.cc",
+ "squeeze.cc",
"sub.cc",
"svdf.cc",
"transpose.cc",
@@ -492,6 +493,18 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "squeeze_test",
+ size = "small",
+ srcs = ["squeeze_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index fa63846020..45ad5f1890 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -56,6 +56,7 @@ TfLiteRegistration* Register_SPACE_TO_DEPTH();
TfLiteRegistration* Register_GATHER();
TfLiteRegistration* Register_TRANSPOSE();
TfLiteRegistration* Register_MEAN();
+TfLiteRegistration* Register_SQUEEZE();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -98,6 +99,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MEAN, Register_MEAN());
AddBuiltin(BuiltinOperator_DIV, Register_DIV());
AddBuiltin(BuiltinOperator_SUB, Register_SUB());
+ AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
}
TfLiteRegistration* BuiltinOpResolver::FindOp(
diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc
new file mode 100644
index 0000000000..29447ab021
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/squeeze.cc
@@ -0,0 +1,99 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace squeeze {
+
+struct SqueezeContext {
+ SqueezeContext(TfLiteContext* context, TfLiteNode* node) {
+ params = reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data);
+ input = GetInput(context, node, 0);
+ output = GetOutput(context, node, 0);
+ }
+ TfLiteSqueezeParams* params;
+ TfLiteTensor* input;
+ TfLiteTensor* output;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ SqueezeContext op_context(context, node);
+ int input_num_dims = NumDimensions(op_context.input);
+ int num_squeeze_dims = op_context.params->num_squeeze_dims;
+
+ // Determines number of dimensions of output tensor after squeeze.
+ const TfLiteIntArray* input_dims = op_context.input->dims;
+ const int* squeeze_dims = op_context.params->squeeze_dims;
+ TF_LITE_ENSURE(context, input_num_dims <= 8);
+ bool should_squeeze[8] = {false};
+ int num_squeezed_dims = 0;
+ if (num_squeeze_dims == 0) {
+ for (int idx = 0; idx < input_num_dims; ++idx) {
+ if (input_dims->data[idx] == 1) {
+ should_squeeze[idx] = true;
+ ++num_squeezed_dims;
+ }
+ }
+ } else {
+ for (int idx = 0; idx < num_squeeze_dims; ++idx) {
+ int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + input_num_dims
+ : squeeze_dims[idx];
+ TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims &&
+ input_dims->data[current] == 1);
+ if (!should_squeeze[current]) ++num_squeezed_dims;
+ should_squeeze[current] = true;
+ }
+ }
+ // Sets output dimensions.
+ TfLiteIntArray* output_dims =
+ TfLiteIntArrayCreate(input_num_dims - num_squeezed_dims);
+ for (int in_idx = 0, out_idx = 0; in_idx < input_num_dims; ++in_idx) {
+ if (!should_squeeze[in_idx]) {
+ output_dims->data[out_idx++] = input_dims->data[in_idx];
+ }
+ }
+ return context->ResizeTensor(context, op_context.output, output_dims);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ SqueezeContext op_context(context, node);
+ TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes);
+ memcpy(op_context.output->data.raw, op_context.input->data.raw,
+ op_context.input->bytes);
+ return kTfLiteOk;
+}
+
+} // namespace squeeze
+
+TfLiteRegistration* Register_SQUEEZE() {
+ static TfLiteRegistration r = {nullptr, nullptr, squeeze::Prepare,
+ squeeze::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/squeeze_test.cc b/tensorflow/contrib/lite/kernels/squeeze_test.cc
new file mode 100644
index 0000000000..409227b626
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/squeeze_test.cc
@@ -0,0 +1,113 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseSqueezeOpModel : public SingleOpModel {
+ public:
+ BaseSqueezeOpModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(
+ BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions,
+ CreateSqueezeOptions(builder_, builder_.CreateVector<int>(axis))
+ .Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ int input() { return input_; }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+class FloatSqueezeOpModel : public BaseSqueezeOpModel {
+ public:
+ using BaseSqueezeOpModel::BaseSqueezeOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+};
+
+TEST(FloatSqueezeOpTest, SqueezeAll) {
+ std::initializer_list<float> data = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}},
+ {TensorType_FLOAT32, {24}}, {});
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
+}
+
+TEST(FloatSqueezeOpTest, SqueezeSelectedAxis) {
+ std::initializer_list<float> data = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}},
+ {TensorType_FLOAT32, {24}}, {2});
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24}));
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
+}
+
+TEST(FloatSqueezeOpTest, SqueezeNegativeAxis) {
+ std::initializer_list<float> data = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}},
+ {TensorType_FLOAT32, {24}}, {-1, 0});
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 09fc9b9613..86e613736d 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -596,6 +596,17 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_SQUEEZE: {
+ auto* params = MallocPOD<TfLiteSqueezeParams>();
+ if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
+ const auto& squeeze_dims = schema_params->squeeze_dims();
+ FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
+ params->squeeze_dims, error_reporter);
+ params->num_squeeze_dims = squeeze_dims->Length();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
}
return builtin_data;
}
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 468c78dcce..b3602f799e 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -337,6 +337,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_MEAN:
case tflite::BuiltinOperator_DIV:
case tflite::BuiltinOperator_SUB:
+ case tflite::BuiltinOperator_SQUEEZE:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 80c05f34cb..f5251031b3 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -116,6 +116,7 @@ enum BuiltinOperator : byte {
MEAN = 40,
SUB = 41,
DIV = 42,
+ SQUEEZE = 43,
}
// Options for the builtin operators.
@@ -149,6 +150,7 @@ union BuiltinOptions {
MeanOptions,
SubOptions,
DivOptions,
+ SqueezeOptions,
}
enum Padding : byte { SAME, VALID }
@@ -326,6 +328,10 @@ table MeanOptions {
keep_dims: bool;
}
+table SqueezeOptions {
+ squeeze_dims:[int];
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 06c62604e4..a2ec8e40e9 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -114,6 +114,9 @@ struct TransposeOptionsT;
struct MeanOptions;
struct MeanOptionsT;
+struct SqueezeOptions;
+struct SqueezeOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -199,11 +202,12 @@ enum BuiltinOperator {
BuiltinOperator_MEAN = 40,
BuiltinOperator_SUB = 41,
BuiltinOperator_DIV = 42,
+ BuiltinOperator_SQUEEZE = 43,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_DIV
+ BuiltinOperator_MAX = BuiltinOperator_SQUEEZE
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[40] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[41] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -244,7 +248,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[40] {
BuiltinOperator_TRANSPOSE,
BuiltinOperator_MEAN,
BuiltinOperator_SUB,
- BuiltinOperator_DIV};
+ BuiltinOperator_DIV,
+ BuiltinOperator_SQUEEZE};
return values;
}
@@ -292,6 +297,7 @@ inline const char **EnumNamesBuiltinOperator() {
"MEAN",
"SUB",
"DIV",
+ "SQUEEZE",
nullptr};
return names;
}
@@ -332,11 +338,12 @@ enum BuiltinOptions {
BuiltinOptions_MeanOptions = 27,
BuiltinOptions_SubOptions = 28,
BuiltinOptions_DivOptions = 29,
+ BuiltinOptions_SqueezeOptions = 30,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_DivOptions
+ BuiltinOptions_MAX = BuiltinOptions_SqueezeOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[30] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[31] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -367,7 +374,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[30] {
BuiltinOptions_TransposeOptions,
BuiltinOptions_MeanOptions,
BuiltinOptions_SubOptions,
- BuiltinOptions_DivOptions};
+ BuiltinOptions_DivOptions,
+ BuiltinOptions_SqueezeOptions};
return values;
}
@@ -402,6 +410,7 @@ inline const char **EnumNamesBuiltinOptions() {
"MeanOptions",
"SubOptions",
"DivOptions",
+ "SqueezeOptions",
nullptr};
return names;
}
@@ -565,6 +574,11 @@ struct BuiltinOptionsTraits<DivOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_DivOptions;
};
+template <>
+struct BuiltinOptionsTraits<SqueezeOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -902,6 +916,16 @@ struct BuiltinOptionsUnion {
? reinterpret_cast<const DivOptionsT *>(value)
: nullptr;
}
+ SqueezeOptionsT *AsSqueezeOptions() {
+ return type == BuiltinOptions_SqueezeOptions
+ ? reinterpret_cast<SqueezeOptionsT *>(value)
+ : nullptr;
+ }
+ const SqueezeOptionsT *AsSqueezeOptions() const {
+ return type == BuiltinOptions_SqueezeOptions
+ ? reinterpret_cast<const SqueezeOptionsT *>(value)
+ : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj,
@@ -3348,6 +3372,71 @@ flatbuffers::Offset<MeanOptions> CreateMeanOptions(
flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct SqueezeOptionsT : public flatbuffers::NativeTable {
+ typedef SqueezeOptions TableType;
+ std::vector<int32_t> squeeze_dims;
+ SqueezeOptionsT() {}
+};
+
+struct SqueezeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SqueezeOptionsT NativeTableType;
+ enum { VT_SQUEEZE_DIMS = 4 };
+ const flatbuffers::Vector<int32_t> *squeeze_dims() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SQUEEZE_DIMS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_SQUEEZE_DIMS) &&
+ verifier.Verify(squeeze_dims()) && verifier.EndTable();
+ }
+ SqueezeOptionsT *UnPack(
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(
+ SqueezeOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<SqueezeOptions> Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct SqueezeOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_squeeze_dims(
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> squeeze_dims) {
+ fbb_.AddOffset(SqueezeOptions::VT_SQUEEZE_DIMS, squeeze_dims);
+ }
+ explicit SqueezeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SqueezeOptionsBuilder &operator=(const SqueezeOptionsBuilder &);
+ flatbuffers::Offset<SqueezeOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SqueezeOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> squeeze_dims = 0) {
+ SqueezeOptionsBuilder builder_(_fbb);
+ builder_.add_squeeze_dims(squeeze_dims);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptionsDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *squeeze_dims = nullptr) {
+ return tflite::CreateSqueezeOptions(
+ _fbb, squeeze_dims ? _fbb.CreateVector<int32_t>(*squeeze_dims) : 0);
+}
+
+flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -3622,6 +3711,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
? static_cast<const DivOptions *>(builtin_options())
: nullptr;
}
+ const SqueezeOptions *builtin_options_as_SqueezeOptions() const {
+ return builtin_options_type() == BuiltinOptions_SqueezeOptions
+ ? static_cast<const SqueezeOptions *>(builtin_options())
+ : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -3817,6 +3911,12 @@ inline const DivOptions *Operator::builtin_options_as<DivOptions>() const {
return builtin_options_as_DivOptions();
}
+template <>
+inline const SqueezeOptions *Operator::builtin_options_as<SqueezeOptions>()
+ const {
+ return builtin_options_as_SqueezeOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -5744,6 +5844,51 @@ inline flatbuffers::Offset<MeanOptions> CreateMeanOptions(
return tflite::CreateMeanOptions(_fbb, _axis, _keep_dims);
}
+inline SqueezeOptionsT *SqueezeOptions::UnPack(
+ const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new SqueezeOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void SqueezeOptions::UnPackTo(
+ SqueezeOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ {
+ auto _e = squeeze_dims();
+ if (_e) {
+ _o->squeeze_dims.resize(_e->size());
+ for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
+ _o->squeeze_dims[_i] = _e->Get(_i);
+ }
+ }
+ };
+}
+
+inline flatbuffers::Offset<SqueezeOptions> SqueezeOptions::Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateSqueezeOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs {
+ flatbuffers::FlatBufferBuilder *__fbb;
+ const SqueezeOptionsT *__o;
+ const flatbuffers::rehasher_function_t *__rehasher;
+ } _va = {&_fbb, _o, _rehasher};
+ (void)_va;
+ auto _squeeze_dims =
+ _o->squeeze_dims.size() ? _fbb.CreateVector(_o->squeeze_dims) : 0;
+ return tflite::CreateSqueezeOptions(_fbb, _squeeze_dims);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(
const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
@@ -6248,6 +6393,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier,
auto ptr = reinterpret_cast<const DivOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_SqueezeOptions: {
+ auto ptr = reinterpret_cast<const SqueezeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default:
return false;
}
@@ -6388,6 +6537,10 @@ inline void *BuiltinOptionsUnion::UnPack(
auto ptr = reinterpret_cast<const DivOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_SqueezeOptions: {
+ auto ptr = reinterpret_cast<const SqueezeOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default:
return nullptr;
}
@@ -6515,6 +6668,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(
auto ptr = reinterpret_cast<const DivOptionsT *>(value);
return CreateDivOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_SqueezeOptions: {
+ auto ptr = reinterpret_cast<const SqueezeOptionsT *>(value);
+ return CreateSqueezeOptions(_fbb, ptr, _rehasher).Union();
+ }
default:
return 0;
}
@@ -6655,6 +6812,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u)
value = new DivOptionsT(*reinterpret_cast<DivOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_SqueezeOptions: {
+ value =
+ new SqueezeOptionsT(*reinterpret_cast<SqueezeOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -6807,6 +6969,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_SqueezeOptions: {
+ auto ptr = reinterpret_cast<SqueezeOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default:
break;
}
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 48a7536d41..933da11353 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -45,6 +45,7 @@ gen_zipped_test_files(
"softmax.zip",
"space_to_batch_nd.zip",
"space_to_depth.zip",
+ "squeeze.zip",
"sub.zip",
"transpose.zip",
],
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 9e17c2a370..29bf2cd7fc 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1346,6 +1346,40 @@ def make_transpose_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_squeeze_tests(zip_path):
+ """Make a set of tests to do squeeze."""
+
+ test_parameters = [{
+ "dtype": [tf.int32, tf.float32, tf.int64],
+ "input_shape": [[1, 2, 1, 3, 1, 4, 1, 1]],
+ "axis": [
+ None, [], [0, 2], [4, 7], [-1, 0, 2, 0, 7, -6], [1], [2, 3, 2],
+ [-1, -2, -4, -6, -8], [0, 2, 4, 6, 7], [7, 6, 4, 2, 0], [6, 6],
+ [0, 1, 2, 3, 4, 5, 6, 7], [-2, -3, 1, 0, 7, -5]
+ ],
+ }, {
+ "dtype": [tf.int32, tf.float32, tf.int64],
+ "input_shape": [[1]],
+ "axis": [None, [], [0], [-1]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=parameters["dtype"],
+ name="input",
+ shape=parameters["input_shape"])
+ out = tf.squeeze(input_tensor, axis=parameters["axis"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
"""Given an input perform a sequence of TensorFlow ops to produce l2pool."""
return tf.sqrt(tf.nn.avg_pool(
@@ -1403,6 +1437,7 @@ def main(unused_args):
"space_to_depth.zip": make_space_to_depth_tests,
"transpose.zip": make_transpose_tests,
"mean.zip": make_mean_tests,
+ "squeeze.zip": make_squeeze_tests,
}
out = FLAGS.zip_to_output
bin_path = FLAGS.toco
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 3c01302c43..c8a6e07abd 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -262,6 +262,7 @@ INSTANTIATE_TESTS(sub)
INSTANTIATE_TESTS(div)
INSTANTIATE_TESTS(transpose)
INSTANTIATE_TESTS(mean)
+INSTANTIATE_TESTS(squeeze)
} // namespace testing
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index cea5b4e92d..967e304742 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -219,11 +219,11 @@ cc_library(
"graph_transformations/resolve_reshape_attributes.cc",
"graph_transformations/resolve_slice_attributes.cc",
"graph_transformations/resolve_space_to_batch_nd_attributes.cc",
+ "graph_transformations/resolve_squeeze_attributes.cc",
"graph_transformations/resolve_strided_slice_attributes.cc",
"graph_transformations/resolve_tensorflow_concat.cc",
"graph_transformations/resolve_tensorflow_matmul.cc",
"graph_transformations/resolve_tensorflow_merge.cc",
- "graph_transformations/resolve_tensorflow_squeeze.cc",
"graph_transformations/resolve_tensorflow_switch.cc",
"graph_transformations/resolve_tensorflow_tile.cc",
"graph_transformations/resolve_transpose_attributes.cc",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 9300ab53a7..9ec9f92c90 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -146,7 +146,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge)
-DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSqueeze)
+DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
index 1d3f42b5ec..dd3e73635a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
@@ -25,15 +25,13 @@ limitations under the License.
namespace toco {
-bool ResolveTensorFlowSqueeze::Run(Model* model, std::size_t op_index) {
- const auto squeeze_it = model->operators.begin() + op_index;
- const auto* squeeze_op = squeeze_it->get();
+bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) {
+ auto* squeeze_op = model->operators[op_index].get();
if (squeeze_op->type != OperatorType::kSqueeze) {
return false;
}
-
- CHECK_EQ(squeeze_op->inputs.size(), 1);
- CHECK_EQ(squeeze_op->outputs.size(), 1);
+ DCHECK_EQ(squeeze_op->inputs.size(), 1);
+ DCHECK_EQ(squeeze_op->outputs.size(), 1);
// If the output is consumed by a reshape op, it's a trivial squeeze.
if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) {
@@ -47,7 +45,6 @@ bool ResolveTensorFlowSqueeze::Run(Model* model, std::size_t op_index) {
return RemoveTrivialPassthroughOp(this, model, op_index);
}
}
-
return false;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 5b98b71155..0111e1ed92 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -584,6 +584,27 @@ class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
}
};
+class Squeeze
+ : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
+ ::tflite::BuiltinOptions_SqueezeOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
+ return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->squeeze_dims.insert(op->squeeze_dims.end(),
+ options.squeeze_dims()->begin(),
+ options.squeeze_dims()->end());
+ }
+};
+
class Split : public CustomOperator<TensorFlowSplitOperator> {
public:
using CustomOperator::CustomOperator;
@@ -754,6 +775,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kTranspose));
ops.emplace_back(
new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
+ ops.emplace_back(
+ new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
// Custom Operators.
ops.emplace_back(new Cast("CAST", OperatorType::kCast));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index debce63760..77c70847d1 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -389,6 +389,15 @@ TEST_F(OperatorTest, Transpose) {
EXPECT_EQ(op.perm, output_toco_op->perm);
}
+TEST_F(OperatorTest, Squeeze) {
+ SqueezeOperator op;
+ op.squeeze_dims = {-2, -3, 4, 1, 4};
+
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("SQUEEZE", OperatorType::kSqueeze), op);
+ EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims);
+}
+
TEST_F(OperatorTest, TensorFlowUnsupported) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 0bcf4596de..94b4d14696 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -74,7 +74,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveConstantStridedSlice);
transformations->Add(new ResolveConstantUnaryOperator);
transformations->Add(new ResolveTensorFlowMerge);
- transformations->Add(new ResolveTensorFlowSqueeze);
+ transformations->Add(new ResolveSqueezeAttributes);
transformations->Add(new ResolveTensorFlowSwitch);
transformations->Add(new ResolveTensorFlowTile);
transformations->Add(new ResolveTensorFlowConcat);