aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-07 11:53:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-07 11:57:15 -0800
commit58fe7d26afa435560e7a0d8ca6fc8d670d2477da (patch)
treeb54039207f46348a83023da4a9148d4ad4a22a1d /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
parent85d02dcef3b0f0900b3d363056be4e177d4d70ab (diff)
Support for transpose convolution. Includes striding, and a reference implementation.
PiperOrigin-RevId: 188210975
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index bde947f78d..778da39bf1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -71,6 +71,11 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
CHECK_GE(op->inputs.size(), 2);
const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
+ } else if (op->type == OperatorType::kTransposeConv) {
+ // These operators produce an output with the same type as their 3rd input
+ CHECK_GE(op->inputs.size(), 3);
+ const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
} else if (op->type == OperatorType::kCast) {
// Data type of the Cast op is specified.
CHECK_EQ(op->outputs.size(), 1);