aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-02 08:00:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-02 08:03:31 -0700
commitbe571938196fb191f260a2c45176d406e6c19a13 (patch)
tree31dd2158201796c4eb994dab8cae70d397df1bb1 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parent1fbad5034a8ea531e496b0ecbf9e2c3839b62311 (diff)
Adding support for RandomUniform. Basic support for op import/export of RandomUniform, and constant resolution with static seeds.
PiperOrigin-RevId: 191293897
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc10
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 676736cfc5..b96d698675 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -392,8 +392,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
depth * block_size * block_size}));
}
-void ProcessFillOperator(Model* model, FillOperator* op) {
- CHECK_EQ(op->inputs.size(), 2);
+void ProcessOpWithShapeInput(Model* model, Operator* op) {
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
@@ -1529,7 +1528,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
static_cast<SpaceToDepthOperator*>(op));
break;
case OperatorType::kFill:
- ProcessFillOperator(model, static_cast<FillOperator*>(op));
+ CHECK_EQ(op->inputs.size(), 2);
+ ProcessOpWithShapeInput(model, op);
break;
case OperatorType::kFullyConnected:
ProcessFullyConnectedOperator(model,
@@ -1659,6 +1659,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
// transforms that remove them, so we avoid propagating shapes through
// them and let things settle once they've been removed.
break;
+ case OperatorType::kRandomUniform:
+ CHECK_EQ(op->inputs.size(), 1);
+ ProcessOpWithShapeInput(model, op);
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);