diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-02 08:00:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-02 08:03:31 -0700 |
commit | be571938196fb191f260a2c45176d406e6c19a13 (patch) | |
tree | 31dd2158201796c4eb994dab8cae70d397df1bb1 /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | |
parent | 1fbad5034a8ea531e496b0ecbf9e2c3839b62311 (diff) |
Adding support for RandomUniform. Basic support for op import/export of RandomUniform, and constant resolution with static seeds.
PiperOrigin-RevId: 191293897
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | 164 |
1 files changed, 97 insertions, 67 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 778da39bf1..89ad58f887 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -50,78 +50,108 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { old_output_data_types[output] = model->GetArray(output).data_type; } // Do the actual output data types propagation. - if (op->type == OperatorType::kDequantize || - op->type == OperatorType::kResizeBilinear) { - // These operators unconditionally produce float outputs - SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); - } else if (op->type == OperatorType::kTensorFlowLess || - op->type == OperatorType::kTensorFlowLessEqual || - op->type == OperatorType::kTensorFlowGreater || - op->type == OperatorType::kTensorFlowGreaterEqual) { - // These operators unconditionally produce bool outputs - SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); - } else if (op->type == OperatorType::kRank || - op->type == OperatorType::kTensorFlowShape) { - // These operators only produce int32 outputs. - SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); - } else if (op->type == OperatorType::kTensorFlowSplit || - op->type == OperatorType::kTensorFlowConcat || - op->type == OperatorType::kFill) { - // These operators produce an output with the same type as their 2nd input - CHECK_GE(op->inputs.size(), 2); - const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type; - SetDataTypeForAllOutputs(model, op, data_type); - } else if (op->type == OperatorType::kTransposeConv) { - // These operators produce an output with the same type as their 3rd input - CHECK_GE(op->inputs.size(), 3); - const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type; - SetDataTypeForAllOutputs(model, op, data_type); - } else if (op->type == OperatorType::kCast) { - // Data type of the Cast op is specified. - CHECK_EQ(op->outputs.size(), 1); - auto* cast_op = static_cast<CastOperator*>(op); - model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type; - } else if (op->type == OperatorType::kArgMax) { - // Data type of the ArgMax op is specified. - CHECK_EQ(op->outputs.size(), 1); - auto* argmax_op = static_cast<ArgMaxOperator*>(op); - model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; - } else if (op->type == OperatorType::kRange) { - auto* range_op = static_cast<RangeOperator*>(op); - // Output type of the Range op can be set via an attribute - ArrayDataType data_type; - if (range_op->dtype != ArrayDataType::kNone) { - // Use the type if specified - data_type = range_op->dtype; - } else { - // Otherwise use the first input - CHECK_GE(op->inputs.size(), 1); - data_type = model->GetArray(op->inputs[0]).data_type; + switch (op->type) { + case OperatorType::kDequantize: + case OperatorType::kResizeBilinear: + // These operators unconditionally produce float outputs + SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); + break; + case OperatorType::kTensorFlowLess: + case OperatorType::kTensorFlowLessEqual: + case OperatorType::kTensorFlowGreater: + case OperatorType::kTensorFlowGreaterEqual: + // These operators unconditionally produce bool outputs + SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); + break; + case OperatorType::kRank: + case OperatorType::kTensorFlowShape: + // These operators only produce int32 outputs. + SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); + break; + case OperatorType::kTensorFlowSplit: + case OperatorType::kTensorFlowConcat: + case OperatorType::kFill: { + // These operators produce an output with the same type as their 2nd input + CHECK_GE(op->inputs.size(), 2); + const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; } - CHECK_EQ(op->outputs.size(), 1); - SetDataTypeForAllOutputs(model, op, data_type); - } else if (op->type == OperatorType::kTensorFlowUnsupported) { - auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op); - // Some output tensors from the op could be eliminated by optimization. - // This can make unsupported_op->output_data_types have more elements than - // op->outputs. - if (unsupported_op->output_data_types.size() < op->outputs.size()) { + case OperatorType::kTransposeConv: { + // These operators produce an output with the same type as their 3rd input + CHECK_GE(op->inputs.size(), 3); + const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; + } + case OperatorType::kCast: { + // Data type of the Cast op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* cast_op = static_cast<CastOperator*>(op); + model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type; + break; + } + case OperatorType::kArgMax: { + // Data type of the ArgMax op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* argmax_op = static_cast<ArgMaxOperator*>(op); + model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; + break; + } + case OperatorType::kRange: { + auto* range_op = static_cast<RangeOperator*>(op); + // Output type of the Range op can be set via an attribute + ArrayDataType data_type; + if (range_op->dtype != ArrayDataType::kNone) { + // Use the type if specified + data_type = range_op->dtype; + } else { + // Otherwise use the first input + CHECK_GE(op->inputs.size(), 1); + data_type = model->GetArray(op->inputs[0]).data_type; + } + CHECK_EQ(op->outputs.size(), 1); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } + case OperatorType::kRandomUniform: { + auto* rand_op = static_cast<RandomUniformOperator*>(op); + // The output type of RandomUniform is specified with an attribute + if (rand_op->dtype == ArrayDataType::kNone) { + return false; + } + CHECK_EQ(op->outputs.size(), 1); + SetDataTypeForAllOutputs(model, op, rand_op->dtype); + break; + } + case OperatorType::kTensorFlowUnsupported: { + auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op); + // Some output tensors from the op could be eliminated by optimization. + // This can make unsupported_op->output_data_types have more elements than + // op->outputs. + if (unsupported_op->output_data_types.size() < op->outputs.size()) { + return false; + } + for (int i = 0; i < op->outputs.size(); ++i) { + auto output = op->outputs[i]; + auto data_type = unsupported_op->output_data_types[i]; + model->GetArray(output).data_type = data_type; + } + break; + } + case OperatorType::kExpandDims: { + // Yield on ExpandDim until it is converted to Reshape return false; } - for (int i = 0; i < op->outputs.size(); ++i) { - auto output = op->outputs[i]; - auto data_type = unsupported_op->output_data_types[i]; - model->GetArray(output).data_type = data_type; + default: { + // These operators produce outputs with the same type as their 1st input + CHECK_GT(op->inputs.size(), 0); + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; } - } else if (op->type == OperatorType::kExpandDims) { - // Yield on ExpandDim until it is converted to Reshape - return false; - } else { - // These operators produce outputs with the same type as their 1st input - CHECK_GT(op->inputs.size(), 0); - const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; - SetDataTypeForAllOutputs(model, op, data_type); } + // Return true if any output data type changed, false if none changed. for (const auto& output : op->outputs) { if (old_output_data_types[output] != model->GetArray(output).data_type) { |