diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-14 14:18:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-14 14:25:30 -0700 |
commit | 840aeb0ce9bd0f0a1c275edc9fe6d51eff5cf33f (patch) | |
tree | ce3f002656fef2e12f28ef9b42de55acabb1d938 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | d943de372a989ca6bc44058e35ba9f26591b42b4 (diff) |
Merged commit includes the following changes:
200617269 by A. Unique TensorFlower:
Internal change
--
200603378 by jpienaar:
The output of the merge should be the value's and not the original output port.
The output port of the IfOp is already taken into account by selecting the
merge node and the output of the merge should be the value used (which is the 0th
output of the merge node).
--
200601721 by A. Unique TensorFlower:
Basic support for tf.tile that multiplies a single axis.
--
200600686 by A. Unique TensorFlower:
Internal change.
--
PiperOrigin-RevId: 200617269
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 53 |
1 files changed, 45 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index b6f0d96900..e7da9051d8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1509,6 +1509,48 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) { } } +void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& multiples_array = model->GetArray(op->inputs[1]); + if (!multiples_array.has_shape()) { + // Yield until multiples shape been resolved. + return; + } + if (!multiples_array.buffer) { + // Yield until the multiples is constant. + return; + } + CHECK(multiples_array.data_type == ArrayDataType::kInt32) + << "Tile multiples input must be int32"; + + std::vector<int32> const& multiples = + multiples_array.GetBuffer<ArrayDataType::kInt32>().data; + CHECK_EQ(multiples.size(), input_shape.dimensions_count()) + << "Tile multiples input " << op->inputs[1] + << " must be same length as input dimensions"; + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->resize(multiples.size()); + for (int i = 0; i < mutable_dims->size(); ++i) { + (*mutable_dims)[i] = input_shape.dims(i) * multiples[i]; + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1627,14 +1669,6 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSliceOperator(model, static_cast<SliceOperator*>(op)); break; - case OperatorType::kTensorFlowTile: - // We don't currently implement the propagation of fixed sizes through - // a TensorFlow Tile. - // - // Fortunately, we don't need to: so far, we have only dealt with Tile - // or Slice ops in subgraphs that are identified as L2Normalization. - // See IdentifyL2Normalization. - break; case OperatorType::kTensorFlowSwitch: // We can't know the sizes of the outputs until we have resolved the // predicate, and once we have resolved the predicate, the whole @@ -1738,6 +1772,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSparseToDenseOperator(model, static_cast<SparseToDenseOperator*>(op)); break; + case OperatorType::kTensorFlowTile: + ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); |