aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2017-12-14 21:26:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-14 21:31:06 -0800
commitdbcb1ffcca6a3c52e3c109a1739018350bc41925 (patch)
treeaa4f440d1f8f4382cea7a1b4e4cfc623da2a76f0 /tensorflow
parentf806269602219d5095265d036f294cc9a6260971 (diff)
Support BatchToSpaceND in TFLite
The internal implementation only support 4D tensors for now. The dimension has to be 1 batch + 2 spatial + 1 other. The most common format within this restriction is NHWC. Cropping is not supported by the internal implementation. PiperOrigin-RevId: 179143332
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h11
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD13
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc161
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc78
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc18
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs8
-rwxr-xr-x[-rw-r--r--]tensorflow/contrib/lite/schema/schema_generated.h234
-rw-r--r--tensorflow/contrib/lite/testing/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py42
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc1
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc70
-rw-r--r--tensorflow/contrib/lite/toco/model.h4
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc34
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc13
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc1
19 files changed, 688 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 548864a1e9..5c6f3016b1 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -105,6 +105,17 @@ typedef struct {
} TfLiteAddParams;
typedef struct {
+ // Number of spatial dimensions.
+ // For now only NHWC is supported, and the value should always be 2.
+ int num_spatial_dimensions;
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int block_shape[2];
+ int before_crops[2];
+ int after_crops[2];
+} TfLiteBatchToSpaceNDParams;
+
+typedef struct {
TfLiteFusedActivation activation;
} TfLiteMulParams;
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 3908960c33..cc02cddb3d 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -77,6 +77,7 @@ cc_library(
"activations.cc",
"add.cc",
"basic_rnn.cc",
+ "batch_to_space_nd.cc",
"concatenation.cc",
"conv.cc",
"depthwise_conv.cc",
@@ -157,6 +158,18 @@ tf_cc_test(
)
tf_cc_test(
+ name = "batch_to_space_nd_test",
+ size = "small",
+ srcs = ["batch_to_space_nd_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "concatenation_test",
size = "small",
srcs = ["concatenation_test.cc"],
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
new file mode 100644
index 0000000000..0eed680fdc
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -0,0 +1,161 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace batch_to_space_nd {
+
+// This file has two implementations of BatchToSpaceND.
+enum KernelType {
+ kReference,
+ kGenericOptimized,
+};
+
+struct BatchToSpaceNDContext {
+ BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) {
+ params = reinterpret_cast<TfLiteBatchToSpaceNDParams*>(node->builtin_data);
+ input = GetInput(context, node, 0);
+ output = GetOutput(context, node, 0);
+ }
+ TfLiteBatchToSpaceNDParams* params;
+ TfLiteTensor* input;
+ TfLiteTensor* output;
+};
+
+// Currently, only 4D NHWC input/output op_context are supported.
+// The 4D array need to have exactly 2 spatial dimensions.
+// TODO(ycling): Support arbitrary dimension in BatchToSpaceND.
+const int kInputDimensionNum = 4;
+const int kOutputDimensionNum = 4;
+const int kSpatialDimensionNum = 2;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // The 2nd tensor (block_shape) and the 3rd tensor (crops) are ignored now.
+ TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ BatchToSpaceNDContext op_context(context, node);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
+ kInputDimensionNum);
+ TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions,
+ kSpatialDimensionNum);
+ TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+
+ const TfLiteIntArray* input_size = op_context.input->dims;
+ const int* block_shape = op_context.params->block_shape;
+
+ // Number of batch must be multiple of (block_shape[0] * block_shape[1]).
+ TF_LITE_ENSURE_EQ(context,
+ input_size->data[0] % (block_shape[0] * block_shape[1]), 0);
+
+ const int output_batch_size =
+ input_size->data[0] / (block_shape[0] * block_shape[1]);
+ const int output_height = input_size->data[1] * block_shape[0];
+ const int output_width = input_size->data[2] * block_shape[1];
+ const int output_channel_size = input_size->data[3];
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum);
+ output_size->data[0] = output_batch_size;
+ output_size->data[1] = output_height;
+ output_size->data[2] = output_width;
+ output_size->data[3] = output_channel_size;
+
+ return context->ResizeTensor(context, op_context.output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ BatchToSpaceNDContext op_context(context, node);
+
+ int block_shape_dims_array[1] = {kSpatialDimensionNum};
+ Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1);
+
+#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
+ type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \
+ GetTensorDims(op_context.input), \
+ op_context.params->block_shape, block_shape_dims, \
+ GetTensorData<scalar>(op_context.output), \
+ GetTensorDims(op_context.output))
+ switch (op_context.input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ if (kernel_type == kReference) {
+ TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float);
+ } else {
+ TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float);
+ }
+ break;
+ case kTfLiteUInt8:
+ if (kernel_type == kReference) {
+ TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t);
+ } else {
+ TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t);
+ }
+ break;
+ case kTfLiteInt32:
+ if (kernel_type == kReference) {
+ TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t);
+ } else {
+ TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t);
+ }
+ break;
+ case kTfLiteInt64:
+ if (kernel_type == kReference) {
+ TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t);
+ } else {
+ TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t);
+ }
+ break;
+ default:
+ context->ReportError(context,
+ "Type is currently not supported by BatchToSpace.");
+ return kTfLiteError;
+ }
+#undef TF_LITE_BATCH_TO_SPACE_ND
+ return kTfLiteOk;
+}
+
+} // namespace batch_to_space_nd
+
+TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, batch_to_space_nd::Prepare,
+ batch_to_space_nd::Eval<batch_to_space_nd::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, batch_to_space_nd::Prepare,
+ batch_to_space_nd::Eval<batch_to_space_nd::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_BATCH_TO_SPACE_ND() {
+ return Register_BATCH_TO_SPACE_ND_GENERIC_OPT();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
new file mode 100644
index 0000000000..3ec4efbebc
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BatchToSpaceNDOpModel : public SingleOpModel {
+ public:
+ BatchToSpaceNDOpModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> block_shape,
+ std::initializer_list<int> before_crops,
+ std::initializer_list<int> after_crops) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
+ BuiltinOptions_BatchToSpaceNDOptions,
+ CreateBatchToSpaceNDOptions(
+ builder_, builder_.CreateVector<int>(block_shape),
+ builder_.CreateVector<int>(before_crops),
+ builder_.CreateVector<int>(after_crops))
+ .Union());
+ BuildInterpreter({input_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(BatchToSpaceNDOpTest, SimpleTest) {
+ BatchToSpaceNDOpModel m({4, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0});
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7,
+ 4, 8, 11, 15, 12, 16}));
+}
+
+TEST(BatchToSpaceNDOpTest, InvalidShapeTest) {
+ EXPECT_DEATH(BatchToSpaceNDOpModel({3, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}),
+ "Cannot allocate tensors");
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 3d1edeef01..d4e7503f48 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -40,6 +40,7 @@ TfLiteRegistration* Register_HASHTABLE_LOOKUP();
TfLiteRegistration* Register_SOFTMAX();
TfLiteRegistration* Register_CONCATENATION();
TfLiteRegistration* Register_ADD();
+TfLiteRegistration* Register_BATCH_TO_SPACE_ND();
TfLiteRegistration* Register_MUL();
TfLiteRegistration* Register_L2_NORMALIZATION();
TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION();
@@ -75,6 +76,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX());
AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION());
AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+ AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND());
AddBuiltin(BuiltinOperator_MUL, Register_MUL());
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 4ef2c942c1..94e22b2659 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -518,6 +518,24 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_BATCH_TO_SPACE_ND: {
+ auto* params = MallocPOD<TfLiteBatchToSpaceNDParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_BatchToSpaceNDOptions()) {
+ const auto& block_shape = schema_params->block_shape();
+ FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape,
+ params->block_shape, error_reporter);
+ const auto& before_crops = schema_params->before_crops();
+ FlatBufferIntVectorToArray(sizeof(params->before_crops), before_crops,
+ params->before_crops, error_reporter);
+ const auto& after_crops = schema_params->after_crops();
+ FlatBufferIntVectorToArray(sizeof(params->after_crops), after_crops,
+ params->after_crops, error_reporter);
+ params->num_spatial_dimensions = block_shape->Length();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
}
return builtin_data;
}
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 6b93a70bff..5cb0afcea0 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -307,6 +307,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_SKIP_GRAM:
case tflite::BuiltinOperator_RELU1:
case tflite::BuiltinOperator_GATHER:
+ case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 8b48543fc8..cc31e03dfc 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -107,6 +107,7 @@ enum BuiltinOperator : byte {
PAD = 34,
UNIDIRECTIONAL_SEQUENCE_RNN = 35,
GATHER = 36,
+ BATCH_TO_SPACE_ND = 37,
}
// Options for the builtin operators.
@@ -134,6 +135,7 @@ union BuiltinOptions {
MulOptions,
PadOptions,
GatherOptions,
+ BatchToSpaceNDOptions,
}
enum Padding : byte { SAME, VALID }
@@ -258,6 +260,12 @@ table ReshapeOptions {
new_shape:[int];
}
+table BatchToSpaceNDOptions {
+ block_shape:[int];
+ before_crops:[int];
+ after_crops:[int];
+}
+
table SkipGramOptions {
ngram_size: int;
max_skip_size: int;
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 7de205e1e4..aa169198fe 100644..100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -1,3 +1,4 @@
+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -85,6 +86,9 @@ struct PadOptionsT;
struct ReshapeOptions;
struct ReshapeOptionsT;
+struct BatchToSpaceNDOptions;
+struct BatchToSpaceNDOptionsT;
+
struct SkipGramOptions;
struct SkipGramOptionsT;
@@ -176,11 +180,12 @@ enum BuiltinOperator {
BuiltinOperator_PAD = 34,
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN = 35,
BuiltinOperator_GATHER = 36,
+ BuiltinOperator_BATCH_TO_SPACE_ND = 37,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_GATHER
+ BuiltinOperator_MAX = BuiltinOperator_BATCH_TO_SPACE_ND
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[34] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[35] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -215,7 +220,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[34] {
BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
BuiltinOperator_PAD,
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
- BuiltinOperator_GATHER};
+ BuiltinOperator_GATHER,
+ BuiltinOperator_BATCH_TO_SPACE_ND};
return values;
}
@@ -257,6 +263,7 @@ inline const char **EnumNamesBuiltinOperator() {
"PAD",
"UNIDIRECTIONAL_SEQUENCE_RNN",
"GATHER",
+ "BATCH_TO_SPACE_ND",
nullptr};
return names;
}
@@ -291,11 +298,12 @@ enum BuiltinOptions {
BuiltinOptions_MulOptions = 21,
BuiltinOptions_PadOptions = 22,
BuiltinOptions_GatherOptions = 23,
+ BuiltinOptions_BatchToSpaceNDOptions = 24,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_GatherOptions
+ BuiltinOptions_MAX = BuiltinOptions_BatchToSpaceNDOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[24] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[25] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -320,7 +328,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[24] {
BuiltinOptions_EmbeddingLookupSparseOptions,
BuiltinOptions_MulOptions,
BuiltinOptions_PadOptions,
- BuiltinOptions_GatherOptions};
+ BuiltinOptions_GatherOptions,
+ BuiltinOptions_BatchToSpaceNDOptions};
return values;
}
@@ -349,6 +358,7 @@ inline const char **EnumNamesBuiltinOptions() {
"MulOptions",
"PadOptions",
"GatherOptions",
+ "BatchToSpaceNDOptions",
nullptr};
return names;
}
@@ -482,6 +492,11 @@ struct BuiltinOptionsTraits<GatherOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_GatherOptions;
};
+template <>
+struct BuiltinOptionsTraits<BatchToSpaceNDOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_BatchToSpaceNDOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -759,6 +774,16 @@ struct BuiltinOptionsUnion {
? reinterpret_cast<const GatherOptionsT *>(value)
: nullptr;
}
+ BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() {
+ return type == BuiltinOptions_BatchToSpaceNDOptions
+ ? reinterpret_cast<BatchToSpaceNDOptionsT *>(value)
+ : nullptr;
+ }
+ const BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() const {
+ return type == BuiltinOptions_BatchToSpaceNDOptions
+ ? reinterpret_cast<const BatchToSpaceNDOptionsT *>(value)
+ : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj,
@@ -2512,6 +2537,101 @@ flatbuffers::Offset<ReshapeOptions> CreateReshapeOptions(
flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct BatchToSpaceNDOptionsT : public flatbuffers::NativeTable {
+ typedef BatchToSpaceNDOptions TableType;
+ std::vector<int32_t> block_shape;
+ std::vector<int32_t> before_crops;
+ std::vector<int32_t> after_crops;
+ BatchToSpaceNDOptionsT() {}
+};
+
+struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS
+ : private flatbuffers::Table {
+ typedef BatchToSpaceNDOptionsT NativeTableType;
+ enum { VT_BLOCK_SHAPE = 4, VT_BEFORE_CROPS = 6, VT_AFTER_CROPS = 8 };
+ const flatbuffers::Vector<int32_t> *block_shape() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_BLOCK_SHAPE);
+ }
+ const flatbuffers::Vector<int32_t> *before_crops() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_BEFORE_CROPS);
+ }
+ const flatbuffers::Vector<int32_t> *after_crops() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_AFTER_CROPS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_BLOCK_SHAPE) &&
+ verifier.Verify(block_shape()) &&
+ VerifyOffset(verifier, VT_BEFORE_CROPS) &&
+ verifier.Verify(before_crops()) &&
+ VerifyOffset(verifier, VT_AFTER_CROPS) &&
+ verifier.Verify(after_crops()) && verifier.EndTable();
+ }
+ BatchToSpaceNDOptionsT *UnPack(
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(
+ BatchToSpaceNDOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<BatchToSpaceNDOptions> Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct BatchToSpaceNDOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_block_shape(
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_shape) {
+ fbb_.AddOffset(BatchToSpaceNDOptions::VT_BLOCK_SHAPE, block_shape);
+ }
+ void add_before_crops(
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> before_crops) {
+ fbb_.AddOffset(BatchToSpaceNDOptions::VT_BEFORE_CROPS, before_crops);
+ }
+ void add_after_crops(
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> after_crops) {
+ fbb_.AddOffset(BatchToSpaceNDOptions::VT_AFTER_CROPS, after_crops);
+ }
+ explicit BatchToSpaceNDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ BatchToSpaceNDOptionsBuilder &operator=(const BatchToSpaceNDOptionsBuilder &);
+ flatbuffers::Offset<BatchToSpaceNDOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<BatchToSpaceNDOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<BatchToSpaceNDOptions> CreateBatchToSpaceNDOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_shape = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> before_crops = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> after_crops = 0) {
+ BatchToSpaceNDOptionsBuilder builder_(_fbb);
+ builder_.add_after_crops(after_crops);
+ builder_.add_before_crops(before_crops);
+ builder_.add_block_shape(block_shape);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<BatchToSpaceNDOptions>
+CreateBatchToSpaceNDOptionsDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *block_shape = nullptr,
+ const std::vector<int32_t> *before_crops = nullptr,
+ const std::vector<int32_t> *after_crops = nullptr) {
+ return tflite::CreateBatchToSpaceNDOptions(
+ _fbb, block_shape ? _fbb.CreateVector<int32_t>(*block_shape) : 0,
+ before_crops ? _fbb.CreateVector<int32_t>(*before_crops) : 0,
+ after_crops ? _fbb.CreateVector<int32_t>(*after_crops) : 0);
+}
+
+flatbuffers::Offset<BatchToSpaceNDOptions> CreateBatchToSpaceNDOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct SkipGramOptionsT : public flatbuffers::NativeTable {
typedef SkipGramOptions TableType;
int32_t ngram_size;
@@ -3000,6 +3120,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
? static_cast<const GatherOptions *>(builtin_options())
: nullptr;
}
+ const BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions()
+ const {
+ return builtin_options_type() == BuiltinOptions_BatchToSpaceNDOptions
+ ? static_cast<const BatchToSpaceNDOptions *>(builtin_options())
+ : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -3162,6 +3288,12 @@ inline const GatherOptions *Operator::builtin_options_as<GatherOptions>()
return builtin_options_as_GatherOptions();
}
+template <>
+inline const BatchToSpaceNDOptions *
+Operator::builtin_options_as<BatchToSpaceNDOptions>() const {
+ return builtin_options_as_BatchToSpaceNDOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -4614,6 +4746,74 @@ inline flatbuffers::Offset<ReshapeOptions> CreateReshapeOptions(
return tflite::CreateReshapeOptions(_fbb, _new_shape);
}
+inline BatchToSpaceNDOptionsT *BatchToSpaceNDOptions::UnPack(
+ const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new BatchToSpaceNDOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void BatchToSpaceNDOptions::UnPackTo(
+ BatchToSpaceNDOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ {
+ auto _e = block_shape();
+ if (_e) {
+ _o->block_shape.resize(_e->size());
+ for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
+ _o->block_shape[_i] = _e->Get(_i);
+ }
+ }
+ };
+ {
+ auto _e = before_crops();
+ if (_e) {
+ _o->before_crops.resize(_e->size());
+ for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
+ _o->before_crops[_i] = _e->Get(_i);
+ }
+ }
+ };
+ {
+ auto _e = after_crops();
+ if (_e) {
+ _o->after_crops.resize(_e->size());
+ for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
+ _o->after_crops[_i] = _e->Get(_i);
+ }
+ }
+ };
+}
+
+inline flatbuffers::Offset<BatchToSpaceNDOptions> BatchToSpaceNDOptions::Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateBatchToSpaceNDOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<BatchToSpaceNDOptions> CreateBatchToSpaceNDOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs {
+ flatbuffers::FlatBufferBuilder *__fbb;
+ const BatchToSpaceNDOptionsT *__o;
+ const flatbuffers::rehasher_function_t *__rehasher;
+ } _va = {&_fbb, _o, _rehasher};
+ (void)_va;
+ auto _block_shape =
+ _o->block_shape.size() ? _fbb.CreateVector(_o->block_shape) : 0;
+ auto _before_crops =
+ _o->before_crops.size() ? _fbb.CreateVector(_o->before_crops) : 0;
+ auto _after_crops =
+ _o->after_crops.size() ? _fbb.CreateVector(_o->after_crops) : 0;
+ return tflite::CreateBatchToSpaceNDOptions(_fbb, _block_shape, _before_crops,
+ _after_crops);
+}
+
inline SkipGramOptionsT *SkipGramOptions::UnPack(
const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new SkipGramOptionsT();
@@ -5265,6 +5465,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier,
auto ptr = reinterpret_cast<const GatherOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_BatchToSpaceNDOptions: {
+ auto ptr = reinterpret_cast<const BatchToSpaceNDOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default:
return false;
}
@@ -5381,6 +5585,10 @@ inline void *BuiltinOptionsUnion::UnPack(
auto ptr = reinterpret_cast<const GatherOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_BatchToSpaceNDOptions: {
+ auto ptr = reinterpret_cast<const BatchToSpaceNDOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default:
return nullptr;
}
@@ -5484,6 +5692,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(
auto ptr = reinterpret_cast<const GatherOptionsT *>(value);
return CreateGatherOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_BatchToSpaceNDOptions: {
+ auto ptr = reinterpret_cast<const BatchToSpaceNDOptionsT *>(value);
+ return CreateBatchToSpaceNDOptions(_fbb, ptr, _rehasher).Union();
+ }
default:
return 0;
}
@@ -5597,6 +5809,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u)
value = new GatherOptionsT(*reinterpret_cast<GatherOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_BatchToSpaceNDOptions: {
+ value = new BatchToSpaceNDOptionsT(
+ *reinterpret_cast<BatchToSpaceNDOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -5719,6 +5936,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_BatchToSpaceNDOptions: {
+ auto ptr = reinterpret_cast<BatchToSpaceNDOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default:
break;
}
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index b63c0c058c..96800304e5 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -18,6 +18,7 @@ gen_zipped_test_files(
files = [
"add.zip",
"avg_pool.zip",
+ "batch_to_space_nd.zip",
"concat.zip",
"constant.zip",
"control_dep.zip",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 4c01fedb1e..02f59438cd 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -96,6 +96,10 @@ KNOWN_BUGS = {
r"space_to_depth.*(float16|int32|uint8|int64)": "68018134",
# Gather doesn't support int64 indices.
r"gather.*indices_dtype=int64": "XXXX",
+ # BatchToSpaceND doesn't support cropping.
+ r"batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\]": "70594634",
+ # BatchToSpaceND only supports 4D tensors.
+ r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
}
@@ -1198,6 +1202,43 @@ def make_space_to_depth_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_batch_to_space_nd_tests(zip_path):
+ """Make a set of tests to do batch_to_space_nd."""
+
+ test_parameters = [
+ {
+ "dtype": [tf.float32, tf.int64, tf.int32],
+ "input_shape": [[12, 2, 2, 1]],
+ "block_shape": [[1, 4], [2, 2], [3, 4]],
+ "crops": [[[0, 0], [0, 0]], [[1, 1], [1, 1]]],
+ },
+ # Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others.
+ {
+ "dtype": [tf.float32],
+ "input_shape": [[8, 2, 2, 2, 1, 1]],
+ "block_shape": [[2, 2, 2]],
+ "crops": [[[0, 0], [0, 0], [0, 0]]],
+ },
+ ]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=parameters["dtype"],
+ name="input",
+ shape=parameters["input_shape"])
+ out = tf.batch_to_space_nd(input_tensor, parameters["block_shape"],
+ parameters["crops"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
"""Given an input perform a sequence of TensorFlow ops to produce l2pool."""
return tf.sqrt(tf.nn.avg_pool(
@@ -1226,6 +1267,7 @@ def main(unused_args):
dispatch = {
"control_dep.zip": make_control_dep_tests,
"add.zip": make_add_tests,
+ "batch_to_space_nd.zip": make_batch_to_space_nd_tests,
"conv.zip": make_conv_tests,
"constant.zip": make_constant_tests,
"depthwiseconv.zip": make_depthwiseconv_tests,
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 29f0c68ba4..4c05979e24 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -241,6 +241,7 @@ TEST_P(OpsTest, RunStuff) {
INSTANTIATE_TESTS(add)
INSTANTIATE_TESTS(avg_pool)
+INSTANTIATE_TESTS(batch_to_space_nd)
INSTANTIATE_TESTS(concat)
INSTANTIATE_TESTS(constant)
INSTANTIATE_TESTS(control_dep)
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 78c036fa77..7556a402f9 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -202,6 +202,7 @@ cc_library(
"graph_transformations/remove_trivial_reshape.cc",
"graph_transformations/remove_unused_op.cc",
"graph_transformations/resolve_batch_normalization.cc",
+ "graph_transformations/resolve_batch_to_space_nd_attributes.cc",
"graph_transformations/resolve_constant_binary.cc",
"graph_transformations/resolve_constant_concatenation.cc",
"graph_transformations/resolve_constant_fake_quant.cc",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index c1dc41170c..2eb244ee08 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -152,6 +152,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions)
+DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
new file mode 100644
index 0000000000..a4f198e92f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
@@ -0,0 +1,70 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
+ const auto op_it = model->operators.begin() + op_index;
+ if (op_it->get()->type != OperatorType::kBatchToSpaceND) return false;
+
+ auto* op = static_cast<BatchToSpaceNDOperator*>(op_it->get());
+
+ // The attributes are resolved only when the 3 attributes (block_shape,
+ // before_crops, after_crops) are all constant.
+ if (!op->block_shape.empty()) {
+ return false;
+ }
+
+ CHECK_EQ(op->inputs.size(), 3);
+ if (!IsConstantParameterArray(*model, op->inputs[1]) or
+ !IsConstantParameterArray(*model, op->inputs[2]))
+ return false;
+
+ // Handling block_shape.
+ const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ if (!block_shape_array.has_shape()) return false;
+ const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
+ CHECK_EQ(block_shape_dims.size(), 1);
+ std::vector<int> block_shape_buffer =
+ block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < block_shape_dims[0]; ++i) {
+ op->block_shape.push_back(block_shape_buffer[i]);
+ }
+
+ // Handling crops.
+ const auto& crops_array = *model->arrays[op->inputs[2]];
+ if (!crops_array.has_shape()) return false;
+ const std::vector<int>& crops_dims = crops_array.shape().dims();
+ CHECK_EQ(crops_dims.size(), 2);
+ std::vector<int> crops_buffer =
+ crops_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < crops_dims[0]; ++i) {
+ op->before_crops.push_back(crops_buffer[i * 2]);
+ op->after_crops.push_back(crops_buffer[i * 2 + 1]);
+ }
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index d155d2bb5c..7305f858da 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -1261,6 +1261,10 @@ struct SpaceToBatchNDOperator : Operator {
// TensorFlow equivalent: BatchToSpaceND
struct BatchToSpaceNDOperator : Operator {
BatchToSpaceNDOperator() : Operator(OperatorType::kBatchToSpaceND) {}
+
+ std::vector<int> block_shape;
+ std::vector<int> before_crops;
+ std::vector<int> after_crops;
};
// Mean operator.
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 7a68c6dbc9..ede6df88ab 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -130,6 +130,37 @@ class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
}
};
+class BatchToSpaceND
+ : public BuiltinOperator<BatchToSpaceNDOperator,
+ ::tflite::BatchToSpaceNDOptions,
+ ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto block_shape = builder->CreateVector(op.block_shape);
+ auto before_crops = builder->CreateVector(op.before_crops);
+ auto after_crops = builder->CreateVector(op.after_crops);
+ return ::tflite::CreateBatchToSpaceNDOptions(*builder, block_shape,
+ before_crops, after_crops);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->block_shape.insert(op->block_shape.end(),
+ options.block_shape()->begin(),
+ options.block_shape()->end());
+ op->before_crops.insert(op->before_crops.end(),
+ options.before_crops()->begin(),
+ options.before_crops()->end());
+ op->after_crops.insert(op->after_crops.end(),
+ options.after_crops()->begin(),
+ options.after_crops()->end());
+ }
+};
+
class Cast : public CustomOperator<CastOperator> {
public:
using CustomOperator::CustomOperator;
@@ -571,6 +602,9 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D,
OperatorType::kAveragePool));
+ ops.emplace_back(
+ new BatchToSpaceND(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
+ OperatorType::kBatchToSpaceND));
ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION,
OperatorType::kConcatenation));
ops.emplace_back(
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index caecbd0325..735eea4ddc 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -119,6 +119,19 @@ TEST_F(OperatorTest, BuiltinAdd) {
output_toco_op->fused_activation_function);
}
+TEST_F(OperatorTest, BuiltinBatchToSpaceND) {
+ BatchToSpaceNDOperator op;
+ op.block_shape = {2, 2};
+ op.before_crops = {1, 2};
+ op.after_crops = {3, 4};
+
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("BATCH_TO_SPACE_ND", OperatorType::kBatchToSpaceND), op);
+ EXPECT_EQ(op.block_shape, output_toco_op->block_shape);
+ EXPECT_EQ(op.before_crops, output_toco_op->before_crops);
+ EXPECT_EQ(op.after_crops, output_toco_op->after_crops);
+}
+
TEST_F(OperatorTest, CustomCast) {
CastOperator op;
op.src_data_type = ArrayDataType::kFloat;
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 7e50c2207f..d6652b7a41 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -78,6 +78,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new IdentifyRelu1);
transformations->Add(new RemoveTrivialBinaryOperator);
transformations->Add(new ReadFakeQuantMinMax);
+ transformations->Add(new ResolveBatchToSpaceNDAttributes);
transformations->Add(new ResolvePadAttributes);
transformations->Add(new ResolveStridedSliceAttributes);
transformations->Add(new ResolveSliceAttributes);