aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-06 14:49:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-06 14:56:36 -0800
commit697fa59286888d75f704629f4181dcca1bf6786a (patch)
tree8ec49e40d9aaf924755a87907617a17f530f1263 /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
parente70d5621fa3d75868a328440d37db2af47a19cd8 (diff)
Simplification of propagate_array_data_types:
No need to distinguish between ops for which we propagate the type of the first input, and ops for which we require all inputs to have the same type then propagate that type. Doing only the former is a mild relaxation of what we've been doing, allows to drop much code, and makes this code far less frequently needing to be updated again in the future as we drop a long non-canonical list of operator types being special-cased. PiperOrigin-RevId: 178156653
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc47
1 files changed, 3 insertions, 44 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 1ff4e827aa..550e0408aa 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -24,19 +24,6 @@ limitations under the License.
namespace toco {
namespace {
-
-ArrayDataType CommonDataTypeOfAllInputs(const Model& model,
- const Operator& op) {
- CHECK_GT(op.inputs.size(), 0);
- const ArrayDataType data_type = model.GetArray(op.inputs[0]).data_type;
- for (const auto& input : op.inputs) {
- const auto& array = model.GetArray(input);
- CHECK(array.data_type == data_type)
- << " Unexpected: this operator has inputs with different data types.";
- }
- return data_type;
-}
-
void SetDataTypeForAllOutputs(Model* model, Operator* op,
ArrayDataType data_type) {
for (const auto& output : op->outputs) {
@@ -75,34 +62,6 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
} else if (op->type == OperatorType::kTensorFlowShape) {
// These operators are assumed to produce int32 outputs.
SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
- } else if (op->type == OperatorType::kAveragePool ||
- op->type == OperatorType::kMaxPool ||
- op->type == OperatorType::kL2Pool ||
- op->type == OperatorType::kConv ||
- op->type == OperatorType::kDepthwiseConv ||
- op->type == OperatorType::kFullyConnected ||
- op->type == OperatorType::kTensorFlowMax ||
- op->type == OperatorType::kTensorFlowMin ||
- op->type == OperatorType::kPad ||
- op->type == OperatorType::kStridedSlice ||
- op->type == OperatorType::kTensorFlowReshape ||
- op->type == OperatorType::kSlice ||
- op->type == OperatorType::kSqueeze ||
- op->type == OperatorType::kTensorFlowSum ||
- op->type == OperatorType::kTensorFlowSwitch ||
- op->type == OperatorType::kTensorFlowTile ||
- op->type == OperatorType::kTensorFlowAll ||
- op->type == OperatorType::kReorderAxes ||
- op->type == OperatorType::kTensorFlowConcatV2 ||
- op->type == OperatorType::kFloor ||
- op->type == OperatorType::kGather ||
- op->type == OperatorType::kSpaceToBatchND ||
- op->type == OperatorType::kBatchToSpaceND ||
- op->type == OperatorType::kMean) {
- // These operators produce outputs with the same type as their 1st input
- CHECK_GT(op->inputs.size(), 0);
- const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
- SetDataTypeForAllOutputs(model, op, data_type);
} else if (op->type == OperatorType::kTensorFlowSplit ||
op->type == OperatorType::kTensorFlowConcat) {
// These operators produce an output with the same type as their 2nd input
@@ -125,9 +84,9 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
model->arrays[output]->data_type = data_type;
}
} else {
- // These operators produce an output with the same type as any of their
- // inputs, which must always have the same type.
- const ArrayDataType data_type = CommonDataTypeOfAllInputs(*model, *op);
+ // These operators produce outputs with the same type as their 1st input
+ CHECK_GT(op->inputs.size(), 0);
+ const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
SetDataTypeForAllOutputs(model, op, data_type);
}
// Return true if any output data type changed, false if none changed.