aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-30 02:41:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 02:45:14 -0700
commit333f9c03950a1b6afb8a902b2dc3d883be490b86 (patch)
tree1c759f72f699df5078f085a517334ce8da8f1fec /tensorflow/contrib/lite
parent9e0b05bbc4bb88d1b34fb2147429dc4ad7bd25cd (diff)
Implementation of logical_or.
PiperOrigin-RevId: 206549781
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/build_def.bzl31
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md12
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD15
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h32
-rw-r--r--tensorflow/contrib/lite/kernels/logical.cc121
-rw-r--r--tensorflow/contrib/lite/kernels/logical_test.cc87
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py33
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc1
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc3
-rw-r--r--tensorflow/contrib/lite/toco/model.h12
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
16 files changed, 358 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 2e91632459..efe8857344 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -2,8 +2,8 @@
load(
"//tensorflow:tensorflow.bzl",
- "tf_cc_test",
"tf_cc_shared_object",
+ "tf_cc_test",
)
def tflite_copts():
@@ -125,19 +125,21 @@ def tflite_jni_binary(
linkopts = linkopts,
)
-def tflite_cc_shared_object(name,
- copts=tflite_copts(),
- linkopts=[],
- linkstatic=1,
- deps=[]):
- """Builds a shared object for TFLite."""
- tf_cc_shared_object(
- name=name,
- copts=copts,
- linkstatic=linkstatic,
- linkopts=linkopts + tflite_jni_linkopts(),
- framework_so=[],
- deps=deps)
+def tflite_cc_shared_object(
+ name,
+ copts = tflite_copts(),
+ linkopts = [],
+ linkstatic = 1,
+ deps = []):
+ """Builds a shared object for TFLite."""
+ tf_cc_shared_object(
+ name = name,
+ copts = copts,
+ linkstatic = linkstatic,
+ linkopts = linkopts + tflite_jni_linkopts(),
+ framework_so = [],
+ deps = deps,
+ )
def tf_to_tflite(name, src, options, out):
"""Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
@@ -243,6 +245,7 @@ def generated_test_models():
"local_response_norm",
"log_softmax",
"log",
+ "logical_or",
"lstm",
"max_pool",
"maximum",
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 2ea7aeaa5d..aa65ec9988 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -831,6 +831,18 @@ Outputs {
}
```
+**LOGICAL_OR**
+
+```
+Inputs {
+ 0: a list of tensors.
+ 1: a list of tensors.
+}
+Outputs {
+ 0: A tensor of logical_or output tensors.
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 026ab4de03..329c98f91e 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -170,6 +170,7 @@ cc_library(
"hashtable_lookup.cc",
"l2norm.cc",
"local_response_norm.cc",
+ "logical.cc",
"lsh_projection.cc",
"lstm.cc",
"maximum_minimum.cc",
@@ -1185,6 +1186,20 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "logical_test",
+ size = "small",
+ srcs = ["logical_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index ce16394082..714613b96e 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -4255,6 +4255,38 @@ inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
}
}
+inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
+ const bool* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = func(input1_data[i], input2_data[i]);
+ }
+}
+
+inline void BroadcastLogical(const bool* input1_data,
+ const Dims<4>& input1_dims,
+ const bool* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ func(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc
new file mode 100644
index 0000000000..3dc39bf79a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/logical.cc
@@ -0,0 +1,121 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace logical {
+namespace {
+
+// Input/output tensor index.
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+// Op data for logical op.
+struct OpData {
+ bool requires_broadcast;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ // Reinterprete the opaque data provided by user.
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+
+ const TfLiteType type = input1->type;
+ if (type != kTfLiteBool) {
+ context->ReportError(context, "Logical ops only support bool type.");
+ return kTfLiteError;
+ }
+ output->type = type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
+ const std::function<bool(bool, bool)>& func) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (data->requires_broadcast) {
+ reference_ops::BroadcastLogical(
+ GetTensorData<bool>(input1), GetTensorDims(input1),
+ GetTensorData<bool>(input2), GetTensorDims(input2),
+ GetTensorData<bool>(output), GetTensorDims(output), func);
+ } else {
+ reference_ops::Logical(GetTensorData<bool>(input1), GetTensorDims(input1),
+ GetTensorData<bool>(input2), GetTensorDims(input2),
+ GetTensorData<bool>(output), GetTensorDims(output),
+ func);
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) {
+ const auto logical_or_func = std::logical_or<bool>();
+ return LogicalImpl(context, node, logical_or_func);
+}
+
+} // namespace
+} // namespace logical
+
+TfLiteRegistration* Register_LOGICAL_OR() {
+ // Init, Free, Prepare, Eval are satisfying the Interface required by
+ // TfLiteRegistration.
+ static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare,
+ logical::LogicalOrEval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/logical_test.cc b/tensorflow/contrib/lite/kernels/logical_test.cc
new file mode 100644
index 0000000000..382008245b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/logical_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+class LogicalOpModel : public SingleOpModel {
+ public:
+ LogicalOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape, BuiltinOperator op) {
+ input1_ = AddInput(TensorType_BOOL);
+ input2_ = AddInput(TensorType_BOOL);
+ output_ = AddOutput(TensorType_BOOL);
+ ConfigureBuiltinOp(op);
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+
+ void ConfigureBuiltinOp(BuiltinOperator op) {
+ switch (op) {
+ case BuiltinOperator_LOGICAL_OR: {
+ SetBuiltinOp(op, BuiltinOptions_LogicalOrOptions,
+ CreateLogicalOrOptions(builder_).Union());
+ break;
+ }
+ default: { FAIL() << "We shouldn't get here."; }
+ }
+ }
+};
+
+TEST(LogicalTest, LogicalOr) {
+ LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_OR);
+ model.PopulateTensor<bool>(model.input1(), {true, false, false, true});
+ model.PopulateTensor<bool>(model.input2(), {true, false, true, false});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(LogicalTest, BroadcastLogicalOr) {
+ LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_OR);
+ model.PopulateTensor<bool>(model.input1(), {true, false, false, true});
+ model.PopulateTensor<bool>(model.input2(), {false});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index da69b85041..e632728841 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -108,6 +108,7 @@ TfLiteRegistration* Register_POW();
TfLiteRegistration* Register_FAKE_QUANT();
TfLiteRegistration* Register_PACK();
TfLiteRegistration* Register_ONE_HOT();
+TfLiteRegistration* Register_LOGICAL_OR();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -199,6 +200,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
AddBuiltin(BuiltinOperator_PACK, Register_PACK());
AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT());
+ AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 4234d0b811..a95b26220d 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -231,6 +231,7 @@ _TF_TYPE_INFO = {
tf.int32: (np.int32, "INT32"),
tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
tf.int64: (np.int64, "INT64"),
+ tf.bool: (np.bool, "BOOL"),
}
@@ -244,7 +245,8 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
value = (max_value-min_value)*np.random.random_sample(shape)+min_value
elif dtype in (tf.int32, tf.uint8, tf.int64):
value = np.random.randint(min_value, max_value+1, shape)
-
+ elif dtype == tf.bool:
+ value = np.random.choice([True, False], size=shape)
return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
dtype)
@@ -2982,6 +2984,35 @@ def make_pack_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_logical_or_tests(zip_path):
+ """Make a set of tests to do logical_or."""
+
+ test_parameters = [{
+ "input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the logical_or op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=tf.bool, name="input1", shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=tf.bool, name="input2", shape=parameters["input_shape_pair"][1])
+ out = tf.logical_or(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(tf.bool,
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(tf.bool,
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
# Toco binary path provided by the generate rule.
bin_path = None
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 9983e59910..378212cb74 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1925,6 +1925,21 @@ void ConvertLogicalNotOperator(const Model& model,
*logical_op->add_input() = src_op.inputs[0];
}
+void ConvertLogicalOrOperator(const Model& model,
+ const LogicalOrOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node();
+ logical_or_op->set_op(op_name);
+ logical_or_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ for (int i = 0; i < 2; ++i) {
+ *logical_or_op->add_input() = src_op.inputs[i];
+ }
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*logical_or_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2175,6 +2190,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kOneHot) {
ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLogicalOr) {
+ ConvertLogicalOrOperator(model,
+ static_cast<const LogicalOrOperator&>(src_op),
+ "LogicalOr", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 0f94006f34..9cec6d65f3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -65,6 +65,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
case OperatorType::kAny:
case OperatorType::kLogicalAnd:
case OperatorType::kLogicalNot:
+ case OperatorType::kLogicalOr:
// These operators unconditionally produce bool outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 5aa0fddf57..3c9379fd87 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1673,6 +1673,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kSin:
case OperatorType::kLogicalAnd:
case OperatorType::kLogicalNot:
+ case OperatorType::kLogicalOr:
ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather:
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 7e593029e5..9a3db5c888 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1914,9 +1914,10 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>},
{"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>},
{"Log", ConvertSimpleOperator<LogOperator, 1>},
- {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
{"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>},
+ {"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2>},
{"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>},
+ {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
{"MatMul", ConvertMatMulOperator},
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
{"MaxPool", ConvertMaxPoolOperator},
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index a3827977fd..7d0dbfcc05 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -147,6 +147,7 @@ enum class OperatorType : uint8 {
kAny,
kLogicalAnd,
kLogicalNot,
+ kLogicalOr,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1790,6 +1791,17 @@ struct OneHotOperator : Operator {
int axis = -1;
};
+// LogicalOr operator:
+//
+// Inputs:
+// Inputs[0]: required: A Bool tensor.
+// Inputs[1]: required: A Bool tensor.
+//
+// TensorFlow equivalent: LogicalOr.
+struct LogicalOrOperator : Operator {
+ LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {}
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 769e350ea9..9380168f30 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1350,6 +1350,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(
new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice));
ops.emplace_back(new SimpleOperator<PowOperator>("POW", OperatorType::kPow));
+ ops.emplace_back(new SimpleOperator<LogicalOrOperator>(
+ "LOGICAL_OR", OperatorType::kLogicalOr));
// Element-wise operator
ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin));
ops.emplace_back(new SimpleOperator<LogOperator>("LOG", OperatorType::kLog));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 7e1e32ae54..384f7c118d 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -127,6 +127,8 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt);
CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt);
CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow);
+ CheckSimpleOperator<LogicalOrOperator>("LOGICAL_OR",
+ OperatorType::kLogicalOr);
}
TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 93c30dd0f8..7f5da251d0 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -403,6 +403,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Any)
HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
+ HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE