aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-09-13 15:34:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 15:39:27 -0700
commite59ddcca727340a8b45694a28cd9f52352607e63 (patch)
tree333e4da24dae2071c8946afd2c8a9def74f68889
parentea52ecd836098e0b1d37325cf1b91133f908547e (diff)
Automated rollback of commit 6b507a6de855a6f988100904229b7f46a5652b88
PiperOrigin-RevId: 212890622
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc18
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc75
3 files changed, 5 insertions, 89 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 72c71b2841..bea90f1ce8 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -331,7 +331,6 @@ cc_library(
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
- "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
] + select({
# Placeholder for internal darwin rule.
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index eb36b3411d..9bc23c4b3c 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -58,7 +58,6 @@ using tensorflow::DT_STRING;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::NodeDef;
-using tensorflow::OpRegistry;
using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
@@ -1080,23 +1079,6 @@ tensorflow::Status ConvertUnsupportedOperator(
} else if (HasAttr(node, "Tout")) {
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
- } else {
- const tensorflow::OpDef* op_def = nullptr;
- if (OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
- for (const auto& output_arg : op_def->output_arg()) {
- if (HasAttr(node, output_arg.type_attr())) {
- op->output_data_types.push_back(
- ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
- } else {
- LOG(INFO) << "Op node missing output type attribute: " << node.name();
- }
- }
- }
- if (op->output_data_types.empty()) {
- // TODO(b/113613439): Figure out how to propagate types for custom ops
- // that have no OpDef.
- LOG(INFO) << "Unable to determine output type for op: " << node.op();
- }
}
if (HasAttr(node, kAttrOutputShapes)) {
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index da248826a7..a00e136dd6 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -49,17 +49,6 @@ Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
namespace {
-Status ImportNode(const NodeDef& node, Model* model) {
- const auto converter = internal::GetTensorFlowNodeConverterMap();
- return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
- converter);
-}
-
-Status ImportNode(const NodeDef& node) {
- Model model;
- return ImportNode(node, &model);
-}
-
class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
protected:
ShapeImportTest() {}
@@ -120,24 +109,12 @@ class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
SetAttrValue(t, &value_attr);
(*node->mutable_attr())["value"] = value_attr;
}
-};
-
-class TypeImportTest : public ::testing::TestWithParam<
- std::pair<tensorflow::DataType, ArrayDataType>> {
- protected:
- TypeImportTest() {}
-
- void BuildUnaryNode(const std::string& op_name, tensorflow::DataType dtype,
- NodeDef* node) {
- node->set_op(op_name);
- node->set_name("Node1");
-
- node->add_input();
- node->set_input(0, "Node0");
- AttrValue dtype_attr;
- SetAttrValue(dtype, &dtype_attr);
- (*node->mutable_attr())["T"] = dtype_attr;
+ Status ImportNode(const NodeDef& node) {
+ Model model;
+ const auto converter = internal::GetTensorFlowNodeConverterMap();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model,
+ converter);
}
};
@@ -190,47 +167,5 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
-std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() {
- return {{DT_FLOAT, ArrayDataType::kFloat},
- {DT_INT32, ArrayDataType::kInt32},
- {DT_INT64, ArrayDataType::kInt64}};
-}
-
-TEST_P(TypeImportTest, BasicTypeInference) {
- NodeDef node;
- BuildUnaryNode("Atan", GetParam().first, &node);
-
- Model model;
- EXPECT_TRUE(ImportNode(node, &model).ok());
-
- ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
- ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
- const TensorFlowUnsupportedOperator* op =
- static_cast<const TensorFlowUnsupportedOperator*>(
- model.operators[0].get());
- ASSERT_THAT(op->output_data_types, ::testing::ElementsAre(GetParam().second));
-}
-INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest,
- ::testing::ValuesIn(UnaryTestTypes()));
-
-TEST(ImportTest, FailedTypeInference) {
- // Create a unary op with no Type ("T") annotation.
- NodeDef node;
- node.set_op("Atan");
- node.set_name("Node1");
- node.add_input();
- node.set_input(0, "Node0");
-
- Model model;
- EXPECT_TRUE(ImportNode(node, &model).ok());
-
- ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
- ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
- const TensorFlowUnsupportedOperator* op =
- static_cast<const TensorFlowUnsupportedOperator*>(
- model.operators[0].get());
- ASSERT_TRUE(op->output_data_types.empty());
-}
-
} // namespace
} // namespace toco