aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/llvm_ir
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-07-18 07:02:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 07:06:51 -0700
commita46c9ab4419402182c34404f0f57c1f7b6b51858 (patch)
treebdeb1b5249f2555398a3cd2ac850d4d833fe8659 /tensorflow/compiler/xla/service/llvm_ir
parentff791a7fde3605493bef70de8a9c9779541daf66 (diff)
Support unsigned indices for in-place DynamicUpdateSlice.
For unsigned indices, we need to use unsigned comparisons when clamping the start_indices. Also rename the files from ops.* to dynamic_update_slice_util.* PiperOrigin-RevId: 205072344
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir')
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc (renamed from tensorflow/compiler/xla/service/llvm_ir/ops.cc)23
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h (renamed from tensorflow/compiler/xla/service/llvm_ir/ops.h)6
3 files changed, 20 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 6f1e04a1c6..c14a5bfb53 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -164,9 +164,9 @@ cc_library(
)
cc_library(
- name = "ops",
- srcs = ["ops.cc"],
- hdrs = ["ops.h"],
+ name = "dynamic_update_slice_util",
+ srcs = ["dynamic_update_slice_util.cc"],
+ hdrs = ["dynamic_update_slice_util.h"],
deps = [
":fused_ir_emitter",
":ir_array",
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
index 3b298f4746..7048fcfdc9 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
@@ -38,8 +38,8 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
// Emits a sequential loop if launch_dimensions is null.
static Status EmitDynamicUpdateSliceInPlaceImpl(
const Shape& update_shape, const ElementGenerator& start_indices_generator,
- ElementGenerator update_array_generator, const IrArray& output_array,
- const gpu::LaunchDimensions* launch_dimensions,
+ bool is_signed, ElementGenerator update_array_generator,
+ const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) {
const Shape& output_shape = output_array.GetShape();
@@ -59,17 +59,20 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
- // to oficially document different behavior.
+ // to officially document different behavior.
llvm::Value* max_bound =
ir_builder->CreateSub(output_dim_size, update_dim_size);
llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0);
start_index[i] = ir_builder->CreateSelect(
- ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]),
+ ir_builder->CreateICmp(
+ is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
+ zero, start_index[i]),
zero, start_index[i]);
start_index[i] = ir_builder->CreateSelect(
- ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound,
- start_index[i]),
+ ir_builder->CreateICmp(
+ is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
+ max_bound, start_index[i]),
max_bound, start_index[i]);
}
@@ -122,8 +125,9 @@ Status EmitDynamicUpdateSliceInPlace(
return update_array.EmitReadArrayElement(index, ir_builder);
};
+ bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape());
return EmitDynamicUpdateSliceInPlaceImpl(
- update_shape, start_indices_generator, update_array_generator,
+ update_shape, start_indices_generator, is_signed, update_array_generator,
output_array, /*launch_dimensions=*/nullptr, name, ir_builder);
}
@@ -170,8 +174,9 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
ElementGenerator start_indices_generator =
fused_emitter.GetGenerator(start_indices);
+ bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
return EmitDynamicUpdateSliceInPlaceImpl(
- update_shape, start_indices_generator, update_array_generator,
+ update_shape, start_indices_generator, is_signed, update_array_generator,
fusion_output_array, launch_dimensions, IrName(fusion), ir_builder);
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
index 175b081e84..7f73fb6b29 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ops.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
@@ -90,4 +90,4 @@ Status EmitParallelFusedDynamicUpdateSliceInPlace(
} // namespace llvm_ir
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_