aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2018-03-28 16:52:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-28 16:55:15 -0700
commit108178da2a20ea2d3899417ee932d46ba1a5c652 (patch)
tree313bd8cec176f8c9ef67b25c6484a650d1f2092a /tensorflow/contrib/lite
parent390e19ab990f5656e09d98624c92b3c80e52937d (diff)
Automated g4 rollback of changelist 190835392
PiperOrigin-RevId: 190858242
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/README.md3
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md2
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD13
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h25
-rw-r--r--tensorflow/contrib/lite/kernels/maximum.cc106
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_test.cc81
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc3
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/python/lite.py22
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs5
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h124
-rw-r--r--tensorflow/contrib/lite/testing/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py36
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc2
18 files changed, 18 insertions, 412 deletions
diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md
index c15ae3f233..2680d515eb 100644
--- a/tensorflow/contrib/lite/README.md
+++ b/tensorflow/contrib/lite/README.md
@@ -126,9 +126,6 @@ The above pre-trained models have been trained on the ImageNet data set, which c
The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/) codelab walks through this process step-by-step. The retraining code supports retraining for both floating point and quantized inference.
-# Getting started with RaspberryPi
-
-Using RaspberryPi can be accomplished by following the [Makefile instructions](g3doc/rpi.md). That will give a you a static library (.a) that you can build your app against. Python bindings will be coming soon as well as a demo app.
### Train a custom model
A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow's Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model.
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 17b791e4e2..d7993e60cc 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -79,7 +79,6 @@ typedef enum {
kTfLiteBuiltinBidirectionalSequenceLstm = 52,
kTfLiteBuiltinCast = 53,
kTfLiteBuiltinPrelu = 54,
- kTfLiteBuiltinMaximum = 55,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index 48f43d4fc4..5b393140d6 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -1,4 +1,4 @@
-# List of Hosted Models
+#List of Hosted Models
* [Inception V3 2015](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_2015_2017_11_10.zip)
* [Inception V3 Slim 2016](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip)
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index c423c00bf5..1450c1e14b 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -156,7 +156,6 @@ cc_library(
"local_response_norm.cc",
"lsh_projection.cc",
"lstm.cc",
- "maximum.cc",
"mean.cc",
"mfcc.cc",
"mul.cc",
@@ -538,18 +537,6 @@ tf_cc_test(
)
tf_cc_test(
- name = "maximum_test",
- size = "small",
- srcs = ["maximum_test.cc"],
- deps = [
- ":builtin_ops",
- "//tensorflow/contrib/lite:framework",
- "//tensorflow/contrib/lite/kernels:test_util",
- "@com_google_googletest//:gtest",
- ],
-)
-
-tf_cc_test(
name = "mean_test",
size = "small",
srcs = ["mean_test.cc"],
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 3575974ae9..33d60afa26 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -404,7 +404,6 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
const int in_d =
out_d + ((out_h % block_size) * block_size + out_w % block_size) *
output_depth;
-
const int in_w = out_w / block_size;
const int in_h = out_h / block_size;
const int in_b = out_b;
@@ -3364,30 +3363,6 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
}
}
-template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- auto out_idx = Offset(output_dims, c, x, y, b);
- auto in1_idx = SubscriptToIndex(desc1, c, x, y, b);
- auto in2_idx = SubscriptToIndex(desc2, c, x, y, b);
- auto in1_val = input1_data[in1_idx];
- auto in2_val = input2_data[in2_idx];
- output_data[out_idx] = in1_val > in2_val ? in1_val : in2_val;
- }
- }
- }
- }
-}
-
template <typename T1, typename T2, typename T3>
void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
T2* output_data, const Dims<4>& output_dims) {
diff --git a/tensorflow/contrib/lite/kernels/maximum.cc b/tensorflow/contrib/lite/kernels/maximum.cc
deleted file mode 100644
index 9fdf2b47ea..0000000000
--- a/tensorflow/contrib/lite/kernels/maximum.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <string.h>
-#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
-#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
-#include "tensorflow/contrib/lite/kernels/kernel_util.h"
-#include "tensorflow/contrib/lite/kernels/op_macros.h"
-
-namespace tflite {
-namespace ops {
-namespace builtin {
-namespace maximum {
-
-// This file has a reference implemenation of TFMaximum.
-enum KernelType {
- kReference,
-};
-
-constexpr int kInputTensor1 = 0;
-constexpr int kInputTensor2 = 1;
-constexpr int kOutputTensor = 0;
-
-struct MaximumContext {
- MaximumContext(TfLiteContext* context, TfLiteNode* node) {
- input1 = GetInput(context, node, kInputTensor1);
- input2 = GetInput(context, node, kInputTensor2);
- output = GetOutput(context, node, kOutputTensor);
- }
- TfLiteTensor* input1;
- TfLiteTensor* input2;
- TfLiteTensor* output;
-};
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
- MaximumContext op_context(context, node);
- TF_LITE_ENSURE_EQ(context, op_context.input1->type, op_context.input2->type);
- TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input2->dims);
- op_context.output->type = op_context.input2->type;
- return context->ResizeTensor(context, op_context.output, output_dims);
-}
-
-template <KernelType kernel_type>
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- MaximumContext op_context(context, node);
-
-#define TF_LITE_MAXIMUM(kernel_type, data_type) \
- kernel_type::TensorFlowMaximum<data_type>( \
- GetTensorData<data_type>(op_context.input1), \
- GetTensorDims(op_context.input1), \
- GetTensorData<data_type>(op_context.input2), \
- GetTensorDims(op_context.input2), \
- GetTensorData<data_type>(op_context.output), \
- GetTensorDims(op_context.output))
-
- if (kernel_type == kReference) {
- switch (op_context.output->type) {
- case kTfLiteFloat32:
- TF_LITE_MAXIMUM(reference_ops, float);
- break;
- default:
- context->ReportError(context,
- "Type %d is currently not supported by Maximum.",
- op_context.output->type);
- return kTfLiteError;
- }
- } else {
- context->ReportError(context,
- "Type %d is currently not supported by Maximum.",
- op_context.output->type);
- return kTfLiteError;
- }
-#undef TF_LITE_MAXIMUM
- return kTfLiteOk;
-}
-
-} // namespace maximum
-
-TfLiteRegistration* Register_MAXIMUM_REF() {
- static TfLiteRegistration r = {nullptr, nullptr, maximum::Prepare,
- maximum::Eval<maximum::kReference>};
- return &r;
-}
-
-TfLiteRegistration* Register_MAXIMUM() { return Register_MAXIMUM_REF(); }
-
-} // namespace builtin
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/maximum_test.cc b/tensorflow/contrib/lite/kernels/maximum_test.cc
deleted file mode 100644
index b3fd7d4e6f..0000000000
--- a/tensorflow/contrib/lite/kernels/maximum_test.cc
+++ /dev/null
@@ -1,81 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/kernels/register.h"
-#include "tensorflow/contrib/lite/kernels/test_util.h"
-#include "tensorflow/contrib/lite/model.h"
-
-namespace tflite {
-namespace {
-
-using ::testing::ElementsAreArray;
-
-class MaximumOpModel : public SingleOpModel {
- public:
- MaximumOpModel(const TensorData& input1, const TensorData& input2,
- const TensorType& output) {
- input1_ = AddInput(input1);
- input2_ = AddInput(input2);
- output_ = AddOutput(output);
- SetBuiltinOp(BuiltinOperator_MAXIMUM, BuiltinOptions_MaximumOptions,
- CreateMaximumOptions(builder_).Union());
- BuildInterpreter({GetShape(input1_), GetShape(input2_)});
- }
-
- template <class T>
- void SetInput1(std::initializer_list<T> data) {
- PopulateTensor(input1_, data);
- }
-
- template <class T>
- void SetInput2(std::initializer_list<T> data) {
- PopulateTensor(input2_, data);
- }
-
- template <class T>
- std::vector<T> GetOutput() {
- return ExtractVector<T>(output_);
- }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- protected:
- int input1_;
- int input2_;
- int output_;
-};
-
-TEST(MaximumOpTest, FloatTest) {
- std::initializer_list<float> data1 = {1.0, 0.0, -1.0, 11.0, -2.0, -1.44};
- std::initializer_list<float> data2 = {-1.0, 0.0, 1.0, 12.0, -3.0, -1.43};
- MaximumOpModel m({TensorType_FLOAT32, {3, 1, 2}},
- {TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32);
- m.SetInput1<float>(data1);
- m.SetInput2<float>(data2);
- m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
- EXPECT_THAT(
- m.GetOutput<float>(),
- ElementsAreArray(ArrayFloatNear({1.0, 0.0, 1.0, 12.0, -2.0, -1.43})));
-}
-
-} // namespace
-} // namespace tflite
-
-int main(int argc, char** argv) {
- ::tflite::LogToStderr();
- ::testing::InitGoogleTest(&argc, argv);
- return RUN_ALL_TESTS();
-}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 0f98154b90..62045f0a4d 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -76,7 +76,6 @@ TfLiteRegistration* Register_LOG_SOFTMAX();
TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_DEQUANTIZE();
TfLiteRegistration* Register_PRELU();
-TfLiteRegistration* Register_MAXIMUM();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -134,7 +133,6 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE());
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
- AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 791d1378f3..b7ccdf070b 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -597,9 +597,6 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_MAXIMUM: {
- break;
- }
case BuiltinOperator_DELEGATE: {
// TODO(ycling): Revisit when supporting saving delegated models.
error_reporter->Report("DELEGATE op shouldn't exist in model.");
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index decaf9f160..e31b7c03a5 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -350,7 +350,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_DELEGATE:
case tflite::BuiltinOperator_CAST:
case tflite::BuiltinOperator_PRELU:
- case tflite::BuiltinOperator_MAXIMUM:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index ed6dd036f9..35d224924e 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -25,9 +25,9 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os as _os
-import subprocess as _subprocess
-import tempfile as _tempfile
+import os
+import subprocess
+import tempfile
# pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs
@@ -74,7 +74,7 @@ else:
_toco_from_proto_bin = _resource_loader.get_path_to_datafile(
"../toco/python/toco_from_protos")
-if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
+if _toco_from_proto_bin and not os.path.exists(_toco_from_proto_bin):
_toco_from_proto_bin = "toco_from_protos"
@@ -102,10 +102,10 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
return _toco_python.TocoConvert(
model_flags_str, toco_flags_str, input_data_str)
- with _tempfile.NamedTemporaryFile() as fp_toco, \
- _tempfile.NamedTemporaryFile() as fp_model, \
- _tempfile.NamedTemporaryFile() as fp_input, \
- _tempfile.NamedTemporaryFile() as fp_output:
+ with tempfile.NamedTemporaryFile() as fp_toco, \
+ tempfile.NamedTemporaryFile() as fp_model, \
+ tempfile.NamedTemporaryFile() as fp_input, \
+ tempfile.NamedTemporaryFile() as fp_output:
fp_model.write(model_flags_str)
fp_toco.write(toco_flags_str)
fp_input.write(input_data_str)
@@ -118,11 +118,11 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
fp_output.name
]
cmdline = " ".join(cmd)
- proc = _subprocess.Popen(
+ proc = subprocess.Popen(
cmdline,
shell=True,
- stdout=_subprocess.PIPE,
- stderr=_subprocess.STDOUT,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
close_fds=True)
stdout, stderr = proc.communicate()
exitcode = proc.returncode
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 7d2e00fe32..e1075971e9 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -131,7 +131,6 @@ enum BuiltinOperator : byte {
BIDIRECTIONAL_SEQUENCE_LSTM = 52,
CAST = 53,
PRELU = 54,
- MAXIMUM = 55,
}
// Options for the builtin operators.
@@ -174,7 +173,6 @@ union BuiltinOptions {
LogSoftmaxOptions,
CastOptions,
DequantizeOptions,
- MaximumOptions,
}
enum Padding : byte { SAME, VALID }
@@ -386,9 +384,6 @@ table CastOptions {
table DequantizeOptions {
}
-table MaximumOptions {
-}
-
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 66a97a1460..86daeaf5cc 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -145,9 +145,6 @@ struct CastOptionsT;
struct DequantizeOptions;
struct DequantizeOptionsT;
-struct MaximumOptions;
-struct MaximumOptionsT;
-
struct OperatorCode;
struct OperatorCodeT;
@@ -258,12 +255,11 @@ enum BuiltinOperator {
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52,
BuiltinOperator_CAST = 53,
BuiltinOperator_PRELU = 54,
- BuiltinOperator_MAXIMUM = 55,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_MAXIMUM
+ BuiltinOperator_MAX = BuiltinOperator_PRELU
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[54] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[53] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -317,8 +313,7 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[54] {
BuiltinOperator_DELEGATE,
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
BuiltinOperator_CAST,
- BuiltinOperator_PRELU,
- BuiltinOperator_MAXIMUM
+ BuiltinOperator_PRELU
};
return values;
}
@@ -380,7 +375,6 @@ inline const char **EnumNamesBuiltinOperator() {
"BIDIRECTIONAL_SEQUENCE_LSTM",
"CAST",
"PRELU",
- "MAXIMUM",
nullptr
};
return names;
@@ -431,12 +425,11 @@ enum BuiltinOptions {
BuiltinOptions_LogSoftmaxOptions = 36,
BuiltinOptions_CastOptions = 37,
BuiltinOptions_DequantizeOptions = 38,
- BuiltinOptions_MaximumOptions = 39,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_MaximumOptions
+ BuiltinOptions_MAX = BuiltinOptions_DequantizeOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[40] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[39] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -476,8 +469,7 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[40] {
BuiltinOptions_SplitOptions,
BuiltinOptions_LogSoftmaxOptions,
BuiltinOptions_CastOptions,
- BuiltinOptions_DequantizeOptions,
- BuiltinOptions_MaximumOptions
+ BuiltinOptions_DequantizeOptions
};
return values;
}
@@ -523,7 +515,6 @@ inline const char **EnumNamesBuiltinOptions() {
"LogSoftmaxOptions",
"CastOptions",
"DequantizeOptions",
- "MaximumOptions",
nullptr
};
return names;
@@ -690,10 +681,6 @@ template<> struct BuiltinOptionsTraits<DequantizeOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_DequantizeOptions;
};
-template<> struct BuiltinOptionsTraits<MaximumOptions> {
- static const BuiltinOptions enum_value = BuiltinOptions_MaximumOptions;
-};
-
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1029,14 +1016,6 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_DequantizeOptions ?
reinterpret_cast<const DequantizeOptionsT *>(value) : nullptr;
}
- MaximumOptionsT *AsMaximumOptions() {
- return type == BuiltinOptions_MaximumOptions ?
- reinterpret_cast<MaximumOptionsT *>(value) : nullptr;
- }
- const MaximumOptionsT *AsMaximumOptions() const {
- return type == BuiltinOptions_MaximumOptions ?
- reinterpret_cast<const MaximumOptionsT *>(value) : nullptr;
- }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -3780,46 +3759,6 @@ inline flatbuffers::Offset<DequantizeOptions> CreateDequantizeOptions(
flatbuffers::Offset<DequantizeOptions> CreateDequantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-struct MaximumOptionsT : public flatbuffers::NativeTable {
- typedef MaximumOptions TableType;
- MaximumOptionsT() {
- }
-};
-
-struct MaximumOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
- typedef MaximumOptionsT NativeTableType;
- bool Verify(flatbuffers::Verifier &verifier) const {
- return VerifyTableStart(verifier) &&
- verifier.EndTable();
- }
- MaximumOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- void UnPackTo(MaximumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- static flatbuffers::Offset<MaximumOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const MaximumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-};
-
-struct MaximumOptionsBuilder {
- flatbuffers::FlatBufferBuilder &fbb_;
- flatbuffers::uoffset_t start_;
- explicit MaximumOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
- : fbb_(_fbb) {
- start_ = fbb_.StartTable();
- }
- MaximumOptionsBuilder &operator=(const MaximumOptionsBuilder &);
- flatbuffers::Offset<MaximumOptions> Finish() {
- const auto end = fbb_.EndTable(start_);
- auto o = flatbuffers::Offset<MaximumOptions>(end);
- return o;
- }
-};
-
-inline flatbuffers::Offset<MaximumOptions> CreateMaximumOptions(
- flatbuffers::FlatBufferBuilder &_fbb) {
- MaximumOptionsBuilder builder_(_fbb);
- return builder_.Finish();
-}
-
-flatbuffers::Offset<MaximumOptions> CreateMaximumOptions(flatbuffers::FlatBufferBuilder &_fbb, const MaximumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -4051,9 +3990,6 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const DequantizeOptions *builtin_options_as_DequantizeOptions() const {
return builtin_options_type() == BuiltinOptions_DequantizeOptions ? static_cast<const DequantizeOptions *>(builtin_options()) : nullptr;
}
- const MaximumOptions *builtin_options_as_MaximumOptions() const {
- return builtin_options_type() == BuiltinOptions_MaximumOptions ? static_cast<const MaximumOptions *>(builtin_options()) : nullptr;
- }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -4232,10 +4168,6 @@ template<> inline const DequantizeOptions *Operator::builtin_options_as<Dequanti
return builtin_options_as_DequantizeOptions();
}
-template<> inline const MaximumOptions *Operator::builtin_options_as<MaximumOptions>() const {
- return builtin_options_as_MaximumOptions();
-}
-
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -5764,29 +5696,6 @@ inline flatbuffers::Offset<DequantizeOptions> CreateDequantizeOptions(flatbuffer
_fbb);
}
-inline MaximumOptionsT *MaximumOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
- auto _o = new MaximumOptionsT();
- UnPackTo(_o, _resolver);
- return _o;
-}
-
-inline void MaximumOptions::UnPackTo(MaximumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
- (void)_o;
- (void)_resolver;
-}
-
-inline flatbuffers::Offset<MaximumOptions> MaximumOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MaximumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
- return CreateMaximumOptions(_fbb, _o, _rehasher);
-}
-
-inline flatbuffers::Offset<MaximumOptions> CreateMaximumOptions(flatbuffers::FlatBufferBuilder &_fbb, const MaximumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
- (void)_rehasher;
- (void)_o;
- struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MaximumOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
- return tflite::CreateMaximumOptions(
- _fbb);
-}
-
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -6119,10 +6028,6 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const DequantizeOptions *>(obj);
return verifier.VerifyTable(ptr);
}
- case BuiltinOptions_MaximumOptions: {
- auto ptr = reinterpret_cast<const MaximumOptions *>(obj);
- return verifier.VerifyTable(ptr);
- }
default: return false;
}
}
@@ -6293,10 +6198,6 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const DequantizeOptions *>(obj);
return ptr->UnPack(resolver);
}
- case BuiltinOptions_MaximumOptions: {
- auto ptr = reinterpret_cast<const MaximumOptions *>(obj);
- return ptr->UnPack(resolver);
- }
default: return nullptr;
}
}
@@ -6455,10 +6356,6 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const DequantizeOptionsT *>(value);
return CreateDequantizeOptions(_fbb, ptr, _rehasher).Union();
}
- case BuiltinOptions_MaximumOptions: {
- auto ptr = reinterpret_cast<const MaximumOptionsT *>(value);
- return CreateMaximumOptions(_fbb, ptr, _rehasher).Union();
- }
default: return 0;
}
}
@@ -6617,10 +6514,6 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new DequantizeOptionsT(*reinterpret_cast<DequantizeOptionsT *>(u.value));
break;
}
- case BuiltinOptions_MaximumOptions: {
- value = new MaximumOptionsT(*reinterpret_cast<MaximumOptionsT *>(u.value));
- break;
- }
default:
break;
}
@@ -6818,11 +6711,6 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
- case BuiltinOptions_MaximumOptions: {
- auto ptr = reinterpret_cast<MaximumOptionsT *>(value);
- delete ptr;
- break;
- }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 12b7b3c350..555ea90034 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -36,7 +36,6 @@ gen_zipped_test_files(
"local_response_norm.zip",
"log_softmax.zip",
"max_pool.zip",
- "maximum.zip",
"mean.zip",
"mul.zip",
"pad.zip",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 8045052452..cb5c500136 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -862,41 +862,6 @@ def make_log_softmax_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_maximum_tests(zip_path):
- """Make a set of tests to do maximum."""
-
- test_parameters = [{
- "input_dtype": [tf.float32],
- "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
- "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
- }]
-
- def build_graph(parameters):
- """Build the maximum op testing graph."""
- input_tensor_1 = tf.placeholder(
- dtype=parameters["input_dtype"],
- name="input_1",
- shape=parameters["input_shape_1"])
- input_tensor_2 = tf.placeholder(
- dtype=parameters["input_dtype"],
- name="input_2",
- shape=parameters["input_shape_2"])
-
- out = tf.maximum(input_tensor_1, input_tensor_2)
- return [input_tensor_1, input_tensor_2], [out]
-
- def build_inputs(parameters, sess, inputs, outputs):
- values = [
- create_tensor_data(parameters["input_dtype"],
- parameters["input_shape_1"]),
- create_tensor_data(parameters["input_dtype"],
- parameters["input_shape_2"])
- ]
- return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
-
- make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-
-
def make_binary_op_tests_func(binary_operator):
"""Return a function that does a test on a binary operator."""
return lambda zip_path: make_binary_op_tests(zip_path, binary_operator)
@@ -2012,7 +1977,6 @@ def main(unused_args):
"exp.zip": make_exp_tests,
"log_softmax.zip": make_log_softmax_tests,
"lstm.zip": make_lstm_tests,
- "maximum.zip": make_maximum_tests,
}
out = FLAGS.zip_to_output
bin_path = FLAGS.toco
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 6697b86e79..a4a7283508 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -253,7 +253,6 @@ INSTANTIATE_TESTS(l2_pool)
INSTANTIATE_TESTS(l2norm)
INSTANTIATE_TESTS(local_response_norm)
INSTANTIATE_TESTS(log_softmax)
-INSTANTIATE_TESTS(maximum)
INSTANTIATE_TESTS(max_pool)
INSTANTIATE_TESTS(mean)
INSTANTIATE_TESTS(mul)
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 0989bfe5a3..f23249cfa1 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -863,8 +863,6 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(new SimpleOperator<ExpOperator>("EXP", OperatorType::kExp));
ops.emplace_back(new SimpleOperator<LogSoftmaxOperator>(
"LOG_SOFTMAX", OperatorType::kLogSoftmax));
- ops.emplace_back(new SimpleOperator<TensorFlowMaximumOperator>(
- "MAXIMUM", OperatorType::kTensorFlowMaximum));
return ops;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index f7a213ecfc..9c19f8d464 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -109,8 +109,6 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp);
CheckSimpleOperator<LogSoftmaxOperator>("LOG_SOFTMAX",
OperatorType::kLogSoftmax);
- CheckSimpleOperator<TensorFlowMaximumOperator>(
- "MAXIMUM", OperatorType::kTensorFlowMaximum);
}
TEST_F(OperatorTest, BuiltinAdd) {