aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc153
1 files changed, 126 insertions, 27 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 82b3ab96fe..a03b589bae 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -437,6 +437,7 @@ void ProcessTensorFlowReshapeOperator(Model* model,
product_non_wildcard_dims *= shape_data[i];
}
}
+
const int input_flat_size = RequiredBufferSizeForShape(input_shape);
if (has_wildcard) {
CHECK_GE(input_flat_size, product_non_wildcard_dims)
@@ -445,6 +446,12 @@ void ProcessTensorFlowReshapeOperator(Model* model,
<< op->outputs[0] << "\". Are your input shapes correct?";
shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
}
+
+ if (shape_data.size() == 1 && shape_data[0] == 0) {
+ // We have reshaped a scalar, so preserve as a scalar.
+ shape_data.clear();
+ }
+
auto& output_shape = *output_array.mutable_shape();
*output_shape.mutable_dims() = shape_data;
CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
@@ -522,12 +529,14 @@ void ProcessAddNOperator(Model* model, Operator* op) {
bool KeepDims(const Operator& op) {
switch (op.type) {
- case OperatorType::kMin: // Reduction Min
+ case OperatorType::kReduceMin: // Reduction Min
return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
- case OperatorType::kMax: // Reduction Max
+ case OperatorType::kReduceMax: // Reduction Max
return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
case OperatorType::kSum:
return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
+ case OperatorType::kReduceProd:
+ return static_cast<const TensorFlowProdOperator&>(op).keep_dims;
case OperatorType::kMean:
return static_cast<const MeanOperator&>(op).keep_dims;
default:
@@ -1034,20 +1043,28 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) {
return;
}
+ // Yield until the axis has been resolved.
+ if (!op->axis) {
+ return;
+ }
+ int axis = op->axis.value();
+
const auto& input_shape = input_array.shape();
const auto& indices_shape = indices_array.shape();
QCHECK_GE(input_shape.dimensions_count(), 1);
op->input_rank = input_shape.dimensions_count();
+ QCHECK_LT(axis, op->input_rank);
- // We only support 1-D indices.
- QCHECK_EQ(indices_shape.dimensions_count(), 1);
-
- // Copy the input dimensions to the output except for dimension 0,
+ // Copy the input dimensions to the output except for the axis dimensions
// where the dimension of indices_shape is used.
- // TODO(mgubin): if axis != 0 this is not true, change when it's supported.
auto output_dims = output_array.mutable_shape()->mutable_dims();
- output_dims->push_back(indices_shape.dims(0));
- for (int dim = 1; dim < input_shape.dimensions_count(); dim++) {
+ for (int dim = 0; dim < axis; ++dim) {
+ output_dims->push_back(input_shape.dims(dim));
+ }
+ for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) {
+ output_dims->push_back(indices_shape.dims(dim));
+ }
+ for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) {
output_dims->push_back(input_shape.dims(dim));
}
}
@@ -1193,7 +1210,7 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
output_shape->ReplaceDims({input_array.shape().dimensions_count()});
}
-void ProcessStackOperator(Model* model, StackOperator* op) {
+void ProcessPackOperator(Model* model, PackOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
@@ -1202,7 +1219,7 @@ void ProcessStackOperator(Model* model, StackOperator* op) {
return;
}
- std::unique_ptr<Shape> stacked_shape;
+ std::unique_ptr<Shape> packed_shape;
for (const auto& input : op->inputs) {
const auto& input_array = model->GetArray(input);
if (!input_array.has_shape()) {
@@ -1211,23 +1228,23 @@ void ProcessStackOperator(Model* model, StackOperator* op) {
}
Shape shape = input_array.shape();
- if (!stacked_shape) {
- stacked_shape.reset(new Shape(shape));
+ if (!packed_shape) {
+ packed_shape.reset(new Shape(shape));
} else {
- CHECK(*stacked_shape == shape) << "All input arrays to Stack operators "
- "must have the same shape. Input \""
- << input << "\" is different.";
+ CHECK(*packed_shape == shape) << "All input arrays to Pack operators "
+ "must have the same shape. Input \""
+ << input << "\" is different.";
}
}
int axis = op->axis;
if (axis < 0) {
// Handle negative axis
- axis += stacked_shape->dims().size() + 1;
+ axis += packed_shape->dims().size() + 1;
}
- stacked_shape->mutable_dims()->insert(
- stacked_shape->mutable_dims()->begin() + axis, op->inputs.size());
- output_array.copy_shape(*stacked_shape);
+ packed_shape->mutable_dims()->insert(
+ packed_shape->mutable_dims()->begin() + axis, op->inputs.size());
+ output_array.copy_shape(*packed_shape);
}
void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
@@ -1407,7 +1424,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
}
}
-void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
+template <typename Op>
+void ProcessArgMinMaxOperator(Model* model, Op* op) {
CHECK_EQ(op->inputs.size(), 2);
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1501,6 +1519,65 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
}
}
+void ProcessAnyOperator(Model* model, AnyOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // We have already run.
+ return;
+ }
+
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // Yield until input dims have been resolved.
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+
+ auto& reduction_indices_array = model->GetArray(op->inputs[1]);
+ if (!reduction_indices_array.has_shape()) {
+ // Yield until reduction indices shape been resolved.
+ return;
+ }
+ if (!reduction_indices_array.buffer) {
+ // Yield until the reduction indices are constant.
+ return;
+ }
+ CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32)
+ << "Any reduction input must be int32";
+
+ int input_rank = input_shape.dimensions_count();
+ std::set<int32> true_indices;
+ const auto& reduction_indices =
+ reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < reduction_indices.size(); ++i) {
+ const int32 reduction_index = reduction_indices[i];
+ if (reduction_index < -input_rank || reduction_index >= input_rank) {
+ CHECK(false) << "Invalid reduction dimension " << reduction_index
+ << " for input with " << input_rank << " dimensions";
+ }
+ int32 wrapped_index = reduction_index;
+ if (wrapped_index < 0) {
+ wrapped_index += input_rank;
+ }
+ true_indices.insert(wrapped_index);
+ }
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->clear();
+ for (int i = 0; i < input_rank; ++i) {
+ if (true_indices.count(i) > 0) {
+ if (op->keep_dims) {
+ mutable_dims->emplace_back(1);
+ }
+ } else {
+ mutable_dims->emplace_back(input_shape.dims(i));
+ }
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1539,6 +1616,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kFloor:
case OperatorType::kExp:
case OperatorType::kSin:
+ case OperatorType::kLogicalAnd:
+ case OperatorType::kLogicalNot:
ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather:
@@ -1607,9 +1686,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kL2Pool:
ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
break;
- case OperatorType::kMin: // Reduction Min
- case OperatorType::kMax: // Reduction Max
+ case OperatorType::kReduceMin: // Reduction Min
+ case OperatorType::kReduceMax: // Reduction Max
case OperatorType::kSum:
+ case OperatorType::kReduceProd:
case OperatorType::kMean:
ProcessTensorFlowReductionOperator(model, op);
break;
@@ -1658,8 +1738,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kShape:
ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
break;
- case OperatorType::kStack:
- ProcessStackOperator(model, static_cast<StackOperator*>(op));
+ case OperatorType::kPack:
+ ProcessPackOperator(model, static_cast<PackOperator*>(op));
break;
case OperatorType::kReorderAxes:
ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
@@ -1699,10 +1779,26 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
static_cast<StridedSliceOperator*>(op));
break;
case OperatorType::kArgMax:
- ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op));
+ ProcessArgMinMaxOperator<ArgMaxOperator>(
+ model, static_cast<ArgMaxOperator*>(op));
+ break;
+ case OperatorType::kArgMin:
+ ProcessArgMinMaxOperator<ArgMinOperator>(
+ model, static_cast<ArgMinOperator*>(op));
break;
- case OperatorType::kUnsupported:
+ case OperatorType::kUnsupported: {
+ const auto* unsupported_op =
+ static_cast<TensorFlowUnsupportedOperator*>(op);
+ // Attribute can be not specified, ignore it.
+ if (unsupported_op->output_shapes.size() < op->outputs.size()) {
+ return false;
+ }
+ for (int i = 0; i < op->outputs.size(); ++i) {
+ const string& output = op->outputs[i];
+ model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i));
+ }
break;
+ }
case OperatorType::kSvdf:
ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
break;
@@ -1726,6 +1822,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTile:
ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
break;
+ case OperatorType::kAny:
+ ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);