aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-14 14:18:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 14:25:30 -0700
commit840aeb0ce9bd0f0a1c275edc9fe6d51eff5cf33f (patch)
treece3f002656fef2e12f28ef9b42de55acabb1d938 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parentd943de372a989ca6bc44058e35ba9f26591b42b4 (diff)
Merged commit includes the following changes:
200617269 by A. Unique TensorFlower: Internal change -- 200603378 by jpienaar: The output of the merge should be the value's and not the original output port. The output port of the IfOp is already taken into account by selecting the merge node and the output of the merge should be the value used (which is the 0th output of the merge node). -- 200601721 by A. Unique TensorFlower: Basic support for tf.tile that multiplies a single axis. -- 200600686 by A. Unique TensorFlower: Internal change. -- PiperOrigin-RevId: 200617269
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc53
1 files changed, 45 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index b6f0d96900..e7da9051d8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1509,6 +1509,48 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
}
}
+void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // We have already run.
+ return;
+ }
+
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // Yield until input dims have been resolved.
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+
+ auto& multiples_array = model->GetArray(op->inputs[1]);
+ if (!multiples_array.has_shape()) {
+ // Yield until multiples shape been resolved.
+ return;
+ }
+ if (!multiples_array.buffer) {
+ // Yield until the multiples is constant.
+ return;
+ }
+ CHECK(multiples_array.data_type == ArrayDataType::kInt32)
+ << "Tile multiples input must be int32";
+
+ std::vector<int32> const& multiples =
+ multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CHECK_EQ(multiples.size(), input_shape.dimensions_count())
+ << "Tile multiples input " << op->inputs[1]
+ << " must be same length as input dimensions";
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->resize(multiples.size());
+ for (int i = 0; i < mutable_dims->size(); ++i) {
+ (*mutable_dims)[i] = input_shape.dims(i) * multiples[i];
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1627,14 +1669,6 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
break;
- case OperatorType::kTensorFlowTile:
- // We don't currently implement the propagation of fixed sizes through
- // a TensorFlow Tile.
- //
- // Fortunately, we don't need to: so far, we have only dealt with Tile
- // or Slice ops in subgraphs that are identified as L2Normalization.
- // See IdentifyL2Normalization.
- break;
case OperatorType::kTensorFlowSwitch:
// We can't know the sizes of the outputs until we have resolved the
// predicate, and once we have resolved the predicate, the whole
@@ -1738,6 +1772,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessSparseToDenseOperator(model,
static_cast<SparseToDenseOperator*>(op));
break;
+ case OperatorType::kTensorFlowTile:
+ ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);