aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-05 18:45:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-05 18:48:59 -0800
commit1a0b637df8d082301118dd0f85ec63704f862aeb (patch)
tree9638dd8ae3174a20c456dd7d1569eddeea7330ba /tensorflow/contrib
parent09138338ed81b56aeaff88b9ee5a11c3f6d77b70 (diff)
Export align_corners to TF Lite
PiperOrigin-RevId: 184622482
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h1
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc12
-rw-r--r--tensorflow/contrib/lite/model.cc1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs3
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h40
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc3
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc22
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc10
8 files changed, 65 insertions, 27 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index a1037a525c..5dbeadd165 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -151,6 +151,7 @@ typedef struct {
} TfLiteLSTMParams;
typedef struct {
+ bool align_corners;
} TfLiteResizeBilinearParams;
typedef struct {
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index 4a2101f246..c5d60cae3a 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -75,6 +75,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
+
TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* size = GetInput(context, node, kSizeTensor);
@@ -86,10 +89,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_RESIZE_BILINEAR(type) \
- type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<int32>(size), GetTensorDims(size), \
- GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_RESIZE_BILINEAR(type) \
+ type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
+ GetTensorData<int32>(size), GetTensorDims(size), \
+ GetTensorData<float>(output), GetTensorDims(output), \
+ params->align_corners)
if (kernel_type == kReference) {
TF_LITE_RESIZE_BILINEAR(reference_ops);
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index b36bfcef84..14b6709964 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -469,6 +469,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
auto* params = MallocPOD<TfLiteResizeBilinearParams>();
if (auto* schema_params =
op->builtin_options_as_ResizeBilinearOptions()) {
+ params->align_corners = schema_params->align_corners();
}
builtin_data = reinterpret_cast<void*>(params);
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index c0b220e872..36cc2724eb 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -273,6 +273,9 @@ table LSTMOptions {
}
table ResizeBilinearOptions {
+ new_height: int (deprecated);
+ new_width: int (deprecated);
+ align_corners: bool;
}
// A call operation options
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 29f3a17be7..e2ac0b9d1e 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -2626,28 +2626,36 @@ flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
struct ResizeBilinearOptionsT : public flatbuffers::NativeTable {
typedef ResizeBilinearOptions TableType;
- ResizeBilinearOptionsT() {}
+ bool align_corners;
+ ResizeBilinearOptionsT()
+ : align_corners(false) {
+ }
};
-struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS
- : private flatbuffers::Table {
+struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef ResizeBilinearOptionsT NativeTableType;
+ enum {
+ VT_ALIGN_CORNERS = 8
+ };
+ bool align_corners() const {
+ return GetField<uint8_t>(VT_ALIGN_CORNERS, 0) != 0;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
- return VerifyTableStart(verifier) && verifier.EndTable();
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_ALIGN_CORNERS) &&
+ verifier.EndTable();
}
- ResizeBilinearOptionsT *UnPack(
- const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- void UnPackTo(
- ResizeBilinearOptionsT *_o,
- const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- static flatbuffers::Offset<ResizeBilinearOptions> Pack(
- flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o,
- const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+ ResizeBilinearOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ResizeBilinearOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct ResizeBilinearOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
+ void add_align_corners(bool align_corners) {
+ fbb_.AddElement<uint8_t>(ResizeBilinearOptions::VT_ALIGN_CORNERS, static_cast<uint8_t>(align_corners), 0);
+ }
explicit ResizeBilinearOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2661,14 +2669,14 @@ struct ResizeBilinearOptionsBuilder {
};
inline flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(
- flatbuffers::FlatBufferBuilder &_fbb) {
+ flatbuffers::FlatBufferBuilder &_fbb,
+ bool align_corners = false) {
ResizeBilinearOptionsBuilder builder_(_fbb);
+ builder_.add_align_corners(align_corners);
return builder_.Finish();
}
-flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(
- flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o,
- const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct CallOptionsT : public flatbuffers::NativeTable {
typedef CallOptions TableType;
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index e8b425a592..5ea3e21f6a 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -89,9 +89,6 @@ std::map<string, string> kBrokenTests = {
// ResizeBilinear looks completely incompatible with Tensorflow
{R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"},
- {R"(^\/resize_bilinearalign_corners=True,.*,size=\[2,2\])", "72401483"},
- {R"(^\/resize_bilinearalign_corners=True,.*,size=\[4,3\])", "72401483"},
- {R"(^\/resize_bilinearalign_corners=True,.*,size=\[5,6\])", "72401483"},
// Transpose only supports 1D-4D input tensors.
{R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 461494fd99..04aaedd59d 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -540,6 +540,24 @@ class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
}
};
+class ResizeBilinear
+ : public BuiltinOperator<ResizeBilinearOperator,
+ ::tflite::ResizeBilinearOptions,
+ ::tflite::BuiltinOptions_ResizeBilinearOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->align_corners = options.align_corners();
+ }
+};
+
class Squeeze
: public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
::tflite::BuiltinOptions_SqueezeOptions> {
@@ -755,6 +773,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kTranspose));
ops.emplace_back(
new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
+ ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR,
+ OperatorType::kResizeBilinear));
ops.emplace_back(
new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
@@ -787,8 +807,6 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1));
ops.emplace_back(
new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
- ops.emplace_back(new SimpleOperator<ResizeBilinearOperator>(
- "RESIZE_BILINEAR", OperatorType::kResizeBilinear));
ops.emplace_back(new SimpleOperator<LogisticOperator>(
"LOGISTIC", OperatorType::kLogistic));
ops.emplace_back(
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 6daa296282..796534be53 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -104,8 +104,6 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
- CheckSimpleOperator<ResizeBilinearOperator>("RESIZE_BILINEAR",
- OperatorType::kResizeBilinear);
CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
}
@@ -331,6 +329,14 @@ TEST_F(OperatorTest, BuiltinMul) {
output_toco_op->fused_activation_function);
}
+TEST_F(OperatorTest, ResizeBilinear) {
+ ResizeBilinearOperator op;
+ op.align_corners = true;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
+ EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
+}
+
TEST_F(OperatorTest, Svdf) {
SvdfOperator op;
op.fused_activation_function = FusedActivationFunctionType::kRelu;