diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-08 11:22:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 11:26:51 -0700 |
commit | a1d3ebbe6f40768ea5ccf6beef9e905bce207b42 (patch) | |
tree | 21f4ba44273ba4b2f6caf10f56c20012f11fbd6e | |
parent | e041ae7dcc8916a4f5b358c45a68a6ec731e86e2 (diff) |
Fixing for TopKV2 shape propagation in Toco.
PiperOrigin-RevId: 207917321
-rw-r--r-- | tensorflow/contrib/lite/testing/generate_examples.py | 19 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 20 |
2 files changed, 22 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 73c05e21b6..1dfd086f5a 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -1670,7 +1670,7 @@ def make_shape_tests(zip_path): }] def build_graph(parameters): - """Build the topk op testing graph.""" + """Build the shape op testing graph.""" # Note that we intentionally leave out the shape from the input placeholder # to prevent the Shape operation from being optimized out during conversion. input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input") @@ -2318,6 +2318,7 @@ def make_topk_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32, tf.int32], "input_shape": [[10], [5, 20]], + "input_k": [None, 1, 3], }] def build_graph(parameters): @@ -2326,15 +2327,23 @@ def make_topk_tests(zip_path): dtype=parameters["input_dtype"], name="input", shape=parameters["input_shape"]) - k = tf.constant(3, name="k") + if parameters["input_k"] is not None: + k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[]) + else: + k = tf.constant(3, name="k") out = tf.nn.top_k(input_value, k) - return [input_value], [out[1]] + return [input_value, k], [out[1]] def build_inputs(parameters, sess, inputs, outputs): input_value = create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) - return [input_value], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_value]))) + if parameters["input_k"] is not None: + k = np.array(parameters["input_k"], dtype=np.int32) + return [input_value, k], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value, k]))) + else: + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 3c9379fd87..91e290439a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1082,27 +1082,23 @@ void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) { } // Yield until input dims have been resolved. - if (!input_values.has_shape()) { + if (!input_values.has_shape() || !input_k.has_shape()) { return; } - const auto& input_values_shape = input_values.shape(); - auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); - auto output_values_dims = output_values.mutable_shape()->mutable_dims(); - for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { - output_indexes_dims->push_back(input_values_shape.dims(dim)); - output_values_dims->push_back(input_values_shape.dims(dim)); - } // If the value is initialized, we can specify the last dimension, otherwise // unknown. if (input_k.buffer) { + const auto& input_values_shape = input_values.shape(); + auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); + auto output_values_dims = output_values.mutable_shape()->mutable_dims(); + for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { + output_indexes_dims->push_back(input_values_shape.dims(dim)); + output_values_dims->push_back(input_values_shape.dims(dim)); + } const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0]; output_indexes_dims->push_back(k_value); output_values_dims->push_back(k_value); - - } else { - output_indexes_dims->push_back(0); - output_values_dims->push_back(0); } } |