diff options
author | 2018-07-30 02:41:59 -0700 | |
---|---|---|
committer | 2018-07-30 02:45:14 -0700 | |
commit | 333f9c03950a1b6afb8a902b2dc3d883be490b86 (patch) | |
tree | 1c759f72f699df5078f085a517334ce8da8f1fec /tensorflow/contrib/lite | |
parent | 9e0b05bbc4bb88d1b34fb2147429dc4ad7bd25cd (diff) |
Implementation of logical_or.
PiperOrigin-RevId: 206549781
Diffstat (limited to 'tensorflow/contrib/lite')
16 files changed, 358 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 2e91632459..efe8857344 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -2,8 +2,8 @@ load( "//tensorflow:tensorflow.bzl", - "tf_cc_test", "tf_cc_shared_object", + "tf_cc_test", ) def tflite_copts(): @@ -125,19 +125,21 @@ def tflite_jni_binary( linkopts = linkopts, ) -def tflite_cc_shared_object(name, - copts=tflite_copts(), - linkopts=[], - linkstatic=1, - deps=[]): - """Builds a shared object for TFLite.""" - tf_cc_shared_object( - name=name, - copts=copts, - linkstatic=linkstatic, - linkopts=linkopts + tflite_jni_linkopts(), - framework_so=[], - deps=deps) +def tflite_cc_shared_object( + name, + copts = tflite_copts(), + linkopts = [], + linkstatic = 1, + deps = []): + """Builds a shared object for TFLite.""" + tf_cc_shared_object( + name = name, + copts = copts, + linkstatic = linkstatic, + linkopts = linkopts + tflite_jni_linkopts(), + framework_so = [], + deps = deps, + ) def tf_to_tflite(name, src, options, out): """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. @@ -243,6 +245,7 @@ def generated_test_models(): "local_response_norm", "log_softmax", "log", + "logical_or", "lstm", "max_pool", "maximum", diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 2ea7aeaa5d..aa65ec9988 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -831,6 +831,18 @@ Outputs { } ``` +**LOGICAL_OR** + +``` +Inputs { + 0: a list of tensors. + 1: a list of tensors. +} +Outputs { + 0: A tensor of logical_or output tensors. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 026ab4de03..329c98f91e 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -170,6 +170,7 @@ cc_library( "hashtable_lookup.cc", "l2norm.cc", "local_response_norm.cc", + "logical.cc", "lsh_projection.cc", "lstm.cc", "maximum_minimum.cc", @@ -1185,6 +1186,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "logical_test", + size = "small", + srcs = ["logical_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index ce16394082..714613b96e 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -4255,6 +4255,38 @@ inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims, } } +inline void Logical(const bool* input1_data, const Dims<4>& input1_dims, + const bool* input2_data, const Dims<4>& input2_dims, + bool* output_data, const Dims<4>& output_dims, + const std::function<bool(bool, bool)>& func) { + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = func(input1_data[i], input2_data[i]); + } +} + +inline void BroadcastLogical(const bool* input1_data, + const Dims<4>& input1_dims, + const bool* input2_data, + const Dims<4>& input2_dims, bool* output_data, + const Dims<4>& output_dims, + const std::function<bool(bool, bool)>& func) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + func(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc new file mode 100644 index 0000000000..3dc39bf79a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/logical.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace logical { +namespace { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for logical op. +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<OpData*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Reinterprete the opaque data provided by user. + OpData* data = reinterpret_cast<OpData*>(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + const TfLiteType type = input1->type; + if (type != kTfLiteBool) { + context->ReportError(context, "Logical ops only support bool type."); + return kTfLiteError; + } + output->type = type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node, + const std::function<bool(bool, bool)>& func) { + OpData* data = reinterpret_cast<OpData*>(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (data->requires_broadcast) { + reference_ops::BroadcastLogical( + GetTensorData<bool>(input1), GetTensorDims(input1), + GetTensorData<bool>(input2), GetTensorDims(input2), + GetTensorData<bool>(output), GetTensorDims(output), func); + } else { + reference_ops::Logical(GetTensorData<bool>(input1), GetTensorDims(input1), + GetTensorData<bool>(input2), GetTensorDims(input2), + GetTensorData<bool>(output), GetTensorDims(output), + func); + } + + return kTfLiteOk; +} + +TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) { + const auto logical_or_func = std::logical_or<bool>(); + return LogicalImpl(context, node, logical_or_func); +} + +} // namespace +} // namespace logical + +TfLiteRegistration* Register_LOGICAL_OR() { + // Init, Free, Prepare, Eval are satisfying the Interface required by + // TfLiteRegistration. + static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare, + logical::LogicalOrEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/logical_test.cc b/tensorflow/contrib/lite/kernels/logical_test.cc new file mode 100644 index 0000000000..382008245b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/logical_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +class LogicalOpModel : public SingleOpModel { + public: + LogicalOpModel(std::initializer_list<int> input1_shape, + std::initializer_list<int> input2_shape, BuiltinOperator op) { + input1_ = AddInput(TensorType_BOOL); + input2_ = AddInput(TensorType_BOOL); + output_ = AddOutput(TensorType_BOOL); + ConfigureBuiltinOp(op); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; + + void ConfigureBuiltinOp(BuiltinOperator op) { + switch (op) { + case BuiltinOperator_LOGICAL_OR: { + SetBuiltinOp(op, BuiltinOptions_LogicalOrOptions, + CreateLogicalOrOptions(builder_).Union()); + break; + } + default: { FAIL() << "We shouldn't get here."; } + } + } +}; + +TEST(LogicalTest, LogicalOr) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_OR); + model.PopulateTensor<bool>(model.input1(), {true, false, false, true}); + model.PopulateTensor<bool>(model.input2(), {true, false, true, false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, BroadcastLogicalOr) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_OR); + model.PopulateTensor<bool>(model.input1(), {true, false, false, true}); + model.PopulateTensor<bool>(model.input2(), {false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index da69b85041..e632728841 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -108,6 +108,7 @@ TfLiteRegistration* Register_POW(); TfLiteRegistration* Register_FAKE_QUANT(); TfLiteRegistration* Register_PACK(); TfLiteRegistration* Register_ONE_HOT(); +TfLiteRegistration* Register_LOGICAL_OR(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -199,6 +200,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2); AddBuiltin(BuiltinOperator_PACK, Register_PACK()); AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT()); + AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 4234d0b811..a95b26220d 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -231,6 +231,7 @@ _TF_TYPE_INFO = { tf.int32: (np.int32, "INT32"), tf.uint8: (np.uint8, "QUANTIZED_UINT8"), tf.int64: (np.int64, "INT64"), + tf.bool: (np.bool, "BOOL"), } @@ -244,7 +245,8 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): value = (max_value-min_value)*np.random.random_sample(shape)+min_value elif dtype in (tf.int32, tf.uint8, tf.int64): value = np.random.randint(min_value, max_value+1, shape) - + elif dtype == tf.bool: + value = np.random.choice([True, False], size=shape) return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype( dtype) @@ -2982,6 +2984,35 @@ def make_pack_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_logical_or_tests(zip_path): + """Make a set of tests to do logical_or.""" + + test_parameters = [{ + "input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the logical_or op testing graph.""" + input_value1 = tf.placeholder( + dtype=tf.bool, name="input1", shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=tf.bool, name="input2", shape=parameters["input_shape_pair"][1]) + out = tf.logical_or(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(tf.bool, + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(tf.bool, + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 9983e59910..378212cb74 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1925,6 +1925,21 @@ void ConvertLogicalNotOperator(const Model& model, *logical_op->add_input() = src_op.inputs[0]; } +void ConvertLogicalOrOperator(const Model& model, + const LogicalOrOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node(); + logical_or_op->set_op(op_name); + logical_or_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *logical_or_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*logical_or_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -2175,6 +2190,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kOneHot) { ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalOr) { + ConvertLogicalOrOperator(model, + static_cast<const LogicalOrOperator&>(src_op), + "LogicalOr", tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 0f94006f34..9cec6d65f3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -65,6 +65,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { case OperatorType::kAny: case OperatorType::kLogicalAnd: case OperatorType::kLogicalNot: + case OperatorType::kLogicalOr: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 5aa0fddf57..3c9379fd87 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1673,6 +1673,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kSin: case OperatorType::kLogicalAnd: case OperatorType::kLogicalNot: + case OperatorType::kLogicalOr: ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 7e593029e5..9a3db5c888 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1914,9 +1914,10 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>}, {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>}, {"Log", ConvertSimpleOperator<LogOperator, 1>}, - {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>}, {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>}, + {"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2>}, {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>}, + {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>}, {"MatMul", ConvertMatMulOperator}, {"Max", ConvertReduceOperator<TensorFlowMaxOperator>}, {"MaxPool", ConvertMaxPoolOperator}, diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index a3827977fd..7d0dbfcc05 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -147,6 +147,7 @@ enum class OperatorType : uint8 { kAny, kLogicalAnd, kLogicalNot, + kLogicalOr, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -1790,6 +1791,17 @@ struct OneHotOperator : Operator { int axis = -1; }; +// LogicalOr operator: +// +// Inputs: +// Inputs[0]: required: A Bool tensor. +// Inputs[1]: required: A Bool tensor. +// +// TensorFlow equivalent: LogicalOr. +struct LogicalOrOperator : Operator { + LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 769e350ea9..9380168f30 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1350,6 +1350,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { ops.emplace_back( new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice)); ops.emplace_back(new SimpleOperator<PowOperator>("POW", OperatorType::kPow)); + ops.emplace_back(new SimpleOperator<LogicalOrOperator>( + "LOGICAL_OR", OperatorType::kLogicalOr)); // Element-wise operator ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin)); ops.emplace_back(new SimpleOperator<LogOperator>("LOG", OperatorType::kLog)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 7e1e32ae54..384f7c118d 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -127,6 +127,8 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt); CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt); CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow); + CheckSimpleOperator<LogicalOrOperator>("LOGICAL_OR", + OperatorType::kLogicalOr); } TEST_F(OperatorTest, BuiltinAdd) { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 93c30dd0f8..7f5da251d0 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -403,6 +403,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Any) HANDLE_OPERATORTYPENAME_CASE(LogicalAnd) HANDLE_OPERATORTYPENAME_CASE(LogicalNot) + HANDLE_OPERATORTYPENAME_CASE(LogicalOr) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE |