diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-01 16:33:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-01 16:35:52 -0700 |
commit | f5dbc1e16622f433f41f195bb33f56d674a004ce (patch) | |
tree | 8a08ec5c43192415056e0695337dd26e61256fcb /tensorflow/contrib/lite/toco/import_tensorflow_test.cc | |
parent | fb8f040f2a927c6df149238da7c4278cf781d081 (diff) |
Check for overflow in shape calculation.
PiperOrigin-RevId: 195017114
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow_test.cc | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc new file mode 100644 index 0000000000..5dc78f73ad --- /dev/null +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/import_tensorflow.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace toco { + +using port::Status; +using tensorflow::AttrValue; +using tensorflow::DT_BOOL; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::DT_INT64; +using tensorflow::DT_QUINT8; +using tensorflow::DT_STRING; +using tensorflow::NodeDef; + +namespace internal { +Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, + Model*); +} // namespace internal + +namespace { + +class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> { + protected: + ShapeImportTest() {} + + void BuildConstNode(std::initializer_list<int64_t> shape, + tensorflow::DataType dtype, int64_t num_elements, + NodeDef* node) { + node->set_op("Const"); + node->set_name("Node1"); + + // An attribute describing the type of this const node. + AttrValue dtype_attr; + SetAttrValue(dtype, &dtype_attr); + (*node->mutable_attr())["dtype"] = dtype_attr; + + // An attribute describing the content of this const node. + tensorflow::TensorProto t; + t.set_dtype(dtype); + auto* s = t.mutable_tensor_shape(); + for (auto d : shape) { + s->add_dim()->set_size(d); + } + + // TODO(ahentz): also need to test via tensor_content() + switch (dtype) { + case DT_FLOAT: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_float_val(i / 10000.0); + } + break; + case DT_INT32: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_int_val(i % std::numeric_limits<int>::max()); + } + break; + case DT_QUINT8: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_int_val(i % std::numeric_limits<uint8_t>::max()); + } + break; + case DT_INT64: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_int64_val(i); + } + break; + case DT_STRING: + break; + case DT_BOOL: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_bool_val(i % 2); + } + break; + default: + break; + } + + AttrValue value_attr; + SetAttrValue(t, &value_attr); + (*node->mutable_attr())["value"] = value_attr; + } + + Status ImportNode(const NodeDef& node) { + Model model; + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), + &model); + } +}; + +std::vector<tensorflow::DataType> TestTypes() { + return {DT_FLOAT, DT_INT32, DT_INT64, DT_BOOL, DT_QUINT8}; +} + +TEST_P(ShapeImportTest, ShapeElementIsNegative) { + NodeDef node; + BuildConstNode({1, -2, 10}, GetParam(), 0, &node); + auto status = ImportNode(node); + EXPECT_EQ(status.error_message(), + "Tensor shape should not include negative values (while processing " + "node 'Node1')"); +} +INSTANTIATE_TEST_CASE_P(ShapeElementIsNegative, ShapeImportTest, + ::testing::ValuesIn(TestTypes())); + +TEST_P(ShapeImportTest, ShapeElementTooLarge) { + NodeDef node; + BuildConstNode({3000000000}, GetParam(), 0, &node); + auto status = ImportNode(node); + EXPECT_EQ(status.error_message(), + "Shape element overflows (while processing node 'Node1')"); +} +INSTANTIATE_TEST_CASE_P(ShapeElementTooLarge, ShapeImportTest, + ::testing::ValuesIn(TestTypes())); + +TEST_P(ShapeImportTest, ShapeTooLarge) { + NodeDef node; + BuildConstNode({1000000, 2000000, 2000000, 2000000}, GetParam(), 0, &node); + auto status = ImportNode(node); + EXPECT_EQ(status.error_message(), + "Tensor shape is too large (while processing node 'Node1')"); +} +INSTANTIATE_TEST_CASE_P(ShapeTooLarge, ShapeImportTest, + ::testing::ValuesIn(TestTypes())); + +TEST_P(ShapeImportTest, ValidShapeButZeroElements) { + NodeDef node; + BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node); + auto status = ImportNode(node); + EXPECT_THAT(status.error_message(), + ::testing::MatchesRegex( + "Neither input_content nor .*_val have the right dimensions " + "for this .* tensor .while processing node 'Node1'.")); +} +INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest, + ::testing::ValuesIn(TestTypes())); + +} // namespace +} // namespace toco |