aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-08 13:25:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-08 13:30:15 -0800
commitdc04e89bc6f0421bf77ac69f21c1f2f57618f53c (patch)
tree49b05183e6a6dd2e8a5afff97a0fe35615de1127 /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
parent1afc6149ed0649971d83fe8e9748056285dcf332 (diff)
Adding support for new TensorFlow operators. Also adding a transformation to convert an ExpandDims into a Reshape op.
PiperOrigin-RevId: 178418377
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc11
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 550e0408aa..1d92bcbccd 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -59,13 +59,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
op->type == OperatorType::kTensorFlowGreaterEqual) {
// These operators unconditionally produce bool outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
- } else if (op->type == OperatorType::kTensorFlowShape) {
+ } else if (op->type == OperatorType::kRank ||
+ op->type == OperatorType::kTensorFlowShape) {
// These operators are assumed to produce int32 outputs.
SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
} else if (op->type == OperatorType::kTensorFlowSplit ||
- op->type == OperatorType::kTensorFlowConcat) {
+ op->type == OperatorType::kTensorFlowConcat ||
+ op->type == OperatorType::kFill) {
// These operators produce an output with the same type as their 2nd input
- CHECK_GT(op->inputs.size(), 1);
+ CHECK_GE(op->inputs.size(), 2);
const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type;
SetDataTypeForAllOutputs(model, op, data_type);
} else if (op->type == OperatorType::kCast) {
@@ -83,6 +85,9 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
auto data_type = unsupported_op->output_data_types[i];
model->arrays[output]->data_type = data_type;
}
+ } else if (op->type == OperatorType::kExpandDims) {
+ // Yield on ExpandDim until it is converted to Reshape
+ return false;
} else {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);