aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/BUILD3
-rw-r--r--tensorflow/compiler/xla/reference_util.cc247
2 files changed, 81 insertions, 169 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index e0a03a78f1..ba90b13b38 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -563,6 +563,9 @@ cc_library(
":xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_evaluator",
+ "//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//tensorflow/core:lib",
],
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 7ef5c6d916..c851c38ea4 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -20,6 +20,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/math/math_util.h"
@@ -446,179 +449,85 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
std::pair<int64, int64> kernel_stride, Padding padding,
std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
ConvolutionDimensionNumbers dnums) {
- std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}};
- std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}};
-
- const int64 ksy = kernel_stride.first;
- const int64 ksx = kernel_stride.second;
- const int64 dy = lhs_dilation.first;
- const int64 dx = lhs_dilation.second;
- const int64 dky = rhs_dilation.first;
- const int64 dkx = rhs_dilation.second;
- CHECK_GE(dky, 1);
- CHECK_GE(dkx, 1);
- CHECK_GE(dy, 1);
- CHECK_GE(dx, 1);
-
- // Get all dimension sizes in lhs and rhs based on the given convolution
- // dimension configuration.
- const int64 ix = window_util::DilatedBound(
- lhs_dimensions[dnums.spatial_dimensions(1)], dx);
- const int64 iy = window_util::DilatedBound(
- lhs_dimensions[dnums.spatial_dimensions(0)], dy);
- const int64 iz = lhs_dimensions[dnums.feature_dimension()];
- const int64 samples = lhs_dimensions[dnums.batch_dimension()];
- const int64 kx = window_util::DilatedBound(
- rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx);
- const int64 ky = window_util::DilatedBound(
- rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky);
- const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()];
- {
- const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()];
- CHECK_EQ(kiz, iz);
+ HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
+ auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs);
+ auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs);
+
+ std::array<int64, 2> ordered_kernel_strides;
+ std::array<int64, 2> ordered_input_dimensions;
+ std::array<int64, 2> ordered_kernel_dimensions;
+ if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
+ ordered_kernel_strides[0] = kernel_stride.second;
+ ordered_kernel_strides[1] = kernel_stride.first;
+ } else {
+ ordered_kernel_strides[0] = kernel_stride.first;
+ ordered_kernel_strides[1] = kernel_stride.second;
}
- if (padding == Padding::kSame) {
- // We reject same padding with kernel striding, since it's somewhat
- // nonsensical. We can always follow up to implement this with the desired
- // semantics if anybody actually uses it.
- CHECK_EQ(1, ksy);
- CHECK_EQ(1, ksx);
- }
-
- const int64 ox =
- padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx);
- const int64 oy =
- padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy);
- const int64 istartx =
- padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2;
- const int64 istarty =
- padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2;
- // Create the output result array and reset the values to 0.
- std::array<int64, 4> result_dimensions;
- result_dimensions[dnums.batch_dimension()] = samples;
- result_dimensions[dnums.feature_dimension()] = oz;
- result_dimensions[dnums.spatial_dimensions(0)] = oy;
- result_dimensions[dnums.spatial_dimensions(1)] = ox;
+ ordered_input_dimensions[0] =
+ lhs_literal->shape().dimensions(dnums.spatial_dimensions(0));
+ ordered_input_dimensions[1] =
+ lhs_literal->shape().dimensions(dnums.spatial_dimensions(1));
+ ordered_kernel_dimensions[0] =
+ rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
+ ordered_kernel_dimensions[1] =
+ rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
+
+ std::vector<std::pair<int64, int64>> paddings =
+ MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
+ ordered_kernel_strides, padding);
+ CHECK_EQ(paddings.size(), 2);
+
+ Window window;
+
+ WindowDimension dim;
+ dim.set_size(
+ rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
+ dim.set_stride(kernel_stride.first);
+ dim.set_padding_low(paddings[0].first);
+ dim.set_padding_high(paddings[0].second);
+ dim.set_window_dilation(rhs_dilation.first);
+ dim.set_base_dilation(lhs_dilation.first);
+ *window.add_dimensions() = dim;
+
+ WindowDimension dim2;
+ dim2.set_size(
+ rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
+ dim2.set_stride(kernel_stride.second);
+ dim2.set_padding_low(paddings[1].first);
+ dim2.set_padding_high(paddings[1].second);
+ dim2.set_window_dilation(rhs_dilation.second);
+ dim2.set_base_dilation(lhs_dilation.second);
+ *window.add_dimensions() = dim2;
+
+ const Shape& shape =
+ ShapeInference::InferConvolveShape(lhs_literal->shape(),
+ rhs_literal->shape(), window, dnums)
+ .ConsumeValueOrDie();
+
+ HloInstruction* lhs_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
+ HloInstruction* rhs_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
+
+ b.AddInstruction(HloInstruction::CreateConvolve(
+ shape, lhs_instruction, rhs_instruction, window, dnums));
+
+ HloEvaluator evaluator;
+ std::unique_ptr<Literal> result_literal =
+ evaluator.Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
+
+ CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
auto result =
- MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1],
- result_dimensions[2], result_dimensions[3]);
- result->Fill(0.0);
-
- const auto is_int32 = [](int64 x) {
- return x >= std::numeric_limits<int32>::min() &&
- x <= std::numeric_limits<int32>::max();
- };
-
- // 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at
- // least on x86-64), so we avoid them where possible.
- const auto fast_idiv64 = [&](int64 a, int64 b) {
- if (is_int32(a) && is_int32(b)) {
- return static_cast<int64>(static_cast<int32>(a) / static_cast<int32>(b));
- }
- return a / b;
- };
- const auto fast_imod64 = [&](int64 a, int64 b) {
- if (is_int32(a) && is_int32(b)) {
- return static_cast<int64>(static_cast<int32>(a) % static_cast<int32>(b));
- }
- return a % b;
- };
-
- // Lambda to access the lhs operand at the given 4D index.
- const auto lhs_element = [&](int64 batch, int64 feature, int64 height,
- int64 width) {
- if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) {
- return 0.0f;
- }
+ MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0),
+ result_literal->shape().dimensions(1),
+ result_literal->shape().dimensions(2),
+ result_literal->shape().dimensions(3));
+
+ result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
+ *value = result_literal->Get<float>(indices);
+ });
- std::array<int64, 4> index;
- index[dnums.batch_dimension()] = batch;
- index[dnums.feature_dimension()] = feature;
- index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy);
- index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx);
- return lhs(index[0], index[1], index[2], index[3]);
- };
-
- // Lambda to access the rhs operand at the given 4D index. height_over_dky
- // should be equal to height / dky, and width_over_dkx should be equal to
- // width / dkx. (This is an optimization to avoid doing divisions.)
- const auto rhs_element =
- [&](int64 kernel_output_feature, int64 kernel_input_feature, int64 height,
- int64 width, int64 height_over_dky, int64 width_over_dkx) {
- DCHECK_EQ(height % dky, 0);
- DCHECK_EQ(width % dkx, 0);
- DCHECK_EQ(height / dky, height_over_dky);
- DCHECK_EQ(width / dkx, width_over_dkx);
-
- std::array<int64, 4> index;
- index[dnums.kernel_output_feature_dimension()] = kernel_output_feature;
- index[dnums.kernel_input_feature_dimension()] = kernel_input_feature;
- index[dnums.kernel_spatial_dimensions(0)] = height_over_dky;
- index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx;
- return rhs(index[0], index[1], index[2], index[3]);
- };
-
- // Lambda to access the result data at the given 4D index.
- const auto result_element = [&](int64 batch, int64 kernel_output_feature,
- int64 height, int64 width) -> float& {
- std::array<int64, 4> index;
- index[dnums.batch_dimension()] = batch;
- index[dnums.feature_dimension()] = kernel_output_feature;
- index[dnums.spatial_dimensions(0)] = height;
- index[dnums.spatial_dimensions(1)] = width;
- return (*result)(index[0], index[1], index[2], index[3]);
- };
-
- for (int64 oyi = 0; oyi < oy; ++oyi) {
- for (int64 oxi = 0; oxi < ox; ++oxi) {
- for (int64 sample = 0; sample < samples; ++sample) {
- for (int64 izi = 0; izi < iz; ++izi) {
- for (int64 ozi = 0; ozi < oz; ++ozi) {
- for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky;
- kyi += dky, kyi_over_dky++) {
- for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx;
- kxi += dkx, kxi_over_dkx++) {
- int64 iyi = istarty + ksy * oyi + kyi;
- int64 ixi = istartx + ksx * oxi + kxi;
- float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0)
- ? 0.0
- : lhs_element(sample, izi, iyi, ixi);
- float gain =
- rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx);
- float addend = input * gain;
- result_element(sample, ozi, oyi, oxi) += addend;
- }
- }
- }
- }
- }
- }
- }
- if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 ||
- iz == 0) {
- LOG(INFO) << "Output will be trivially empty because one of these "
- "dimensions is 0: samples: "
- << samples << " kx: " << kx << " ky: " << ky << " ox: " << ox
- << " oy: " << oy << " oz: " << oz << " iz: " << iz;
- return result;
- }
- bool trivial = true;
- auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice<int64> indices,
- float value) {
- if (value != 0.0) {
- trivial = false;
- }
- };
- lhs.Each(check_trivial);
- if (trivial) {
- LOG(FATAL) << "LHS is all 0.0.";
- }
- trivial = true;
- rhs.Each(check_trivial);
- if (trivial) {
- LOG(FATAL) << "RHS is all 0.0.";
- }
return result;
}