diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-02 08:00:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-02 08:03:31 -0700 |
commit | be571938196fb191f260a2c45176d406e6c19a13 (patch) | |
tree | 31dd2158201796c4eb994dab8cae70d397df1bb1 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | 1fbad5034a8ea531e496b0ecbf9e2c3839b62311 (diff) |
Adding support for RandomUniform. Basic support for op import/export of RandomUniform, and constant resolution with static seeds.
PiperOrigin-RevId: 191293897
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 676736cfc5..b96d698675 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -392,8 +392,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) { depth * block_size * block_size})); } -void ProcessFillOperator(Model* model, FillOperator* op) { - CHECK_EQ(op->inputs.size(), 2); +void ProcessOpWithShapeInput(Model* model, Operator* op) { CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { @@ -1529,7 +1528,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { static_cast<SpaceToDepthOperator*>(op)); break; case OperatorType::kFill: - ProcessFillOperator(model, static_cast<FillOperator*>(op)); + CHECK_EQ(op->inputs.size(), 2); + ProcessOpWithShapeInput(model, op); break; case OperatorType::kFullyConnected: ProcessFullyConnectedOperator(model, @@ -1659,6 +1659,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { // transforms that remove them, so we avoid propagating shapes through // them and let things settle once they've been removed. break; + case OperatorType::kRandomUniform: + CHECK_EQ(op->inputs.size(), 1); + ProcessOpWithShapeInput(model, op); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); |