aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-17 13:29:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-17 13:33:14 -0800
commit29baea36e3b374a852ad3dedc1c3719016febdc4 (patch)
treedbeca52751cc4fea3c2b3fd776468043d1900f0d
parentd8697935d334bb0f2e1c9bccfe9a2a7bee9785cc (diff)
Adding support for to resolve constant FloorDiv, FloorMod, StridedSlice, Stack, Rank and Range.
PiperOrigin-RevId: 182261052
-rw-r--r--tensorflow/contrib/lite/toco/BUILD6
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc85
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc298
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc32
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc107
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc70
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc113
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc198
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc62
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc30
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc32
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc15
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc6
17 files changed, 934 insertions, 156 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 741f9d4bfb..cea5b4e92d 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -171,6 +171,7 @@ cc_library(
srcs = [
"graph_transformations/convert_expanddims_to_reshape.cc",
"graph_transformations/convert_pure_conv_to_depthwise.cc",
+ "graph_transformations/convert_trivial_transpose_to_reshape.cc",
"graph_transformations/create_im2col_arrays.cc",
"graph_transformations/dequantize.cc",
"graph_transformations/drop_fake_quant.cc",
@@ -207,7 +208,10 @@ cc_library(
"graph_transformations/resolve_constant_concatenation.cc",
"graph_transformations/resolve_constant_fake_quant.cc",
"graph_transformations/resolve_constant_fill.cc",
- "graph_transformations/resolve_constant_tensorflow_shape.cc",
+ "graph_transformations/resolve_constant_range.cc",
+ "graph_transformations/resolve_constant_shape_or_rank.cc",
+ "graph_transformations/resolve_constant_stack.cc",
+ "graph_transformations/resolve_constant_strided_slice.cc",
"graph_transformations/resolve_constant_unary.cc",
"graph_transformations/resolve_mean_attributes.cc",
"graph_transformations/resolve_pad_attributes.cc",
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
index 62e7282d16..d4da8f5dfe 100644
--- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -239,8 +239,8 @@ void AllocateTransientArrays(Model* model,
// is a misnormer, should read 'workspace'.
for (const auto& array_pair : ordered_arrays_map) {
const string& array_name = array_pair.first;
- const auto& array_lifespan = array_lifespans.find(array_name)->second;
- if (array_lifespan.persistent) {
+ auto it = array_lifespans.find(array_name);
+ if (it != array_lifespans.end() && it->second.persistent) {
AllocateTransientArray(*model, array_name, &allocator,
transient_data_alignment);
}
@@ -282,8 +282,8 @@ void AllocateTransientArrays(Model* model,
std::size_t persistent_alloc_size = 0;
for (const auto& array_pair : ordered_arrays_map) {
const string& array_name = array_pair.first;
- const auto& array_lifespan = array_lifespans.find(array_name)->second;
- if (array_lifespan.persistent) {
+ auto it = array_lifespans.find(array_name);
+ if (it != array_lifespans.end() && it->second.persistent) {
persistent_alloc_size +=
TransientArraySize(*model, array_name, transient_data_alignment);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
new file mode 100644
index 0000000000..a234c20924
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
@@ -0,0 +1,85 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
+ auto transpose_it = model->operators.begin() + op_index;
+ if (transpose_it->get()->type != OperatorType::kTranspose) {
+ return false;
+ }
+ TransposeOperator* transpose_op =
+ static_cast<TransposeOperator*>(transpose_it->get());
+
+ const auto& output_array = *model->arrays[transpose_op->outputs[0]];
+ if (!output_array.has_shape()) {
+ // Yield until PropagateFixedSizes has been run on this op.
+ return false;
+ }
+ // Note: We can assume we have error checked inputs in PropagateFixedSizes.
+
+ // This transpose is trivial if we only have one non-unitary dimension.
+ std::vector<int> const& dims = output_array.shape().dims();
+ unsigned non_unitary_axis_count = 0;
+ for (int i = 0; i < dims.size(); i++) {
+ if (dims[i] != 1) {
+ non_unitary_axis_count++;
+ }
+ }
+ if (non_unitary_axis_count > 1) {
+ // Transpose is not trivial
+ return false;
+ }
+
+ // This transpose is trivial. Replace it with a Reshape op.
+ auto* reshape_op = new TensorFlowReshapeOperator;
+
+ // Copy input and output
+ reshape_op->inputs.push_back(transpose_op->inputs[0]);
+ reshape_op->outputs = transpose_op->outputs;
+
+ // Create a new input array for the shape input
+ string perm_array_name = transpose_op->inputs[1];
+ string shape_array_name = toco::AvailableArrayName(*model, perm_array_name);
+ Array& shape_array = model->GetOrCreateArray(shape_array_name);
+ *(shape_array.mutable_shape()->mutable_dims()) = {
+ 1, static_cast<int>(dims.size())};
+ reshape_op->inputs.push_back(shape_array_name);
+ shape_array.data_type = ArrayDataType::kInt32;
+ auto& shape_buffer = shape_array.GetMutableBuffer<ArrayDataType::kInt32>();
+ shape_buffer.data = dims;
+
+ // Delete perm array if unused
+ if (IsDiscardableArray(*model, perm_array_name) &&
+ CountOpsWithInput(*model, perm_array_name) == 1) {
+ model->arrays.erase(perm_array_name);
+ }
+
+ // Replace the operator in the graph.
+ const auto reshape_it = model->operators.emplace(transpose_it, reshape_op);
+ transpose_it = reshape_it + 1;
+ CHECK_EQ(transpose_it->get(), transpose_op);
+ model->operators.erase(transpose_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 785dad8596..9300ab53a7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -114,6 +114,7 @@ void RunGraphTransformations(Model* model, const string& message,
// List of all graph transformations
DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
+DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape)
DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors)
DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions)
DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine)
@@ -159,7 +160,10 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
-DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
DECLARE_GRAPH_TRANSFORMATION(Dequantize)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 4fe127544b..c6f17cf319 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -61,7 +61,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
} else if (op->type == OperatorType::kRank ||
op->type == OperatorType::kTensorFlowShape) {
- // These operators are assumed to produce int32 outputs.
+ // These operators only produce int32 outputs.
SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
} else if (op->type == OperatorType::kTensorFlowSplit ||
op->type == OperatorType::kTensorFlowConcat ||
@@ -80,6 +80,20 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->outputs.size(), 1);
auto* argmax_op = static_cast<ArgMaxOperator*>(op);
model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type;
+ } else if (op->type == OperatorType::kRange) {
+ auto* range_op = static_cast<RangeOperator*>(op);
+ // Output type of the Range op can be set via an attribute
+ ArrayDataType data_type;
+ if (range_op->dtype != ArrayDataType::kNone) {
+ // Use the type if specified
+ data_type = range_op->dtype;
+ } else {
+ // Otherwise use the first input
+ CHECK_GE(op->inputs.size(), 1);
+ data_type = model->arrays[op->inputs[0]]->data_type;
+ }
+ CHECK_EQ(op->outputs.size(), 1);
+ SetDataTypeForAllOutputs(model, op, data_type);
} else if (op->type == OperatorType::kTensorFlowUnsupported) {
auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
if (unsupported_op->output_data_types.size() != op->outputs.size()) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 8f181e78d8..a939efb4db 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -316,25 +316,30 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
void ProcessTensorFlowReshapeOperator(Model* model,
TensorFlowReshapeOperator* op) {
auto& output_array = *model->arrays[op->outputs[0]];
- // Bail if we already have output dims
if (output_array.has_shape()) {
+ // We have already run
return;
}
const auto& input_array = *model->arrays[op->inputs[0]];
- // Yield until input dims have been resolved.
if (!input_array.has_shape()) {
+ // Yield until input dims have been resolved.
return;
}
const auto& input_shape = input_array.shape();
- const string& shape_name = op->inputs[1];
- auto& shape_array = model->GetArray(shape_name);
- // Yield until the shape is resolved as a constant array
+ auto& shape_array = model->GetArray(op->inputs[1]);
+ if (!shape_array.has_shape()) {
+ // Yield until target_shape shape been resolved.
+ return;
+ }
if (!shape_array.buffer) {
+ // Yield until the target_shape is constant
return;
}
- CHECK(shape_array.data_type == ArrayDataType::kInt32);
+ CHECK(shape_array.data_type == ArrayDataType::kInt32)
+ << "Reshape dims must be int32";
+
// shape_data is the raw array of ints describing the shape
// in the TensorFlow node. We intentionally make a copy here, rather than
// modify wildcards in-place below, because in some graphs, the same shape
@@ -357,12 +362,18 @@ void ProcessTensorFlowReshapeOperator(Model* model,
}
const int input_flat_size = RequiredBufferSizeForShape(input_shape);
if (has_wildcard) {
+ CHECK_GE(input_flat_size, product_non_wildcard_dims)
+ << "Array not large enough to fill the requested dimensions for "
+ "Reshape op with output \""
+ << op->outputs[0] << "\". Are your input shapes correct?";
shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
}
auto& output_shape = *output_array.mutable_shape();
*output_shape.mutable_dims() = shape_data;
- const int output_flat_size = RequiredBufferSizeForShape(output_shape);
- CHECK_EQ(output_flat_size, input_flat_size);
+ CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
+ << "Input cannot be reshaped to requested dimensions for Reshape op with "
+ "output \""
+ << op->outputs[0] << "\". Are your input shapes correct?";
}
void ProcessSimpleOperator(Model* model, Operator* op) {
@@ -535,6 +546,56 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
output_dims[op->axis] = concat_size;
}
+void ProcessRangeOperator(Model* model, RangeOperator* op) {
+ CHECK_EQ(op->inputs.size(), 3);
+ const auto& start_array = *model->arrays[op->inputs[0]];
+ if (!start_array.has_shape()) {
+ // Yield until input dims have been resolved.
+ return;
+ }
+ const auto& limit_array = *model->arrays[op->inputs[1]];
+ if (!limit_array.has_shape()) {
+ return;
+ }
+ const auto& delta_array = *model->arrays[op->inputs[2]];
+ if (!delta_array.has_shape()) {
+ return;
+ }
+
+ if (!IsConstantParameterArray(*model, op->inputs[0])) {
+ // Yield until inputs are constant.
+ return;
+ }
+ if (!IsConstantParameterArray(*model, op->inputs[1])) {
+ return;
+ }
+ if (!IsConstantParameterArray(*model, op->inputs[2])) {
+ return;
+ }
+
+ CHECK(start_array.data_type == ArrayDataType::kInt32)
+ << "Range op inputs must be int32.";
+ CHECK(limit_array.data_type == ArrayDataType::kInt32)
+ << "Range op inputs must be int32.";
+ CHECK(delta_array.data_type == ArrayDataType::kInt32)
+ << "Range op inputs must be int32.";
+ CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
+ << "Range op inputs must be scalar.";
+ CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1)
+ << "Range op inputs must be scalar.";
+ CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1)
+ << "Range op inputs must be scalar.";
+ int size = floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] -
+ start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) /
+ delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]);
+
+ // Only set the output shape. Contents are set by ResolveConstantRange.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ Shape* output_shape = output_array.mutable_shape();
+ output_shape->ReplaceDims({size});
+}
+
void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
const string& input_name = op->inputs[1];
@@ -885,35 +946,166 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
output_array.copy_shape(output_shape);
}
-void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
- CHECK_EQ(op->inputs.size(), 4);
+void ProcessRankOperator(Model* model, RankOperator* op) {
+ CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
const auto& input_array = *model->arrays[op->inputs[0]];
+ if (!input_array.has_shape()) {
+ // Yield until input dims have been resolved.
+ return;
+ }
- // Yield until input dims have been resolved.
- if (!input_array.has_shape()) return;
+ // Only set the output shape. Array contents are set by
+ // ResolveConstantShapeOrRank.
+ Shape* output_shape = output_array.mutable_shape();
+ output_shape->ReplaceDims({});
+}
+
+void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
+ CHECK_GE(op->inputs.size(), 1);
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ if (!input_array.has_shape()) {
+ // Yield until input dims have been resolved.
+ return;
+ }
- if (op->start_indices.empty()) return;
- CHECK_EQ(op->start_indices.size(), op->stop_indices.size());
- CHECK_EQ(op->start_indices.size(), op->strides.size());
+ // Only set the output shape. Array contents are set by
+ // ResolveConstantShapeOrRank.
+ Shape* output_shape = output_array.mutable_shape();
+ output_shape->ReplaceDims({input_array.shape().dimensions_count()});
+}
+void ProcessStackOperator(Model* model, StackOperator* op) {
+ CHECK_GE(op->inputs.size(), 1);
+ CHECK_EQ(op->outputs.size(), 1);
auto& output_array = *model->arrays[op->outputs[0]];
- if (output_array.has_shape()) return;
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
- Shape output_shape = input_array.shape();
- std::vector<int>& dims = *output_shape.mutable_dims();
- CHECK_EQ(op->start_indices.size(), dims.size());
+ std::unique_ptr<Shape> stacked_shape;
+ for (const auto& input : op->inputs) {
+ const auto& input_array = model->GetArray(input);
+ if (!input_array.has_shape()) {
+ // Yield until all input dims have been resolved.
+ return;
+ }
- for (int i = 0; i < op->start_indices.size(); ++i) {
- const int mask = 1 << i;
- const int start = (op->begin_mask & mask) ? 0 : op->start_indices[i];
- const int stop = (op->end_mask & mask) ? input_array.shape().dims()[i]
- : op->stop_indices[i];
- dims[i] = (stop - start) / op->strides[i];
+ Shape shape = input_array.shape();
+ if (shape.dimensions_count() == 0) {
+ // Convert 0D scalars to 1D scalars of shape {1}.
+ shape.mutable_dims()->push_back(1);
+ }
+ if (!stacked_shape) {
+ stacked_shape.reset(new Shape(shape));
+ } else {
+ CHECK(*stacked_shape == shape) << "All input arrays to Stack operators "
+ "must have the same shape. Input \""
+ << input << "\" is different.";
+ }
}
- output_array.copy_shape(output_shape);
+ int axis = op->axis;
+ if (axis < 0) {
+ // Handle negative axis
+ axis += stacked_shape->dims().size() + 1;
+ }
+ stacked_shape->mutable_dims()->insert(
+ stacked_shape->mutable_dims()->begin() + axis, op->inputs.size());
+ output_array.copy_shape(*stacked_shape);
+}
+
+void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
+ CHECK_GE(op->inputs.size(), 1);
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
+
+ if (op->start_indices.empty() || op->stop_indices.empty() ||
+ op->strides.empty()) {
+ // ResolveStridedSliceAttributes has not run yet.
+ return;
+ }
+
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // Yield until input dims have been resolved.
+ return;
+ }
+
+ if (op->ellipsis_mask != 0) {
+ // Something like LOG_FIRST_N(WARNING, 10) would be prefferable to reduce
+ // log noise. However, the TensorFlow logging library does not appear to
+ // support this.
+ LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
+ << "\". ellipsis_mask is not supported (mask="
+ << op->ellipsis_mask << ")";
+ return;
+ }
+ if (op->new_axis_mask != 0) {
+ LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
+ << "\". new_axis_mask is not supported (mask="
+ << op->new_axis_mask << ")";
+ return;
+ }
+
+ int dim_count = input_array.shape().dimensions_count();
+ CHECK(op->start_indices.size() == dim_count)
+ << ": Incorrect number of start indices supplied to StridedSlice op with "
+ "output \""
+ << op->outputs[0] << "\". Op requires " << dim_count << " start indices";
+ CHECK(op->stop_indices.size() == dim_count)
+ << ": Incorrect number of stop indices supplied to StridedSlice op with "
+ "output \""
+ << op->outputs[0] << "\". Op requires " << dim_count << " stop indices";
+ CHECK(op->strides.size() == dim_count)
+ << ": Incorrect number of strides supplied to StridedSlice op with "
+ " output \""
+ << op->outputs[0] << "\". Op requires " << dim_count << " strides";
+
+ // Create output shape
+ std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
+
+ // Compute output shape
+ for (int i = 0; i < dim_count; ++i) {
+ const int mask = 1 << i;
+ int start = (op->begin_mask & mask) ? 0 : op->start_indices[i];
+ if (start < 0) {
+ // handle negative indices
+ start += input_array.shape().dims(i);
+ }
+ int stop = (op->end_mask & mask) ? input_array.shape().dims(i)
+ : op->stop_indices[i];
+ if (stop < 0) {
+ // handle negative indices
+ stop += input_array.shape().dims(i);
+ }
+
+ int dim_size = (stop - start) / op->strides[i];
+ if (op->shrink_axis_mask & mask) {
+ CHECK_EQ(dim_size, 1) << "Output size for an axis must compute to 1 when "
+ "shrinking that axis";
+ } else {
+ dims->push_back(dim_size);
+ }
+ }
}
void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
@@ -971,6 +1163,45 @@ void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
}
+void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) {
+ // We have already run
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ if (!input_array.has_shape()) {
+ // Yield until input dims have been resolved.
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+
+ auto& perm_array = model->GetArray(op->inputs[1]);
+ if (!perm_array.has_shape()) {
+ // Yield until permutation shape been resolved.
+ return;
+ }
+ if (!perm_array.buffer) {
+ // Yield until the permutation is constant
+ return;
+ }
+ CHECK(perm_array.data_type == ArrayDataType::kInt32)
+ << "Transpose permutation input must be int32";
+
+ std::vector<int32> const& perm =
+ perm_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CHECK_EQ(perm.size(), input_shape.dimensions_count())
+ << "Transpose permutation input must be same length as input dimensions";
+ std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims();
+ for (int i = 0; i < perm.size(); i++) {
+ int axis = perm[i];
+ CHECK_GE(axis, 0);
+ CHECK_LT(axis, input_shape.dimensions_count());
+ output_dims->push_back(input_shape.dims(axis));
+ }
+}
+
void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
const auto& input_array = *model->arrays[op->inputs[0]];
@@ -1136,13 +1367,19 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
// or else at the moment we will abort.
break;
case OperatorType::kExpandDims:
+ // Yield until ExpandDims is converted to Reshape
+ break;
case OperatorType::kRange:
+ ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
+ break;
case OperatorType::kRank:
+ ProcessRankOperator(model, static_cast<RankOperator*>(op));
+ break;
case OperatorType::kTensorFlowShape:
+ ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
+ break;
case OperatorType::kStack:
- case OperatorType::kTranspose:
- // Unimplemented. Hopefully another graph transformation will drop it or
- // rewrite it.
+ ProcessStackOperator(model, static_cast<StackOperator*>(op));
break;
case OperatorType::kReorderAxes:
ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
@@ -1185,6 +1422,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kSvdf:
ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
break;
+ case OperatorType::kTranspose:
+ ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
index a4f198e92f..7777d4f543 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
@@ -37,26 +37,19 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
}
CHECK_EQ(op->inputs.size(), 3);
- if (!IsConstantParameterArray(*model, op->inputs[1]) or
+ if (!IsConstantParameterArray(*model, op->inputs[1]) ||
!IsConstantParameterArray(*model, op->inputs[2]))
return false;
- // Handling block_shape.
- const auto& block_shape_array = *model->arrays[op->inputs[1]];
- if (!block_shape_array.has_shape()) return false;
- const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
- CHECK_EQ(block_shape_dims.size(), 1);
- std::vector<int> block_shape_buffer =
- block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
- for (int i = 0; i < block_shape_dims[0]; ++i) {
- op->block_shape.push_back(block_shape_buffer[i]);
- }
-
- // Handling crops.
+ // Handle crops
const auto& crops_array = *model->arrays[op->inputs[2]];
if (!crops_array.has_shape()) return false;
const std::vector<int>& crops_dims = crops_array.shape().dims();
- CHECK_EQ(crops_dims.size(), 2);
+ if (crops_dims.size() != 2) {
+ // Code only handles crops of 2 dimensions. Perhaps another transformation
+ // will delete this op.
+ return false;
+ }
std::vector<int> crops_buffer =
crops_array.GetBuffer<ArrayDataType::kInt32>().data;
for (int i = 0; i < crops_dims[0]; ++i) {
@@ -64,6 +57,17 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
op->after_crops.push_back(crops_buffer[i * 2 + 1]);
}
+ // Handle block_shape
+ const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ if (!block_shape_array.has_shape()) return false;
+ const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
+ CHECK_EQ(block_shape_dims.size(), 1);
+ std::vector<int> block_shape_buffer =
+ block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < block_shape_dims[0]; ++i) {
+ op->block_shape.push_back(block_shape_buffer[i]);
+ }
+
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
index 53e1be7a05..fd51df4058 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -141,6 +141,10 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
outval = val0 - val1;
} else if (binary_op->type == OperatorType::kDiv) {
outval = val0 / val1;
+ } else if (binary_op->type == OperatorType::kFloorDiv) {
+ outval = floor(val0 / val1);
+ } else if (binary_op->type == OperatorType::kFloorMod) {
+ outval = val0 - (floor(val0 / val1) * val1);
} else if (binary_op->type == OperatorType::kTensorFlowMinimum) {
outval = std::min(val0, val1);
} else if (binary_op->type == OperatorType::kTensorFlowMaximum) {
@@ -191,6 +195,8 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv &&
+ binary_op->type != OperatorType::kFloorDiv &&
+ binary_op->type != OperatorType::kFloorMod &&
binary_op->type != OperatorType::kTensorFlowMinimum &&
binary_op->type != OperatorType::kTensorFlowMaximum &&
binary_op->type != OperatorType::kTensorFlowLess &&
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
new file mode 100644
index 0000000000..383d54aa5a
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* base_op = it->get();
+ if (base_op->type != OperatorType::kRange) {
+ return false;
+ }
+ auto* op = static_cast<RangeOperator*>(base_op);
+
+ CHECK_EQ(op->inputs.size(), 3);
+ const auto& start_array = *model->arrays[op->inputs[0]];
+ if (!start_array.has_shape()) {
+ // Yield until all input dims have been resolved.
+ return false;
+ }
+ const auto& limit_array = *model->arrays[op->inputs[1]];
+ if (!limit_array.has_shape()) {
+ // Yield until all input dims have been resolved.
+ return false;
+ }
+ const auto& delta_array = *model->arrays[op->inputs[2]];
+ if (!delta_array.has_shape()) {
+ // Yield until all input dims have been resolved.
+ return false;
+ }
+
+ for (const auto& input : op->inputs) {
+ if (!IsConstantParameterArray(*model, input)) {
+ // yield if any input is mutable
+ return false;
+ }
+ }
+
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return false;
+ }
+
+ CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
+ << "Range op inputs must be scalar.";
+ CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1)
+ << "Range op inputs must be scalar.";
+ CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1)
+ << "Range op inputs must be scalar.";
+
+ CHECK(start_array.data_type == ArrayDataType::kInt32)
+ << "Range op inputs must be int32.";
+ CHECK(limit_array.data_type == ArrayDataType::kInt32)
+ << "Range op inputs must be int32.";
+ CHECK(delta_array.data_type == ArrayDataType::kInt32)
+ << "Range op inputs must be int32.";
+
+ // Compute buffer contents
+ int start = start_array.GetBuffer<ArrayDataType::kInt32>().data[0];
+ int limit = limit_array.GetBuffer<ArrayDataType::kInt32>().data[0];
+ int delta = delta_array.GetBuffer<ArrayDataType::kInt32>().data[0];
+ auto& buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>();
+ buffer.data.clear();
+ for (int32 val = start; val < limit; val += delta) {
+ buffer.data.push_back(val);
+ }
+ CHECK_EQ(floor((limit - start) / delta), buffer.data.size());
+ CHECK_EQ(buffer.data.size(), output_array.shape().dims()[0]);
+
+ // Delete the input array if no longer used
+ if (IsDiscardableArray(*model, op->inputs[0]) &&
+ CountOpsWithInput(*model, op->inputs[0]) == 1) {
+ model->arrays.erase(op->inputs[0]);
+ }
+ if (IsDiscardableArray(*model, op->inputs[1]) &&
+ CountOpsWithInput(*model, op->inputs[1]) == 1) {
+ model->arrays.erase(op->inputs[1]);
+ }
+ if (IsDiscardableArray(*model, op->inputs[2]) &&
+ CountOpsWithInput(*model, op->inputs[2]) == 1) {
+ model->arrays.erase(op->inputs[2]);
+ }
+
+ // Delete the operator
+ model->operators.erase(it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
new file mode 100644
index 0000000000..e15ee7805c
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
@@ -0,0 +1,70 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ const auto* op = it->get();
+ if (!(op->type == OperatorType::kTensorFlowShape ||
+ op->type == OperatorType::kRank)) {
+ return false;
+ }
+
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been resolved
+ return false;
+ }
+
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // Yield until the input array's shape has been resolved.
+ return false;
+ }
+
+ if (!output_array.has_shape()) {
+ // Yield until the output shape has been resolved.
+ return false;
+ }
+
+ // Compute the output
+ CHECK(!output_array.buffer);
+ auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>();
+ if (op->type == OperatorType::kTensorFlowShape) {
+ // Copy the input shape into the output buffer.
+ output_buffer.data = input_array.shape().dims();
+ } else if (op->type == OperatorType::kRank) {
+ // Copy the dimension count into the output buffer.
+ output_buffer.data.resize(1);
+ output_buffer.data[0] = input_array.shape().dimensions_count();
+ }
+
+ // Delete the input array if no longer used
+ if (IsDiscardableArray(*model, op->inputs[0]) &&
+ CountOpsWithInput(*model, op->inputs[0]) == 1) {
+ model->arrays.erase(op->inputs[0]);
+ }
+
+ model->operators.erase(it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
new file mode 100644
index 0000000000..86c76141a4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
@@ -0,0 +1,113 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+template <ArrayDataType Type>
+void Stack(Model* model, StackOperator const& op) {
+ auto& output_array = model->GetArray(op.outputs[0]);
+ CHECK(output_array.data_type == Type);
+
+ // Create a buffer for the output array
+ std::vector<DataType<Type>>& output_data =
+ output_array.GetMutableBuffer<Type>().data;
+ output_data.resize(RequiredBufferSizeForShape(output_array.shape()));
+
+ // Stack inputs into buffer
+ CHECK_EQ(op.axis, 0) << "Stacking only supported along first axis";
+ int dst_offset = 0;
+ for (int i = 0; i < op.inputs.size(); i++) {
+ // Append array data to output for each input array
+ const auto& input_array = model->GetArray(op.inputs[i]);
+ int input_size = RequiredBufferSizeForShape(input_array.shape());
+ memcpy(&output_data[dst_offset], &input_array.GetBuffer<Type>().data[0],
+ input_size * sizeof(Type));
+ dst_offset += input_size;
+ }
+ CHECK_EQ(dst_offset, output_data.size());
+}
+
+} // namespace
+
+bool ResolveConstantStack::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ const auto* base_op = it->get();
+ if (base_op->type != OperatorType::kStack) {
+ return false;
+ }
+ const auto* op = static_cast<const StackOperator*>(base_op);
+
+ CHECK_GE(op->inputs.size(), 1);
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return false;
+ }
+
+ if (!output_array.has_shape()) {
+ // Yield until the output shape has been set by PropagateFixedShapes
+ return false;
+ }
+
+ for (const auto& input : op->inputs) {
+ if (!IsConstantParameterArray(*model, input)) {
+ // Yield if any input is mutable
+ return false;
+ }
+ }
+
+ CHECK(!output_array.buffer);
+ switch (output_array.data_type) {
+ case ArrayDataType::kFloat:
+ Stack<ArrayDataType::kFloat>(model, *op);
+ break;
+ case ArrayDataType::kUint8:
+ Stack<ArrayDataType::kUint8>(model, *op);
+ break;
+ case ArrayDataType::kInt32:
+ Stack<ArrayDataType::kInt32>(model, *op);
+ break;
+ case ArrayDataType::kInt64:
+ Stack<ArrayDataType::kInt64>(model, *op);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data type given to Stack op with output \""
+ << op->outputs[0] << "\"";
+ break;
+ }
+
+ // Erase input arrays if no longer used
+ for (const auto& input : op->inputs) {
+ if (IsDiscardableArray(*model, input) &&
+ CountOpsWithInput(*model, input) == 1) {
+ model->arrays.erase(input);
+ }
+ }
+
+ // Erase the operator
+ model->operators.erase(it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
new file mode 100644
index 0000000000..3976d9cbb4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -0,0 +1,198 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+int StartForAxis(StridedSliceOperator const& op, Shape const& input_shape,
+ int axis) {
+ int start;
+ if (op.begin_mask & 1 << axis) {
+ // If begin mask bit is set, use the first element
+ start = 0;
+ } else {
+ // Otherwise, use the specified element
+ start = op.start_indices[axis];
+ if (start < 0) {
+ // Handle negative indices
+ start += input_shape.dims(axis);
+ }
+ }
+ return start;
+}
+
+int StopForAxis(StridedSliceOperator const& op, Shape const& input_shape,
+ int axis) {
+ int stop;
+ if (op.end_mask & (1 << axis)) {
+ // If end mask bit set, use the last element
+ stop = input_shape.dims(axis);
+ } else {
+ // Otherwise, use the specified element
+ stop = op.stop_indices[axis];
+ if (stop < 0) {
+ // Handle negative indices
+ stop += input_shape.dims(axis);
+ }
+ }
+ return stop;
+}
+
+template <ArrayDataType Type>
+void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
+ Array* output_array) {
+ // The TensorFlow documentation for StridedSlice is a bit ambiguous in places
+ // (https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/strided-slice).
+ // Use the source code at /third_party/tensorflow/core/util/strided_op.cc as
+ // "master documentation".
+
+ CHECK(input_array.data_type == Type);
+ CHECK(output_array->data_type == Type);
+ CHECK_EQ(op.ellipsis_mask, 0);
+ CHECK_EQ(op.new_axis_mask, 0);
+
+ int num_input_axes = op.start_indices.size();
+ CHECK_EQ(num_input_axes, op.stop_indices.size());
+ CHECK_EQ(num_input_axes, op.strides.size());
+ for (int i = 0; i < op.strides.size(); i++) {
+ CHECK_GE(op.strides[i], 0) << "Negative strides usupported";
+ }
+
+ // Create a buffer for the output array
+ std::vector<DataType<Type>>& output_data =
+ output_array->GetMutableBuffer<Type>().data;
+ output_data.resize(RequiredBufferSizeForShape(output_array->shape()));
+
+ // Initialize source coordinate
+ Shape const& input_shape = input_array.shape();
+ Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
+ std::vector<int> src_coord(op.start_indices.size());
+ for (int axis = 0; axis < num_input_axes; axis++) {
+ src_coord[axis] = StartForAxis(op, input_shape, axis);
+ }
+
+ // In order to handle any number (N) of dimensions, we copy elements one by
+ // one and treat the source coordinate as an N digit number (src_coord here).
+ // Each "digit" is incremented individually (by the stride). When it overflows
+ // (becomes greater than the stop), that digit is reset and a carry flag is
+ // used to increment the next digit.
+ int dst_offset = 0;
+ do {
+ // Copy element.
+ output_data[dst_offset] = input_buffer.data[Offset(input_shape, src_coord)];
+
+ // Compute next source input coordinates.
+ bool carry = true;
+ for (int axis = 0; axis < num_input_axes; axis++) {
+ // Increment this axis if we carried from the previous one
+ if (carry) {
+ src_coord[axis] += op.strides[axis];
+ }
+
+ // Check if we've overflowed.
+ if (src_coord[axis] >= StopForAxis(op, input_shape, axis)) {
+ // Reset axis and set carry
+ src_coord[axis] = StartForAxis(op, input_shape, axis);
+ carry = true;
+ } else {
+ carry = false;
+ }
+ }
+ // increment destination buffer offset
+ dst_offset++;
+ } while (dst_offset < output_data.size());
+}
+
+} // anonymous namespace
+
+bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ const auto* base_op = it->get();
+ if (base_op->type != OperatorType::kStridedSlice) {
+ return false;
+ }
+
+ const StridedSliceOperator* op =
+ static_cast<const StridedSliceOperator*>(base_op);
+
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return false;
+ }
+
+ if (!output_array.has_shape()) {
+ // Yield until the output shape has been set by PropagateFixedShapes
+ return false;
+ }
+
+ if (op->start_indices.empty() || op->stop_indices.empty() ||
+ op->strides.empty()) {
+ // Attributes have not resolved yet.
+ return false;
+ }
+
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // Yield until the value shape has been resolved.
+ return false;
+ }
+ if (!IsConstantParameterArray(*model, op->inputs[0])) {
+ // Yield until the value is constant.
+ return false;
+ }
+
+ CHECK(!output_array.buffer);
+ switch (output_array.data_type) {
+ case ArrayDataType::kFloat:
+ StridedSlice<ArrayDataType::kFloat>(*op, input_array, &output_array);
+ break;
+ case ArrayDataType::kUint8:
+ StridedSlice<ArrayDataType::kUint8>(*op, input_array, &output_array);
+ break;
+ case ArrayDataType::kInt32:
+ StridedSlice<ArrayDataType::kInt32>(*op, input_array, &output_array);
+ break;
+ case ArrayDataType::kInt64:
+ StridedSlice<ArrayDataType::kInt64>(*op, input_array, &output_array);
+ break;
+ default:
+ LOG(FATAL)
+ << "Unsupported data type input to StridedSlice op with output \""
+ << op->outputs[0] << "\"";
+ break;
+ }
+
+ // Erase input array if no longer used
+ if (IsDiscardableArray(*model, op->inputs[0]) &&
+ CountOpsWithInput(*model, op->inputs[0]) == 1) {
+ model->arrays.erase(op->inputs[0]);
+ }
+
+ // Erase the operator
+ model->operators.erase(it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc
deleted file mode 100644
index 8cc6db1619..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc
+++ /dev/null
@@ -1,62 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <cstddef>
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace toco {
-
-bool ResolveConstantTensorFlowShape::Run(Model* model, std::size_t op_index) {
- const auto tfshape_it = model->operators.begin() + op_index;
- const auto* tfshape_base_op = tfshape_it->get();
- if (tfshape_base_op->type != OperatorType::kTensorFlowShape) {
- return false;
- }
-
- const auto* tfshape_op =
- static_cast<const TensorFlowShapeOperator*>(tfshape_base_op);
-
- const auto& input_array = model->GetArray(tfshape_op->inputs[0]);
- auto& output_array = model->GetArray(tfshape_op->outputs[0]);
-
- // Yield until the input array's shape has been resolved.
- if (!input_array.has_shape()) {
- return false;
- }
-
- // Create a buffer for the output array, making it a constant array, and
- // copy the input shape into the output buffer.
- CHECK(!output_array.buffer);
- auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>();
- output_buffer.data = input_array.shape().dims();
-
- // Erase the input array if no longer used
- if (IsDiscardableArray(*model, tfshape_op->inputs[0]) &&
- CountOpsWithInput(*model, tfshape_op->inputs[0]) == 1) {
- model->arrays.erase(tfshape_op->inputs[0]);
- }
- model->operators.erase(tfshape_it);
-
- return true;
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
index 0e4c66544d..a73f16735c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
@@ -44,22 +44,15 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
!IsConstantParameterArray(*model, op->inputs[paddings_index]))
return false;
- // Handling block_shape.
- const auto& block_shape_array = *model->arrays[op->inputs[block_shape_index]];
- if (!block_shape_array.has_shape()) return false;
- const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
- CHECK_EQ(block_shape_dims.size(), 1);
- std::vector<int> block_shape_buffer =
- block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
- for (int i = 0; i < block_shape_dims[0]; ++i) {
- op->block_shape.push_back(block_shape_buffer[i]);
- }
-
- // Handling paddings.
+ // Handle paddings.
const auto& paddings_array = *model->arrays[op->inputs[paddings_index]];
if (!paddings_array.has_shape()) return false;
const std::vector<int>& paddings_dims = paddings_array.shape().dims();
- CHECK_EQ(paddings_dims.size(), 2);
+ if (paddings_dims.size() != 2) {
+ // Code only handles padding of 2 dimensions. Perhaps another transformation
+ // will delete this op.
+ return false;
+ }
std::vector<int> paddings_buffer =
paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
for (int i = 0; i < paddings_dims[0]; ++i) {
@@ -67,6 +60,17 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
op->after_paddings.push_back(paddings_buffer[i * 2 + 1]);
}
+ // Handle block_shape.
+ const auto& block_shape_array = *model->arrays[op->inputs[block_shape_index]];
+ if (!block_shape_array.has_shape()) return false;
+ const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
+ CHECK_EQ(block_shape_dims.size(), 1);
+ std::vector<int> block_shape_buffer =
+ block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < block_shape_dims[0]; ++i) {
+ op->block_shape.push_back(block_shape_buffer[i]);
+ }
+
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
index 97946182ef..dbe69adcbd 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -12,11 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
@@ -30,19 +25,14 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
if (slice_op->type != OperatorType::kStridedSlice) return false;
auto* op = static_cast<StridedSliceOperator*>(slice_op);
- if (!op->start_indices.empty()) return false;
+ if (!op->start_indices.empty()) {
+ // We have already resolved these attributes
+ return false;
+ }
CHECK_EQ(op->inputs.size(), 4);
- if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
- if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
-
const auto& start_array = *model->arrays[op->inputs[1]];
if (!start_array.has_shape()) return false;
- if (toco::RequiredBufferSizeForShape(start_array.shape()) != 4) {
- // Only 4D arrays are supported for now.
- return false;
- }
const auto& stop_array = *model->arrays[op->inputs[2]];
if (!stop_array.has_shape()) return false;
@@ -50,12 +40,24 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
const auto& stride_array = *model->arrays[op->inputs[3]];
if (!stride_array.has_shape()) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
+
op->start_indices = start_array.GetBuffer<ArrayDataType::kInt32>().data;
op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
- // TODO(dkalenichenko): Delete the extra inputs?
+ CHECK_GE(op->start_indices.size(), 1);
+ CHECK_LE(op->start_indices.size(), 4);
+ CHECK_EQ(op->stop_indices.size(), op->start_indices.size());
+ CHECK_EQ(op->strides.size(), op->stop_indices.size());
+ // Ideally, we would remove the input arrays after they have been resolved.
+ // However, we must then reconstitute these input arrays for all supported
+ // export formats. For now, leave the arrays so we don't have to modify our
+ // exporters. Ideally, we wouldn't have op attributes, and would work directly
+ // with the input arrays.
return true;
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index f07fef117f..995e9d67ca 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1181,21 +1181,6 @@ void ConvertStridedSliceOperator(const NodeDef& node,
CHECK_EQ(node.op(), "StridedSlice");
CheckInputsCount(node, tf_import_flags, 4);
- // Only a subset of the full TF op functionality is supported now.
- if ( // No 64-bit indices.
- GetDataTypeAttr(node, "Index") != DT_INT32 ||
- // No dimensionality changes.
- GetIntAttr(node, "new_axis_mask") != 0 ||
- GetIntAttr(node, "shrink_axis_mask") != 0 ||
- // No sparse indices.
- GetIntAttr(node, "ellipsis_mask") != 0 ||
- // Only 4D tensors are supported.
- GetIntAttr(node, "begin_mask") > 15 ||
- GetIntAttr(node, "end_mask") > 15) {
- ConvertUnsupportedOperator(node, tf_import_flags, model);
- return;
- }
-
auto* op = new StridedSliceOperator;
for (const auto& input : node.input()) {
op->inputs.push_back(input);
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 92ec0bcba8..0bcf4596de 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -52,6 +52,7 @@ void MakeGeneralGraphTransformationsSet(
GraphTransformationsSet* transformations) {
CHECK(transformations->empty());
transformations->Add(new ConvertExpandDimsToReshape);
+ transformations->Add(new ConvertTrivialTransposeToReshape);
transformations->Add(new ResolveReshapeAttributes);
transformations->Add(new PropagateArrayDataTypes);
transformations->Add(new PropagateFixedSizes);
@@ -68,6 +69,9 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveBatchNormalization);
transformations->Add(new ResolveConstantBinaryOperator);
transformations->Add(new ResolveConstantFill);
+ transformations->Add(new ResolveConstantRange);
+ transformations->Add(new ResolveConstantStack);
+ transformations->Add(new ResolveConstantStridedSlice);
transformations->Add(new ResolveConstantUnaryOperator);
transformations->Add(new ResolveTensorFlowMerge);
transformations->Add(new ResolveTensorFlowSqueeze);
@@ -86,7 +90,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveSliceAttributes);
transformations->Add(new ResolveMeanAttributes);
transformations->Add(new ResolveTransposeAttributes);
- transformations->Add(new ResolveConstantTensorFlowShape);
+ transformations->Add(new ResolveConstantShapeOrRank);
transformations->Add(new MakeInitialDequantizeOperator);
}