diff options
author | 2018-10-09 14:19:07 -0700 | |
---|---|---|
committer | 2018-10-09 14:28:11 -0700 | |
commit | fa1542234857acf56af6e7f0dbe8d2084a18fa00 (patch) | |
tree | 1254448bf59e0fc3330d421059f53e0258dc56b6 /tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc | |
parent | b145f46b735fe1e383be6629cafaa5269b07b7fb (diff) |
[XLA:GPU] Pattern match atomic "apply" into an atomic store
Otherwise we'd emit a CAS loop.
PiperOrigin-RevId: 216421161
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc new file mode 100644 index 0000000000..6b18c4c637 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <utility> + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuAtomicTest : public GpuCodegenTest {}; + +TEST_F(GpuAtomicTest, TestStore) { + const char* hlo_string = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } +)"; + + CompileAndVerifyIr(hlo_string, R"( +CHECK: store atomic{{.*}}unordered, align 4 +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla |