diff options
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.cc | 75 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.cc | 34 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/gather_operation_test.cc | 79 |
3 files changed, 127 insertions, 61 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 21c6f7d358..bd68685153 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1565,7 +1565,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice( // TODO(b/74360564): This is implementation defined behavior, but is // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. + // to officially document different behavior. start_index_value = ir_builder_->CreateSExtOrTrunc(start_index_value, index_type); llvm::Value* operand_dim_size = @@ -1610,19 +1610,22 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( llvm::Type* index_type = index.GetType(); // This is the index into `operand` that holds the element we want to - // generate. This index "unsafe" as in the components in here may be - // out of bounds. - IrArray::Index unsafe_operand_index(index_type); - - // First copy in the window indices to unsafe_operand_index. - for (int64 i = 0, e = operand_shape.dimensions_size(), - unsafe_operand_index_dim = 0; + // generate. + IrArray::Index operand_index(index_type); + + // First copy in the window indices to operand_index. Also collect a mapping + // from operand dimension to output window dimension. Elided window dimensions + // map to -1. + std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1); + for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { if (c_binary_search(dim_numbers.elided_window_dims(), i)) { - unsafe_operand_index.push_back(index.GetConstantWithIndexType(0)); + operand_index.push_back(index.GetConstantWithIndexType(0)); } else { - unsafe_operand_index.push_back( - index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]); + int64 output_window_dim = + dim_numbers.output_window_dims(operand_index_dim++); + operand_to_output_dim[i] = output_window_dim; + operand_index.push_back(index[output_window_dim]); } } @@ -1641,20 +1644,42 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( } } - auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, - int64 dim) { + auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc(index_component, index_type); - unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = - ir_builder_->CreateAdd( - unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)], - gather_dim_component_extended); + int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim); + int64 output_dim = operand_to_output_dim[operand_dim]; + // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. + // This means we set the iteration index to 0, so for the purpose of the + // following calculations we can consider the output dimension size to be 1. + int64 output_dim_size = + output_dim == -1 ? 1 : output_shape.dimensions(output_dim); + int64 largest_valid_start_index = + operand_shape.dimensions(operand_dim) - output_dim_size; + CHECK_GE(largest_valid_start_index, 0); + + // Clamp the gather index so that the gather region fits in the operand. + // gather_dim_component_extended_inbound = + // clamp(gather_dim_component_extended, 0, largest_valid_start_index); + + // TODO(b/111078873): This is implementation defined behavior. + + bool is_signed = ShapeUtil::ElementIsSigned(indices_shape); + auto gather_dim_component_extended_inbound = EmitIntegralMin( + index.GetConstantWithIndexType(largest_valid_start_index), + EmitIntegralMax(index.GetConstantWithIndexType(0), + gather_dim_component_extended, + /*is_signed=*/is_signed), + /*is_signed=*/is_signed); + + operand_index[operand_dim] = ir_builder_->CreateAdd( + operand_index[operand_dim], gather_dim_component_extended_inbound); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); - add_to_unsafe_operand_index(gather_dim_component, 0); + add_to_operand_index(gather_dim_component, 0); } else { int64 index_vector_size = indices_shape.dimensions(dim_numbers.index_vector_dim()); @@ -1663,18 +1688,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( index.GetConstantWithIndexType(i); TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); - add_to_unsafe_operand_index(gather_dim_component, i); + add_to_operand_index(gather_dim_component, i); } } - - IrArray::Index safe_operand_index(index_type); - for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { - safe_operand_index.push_back(ir_builder_->CreateURem( - unsafe_operand_index[i], - index.GetConstantWithIndexType(operand_shape.dimensions(i)))); - } - - return operand_generator(safe_operand_index); + return operand_generator(operand_index); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice( @@ -1706,7 +1723,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // TODO(b/74360564): This is implementation defined behavior, but is // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. + // to officially document different behavior. start_index_value = ir_builder_->CreateSExtOrTrunc(start_index_value, index_type); llvm::Value* input_dim_size = diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index f4fd9ba926..c5f6fe3fd8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -775,6 +775,12 @@ class OutputWindowIndexToInputIndex { return ArraySlice<int64>(input_index_); } + // Returns for a given 'input_dim' the corresponding output dimension index, + // or -1 if 'input_dim' is an elided window dimension. + int64 input_dim_value_to_output_index(int64 input_dim) { + return input_dim_value_to_output_index_[input_dim]; + } + private: // Propagates window dimensions from the output index to input_index_ by // mutating input_index_ in place. @@ -845,6 +851,8 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { // corresponding index in the input shape. std::vector<int64> input_index(operand.shape().dimensions_size()); std::vector<int64> output_index(gather->shape().dimensions_size()); + std::vector<int64> input_gather_index_clamped( + operand.shape().dimensions_size()); OutputGatherIndexToInputIndex output_gather_index_to_input_index( &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), @@ -866,14 +874,26 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { output_index[i] = output_gather_index[i] + output_window_index[i]; DCHECK_LT(output_index[i], shape.dimensions(i)); } + for (int i = 0, e = input_gather_index.size(); i < e; i++) { + int64 output_dim = + output_window_index_to_input_index.input_dim_value_to_output_index(i); + // If 'output_dim' is -1, it means 'i' is an elided window dim. This means + // we set the iteration index to 0, so for the purpose of the following + // calculations we can consider the output dimension size to be 1. + int64 output_dim_size = + output_dim == -1 ? 1 : shape.dimensions(output_dim); + // Clamp the gather index so that the gather region fits in the operand. + // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0, + // operand_shape.dimensions(i) - + // output_dim_size); + input_gather_index_clamped[i] = + std::min(operand_shape.dimensions(i) - output_dim_size, + std::max(0LL, input_gather_index[i])); + } for (int i = 0, e = input_index.size(); i < e; i++) { - // TODO(b/74360564): We should implement whatever out of bounds behavior - // we decide for dynamic-slice here as well. - input_index[i] = (input_gather_index[i] + input_window_index[i]) % - operand_shape.dimensions(i); - if (input_index[i] < 0) { - input_index[i] += operand_shape.dimensions(i); - } + input_index[i] = input_gather_index_clamped[i] + input_window_index[i]; + DCHECK_GT(input_index[i], 0); + DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } TF_RETURN_IF_ERROR( result->CopyElementFrom(operand, input_index, output_index)); diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 9178b50595..c5ca64fa3f 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -22,9 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -// NB! TODO(b/74360564): These tests do not test out of bounds behavior since -// that hasn't been specced yet. - namespace xla { namespace { @@ -273,10 +270,6 @@ ENTRY main { XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { // Out of bounds indices must not crash, and the indices in range should // produce the same values across all backends. - // - // TODO(b/74360564): Once we have a well defined semantics for OOB accesses, - // we should get rid of the mask and check that backends produce the same - // value for OOB indices too. const string hlo_text = R"( HloModule BatchDynamicSlice @@ -290,29 +283,45 @@ ENTRY main { gather_dims_to_operand_dims={0,1}, index_vector_dim=1, window_bounds={1,1} - gather_reshaped = s32[6]{0} reshape(gather) - in_bounds_mask = s32[6]{0} parameter(2) - ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) + ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - std::unique_ptr<Literal> in_bounds_mask = - LiteralUtil::CreateR1<int32>({0, 1, 1, 0, 0, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { + // Out of bounds indices must not crash, and the indices in range should + // produce the same values across all backends. + + const string hlo_text = R"( +HloModule BatchDynamicSlice - RunTest(hlo_text, - {operand.get(), gather_indices.get(), in_bounds_mask.get()}); +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = u32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + ROOT result = s32[6]{0} reshape(gather) +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<uint32>( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { // Negative indices must not crash, and the indices in range should produce // the same values across all backends. - // - // TODO(b/74360564): Once we have a well defined semantics for negative - // accesses, we should get rid of the mask and check that backends produce the - // same value for negative indices too. const string hlo_text = R"( HloModule BatchDynamicSlice @@ -326,20 +335,40 @@ ENTRY main { gather_dims_to_operand_dims={0,1}, index_vector_dim=1, window_bounds={1,1} - gather_reshaped = s32[6]{0} reshape(gather) - in_bounds_mask = s32[6]{0} parameter(2) - ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) + ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - std::unique_ptr<Literal> in_bounds_mask = - LiteralUtil::CreateR1<int32>({0, 1, 1, 0, 0, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { + // Negative indices must not crash, and the indices in range should produce + // the same values across all backends. - RunTest(hlo_text, - {operand.get(), gather_indices.get(), in_bounds_mask.get()}); + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = u32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = u32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + ROOT result = u32[6]{0} reshape(gather) +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>( + {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { |