aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-02-26 10:17:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 10:21:06 -0800
commit3b08cd35bc108f48b4f63d73af7a53eb8a1169f9 (patch)
tree4acb6d08b8978c51499073b25ec187e3f0e57fc1 /tensorflow/compiler/xla/service/shape_inference.cc
parentf4e70be18b104fbb2efeefeb83bea190aec12727 (diff)
Generalize the gather_indices dimension that stores indices
This is now exposed as a index_vector_dim dimension number. Also fixed an off-by-one error in ValidateGatherDimensionNumbers in the expression computing output_shape_rank. PiperOrigin-RevId: 187040748
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc42
1 files changed, 27 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index c9692757b2..607a672025 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -2467,27 +2467,27 @@ static Status ValidateGatherDimensionNumbers(
const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
const int64 output_shape_rank =
- output_window_dim_count + gather_indices_shape.size();
+ output_window_dim_count + gather_indices_shape.size() - 1;
for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
int64 window_index = dim_numbers.output_window_dims(i);
if (window_index < 0 || window_index >= output_shape_rank) {
return InvalidArgument(
"Window index %d in gather op is out of bounds; got %lld, but should "
- "have been in"
- "[0,%lld)",
+ "have been in [0,%lld)",
i, window_index, output_shape_rank);
}
}
if (dim_numbers.gather_dims_to_operand_dims_size() !=
- gather_indices_shape.back()) {
+ gather_indices_shape[dim_numbers.index_vector_dim()]) {
return InvalidArgument(
- "There must be exactly as many elements in gather_dims_to_operand_dims "
- "as there are elements in the last dimension of %%gather_indices; got: "
- "%d, expected %lld",
+ "Gather op has %d elements in gather_dims_to_operand_dims and the "
+ "bound of dimension index_vector_dim=%lld of gather_indices is "
+ "%lld. These two numbers must be equal.",
dim_numbers.gather_dims_to_operand_dims_size(),
- gather_indices_shape.back());
+ dim_numbers.index_vector_dim(),
+ gather_indices_shape[dim_numbers.index_vector_dim()]);
}
for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
@@ -2550,24 +2550,33 @@ static Status ValidateGatherDimensionNumbers(
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
gather_indices_shape, "gather indices operand of gather op"));
- if (gather_indices_shape.dimensions_size() < 1) {
+ if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
return InvalidArgument(
- "Gather indices parameter must at least of rank 1; got %s",
+ "Gather indices parameter must be an integral tensor; got %s",
ShapeUtil::HumanString(gather_indices_shape).c_str());
}
- if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
+ // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
+ // index_vector_dim is rank(P). The bounds of this expanded shape is
+ // stored in expanded_gather_indices_shape.
+
+ if (gather_indices_shape.dimensions_size() <
+ gather_dim_numbers.index_vector_dim() ||
+ gather_dim_numbers.index_vector_dim() < 0) {
return InvalidArgument(
- "Gather indices parameter must be an integral tensor; got %s",
- ShapeUtil::HumanString(gather_indices_shape).c_str());
+ "Gather index leaf dimension must be within [0, rank(gather_indices) + "
+ "1). rank(gather_indices) is %d and gather index leaf dimension is "
+ "%lld.",
+ gather_indices_shape.dimensions_size(),
+ gather_dim_numbers.index_vector_dim());
}
std::vector<int64> expanded_gather_indices_shape;
- // We implicitly reshape gather indices of shape P[N] to P[N,1].
expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
c_copy(gather_indices_shape.dimensions(),
std::back_inserter(expanded_gather_indices_shape));
- if (expanded_gather_indices_shape.size() == 1) {
+ if (expanded_gather_indices_shape.size() ==
+ gather_dim_numbers.index_vector_dim()) {
expanded_gather_indices_shape.push_back(1);
}
@@ -2632,6 +2641,9 @@ static Status ValidateGatherDimensionNumbers(
}
current_bound = window_bounds[window_dims_seen++];
} else {
+ if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
+ gather_dims_seen++;
+ }
current_bound = expanded_gather_indices_shape[gather_dims_seen++];
}