diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-17 13:29:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-17 13:33:14 -0800 |
commit | 29baea36e3b374a852ad3dedc1c3719016febdc4 (patch) | |
tree | dbeca52751cc4fea3c2b3fd776468043d1900f0d | |
parent | d8697935d334bb0f2e1c9bccfe9a2a7bee9785cc (diff) |
Adding support for to resolve constant FloorDiv, FloorMod, StridedSlice, Stack, Rank and Range.
PiperOrigin-RevId: 182261052
17 files changed, 934 insertions, 156 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 741f9d4bfb..cea5b4e92d 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -171,6 +171,7 @@ cc_library( srcs = [ "graph_transformations/convert_expanddims_to_reshape.cc", "graph_transformations/convert_pure_conv_to_depthwise.cc", + "graph_transformations/convert_trivial_transpose_to_reshape.cc", "graph_transformations/create_im2col_arrays.cc", "graph_transformations/dequantize.cc", "graph_transformations/drop_fake_quant.cc", @@ -207,7 +208,10 @@ cc_library( "graph_transformations/resolve_constant_concatenation.cc", "graph_transformations/resolve_constant_fake_quant.cc", "graph_transformations/resolve_constant_fill.cc", - "graph_transformations/resolve_constant_tensorflow_shape.cc", + "graph_transformations/resolve_constant_range.cc", + "graph_transformations/resolve_constant_shape_or_rank.cc", + "graph_transformations/resolve_constant_stack.cc", + "graph_transformations/resolve_constant_strided_slice.cc", "graph_transformations/resolve_constant_unary.cc", "graph_transformations/resolve_mean_attributes.cc", "graph_transformations/resolve_pad_attributes.cc", diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc index 62e7282d16..d4da8f5dfe 100644 --- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc @@ -239,8 +239,8 @@ void AllocateTransientArrays(Model* model, // is a misnormer, should read 'workspace'. for (const auto& array_pair : ordered_arrays_map) { const string& array_name = array_pair.first; - const auto& array_lifespan = array_lifespans.find(array_name)->second; - if (array_lifespan.persistent) { + auto it = array_lifespans.find(array_name); + if (it != array_lifespans.end() && it->second.persistent) { AllocateTransientArray(*model, array_name, &allocator, transient_data_alignment); } @@ -282,8 +282,8 @@ void AllocateTransientArrays(Model* model, std::size_t persistent_alloc_size = 0; for (const auto& array_pair : ordered_arrays_map) { const string& array_name = array_pair.first; - const auto& array_lifespan = array_lifespans.find(array_name)->second; - if (array_lifespan.persistent) { + auto it = array_lifespans.find(array_name); + if (it != array_lifespans.end() && it->second.persistent) { persistent_alloc_size += TransientArraySize(*model, array_name, transient_data_alignment); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc new file mode 100644 index 0000000000..a234c20924 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc @@ -0,0 +1,85 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { + auto transpose_it = model->operators.begin() + op_index; + if (transpose_it->get()->type != OperatorType::kTranspose) { + return false; + } + TransposeOperator* transpose_op = + static_cast<TransposeOperator*>(transpose_it->get()); + + const auto& output_array = *model->arrays[transpose_op->outputs[0]]; + if (!output_array.has_shape()) { + // Yield until PropagateFixedSizes has been run on this op. + return false; + } + // Note: We can assume we have error checked inputs in PropagateFixedSizes. + + // This transpose is trivial if we only have one non-unitary dimension. + std::vector<int> const& dims = output_array.shape().dims(); + unsigned non_unitary_axis_count = 0; + for (int i = 0; i < dims.size(); i++) { + if (dims[i] != 1) { + non_unitary_axis_count++; + } + } + if (non_unitary_axis_count > 1) { + // Transpose is not trivial + return false; + } + + // This transpose is trivial. Replace it with a Reshape op. + auto* reshape_op = new TensorFlowReshapeOperator; + + // Copy input and output + reshape_op->inputs.push_back(transpose_op->inputs[0]); + reshape_op->outputs = transpose_op->outputs; + + // Create a new input array for the shape input + string perm_array_name = transpose_op->inputs[1]; + string shape_array_name = toco::AvailableArrayName(*model, perm_array_name); + Array& shape_array = model->GetOrCreateArray(shape_array_name); + *(shape_array.mutable_shape()->mutable_dims()) = { + 1, static_cast<int>(dims.size())}; + reshape_op->inputs.push_back(shape_array_name); + shape_array.data_type = ArrayDataType::kInt32; + auto& shape_buffer = shape_array.GetMutableBuffer<ArrayDataType::kInt32>(); + shape_buffer.data = dims; + + // Delete perm array if unused + if (IsDiscardableArray(*model, perm_array_name) && + CountOpsWithInput(*model, perm_array_name) == 1) { + model->arrays.erase(perm_array_name); + } + + // Replace the operator in the graph. + const auto reshape_it = model->operators.emplace(transpose_it, reshape_op); + transpose_it = reshape_it + 1; + CHECK_EQ(transpose_it->get(), transpose_op); + model->operators.erase(transpose_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 785dad8596..9300ab53a7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -114,6 +114,7 @@ void RunGraphTransformations(Model* model, const string& message, // List of all graph transformations DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) +DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) @@ -159,7 +160,10 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes) -DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) DECLARE_GRAPH_TRANSFORMATION(Dequantize) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 4fe127544b..c6f17cf319 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -61,7 +61,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); } else if (op->type == OperatorType::kRank || op->type == OperatorType::kTensorFlowShape) { - // These operators are assumed to produce int32 outputs. + // These operators only produce int32 outputs. SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); } else if (op->type == OperatorType::kTensorFlowSplit || op->type == OperatorType::kTensorFlowConcat || @@ -80,6 +80,20 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->outputs.size(), 1); auto* argmax_op = static_cast<ArgMaxOperator*>(op); model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type; + } else if (op->type == OperatorType::kRange) { + auto* range_op = static_cast<RangeOperator*>(op); + // Output type of the Range op can be set via an attribute + ArrayDataType data_type; + if (range_op->dtype != ArrayDataType::kNone) { + // Use the type if specified + data_type = range_op->dtype; + } else { + // Otherwise use the first input + CHECK_GE(op->inputs.size(), 1); + data_type = model->arrays[op->inputs[0]]->data_type; + } + CHECK_EQ(op->outputs.size(), 1); + SetDataTypeForAllOutputs(model, op, data_type); } else if (op->type == OperatorType::kTensorFlowUnsupported) { auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op); if (unsupported_op->output_data_types.size() != op->outputs.size()) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 8f181e78d8..a939efb4db 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -316,25 +316,30 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { void ProcessTensorFlowReshapeOperator(Model* model, TensorFlowReshapeOperator* op) { auto& output_array = *model->arrays[op->outputs[0]]; - // Bail if we already have output dims if (output_array.has_shape()) { + // We have already run return; } const auto& input_array = *model->arrays[op->inputs[0]]; - // Yield until input dims have been resolved. if (!input_array.has_shape()) { + // Yield until input dims have been resolved. return; } const auto& input_shape = input_array.shape(); - const string& shape_name = op->inputs[1]; - auto& shape_array = model->GetArray(shape_name); - // Yield until the shape is resolved as a constant array + auto& shape_array = model->GetArray(op->inputs[1]); + if (!shape_array.has_shape()) { + // Yield until target_shape shape been resolved. + return; + } if (!shape_array.buffer) { + // Yield until the target_shape is constant return; } - CHECK(shape_array.data_type == ArrayDataType::kInt32); + CHECK(shape_array.data_type == ArrayDataType::kInt32) + << "Reshape dims must be int32"; + // shape_data is the raw array of ints describing the shape // in the TensorFlow node. We intentionally make a copy here, rather than // modify wildcards in-place below, because in some graphs, the same shape @@ -357,12 +362,18 @@ void ProcessTensorFlowReshapeOperator(Model* model, } const int input_flat_size = RequiredBufferSizeForShape(input_shape); if (has_wildcard) { + CHECK_GE(input_flat_size, product_non_wildcard_dims) + << "Array not large enough to fill the requested dimensions for " + "Reshape op with output \"" + << op->outputs[0] << "\". Are your input shapes correct?"; shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims; } auto& output_shape = *output_array.mutable_shape(); *output_shape.mutable_dims() = shape_data; - const int output_flat_size = RequiredBufferSizeForShape(output_shape); - CHECK_EQ(output_flat_size, input_flat_size); + CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape)) + << "Input cannot be reshaped to requested dimensions for Reshape op with " + "output \"" + << op->outputs[0] << "\". Are your input shapes correct?"; } void ProcessSimpleOperator(Model* model, Operator* op) { @@ -535,6 +546,56 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { output_dims[op->axis] = concat_size; } +void ProcessRangeOperator(Model* model, RangeOperator* op) { + CHECK_EQ(op->inputs.size(), 3); + const auto& start_array = *model->arrays[op->inputs[0]]; + if (!start_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& limit_array = *model->arrays[op->inputs[1]]; + if (!limit_array.has_shape()) { + return; + } + const auto& delta_array = *model->arrays[op->inputs[2]]; + if (!delta_array.has_shape()) { + return; + } + + if (!IsConstantParameterArray(*model, op->inputs[0])) { + // Yield until inputs are constant. + return; + } + if (!IsConstantParameterArray(*model, op->inputs[1])) { + return; + } + if (!IsConstantParameterArray(*model, op->inputs[2])) { + return; + } + + CHECK(start_array.data_type == ArrayDataType::kInt32) + << "Range op inputs must be int32."; + CHECK(limit_array.data_type == ArrayDataType::kInt32) + << "Range op inputs must be int32."; + CHECK(delta_array.data_type == ArrayDataType::kInt32) + << "Range op inputs must be int32."; + CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1) + << "Range op inputs must be scalar."; + CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1) + << "Range op inputs must be scalar."; + CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1) + << "Range op inputs must be scalar."; + int size = floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] - + start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) / + delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]); + + // Only set the output shape. Contents are set by ResolveConstantRange. + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + Shape* output_shape = output_array.mutable_shape(); + output_shape->ReplaceDims({size}); +} + void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { CHECK_EQ(op->inputs.size(), 2); const string& input_name = op->inputs[1]; @@ -885,35 +946,166 @@ void ProcessPadOperator(Model* model, PadOperator* op) { output_array.copy_shape(output_shape); } -void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { - CHECK_EQ(op->inputs.size(), 4); +void ProcessRankOperator(Model* model, RankOperator* op) { + CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) { + // Shape already propagated + return; + } const auto& input_array = *model->arrays[op->inputs[0]]; + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } - // Yield until input dims have been resolved. - if (!input_array.has_shape()) return; + // Only set the output shape. Array contents are set by + // ResolveConstantShapeOrRank. + Shape* output_shape = output_array.mutable_shape(); + output_shape->ReplaceDims({}); +} + +void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) { + CHECK_GE(op->inputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) { + // Shape already propagated + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } - if (op->start_indices.empty()) return; - CHECK_EQ(op->start_indices.size(), op->stop_indices.size()); - CHECK_EQ(op->start_indices.size(), op->strides.size()); + // Only set the output shape. Array contents are set by + // ResolveConstantShapeOrRank. + Shape* output_shape = output_array.mutable_shape(); + output_shape->ReplaceDims({input_array.shape().dimensions_count()}); +} +void ProcessStackOperator(Model* model, StackOperator* op) { + CHECK_GE(op->inputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); auto& output_array = *model->arrays[op->outputs[0]]; - if (output_array.has_shape()) return; + if (output_array.has_shape()) { + // Shape already propagated + return; + } - Shape output_shape = input_array.shape(); - std::vector<int>& dims = *output_shape.mutable_dims(); - CHECK_EQ(op->start_indices.size(), dims.size()); + std::unique_ptr<Shape> stacked_shape; + for (const auto& input : op->inputs) { + const auto& input_array = model->GetArray(input); + if (!input_array.has_shape()) { + // Yield until all input dims have been resolved. + return; + } - for (int i = 0; i < op->start_indices.size(); ++i) { - const int mask = 1 << i; - const int start = (op->begin_mask & mask) ? 0 : op->start_indices[i]; - const int stop = (op->end_mask & mask) ? input_array.shape().dims()[i] - : op->stop_indices[i]; - dims[i] = (stop - start) / op->strides[i]; + Shape shape = input_array.shape(); + if (shape.dimensions_count() == 0) { + // Convert 0D scalars to 1D scalars of shape {1}. + shape.mutable_dims()->push_back(1); + } + if (!stacked_shape) { + stacked_shape.reset(new Shape(shape)); + } else { + CHECK(*stacked_shape == shape) << "All input arrays to Stack operators " + "must have the same shape. Input \"" + << input << "\" is different."; + } } - output_array.copy_shape(output_shape); + int axis = op->axis; + if (axis < 0) { + // Handle negative axis + axis += stacked_shape->dims().size() + 1; + } + stacked_shape->mutable_dims()->insert( + stacked_shape->mutable_dims()->begin() + axis, op->inputs.size()); + output_array.copy_shape(*stacked_shape); +} + +void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { + CHECK_GE(op->inputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) { + // Shape already propagated + return; + } + + if (op->start_indices.empty() || op->stop_indices.empty() || + op->strides.empty()) { + // ResolveStridedSliceAttributes has not run yet. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + + if (op->ellipsis_mask != 0) { + // Something like LOG_FIRST_N(WARNING, 10) would be prefferable to reduce + // log noise. However, the TensorFlow logging library does not appear to + // support this. + LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0] + << "\". ellipsis_mask is not supported (mask=" + << op->ellipsis_mask << ")"; + return; + } + if (op->new_axis_mask != 0) { + LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0] + << "\". new_axis_mask is not supported (mask=" + << op->new_axis_mask << ")"; + return; + } + + int dim_count = input_array.shape().dimensions_count(); + CHECK(op->start_indices.size() == dim_count) + << ": Incorrect number of start indices supplied to StridedSlice op with " + "output \"" + << op->outputs[0] << "\". Op requires " << dim_count << " start indices"; + CHECK(op->stop_indices.size() == dim_count) + << ": Incorrect number of stop indices supplied to StridedSlice op with " + "output \"" + << op->outputs[0] << "\". Op requires " << dim_count << " stop indices"; + CHECK(op->strides.size() == dim_count) + << ": Incorrect number of strides supplied to StridedSlice op with " + " output \"" + << op->outputs[0] << "\". Op requires " << dim_count << " strides"; + + // Create output shape + std::vector<int>* dims = output_array.mutable_shape()->mutable_dims(); + + // Compute output shape + for (int i = 0; i < dim_count; ++i) { + const int mask = 1 << i; + int start = (op->begin_mask & mask) ? 0 : op->start_indices[i]; + if (start < 0) { + // handle negative indices + start += input_array.shape().dims(i); + } + int stop = (op->end_mask & mask) ? input_array.shape().dims(i) + : op->stop_indices[i]; + if (stop < 0) { + // handle negative indices + stop += input_array.shape().dims(i); + } + + int dim_size = (stop - start) / op->strides[i]; + if (op->shrink_axis_mask & mask) { + CHECK_EQ(dim_size, 1) << "Output size for an axis must compute to 1 when " + "shrinking that axis"; + } else { + dims->push_back(dim_size); + } + } } void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) { @@ -971,6 +1163,45 @@ void ProcessSvdfOperator(Model* model, SvdfOperator* op) { output_array.mutable_shape()->ReplaceDims({batch_size, num_units}); } +void ProcessTransposeOperator(Model* model, TransposeOperator* op) { + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) { + // We have already run + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& perm_array = model->GetArray(op->inputs[1]); + if (!perm_array.has_shape()) { + // Yield until permutation shape been resolved. + return; + } + if (!perm_array.buffer) { + // Yield until the permutation is constant + return; + } + CHECK(perm_array.data_type == ArrayDataType::kInt32) + << "Transpose permutation input must be int32"; + + std::vector<int32> const& perm = + perm_array.GetBuffer<ArrayDataType::kInt32>().data; + CHECK_EQ(perm.size(), input_shape.dimensions_count()) + << "Transpose permutation input must be same length as input dimensions"; + std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims(); + for (int i = 0; i < perm.size(); i++) { + int axis = perm[i]; + CHECK_GE(axis, 0); + CHECK_LT(axis, input_shape.dimensions_count()); + output_dims->push_back(input_shape.dims(axis)); + } +} + void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { CHECK_EQ(op->inputs.size(), 2); const auto& input_array = *model->arrays[op->inputs[0]]; @@ -1136,13 +1367,19 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { // or else at the moment we will abort. break; case OperatorType::kExpandDims: + // Yield until ExpandDims is converted to Reshape + break; case OperatorType::kRange: + ProcessRangeOperator(model, static_cast<RangeOperator*>(op)); + break; case OperatorType::kRank: + ProcessRankOperator(model, static_cast<RankOperator*>(op)); + break; case OperatorType::kTensorFlowShape: + ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op)); + break; case OperatorType::kStack: - case OperatorType::kTranspose: - // Unimplemented. Hopefully another graph transformation will drop it or - // rewrite it. + ProcessStackOperator(model, static_cast<StackOperator*>(op)); break; case OperatorType::kReorderAxes: ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op)); @@ -1185,6 +1422,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kSvdf: ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op)); break; + case OperatorType::kTranspose: + ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc index a4f198e92f..7777d4f543 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc @@ -37,26 +37,19 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { } CHECK_EQ(op->inputs.size(), 3); - if (!IsConstantParameterArray(*model, op->inputs[1]) or + if (!IsConstantParameterArray(*model, op->inputs[1]) || !IsConstantParameterArray(*model, op->inputs[2])) return false; - // Handling block_shape. - const auto& block_shape_array = *model->arrays[op->inputs[1]]; - if (!block_shape_array.has_shape()) return false; - const std::vector<int>& block_shape_dims = block_shape_array.shape().dims(); - CHECK_EQ(block_shape_dims.size(), 1); - std::vector<int> block_shape_buffer = - block_shape_array.GetBuffer<ArrayDataType::kInt32>().data; - for (int i = 0; i < block_shape_dims[0]; ++i) { - op->block_shape.push_back(block_shape_buffer[i]); - } - - // Handling crops. + // Handle crops const auto& crops_array = *model->arrays[op->inputs[2]]; if (!crops_array.has_shape()) return false; const std::vector<int>& crops_dims = crops_array.shape().dims(); - CHECK_EQ(crops_dims.size(), 2); + if (crops_dims.size() != 2) { + // Code only handles crops of 2 dimensions. Perhaps another transformation + // will delete this op. + return false; + } std::vector<int> crops_buffer = crops_array.GetBuffer<ArrayDataType::kInt32>().data; for (int i = 0; i < crops_dims[0]; ++i) { @@ -64,6 +57,17 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { op->after_crops.push_back(crops_buffer[i * 2 + 1]); } + // Handle block_shape + const auto& block_shape_array = *model->arrays[op->inputs[1]]; + if (!block_shape_array.has_shape()) return false; + const std::vector<int>& block_shape_dims = block_shape_array.shape().dims(); + CHECK_EQ(block_shape_dims.size(), 1); + std::vector<int> block_shape_buffer = + block_shape_array.GetBuffer<ArrayDataType::kInt32>().data; + for (int i = 0; i < block_shape_dims[0]; ++i) { + op->block_shape.push_back(block_shape_buffer[i]); + } + return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc index 53e1be7a05..fd51df4058 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -141,6 +141,10 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model, outval = val0 - val1; } else if (binary_op->type == OperatorType::kDiv) { outval = val0 / val1; + } else if (binary_op->type == OperatorType::kFloorDiv) { + outval = floor(val0 / val1); + } else if (binary_op->type == OperatorType::kFloorMod) { + outval = val0 - (floor(val0 / val1) * val1); } else if (binary_op->type == OperatorType::kTensorFlowMinimum) { outval = std::min(val0, val1); } else if (binary_op->type == OperatorType::kTensorFlowMaximum) { @@ -191,6 +195,8 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { binary_op->type != OperatorType::kMul && binary_op->type != OperatorType::kSub && binary_op->type != OperatorType::kDiv && + binary_op->type != OperatorType::kFloorDiv && + binary_op->type != OperatorType::kFloorMod && binary_op->type != OperatorType::kTensorFlowMinimum && binary_op->type != OperatorType::kTensorFlowMaximum && binary_op->type != OperatorType::kTensorFlowLess && diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc new file mode 100644 index 0000000000..383d54aa5a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* base_op = it->get(); + if (base_op->type != OperatorType::kRange) { + return false; + } + auto* op = static_cast<RangeOperator*>(base_op); + + CHECK_EQ(op->inputs.size(), 3); + const auto& start_array = *model->arrays[op->inputs[0]]; + if (!start_array.has_shape()) { + // Yield until all input dims have been resolved. + return false; + } + const auto& limit_array = *model->arrays[op->inputs[1]]; + if (!limit_array.has_shape()) { + // Yield until all input dims have been resolved. + return false; + } + const auto& delta_array = *model->arrays[op->inputs[2]]; + if (!delta_array.has_shape()) { + // Yield until all input dims have been resolved. + return false; + } + + for (const auto& input : op->inputs) { + if (!IsConstantParameterArray(*model, input)) { + // yield if any input is mutable + return false; + } + } + + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return false; + } + + CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1) + << "Range op inputs must be scalar."; + CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1) + << "Range op inputs must be scalar."; + CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1) + << "Range op inputs must be scalar."; + + CHECK(start_array.data_type == ArrayDataType::kInt32) + << "Range op inputs must be int32."; + CHECK(limit_array.data_type == ArrayDataType::kInt32) + << "Range op inputs must be int32."; + CHECK(delta_array.data_type == ArrayDataType::kInt32) + << "Range op inputs must be int32."; + + // Compute buffer contents + int start = start_array.GetBuffer<ArrayDataType::kInt32>().data[0]; + int limit = limit_array.GetBuffer<ArrayDataType::kInt32>().data[0]; + int delta = delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]; + auto& buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>(); + buffer.data.clear(); + for (int32 val = start; val < limit; val += delta) { + buffer.data.push_back(val); + } + CHECK_EQ(floor((limit - start) / delta), buffer.data.size()); + CHECK_EQ(buffer.data.size(), output_array.shape().dims()[0]); + + // Delete the input array if no longer used + if (IsDiscardableArray(*model, op->inputs[0]) && + CountOpsWithInput(*model, op->inputs[0]) == 1) { + model->arrays.erase(op->inputs[0]); + } + if (IsDiscardableArray(*model, op->inputs[1]) && + CountOpsWithInput(*model, op->inputs[1]) == 1) { + model->arrays.erase(op->inputs[1]); + } + if (IsDiscardableArray(*model, op->inputs[2]) && + CountOpsWithInput(*model, op->inputs[2]) == 1) { + model->arrays.erase(op->inputs[2]); + } + + // Delete the operator + model->operators.erase(it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc new file mode 100644 index 0000000000..e15ee7805c --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + const auto* op = it->get(); + if (!(op->type == OperatorType::kTensorFlowShape || + op->type == OperatorType::kRank)) { + return false; + } + + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been resolved + return false; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until the input array's shape has been resolved. + return false; + } + + if (!output_array.has_shape()) { + // Yield until the output shape has been resolved. + return false; + } + + // Compute the output + CHECK(!output_array.buffer); + auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>(); + if (op->type == OperatorType::kTensorFlowShape) { + // Copy the input shape into the output buffer. + output_buffer.data = input_array.shape().dims(); + } else if (op->type == OperatorType::kRank) { + // Copy the dimension count into the output buffer. + output_buffer.data.resize(1); + output_buffer.data[0] = input_array.shape().dimensions_count(); + } + + // Delete the input array if no longer used + if (IsDiscardableArray(*model, op->inputs[0]) && + CountOpsWithInput(*model, op->inputs[0]) == 1) { + model->arrays.erase(op->inputs[0]); + } + + model->operators.erase(it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc new file mode 100644 index 0000000000..86c76141a4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +template <ArrayDataType Type> +void Stack(Model* model, StackOperator const& op) { + auto& output_array = model->GetArray(op.outputs[0]); + CHECK(output_array.data_type == Type); + + // Create a buffer for the output array + std::vector<DataType<Type>>& output_data = + output_array.GetMutableBuffer<Type>().data; + output_data.resize(RequiredBufferSizeForShape(output_array.shape())); + + // Stack inputs into buffer + CHECK_EQ(op.axis, 0) << "Stacking only supported along first axis"; + int dst_offset = 0; + for (int i = 0; i < op.inputs.size(); i++) { + // Append array data to output for each input array + const auto& input_array = model->GetArray(op.inputs[i]); + int input_size = RequiredBufferSizeForShape(input_array.shape()); + memcpy(&output_data[dst_offset], &input_array.GetBuffer<Type>().data[0], + input_size * sizeof(Type)); + dst_offset += input_size; + } + CHECK_EQ(dst_offset, output_data.size()); +} + +} // namespace + +bool ResolveConstantStack::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kStack) { + return false; + } + const auto* op = static_cast<const StackOperator*>(base_op); + + CHECK_GE(op->inputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return false; + } + + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes + return false; + } + + for (const auto& input : op->inputs) { + if (!IsConstantParameterArray(*model, input)) { + // Yield if any input is mutable + return false; + } + } + + CHECK(!output_array.buffer); + switch (output_array.data_type) { + case ArrayDataType::kFloat: + Stack<ArrayDataType::kFloat>(model, *op); + break; + case ArrayDataType::kUint8: + Stack<ArrayDataType::kUint8>(model, *op); + break; + case ArrayDataType::kInt32: + Stack<ArrayDataType::kInt32>(model, *op); + break; + case ArrayDataType::kInt64: + Stack<ArrayDataType::kInt64>(model, *op); + break; + default: + LOG(FATAL) << "Unsupported data type given to Stack op with output \"" + << op->outputs[0] << "\""; + break; + } + + // Erase input arrays if no longer used + for (const auto& input : op->inputs) { + if (IsDiscardableArray(*model, input) && + CountOpsWithInput(*model, input) == 1) { + model->arrays.erase(input); + } + } + + // Erase the operator + model->operators.erase(it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc new file mode 100644 index 0000000000..3976d9cbb4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -0,0 +1,198 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +int StartForAxis(StridedSliceOperator const& op, Shape const& input_shape, + int axis) { + int start; + if (op.begin_mask & 1 << axis) { + // If begin mask bit is set, use the first element + start = 0; + } else { + // Otherwise, use the specified element + start = op.start_indices[axis]; + if (start < 0) { + // Handle negative indices + start += input_shape.dims(axis); + } + } + return start; +} + +int StopForAxis(StridedSliceOperator const& op, Shape const& input_shape, + int axis) { + int stop; + if (op.end_mask & (1 << axis)) { + // If end mask bit set, use the last element + stop = input_shape.dims(axis); + } else { + // Otherwise, use the specified element + stop = op.stop_indices[axis]; + if (stop < 0) { + // Handle negative indices + stop += input_shape.dims(axis); + } + } + return stop; +} + +template <ArrayDataType Type> +void StridedSlice(StridedSliceOperator const& op, Array const& input_array, + Array* output_array) { + // The TensorFlow documentation for StridedSlice is a bit ambiguous in places + // (https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/strided-slice). + // Use the source code at /third_party/tensorflow/core/util/strided_op.cc as + // "master documentation". + + CHECK(input_array.data_type == Type); + CHECK(output_array->data_type == Type); + CHECK_EQ(op.ellipsis_mask, 0); + CHECK_EQ(op.new_axis_mask, 0); + + int num_input_axes = op.start_indices.size(); + CHECK_EQ(num_input_axes, op.stop_indices.size()); + CHECK_EQ(num_input_axes, op.strides.size()); + for (int i = 0; i < op.strides.size(); i++) { + CHECK_GE(op.strides[i], 0) << "Negative strides usupported"; + } + + // Create a buffer for the output array + std::vector<DataType<Type>>& output_data = + output_array->GetMutableBuffer<Type>().data; + output_data.resize(RequiredBufferSizeForShape(output_array->shape())); + + // Initialize source coordinate + Shape const& input_shape = input_array.shape(); + Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>(); + std::vector<int> src_coord(op.start_indices.size()); + for (int axis = 0; axis < num_input_axes; axis++) { + src_coord[axis] = StartForAxis(op, input_shape, axis); + } + + // In order to handle any number (N) of dimensions, we copy elements one by + // one and treat the source coordinate as an N digit number (src_coord here). + // Each "digit" is incremented individually (by the stride). When it overflows + // (becomes greater than the stop), that digit is reset and a carry flag is + // used to increment the next digit. + int dst_offset = 0; + do { + // Copy element. + output_data[dst_offset] = input_buffer.data[Offset(input_shape, src_coord)]; + + // Compute next source input coordinates. + bool carry = true; + for (int axis = 0; axis < num_input_axes; axis++) { + // Increment this axis if we carried from the previous one + if (carry) { + src_coord[axis] += op.strides[axis]; + } + + // Check if we've overflowed. + if (src_coord[axis] >= StopForAxis(op, input_shape, axis)) { + // Reset axis and set carry + src_coord[axis] = StartForAxis(op, input_shape, axis); + carry = true; + } else { + carry = false; + } + } + // increment destination buffer offset + dst_offset++; + } while (dst_offset < output_data.size()); +} + +} // anonymous namespace + +bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kStridedSlice) { + return false; + } + + const StridedSliceOperator* op = + static_cast<const StridedSliceOperator*>(base_op); + + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return false; + } + + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes + return false; + } + + if (op->start_indices.empty() || op->stop_indices.empty() || + op->strides.empty()) { + // Attributes have not resolved yet. + return false; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until the value shape has been resolved. + return false; + } + if (!IsConstantParameterArray(*model, op->inputs[0])) { + // Yield until the value is constant. + return false; + } + + CHECK(!output_array.buffer); + switch (output_array.data_type) { + case ArrayDataType::kFloat: + StridedSlice<ArrayDataType::kFloat>(*op, input_array, &output_array); + break; + case ArrayDataType::kUint8: + StridedSlice<ArrayDataType::kUint8>(*op, input_array, &output_array); + break; + case ArrayDataType::kInt32: + StridedSlice<ArrayDataType::kInt32>(*op, input_array, &output_array); + break; + case ArrayDataType::kInt64: + StridedSlice<ArrayDataType::kInt64>(*op, input_array, &output_array); + break; + default: + LOG(FATAL) + << "Unsupported data type input to StridedSlice op with output \"" + << op->outputs[0] << "\""; + break; + } + + // Erase input array if no longer used + if (IsDiscardableArray(*model, op->inputs[0]) && + CountOpsWithInput(*model, op->inputs[0]) == 1) { + model->arrays.erase(op->inputs[0]); + } + + // Erase the operator + model->operators.erase(it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc deleted file mode 100644 index 8cc6db1619..0000000000 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include <cstddef> -#include <memory> -#include <string> -#include <unordered_map> -#include <vector> - -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace toco { - -bool ResolveConstantTensorFlowShape::Run(Model* model, std::size_t op_index) { - const auto tfshape_it = model->operators.begin() + op_index; - const auto* tfshape_base_op = tfshape_it->get(); - if (tfshape_base_op->type != OperatorType::kTensorFlowShape) { - return false; - } - - const auto* tfshape_op = - static_cast<const TensorFlowShapeOperator*>(tfshape_base_op); - - const auto& input_array = model->GetArray(tfshape_op->inputs[0]); - auto& output_array = model->GetArray(tfshape_op->outputs[0]); - - // Yield until the input array's shape has been resolved. - if (!input_array.has_shape()) { - return false; - } - - // Create a buffer for the output array, making it a constant array, and - // copy the input shape into the output buffer. - CHECK(!output_array.buffer); - auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>(); - output_buffer.data = input_array.shape().dims(); - - // Erase the input array if no longer used - if (IsDiscardableArray(*model, tfshape_op->inputs[0]) && - CountOpsWithInput(*model, tfshape_op->inputs[0]) == 1) { - model->arrays.erase(tfshape_op->inputs[0]); - } - model->operators.erase(tfshape_it); - - return true; -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc index 0e4c66544d..a73f16735c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc @@ -44,22 +44,15 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { !IsConstantParameterArray(*model, op->inputs[paddings_index])) return false; - // Handling block_shape. - const auto& block_shape_array = *model->arrays[op->inputs[block_shape_index]]; - if (!block_shape_array.has_shape()) return false; - const std::vector<int>& block_shape_dims = block_shape_array.shape().dims(); - CHECK_EQ(block_shape_dims.size(), 1); - std::vector<int> block_shape_buffer = - block_shape_array.GetBuffer<ArrayDataType::kInt32>().data; - for (int i = 0; i < block_shape_dims[0]; ++i) { - op->block_shape.push_back(block_shape_buffer[i]); - } - - // Handling paddings. + // Handle paddings. const auto& paddings_array = *model->arrays[op->inputs[paddings_index]]; if (!paddings_array.has_shape()) return false; const std::vector<int>& paddings_dims = paddings_array.shape().dims(); - CHECK_EQ(paddings_dims.size(), 2); + if (paddings_dims.size() != 2) { + // Code only handles padding of 2 dimensions. Perhaps another transformation + // will delete this op. + return false; + } std::vector<int> paddings_buffer = paddings_array.GetBuffer<ArrayDataType::kInt32>().data; for (int i = 0; i < paddings_dims[0]; ++i) { @@ -67,6 +60,17 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { op->after_paddings.push_back(paddings_buffer[i * 2 + 1]); } + // Handle block_shape. + const auto& block_shape_array = *model->arrays[op->inputs[block_shape_index]]; + if (!block_shape_array.has_shape()) return false; + const std::vector<int>& block_shape_dims = block_shape_array.shape().dims(); + CHECK_EQ(block_shape_dims.size(), 1); + std::vector<int> block_shape_buffer = + block_shape_array.GetBuffer<ArrayDataType::kInt32>().data; + for (int i = 0; i < block_shape_dims[0]; ++i) { + op->block_shape.push_back(block_shape_buffer[i]); + } + return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc index 97946182ef..dbe69adcbd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -12,11 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <memory> -#include <string> -#include <unordered_map> -#include <vector> - #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" @@ -30,19 +25,14 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { if (slice_op->type != OperatorType::kStridedSlice) return false; auto* op = static_cast<StridedSliceOperator*>(slice_op); - if (!op->start_indices.empty()) return false; + if (!op->start_indices.empty()) { + // We have already resolved these attributes + return false; + } CHECK_EQ(op->inputs.size(), 4); - if (!IsConstantParameterArray(*model, op->inputs[1])) return false; - if (!IsConstantParameterArray(*model, op->inputs[2])) return false; - if (!IsConstantParameterArray(*model, op->inputs[3])) return false; - const auto& start_array = *model->arrays[op->inputs[1]]; if (!start_array.has_shape()) return false; - if (toco::RequiredBufferSizeForShape(start_array.shape()) != 4) { - // Only 4D arrays are supported for now. - return false; - } const auto& stop_array = *model->arrays[op->inputs[2]]; if (!stop_array.has_shape()) return false; @@ -50,12 +40,24 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { const auto& stride_array = *model->arrays[op->inputs[3]]; if (!stride_array.has_shape()) return false; + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + if (!IsConstantParameterArray(*model, op->inputs[3])) return false; + op->start_indices = start_array.GetBuffer<ArrayDataType::kInt32>().data; op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data; op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data; - // TODO(dkalenichenko): Delete the extra inputs? + CHECK_GE(op->start_indices.size(), 1); + CHECK_LE(op->start_indices.size(), 4); + CHECK_EQ(op->stop_indices.size(), op->start_indices.size()); + CHECK_EQ(op->strides.size(), op->stop_indices.size()); + // Ideally, we would remove the input arrays after they have been resolved. + // However, we must then reconstitute these input arrays for all supported + // export formats. For now, leave the arrays so we don't have to modify our + // exporters. Ideally, we wouldn't have op attributes, and would work directly + // with the input arrays. return true; } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index f07fef117f..995e9d67ca 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1181,21 +1181,6 @@ void ConvertStridedSliceOperator(const NodeDef& node, CHECK_EQ(node.op(), "StridedSlice"); CheckInputsCount(node, tf_import_flags, 4); - // Only a subset of the full TF op functionality is supported now. - if ( // No 64-bit indices. - GetDataTypeAttr(node, "Index") != DT_INT32 || - // No dimensionality changes. - GetIntAttr(node, "new_axis_mask") != 0 || - GetIntAttr(node, "shrink_axis_mask") != 0 || - // No sparse indices. - GetIntAttr(node, "ellipsis_mask") != 0 || - // Only 4D tensors are supported. - GetIntAttr(node, "begin_mask") > 15 || - GetIntAttr(node, "end_mask") > 15) { - ConvertUnsupportedOperator(node, tf_import_flags, model); - return; - } - auto* op = new StridedSliceOperator; for (const auto& input : node.input()) { op->inputs.push_back(input); diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 92ec0bcba8..0bcf4596de 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -52,6 +52,7 @@ void MakeGeneralGraphTransformationsSet( GraphTransformationsSet* transformations) { CHECK(transformations->empty()); transformations->Add(new ConvertExpandDimsToReshape); + transformations->Add(new ConvertTrivialTransposeToReshape); transformations->Add(new ResolveReshapeAttributes); transformations->Add(new PropagateArrayDataTypes); transformations->Add(new PropagateFixedSizes); @@ -68,6 +69,9 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveBatchNormalization); transformations->Add(new ResolveConstantBinaryOperator); transformations->Add(new ResolveConstantFill); + transformations->Add(new ResolveConstantRange); + transformations->Add(new ResolveConstantStack); + transformations->Add(new ResolveConstantStridedSlice); transformations->Add(new ResolveConstantUnaryOperator); transformations->Add(new ResolveTensorFlowMerge); transformations->Add(new ResolveTensorFlowSqueeze); @@ -86,7 +90,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveSliceAttributes); transformations->Add(new ResolveMeanAttributes); transformations->Add(new ResolveTransposeAttributes); - transformations->Add(new ResolveConstantTensorFlowShape); + transformations->Add(new ResolveConstantShapeOrRank); transformations->Add(new MakeInitialDequantizeOperator); } |