aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-08 11:22:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 11:26:51 -0700
commita1d3ebbe6f40768ea5ccf6beef9e905bce207b42 (patch)
tree21f4ba44273ba4b2f6caf10f56c20012f11fbd6e
parente041ae7dcc8916a4f5b358c45a68a6ec731e86e2 (diff)
Fixing for TopKV2 shape propagation in Toco.
PiperOrigin-RevId: 207917321
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc20
2 files changed, 22 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 73c05e21b6..1dfd086f5a 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1670,7 +1670,7 @@ def make_shape_tests(zip_path):
}]
def build_graph(parameters):
- """Build the topk op testing graph."""
+ """Build the shape op testing graph."""
# Note that we intentionally leave out the shape from the input placeholder
# to prevent the Shape operation from being optimized out during conversion.
input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input")
@@ -2318,6 +2318,7 @@ def make_topk_tests(zip_path):
test_parameters = [{
"input_dtype": [tf.float32, tf.int32],
"input_shape": [[10], [5, 20]],
+ "input_k": [None, 1, 3],
}]
def build_graph(parameters):
@@ -2326,15 +2327,23 @@ def make_topk_tests(zip_path):
dtype=parameters["input_dtype"],
name="input",
shape=parameters["input_shape"])
- k = tf.constant(3, name="k")
+ if parameters["input_k"] is not None:
+ k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[])
+ else:
+ k = tf.constant(3, name="k")
out = tf.nn.top_k(input_value, k)
- return [input_value], [out[1]]
+ return [input_value, k], [out[1]]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
parameters["input_shape"])
- return [input_value], sess.run(
- outputs, feed_dict=dict(zip(inputs, [input_value])))
+ if parameters["input_k"] is not None:
+ k = np.array(parameters["input_k"], dtype=np.int32)
+ return [input_value, k], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value, k])))
+ else:
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 3c9379fd87..91e290439a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1082,27 +1082,23 @@ void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
}
// Yield until input dims have been resolved.
- if (!input_values.has_shape()) {
+ if (!input_values.has_shape() || !input_k.has_shape()) {
return;
}
- const auto& input_values_shape = input_values.shape();
- auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims();
- auto output_values_dims = output_values.mutable_shape()->mutable_dims();
- for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) {
- output_indexes_dims->push_back(input_values_shape.dims(dim));
- output_values_dims->push_back(input_values_shape.dims(dim));
- }
// If the value is initialized, we can specify the last dimension, otherwise
// unknown.
if (input_k.buffer) {
+ const auto& input_values_shape = input_values.shape();
+ auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims();
+ auto output_values_dims = output_values.mutable_shape()->mutable_dims();
+ for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) {
+ output_indexes_dims->push_back(input_values_shape.dims(dim));
+ output_values_dims->push_back(input_values_shape.dims(dim));
+ }
const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
output_indexes_dims->push_back(k_value);
output_values_dims->push_back(k_value);
-
- } else {
- output_indexes_dims->push_back(0);
- output_values_dims->push_back(0);
}
}