aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-31 06:05:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 06:07:59 -0700
commit7e2e57410eb40c0512dc573955fd256a6c787741 (patch)
treeec345a16ed486ec5a964ac5d6be20bde7d7b401c /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parentca4bda919793cc2578e5c0f7440525261da16fdf (diff)
implementation of sparse_to_dense
PiperOrigin-RevId: 198710452
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 9d1d27f3ef..adb241da32 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1477,6 +1477,34 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
*output_array.mutable_shape()->mutable_dims() = output_dims;
}
+void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
+ CHECK_EQ(op->inputs.size(), 4);
+
+ const Array& output_shape_array = model->GetArray(op->inputs[1]);
+ if (!output_shape_array.has_shape()) return;
+ CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
+
+ // Output should not go over four dimensions.
+ CHECK_LE(output_shape_array.shape().dims(0), 4);
+
+ const string& output_name = op->outputs[0];
+ Array& output_array = model->GetArray(output_name);
+ if (output_array.has_shape()) return;
+
+ CHECK(output_shape_array.data_type == ArrayDataType::kInt32 ||
+ output_shape_array.data_type == ArrayDataType::kInt64);
+ if (output_shape_array.data_type == ArrayDataType::kInt32) {
+ *output_array.mutable_shape()->mutable_dims() =
+ output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ } else {
+ const std::vector<int64>& output_shape_data =
+ output_shape_array.GetBuffer<ArrayDataType::kInt64>().data;
+ std::copy(
+ output_shape_data.begin(), output_shape_data.end(),
+ std::back_inserter(*output_array.mutable_shape()->mutable_dims()));
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1700,6 +1728,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 1);
ProcessOpWithShapeInput(model, op);
break;
+ case OperatorType::kSparseToDense:
+ ProcessSparseToDenseOperator(model,
+ static_cast<SparseToDenseOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);