aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-02 08:00:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-02 08:03:31 -0700
commitbe571938196fb191f260a2c45176d406e6c19a13 (patch)
tree31dd2158201796c4eb994dab8cae70d397df1bb1 /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
parent1fbad5034a8ea531e496b0ecbf9e2c3839b62311 (diff)
Adding support for RandomUniform. Basic support for op import/export of RandomUniform, and constant resolution with static seeds.
PiperOrigin-RevId: 191293897
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc164
1 files changed, 97 insertions, 67 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 778da39bf1..89ad58f887 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -50,78 +50,108 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
old_output_data_types[output] = model->GetArray(output).data_type;
}
// Do the actual output data types propagation.
- if (op->type == OperatorType::kDequantize ||
- op->type == OperatorType::kResizeBilinear) {
- // These operators unconditionally produce float outputs
- SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
- } else if (op->type == OperatorType::kTensorFlowLess ||
- op->type == OperatorType::kTensorFlowLessEqual ||
- op->type == OperatorType::kTensorFlowGreater ||
- op->type == OperatorType::kTensorFlowGreaterEqual) {
- // These operators unconditionally produce bool outputs
- SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
- } else if (op->type == OperatorType::kRank ||
- op->type == OperatorType::kTensorFlowShape) {
- // These operators only produce int32 outputs.
- SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
- } else if (op->type == OperatorType::kTensorFlowSplit ||
- op->type == OperatorType::kTensorFlowConcat ||
- op->type == OperatorType::kFill) {
- // These operators produce an output with the same type as their 2nd input
- CHECK_GE(op->inputs.size(), 2);
- const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
- SetDataTypeForAllOutputs(model, op, data_type);
- } else if (op->type == OperatorType::kTransposeConv) {
- // These operators produce an output with the same type as their 3rd input
- CHECK_GE(op->inputs.size(), 3);
- const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type;
- SetDataTypeForAllOutputs(model, op, data_type);
- } else if (op->type == OperatorType::kCast) {
- // Data type of the Cast op is specified.
- CHECK_EQ(op->outputs.size(), 1);
- auto* cast_op = static_cast<CastOperator*>(op);
- model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
- } else if (op->type == OperatorType::kArgMax) {
- // Data type of the ArgMax op is specified.
- CHECK_EQ(op->outputs.size(), 1);
- auto* argmax_op = static_cast<ArgMaxOperator*>(op);
- model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
- } else if (op->type == OperatorType::kRange) {
- auto* range_op = static_cast<RangeOperator*>(op);
- // Output type of the Range op can be set via an attribute
- ArrayDataType data_type;
- if (range_op->dtype != ArrayDataType::kNone) {
- // Use the type if specified
- data_type = range_op->dtype;
- } else {
- // Otherwise use the first input
- CHECK_GE(op->inputs.size(), 1);
- data_type = model->GetArray(op->inputs[0]).data_type;
+ switch (op->type) {
+ case OperatorType::kDequantize:
+ case OperatorType::kResizeBilinear:
+ // These operators unconditionally produce float outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
+ break;
+ case OperatorType::kTensorFlowLess:
+ case OperatorType::kTensorFlowLessEqual:
+ case OperatorType::kTensorFlowGreater:
+ case OperatorType::kTensorFlowGreaterEqual:
+ // These operators unconditionally produce bool outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
+ break;
+ case OperatorType::kRank:
+ case OperatorType::kTensorFlowShape:
+ // These operators only produce int32 outputs.
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
+ break;
+ case OperatorType::kTensorFlowSplit:
+ case OperatorType::kTensorFlowConcat:
+ case OperatorType::kFill: {
+ // These operators produce an output with the same type as their 2nd input
+ CHECK_GE(op->inputs.size(), 2);
+ const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
}
- CHECK_EQ(op->outputs.size(), 1);
- SetDataTypeForAllOutputs(model, op, data_type);
- } else if (op->type == OperatorType::kTensorFlowUnsupported) {
- auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
- // Some output tensors from the op could be eliminated by optimization.
- // This can make unsupported_op->output_data_types have more elements than
- // op->outputs.
- if (unsupported_op->output_data_types.size() < op->outputs.size()) {
+ case OperatorType::kTransposeConv: {
+ // These operators produce an output with the same type as their 3rd input
+ CHECK_GE(op->inputs.size(), 3);
+ const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
+ case OperatorType::kCast: {
+ // Data type of the Cast op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* cast_op = static_cast<CastOperator*>(op);
+ model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
+ break;
+ }
+ case OperatorType::kArgMax: {
+ // Data type of the ArgMax op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* argmax_op = static_cast<ArgMaxOperator*>(op);
+ model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
+ break;
+ }
+ case OperatorType::kRange: {
+ auto* range_op = static_cast<RangeOperator*>(op);
+ // Output type of the Range op can be set via an attribute
+ ArrayDataType data_type;
+ if (range_op->dtype != ArrayDataType::kNone) {
+ // Use the type if specified
+ data_type = range_op->dtype;
+ } else {
+ // Otherwise use the first input
+ CHECK_GE(op->inputs.size(), 1);
+ data_type = model->GetArray(op->inputs[0]).data_type;
+ }
+ CHECK_EQ(op->outputs.size(), 1);
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
+ case OperatorType::kRandomUniform: {
+ auto* rand_op = static_cast<RandomUniformOperator*>(op);
+ // The output type of RandomUniform is specified with an attribute
+ if (rand_op->dtype == ArrayDataType::kNone) {
+ return false;
+ }
+ CHECK_EQ(op->outputs.size(), 1);
+ SetDataTypeForAllOutputs(model, op, rand_op->dtype);
+ break;
+ }
+ case OperatorType::kTensorFlowUnsupported: {
+ auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
+ // Some output tensors from the op could be eliminated by optimization.
+ // This can make unsupported_op->output_data_types have more elements than
+ // op->outputs.
+ if (unsupported_op->output_data_types.size() < op->outputs.size()) {
+ return false;
+ }
+ for (int i = 0; i < op->outputs.size(); ++i) {
+ auto output = op->outputs[i];
+ auto data_type = unsupported_op->output_data_types[i];
+ model->GetArray(output).data_type = data_type;
+ }
+ break;
+ }
+ case OperatorType::kExpandDims: {
+ // Yield on ExpandDim until it is converted to Reshape
return false;
}
- for (int i = 0; i < op->outputs.size(); ++i) {
- auto output = op->outputs[i];
- auto data_type = unsupported_op->output_data_types[i];
- model->GetArray(output).data_type = data_type;
+ default: {
+ // These operators produce outputs with the same type as their 1st input
+ CHECK_GT(op->inputs.size(), 0);
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
}
- } else if (op->type == OperatorType::kExpandDims) {
- // Yield on ExpandDim until it is converted to Reshape
- return false;
- } else {
- // These operators produce outputs with the same type as their 1st input
- CHECK_GT(op->inputs.size(), 0);
- const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
- SetDataTypeForAllOutputs(model, op, data_type);
}
+
// Return true if any output data type changed, false if none changed.
for (const auto& output : op->outputs) {
if (old_output_data_types[output] != model->GetArray(output).data_type) {