aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc29
1 files changed, 22 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 0a2bf939c1..3df1911d07 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1445,7 +1445,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str());
}
- if (dnums.spatial_dimensions_size() !=
+ if (dnums.input_spatial_dimensions_size() !=
dnums.kernel_spatial_dimensions_size()) {
return InvalidArgument(
"Both arguments to convolution must have same number of dimensions.\n"
@@ -1453,7 +1453,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
window.DebugString().c_str());
}
- const int num_spatial_dims = dnums.spatial_dimensions_size();
+ const int num_spatial_dims = dnums.input_spatial_dimensions_size();
if (window.dimensions_size() != num_spatial_dims) {
return InvalidArgument(
"Window must have same number of dimensions as dimension numbers.\n"
@@ -1482,8 +1482,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
std::vector<int64> input_dnums(num_dims);
input_dnums[0] = dnums.input_batch_dimension();
input_dnums[1] = dnums.input_feature_dimension();
- std::copy(dnums.spatial_dimensions().begin(),
- dnums.spatial_dimensions().end(), input_dnums.begin() + 2);
+ std::copy(dnums.input_spatial_dimensions().begin(),
+ dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
std::sort(input_dnums.begin(), input_dnums.end());
std::vector<int64> window_dnums(num_dims);
@@ -1493,12 +1493,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
std::sort(window_dnums.begin(), window_dnums.end());
+ std::vector<int64> output_dnums(num_dims);
+ output_dnums[0] = dnums.output_batch_dimension();
+ output_dnums[1] = dnums.output_feature_dimension();
+ std::copy(dnums.output_spatial_dimensions().begin(),
+ dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
+ std::sort(output_dnums.begin(), output_dnums.end());
+
std::vector<int64> expected_dnums(num_dims);
std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) ||
- !std::all_of(window_dnums.begin(), window_dnums.end(), in_range)) {
+ !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) ||
+ !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) {
return InvalidArgument(
"A dimension number is out of range in convolution: %s",
dnums.DebugString().c_str());
@@ -1516,10 +1524,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"once: %s",
dnums.DebugString().c_str());
}
+ if (output_dnums != expected_dnums) {
+ return InvalidArgument(
+ "Output dimensions of convolution must contain each dimension exactly "
+ "once: %s",
+ dnums.DebugString().c_str());
+ }
std::vector<int64> input_spatial_dims(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
- input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i));
+ input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
}
const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());
@@ -1567,7 +1581,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
dimensions[dnums.output_batch_dimension()] = input_batch;
dimensions[dnums.output_feature_dimension()] = kernel_output_features;
for (int i = 0; i < num_spatial_dims; ++i) {
- dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i);
+ dimensions[dnums.output_spatial_dimensions(i)] =
+ window_output_shape.dimensions(i);
}
return ShapeUtil::MakeShape(lhs.element_type(), dimensions);