diff options
-rw-r--r-- | tensorflow/compiler/xla/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/reference_util.cc | 247 |
2 files changed, 81 insertions, 169 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index e0a03a78f1..ba90b13b38 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -563,6 +563,9 @@ cc_library( ":xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_evaluator", + "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 7ef5c6d916..c851c38ea4 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -20,6 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/math/math_util.h" @@ -446,179 +449,85 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( std::pair<int64, int64> kernel_stride, Padding padding, std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation, ConvolutionDimensionNumbers dnums) { - std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}}; - std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}}; - - const int64 ksy = kernel_stride.first; - const int64 ksx = kernel_stride.second; - const int64 dy = lhs_dilation.first; - const int64 dx = lhs_dilation.second; - const int64 dky = rhs_dilation.first; - const int64 dkx = rhs_dilation.second; - CHECK_GE(dky, 1); - CHECK_GE(dkx, 1); - CHECK_GE(dy, 1); - CHECK_GE(dx, 1); - - // Get all dimension sizes in lhs and rhs based on the given convolution - // dimension configuration. - const int64 ix = window_util::DilatedBound( - lhs_dimensions[dnums.spatial_dimensions(1)], dx); - const int64 iy = window_util::DilatedBound( - lhs_dimensions[dnums.spatial_dimensions(0)], dy); - const int64 iz = lhs_dimensions[dnums.feature_dimension()]; - const int64 samples = lhs_dimensions[dnums.batch_dimension()]; - const int64 kx = window_util::DilatedBound( - rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx); - const int64 ky = window_util::DilatedBound( - rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky); - const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()]; - { - const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()]; - CHECK_EQ(kiz, iz); + HloComputation::Builder b("ConvArray4DGeneralDimensionDilated"); + auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs); + auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs); + + std::array<int64, 2> ordered_kernel_strides; + std::array<int64, 2> ordered_input_dimensions; + std::array<int64, 2> ordered_kernel_dimensions; + if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) { + ordered_kernel_strides[0] = kernel_stride.second; + ordered_kernel_strides[1] = kernel_stride.first; + } else { + ordered_kernel_strides[0] = kernel_stride.first; + ordered_kernel_strides[1] = kernel_stride.second; } - if (padding == Padding::kSame) { - // We reject same padding with kernel striding, since it's somewhat - // nonsensical. We can always follow up to implement this with the desired - // semantics if anybody actually uses it. - CHECK_EQ(1, ksy); - CHECK_EQ(1, ksx); - } - - const int64 ox = - padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx); - const int64 oy = - padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy); - const int64 istartx = - padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2; - const int64 istarty = - padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2; - // Create the output result array and reset the values to 0. - std::array<int64, 4> result_dimensions; - result_dimensions[dnums.batch_dimension()] = samples; - result_dimensions[dnums.feature_dimension()] = oz; - result_dimensions[dnums.spatial_dimensions(0)] = oy; - result_dimensions[dnums.spatial_dimensions(1)] = ox; + ordered_input_dimensions[0] = + lhs_literal->shape().dimensions(dnums.spatial_dimensions(0)); + ordered_input_dimensions[1] = + lhs_literal->shape().dimensions(dnums.spatial_dimensions(1)); + ordered_kernel_dimensions[0] = + rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); + ordered_kernel_dimensions[1] = + rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); + + std::vector<std::pair<int64, int64>> paddings = + MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, + ordered_kernel_strides, padding); + CHECK_EQ(paddings.size(), 2); + + Window window; + + WindowDimension dim; + dim.set_size( + rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); + dim.set_stride(kernel_stride.first); + dim.set_padding_low(paddings[0].first); + dim.set_padding_high(paddings[0].second); + dim.set_window_dilation(rhs_dilation.first); + dim.set_base_dilation(lhs_dilation.first); + *window.add_dimensions() = dim; + + WindowDimension dim2; + dim2.set_size( + rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); + dim2.set_stride(kernel_stride.second); + dim2.set_padding_low(paddings[1].first); + dim2.set_padding_high(paddings[1].second); + dim2.set_window_dilation(rhs_dilation.second); + dim2.set_base_dilation(lhs_dilation.second); + *window.add_dimensions() = dim2; + + const Shape& shape = + ShapeInference::InferConvolveShape(lhs_literal->shape(), + rhs_literal->shape(), window, dnums) + .ConsumeValueOrDie(); + + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + HloEvaluator evaluator; + std::unique_ptr<Literal> result_literal = + evaluator.Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); auto result = - MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1], - result_dimensions[2], result_dimensions[3]); - result->Fill(0.0); - - const auto is_int32 = [](int64 x) { - return x >= std::numeric_limits<int32>::min() && - x <= std::numeric_limits<int32>::max(); - }; - - // 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at - // least on x86-64), so we avoid them where possible. - const auto fast_idiv64 = [&](int64 a, int64 b) { - if (is_int32(a) && is_int32(b)) { - return static_cast<int64>(static_cast<int32>(a) / static_cast<int32>(b)); - } - return a / b; - }; - const auto fast_imod64 = [&](int64 a, int64 b) { - if (is_int32(a) && is_int32(b)) { - return static_cast<int64>(static_cast<int32>(a) % static_cast<int32>(b)); - } - return a % b; - }; - - // Lambda to access the lhs operand at the given 4D index. - const auto lhs_element = [&](int64 batch, int64 feature, int64 height, - int64 width) { - if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) { - return 0.0f; - } + MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0), + result_literal->shape().dimensions(1), + result_literal->shape().dimensions(2), + result_literal->shape().dimensions(3)); + + result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) { + *value = result_literal->Get<float>(indices); + }); - std::array<int64, 4> index; - index[dnums.batch_dimension()] = batch; - index[dnums.feature_dimension()] = feature; - index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy); - index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx); - return lhs(index[0], index[1], index[2], index[3]); - }; - - // Lambda to access the rhs operand at the given 4D index. height_over_dky - // should be equal to height / dky, and width_over_dkx should be equal to - // width / dkx. (This is an optimization to avoid doing divisions.) - const auto rhs_element = - [&](int64 kernel_output_feature, int64 kernel_input_feature, int64 height, - int64 width, int64 height_over_dky, int64 width_over_dkx) { - DCHECK_EQ(height % dky, 0); - DCHECK_EQ(width % dkx, 0); - DCHECK_EQ(height / dky, height_over_dky); - DCHECK_EQ(width / dkx, width_over_dkx); - - std::array<int64, 4> index; - index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; - index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; - index[dnums.kernel_spatial_dimensions(0)] = height_over_dky; - index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx; - return rhs(index[0], index[1], index[2], index[3]); - }; - - // Lambda to access the result data at the given 4D index. - const auto result_element = [&](int64 batch, int64 kernel_output_feature, - int64 height, int64 width) -> float& { - std::array<int64, 4> index; - index[dnums.batch_dimension()] = batch; - index[dnums.feature_dimension()] = kernel_output_feature; - index[dnums.spatial_dimensions(0)] = height; - index[dnums.spatial_dimensions(1)] = width; - return (*result)(index[0], index[1], index[2], index[3]); - }; - - for (int64 oyi = 0; oyi < oy; ++oyi) { - for (int64 oxi = 0; oxi < ox; ++oxi) { - for (int64 sample = 0; sample < samples; ++sample) { - for (int64 izi = 0; izi < iz; ++izi) { - for (int64 ozi = 0; ozi < oz; ++ozi) { - for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky; - kyi += dky, kyi_over_dky++) { - for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx; - kxi += dkx, kxi_over_dkx++) { - int64 iyi = istarty + ksy * oyi + kyi; - int64 ixi = istartx + ksx * oxi + kxi; - float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0) - ? 0.0 - : lhs_element(sample, izi, iyi, ixi); - float gain = - rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx); - float addend = input * gain; - result_element(sample, ozi, oyi, oxi) += addend; - } - } - } - } - } - } - } - if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 || - iz == 0) { - LOG(INFO) << "Output will be trivially empty because one of these " - "dimensions is 0: samples: " - << samples << " kx: " << kx << " ky: " << ky << " ox: " << ox - << " oy: " << oy << " oz: " << oz << " iz: " << iz; - return result; - } - bool trivial = true; - auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice<int64> indices, - float value) { - if (value != 0.0) { - trivial = false; - } - }; - lhs.Each(check_trivial); - if (trivial) { - LOG(FATAL) << "LHS is all 0.0."; - } - trivial = true; - rhs.Each(check_trivial); - if (trivial) { - LOG(FATAL) << "RHS is all 0.0."; - } return result; } |